diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index 9f2d674dd..11a8b93e4 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -110,9 +110,11 @@ def _patch_transformers_causal_conv1d_hub_kernel_compat(): return with _MONKEY_PATCH_LOCK: + if not hasattr(hf_integrations, "lazy_load_kernel"): + return + if getattr(hub_kernels, "_gptqmodel_local_causal_conv1d_kernel", False): - if hasattr(hf_integrations, "lazy_load_kernel"): - hf_integrations.lazy_load_kernel = hub_kernels.lazy_load_kernel + hf_integrations.lazy_load_kernel = hub_kernels.lazy_load_kernel return original_lazy_load_kernel = hub_kernels.lazy_load_kernel diff --git a/gptqmodel/looper/forward_executor.py b/gptqmodel/looper/forward_executor.py index d01592ed5..b98a7548c 100644 --- a/gptqmodel/looper/forward_executor.py +++ b/gptqmodel/looper/forward_executor.py @@ -232,6 +232,7 @@ def run_single( """Run the forward pass sequentially on the current device.""" outputs: List[List[torch.Tensor]] = [] + write_shared_kv_cache = bool(getattr(self.looper.gptq_model, "write_shared_kv_cache", False)) prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None total_batches, batch_row_counts, total_rows = self._resolve_batch_progress( processor, @@ -320,7 +321,7 @@ def run_single( del additional_inputs if ( - reuse_kv + (reuse_kv or write_shared_kv_cache) and module_output is not None and isinstance(module_output, tuple) and len(module_output) > 0 @@ -462,6 +463,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> ctx.__enter__() moe_contexts.append(ctx) + write_shared_kv_cache = bool(getattr(self.looper.gptq_model, "write_shared_kv_cache", False)) prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None results: Dict[int, torch.Tensor | None] = {} processed_rows = 0 @@ -553,6 +555,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> need_output=need_outputs, reuse_kv=reuse_kv, prev_kv=prev_kv, + write_shared_kv_cache=write_shared_kv_cache, ) ) @@ -566,7 +569,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> primary = module_output[0] if isinstance(module_output, tuple) else module_output results[batch_idx] = move_to(primary, device=target_device) del module_output - if reuse_kv and kv_next is not None and shared_kv_cache_dict.get(layer_index) is None: + if (reuse_kv or write_shared_kv_cache) and kv_next is not None and shared_kv_cache_dict.get(layer_index) is None: shared_kv_cache_dict[layer_index] = nested_move_to(kv_next, device=cur_layer_device) rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0 diff --git a/gptqmodel/looper/stage_layer.py b/gptqmodel/looper/stage_layer.py index 0b3cfa4d1..3a1c81f34 100644 --- a/gptqmodel/looper/stage_layer.py +++ b/gptqmodel/looper/stage_layer.py @@ -272,7 +272,7 @@ def _replay_layer_outputs( shared_kv_cache_dict=shared_kv_cache_dict, layer_index=layer_index, need_outputs=True, - reuse_kv=False, + reuse_kv=getattr(module, "reuse_kv", False), progress_pb=replay_pb, progress_title=replay_msg, progress_stage="Forward replay", diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 2b6c62933..e10867709 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -231,6 +231,12 @@ class BaseQModel(nn.Module): # some models have broken attention mask codes so we need to only use batch 1 with no masks support_batch_quantize = True + # Whether this model should publish a layer's KV tuple into the shared + # replay cache even when that same layer does not consume `kv_last_layer`. + # Models like "hymba" need some layers to write KV for later layers even if + # the current layer itself does not consume `kv_last_layer`. + write_shared_kv_cache = False + # allow models to define optional notes that output messages to users that want to use this model # list of supported keys: [ "notes" = print the notes value on model load ] info: Dict[str, str] = {} diff --git a/gptqmodel/models/definitions/hymba.py b/gptqmodel/models/definitions/hymba.py index 1575743b9..9ef18a0bd 100644 --- a/gptqmodel/models/definitions/hymba.py +++ b/gptqmodel/models/definitions/hymba.py @@ -11,6 +11,7 @@ class HymbaQModel(BaseQModel): supports_desc_act = [False] require_trust_remote_code = True require_monkeypatch = True + write_shared_kv_cache = True require_pkgs = ["tiktoken>=0.7.0", "sentencepiece>=0.2.0", "protobuf>=5.28.3", diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py index ec331d283..95e4749a3 100644 --- a/gptqmodel/utils/hf.py +++ b/gptqmodel/utils/hf.py @@ -21,7 +21,6 @@ AutoModelForCausalLM, AutoTokenizer, GenerationConfig, - PreTrainedConfig, PreTrainedModel, ) @@ -41,7 +40,14 @@ try: from transformers.initialization import no_init_weights except ImportError: - from transformers.modeling_utils import no_init_weights + from transformers.modeling_utils import no_init_weights# Compatibility wrapper for no_init_weights across different transformers versions + +# transformers >= 5.0.0: from transformers import PreTrainedConfig +# transformers < 5.0.0: from transformers import PretrainedConfig +try: + from transformers import PreTrainedConfig +except ImportError: + from transformers import PretrainedConfig as PreTrainedConfig from ..utils.logger import setup_logger @@ -1016,6 +1022,14 @@ def _normalize_remote_code_config_compat(config: Any) -> None: model_type = getattr(config, "model_type", None) model_type_lower = model_type.lower() if isinstance(model_type, str) else None + if model_type_lower == "hymba": + # hymba uses Flex by default; + # however, `modeling_hymba` has not yet been adapted to support the latest version of PyTorch Flex. + # Therefore, we are applying a patch here to switch to flash_attention_2 or sdpa. + if getattr(config, 'attn_implementation_new', None) == "flex": + from transformers.utils import is_flash_attn_2_available + config.attn_implementation_new = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" + if model_type_lower == "dream" or model_type == "brumby": import transformers.modeling_rope_utils as rope_utils # dream remote models expect "default" @@ -1234,6 +1248,31 @@ def tie_weights_compat(self, *args, **kwargs): formatter_cls.support_tokenizer_types = support_tokenizer_types formatter_cls._gptqmodel_tokenizer_backend_patch = True + if getattr(config, "model_type", None) == "hymba" and remote_module is not None: + rotary_cls = getattr(remote_module, "LlamaRotaryEmbedding", None) + attention_cls = getattr(remote_module, "HymbaAttention", None) + if ( + rotary_cls is not None + and attention_cls is not None + and not getattr(attention_cls, "_gptqmodel_init_rope_meta_patch", False) + ): + def hymba_init_rope_compat(self): + # Hymba remote code hard-codes CUDA here, which forces a + # meta->real-device materialization during __init__ under + # transformers 5.x. Keep the device is None until HF + # finishes model loading and device placement. + device = None + + self.rotary_emb = rotary_cls( + config=self.config, + dim=self.kq_head_dim, + base=self.rope_theta, + device=device, + ) + + attention_cls._init_rope = hymba_init_rope_compat + attention_cls._gptqmodel_init_rope_meta_patch = True + if getattr(config, "model_type", None) != "phi4mm": return diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index f5c953751..e9cf49bb1 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -370,6 +370,7 @@ def forward_batch_worker( need_output: bool, reuse_kv: bool, prev_kv, + write_shared_kv_cache: bool = False, ): processor._set_current_batch_index(batch_index) module_device = getattr(module, "_gptqmodule_device_hint", None) or get_device(module) @@ -433,7 +434,7 @@ def forward_batch_worker( mask_tls.value = None processor._set_current_batch_index(None) - if reuse_kv and module_output is not None and isinstance(module_output, tuple) and len(module_output) > 0: + if (reuse_kv or write_shared_kv_cache) and module_output is not None and isinstance(module_output, tuple) and len(module_output) > 0: kv_next = module_output[-1] result_output = None diff --git a/tests/models/test_hymba.py b/tests/models/test_hymba.py index de2481511..bc73ec6c9 100644 --- a/tests/models/test_hymba.py +++ b/tests/models/test_hymba.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium - +from gptqmodel import BACKEND from model_test import ModelTest @@ -11,8 +11,8 @@ class TestHymba(ModelTest): EVAL_TASKS_SLOW = { "arc_challenge": { "chat_template": True, - "acc": {"value": 0.2073, "floor_pct": 0.75}, - "acc_norm": {"value": 0.2713, "floor_pct": 0.75}, + "acc": {"value": {"A100": 0.3737}, "floor_pct": 0.75}, + "acc_norm": {"value": {"A100": 0.3703}, "floor_pct": 0.75}, }, } EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) @@ -25,6 +25,9 @@ class TestHymba(ModelTest): # Hymba currently tests that DESC_ACT=False to get better results. # If DESC_ACT=False, the output will be terrible. DESC_ACT = False + OFFLOAD_TO_DISK = False # FIXME the issue where hymba does not work with OFFLOAD_TO_DISK=True + LOAD_BACKEND = BACKEND.AUTO + USE_FLASH_ATTN = True def test_hymba(self): diff --git a/tests/test_compute_device_filter.py b/tests/test_compute_device_filter.py index 1114416af..80cf1ff5e 100644 --- a/tests/test_compute_device_filter.py +++ b/tests/test_compute_device_filter.py @@ -7,8 +7,9 @@ class _DummyModel: - def __init__(self, compute_device_filter): + def __init__(self, compute_device_filter, *, write_shared_kv_cache=False): self.support_batch_quantize = False + self.write_shared_kv_cache = write_shared_kv_cache self.quantize_config = types.SimpleNamespace( device=torch.device("cpu"), dense_vram_strategy="exclusive", @@ -118,4 +119,75 @@ def fake_submit(device, fn, *args, **kwargs): assert outputs == [] assert called_devices assert torch.device("meta") not in called_devices - assert all(device.type == "cpu" for device in called_devices) + + +def test_parallel_forward_writes_shared_kv_cache_when_model_requests_it(monkeypatch): + devices = [torch.device("cpu"), torch.device("meta")] + looper = ModuleLooper( + model=_DummyModel(lambda candidates: [candidates[0]], write_shared_kv_cache=True), + processors=[], + ) + + class DummyProcessor: + num_batches = 1 + + def _set_current_batch_index(self, _index): + return None + + def fake_clone_module_for_devices(_module, target_devices, progress_callback=None): + return {device: object() for device in target_devices} + + kv_payload = ("kv",) + + def fake_forward_batch_worker(*args, **kwargs): + del args + assert kwargs["reuse_kv"] is False + assert kwargs["write_shared_kv_cache"] is True + return 0, None, kv_payload + + class DummyFuture: + def __init__(self, result): + self._result = result + + def result(self): + return self._result + + def fake_submit(device, fn, *args, **kwargs): + return DummyFuture(fn(*args, **kwargs)) + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.clone_module_for_devices", + fake_clone_module_for_devices, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.forward_batch_worker", + fake_forward_batch_worker, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit", + fake_submit, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit_serial", + fake_submit, + ) + + shared_kv_cache_dict = {} + outputs = looper._run_forward_batches_parallel( + module=torch.nn.Linear(1, 1), + processor=DummyProcessor(), + layer_inputs=[[torch.zeros(1, 1)]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[torch.zeros(1, 1)], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=0, + need_outputs=False, + reuse_kv=False, + devices=devices, + ) + + assert outputs == [] + assert shared_kv_cache_dict[0] == kv_payload