style: add type hinting to variational distributions#261
style: add type hinting to variational distributions#261leahaeusel wants to merge 2 commits intoqueens-py:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds comprehensive type hinting to the variational distributions module in QUEENS, introducing two new type aliases (Array1xN and ArrayNxM) for better shape indication in numpy arrays. The PR also refactors FullRankNormal.reconstruct_distribution_parameters() to satisfy mypy's requirement for consistent return types by creating a separate method for cases where the Cholesky decomposition is needed, and enables mypy checking for the variational distributions module by removing it from exclusion lists.
Key Changes
- Introduced new type aliases
Array1xNandArrayNxMinsrc/queens/utils/type_hinting.pyto indicate array shapes - Added type hints throughout the variational distributions hierarchy including the base class and all concrete implementations
- Split
FullRankNormal.reconstruct_distribution_parameters()into two methods to resolve mypy's varying return count issue
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 27 comments.
Show a summary per file
| File | Description |
|---|---|
src/queens/utils/type_hinting.py |
New file introducing Array1xN and ArrayNxM type aliases for shaped numpy arrays |
src/queens/variational_distributions/_variational_distribution.py |
Added type hints to base class including Array1xN for variational parameters and proper return types |
src/queens/variational_distributions/particle.py |
Added type hints to all methods, updated docstrings, and modified initialization to pass n_parameters |
src/queens/variational_distributions/mixture_model.py |
Added type hints, improved variable naming (logpdf_lst, parameters), and initialized FIM properly |
src/queens/variational_distributions/mean_field_normal.py |
Added type hints to all methods and updated docstrings for consistency |
src/queens/variational_distributions/full_rank_normal.py |
Added type hints, split reconstruct_distribution_parameters() into two methods for mypy compatibility |
src/queens/variational_distributions/joint.py |
Added type hints, fixed logpdf initialization to use proper numpy array instead of scalar |
src/queens/variational_distributions/__init__.py |
Added type hints to __getattr__ function and imported base class for type checking |
pyproject.toml |
Removed variational_distributions from mypy exclusion list to enable type checking |
.pre-commit-config.yaml |
Removed variational_distributions from mypy exclusion pattern in pre-commit hooks |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
b16cf83 to
b69c2b1
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
style: add type hinting to variational distributions
b69c2b1 to
8dec669
Compare
gilrrei
left a comment
There was a problem hiding this comment.
I think the return statements could be more precise :)
d161ffc to
87ad5eb
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 23 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
87ad5eb to
79692a8
Compare
danielwolff1
left a comment
There was a problem hiding this comment.
Thanks a lot for providing the additional type hints in the variational distribution classes, that really must have a been quite some effort and sorry for taking so long to review it. I left a few comments to mark points where I'm unsure about because I'm not too familiar with variational distributions 🙈
| NDimsComponent = TypeVar("NDimsComponent", bound=int) | ||
| V = TypeVar("V", bound=Variational) | ||
| ArrayNSamplesXNDimsComponent: TypeAlias = np.ndarray[ # pylint: disable=invalid-name | ||
| tuple[NSamples, NDimsComponent], np.dtype[np.floating] | ||
| ] |
There was a problem hiding this comment.
since I am not familiar with the math here: Is this something that is super specific to a joint distribution or would it make sense to move this to _variational_distribution as well?
| within one sample. (Third dimension is empty and just added to keep slices two | ||
| dimensional.) |
There was a problem hiding this comment.
Is the remark here about the third dimension still true? 🤔
Just wondering because then the typing annotation would technically not be correct, right, since we define ArrayNSamplesXNDims: TypeAlias = np.ndarray[tuple[NSamples, NDims], np.dtype[np.floating]], i.e., as an array with exactly two dimensions.
| Gradients of the log-pdf w.r.t. the sample *x*. The first dimension of the array | ||
| corresponds to the different samples. The second dimension to different dimensions | ||
| within one sample. (Third dimension is empty and just added to keep slices | ||
| two-dimensional.) |
There was a problem hiding this comment.
Same question as below / above (I never know in which order GitHub will display my comments 🙈): Is the remark here about the third dimension still true?
| Variational, | ||
| ) | ||
|
|
||
| NComponents = TypeVar("NComponents", bound=int) |
There was a problem hiding this comment.
Is this from the meaning fundamentally different from the NDimsComponent type variable defined in joint.py? Or could they be unified and moved to _variational_distribution.py?
Description and Context:
What and Why?
This PR adds type hinting to our variational distributions.
I have introduced new type aliases to indicate the shapes of numpy arrays, since the docstrings provided this information for some of the variables. In some cases, the shapes were outdated, though, as I realized when I checked the shapes with our tests. In my opinion, it would be really helpful to eventually replace all or most of the
np.ndarrays in our code base with these shape-indicating aliases.FYI: I had to adapt
FullRankNormal.reconstruct_distribution_parameters()because mypy doesn't like functions with a varying number of returned variables.Related Issues and Pull Requests
Interested Parties
@gilrrei