diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 4d16e6ec..41b243e7 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -1301,21 +1301,21 @@ def getinfo(self, info_type: int) -> Union[str, int, bool, None]: # Make sure we use the correct amount of data based on length actual_data = data[:length] - # Now decode the string data - try: - return actual_data.decode("utf-8").rstrip("\0") - except UnicodeDecodeError: + # SQLGetInfoW returns UTF-16LE encoded strings (wide-character ODBC API) + # Try UTF-16LE first (expected), then UTF-8 as fallback + for encoding in ("utf-16-le", "utf-8"): try: - return actual_data.decode("latin1").rstrip("\0") - except Exception as e: - logger.debug( - "error", - "Failed to decode string in getinfo: %s. " - "Returning None to avoid silent corruption.", - e, - ) - # Explicitly return None to signal decoding failure - return None + return actual_data.decode(encoding).rstrip("\0") + except UnicodeDecodeError: + continue + + # All decodings failed + logger.debug( + "Failed to decode string in getinfo (info_type=%d) with supported encodings. " + "Returning None to avoid silent corruption.", + info_type, + ) + return None else: # If it's not bytes, return as is return data diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 97edff2a..289b9631 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -2187,6 +2187,132 @@ def test_getinfo_basic_driver_info(db_connection): pytest.fail(f"getinfo failed for basic driver info: {e}") +def test_getinfo_string_encoding_utf16(db_connection): + """Test that string values from getinfo are properly decoded from UTF-16.""" + + # Test string info types that should not contain null bytes + string_info_types = [ + ("SQL_DRIVER_VER", sql_const.SQL_DRIVER_VER.value), + ("SQL_DRIVER_NAME", sql_const.SQL_DRIVER_NAME.value), + ("SQL_DRIVER_ODBC_VER", sql_const.SQL_DRIVER_ODBC_VER.value), + ("SQL_SERVER_NAME", sql_const.SQL_SERVER_NAME.value), + ] + + for name, info_type in string_info_types: + result = db_connection.getinfo(info_type) + + if result is not None: + # Verify it's a string + assert isinstance(result, str), f"{name}: Expected str, got {type(result).__name__}" + + # Verify no null bytes (indicates UTF-16 decoded as UTF-8 bug) + assert ( + "\x00" not in result + ), f"{name} contains null bytes, likely UTF-16/UTF-8 encoding mismatch: {repr(result)}" + + # Verify it's not empty (optional, but good sanity check) + assert len(result) > 0, f"{name} returned empty string" + + +def test_getinfo_string_decoding_utf8_fallback(db_connection): + """Test that getinfo falls back to UTF-8 when UTF-16LE decoding fails. + + This test verifies the fallback path in the encoding loop where + UTF-16LE fails but UTF-8 succeeds. + """ + from unittest.mock import MagicMock + + # UTF-8 encoded "Hello" - this is valid UTF-8 but NOT valid UTF-16LE + # (odd number of bytes would fail UTF-16LE decode) + utf8_data = "Hello".encode("utf-8") # b'Hello' - 5 bytes, odd length + + mock_result = {"data": utf8_data, "length": len(utf8_data)} + + # Use a string-type info_type (SQL_DRIVER_NAME = 6 is in string_type_constants) + info_type = sql_const.SQL_DRIVER_NAME.value + + # Save the original _conn and replace with a mock + original_conn = db_connection._conn + try: + mock_conn = MagicMock() + mock_conn.get_info.return_value = mock_result + db_connection._conn = mock_conn + + result = db_connection.getinfo(info_type) + + assert result == "Hello", f"Expected 'Hello', got {repr(result)}" + assert isinstance(result, str), f"Expected str, got {type(result).__name__}" + finally: + # Restore the original connection + db_connection._conn = original_conn + + +def test_getinfo_string_decoding_all_fail_returns_none(db_connection): + """Test that getinfo returns None when all decoding attempts fail. + + This test verifies that when both UTF-16LE and UTF-8 decoding fail, + the method returns None to avoid silent data corruption. + """ + from unittest.mock import MagicMock + + # Invalid byte sequence that cannot be decoded as UTF-16LE or UTF-8 + # 0xFF 0xFE is a BOM, but followed by invalid continuation bytes for UTF-8 + # and odd length makes it invalid UTF-16LE + invalid_data = bytes([0x80, 0x81, 0x82]) # Invalid for both encodings + + mock_result = {"data": invalid_data, "length": len(invalid_data)} + + # Use a string-type info_type (SQL_DRIVER_NAME = 6 is in string_type_constants) + info_type = sql_const.SQL_DRIVER_NAME.value + + # Save the original _conn and replace with a mock + original_conn = db_connection._conn + try: + mock_conn = MagicMock() + mock_conn.get_info.return_value = mock_result + db_connection._conn = mock_conn + + result = db_connection.getinfo(info_type) + + # Should return None when all decoding fails + assert result is None, f"Expected None for invalid encoding, got {repr(result)}" + finally: + # Restore the original connection + db_connection._conn = original_conn + + +def test_getinfo_string_encoding_utf16_primary(db_connection): + """Test that getinfo correctly decodes valid UTF-16LE data. + + This test verifies the primary (expected) encoding path where + UTF-16LE decoding succeeds on first try. + """ + from unittest.mock import MagicMock + + # Valid UTF-16LE encoded "Test" with null terminator + utf16_data = "Test".encode("utf-16-le") + b"\x00\x00" + + mock_result = {"data": utf16_data, "length": len(utf16_data)} + + # Use a string-type info_type + info_type = sql_const.SQL_DRIVER_NAME.value + + # Save the original _conn and replace with a mock + original_conn = db_connection._conn + try: + mock_conn = MagicMock() + mock_conn.get_info.return_value = mock_result + db_connection._conn = mock_conn + + result = db_connection.getinfo(info_type) + + assert result == "Test", f"Expected 'Test', got {repr(result)}" + assert "\x00" not in result, f"Result contains null bytes: {repr(result)}" + finally: + # Restore the original connection + db_connection._conn = original_conn + + def test_getinfo_sql_support(db_connection): """Test SQL support and conformance info types."""