diff --git a/overcomplete/sae/__init__.py b/overcomplete/sae/__init__.py index d92f814..61d1fe4 100644 --- a/overcomplete/sae/__init__.py +++ b/overcomplete/sae/__init__.py @@ -12,6 +12,7 @@ from .factory import EncoderFactory from .jump_sae import JumpSAE, jump_relu, heaviside from .topk_sae import TopKSAE +from .rasae import RATopKSAE, RAJumpSAE from .qsae import QSAE from .batchtopk_sae import BatchTopKSAE from .mp_sae import MpSAE diff --git a/overcomplete/sae/jump_sae.py b/overcomplete/sae/jump_sae.py index 632be7b..d08d771 100644 --- a/overcomplete/sae/jump_sae.py +++ b/overcomplete/sae/jump_sae.py @@ -15,7 +15,7 @@ class JumpReLU(torch.autograd.Function): JumpReLU activation function with pseudo-gradient for threshold. """ @staticmethod - def forward(ctx, x, threshold, kernel_fn, bandwith): + def forward(ctx, x, threshold, kernel_fn, bandwidth): """ Forward pass of the JumpReLU activation function. Save the necessary variables for the backward pass. @@ -28,11 +28,11 @@ def forward(ctx, x, threshold, kernel_fn, bandwith): Threshold tensor, learnable parameter. kernel_fn : callable Kernel function. - bandwith : float - Bandwith of the kernel. + bandwidth : float + Bandwidth of the kernel. """ ctx.save_for_backward(x, threshold) - ctx.bandwith = bandwith + ctx.bandwidth = bandwidth ctx.kernel_fn = kernel_fn output = x.clone() @@ -52,7 +52,7 @@ def backward(ctx, grad_output): Gradient of the loss w.r.t. the output. """ x, threshold = ctx.saved_tensors - bandwith = ctx.bandwith + bandwidth = ctx.bandwidth kernel_fn = ctx.kernel_fn # gradient w.r.t. input (normal gradient) @@ -61,11 +61,11 @@ def backward(ctx, grad_output): # pseudo-gradient w.r.t. threshold parameters delta = x - threshold - kernel_values = kernel_fn(delta, bandwith) + kernel_values = kernel_fn(delta, bandwidth) # @tfel: we have a singularity at threshold=0, thus the # re-parametrization trick in JumpSAE class - grad_threshold = - (threshold / bandwith) * kernel_values * grad_output + grad_threshold = - (threshold / bandwidth) * kernel_values * grad_output grad_threshold = grad_threshold.sum(0) return grad_input, grad_threshold, None, None @@ -81,7 +81,7 @@ class HeavisidePseudoGradient(torch.autograd.Function): The pseudo-gradient is used to approximate the gradient at the threshold. """ @staticmethod - def forward(ctx, x, threshold, kernel_fn, bandwith): + def forward(ctx, x, threshold, kernel_fn, bandwidth): """ Forward pass of the Heaviside step function. Save the necessary variables for the backward pass. @@ -94,11 +94,11 @@ def forward(ctx, x, threshold, kernel_fn, bandwith): Threshold tensor, learnable parameter. kernel_fn : callable Kernel function. - bandwith : float - Bandwith of the kernel. + bandwidth : float + Bandwidth of the kernel. """ ctx.save_for_backward(x, threshold) - ctx.bandwith = bandwith + ctx.bandwidth = bandwidth ctx.kernel_fn = kernel_fn output = (x > threshold).float() @@ -117,14 +117,14 @@ def backward(ctx, grad_output): Gradient of the loss w.r.t. the output. """ x, threshold = ctx.saved_tensors - bandwith = ctx.bandwith + bandwidth = ctx.bandwidth kernel_fn = ctx.kernel_fn delta = x - threshold - kernel_values = kernel_fn(delta, bandwith) + kernel_values = kernel_fn(delta, bandwidth) # see the paper for the formula - grad_threshold = - (1 / bandwith) * kernel_values * grad_output + grad_threshold = - (1 / bandwidth) * kernel_values * grad_output grad_threshold = grad_threshold.sum(0) grad_input = torch.zeros_like(x) @@ -132,7 +132,7 @@ def backward(ctx, grad_output): return grad_input, grad_threshold, None, None -def jump_relu(x, threshold, kernel_fn, bandwith): +def jump_relu(x, threshold, kernel_fn, bandwidth): """ Apply the JumpReLU activation function to the input tensor. @@ -144,18 +144,18 @@ def jump_relu(x, threshold, kernel_fn, bandwith): Threshold tensor, learnable parameter. kernel_fn : callable Kernel function. - bandwith : float - Bandwith of the kernel. + bandwidth : float + Bandwidth of the kernel. Returns ------- torch.Tensor Output tensor. """ - return JumpReLU.apply(x, threshold, kernel_fn, bandwith) + return JumpReLU.apply(x, threshold, kernel_fn, bandwidth) -def heaviside(x, threshold, kernel_fn, bandwith): +def heaviside(x, threshold, kernel_fn, bandwidth): """ Apply the Heaviside step function to the input tensor. @@ -167,15 +167,15 @@ def heaviside(x, threshold, kernel_fn, bandwith): Threshold tensor, learnable parameter. kernel_fn : callable Kernel function. - bandwith : float - Bandwith of the kernel. + bandwidth : float + Bandwidth of the kernel. Returns ------- torch.Tensor Output tensor. """ - return HeavisidePseudoGradient.apply(x, threshold, kernel_fn, bandwith) + return HeavisidePseudoGradient.apply(x, threshold, kernel_fn, bandwidth) class JumpSAE(SAE): @@ -222,8 +222,8 @@ class JumpSAE(SAE): - 'quartic' - 'silverman' - 'cauchy'. - bandwith : float, optional - Bandwith of the kernel, by default 1e-3. + bandwidth : float, optional + Bandwidth of the kernel, by default 1e-3. encoder_module : nn.Module or string, optional Custom encoder module, by default None. If None, a simple Linear + BatchNorm default encoder is used. @@ -257,7 +257,7 @@ class JumpSAE(SAE): 'cauchy': cauchy_kernel } - def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwith=1e-3, + def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwidth=1e-3, encoder_module=None, dictionary_params=None, device='cpu'): assert isinstance(encoder_module, (str, nn.Module, type(None))) assert isinstance(input_shape, (int, tuple, list)) @@ -267,7 +267,7 @@ def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwith=1e-3, dictionary_params, device) self.kernel_fn = self._KERNELS[kernel] - self.bandwith = torch.tensor(bandwith, device=device) + self.bandwidth = torch.tensor(bandwidth, device=device) # exp(-3) make the thresholds start around 0.05 self.thresholds = nn.Parameter(torch.ones(nb_concepts, device=device)*(-3.0), requires_grad=True) @@ -297,7 +297,7 @@ def encode(self, x): # see paper, appendix J codes = torch.relu(pre_codes) - codes = jump_relu(codes, exp_thresholds, bandwith=self.bandwith, + codes = jump_relu(codes, exp_thresholds, bandwidth=self.bandwidth, kernel_fn=self.kernel_fn) return pre_codes, codes diff --git a/overcomplete/sae/rasae.py b/overcomplete/sae/rasae.py new file mode 100644 index 0000000..35380b9 --- /dev/null +++ b/overcomplete/sae/rasae.py @@ -0,0 +1,110 @@ +""" +Module for Relaxed Archetypal SAE implementations. +For the implementation of the Relaxed Archetypal Dictionary, see archetypal_dictionary.py. +""" + +import torch +from torch import nn + +from .topk_sae import TopKSAE +from .jump_sae import JumpSAE +from .archetypal_dictionary import RelaxedArchetypalDictionary + + +class RATopKSAE(TopKSAE): + """ + Relaxed Archetypal TopK SAE. + + This class implements a TopK SAE that utilizes a Relaxed Archetypal Dictionary. + The dictionary atoms are initialized and constrained to be convex combinations + of data points. + + For more information, see: + - "Archetypal SAE: Adaptive and Stable Dictionary Learning for Concept Extraction in + Large Vision Models" by T. Fel et al., ICML 2025 (https://arxiv.org/abs/2502.12892). + + Parameters + ---------- + input_shape : int + Dimensionality of the input data (excluding the batch dimension). + nb_concepts : int + Number of dictionary atoms (concepts). + points : torch.Tensor + The data points used to initialize/define the archetypes. + Shape should be (num_points, input_shape). + top_k : int + Number of top activations to keep in the latent representation. + By default, 10% sparsity is used. + delta : float, optional + Delta parameter for the archetypal dictionary, by default 1.0. + use_multiplier : bool, optional + Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter + is 3 then the dictionary atoms are all on the ball of radius 3). By default True. + **kwargs : dict, optional + Additional arguments passed to the parent TopKSAE (e.g., encoder_module, device). + """ + + def __init__(self, input_shape, nb_concepts, points, top_k=None, delta=1.0, use_multiplier=True, **kwargs): + assert isinstance(input_shape, int), "RATopKSAE input_shape must be an integer." + + super().__init__(input_shape=input_shape, nb_concepts=nb_concepts, + top_k=top_k, **kwargs) + + # enforce archetypal dictionary after the init of the parent class + self.dictionary = RelaxedArchetypalDictionary( + in_dimensions=input_shape, + nb_concepts=nb_concepts, + points=points, + delta=delta, + use_multiplier=use_multiplier, + device=self.device + ) + + +class RAJumpSAE(JumpSAE): + """ + Relaxed Archetypal Jump SAE. + + This class implements a Jump SAE that utilizes a Relaxed Archetypal Dictionary. + The dictionary atoms are initialized and constrained to be convex combinations + of data points. + + For more information, see: + - "Archetypal SAE: Adaptive and Stable Dictionary Learning for Concept Extraction in + Large Vision Models" by T. Fel et al., ICML 2025 (https://arxiv.org/abs/2502.12892). + + Parameters + ---------- + input_shape : int + Dimensionality of the input data (excluding the batch dimension). + nb_concepts : int + Number of dictionary atoms (concepts). + points : torch.Tensor + The data points used to initialize/define the archetypes. + Shape should be (num_points, input_shape). + bandwidth : float, optional + Bandwidth parameter for the Jump SAE kernel, by default 1e-3. + delta : float, optional + Delta parameter for the archetypal dictionary, by default 1.0. + use_multiplier : bool, optional + Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter + is 3 then the dictionary atoms are all on the ball of radius 3). By default True. + **kwargs : dict, optional + Additional arguments passed to the parent JumpSAE (e.g., encoder_module, device). + """ + + def __init__(self, input_shape, nb_concepts, points, bandwidth=1e-3, delta=1.0, use_multiplier=True, **kwargs): + assert isinstance(input_shape, int), "RAJumpSAE input_shape must be an integer." + + super().__init__(input_shape=input_shape, nb_concepts=nb_concepts, + bandwidth=bandwidth, **kwargs) + + # enforce archetypal dictionary after the init of the parent class + self.dictionary = RelaxedArchetypalDictionary( + in_dimensions=input_shape, + nb_concepts=nb_concepts, + points=points, + delta=delta, + use_multiplier=use_multiplier, + device=self.device + ) diff --git a/tests/sae/test_base_sae.py b/tests/sae/test_base_sae.py index fc7482c..4089053 100644 --- a/tests/sae/test_base_sae.py +++ b/tests/sae/test_base_sae.py @@ -1,12 +1,21 @@ import pytest import torch -from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal -all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE] + + +def get_sae_kwargs(sae_class, input_size, nb_concepts, device): + """Return specific kwargs required for certain SAE classes.""" + kwargs = {} + # archetypal SAEs require 'points' + if sae_class in [RATopKSAE, RAJumpSAE]: + kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device) + return kwargs def test_dictionary_layer(): @@ -23,7 +32,9 @@ def test_dictionary_layer(): def test_sae(sae_class): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) x = torch.randn(3, input_size) output = model(x) @@ -43,13 +54,15 @@ def test_sae_device(sae_class): input_size = 10 nb_components = 5 - model = sae_class(input_size, nb_components, device='meta') + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_components, device='meta') + model = sae_class(input_size, nb_components, device='meta', **extra_kwargs) # ensure dictionary is on the meta device dictionary = model.get_dictionary() assert dictionary.device.type == 'meta' - model = sae_class(input_size, nb_components, device='cpu') + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_components, device='cpu') + model = sae_class(input_size, nb_components, device='cpu', **extra_kwargs) # ensure dictionary is on the cpu device dictionary = model.get_dictionary() @@ -111,7 +124,8 @@ def test_sae_tied_untied(sae_class): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) # Tie weights model.tied() @@ -130,7 +144,8 @@ def test_sae_tied_forward(sae_class): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) model.tied() x = torch.randn(3, input_size) @@ -146,7 +161,8 @@ def test_sae_untied_copy_weights(sae_class): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) model.tied() # Get dictionary before untying diff --git a/tests/sae/test_ra_sae.py b/tests/sae/test_ra_sae.py new file mode 100644 index 0000000..4d2d263 --- /dev/null +++ b/tests/sae/test_ra_sae.py @@ -0,0 +1,55 @@ +import pytest +import torch + +from overcomplete.sae import RATopKSAE, RAJumpSAE + +INPUT_SIZE = 128 +NB_CONCEPTS = 20 +POINTS_COUNT = 100 +BATCH_SIZE = 32 +TOP_K = 5 +BANDWIDTH = 0.001 + + +@pytest.mark.parametrize("device", ['cpu']) +@pytest.mark.parametrize("ra_class", [RATopKSAE, RAJumpSAE]) +def test_ra_sae_device_propagation(device, ra_class): + # points need to be on the same device as the module for initialization + points = torch.randn(POINTS_COUNT, INPUT_SIZE, device=device) + + if ra_class == RATopKSAE: + model = ra_class(INPUT_SIZE, NB_CONCEPTS, points=points, top_k=TOP_K, device=device) + else: + model = ra_class(INPUT_SIZE, NB_CONCEPTS, points=points, bandwidth=BANDWIDTH, device=device) + + # check encoder parameters + for param in model.encoder.parameters(): + assert param.device.type == device + + # check dictionary parameters + for param in model.dictionary.parameters(): + assert param.device.type == device + + # check all parameters + for param in model.parameters(): + assert param.device.type == device + + +@pytest.mark.parametrize("ra_class", [RATopKSAE, RAJumpSAE]) +def test_ra_sae_forward_shape(ra_class): + # run forward pass on cpu to ensure shape correctness + device = 'cpu' + points = torch.randn(POINTS_COUNT, INPUT_SIZE, device=device) + input_data = torch.randn(BATCH_SIZE, INPUT_SIZE, device=device) + + if ra_class == RATopKSAE: + model = ra_class(INPUT_SIZE, NB_CONCEPTS, points=points, top_k=TOP_K, device=device) + else: + model = ra_class(INPUT_SIZE, NB_CONCEPTS, points=points, bandwidth=BANDWIDTH, device=device) + + z_pre, z, x_hat = model(input_data) + + # check output shapes + assert x_hat.shape == input_data.shape + assert z.shape == (BATCH_SIZE, NB_CONCEPTS) + assert z_pre.shape == (BATCH_SIZE, NB_CONCEPTS) diff --git a/tests/sae/test_sae_cuda.py b/tests/sae/test_sae_cuda.py index eb295ad..e262f64 100644 --- a/tests/sae/test_sae_cuda.py +++ b/tests/sae/test_sae_cuda.py @@ -1,7 +1,8 @@ import pytest - +import torch from overcomplete.sae import (MLPEncoder, AttentionEncoder, ResNetEncoder, - EncoderFactory, SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE) + EncoderFactory, SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, + RATopKSAE, RAJumpSAE) INPUT_SIZE = 32 @@ -11,7 +12,16 @@ HEIGHT = 7 WIDTH = 7 -all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE] + + +def get_sae_kwargs(sae_class, input_size, nb_concepts, device): + """Return specific kwargs required for certain SAE classes.""" + kwargs = {} + # archetypal SAEs require 'points' + if sae_class in [RATopKSAE, RAJumpSAE]: + kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device) + return kwargs @pytest.mark.parametrize("device", ['cpu', 'meta']) @@ -41,7 +51,8 @@ def test_resnet_encoder_device_propagation(device): @pytest.mark.parametrize("device, ", ['cpu', 'meta']) @pytest.mark.parametrize("sae_class", all_sae) def test_default_sae_device_propagation(device, sae_class): - model = sae_class(32, 5, encoder_module=None, device=device) + extra_kwargs = get_sae_kwargs(sae_class, INPUT_SIZE, N_COMPONENTS, device=device) + model = sae_class(32, 5, encoder_module=None, device=device, **extra_kwargs) for param in model.encoder.parameters(): assert param.device.type == device diff --git a/tests/sae/test_sae_dictionary.py b/tests/sae/test_sae_dictionary.py index bc2d8f3..66e5649 100644 --- a/tests/sae/test_sae_dictionary.py +++ b/tests/sae/test_sae_dictionary.py @@ -1,7 +1,7 @@ import torch import pytest -from overcomplete.sae import DictionaryLayer, SAE, QSAE, TopKSAE, JumpSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae import DictionaryLayer, SAE, QSAE, TopKSAE, JumpSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal diff --git a/tests/sae/test_save_and_load.py b/tests/sae/test_save_and_load.py index f61c21d..5ae9219 100644 --- a/tests/sae/test_save_and_load.py +++ b/tests/sae/test_save_and_load.py @@ -2,18 +2,27 @@ import pytest import torch -from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal -all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE] def _load(path): return torch.load(path, map_location="cpu", weights_only=False) +def get_sae_kwargs(sae_class, input_size, nb_concepts, device): + """Return specific kwargs required for certain SAE classes.""" + kwargs = {} + # archetypal SAEs require 'points' + if sae_class in [RATopKSAE, RAJumpSAE]: + kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device) + return kwargs + + @pytest.mark.parametrize("nb_concepts, dimensions", [(5, 10)]) def test_save_and_load_dictionary_layer(nb_concepts, dimensions, tmp_path): # Initialize and run layer @@ -41,7 +50,9 @@ def test_save_and_load_dictionary_layer(nb_concepts, dimensions, tmp_path): def test_save_and_load_sae(sae_class, tmp_path): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) x = torch.randn(3, input_size) output = model(x) @@ -69,7 +80,9 @@ def test_save_and_load_sae(sae_class, tmp_path): def test_eval_and_save_sae(sae_class, tmp_path): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) x = torch.randn(3, input_size) output = model(x) @@ -99,7 +112,8 @@ def test_save_and_load_tied_sae(sae_class, tmp_path): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) model.tied() x = torch.randn(3, input_size) @@ -128,7 +142,9 @@ def test_save_and_load_untied_with_copy(sae_class, tmp_path): input_size = 10 nb_concepts = 5 - model = sae_class(input_size, nb_concepts) + extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu') + model = sae_class(input_size, nb_concepts, **extra_kwargs) + model.tied() model.untied(copy_from_dictionary=True) diff --git a/tests/sae/test_train_sae.py b/tests/sae/test_train_sae.py index 4359aff..bdfbf88 100644 --- a/tests/sae/test_train_sae.py +++ b/tests/sae/test_train_sae.py @@ -8,15 +8,24 @@ from overcomplete.sae.train import train_sae, train_sae_amp from overcomplete.sae.losses import mse_l1 -from overcomplete.sae import SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE +from overcomplete.sae import SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE from overcomplete.sae.modules import TieableEncoder from ..utils import epsilon_equal -all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE] +all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE] saes_attention_conv_format = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE] +def get_sae_kwargs(sae_class, input_size, nb_concepts, device): + """Return specific kwargs required for certain SAE classes.""" + kwargs = {} + # archetypal SAEs require 'points' + if sae_class in [RATopKSAE, RAJumpSAE]: + kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device) + return kwargs + + @pytest.mark.parametrize( "module_name", [ @@ -38,7 +47,8 @@ def test_train_mlp_sae(module_name, sae_class): criterion = mse_l1 n_components = 2 - model = sae_class(data.shape[1], n_components, encoder_module=module_name) + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1], n_components, device='cpu') + model = sae_class(data.shape[1], n_components, encoder_module=module_name, **extra_kwargs) optimizer = optim.SGD(model.parameters(), lr=0.001) scheduler = None @@ -90,7 +100,8 @@ def criterion(x, x_hat, z_pre, z, dictionary): dataloader = DataLoader(dataset, batch_size=10) n_components = 2 - model = sae_class(data.shape[1:], n_components, encoder_module="resnet_3b") + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1:], n_components, device='cpu') + model = sae_class(data.shape[1:], n_components, encoder_module="resnet_3b", **extra_kwargs) optimizer = optim.SGD(model.parameters(), lr=0.001) scheduler = None @@ -121,7 +132,8 @@ def criterion(x, x_hat, z_pre, z, dictionary): dataloader = DataLoader(dataset, batch_size=10) n_components = 2 - model = sae_class(data.shape[1:], n_components, encoder_module="attention_3b") + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1:], n_components, device='cpu') + model = sae_class(data.shape[1:], n_components, encoder_module="attention_3b", **extra_kwargs) optimizer = optim.SGD(model.parameters(), lr=0.001) scheduler = None @@ -159,7 +171,8 @@ def test_train_without_amp(module_name, sae_class): criterion = mse_l1 n_components = 2 - model = sae_class(data.shape[1], n_components, encoder_module=module_name) + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1], n_components, device='cpu') + model = sae_class(data.shape[1], n_components, encoder_module=module_name, **extra_kwargs) optimizer = optim.SGD(model.parameters(), lr=0.001) scheduler = None @@ -333,7 +346,8 @@ def test_train_tied_sae(sae_class): criterion = mse_l1 n_components = 2 - model = sae_class(data.shape[1], n_components) + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1], n_components, device='cpu') + model = sae_class(data.shape[1], n_components, **extra_kwargs) model.tied() optimizer = optim.SGD(model.parameters(), lr=0.001) @@ -361,7 +375,8 @@ def test_train_untied_after_tied(sae_class): criterion = mse_l1 n_components = 2 - model = sae_class(data.shape[1], n_components) + extra_kwargs = get_sae_kwargs(sae_class, data.shape[1], n_components, device='cpu') + model = sae_class(data.shape[1], n_components, **extra_kwargs) model.tied() # Train tied