Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions gptqmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def _patch_transformers_causal_conv1d_hub_kernel_compat():
return

with _MONKEY_PATCH_LOCK:
if not hasattr(hf_integrations, "lazy_load_kernel"):
return

if getattr(hub_kernels, "_gptqmodel_local_causal_conv1d_kernel", False):
if hasattr(hf_integrations, "lazy_load_kernel"):
hf_integrations.lazy_load_kernel = hub_kernels.lazy_load_kernel
hf_integrations.lazy_load_kernel = hub_kernels.lazy_load_kernel
return

original_lazy_load_kernel = hub_kernels.lazy_load_kernel
Expand Down
7 changes: 5 additions & 2 deletions gptqmodel/looper/forward_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def run_single(
"""Run the forward pass sequentially on the current device."""

outputs: List[List[torch.Tensor]] = []
write_shared_kv_cache = bool(getattr(self.looper.gptq_model, "write_shared_kv_cache", False))
prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None
total_batches, batch_row_counts, total_rows = self._resolve_batch_progress(
processor,
Expand Down Expand Up @@ -320,7 +321,7 @@ def run_single(
del additional_inputs

if (
reuse_kv
(reuse_kv or write_shared_kv_cache)
and module_output is not None
and isinstance(module_output, tuple)
and len(module_output) > 0
Expand Down Expand Up @@ -462,6 +463,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
ctx.__enter__()
moe_contexts.append(ctx)

write_shared_kv_cache = bool(getattr(self.looper.gptq_model, "write_shared_kv_cache", False))
prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None
results: Dict[int, torch.Tensor | None] = {}
processed_rows = 0
Expand Down Expand Up @@ -553,6 +555,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
need_output=need_outputs,
reuse_kv=reuse_kv,
prev_kv=prev_kv,
write_shared_kv_cache=write_shared_kv_cache,
)
)

Expand All @@ -566,7 +569,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
primary = module_output[0] if isinstance(module_output, tuple) else module_output
results[batch_idx] = move_to(primary, device=target_device)
del module_output
if reuse_kv and kv_next is not None and shared_kv_cache_dict.get(layer_index) is None:
if (reuse_kv or write_shared_kv_cache) and kv_next is not None and shared_kv_cache_dict.get(layer_index) is None:
shared_kv_cache_dict[layer_index] = nested_move_to(kv_next, device=cur_layer_device)

rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/looper/stage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _replay_layer_outputs(
shared_kv_cache_dict=shared_kv_cache_dict,
layer_index=layer_index,
need_outputs=True,
reuse_kv=False,
reuse_kv=getattr(module, "reuse_kv", False),
progress_pb=replay_pb,
progress_title=replay_msg,
progress_stage="Forward replay",
Expand Down
6 changes: 6 additions & 0 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ class BaseQModel(nn.Module):
# some models have broken attention mask codes so we need to only use batch 1 with no masks
support_batch_quantize = True

# Whether this model should publish a layer's KV tuple into the shared
# replay cache even when that same layer does not consume `kv_last_layer`.
# Models like "hymba" need some layers to write KV for later layers even if
# the current layer itself does not consume `kv_last_layer`.
write_shared_kv_cache = False

# allow models to define optional notes that output messages to users that want to use this model
# list of supported keys: [ "notes" = print the notes value on model load ]
info: Dict[str, str] = {}
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/hymba.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class HymbaQModel(BaseQModel):
supports_desc_act = [False]
require_trust_remote_code = True
require_monkeypatch = True
write_shared_kv_cache = True
require_pkgs = ["tiktoken>=0.7.0",
"sentencepiece>=0.2.0",
"protobuf>=5.28.3",
Expand Down
43 changes: 41 additions & 2 deletions gptqmodel/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
PreTrainedConfig,
PreTrainedModel,
)

Expand All @@ -41,7 +40,14 @@
try:
from transformers.initialization import no_init_weights
except ImportError:
from transformers.modeling_utils import no_init_weights
from transformers.modeling_utils import no_init_weights# Compatibility wrapper for no_init_weights across different transformers versions

# transformers >= 5.0.0: from transformers import PreTrainedConfig
# transformers < 5.0.0: from transformers import PretrainedConfig
try:
from transformers import PreTrainedConfig
except ImportError:
from transformers import PretrainedConfig as PreTrainedConfig

from ..utils.logger import setup_logger

Expand Down Expand Up @@ -1016,6 +1022,14 @@ def _normalize_remote_code_config_compat(config: Any) -> None:
model_type = getattr(config, "model_type", None)
model_type_lower = model_type.lower() if isinstance(model_type, str) else None

if model_type_lower == "hymba":
# hymba uses Flex by default;
# however, `modeling_hymba` has not yet been adapted to support the latest version of PyTorch Flex.
# Therefore, we are applying a patch here to switch to flash_attention_2 or sdpa.
if getattr(config, 'attn_implementation_new', None) == "flex":
from transformers.utils import is_flash_attn_2_available
config.attn_implementation_new = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"

if model_type_lower == "dream" or model_type == "brumby":
import transformers.modeling_rope_utils as rope_utils
# dream remote models expect "default"
Expand Down Expand Up @@ -1234,6 +1248,31 @@ def tie_weights_compat(self, *args, **kwargs):
formatter_cls.support_tokenizer_types = support_tokenizer_types
formatter_cls._gptqmodel_tokenizer_backend_patch = True

if getattr(config, "model_type", None) == "hymba" and remote_module is not None:
rotary_cls = getattr(remote_module, "LlamaRotaryEmbedding", None)
attention_cls = getattr(remote_module, "HymbaAttention", None)
if (
rotary_cls is not None
and attention_cls is not None
and not getattr(attention_cls, "_gptqmodel_init_rope_meta_patch", False)
):
def hymba_init_rope_compat(self):
# Hymba remote code hard-codes CUDA here, which forces a
# meta->real-device materialization during __init__ under
# transformers 5.x. Keep the device is None until HF
# finishes model loading and device placement.
device = None

self.rotary_emb = rotary_cls(
config=self.config,
dim=self.kq_head_dim,
base=self.rope_theta,
device=device,
)

attention_cls._init_rope = hymba_init_rope_compat
attention_cls._gptqmodel_init_rope_meta_patch = True

if getattr(config, "model_type", None) != "phi4mm":
return

Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/utils/looper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def forward_batch_worker(
need_output: bool,
reuse_kv: bool,
prev_kv,
write_shared_kv_cache: bool = False,
):
processor._set_current_batch_index(batch_index)
module_device = getattr(module, "_gptqmodule_device_hint", None) or get_device(module)
Expand Down Expand Up @@ -433,7 +434,7 @@ def forward_batch_worker(
mask_tls.value = None
processor._set_current_batch_index(None)

if reuse_kv and module_output is not None and isinstance(module_output, tuple) and len(module_output) > 0:
if (reuse_kv or write_shared_kv_cache) and module_output is not None and isinstance(module_output, tuple) and len(module_output) > 0:
kv_next = module_output[-1]

result_output = None
Expand Down
9 changes: 6 additions & 3 deletions tests/models/test_hymba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from gptqmodel import BACKEND
from model_test import ModelTest


Expand All @@ -11,8 +11,8 @@ class TestHymba(ModelTest):
EVAL_TASKS_SLOW = {
"arc_challenge": {
"chat_template": True,
"acc": {"value": 0.2073, "floor_pct": 0.75},
"acc_norm": {"value": 0.2713, "floor_pct": 0.75},
"acc": {"value": {"A100": 0.3737}, "floor_pct": 0.75},
"acc_norm": {"value": {"A100": 0.3703}, "floor_pct": 0.75},
},
}
EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW)
Expand All @@ -25,6 +25,9 @@ class TestHymba(ModelTest):
# Hymba currently tests that DESC_ACT=False to get better results.
# If DESC_ACT=False, the output will be terrible.
DESC_ACT = False
OFFLOAD_TO_DISK = False # FIXME the issue where hymba does not work with OFFLOAD_TO_DISK=True
LOAD_BACKEND = BACKEND.AUTO
USE_FLASH_ATTN = True


def test_hymba(self):
Expand Down
76 changes: 74 additions & 2 deletions tests/test_compute_device_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@


class _DummyModel:
def __init__(self, compute_device_filter):
def __init__(self, compute_device_filter, *, write_shared_kv_cache=False):
self.support_batch_quantize = False
self.write_shared_kv_cache = write_shared_kv_cache
self.quantize_config = types.SimpleNamespace(
device=torch.device("cpu"),
dense_vram_strategy="exclusive",
Expand Down Expand Up @@ -118,4 +119,75 @@ def fake_submit(device, fn, *args, **kwargs):
assert outputs == []
assert called_devices
assert torch.device("meta") not in called_devices
assert all(device.type == "cpu" for device in called_devices)


def test_parallel_forward_writes_shared_kv_cache_when_model_requests_it(monkeypatch):
devices = [torch.device("cpu"), torch.device("meta")]
looper = ModuleLooper(
model=_DummyModel(lambda candidates: [candidates[0]], write_shared_kv_cache=True),
processors=[],
)

class DummyProcessor:
num_batches = 1

def _set_current_batch_index(self, _index):
return None

def fake_clone_module_for_devices(_module, target_devices, progress_callback=None):
return {device: object() for device in target_devices}

kv_payload = ("kv",)

def fake_forward_batch_worker(*args, **kwargs):
del args
assert kwargs["reuse_kv"] is False
assert kwargs["write_shared_kv_cache"] is True
return 0, None, kv_payload

class DummyFuture:
def __init__(self, result):
self._result = result

def result(self):
return self._result

def fake_submit(device, fn, *args, **kwargs):
return DummyFuture(fn(*args, **kwargs))

monkeypatch.setattr(
"gptqmodel.looper.module_looper.clone_module_for_devices",
fake_clone_module_for_devices,
)
monkeypatch.setattr(
"gptqmodel.looper.module_looper.forward_batch_worker",
fake_forward_batch_worker,
)
monkeypatch.setattr(
"gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit",
fake_submit,
)
monkeypatch.setattr(
"gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit_serial",
fake_submit,
)

shared_kv_cache_dict = {}
outputs = looper._run_forward_batches_parallel(
module=torch.nn.Linear(1, 1),
processor=DummyProcessor(),
layer_inputs=[[torch.zeros(1, 1)]],
layer_input_kwargs=[{}],
position_ids=[],
attention_masks=[torch.zeros(1, 1)],
cur_layer_device=torch.device("cpu"),
is_lm_head_module=False,
shared_kv_cache_dict=shared_kv_cache_dict,
layer_index=0,
need_outputs=False,
reuse_kv=False,
devices=devices,
)

assert outputs == []
assert shared_kv_cache_dict[0] == kv_payload
Loading