Skip to content

Improve legacy model testing #293

@stes

Description

@stes

Right now, we have the following test

@pytest.mark.parametrize("model_variant", MODEL_VARIANTS)
def test_load_legacy_model(model_variant):
    """Test loading a legacy CEBRA model."""

    X = np.random.normal(0, 1, (1000, 30))

    model_path = pathlib.Path(
        __file__
    ).parent / "_build_legacy_model" / f"cebra_model_{model_variant}.pt"

    if not model_path.exists():
        url = f"https://cebra.fra1.digitaloceanspaces.com/cebra_model_{model_variant}.pt"
        urllib.request.urlretrieve(url, model_path)

    loaded_model = CEBRA.load(model_path)

    assert loaded_model.model_architecture == "offset10-model"
    assert loaded_model.output_dimension == 8
    assert loaded_model.num_hidden_units == 16
    assert loaded_model.time_offsets == 10

    output = loaded_model.transform(X)
    assert isinstance(output, np.ndarray)
    assert output.shape[1] == loaded_model.output_dimension

    assert hasattr(loaded_model, "state_dict_")
    assert hasattr(loaded_model, "n_features_")

This test checks that the models can be loaded, but not that they give the same output. To improve this test, lets

  • compute reference outputs of the legacy models and store them in the s3 bucket
  • adapt the test to include assert_close checks between the re-computed and original model embeddings

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions