|
1 | 1 | # pipeline.py |
2 | | - |
3 | 2 | from collections.abc import Callable |
4 | 3 | from collections.abc import Iterable |
5 | 4 | from collections.abc import Iterator |
| 5 | +from concurrent.futures import ThreadPoolExecutor |
| 6 | +from concurrent.futures import as_completed |
6 | 7 | import itertools |
7 | 8 | import multiprocessing as mp |
| 9 | +from queue import Queue |
8 | 10 | from typing import Any |
9 | 11 | from typing import TypeVar |
10 | 12 | from typing import overload |
11 | 13 |
|
12 | 14 | from laygo.helpers import PipelineContext |
13 | 15 | from laygo.helpers import is_context_aware |
14 | | -from laygo.transformers.threaded import ThreadedTransformer |
15 | 16 | from laygo.transformers.transformer import Transformer |
| 17 | +from laygo.transformers.transformer import passthrough_chunks |
16 | 18 |
|
17 | 19 | T = TypeVar("T") |
| 20 | +U = TypeVar("U") |
18 | 21 | PipelineFunction = Callable[[T], Any] |
19 | 22 |
|
20 | 23 |
|
@@ -147,16 +150,109 @@ def apply[U]( |
147 | 150 |
|
148 | 151 | return self # type: ignore |
149 | 152 |
|
150 | | - def buffer(self, size: int) -> "Pipeline[T]": |
151 | | - """Buffer the pipeline using threaded processing. |
| 153 | + def branch( |
| 154 | + self, |
| 155 | + branches: dict[str, Transformer[T, Any]], |
| 156 | + batch_size: int = 1000, |
| 157 | + max_batch_buffer: int = 1, |
| 158 | + use_queue_chunks: bool = True, |
| 159 | + ) -> dict[str, list[Any]]: |
| 160 | + """Forks the pipeline into multiple branches for concurrent, parallel processing.""" |
| 161 | + if not branches: |
| 162 | + self.consume() |
| 163 | + return {} |
| 164 | + |
| 165 | + source_iterator = self.processed_data |
| 166 | + branch_items = list(branches.items()) |
| 167 | + num_branches = len(branch_items) |
| 168 | + final_results: dict[str, list[Any]] = {} |
| 169 | + |
| 170 | + queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)] |
| 171 | + |
| 172 | + def producer() -> None: |
| 173 | + """Reads from the source and distributes batches to ALL branch queues.""" |
| 174 | + # Use itertools.batched for clean and efficient batch creation. |
| 175 | + for batch_tuple in itertools.batched(source_iterator, batch_size): |
| 176 | + # The batch is a tuple; convert to a list for consumers. |
| 177 | + batch_list = list(batch_tuple) |
| 178 | + for q in queues: |
| 179 | + q.put(batch_list) |
| 180 | + |
| 181 | + # Signal to all consumers that the stream is finished. |
| 182 | + for q in queues: |
| 183 | + q.put(None) |
| 184 | + |
| 185 | + def consumer(transformer: Transformer, queue: Queue) -> list[Any]: |
| 186 | + """Consumes batches from a queue and runs them through a transformer.""" |
| 187 | + |
| 188 | + def stream_from_queue() -> Iterator[T]: |
| 189 | + while (batch := queue.get()) is not None: |
| 190 | + yield batch |
| 191 | + |
| 192 | + if use_queue_chunks: |
| 193 | + transformer = transformer.set_chunker(passthrough_chunks) |
| 194 | + |
| 195 | + result_iterator = transformer(stream_from_queue(), self.ctx) # type: ignore |
| 196 | + return list(result_iterator) |
| 197 | + |
| 198 | + with ThreadPoolExecutor(max_workers=num_branches + 1) as executor: |
| 199 | + executor.submit(producer) |
| 200 | + |
| 201 | + future_to_name = { |
| 202 | + executor.submit(consumer, transformer, queues[i]): name for i, (name, transformer) in enumerate(branch_items) |
| 203 | + } |
| 204 | + |
| 205 | + for future in as_completed(future_to_name): |
| 206 | + name = future_to_name[future] |
| 207 | + try: |
| 208 | + final_results[name] = future.result() |
| 209 | + except Exception as e: |
| 210 | + print(f"Branch '{name}' raised an exception: {e}") |
| 211 | + final_results[name] = [] |
| 212 | + |
| 213 | + return final_results |
| 214 | + |
| 215 | + def buffer(self, size: int, batch_size: int = 1000) -> "Pipeline[T]": |
| 216 | + """Inserts a buffer in the pipeline to allow downstream processing to read ahead. |
| 217 | +
|
| 218 | + This creates a background thread that reads from the upstream data source |
| 219 | + and fills a queue, decoupling the upstream and downstream stages. |
152 | 220 |
|
153 | 221 | Args: |
154 | | - size: The number of worker threads to use for buffering. |
| 222 | + size: The number of **batches** to hold in the buffer. |
| 223 | + batch_size: The number of items to accumulate per batch. |
155 | 224 |
|
156 | 225 | Returns: |
157 | 226 | The pipeline instance for method chaining. |
158 | 227 | """ |
159 | | - self.apply(ThreadedTransformer(max_workers=size)) |
| 228 | + source_iterator = self.processed_data |
| 229 | + |
| 230 | + def _buffered_stream() -> Iterator[T]: |
| 231 | + queue = Queue(maxsize=size) |
| 232 | + # We only need one background thread for the producer. |
| 233 | + executor = ThreadPoolExecutor(max_workers=1) |
| 234 | + |
| 235 | + def _producer() -> None: |
| 236 | + """The producer reads from the source and fills the queue.""" |
| 237 | + try: |
| 238 | + for batch_tuple in itertools.batched(source_iterator, batch_size): |
| 239 | + queue.put(list(batch_tuple)) |
| 240 | + finally: |
| 241 | + # Always put the sentinel value to signal the end of the stream. |
| 242 | + queue.put(None) |
| 243 | + |
| 244 | + # Start the producer in the background thread. |
| 245 | + executor.submit(_producer) |
| 246 | + |
| 247 | + try: |
| 248 | + # The main thread becomes the consumer. |
| 249 | + while (batch := queue.get()) is not None: |
| 250 | + yield from batch |
| 251 | + finally: |
| 252 | + # Ensure the background thread is cleaned up. |
| 253 | + executor.shutdown(wait=False, cancel_futures=True) |
| 254 | + |
| 255 | + self.processed_data = _buffered_stream() |
160 | 256 | return self |
161 | 257 |
|
162 | 258 | def __iter__(self) -> Iterator[T]: |
|
0 commit comments