Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions vertica_python/vertica/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import signal
import select
import sys
import unicodedata
from collections import deque
from struct import unpack

Expand Down Expand Up @@ -88,6 +89,53 @@
warnings.warn(f"Cannot get the login user name: {str(e)}")


# TOTP validation utilities (client-side)
class TotpValidationResult(NamedTuple):
ok: bool
code: str
message: str


INVALID_TOTP_MSG = 'Invalid TOTP: Please enter a valid 6-digit numeric code.'


def validate_totp_code(raw_code: str, totp_is_valid=None) -> TotpValidationResult:
"""Validate and normalize a user-supplied TOTP value.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totp_is_valid is written as it is reserved for server side checks. but validate_totp_code function itself is called to do the checks before sending to the server. Something is not right, can you please verify?


Precedence:
1) Trim & normalize input (strip spaces and separators; normalize full-width digits)
2) Check emptiness, length == 6, and numeric-only

Returns TotpValidationResult(ok, code, message).
- Success: `ok=True`, `code` is a 6-digit ASCII string, `message=''`.
- Failure: `ok=False`, `code=''`, `message` is always the generic INVALID_TOTP_MSG.
`totp_is_valid` is reserved for optional server-side checks and ignored here.
"""
try:
s = raw_code if raw_code is not None else ''
# Normalize Unicode (convert full-width digits etc. to ASCII)
s = unicodedata.normalize('NFKC', s)
# Strip leading/trailing whitespace
s = s.strip()
# Remove common separators inside the code
# Spaces, hyphens, underscores, dots, and common dash-like characters
separators = {' ', '\t', '\n', '\r', '\f', '\v', '-', '_', '.',
'\u2012', '\u2013', '\u2014', '\u2212', '\u00B7', '\u2027', '\u30FB'}
Copy link
Collaborator

@sivaalamp sivaalamp Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user enters the TOTP with - or _ or ., your code is removing them without giving error message. Removing the trailing or leading white spaces is acceptable, but if user deliberately enters these special characters, should we remove them or should we warn user that invalid characters are not allowed.

# Replace all occurrences of separators
for sep in list(separators):
s = s.replace(sep, '')

# Empty / length / numeric checks
if s == '' or len(s) != 6 or not s.isdigit():
return TotpValidationResult(False, '', INVALID_TOTP_MSG)

# All good
return TotpValidationResult(True, s, '')
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the tests are not updated to test these new error message?

# Fallback generic error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no exception handling, it is just like function return. Handle the exception when we use try catch block

return TotpValidationResult(False, '', INVALID_TOTP_MSG)


def connect(**kwargs: Any) -> Connection:
"""Opens a new connection to a Vertica database."""
return Connection(kwargs)
Expand Down Expand Up @@ -313,6 +361,14 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None:
if self.totp is not None:
if not isinstance(self.totp, str):
raise TypeError('The value of connection option "totp" should be a string')
# Validate using local validator
result = validate_totp_code(self.totp, totp_is_valid=None)
if not result.ok:
msg = result.message or INVALID_TOTP_MSG
self._logger.error(f'Authentication failed: {msg}')
raise errors.ConnectionError(f'Authentication failed: {msg}')
# normalized digits-only code
self.totp = result.code
self._logger.info('TOTP received in connection options')

# OAuth authentication setup
Expand Down Expand Up @@ -974,13 +1030,11 @@ def send_startup(totp_value=None):
short_msg = match.group(1).strip() if match else error_msg.strip()

if "Invalid TOTP" in short_msg:
print("Authentication failed: Invalid TOTP token.")
self._logger.error("Authentication failed: Invalid TOTP token.")
self._logger.error(f"Authentication failed: {INVALID_TOTP_MSG}")
self.close_socket()
raise errors.ConnectionError("Authentication failed: Invalid TOTP token.")
raise errors.ConnectionError(f"Authentication failed: {INVALID_TOTP_MSG}")

# Generic error fallback
print(f"Authentication failed: {short_msg}")
self._logger.error(short_msg)
raise errors.ConnectionError(f"Authentication failed: {short_msg}")
else:
Expand All @@ -993,23 +1047,20 @@ def send_startup(totp_value=None):

# ✅ If TOTP not provided initially, prompt only once
if not totp:
timeout_seconds = 30 # 5 minutes timeout
timeout_seconds = 300 # 5 minutes timeout
try:
print("Enter TOTP: ", end="", flush=True)
ready, _, _ = select.select([sys.stdin], [], [], timeout_seconds)
if ready:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we wait 5mins for the user to enter TOTP?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Siva, if the user does not provide the TOTP within 5 minutes, the session should expire—meaning the connection between the driver and the server should be terminated.

totp_input = sys.stdin.readline().strip()

# ❌ Blank TOTP entered
if not totp_input:
self._logger.error("Invalid TOTP: Cannot be empty.")
raise errors.ConnectionError("Invalid TOTP: Cannot be empty.")

# ❌ Validate TOTP format (must be 6 digits)
if not totp_input.isdigit() or len(totp_input) != 6:
print("Invalid TOTP format. Please enter a 6-digit code.")
self._logger.error("Invalid TOTP format entered.")
raise errors.ConnectionError("Invalid TOTP format: Must be a 6-digit number.")
# Validate using local precedence-based validator
result = validate_totp_code(totp_input, totp_is_valid=None)
if not result.ok:
msg = INVALID_TOTP_MSG
self._logger.error(f"Authentication failed: {msg}")
raise errors.ConnectionError(f"Authentication failed: {msg}")
totp_input = result.code
# ✅ Valid TOTP — retry connection
totp = totp_input
self.close_socket()
Expand Down