Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,19 @@ def on_partition_generation_completed(
if status_message:
yield status_message

def on_partition(self, partition: Partition) -> None:
def on_partition(self, partition: Partition) -> Iterable[AirbyteMessage]:
"""
This method is called when a partition is generated.
1. Add the partition to the set of partitions for the stream
2. Log the slice if necessary
2. Log the slice if necessary — yield the log message directly instead of
putting it on the shared queue (prevents deadlock when queue is full)
3. Submit the partition to the thread pool manager
"""
stream_name = partition.stream_name()
self._streams_to_running_partitions[stream_name].add(partition)
cursor = self._stream_name_to_instance[stream_name].cursor
if self._slice_logger.should_log_slice_message(self._logger):
self._message_repository.emit_message(
self._slice_logger.create_slice_log_message(partition.to_slice())
)
yield self._slice_logger.create_slice_log_message(partition.to_slice())
self._thread_pool_manager.submit(
self._partition_reader.process_partition, partition, cursor
)
Expand Down Expand Up @@ -426,7 +425,7 @@ def _on_stream_is_done(self, stream_name: str) -> Iterable[AirbyteMessage]:
)
self._logger.info(f"Marking stream {stream_name} as STOPPED")
stream = self._stream_name_to_instance[stream_name]
stream.cursor.ensure_at_least_one_state_emitted()
yield from stream.cursor.ensure_at_least_one_state_emitted()
yield from self._message_repository.consume_queue()
self._logger.info(f"Finished syncing {stream.name}")
self._streams_done.add(stream_name)
Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/sources/concurrent_source/concurrent_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _handle_item(
elif isinstance(queue_item, PartitionGenerationCompletedSentinel):
yield from concurrent_stream_processor.on_partition_generation_completed(queue_item)
elif isinstance(queue_item, Partition):
concurrent_stream_processor.on_partition(queue_item)
yield from concurrent_stream_processor.on_partition(queue_item)
elif isinstance(queue_item, PartitionCompleteSentinel):
yield from concurrent_stream_processor.on_partition_complete_sentinel(queue_item)
elif isinstance(queue_item, Record):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, TypeVar

from airbyte_cdk.models import (
AirbyteMessage,
AirbyteStateBlob,
AirbyteStateMessage,
AirbyteStateType,
Expand Down Expand Up @@ -268,10 +269,11 @@ def _check_and_update_parent_state(self) -> None:
if last_closed_state is not None:
self._parent_state = last_closed_state

def ensure_at_least_one_state_emitted(self) -> None:
def ensure_at_least_one_state_emitted(self) -> Iterable[AirbyteMessage]:
"""
The platform expects at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
called.
Returns the state message directly instead of putting it on the shared queue.
"""
if not any(
semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items()
Expand All @@ -281,7 +283,7 @@ def ensure_at_least_one_state_emitted(self) -> None:
self._global_cursor = self._new_global_cursor
self._lookback_window = self._timer.finish()
self._parent_state = self._partition_router.get_stream_state()
self._emit_state_message(throttle=False)
yield from self._create_state_message(throttle=False)

def _throttle_state_message(self) -> Optional[float]:
"""
Expand All @@ -292,7 +294,33 @@ def _throttle_state_message(self) -> Optional[float]:
return None
return current_time

def _create_state_message(self, throttle: bool = True) -> Iterable[AirbyteMessage]:
"""
Build and return the state message directly instead of emitting through the message repository.
Used by ensure_at_least_one_state_emitted() to avoid deadlock when the main thread
would otherwise call queue.put() on a full queue.
"""
if throttle:
current_time = self._throttle_state_message()
if current_time is None:
return
self._last_emission_time = current_time
# Skip state emit for global cursor if parent state is empty
if self._use_global_cursor and not self._parent_state:
return

self._connector_state_manager.update_state_for_stream(
self._stream_name,
self._stream_namespace,
self.state,
)
state_message = self._connector_state_manager.create_state_message(
self._stream_name, self._stream_namespace
)
yield state_message

def _emit_state_message(self, throttle: bool = True) -> None:
"""Emit state message via message repository. Used by close_partition() on worker threads."""
if throttle:
current_time = self._throttle_state_message()
if current_time is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,16 @@ def stream_slices(self) -> Iterable[StreamSlice]:
if is_last_record_in_slice:
parent_stream.cursor.close_partition(partition)
if is_last_slice:
parent_stream.cursor.ensure_at_least_one_state_emitted()
# ensure_at_least_one_state_emitted now returns messages directly.
# On this worker thread we need to consume the returned iterator
# so the cursor's internal state updates happen, but the messages
# themselves are discarded — the parent cursor's close_partition()
# above already emitted state through the queue. This call just
# ensures internal bookkeeping is finalized.
for (
_msg
) in parent_stream.cursor.ensure_at_least_one_state_emitted():
pass

if emit_slice:
yield StreamSlice(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Iterable, List, MutableMapping

from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.types import StreamState
Expand Down Expand Up @@ -56,4 +57,4 @@ def get_start_time(self) -> datetime: ...
def emit_state_message(self) -> None: ...

@abstractmethod
def ensure_at_least_one_state_emitted(self) -> None: ...
def ensure_at_least_one_state_emitted(self) -> Iterable[AirbyteMessage]: ...
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,19 @@ def get_state(self) -> MutableMapping[str, Any]:
def set_initial_state(self, value: StreamState) -> None:
pass

def ensure_at_least_one_state_emitted(self) -> None:
self.emit_state_message()
def ensure_at_least_one_state_emitted(self) -> Iterable[AirbyteMessage]:
"""Return the state message directly instead of putting it on the shared queue."""
with self._state_lock:
new_state = self.get_state()
self._connector_state_manager.update_state_for_stream(
self._stream_name,
self._stream_namespace,
new_state,
)
state_message = self._connector_state_manager.create_state_message(
self._stream_name, self._stream_namespace
)
yield state_message

def should_be_synced(self, record: Record) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Iterable, List, MutableMapping, Optional

from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
Expand Down Expand Up @@ -73,14 +74,15 @@ def get_start_time(self) -> datetime:
def emit_state_message(self) -> None:
pass

def ensure_at_least_one_state_emitted(self) -> None:
def ensure_at_least_one_state_emitted(self) -> Iterable[AirbyteMessage]:
"""Return the state message directly instead of putting it on the shared queue."""
self._connector_state_manager.update_state_for_stream(
self._stream_name, self._stream_namespace, self.state
)
state_message = self._connector_state_manager.create_state_message(
self._stream_name, self._stream_namespace
)
self._message_repository.emit_message(state_message)
yield state_message

def should_be_synced(self, record: Record) -> bool:
return True
26 changes: 20 additions & 6 deletions airbyte_cdk/sources/streams/concurrent/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Union,
)

from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
from airbyte_cdk.sources.streams import NO_CURSOR_STATE_KEY
Expand Down Expand Up @@ -72,10 +73,13 @@ def close_partition(self, partition: Partition) -> None:
raise NotImplementedError()

@abstractmethod
def ensure_at_least_one_state_emitted(self) -> None:
def ensure_at_least_one_state_emitted(self) -> Iterable[AirbyteMessage]:
"""
State messages are emitted when a partition is closed. However, the platform expects at least one state to be emitted per sync per
stream. Hence, if no partitions are generated, this method needs to be called.

Returns the state messages directly instead of putting them on the shared queue,
so the caller (main thread) can yield them without risk of deadlock.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -140,9 +144,10 @@ def observe(self, record: Record) -> None:
def close_partition(self, partition: Partition) -> None:
pass

def ensure_at_least_one_state_emitted(self) -> None:
def ensure_at_least_one_state_emitted(self) -> Iterable[AirbyteMessage]:
"""
Used primarily for full refresh syncs that do not have a valid cursor value to emit at the end of a sync
Used primarily for full refresh syncs that do not have a valid cursor value to emit at the end of a sync.
Returns the state message directly instead of putting it on the shared queue.
"""

self._connector_state_manager.update_state_for_stream(
Expand All @@ -151,7 +156,7 @@ def ensure_at_least_one_state_emitted(self) -> None:
state_message = self._connector_state_manager.create_state_message(
self._stream_name, self._stream_namespace
)
self._message_repository.emit_message(state_message)
yield state_message

def should_be_synced(self, record: Record) -> bool:
return True
Expand Down Expand Up @@ -397,12 +402,21 @@ def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType
f"Partition is expected to have key `{key}` but could not be found"
) from exception

def ensure_at_least_one_state_emitted(self) -> None:
def ensure_at_least_one_state_emitted(self) -> Iterable[AirbyteMessage]:
"""
The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
called.
Returns the state message directly instead of putting it on the shared queue.
"""
self._emit_state_message()
self._connector_state_manager.update_state_for_stream(
self._stream_name,
self._stream_namespace,
self.state,
)
state_message = self._connector_state_manager.create_state_message(
self._stream_name, self._stream_namespace
)
yield state_message

def stream_slices(self) -> Iterable[StreamSlice]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3739,7 +3739,7 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update():
record_counter=RecordCounter(),
)
)
cursor.ensure_at_least_one_state_emitted()
list(cursor.ensure_at_least_one_state_emitted())

state = cursor.state
assert state == {
Expand Down Expand Up @@ -3835,7 +3835,7 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update
record_counter=RecordCounter(),
)
)
cursor.ensure_at_least_one_state_emitted()
list(cursor.ensure_at_least_one_state_emitted())

state = cursor.state
assert state == {
Expand Down Expand Up @@ -3927,15 +3927,18 @@ def test_given_all_partitions_finished_when_close_partition_then_final_state_emi
)
)

cursor.ensure_at_least_one_state_emitted()
state_messages = list(cursor.ensure_at_least_one_state_emitted())

final_state = cursor.state
assert final_state["use_global_cursor"] is False
assert len(final_state["states"]) == 2
assert final_state["state"]["updated_at"] == "2024-01-02T00:00:00Z"
assert final_state["parent_state"] == {"posts": {"updated_at": "2024-01-06T00:00:00Z"}}
assert final_state["lookback_window"] == 86400
assert cursor._message_repository.emit_message.call_count == 2
# close_partition() emits 1 state via message_repository (second is throttled)
# ensure_at_least_one_state_emitted() returns 1 state directly (no longer uses message_repository)
assert cursor._message_repository.emit_message.call_count == 1
assert len(state_messages) == 1
assert mock_cursor.stream_slices.call_count == 2 # Called once for each partition

# Checks that all internal variables are cleaned up
Expand Down Expand Up @@ -4001,7 +4004,7 @@ def test_given_partition_limit_exceeded_when_close_partition_then_switch_to_glob
record_counter=RecordCounter(),
)
)
cursor.ensure_at_least_one_state_emitted()
list(cursor.ensure_at_least_one_state_emitted())

final_state = cursor.state
assert len(slices) == 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ def test_substream_partition_router_closes_all_partitions_even_when_no_records()

mock_cursor = Mock()
mock_cursor.stream_slices.return_value = []
mock_cursor.ensure_at_least_one_state_emitted.return_value = []

partition_router = SubstreamPartitionRouter(
parent_stream_configs=[
Expand Down Expand Up @@ -1270,6 +1271,7 @@ def test_substream_partition_router_closes_partition_even_when_parent_key_missin

mock_cursor = Mock()
mock_cursor.stream_slices.return_value = []
mock_cursor.ensure_at_least_one_state_emitted.return_value = []

partition_router = SubstreamPartitionRouter(
parent_stream_configs=[
Expand Down
Loading
Loading