Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions astrbot/core/message/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
from astrbot.core.utils.io import download_audio_by_url, download_file, download_image_by_url, file_to_base64


class ComponentType(str, Enum):
Expand Down Expand Up @@ -157,7 +157,7 @@ async def convert_to_file_path(self) -> str:
if self.file.startswith("file:///"):
return self.file[8:]
if self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
file_path = await download_audio_by_url(self.file)
return os.path.abspath(file_path)
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
Expand Down
38 changes: 38 additions & 0 deletions astrbot/core/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def save_temp_img(img: Image.Image | bytes) -> str:
return p


def save_temp_audio(audio_data: bytes) -> str:
"""Save audio data to a temporary file with a proper extension."""
temp_dir = get_astrbot_temp_path()
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = os.path.join(temp_dir, f"recordseg_{timestamp}.audio")
Comment on lines +66 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Function docstring claims a proper extension is used, but the implementation always writes .audio.

Please either (a) change the implementation to accept/use an extension or filename hint so it truly uses a “proper” extension, or (b) update the docstring to state that the file is always saved with a .audio extension, so callers don’t rely on the extension to infer the actual audio type.

with open(p, "wb") as f:
f.write(audio_data)
return p


async def download_image_by_url(
url: str,
post: bool = False,
Expand Down Expand Up @@ -123,6 +133,34 @@ async def download_image_by_url(
raise e


async def download_audio_by_url(url: str) -> str:
"""Download audio from URL, preserving extension. Returns local file path."""
try:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True,
connector=connector,
) as session:
async with session.get(url) as resp:
data = await resp.read()
return save_temp_audio(data)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
logger.warning(
f"SSL certificate verification failed for {url}. "
"Disabling SSL verification (CERT_NONE) as a fallback."
)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context) as resp:
data = await resp.read()
return save_temp_audio(data)
except Exception as e:
raise e
Comment on lines +136 to +161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new function download_audio_by_url contains a significant amount of duplicated code for handling SSL fallbacks, which is also present in other download functions like download_image_by_url. This makes the code harder to maintain. Additionally, the docstring is misleading as it claims to "preserve extension", but the implementation saves the file with a hardcoded .audio extension via save_temp_audio.

To improve maintainability and correctness, you could refactor this function to reduce duplication and update the docstring. Here's a suggested refactoring that uses a nested helper function to avoid repeating the download logic:

async def download_audio_by_url(url: str) -> str:
    """Download audio from URL. Returns local file path."""

    async def _download_and_save(session, **kwargs):
        async with session.get(url, **kwargs) as resp:
            resp.raise_for_status()
            data = await resp.read()
            return save_temp_audio(data)

    try:
        ssl_context = ssl.create_default_context(cafile=certifi.where())
        connector = aiohttp.TCPConnector(ssl=ssl_context)
        async with aiohttp.ClientSession(
            trust_env=True,
            connector=connector,
        ) as session:
            return await _download_and_save(session)
    except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
        logger.warning(
            f"SSL certificate verification failed for {url}. "
            "Disabling SSL verification (CERT_NONE) as a fallback."
        )
        ssl_context = ssl.create_default_context()
        ssl_context.check_hostname = False
        ssl_context.verify_mode = ssl.CERT_NONE
        async with aiohttp.ClientSession() as session:
            return await _download_and_save(session, ssl=ssl_context)



async def download_file(url: str, path: str, show_progress: bool = False) -> None:
"""从指定 url 下载文件到指定路径 path"""
try:
Expand Down
Loading