Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,4 @@ Scikit-Mol has been developed as a community effort with contributions from peop
- [@mikemhenry](https://github.com/mikemhenry)
- [@c-feldmann](https://github.com/c-feldmann)
- Mieczyslaw Torchala [@mieczyslaw](https://github.com/mieczyslaw)
- Kyle Barbary [@kbarbary](https://github.com/kbarbary)
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,4 @@ Scikit-Mol has been developed as a community effort with contributions from peop
- [@mikemhenry](https://github.com/mikemhenry)
- [@c-feldmann](https://github.com/c-feldmann)
- Mieczyslaw Torchala [@mieczyslaw](https://github.com/mieczyslaw)
- Kyle Barbary [@kbarbary](https://github.com/kbarbary)
12 changes: 9 additions & 3 deletions scikit_mol/parallel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections.abc import Sequence
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union

import numpy as np
import pandas as pd
from joblib import Parallel, delayed, effective_n_jobs


def parallelized_with_batches(
fn: Callable[[Sequence[Any]], Any],
inputs_list: np.ndarray,
inputs_list: Union[np.ndarray, pd.DataFrame],
n_jobs: Optional[int] = None,
**job_kwargs: Any,
) -> Sequence[Optional[Any]]:
Expand All @@ -33,6 +34,11 @@ def parallelized_with_batches(
n_jobs = len(inputs_list)
pool = Parallel(n_jobs=n_jobs, **job_kwargs)

input_chunks = np.array_split(inputs_list, n_jobs)
if isinstance(inputs_list, (pd.DataFrame, pd.Series)):
indexes = np.array_split(range(len(inputs_list)), n_jobs)
input_chunks = [inputs_list.iloc[idx] for idx in indexes]
else:
input_chunks = np.array_split(inputs_list, n_jobs)

results = pool(delayed(fn)(chunk) for chunk in input_chunks)
return results