diff --git a/src/main.cpp b/src/main.cpp index a2c46cb..6f36bc4 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -296,48 +296,65 @@ int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * c return whisper_ctx_init_openvino_encoder(ctx->ptr, model_path, device, cache_dir); } -class WhisperFullParamsWrapper : public whisper_full_params { - std::string initial_prompt_str; - std::string suppress_regex_str; +struct WhisperFullParamsWrapper : public whisper_full_params { + std::string initial_prompt_str; + std::string suppress_regex_str; public: - py::function py_progress_callback; - WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params()) - : whisper_full_params(params), - initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""), - suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") { - initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str(); - suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str(); - // progress callback - progress_callback_user_data = this; - progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { - auto* self = static_cast(user_data); - if(self && self->print_progress){ - if (self->py_progress_callback) { - // call the python callback - py::gil_scoped_acquire gil; - self->py_progress_callback(progress); // Call Python callback - } - else { - fprintf(stderr, "Progress: %3d%%\n", progress); - } // Default message - } - } ; - } - - WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other) - : WhisperFullParamsWrapper(static_cast(other)) {} - - void set_initial_prompt(const std::string& prompt) { - initial_prompt_str = prompt; - initial_prompt = initial_prompt_str.c_str(); - } - - void set_suppress_regex(const std::string& regex) { - suppress_regex_str = regex; - suppress_regex = suppress_regex_str.c_str(); - } + py::function py_progress_callback; + WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params()) + : whisper_full_params(params), + initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""), + suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") { + initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str(); + suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str(); + // progress callback + progress_callback_user_data = this; + progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { + auto* self = static_cast(user_data); + if(self && self->print_progress){ + if (self->py_progress_callback) { + // call the python callback + py::gil_scoped_acquire gil; + self->py_progress_callback(progress); // Call Python callback + } + else { + fprintf(stderr, "Progress: %3d%%\n", progress); + } // Default message + } + } ; + } + WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other) + : whisper_full_params(static_cast(other)), // Copy base struct + initial_prompt_str(other.initial_prompt_str), + suppress_regex_str(other.suppress_regex_str), + py_progress_callback(other.py_progress_callback) { + // Reset pointers to new string copies + initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str(); + suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str(); + progress_callback_user_data = this; + progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { + auto* self = static_cast(user_data); + if(self && self->print_progress){ + if (self->py_progress_callback) { + // call the python callback + py::gil_scoped_acquire gil; + self->py_progress_callback(progress); // Call Python callback + } + else { + fprintf(stderr, "Progress: %3d%%\n", progress); + } // Default message + } + }; + } + void set_initial_prompt(const std::string& prompt) { + initial_prompt_str = prompt; + initial_prompt = initial_prompt_str.c_str(); + } + void set_suppress_regex(const std::string& regex) { + suppress_regex_str = regex; + suppress_regex = suppress_regex_str.c_str(); + } }; - WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) { return WhisperFullParamsWrapper(whisper_full_default_params(strategy)); }