diff --git a/asr/asr_connector.py b/asr/asr_connector.py index 1c52960..93d6e00 100644 --- a/asr/asr_connector.py +++ b/asr/asr_connector.py @@ -1,42 +1,53 @@ #!/usr/bin/env python3 +import threading + import rclpy from rclpy.node import Node from std_msgs.msg import String from rclpy.action import ActionClient from custom_interfaces.action import Prompt +from TTS.api import TTS +import sounddevice as sd + + class HighLevelPromptClient(Node): def __init__(self): super().__init__('high_level_prompt_client') - - # Subscriber to /high_level_prompt (std_msgs/String) + + # Subscriber to /high_level_prompt self.subscription = self.create_subscription( String, '/high_level_prompt', self.prompt_callback, 10 ) - self.subscription # prevent unused variable warning - # Action client for /prompt_high_level (custom_interfaces/Prompt) + # Action client for /prompt_high_level self.action_client = ActionClient(self, Prompt, '/prompt_high_level') + # Load TTS once at startup + self.get_logger().info('Loading Coqui TTS model...') + self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC") + self.sample_rate = self.tts.synthesizer.output_sample_rate + self.get_logger().info(f'TTS loaded. sample_rate={self.sample_rate}') + + # Prevent overlapping speech + self._speak_lock = threading.Lock() + def prompt_callback(self, msg: String): self.get_logger().info(f'Received prompt: "{msg.data}"') self.send_prompt_action(msg.data) def send_prompt_action(self, prompt_text: str): - # Wait until the action server is available if not self.action_client.wait_for_server(timeout_sec=5.0): self.get_logger().error('Action server /prompt_high_level not available!') return - - # Create goal + goal_msg = Prompt.Goal() goal_msg.prompt = prompt_text - - # Send goal asynchronously + self._send_goal_future = self.action_client.send_goal_async( goal_msg, feedback_callback=self.feedback_callback @@ -55,22 +66,49 @@ def goal_response_callback(self, future): def feedback_callback(self, feedback_msg): feedback = feedback_msg.feedback - self.get_logger().info(f'Feedback received: tools_called={feedback.tools_called}') + self.get_logger().info( + f'Feedback received: tools_called={feedback.tools_called}' + ) def get_result_callback(self, future): result = future.result().result - self.get_logger().info(f'Action finished. Success: {result.success}, Final Response: "{result.final_response}"') + final_response = result.final_response.strip() + + self.get_logger().info( + f'Action finished. Success: {result.success}, Final Response: "{final_response}"' + ) + + if result.success and final_response: + self.speak_text(final_response) + + def speak_text(self, text: str): + # Run TTS in a background thread so ROS callbacks stay responsive + threading.Thread( + target=self._speak_text_worker, + args=(text,), + daemon=True + ).start() + + def _speak_text_worker(self, text: str): + with self._speak_lock: + try: + self.get_logger().info(f'Speaking: "{text}"') + wav = self.tts.tts(text=text) + sd.stop() + sd.play(wav, samplerate=self.sample_rate) + sd.wait() + except Exception as e: + self.get_logger().error(f'TTS playback failed: {e}') + def main(args=None): rclpy.init(args=args) node = HighLevelPromptClient() + try: rclpy.spin(node) except KeyboardInterrupt: pass finally: node.destroy_node() - rclpy.shutdown() - -if __name__ == '__main__': - main() + rclpy.shutdown() \ No newline at end of file diff --git a/setup.py b/setup.py index f636474..9cdbc98 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,12 @@ (os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py') + glob('launch/*.sh')), ], - install_requires=['setuptools'], + install_requires=[ + 'setuptools', + 'coqui-tts[codec]', + 'transformers==5.0.0', + 'sounddevice', +], zip_safe=True, maintainer='final-project', maintainer_email='karamahati@gmail.com',