diff --git a/marlin/__init__.py b/marlin/__init__.py index ee7ccf2..eecc98d 100644 --- a/marlin/__init__.py +++ b/marlin/__init__.py @@ -295,7 +295,7 @@ def __init__(self, infeatures, outfeatures, groupsize=-1): self.n = outfeatures self.groupsize = groupsize self.register_buffer( - "B", torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int) + "B", torch.empty((self.k // 16 // 2, self.n * 16 // 8), dtype=torch.int) ) self.register_buffer( "meta", torch.empty((self.n, self.k // 16), dtype=torch.int16) @@ -365,8 +365,15 @@ def pack(self, linear, scales, trans=False): s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] mask = mask_creator(w.T).cuda().bool() + # Avoid confusing the pruned elements and the zero elements (-8 + (maxq + 1) // 2) + w += 1 + w = mask * w.T w, meta = sparse_semi_structured_from_dense_cutlass(w) + + # Reover the original values + w -= 1 + w = w.t() self.k = self.k // 2 self.groupsize = self.groupsize // 2 diff --git a/marlin/_semi_structured_conversions.py b/marlin/_semi_structured_conversions.py index f85092d..dbf8fb8 100644 --- a/marlin/_semi_structured_conversions.py +++ b/marlin/_semi_structured_conversions.py @@ -304,6 +304,11 @@ def mask_creator(tensor): num_groups = tensor.numel() // M + # Subtract the offset value for pruning + maxq = 2**4 - 1 + ZERO_VALUE = (maxq + 1) // 2 + tensor = tensor - ZERO_VALUE + # N:M sparsity for linear layers tensor_temp = tensor.detach().abs().reshape(num_groups, M) index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]