diff --git a/tico/quantization/algorithm/gptq/quantizer.py b/tico/quantization/algorithm/gptq/quantizer.py index 351d0b5c..ba3d05b5 100644 --- a/tico/quantization/algorithm/gptq/quantizer.py +++ b/tico/quantization/algorithm/gptq/quantizer.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import types from typing import Any, Callable, Dict, List, Optional @@ -133,6 +134,15 @@ def forward(layer, *args, **kwargs): self.num_batches += 1 raise StopForward # stop after the first layer + gptq_conf = self.config + assert isinstance(gptq_conf, GPTQConfig) + if gptq_conf.use_orig_model_inference is True: + device = next(model.parameters()).device + model = model.cpu() + self.orig_model = copy.deepcopy(model) + model = model.to(device) + else: + self.orig_model = None # Replace the first layer with defined function to capture calibration data. if hasattr(model, "model"): if hasattr(model.model, "layers") and isinstance( @@ -211,9 +221,12 @@ def convert(self, model): gptq_conf.validate() # Identify layers + orig_layers = None if hasattr(model, "model"): if hasattr(model.model, "layers"): target_layers = model.model.layers + if self.orig_model is not None: + orig_layers = self.orig_model.model.layers else: target_layers = [model] else: @@ -347,7 +360,12 @@ def _hook(_, inp, out): ) cache_kwargs_batch = move_to_device(cache_kwargs_batch, device) - outs = layer(*cache_args_batch, **cache_kwargs_batch) + if orig_layers is None: + outs = layer(*cache_args_batch, **cache_kwargs_batch) + else: + orig_layer = orig_layers[l_idx].to(device) + outs = orig_layer(*cache_args_batch, **cache_kwargs_batch) + orig_layer.cpu() # LLaMA's decoder layer return type differs across Transformers versions: # some return a tuple (hidden_states, ...), others return just a tensor. # This line ensures we always take the first element when it's a tuple. @@ -412,8 +430,12 @@ def _quantize_lm_head(self, model, quantizers): ): hidden_states = gather_single_batch_from_list(self.cache_args, batch_idx)[0] hidden_states = move_to_device(hidden_states, device) - - hidden_states = model.model.norm(hidden_states) + if self.orig_model is None: + hidden_states = model.model.norm(hidden_states) + else: + norm = self.orig_model.model.norm.to(device) + hidden_states = norm(hidden_states) + norm = norm.cpu() if len(self.cache_args) > 0: self.cache_args[0][batch_idx] = move_to_cpu(hidden_states) diff --git a/tico/quantization/config/gptq.py b/tico/quantization/config/gptq.py index c42a19df..31302e24 100644 --- a/tico/quantization/config/gptq.py +++ b/tico/quantization/config/gptq.py @@ -66,6 +66,9 @@ class GPTQConfig(BaseConfig): actorder: bool = True static_groups: bool = False + # use this option to stabilize GPTQ for deep models + use_orig_model_inference: bool = False + @property def name(self) -> str: return "gptq" diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 60f0ff42..f671ebc5 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -273,7 +273,12 @@ def parse_args(): action="store_true", help="Verbose logging for debugging (e.g., GPTQ injection coverage)", ) - + parser.add_argument( + "--gptq_use_orig_model_inference", + action="store_true", + default=False, + help="Run inputs for the next layer on original model to stabilize GPTQ", + ) return parser.parse_args() @@ -432,6 +437,7 @@ def build_gptq_config( mse=args.gptq_mse, sensitivity=sensitivity, quantize_lm_head=args.gptq_lm_head, + use_orig_model_inference=args.gptq_use_orig_model_inference, )