Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 57 additions & 40 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,48 +296,65 @@ int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * c
return whisper_ctx_init_openvino_encoder(ctx->ptr, model_path, device, cache_dir);
}

class WhisperFullParamsWrapper : public whisper_full_params {
std::string initial_prompt_str;
std::string suppress_regex_str;
struct WhisperFullParamsWrapper : public whisper_full_params {
std::string initial_prompt_str;
std::string suppress_regex_str;
public:
py::function py_progress_callback;
WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
: whisper_full_params(params),
initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
// progress callback
progress_callback_user_data = this;
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
if(self && self->print_progress){
if (self->py_progress_callback) {
// call the python callback
py::gil_scoped_acquire gil;
self->py_progress_callback(progress); // Call Python callback
}
else {
fprintf(stderr, "Progress: %3d%%\n", progress);
} // Default message
}
} ;
}

WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other)
: WhisperFullParamsWrapper(static_cast<const whisper_full_params&>(other)) {}

void set_initial_prompt(const std::string& prompt) {
initial_prompt_str = prompt;
initial_prompt = initial_prompt_str.c_str();
}

void set_suppress_regex(const std::string& regex) {
suppress_regex_str = regex;
suppress_regex = suppress_regex_str.c_str();
}
py::function py_progress_callback;
WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
: whisper_full_params(params),
initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
// progress callback
progress_callback_user_data = this;
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
if(self && self->print_progress){
if (self->py_progress_callback) {
// call the python callback
py::gil_scoped_acquire gil;
self->py_progress_callback(progress); // Call Python callback
}
else {
fprintf(stderr, "Progress: %3d%%\n", progress);
} // Default message
}
} ;
}
WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other)
: whisper_full_params(static_cast<whisper_full_params>(other)), // Copy base struct
initial_prompt_str(other.initial_prompt_str),
suppress_regex_str(other.suppress_regex_str),
py_progress_callback(other.py_progress_callback) {
// Reset pointers to new string copies
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
progress_callback_user_data = this;
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
if(self && self->print_progress){
if (self->py_progress_callback) {
// call the python callback
py::gil_scoped_acquire gil;
self->py_progress_callback(progress); // Call Python callback
}
else {
fprintf(stderr, "Progress: %3d%%\n", progress);
} // Default message
}
};
}
void set_initial_prompt(const std::string& prompt) {
initial_prompt_str = prompt;
initial_prompt = initial_prompt_str.c_str();
}
void set_suppress_regex(const std::string& regex) {
suppress_regex_str = regex;
suppress_regex = suppress_regex_str.c_str();
}
};

WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) {
return WhisperFullParamsWrapper(whisper_full_default_params(strategy));
}
Expand Down
Loading