From 3b779281af744c3ce0879812c0d3ddd716fc3848 Mon Sep 17 00:00:00 2001 From: SanBast Date: Mon, 1 Dec 2025 12:58:18 +0100 Subject: [PATCH 1/3] added tests for ra_sae --- overcomplete/sae/__init__.py | 1 + overcomplete/sae/ra_sae.py | 85 ++++++++++++++++++++++++++++++++++++ tests/sae/test_ra_sae.py | 41 +++++++++++++++++ 3 files changed, 127 insertions(+) create mode 100644 overcomplete/sae/ra_sae.py create mode 100644 tests/sae/test_ra_sae.py diff --git a/overcomplete/sae/__init__.py b/overcomplete/sae/__init__.py index d92f814..cbe7229 100644 --- a/overcomplete/sae/__init__.py +++ b/overcomplete/sae/__init__.py @@ -12,6 +12,7 @@ from .factory import EncoderFactory from .jump_sae import JumpSAE, jump_relu, heaviside from .topk_sae import TopKSAE +from .ra_sae import RATopKSAE, RAJumpSAE from .qsae import QSAE from .batchtopk_sae import BatchTopKSAE from .mp_sae import MpSAE diff --git a/overcomplete/sae/ra_sae.py b/overcomplete/sae/ra_sae.py new file mode 100644 index 0000000..b579857 --- /dev/null +++ b/overcomplete/sae/ra_sae.py @@ -0,0 +1,85 @@ +import torch +from torch import nn +from typing import Optional, Any + +from .topk_sae import TopKSAE +from .jump_sae import ( + JumpSAE, +) +from .archetypal_dictionary import RelaxedArchetypalDictionary + + +class RATopKSAE(TopKSAE): + """ + Relaxed Archetypal TopK SAE. + + This class implements a TopK SAE that utilizes a Relaxed Archetypal Dictionary. + The dictionary atoms are initialized/constrained to be convex combinations of + data points. + """ + + def __init__( + self, + input_shape: int, + nb_concepts: int, + points: torch.Tensor, + top_k: int, + **kwargs: Any + ): + """ + Args: + input_shape (int): Dimension of the input data. + nb_concepts (int): Number of dictionary atoms (concepts). + points (torch.Tensor): The data points used to initialize/define the archetypes. + Shape should be (num_points, input_shape). + top_k (int): The k in TopK. + **kwargs: Additional arguments passed to the parent TopKSAE. + """ + # Initialize the parent TopKSAE + super().__init__( + input_shape=input_shape, nb_concepts=nb_concepts, top_k=top_k, **kwargs + ) + + # Overwrite the standard dictionary with the Relaxed Archetypal Dictionary + # We assume points are on the correct device or will be moved via .to(device) later + self.dictionary = RelaxedArchetypalDictionary( + in_dimensions=input_shape, nb_concepts=nb_concepts, points=points + ) + + +# TODO: chck if it's "bandwith" or "bandwidth" in the original paper +class RAJumpSAE(JumpSAE): + """ + Relaxed Archetypal Jump SAE. + + This class implements a Jump SAE that utilizes a Relaxed Archetypal Dictionary. + """ + + def __init__( + self, + input_shape: int, + nb_concepts: int, + points: torch.Tensor, + bandwith: float = 0.001, + **kwargs: Any + ): + """ + Args: + input_shape (int): Dimension of the input data. + nb_concepts (int): Number of dictionary atoms. + points (torch.Tensor): The data points used for the dictionary. + bandwidth (float): Bandwidth parameter for Jump SAE. + **kwargs: Additional arguments passed to the parent JumpSAE. + """ + # Initialize the parent JumpSAE + super().__init__( + input_shape=input_shape, + nb_concepts=nb_concepts, + bandwith=bandwith, + **kwargs + ) + + # Overwrite the dictionary + self.dictionary = RelaxedArchetypalDictionary( + in_dimensions=input_shape, nb_concepts=nb_concepts, points=points + ) diff --git a/tests/sae/test_ra_sae.py b/tests/sae/test_ra_sae.py new file mode 100644 index 0000000..c41f41f --- /dev/null +++ b/tests/sae/test_ra_sae.py @@ -0,0 +1,41 @@ +import torch +from overcomplete.sae import RATopKSAE, RAJumpSAE + +def test_ra_implementations(): + # Mock data + batch_size = 32 + input_dim = 128 + nb_concepts = 20 + + #RelaxedArchetypalDictionary usually expects points to be [N, input_dim] ??? + points = torch.randn(100, input_dim) + input_data = torch.randn(batch_size, input_dim) + + print("Testing RATopKSAE...") + ra_topk = RATopKSAE( + input_shape=input_dim, + nb_concepts=nb_concepts, + points=points, + top_k=5 + ) + + z_pre, z, x_hat = ra_topk(input_data) + print(f"TopK Output shape: {x_hat.shape}") + assert x_hat.shape == input_data.shape + + print("\nTesting RAJumpSAE...") + ra_jump = RAJumpSAE( + input_shape=input_dim, + nb_concepts=nb_concepts, + points=points, + bandwith=0.001 + ) + + z_pre, z, x_hat = ra_jump(input_data) + print(f"Jump Output shape: {x_hat.shape}") + assert x_hat.shape == input_data.shape + + print("\nSuccess! Both RA classes instantiated and forwarded.") + +if __name__ == "__main__": + test_ra_implementations() \ No newline at end of file From c559bb921b775600e9bee163b205591116fbd470 Mon Sep 17 00:00:00 2001 From: Thomas Fel Date: Thu, 4 Dec 2025 12:34:24 +0100 Subject: [PATCH 2/3] jumpsae: typo with keyword bandwidth --- overcomplete/sae/jump_sae.py | 54 ++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/overcomplete/sae/jump_sae.py b/overcomplete/sae/jump_sae.py index 632be7b..d08d771 100644 --- a/overcomplete/sae/jump_sae.py +++ b/overcomplete/sae/jump_sae.py @@ -15,7 +15,7 @@ class JumpReLU(torch.autograd.Function): JumpReLU activation function with pseudo-gradient for threshold. """ @staticmethod - def forward(ctx, x, threshold, kernel_fn, bandwith): + def forward(ctx, x, threshold, kernel_fn, bandwidth): """ Forward pass of the JumpReLU activation function. Save the necessary variables for the backward pass. @@ -28,11 +28,11 @@ def forward(ctx, x, threshold, kernel_fn, bandwith): Threshold tensor, learnable parameter. kernel_fn : callable Kernel function. - bandwith : float - Bandwith of the kernel. + bandwidth : float + Bandwidth of the kernel. """ ctx.save_for_backward(x, threshold) - ctx.bandwith = bandwith + ctx.bandwidth = bandwidth ctx.kernel_fn = kernel_fn output = x.clone() @@ -52,7 +52,7 @@ def backward(ctx, grad_output): Gradient of the loss w.r.t. the output. """ x, threshold = ctx.saved_tensors - bandwith = ctx.bandwith + bandwidth = ctx.bandwidth kernel_fn = ctx.kernel_fn # gradient w.r.t. input (normal gradient) @@ -61,11 +61,11 @@ def backward(ctx, grad_output): # pseudo-gradient w.r.t. threshold parameters delta = x - threshold - kernel_values = kernel_fn(delta, bandwith) + kernel_values = kernel_fn(delta, bandwidth) # @tfel: we have a singularity at threshold=0, thus the # re-parametrization trick in JumpSAE class - grad_threshold = - (threshold / bandwith) * kernel_values * grad_output + grad_threshold = - (threshold / bandwidth) * kernel_values * grad_output grad_threshold = grad_threshold.sum(0) return grad_input, grad_threshold, None, None @@ -81,7 +81,7 @@ class HeavisidePseudoGradient(torch.autograd.Function): The pseudo-gradient is used to approximate the gradient at the threshold. """ @staticmethod - def forward(ctx, x, threshold, kernel_fn, bandwith): + def forward(ctx, x, threshold, kernel_fn, bandwidth): """ Forward pass of the Heaviside step function. Save the necessary variables for the backward pass. @@ -94,11 +94,11 @@ def forward(ctx, x, threshold, kernel_fn, bandwith): Threshold tensor, learnable parameter. kernel_fn : callable Kernel function. - bandwith : float - Bandwith of the kernel. + bandwidth : float + Bandwidth of the kernel. """ ctx.save_for_backward(x, threshold) - ctx.bandwith = bandwith + ctx.bandwidth = bandwidth ctx.kernel_fn = kernel_fn output = (x > threshold).float() @@ -117,14 +117,14 @@ def backward(ctx, grad_output): Gradient of the loss w.r.t. the output. """ x, threshold = ctx.saved_tensors - bandwith = ctx.bandwith + bandwidth = ctx.bandwidth kernel_fn = ctx.kernel_fn delta = x - threshold - kernel_values = kernel_fn(delta, bandwith) + kernel_values = kernel_fn(delta, bandwidth) # see the paper for the formula - grad_threshold = - (1 / bandwith) * kernel_values * grad_output + grad_threshold = - (1 / bandwidth) * kernel_values * grad_output grad_threshold = grad_threshold.sum(0) grad_input = torch.zeros_like(x) @@ -132,7 +132,7 @@ def backward(ctx, grad_output): return grad_input, grad_threshold, None, None -def jump_relu(x, threshold, kernel_fn, bandwith): +def jump_relu(x, threshold, kernel_fn, bandwidth): """ Apply the JumpReLU activation function to the input tensor. @@ -144,18 +144,18 @@ def jump_relu(x, threshold, kernel_fn, bandwith): Threshold tensor, learnable parameter. kernel_fn : callable Kernel function. - bandwith : float - Bandwith of the kernel. + bandwidth : float + Bandwidth of the kernel. Returns ------- torch.Tensor Output tensor. """ - return JumpReLU.apply(x, threshold, kernel_fn, bandwith) + return JumpReLU.apply(x, threshold, kernel_fn, bandwidth) -def heaviside(x, threshold, kernel_fn, bandwith): +def heaviside(x, threshold, kernel_fn, bandwidth): """ Apply the Heaviside step function to the input tensor. @@ -167,15 +167,15 @@ def heaviside(x, threshold, kernel_fn, bandwith): Threshold tensor, learnable parameter. kernel_fn : callable Kernel function. - bandwith : float - Bandwith of the kernel. + bandwidth : float + Bandwidth of the kernel. Returns ------- torch.Tensor Output tensor. """ - return HeavisidePseudoGradient.apply(x, threshold, kernel_fn, bandwith) + return HeavisidePseudoGradient.apply(x, threshold, kernel_fn, bandwidth) class JumpSAE(SAE): @@ -222,8 +222,8 @@ class JumpSAE(SAE): - 'quartic' - 'silverman' - 'cauchy'. - bandwith : float, optional - Bandwith of the kernel, by default 1e-3. + bandwidth : float, optional + Bandwidth of the kernel, by default 1e-3. encoder_module : nn.Module or string, optional Custom encoder module, by default None. If None, a simple Linear + BatchNorm default encoder is used. @@ -257,7 +257,7 @@ class JumpSAE(SAE): 'cauchy': cauchy_kernel } - def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwith=1e-3, + def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwidth=1e-3, encoder_module=None, dictionary_params=None, device='cpu'): assert isinstance(encoder_module, (str, nn.Module, type(None))) assert isinstance(input_shape, (int, tuple, list)) @@ -267,7 +267,7 @@ def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwith=1e-3, dictionary_params, device) self.kernel_fn = self._KERNELS[kernel] - self.bandwith = torch.tensor(bandwith, device=device) + self.bandwidth = torch.tensor(bandwidth, device=device) # exp(-3) make the thresholds start around 0.05 self.thresholds = nn.Parameter(torch.ones(nb_concepts, device=device)*(-3.0), requires_grad=True) @@ -297,7 +297,7 @@ def encode(self, x): # see paper, appendix J codes = torch.relu(pre_codes) - codes = jump_relu(codes, exp_thresholds, bandwith=self.bandwith, + codes = jump_relu(codes, exp_thresholds, bandwidth=self.bandwidth, kernel_fn=self.kernel_fn) return pre_codes, codes From 6e8c06a6bdb8f4278e47e3b0c7abc4f305e10227 Mon Sep 17 00:00:00 2001 From: Thomas Fel Date: Thu, 4 Dec 2025 12:34:57 +0100 Subject: [PATCH 3/3] rasae: introduce archetypal helper for topk and jumprelu sae & test suite --- overcomplete/sae/__init__.py | 2 +- overcomplete/sae/ra_sae.py | 85 ------------------------ overcomplete/sae/rasae.py | 110 +++++++++++++++++++++++++++++++ tests/sae/test_base_sae.py | 32 ++++++--- tests/sae/test_ra_sae.py | 86 ++++++++++++++---------- tests/sae/test_sae_cuda.py | 19 ++++-- tests/sae/test_sae_dictionary.py | 2 +- tests/sae/test_save_and_load.py | 28 ++++++-- tests/sae/test_train_sae.py | 31 ++++++--- 9 files changed, 246 insertions(+), 149 deletions(-) delete mode 100644 overcomplete/sae/ra_sae.py create mode 100644 overcomplete/sae/rasae.py diff --git a/overcomplete/sae/__init__.py b/overcomplete/sae/__init__.py index cbe7229..61d1fe4 100644 --- a/overcomplete/sae/__init__.py +++ b/overcomplete/sae/__init__.py @@ -12,7 +12,7 @@ from .factory import EncoderFactory from .jump_sae import JumpSAE, jump_relu, heaviside from .topk_sae import TopKSAE -from .ra_sae import RATopKSAE, RAJumpSAE +from .rasae import RATopKSAE, RAJumpSAE from .qsae import QSAE from .batchtopk_sae import BatchTopKSAE from .mp_sae import MpSAE diff --git a/overcomplete/sae/ra_sae.py b/overcomplete/sae/ra_sae.py deleted file mode 100644 index b579857..0000000 --- a/overcomplete/sae/ra_sae.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from torch import nn -from typing import Optional, Any - -from .topk_sae import TopKSAE -from .jump_sae import ( - JumpSAE, -) -from .archetypal_dictionary import RelaxedArchetypalDictionary - - -class RATopKSAE(TopKSAE): - """ - Relaxed Archetypal TopK SAE. - - This class implements a TopK SAE that utilizes a Relaxed Archetypal Dictionary. - The dictionary atoms are initialized/constrained to be convex combinations of - data points. - """ - - def __init__( - self, - input_shape: int, - nb_concepts: int, - points: torch.Tensor, - top_k: int, - **kwargs: Any - ): - """ - Args: - input_shape (int): Dimension of the input data. - nb_concepts (int): Number of dictionary atoms (concepts). - points (torch.Tensor): The data points used to initialize/define the archetypes. - Shape should be (num_points, input_shape). - top_k (int): The k in TopK. - **kwargs: Additional arguments passed to the parent TopKSAE. - """ - # Initialize the parent TopKSAE - super().__init__( - input_shape=input_shape, nb_concepts=nb_concepts, top_k=top_k, **kwargs - ) - - # Overwrite the standard dictionary with the Relaxed Archetypal Dictionary - # We assume points are on the correct device or will be moved via .to(device) later - self.dictionary = RelaxedArchetypalDictionary( - in_dimensions=input_shape, nb_concepts=nb_concepts, points=points - ) - - -# TODO: chck if it's "bandwith" or "bandwidth" in the original paper -class RAJumpSAE(JumpSAE): - """ - Relaxed Archetypal Jump SAE. - - This class implements a Jump SAE that utilizes a Relaxed Archetypal Dictionary. - """ - - def __init__( - self, - input_shape: int, - nb_concepts: int, - points: torch.Tensor, - bandwith: float = 0.001, - **kwargs: Any - ): - """ - Args: - input_shape (int): Dimension of the input data. - nb_concepts (int): Number of dictionary atoms. - points (torch.Tensor): The data points used for the dictionary. - bandwidth (float): Bandwidth parameter for Jump SAE. - **kwargs: Additional arguments passed to the parent JumpSAE. - """ - # Initialize the parent JumpSAE - super().__init__( - input_shape=input_shape, - nb_concepts=nb_concepts, - bandwith=bandwith, - **kwargs - ) - - # Overwrite the dictionary - self.dictionary = RelaxedArchetypalDictionary( - in_dimensions=input_shape, nb_concepts=nb_concepts, points=points - ) diff --git a/overcomplete/sae/rasae.py b/overcomplete/sae/rasae.py new file mode 100644 index 0000000..35380b9 --- /dev/null +++ b/overcomplete/sae/rasae.py @@ -0,0 +1,110 @@ +""" +Module for Relaxed Archetypal SAE implementations. +For the implementation of the Relaxed Archetypal Dictionary, see archetypal_dictionary.py. +""" + +import torch +from torch import nn + +from .topk_sae import TopKSAE +from .jump_sae import JumpSAE +from .archetypal_dictionary import RelaxedArchetypalDictionary + + +class RATopKSAE(TopKSAE): + """ + Relaxed Archetypal TopK SAE. + + This class implements a TopK SAE that utilizes a Relaxed Archetypal Dictionary. + The dictionary atoms are initialized and constrained to be convex combinations + of data points. + + For more information, see: + - "Archetypal SAE: Adaptive and Stable Dictionary Learning for Concept Extraction in + Large Vision Models" by T. Fel et al., ICML 2025 (https://arxiv.org/abs/2502.12892). + + Parameters + ---------- + input_shape : int + Dimensionality of the input data (excluding the batch dimension). + nb_concepts : int + Number of dictionary atoms (concepts). + points : torch.Tensor + The data points used to initialize/define the archetypes. + Shape should be (num_points, input_shape). + top_k : int + Number of top activations to keep in the latent representation. + By default, 10% sparsity is used. + delta : float, optional + Delta parameter for the archetypal dictionary, by default 1.0. + use_multiplier : bool, optional + Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter + is 3 then the dictionary atoms are all on the ball of radius 3). By default True. + **kwargs : dict, optional + Additional arguments passed to the parent TopKSAE (e.g., encoder_module, device). + """ + + def __init__(self, input_shape, nb_concepts, points, top_k=None, delta=1.0, use_multiplier=True, **kwargs): + assert isinstance(input_shape, int), "RATopKSAE input_shape must be an integer." + + super().__init__(input_shape=input_shape, nb_concepts=nb_concepts, + top_k=top_k, **kwargs) + + # enforce archetypal dictionary after the init of the parent class + self.dictionary = RelaxedArchetypalDictionary( + in_dimensions=input_shape, + nb_concepts=nb_concepts, + points=points, + delta=delta, + use_multiplier=use_multiplier, + device=self.device + ) + + +class RAJumpSAE(JumpSAE): + """ + Relaxed Archetypal Jump SAE. + + This class implements a Jump SAE that utilizes a Relaxed Archetypal Dictionary. + The dictionary atoms are initialized and constrained to be convex combinations + of data points. + + For more information, see: + - "Archetypal SAE: Adaptive and Stable Dictionary Learning for Concept Extraction in + Large Vision Models" by T. Fel et al., ICML 2025 (https://arxiv.org/abs/2502.12892). + + Parameters + ---------- + input_shape : int + Dimensionality of the input data (excluding the batch dimension). + nb_concepts : int + Number of dictionary atoms (concepts). + points : torch.Tensor + The data points used to initialize/define the archetypes. + Shape should be (num_points, input_shape). + bandwidth : float, optional + Bandwidth parameter for the Jump SAE kernel, by default 1e-3. + delta : float, optional + Delta parameter for the archetypal dictionary, by default 1.0. + use_multiplier : bool, optional + Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter + is 3 then the dictionary atoms are all on the ball of radius 3). By default True. + **kwargs : dict, optional + Additional arguments passed to the parent JumpSAE (e.g., encoder_module, device). + """ + + def __init__(self, input_shape, nb_concepts, points, bandwidth=1e-3, delta=1.0, use_multiplier=True, **kwargs): + assert isinstance(input_shape, int), "RAJumpSAE input_shape must be an integer." + + super().__init__(input_shape=input_shape, nb_concepts=nb_concepts, + bandwidth=bandwidth, **kwargs) + + # enforce archetypal dictionary after the init of the parent class + self.dictionary = RelaxedArchetypalDictionary( + in_dimensions=input_shape, + nb_concepts=nb_concepts, + points=points, + delta=delta, + use_multiplier=use_multiplier, + device=self.device + ) diff --git a/tests/sae/test_base_sae.py b/tests/sae/test_base_sae.py index fc7482c..4089053 100644 --- a/tests/sae/test_base_sae.py +++ b/tests/sae/test_base_sae.py @@ -1,12 +1,21 @@ import pytest import torch -from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal -all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE] + + +def get_sae_kwargs(sae_class, input_size, nb_concepts, device): + """Return specific kwargs required for certain SAE classes.""" + kwargs = {} + # archetypal SAEs require 'points' + if sae_class in [RATopKSAE, RAJumpSAE]: + kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device) + return kwargs def test_dictionary_layer(): @@ -23,7 +32,9 @@ def test_dictionary_layer(): def test_sae(sae_class): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) x = torch.randn(3, input_size) output = model(x) @@ -43,13 +54,15 @@ def test_sae_device(sae_class): input_size = 10 nb_components = 5 - model = sae_class(input_size, nb_components, device='meta') + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_components, device='meta') + model = sae_class(input_size, nb_components, device='meta', **extra_kwargs) # ensure dictionary is on the meta device dictionary = model.get_dictionary() assert dictionary.device.type == 'meta' - model = sae_class(input_size, nb_components, device='cpu') + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_components, device='cpu') + model = sae_class(input_size, nb_components, device='cpu', **extra_kwargs) # ensure dictionary is on the cpu device dictionary = model.get_dictionary() @@ -111,7 +124,8 @@ def test_sae_tied_untied(sae_class): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) # Tie weights model.tied() @@ -130,7 +144,8 @@ def test_sae_tied_forward(sae_class): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) model.tied() x = torch.randn(3, input_size) @@ -146,7 +161,8 @@ def test_sae_untied_copy_weights(sae_class): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) model.tied() # Get dictionary before untying diff --git a/tests/sae/test_ra_sae.py b/tests/sae/test_ra_sae.py index c41f41f..4d2d263 100644 --- a/tests/sae/test_ra_sae.py +++ b/tests/sae/test_ra_sae.py @@ -1,41 +1,55 @@ +import pytest import torch + from overcomplete.sae import RATopKSAE, RAJumpSAE -def test_ra_implementations(): - # Mock data - batch_size = 32 - input_dim = 128 - nb_concepts = 20 - - #RelaxedArchetypalDictionary usually expects points to be [N, input_dim] ??? - points = torch.randn(100, input_dim) - input_data = torch.randn(batch_size, input_dim) - - print("Testing RATopKSAE...") - ra_topk = RATopKSAE( - input_shape=input_dim, - nb_concepts=nb_concepts, - points=points, - top_k=5 - ) - - z_pre, z, x_hat = ra_topk(input_data) - print(f"TopK Output shape: {x_hat.shape}") - assert x_hat.shape == input_data.shape +INPUT_SIZE = 128 +NB_CONCEPTS = 20 +POINTS_COUNT = 100 +BATCH_SIZE = 32 +TOP_K = 5 +BANDWIDTH = 0.001 - print("\nTesting RAJumpSAE...") - ra_jump = RAJumpSAE( - input_shape=input_dim, - nb_concepts=nb_concepts, - points=points, - bandwith=0.001 - ) - - z_pre, z, x_hat = ra_jump(input_data) - print(f"Jump Output shape: {x_hat.shape}") - assert x_hat.shape == input_data.shape - - print("\nSuccess! Both RA classes instantiated and forwarded.") -if __name__ == "__main__": - test_ra_implementations() \ No newline at end of file +@pytest.mark.parametrize("device", ['cpu']) +@pytest.mark.parametrize("ra_class", [RATopKSAE, RAJumpSAE]) +def test_ra_sae_device_propagation(device, ra_class): + # points need to be on the same device as the module for initialization + points = torch.randn(POINTS_COUNT, INPUT_SIZE, device=device) + + if ra_class == RATopKSAE: + model = ra_class(INPUT_SIZE, NB_CONCEPTS, points=points, top_k=TOP_K, device=device) + else: + model = ra_class(INPUT_SIZE, NB_CONCEPTS, points=points, bandwidth=BANDWIDTH, device=device) + + # check encoder parameters + for param in model.encoder.parameters(): + assert param.device.type == device + + # check dictionary parameters + for param in model.dictionary.parameters(): + assert param.device.type == device + + # check all parameters + for param in model.parameters(): + assert param.device.type == device + + +@pytest.mark.parametrize("ra_class", [RATopKSAE, RAJumpSAE]) +def test_ra_sae_forward_shape(ra_class): + # run forward pass on cpu to ensure shape correctness + device = 'cpu' + points = torch.randn(POINTS_COUNT, INPUT_SIZE, device=device) + input_data = torch.randn(BATCH_SIZE, INPUT_SIZE, device=device) + + if ra_class == RATopKSAE: + model = ra_class(INPUT_SIZE, NB_CONCEPTS, points=points, top_k=TOP_K, device=device) + else: + model = ra_class(INPUT_SIZE, NB_CONCEPTS, points=points, bandwidth=BANDWIDTH, device=device) + + z_pre, z, x_hat = model(input_data) + + # check output shapes + assert x_hat.shape == input_data.shape + assert z.shape == (BATCH_SIZE, NB_CONCEPTS) + assert z_pre.shape == (BATCH_SIZE, NB_CONCEPTS) diff --git a/tests/sae/test_sae_cuda.py b/tests/sae/test_sae_cuda.py index eb295ad..e262f64 100644 --- a/tests/sae/test_sae_cuda.py +++ b/tests/sae/test_sae_cuda.py @@ -1,7 +1,8 @@ import pytest - +import torch from overcomplete.sae import (MLPEncoder, AttentionEncoder, ResNetEncoder, - EncoderFactory, SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE) + EncoderFactory, SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, + RATopKSAE, RAJumpSAE) INPUT_SIZE = 32 @@ -11,7 +12,16 @@ HEIGHT = 7 WIDTH = 7 -all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE] + + +def get_sae_kwargs(sae_class, input_size, nb_concepts, device): + """Return specific kwargs required for certain SAE classes.""" + kwargs = {} + # archetypal SAEs require 'points' + if sae_class in [RATopKSAE, RAJumpSAE]: + kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device) + return kwargs @pytest.mark.parametrize("device", ['cpu', 'meta']) @@ -41,7 +51,8 @@ def test_resnet_encoder_device_propagation(device): @pytest.mark.parametrize("device, ", ['cpu', 'meta']) @pytest.mark.parametrize("sae_class", all_sae) def test_default_sae_device_propagation(device, sae_class): - model = sae_class(32, 5, encoder_module=None, device=device) + extra_kwargs = get_sae_kwargs(sae_class, INPUT_SIZE, N_COMPONENTS, device=device) + model = sae_class(32, 5, encoder_module=None, device=device, **extra_kwargs) for param in model.encoder.parameters(): assert param.device.type == device diff --git a/tests/sae/test_sae_dictionary.py b/tests/sae/test_sae_dictionary.py index bc2d8f3..66e5649 100644 --- a/tests/sae/test_sae_dictionary.py +++ b/tests/sae/test_sae_dictionary.py @@ -1,7 +1,7 @@ import torch import pytest -from overcomplete.sae import DictionaryLayer, SAE, QSAE, TopKSAE, JumpSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae import DictionaryLayer, SAE, QSAE, TopKSAE, JumpSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal diff --git a/tests/sae/test_save_and_load.py b/tests/sae/test_save_and_load.py index f61c21d..5ae9219 100644 --- a/tests/sae/test_save_and_load.py +++ b/tests/sae/test_save_and_load.py @@ -2,18 +2,27 @@ import pytest import torch -from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal -all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE] def _load(path): return torch.load(path, map_location="cpu", weights_only=False) +def get_sae_kwargs(sae_class, input_size, nb_concepts, device): + """Return specific kwargs required for certain SAE classes.""" + kwargs = {} + # archetypal SAEs require 'points' + if sae_class in [RATopKSAE, RAJumpSAE]: + kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device) + return kwargs + + @pytest.mark.parametrize("nb_concepts, dimensions", [(5, 10)]) def test_save_and_load_dictionary_layer(nb_concepts, dimensions, tmp_path): # Initialize and run layer @@ -41,7 +50,9 @@ def test_save_and_load_dictionary_layer(nb_concepts, dimensions, tmp_path): def test_save_and_load_sae(sae_class, tmp_path): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) x = torch.randn(3, input_size) output = model(x) @@ -69,7 +80,9 @@ def test_save_and_load_sae(sae_class, tmp_path): def test_eval_and_save_sae(sae_class, tmp_path): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) x = torch.randn(3, input_size) output = model(x) @@ -99,7 +112,8 @@ def test_save_and_load_tied_sae(sae_class, tmp_path): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) model.tied() x = torch.randn(3, input_size) @@ -128,7 +142,9 @@ def test_save_and_load_untied_with_copy(sae_class, tmp_path): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) + model.tied() model.untied(copy_from_dictionary=True) diff --git a/tests/sae/test_train_sae.py b/tests/sae/test_train_sae.py index 4359aff..bdfbf88 100644 --- a/tests/sae/test_train_sae.py +++ b/tests/sae/test_train_sae.py @@ -8,15 +8,24 @@ from overcomplete.sae.train import train_sae, train_sae_amp from overcomplete.sae.losses import mse_l1 -from overcomplete.sae import SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae import SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal -all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE] saes_attention_conv_format = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE] +def get_sae_kwargs(sae_class, input_size, nb_concepts, device): + """Return specific kwargs required for certain SAE classes.""" + kwargs = {} + # archetypal SAEs require 'points' + if sae_class in [RATopKSAE, RAJumpSAE]: + kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device) + return kwargs + + @pytest.mark.parametrize( "module_name", [ @@ -38,7 +47,8 @@ def test_train_mlp_sae(module_name, sae_class): criterion = mse_l1 n_components = 2 - model = sae_class(data.shape[1], n_components, encoder_module=module_name) + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1], n_components, device='cpu') + model = sae_class(data.shape[1], n_components, encoder_module=module_name, **extra_kwargs) optimizer = optim.SGD(model.parameters(), lr=0.001) scheduler = None @@ -90,7 +100,8 @@ def criterion(x, x_hat, z_pre, z, dictionary): dataloader = DataLoader(dataset, batch_size=10) n_components = 2 - model = sae_class(data.shape[1:], n_components, encoder_module="resnet_3b") + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1:], n_components, device='cpu') + model = sae_class(data.shape[1:], n_components, encoder_module="resnet_3b", **extra_kwargs) optimizer = optim.SGD(model.parameters(), lr=0.001) scheduler = None @@ -121,7 +132,8 @@ def criterion(x, x_hat, z_pre, z, dictionary): dataloader = DataLoader(dataset, batch_size=10) n_components = 2 - model = sae_class(data.shape[1:], n_components, encoder_module="attention_3b") + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1:], n_components, device='cpu') + model = sae_class(data.shape[1:], n_components, encoder_module="attention_3b", **extra_kwargs) optimizer = optim.SGD(model.parameters(), lr=0.001) scheduler = None @@ -159,7 +171,8 @@ def test_train_without_amp(module_name, sae_class): criterion = mse_l1 n_components = 2 - model = sae_class(data.shape[1], n_components, encoder_module=module_name) + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1], n_components, device='cpu') + model = sae_class(data.shape[1], n_components, encoder_module=module_name, **extra_kwargs) optimizer = optim.SGD(model.parameters(), lr=0.001) scheduler = None @@ -333,7 +346,8 @@ def test_train_tied_sae(sae_class): criterion = mse_l1 n_components = 2 - model = sae_class(data.shape[1], n_components) + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1], n_components, device='cpu') + model = sae_class(data.shape[1], n_components, **extra_kwargs) model.tied() optimizer = optim.SGD(model.parameters(), lr=0.001) @@ -361,7 +375,8 @@ def test_train_untied_after_tied(sae_class): criterion = mse_l1 n_components = 2 - model = sae_class(data.shape[1], n_components) + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1], n_components, device='cpu') + model = sae_class(data.shape[1], n_components, **extra_kwargs) model.tied() # Train tied