Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 53 additions & 15 deletions asr/asr_connector.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,53 @@
#!/usr/bin/env python3

import threading

import rclpy
from rclpy.node import Node
from std_msgs.msg import String
from rclpy.action import ActionClient
from custom_interfaces.action import Prompt

from TTS.api import TTS
import sounddevice as sd


class HighLevelPromptClient(Node):
def __init__(self):
super().__init__('high_level_prompt_client')
# Subscriber to /high_level_prompt (std_msgs/String)

# Subscriber to /high_level_prompt
self.subscription = self.create_subscription(
String,
'/high_level_prompt',
self.prompt_callback,
10
)
self.subscription # prevent unused variable warning

# Action client for /prompt_high_level (custom_interfaces/Prompt)
# Action client for /prompt_high_level
self.action_client = ActionClient(self, Prompt, '/prompt_high_level')

# Load TTS once at startup
self.get_logger().info('Loading Coqui TTS model...')
self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC")
self.sample_rate = self.tts.synthesizer.output_sample_rate
self.get_logger().info(f'TTS loaded. sample_rate={self.sample_rate}')

# Prevent overlapping speech
self._speak_lock = threading.Lock()

def prompt_callback(self, msg: String):
self.get_logger().info(f'Received prompt: "{msg.data}"')
self.send_prompt_action(msg.data)

def send_prompt_action(self, prompt_text: str):
# Wait until the action server is available
if not self.action_client.wait_for_server(timeout_sec=5.0):
self.get_logger().error('Action server /prompt_high_level not available!')
return

# Create goal

goal_msg = Prompt.Goal()
goal_msg.prompt = prompt_text

# Send goal asynchronously

self._send_goal_future = self.action_client.send_goal_async(
goal_msg,
feedback_callback=self.feedback_callback
Expand All @@ -55,22 +66,49 @@ def goal_response_callback(self, future):

def feedback_callback(self, feedback_msg):
feedback = feedback_msg.feedback
self.get_logger().info(f'Feedback received: tools_called={feedback.tools_called}')
self.get_logger().info(
f'Feedback received: tools_called={feedback.tools_called}'
)

def get_result_callback(self, future):
result = future.result().result
self.get_logger().info(f'Action finished. Success: {result.success}, Final Response: "{result.final_response}"')
final_response = result.final_response.strip()

self.get_logger().info(
f'Action finished. Success: {result.success}, Final Response: "{final_response}"'
)

if result.success and final_response:
self.speak_text(final_response)

def speak_text(self, text: str):
# Run TTS in a background thread so ROS callbacks stay responsive
threading.Thread(
target=self._speak_text_worker,
args=(text,),
daemon=True
).start()

def _speak_text_worker(self, text: str):
with self._speak_lock:
try:
self.get_logger().info(f'Speaking: "{text}"')
wav = self.tts.tts(text=text)
sd.stop()
sd.play(wav, samplerate=self.sample_rate)
sd.wait()
except Exception as e:
self.get_logger().error(f'TTS playback failed: {e}')


def main(args=None):
rclpy.init(args=args)
node = HighLevelPromptClient()

try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

if __name__ == '__main__':
main()
rclpy.shutdown()
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
(os.path.join('share', package_name, 'launch'),
glob('launch/*.launch.py') + glob('launch/*.sh')),
],
install_requires=['setuptools'],
install_requires=[
'setuptools',
'coqui-tts[codec]',
'transformers==5.0.0',
'sounddevice',
],
zip_safe=True,
maintainer='final-project',
maintainer_email='karamahati@gmail.com',
Expand Down