diff --git a/vertica_python/tests/integration_tests/test_authentication.py b/vertica_python/tests/integration_tests/test_authentication.py index 85503b54..1da6318c 100644 --- a/vertica_python/tests/integration_tests/test_authentication.py +++ b/vertica_python/tests/integration_tests/test_authentication.py @@ -217,6 +217,17 @@ def totp_invalid_format_scenario(self): cur.execute("DROP USER IF EXISTS totp_user") cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + def test_totp_invalid_alphanumeric_code(self): + # Verify alphanumeric TOTP inputs return explicit client-side validation error + try: + # Provide alphanumeric TOTP via connection options; should fail locally + self._conn_info['totp'] = "ot123" + err_msg = "Invalid TOTP: Please enter a valid 6-digit numeric code." + self.assertConnectionFail(err_msg=err_msg) + finally: + # Clean up connection options + self._conn_info.pop('totp', None) + # Negative Test: Wrong TOTP (Valid format, wrong value) def totp_wrong_code_scenario(self): with self._connect() as conn: diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index 0d0e6a54..cf474ae4 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -49,6 +49,7 @@ import signal import select import sys +import unicodedata from collections import deque from struct import unpack @@ -88,6 +89,50 @@ warnings.warn(f"Cannot get the login user name: {str(e)}") +# TOTP validation utilities (client-side) +class TotpValidationResult(NamedTuple): + ok: bool + code: str + message: str + + +INVALID_TOTP_MSG = 'Invalid TOTP: Please enter a valid 6-digit numeric code.' + + +def validate_totp_code(raw_code: str) -> TotpValidationResult: + """Validate and normalize a user-supplied TOTP value. + + Precedence: + 1) Trim & normalize input (normalize full-width digits; strip leading/trailing whitespace only) + 2) Empty check + 3) Length check (must be exactly 6) + 4) Numeric-only check (digits 0–9 only; do not remove internal separators) + + Returns TotpValidationResult(ok, code, message). + - Success: `ok=True`, `code` is a 6-digit ASCII string, `message=''`. + - Failure: `ok=False`, `code=''`, `message` is always the generic INVALID_TOTP_MSG. + """ + try: + s = raw_code if raw_code is not None else '' + # Normalize Unicode (convert full-width digits etc. to ASCII) + s = unicodedata.normalize('NFKC', s) + # Strip leading/trailing whitespace + s = s.strip() + # Empty / length / numeric checks (do not remove internal separators) + if s == '': + return TotpValidationResult(False, '', INVALID_TOTP_MSG) + if len(s) != 6: + return TotpValidationResult(False, '', INVALID_TOTP_MSG) + if not s.isdigit(): + return TotpValidationResult(False, '', INVALID_TOTP_MSG) + + # All good + return TotpValidationResult(True, s, '') + except Exception: + # Fallback generic error + return TotpValidationResult(False, '', INVALID_TOTP_MSG) + + def connect(**kwargs: Any) -> Connection: """Opens a new connection to a Vertica database.""" return Connection(kwargs) @@ -313,6 +358,14 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: if self.totp is not None: if not isinstance(self.totp, str): raise TypeError('The value of connection option "totp" should be a string') + # Validate using local validator + result = validate_totp_code(self.totp) + if not result.ok: + msg = INVALID_TOTP_MSG + self._logger.error(msg) + raise errors.ConnectionError(msg) + # normalized digits-only code + self.totp = result.code self._logger.info('TOTP received in connection options') # OAuth authentication setup @@ -974,13 +1027,11 @@ def send_startup(totp_value=None): short_msg = match.group(1).strip() if match else error_msg.strip() if "Invalid TOTP" in short_msg: - print("Authentication failed: Invalid TOTP token.") - self._logger.error("Authentication failed: Invalid TOTP token.") + self._logger.error(INVALID_TOTP_MSG) self.close_socket() - raise errors.ConnectionError("Authentication failed: Invalid TOTP token.") + raise errors.ConnectionError(INVALID_TOTP_MSG) # Generic error fallback - print(f"Authentication failed: {short_msg}") self._logger.error(short_msg) raise errors.ConnectionError(f"Authentication failed: {short_msg}") else: @@ -993,23 +1044,20 @@ def send_startup(totp_value=None): # ✅ If TOTP not provided initially, prompt only once if not totp: - timeout_seconds = 30 # 5 minutes timeout + timeout_seconds = 300 # 5 minutes timeout try: print("Enter TOTP: ", end="", flush=True) ready, _, _ = select.select([sys.stdin], [], [], timeout_seconds) if ready: totp_input = sys.stdin.readline().strip() - # ❌ Blank TOTP entered - if not totp_input: - self._logger.error("Invalid TOTP: Cannot be empty.") - raise errors.ConnectionError("Invalid TOTP: Cannot be empty.") - - # ❌ Validate TOTP format (must be 6 digits) - if not totp_input.isdigit() or len(totp_input) != 6: - print("Invalid TOTP format. Please enter a 6-digit code.") - self._logger.error("Invalid TOTP format entered.") - raise errors.ConnectionError("Invalid TOTP format: Must be a 6-digit number.") + # Validate using local precedence-based validator + result = validate_totp_code(totp_input) + if not result.ok: + msg = INVALID_TOTP_MSG + self._logger.error(msg) + raise errors.ConnectionError(msg) + totp_input = result.code # ✅ Valid TOTP — retry connection totp = totp_input self.close_socket()