2222import threading
2323import logging
2424import uuid
25+ from contextlib import contextmanager
2526
2627from google .protobuf .struct_pb2 import ListValue
2728from google .protobuf .struct_pb2 import Value
3435from google .cloud .spanner_v1 .types import ExecuteSqlRequest
3536from google .cloud .spanner_v1 .types import TransactionOptions
3637from google .cloud .spanner_v1 .data_types import JsonObject , Interval
37- from google .cloud .spanner_v1 .request_id_header import with_request_id
38+ from google .cloud .spanner_v1 .request_id_header import (
39+ with_request_id ,
40+ with_request_id_metadata_only ,
41+ )
3842from google .cloud .spanner_v1 .types import TypeCode
43+ from google .cloud .spanner_v1 .exceptions import wrap_with_request_id
3944
4045from google .rpc .error_details_pb2 import RetryInfo
4146
@@ -568,7 +573,10 @@ def _retry_on_aborted_exception(
568573):
569574 """
570575 Handles retry logic for Aborted exceptions, considering the deadline.
576+ Also handles SpannerError that wraps Aborted exceptions.
571577 """
578+ from google .cloud .spanner_v1 .exceptions import SpannerError
579+
572580 attempts = 0
573581 while True :
574582 try :
@@ -582,6 +590,17 @@ def _retry_on_aborted_exception(
582590 default_retry_delay = default_retry_delay ,
583591 )
584592 continue
593+ except SpannerError as exc :
594+ # Check if the wrapped error is Aborted
595+ if isinstance (exc ._error , Aborted ):
596+ _delay_until_retry (
597+ exc ._error ,
598+ deadline = deadline ,
599+ attempts = attempts ,
600+ default_retry_delay = default_retry_delay ,
601+ )
602+ continue
603+ raise
585604
586605
587606def _retry (
@@ -600,10 +619,13 @@ def _retry(
600619 delay: The delay in seconds between retries.
601620 allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
602621 Passing allowed_exceptions as None will lead to retrying for all exceptions.
622+ Also handles SpannerError wrapping allowed exceptions.
603623
604624 Returns:
605625 The result of the function if it is successful, or raises the last exception if all retries fail.
606626 """
627+ from google .cloud .spanner_v1 .exceptions import SpannerError
628+
607629 retries = 0
608630 while retries <= retry_count :
609631 if retries > 0 and before_next_retry :
@@ -612,14 +634,23 @@ def _retry(
612634 try :
613635 return func ()
614636 except Exception as exc :
615- if (
616- allowed_exceptions is None or exc .__class__ in allowed_exceptions
617- ) and retries < retry_count :
637+ # Check if exception is allowed directly or wrapped in SpannerError
638+ exc_to_check = exc
639+ if isinstance (exc , SpannerError ):
640+ exc_to_check = exc ._error
641+
642+ is_allowed = (
643+ allowed_exceptions is None
644+ or exc_to_check .__class__ in allowed_exceptions
645+ )
646+
647+ if is_allowed and retries < retry_count :
618648 if (
619649 allowed_exceptions is not None
620- and allowed_exceptions [exc .__class__ ] is not None
650+ and exc_to_check .__class__ in allowed_exceptions
651+ and allowed_exceptions [exc_to_check .__class__ ] is not None
621652 ):
622- allowed_exceptions [exc .__class__ ](exc )
653+ allowed_exceptions [exc_to_check .__class__ ](exc_to_check )
623654 time .sleep (delay )
624655 delay = delay * 2
625656 retries = retries + 1
@@ -767,9 +798,67 @@ def reset(self):
767798
768799
769800def _metadata_with_request_id (* args , ** kwargs ):
801+ """Return metadata with request ID header.
802+
803+ This function returns only the metadata list (not a tuple),
804+ maintaining backward compatibility with existing code.
805+
806+ Args:
807+ *args: Arguments to pass to with_request_id
808+ **kwargs: Keyword arguments to pass to with_request_id
809+
810+ Returns:
811+ list: gRPC metadata with request ID header
812+ """
813+ return with_request_id_metadata_only (* args , ** kwargs )
814+
815+
816+ def _metadata_with_request_id_and_req_id (* args , ** kwargs ):
817+ """Return both metadata and request ID string.
818+
819+ This is used when we need to augment errors with the request ID.
820+
821+ Args:
822+ *args: Arguments to pass to with_request_id
823+ **kwargs: Keyword arguments to pass to with_request_id
824+
825+ Returns:
826+ tuple: (metadata, request_id)
827+ """
770828 return with_request_id (* args , ** kwargs )
771829
772830
831+ def _augment_error_with_request_id (error , request_id = None ):
832+ """Augment an error with request ID information.
833+
834+ Args:
835+ error: The error to augment (typically GoogleAPICallError)
836+ request_id (str): The request ID to include
837+
838+ Returns:
839+ The augmented error with request ID information
840+ """
841+ return wrap_with_request_id (error , request_id )
842+
843+
844+ @contextmanager
845+ def _augment_errors_with_request_id (request_id ):
846+ """Context manager to augment exceptions with request ID.
847+
848+ Args:
849+ request_id (str): The request ID to include in exceptions
850+
851+ Yields:
852+ None
853+ """
854+ try :
855+ yield
856+ except Exception as exc :
857+ augmented = _augment_error_with_request_id (exc , request_id )
858+ # Use exception chaining to preserve the original exception
859+ raise augmented from exc
860+
861+
773862def _merge_Transaction_Options (
774863 defaultTransactionOptions : TransactionOptions ,
775864 mergeTransactionOptions : TransactionOptions ,
0 commit comments