|
1 | 1 | #include "mlxsharp/api.h" |
2 | | -#include "mlxsharp/llm_model_runner.h" |
3 | 2 |
|
4 | 3 | #include <algorithm> |
5 | 4 | #include <atomic> |
@@ -44,8 +43,6 @@ struct mlxsharp_session { |
44 | 43 | std::string chat_model; |
45 | 44 | std::string embedding_model; |
46 | 45 | std::string image_model; |
47 | | - std::unique_ptr<mlxsharp::llm::ModelRunner> model_runner; |
48 | | - |
49 | 46 | mlxsharp_session(mlxsharp_context_t* ctx, std::string chat, std::string embed, std::string image) |
50 | 47 | : context(ctx), |
51 | 48 | chat_model(std::move(chat)), |
@@ -565,104 +562,6 @@ void mlxsharp_free_buffer(unsigned char* data) { |
565 | 562 | std::free(data); |
566 | 563 | } |
567 | 564 |
|
568 | | -int mlxsharp_session_load_model( |
569 | | - void* session_ptr, |
570 | | - const char* model_directory, |
571 | | - const char* tokenizer_path) { |
572 | | - if (session_ptr == nullptr) { |
573 | | - return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session pointer is null."); |
574 | | - } |
575 | | - |
576 | | - if (model_directory == nullptr || tokenizer_path == nullptr) { |
577 | | - return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Model directory or tokenizer path is null."); |
578 | | - } |
579 | | - |
580 | | - auto* session = static_cast<mlxsharp_session_t*>(session_ptr); |
581 | | - |
582 | | - return invoke([&]() -> int { |
583 | | - auto model = mlxsharp::llm::ModelRunner::Create(model_directory, tokenizer_path); |
584 | | - session->model_runner = std::move(model); |
585 | | - return MLXSHARP_STATUS_SUCCESS; |
586 | | - }); |
587 | | -} |
588 | | - |
589 | | -int mlxsharp_session_generate_tokens( |
590 | | - void* session_ptr, |
591 | | - const int32_t* prompt_tokens, |
592 | | - size_t prompt_token_count, |
593 | | - const mlxsharp_generation_options* options, |
594 | | - mlxsharp_token_buffer* output_tokens, |
595 | | - mlx_usage* usage) { |
596 | | - if (session_ptr == nullptr) { |
597 | | - return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session pointer is null."); |
598 | | - } |
599 | | - |
600 | | - if (output_tokens == nullptr) { |
601 | | - return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, kNullOutParameter); |
602 | | - } |
603 | | - |
604 | | - output_tokens->tokens = nullptr; |
605 | | - output_tokens->length = 0; |
606 | | - |
607 | | - auto* session = static_cast<mlxsharp_session_t*>(session_ptr); |
608 | | - |
609 | | - if (session->model_runner == nullptr) { |
610 | | - return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Model is not loaded. Call mlxsharp_session_load_model first."); |
611 | | - } |
612 | | - |
613 | | - if (prompt_token_count > 0 && prompt_tokens == nullptr) { |
614 | | - return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Prompt tokens pointer is null."); |
615 | | - } |
616 | | - |
617 | | - if (options == nullptr) { |
618 | | - return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Generation options pointer is null."); |
619 | | - } |
620 | | - |
621 | | - return invoke([&]() -> int { |
622 | | - std::vector<int32_t> prompt; |
623 | | - prompt.reserve(prompt_token_count); |
624 | | - for (size_t i = 0; i < prompt_token_count; ++i) { |
625 | | - prompt.push_back(prompt_tokens[i]); |
626 | | - } |
627 | | - |
628 | | - mlxsharp::llm::GenerationOptions native_options{ |
629 | | - options->max_tokens, |
630 | | - options->temperature, |
631 | | - options->top_p, |
632 | | - options->top_k, |
633 | | - }; |
634 | | - |
635 | | - auto generated = session->model_runner->Generate(prompt, native_options); |
636 | | - output_tokens->length = generated.size(); |
637 | | - |
638 | | - if (generated.empty()) { |
639 | | - assign_usage(usage, static_cast<int>(prompt_token_count), 0); |
640 | | - return MLXSHARP_STATUS_SUCCESS; |
641 | | - } |
642 | | - |
643 | | - auto* buffer = static_cast<int32_t*>(std::malloc(generated.size() * sizeof(int32_t))); |
644 | | - if (buffer == nullptr) { |
645 | | - return set_error(MLXSHARP_STATUS_OUT_OF_MEMORY, "Failed to allocate output token buffer."); |
646 | | - } |
647 | | - |
648 | | - std::memcpy(buffer, generated.data(), generated.size() * sizeof(int32_t)); |
649 | | - output_tokens->tokens = buffer; |
650 | | - |
651 | | - assign_usage(usage, static_cast<int>(prompt_token_count), static_cast<int>(generated.size())); |
652 | | - return MLXSHARP_STATUS_SUCCESS; |
653 | | - }); |
654 | | -} |
655 | | - |
656 | | -void mlxsharp_release_tokens(mlxsharp_token_buffer* buffer) { |
657 | | - if (buffer == nullptr || buffer->tokens == nullptr) { |
658 | | - return; |
659 | | - } |
660 | | - |
661 | | - std::free(buffer->tokens); |
662 | | - buffer->tokens = nullptr; |
663 | | - buffer->length = 0; |
664 | | -} |
665 | | - |
666 | 565 | void mlxsharp_release_session(void* session_ptr) { |
667 | 566 | if (session_ptr == nullptr) { |
668 | 567 | return; |
|
0 commit comments