diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp index 3137d274b..dffcca848 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp @@ -10,11 +10,19 @@ using namespace types; OnlineASRProcessor::OnlineASRProcessor(const ASR *asr) : asr(asr) {} void OnlineASRProcessor::insertAudioChunk(std::span audio) { + std::scoped_lock lock(audioMutex); audioBuffer.insert(audioBuffer.end(), audio.begin(), audio.end()); } ProcessResult OnlineASRProcessor::processIter(const DecodingOptions &options) { - std::vector res = asr->transcribe(audioBuffer, options); + std::vector snapshot; + + { + std::scoped_lock lock(audioMutex); + snapshot = audioBuffer; + } + + std::vector res = asr->transcribe(snapshot, options); std::vector tsw; for (const auto &segment : res) { @@ -28,20 +36,19 @@ ProcessResult OnlineASRProcessor::processIter(const DecodingOptions &options) { this->committed.insert(this->committed.end(), flushed.begin(), flushed.end()); constexpr int32_t chunkThresholdSec = 15; - if (static_cast(audioBuffer.size()) / - OnlineASRProcessor::kSamplingRate > + if (static_cast(snapshot.size()) / OnlineASRProcessor::kSamplingRate > chunkThresholdSec) { chunkCompletedSegment(res); } - auto move_to_vector = [](auto& container) { - return std::vector(std::make_move_iterator(container.begin()), - std::make_move_iterator(container.end())); + auto move_to_vector = [](auto &container) { + return std::vector(std::make_move_iterator(container.begin()), + std::make_move_iterator(container.end())); }; std::deque nonCommittedWords = this->hypothesisBuffer.complete(); - return { move_to_vector(flushed), move_to_vector(nonCommittedWords) }; + return {move_to_vector(flushed), move_to_vector(nonCommittedWords)}; } void OnlineASRProcessor::chunkCompletedSegment(std::span res) { @@ -74,6 +81,7 @@ void OnlineASRProcessor::chunkAt(float time) { auto startIndex = static_cast(cutSeconds * OnlineASRProcessor::kSamplingRate); + std::scoped_lock lock(audioMutex); if (startIndex < audioBuffer.size()) { audioBuffer.erase(audioBuffer.begin(), audioBuffer.begin() + startIndex); } else { @@ -88,8 +96,11 @@ std::vector OnlineASRProcessor::finish() { std::vector buffer(std::make_move_iterator(bufferDeq.begin()), std::make_move_iterator(bufferDeq.end())); - this->bufferTimeOffset += static_cast(audioBuffer.size()) / - OnlineASRProcessor::kSamplingRate; + { + std::scoped_lock lock(audioMutex); + this->bufferTimeOffset += static_cast(audioBuffer.size()) / + OnlineASRProcessor::kSamplingRate; + } return buffer; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h index 98944bdbe..2fa0bdd8b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h @@ -5,6 +5,8 @@ #include "rnexecutorch/models/speech_to_text/types/ProcessResult.h" #include "rnexecutorch/models/speech_to_text/types/Word.h" +#include + namespace rnexecutorch::models::speech_to_text::stream { class OnlineASRProcessor { @@ -16,6 +18,7 @@ class OnlineASRProcessor { std::vector finish(); std::vector audioBuffer; + mutable std::mutex audioMutex; private: const asr::ASR *asr;