diff --git a/README.md b/README.md index 3a3cf0d..ee99949 100644 --- a/README.md +++ b/README.md @@ -118,3 +118,4 @@ Scikit-Mol has been developed as a community effort with contributions from peop - [@mikemhenry](https://github.com/mikemhenry) - [@c-feldmann](https://github.com/c-feldmann) - Mieczyslaw Torchala [@mieczyslaw](https://github.com/mieczyslaw) +- Kyle Barbary [@kbarbary](https://github.com/kbarbary) diff --git a/docs/index.md b/docs/index.md index 3a3cf0d..ee99949 100644 --- a/docs/index.md +++ b/docs/index.md @@ -118,3 +118,4 @@ Scikit-Mol has been developed as a community effort with contributions from peop - [@mikemhenry](https://github.com/mikemhenry) - [@c-feldmann](https://github.com/c-feldmann) - Mieczyslaw Torchala [@mieczyslaw](https://github.com/mieczyslaw) +- Kyle Barbary [@kbarbary](https://github.com/kbarbary) diff --git a/scikit_mol/parallel.py b/scikit_mol/parallel.py index b1d642e..75bc234 100644 --- a/scikit_mol/parallel.py +++ b/scikit_mol/parallel.py @@ -1,13 +1,14 @@ from collections.abc import Sequence -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import numpy as np +import pandas as pd from joblib import Parallel, delayed, effective_n_jobs def parallelized_with_batches( fn: Callable[[Sequence[Any]], Any], - inputs_list: np.ndarray, + inputs_list: Union[np.ndarray, pd.DataFrame], n_jobs: Optional[int] = None, **job_kwargs: Any, ) -> Sequence[Optional[Any]]: @@ -33,6 +34,11 @@ def parallelized_with_batches( n_jobs = len(inputs_list) pool = Parallel(n_jobs=n_jobs, **job_kwargs) - input_chunks = np.array_split(inputs_list, n_jobs) + if isinstance(inputs_list, (pd.DataFrame, pd.Series)): + indexes = np.array_split(range(len(inputs_list)), n_jobs) + input_chunks = [inputs_list.iloc[idx] for idx in indexes] + else: + input_chunks = np.array_split(inputs_list, n_jobs) + results = pool(delayed(fn)(chunk) for chunk in input_chunks) return results