From 050a660fa4e1bed4a583a1513e0704a4bdb064a9 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Fri, 31 Jan 2025 16:12:31 +0000 Subject: [PATCH 01/26] wip --- internal/manager/provider_setting.go | 24 ++++- internal/server/web/proxy/proxy.go | 3 + internal/server/web/proxy/x_custom.go | 135 ++++++++++++++++++++++++++ 3 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 internal/server/web/proxy/x_custom.go diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index d02e699..5fb9762 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -3,6 +3,7 @@ package manager import ( "encoding/json" "fmt" + "slices" "strconv" "strings" "time" @@ -39,6 +40,8 @@ type ProviderSettingsManager struct { Encryptor Encryptor } +var nativelySupportedProviders = []string{"openai", "anthropic", "azure", "vllm", "deepinfra", "bedrock", "xCustom"} + func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSettingsCache, encryptor Encryptor) *ProviderSettingsManager { return &ProviderSettingsManager{ Storage: s, @@ -48,7 +51,7 @@ func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSetting } func isProviderNativelySupported(provider string) bool { - return provider == "openai" || provider == "anthropic" || provider == "azure" || provider == "vllm" || provider == "deepinfra" || provider == "bedrock" + return slices.Contains(nativelySupportedProviders, provider) } func findMissingAuthParams(providerName string, params map[string]string) string { @@ -99,6 +102,25 @@ func findMissingAuthParams(providerName string, params map[string]string) string } } + if providerName == "xCustom" { + val := params["apikey"] + if len(val) == 0 { + missingFields = append(missingFields, "apikey") + } + val = params["endpoint"] + if len(val) == 0 { + missingFields = append(missingFields, "endpoint") + } + val = params["header"] + if len(val) == 0 { + missingFields = append(missingFields, "header") + } + val = params["maskAuth"] + if !strings.Contains(val, "{{apikey}}") { + missingFields = append(missingFields, "maskAuth") + } + } + return strings.Join(missingFields, ",") } diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go index e6078f4..d41c4e1 100644 --- a/internal/server/web/proxy/proxy.go +++ b/internal/server/web/proxy/proxy.go @@ -220,6 +220,9 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/cancel", getCancelVectorStoreFileBatchHandler(prod, client)) router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files", getListVectorStoreFileBatchFilesHandler(prod, client)) + // codio xCustom + router.Any("/api/providers/xCustom/:x_provider_id/*wildcard", getXCustomHandler(prod, client)) + srv := &http.Server{ Addr: ":8002", Handler: router, diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go new file mode 100644 index 0000000..1648a0d --- /dev/null +++ b/internal/server/web/proxy/x_custom.go @@ -0,0 +1,135 @@ +package proxy + +import ( + "bytes" + "context" + "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider" + "github.com/bricks-cloud/bricksllm/internal/util" + "github.com/gin-gonic/gin" + "github.com/go-viper/mapstructure/v2" + "io" + "net/http" + "net/http/httputil" + "strings" +) + +type XCustomSettings struct { + Apikey string `json:"apikey"` + Endpoint string `json:"endpoint"` + Header string `json:"header"` + MaskAuth string `json:"maskAuth"` +} + +func getXCustomHandler(prod bool, client http.Client) gin.HandlerFunc { + return func(c *gin.Context) { + logWithCid := util.GetLogFromCtx(c) + ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) + defer cancel() + + providerId := c.Param("x_provider_id") + rawProviderSettings, exists := c.Get("settings") + if !exists { + //fmt.Println("[BricksLLM] no settings found") + c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") + return + } + settings, ok := rawProviderSettings.([]*provider.Setting) + if !ok { + //fmt.Println("[BricksLLM] no settings found") + c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") + return + } + var providerSetting *provider.Setting + for _, setting := range settings { + if setting.Id == providerId { + providerSetting = setting + } + } + if providerSetting == nil { + c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") + return + } + + var setting *XCustomSettings + err := mapstructure.Decode(providerSetting.Setting, &setting) + if err != nil { + //logError(logWithCid, "error when unmarshalling settings", prod, err) + c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") + return + } + authHeaderVal := strings.Replace(setting.MaskAuth, "{{apikey}}", setting.Apikey, -1) + wildcard := c.Param("wildcard") + targetUrl := fmt.Sprintf("%s%s", setting.Endpoint, wildcard) + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + logError(logWithCid, "error when reading request body", prod, err) + return + } + + req, err := http.NewRequestWithContext(ctx, c.Request.Method, targetUrl, io.NopCloser(bytes.NewReader(body))) + if err != nil { + logError(logWithCid, "error when creating custom provider http request", prod, err) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create custom provider http request") + return + } + + copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent")) + // remove connection ??? + req.Header.Del("Connection") + req.Header.Del("Authorization") + req.Header.Del("Api-Key") + + req.Header.Set(setting.Header, authHeaderVal) + + dumpRequest(req) + + res, err := client.Do(req) + dumpResponse(res) + if err != nil { + //telemetry.Incr("bricksllm.proxy.get_custom_provider_handler.http_client_error", tags, 1) + //logError(logWithCid, "error when sending custom provider request", prod, err) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send custom provider request") + return + } + defer res.Body.Close() + + responseBody, err := io.ReadAll(res.Body) + if err != nil { + fmt.Printf("[BricksLLM] failed to read custom provider response body") + } + fmt.Println(string(responseBody)) + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + } +} + +func dumpRequest(r *http.Request) { + if r == nil { + fmt.Println("[BricksLLM] dumpRequest called with nil request") + return + } + dump, err := httputil.DumpRequest(r, true) + if err != nil { + fmt.Println("error dumping request", err) + return + } + fmt.Println("-----------DUMP REQUEST----------") + fmt.Println(string(dump)) + fmt.Println("===========DUMP REQUEST =============") +} + +func dumpResponse(r *http.Response) { + if r == nil { + fmt.Println("[BricksLLM] dumpResponse called with nil request") + return + } + dump, err := httputil.DumpResponse(r, true) + if err != nil { + fmt.Println("error dumping response", err) + return + } + fmt.Println("-----------DUMP RESPONSE ========---") + fmt.Println(string(dump)) + fmt.Println("============DUMP RESPONSE ============") +} From 97a801b2ceb1d843732b974cd53033edb1bd856b Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 6 Feb 2025 18:03:31 +0000 Subject: [PATCH 02/26] xcustom --- internal/authenticator/authenticator.go | 47 ++++++++- internal/provider/xcustom/xcustom.go | 33 ++++++ internal/server/web/proxy/middleware.go | 5 +- internal/server/web/proxy/proxy.go | 2 +- internal/server/web/proxy/x_custom.go | 127 +++++++----------------- 5 files changed, 116 insertions(+), 98 deletions(-) create mode 100644 internal/provider/xcustom/xcustom.go diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index bec413c..bc915a5 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -3,6 +3,8 @@ package auth import ( "errors" "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" + "github.com/go-viper/mapstructure/v2" "math/rand" "net/http" "strconv" @@ -79,6 +81,33 @@ func getApiKey(req *http.Request) (string, error) { return "", internal_errors.NewAuthError("api key not found in header") } +func (a *Authenticator) getApiKeyByXCustomProvider(req *http.Request, xCustomProviderId string) (key string, pSettings []*provider.Setting, err error) { + providerSetting, err := a.psm.GetSettingViaCache(xCustomProviderId) + if err != nil { + return + } + pSettings = []*provider.Setting{providerSetting} + setting := providerSetting.Setting + if setting == nil { + err = internal_errors.NewAuthError("provider settings not found") + return + } + fmt.Printf("xCustomProviderId: %s, setting: %+v\n", xCustomProviderId, setting) + var xCustomSetting *xcustom.XCustomSettings + err = mapstructure.Decode(setting, &xCustomSetting) + if err != nil { + err = internal_errors.NewAuthError("provider settings error") + return + } + header := req.Header.Get(xCustomSetting.Header) + key, err = xcustom.ExtractBricksKey(header, xCustomSetting.MaskAuth) + if err != nil { + err = internal_errors.NewAuthError("provider settings error") + return + } + return +} + func rewriteHttpAuthHeader(req *http.Request, setting *provider.Setting) error { uri := req.URL.RequestURI() if strings.HasPrefix(uri, "/api/routes") { @@ -204,8 +233,15 @@ func anonymize(input string) string { return string(input[0:5]) + "**********************************************" } -func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.ResponseKey, []*provider.Setting, error) { - raw, err := getApiKey(req) +func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProviderId string) (*key.ResponseKey, []*provider.Setting, error) { + var raw string + var err error + var settings []*provider.Setting + if xcustom.IsXCustomRequest(req) { + raw, settings, err = a.getApiKeyByXCustomProvider(req, xCustomProviderId) + } else { + raw, err = getApiKey(req) + } if err != nil { return nil, nil, err } @@ -238,6 +274,13 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons return nil, nil, internal_errors.NewAuthError(fmt.Sprintf("key %s has been revoked", anonymize(raw))) } + if xcustom.IsXCustomRequest(req) { + authHeaderK := settings[0].GetParam("header") + authHeaderV := strings.Replace(settings[0].GetParam("maskAuth"), "{{apikey}}", settings[0].GetParam("apikey"), -1) + req.Header.Set(authHeaderK, authHeaderV) + return key, settings, nil + } + if strings.HasPrefix(req.URL.Path, "/api/routes") { err = a.canKeyAccessCustomRoute(req.URL.Path, key.KeyId) if err != nil { diff --git a/internal/provider/xcustom/xcustom.go b/internal/provider/xcustom/xcustom.go new file mode 100644 index 0000000..2bf1965 --- /dev/null +++ b/internal/provider/xcustom/xcustom.go @@ -0,0 +1,33 @@ +package xcustom + +import ( + "fmt" + "net/http" + "regexp" + "strings" +) + +type XCustomSettings struct { + Apikey string `json:"apikey"` + Endpoint string `json:"endpoint"` + Header string `json:"header"` + MaskAuth string `json:"maskAuth"` +} + +const XProviderIdParam = "x_provider_id" + +func IsXCustomRequest(req *http.Request) bool { + return strings.HasPrefix(req.URL.RequestURI(), "/api/providers/xCustom/") +} + +func ExtractBricksKey(header, mask string) (key string, err error) { + regexStr := strings.Replace(mask, "{{apikey}}", "(?P.*)", -1) + regex := regexp.MustCompile(regexStr) + matches := regex.FindStringSubmatch(header) + if len(matches) < 2 { + err = fmt.Errorf("unable to extract bricks key") + return + } + key = strings.TrimSpace(matches[1]) + return +} diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go index 6e2a75e..5a215f5 100644 --- a/internal/server/web/proxy/middleware.go +++ b/internal/server/web/proxy/middleware.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" "io" "net/http" "strconv" @@ -67,7 +68,7 @@ type deepinfraEstimator interface { } type authenticator interface { - AuthenticateHttpRequest(req *http.Request) (*key.ResponseKey, []*provider.Setting, error) + AuthenticateHttpRequest(req *http.Request, xCustomProviderId string) (*key.ResponseKey, []*provider.Setting, error) } type validator interface { @@ -304,7 +305,7 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag return } - kc, settings, err := a.AuthenticateHttpRequest(c.Request) + kc, settings, err := a.AuthenticateHttpRequest(c.Request, c.Param(xcustom.XProviderIdParam)) enrichedEvent.Key = kc _, ok := err.(notAuthorizedError) if ok { diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go index d41c4e1..8bf70b9 100644 --- a/internal/server/web/proxy/proxy.go +++ b/internal/server/web/proxy/proxy.go @@ -221,7 +221,7 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files", getListVectorStoreFileBatchFilesHandler(prod, client)) // codio xCustom - router.Any("/api/providers/xCustom/:x_provider_id/*wildcard", getXCustomHandler(prod, client)) + router.Any("/api/providers/xCustom/:x_provider_id/*wildcard", getXCustomHandler(prod)) srv := &http.Server{ Addr: ":8002", diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index 1648a0d..6d15cdd 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -1,43 +1,44 @@ package proxy import ( - "bytes" "context" + "errors" "fmt" "github.com/bricks-cloud/bricksllm/internal/provider" + "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" + "github.com/bricks-cloud/bricksllm/internal/telemetry" "github.com/bricks-cloud/bricksllm/internal/util" "github.com/gin-gonic/gin" - "github.com/go-viper/mapstructure/v2" - "io" "net/http" "net/http/httputil" + "net/url" "strings" ) -type XCustomSettings struct { - Apikey string `json:"apikey"` - Endpoint string `json:"endpoint"` - Header string `json:"header"` - MaskAuth string `json:"maskAuth"` -} - -func getXCustomHandler(prod bool, client http.Client) gin.HandlerFunc { +func getXCustomHandler(prod bool) gin.HandlerFunc { return func(c *gin.Context) { - logWithCid := util.GetLogFromCtx(c) + log := util.GetLogFromCtx(c) + telemetry.Incr("bricksllm.proxy.get_x_custom_handler.requests", nil, 1) + + if c == nil || c.Request == nil { + JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty") + return + } + ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) defer cancel() - providerId := c.Param("x_provider_id") + providerId := c.Param(xcustom.XProviderIdParam) rawProviderSettings, exists := c.Get("settings") if !exists { - //fmt.Println("[BricksLLM] no settings found") + logError(log, "error provider setting", prod, errors.New("provider setting not found")) c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") return } settings, ok := rawProviderSettings.([]*provider.Setting) if !ok { - //fmt.Println("[BricksLLM] no settings found") - c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") + logError(log, "error provider setting", prod, errors.New("incorrect setting")) + c.JSON(http.StatusInternalServerError, "[BricksLLM] incorrect provider setting") return } var providerSetting *provider.Setting @@ -47,89 +48,29 @@ func getXCustomHandler(prod bool, client http.Client) gin.HandlerFunc { } } if providerSetting == nil { + logError(log, "error provider setting", prod, errors.New("provider setting not found")) c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") return } - - var setting *XCustomSettings - err := mapstructure.Decode(providerSetting.Setting, &setting) - if err != nil { - //logError(logWithCid, "error when unmarshalling settings", prod, err) - c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") - return - } - authHeaderVal := strings.Replace(setting.MaskAuth, "{{apikey}}", setting.Apikey, -1) wildcard := c.Param("wildcard") - targetUrl := fmt.Sprintf("%s%s", setting.Endpoint, wildcard) - - body, err := io.ReadAll(c.Request.Body) - if err != nil { - logError(logWithCid, "error when reading request body", prod, err) - return - } - - req, err := http.NewRequestWithContext(ctx, c.Request.Method, targetUrl, io.NopCloser(bytes.NewReader(body))) - if err != nil { - logError(logWithCid, "error when creating custom provider http request", prod, err) - JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create custom provider http request") + endpoint := strings.TrimSuffix(providerSetting.GetParam("endpoint"), "/") + targetUrl := fmt.Sprintf("%s%s", endpoint, wildcard) + target, e := url.Parse(targetUrl) + if e != nil { + logError(log, "error parsing target url", prod, e) + c.JSON(http.StatusInternalServerError, "[BricksLLM] invalid endpoint") return } - - copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent")) - // remove connection ??? - req.Header.Del("Connection") - req.Header.Del("Authorization") - req.Header.Del("Api-Key") - - req.Header.Set(setting.Header, authHeaderVal) - - dumpRequest(req) - - res, err := client.Do(req) - dumpResponse(res) - if err != nil { - //telemetry.Incr("bricksllm.proxy.get_custom_provider_handler.http_client_error", tags, 1) - //logError(logWithCid, "error when sending custom provider request", prod, err) - JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send custom provider request") - return - } - defer res.Body.Close() - - responseBody, err := io.ReadAll(res.Body) - if err != nil { - fmt.Printf("[BricksLLM] failed to read custom provider response body") + proxy := &httputil.ReverseProxy{ + Director: func(r *http.Request) { + r.URL.Scheme = target.Scheme + r.URL.Host = target.Host + r.URL.Path, r.URL.RawPath = target.Path, target.RawPath + r.RequestURI = target.RequestURI() + r.Host = target.Host + r = r.WithContext(ctx) + }, } - fmt.Println(string(responseBody)) - c.JSON(http.StatusOK, gin.H{"status": "ok"}) - } -} - -func dumpRequest(r *http.Request) { - if r == nil { - fmt.Println("[BricksLLM] dumpRequest called with nil request") - return - } - dump, err := httputil.DumpRequest(r, true) - if err != nil { - fmt.Println("error dumping request", err) - return - } - fmt.Println("-----------DUMP REQUEST----------") - fmt.Println(string(dump)) - fmt.Println("===========DUMP REQUEST =============") -} - -func dumpResponse(r *http.Response) { - if r == nil { - fmt.Println("[BricksLLM] dumpResponse called with nil request") - return - } - dump, err := httputil.DumpResponse(r, true) - if err != nil { - fmt.Println("error dumping response", err) - return + proxy.ServeHTTP(c.Writer, c.Request) } - fmt.Println("-----------DUMP RESPONSE ========---") - fmt.Println(string(dump)) - fmt.Println("============DUMP RESPONSE ============") } From c024be944d34298f6de4afbc12b639a426a09d31 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Wed, 12 Feb 2025 10:50:35 +0000 Subject: [PATCH 03/26] naming --- internal/authenticator/authenticator.go | 4 ++-- internal/manager/provider_setting.go | 8 ++++---- internal/provider/xcustom/xcustom.go | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index bc915a5..212264b 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -99,8 +99,8 @@ func (a *Authenticator) getApiKeyByXCustomProvider(req *http.Request, xCustomPro err = internal_errors.NewAuthError("provider settings error") return } - header := req.Header.Get(xCustomSetting.Header) - key, err = xcustom.ExtractBricksKey(header, xCustomSetting.MaskAuth) + header := req.Header.Get(xCustomSetting.AuthHeader) + key, err = xcustom.ExtractBricksKey(header, xCustomSetting.AuthTemplate) if err != nil { err = internal_errors.NewAuthError("provider settings error") return diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index 5fb9762..82b42b2 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -111,13 +111,13 @@ func findMissingAuthParams(providerName string, params map[string]string) string if len(val) == 0 { missingFields = append(missingFields, "endpoint") } - val = params["header"] + val = params["authHeader"] if len(val) == 0 { - missingFields = append(missingFields, "header") + missingFields = append(missingFields, "authHeader") } - val = params["maskAuth"] + val = params["authTemplate"] if !strings.Contains(val, "{{apikey}}") { - missingFields = append(missingFields, "maskAuth") + missingFields = append(missingFields, "authTemplate") } } diff --git a/internal/provider/xcustom/xcustom.go b/internal/provider/xcustom/xcustom.go index 2bf1965..3eb5ec0 100644 --- a/internal/provider/xcustom/xcustom.go +++ b/internal/provider/xcustom/xcustom.go @@ -8,10 +8,10 @@ import ( ) type XCustomSettings struct { - Apikey string `json:"apikey"` - Endpoint string `json:"endpoint"` - Header string `json:"header"` - MaskAuth string `json:"maskAuth"` + Apikey string `json:"apikey"` + Endpoint string `json:"endpoint"` + AuthHeader string `json:"authHeader"` + AuthTemplate string `json:"authTemplate"` } const XProviderIdParam = "x_provider_id" From 931cb2b08f37b73d857490c0f9a3dd6086fa384b Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 13 Feb 2025 21:00:53 +0000 Subject: [PATCH 04/26] reqLimit --- cmd/bricksllm/main.go | 14 +++++++-- internal/authenticator/authenticator.go | 1 - internal/errors/requests_limit_err.go | 17 +++++++++++ internal/key/key.go | 6 ++++ internal/recorder/recorder.go | 34 ++++++++++++--------- internal/storage/postgresql/key.go | 12 +++++++- internal/validator/validator.go | 40 ++++++++++++++++++++----- 7 files changed, 98 insertions(+), 26 deletions(-) create mode 100644 internal/errors/requests_limit_err.go diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go index 388e4e9..fd8da73 100644 --- a/cmd/bricksllm/main.go +++ b/cmd/bricksllm/main.go @@ -272,6 +272,14 @@ func main() { log.Sugar().Fatalf("error connecting to keys redis storage: %v", err) } + requestsLimitRedisStorage := redis.NewClient(defaultRedisOption(cfg, 11)) + + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := requestsLimitRedisStorage.Ping(ctx).Err(); err != nil { + log.Sugar().Fatalf("error connecting to requests limit redis storage: %v", err) + } + rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) @@ -285,13 +293,13 @@ func main() { psCache := redisStorage.NewProviderSettingsCache(providerSettingsRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) keysCache := redisStorage.NewKeysCache(keysRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) + requestsLimitStorage := redisStorage.NewStore(requestsLimitRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) encryptor, err := encryptor.NewEncryptor(cfg.DecryptionEndpoint, cfg.EncryptionEndpoint, cfg.EnableEncrytion, cfg.EncryptionTimeout, cfg.Audience) if cfg.EnableEncrytion && err != nil { log.Sugar().Fatalf("error creating encryption client: %v", err) } - v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage) - + v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage, requestsLimitStorage) m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache) krm := manager.NewReportingManager(costStorage, store, store, v) @@ -330,7 +338,7 @@ func main() { uv := validator.NewUserValidator(userCostLimitCache, userRateLimitCache, userCostStorage) - rec := recorder.NewRecorder(costStorage, userCostStorage, costLimitCache, userCostLimitCache, ce, store) + rec := recorder.NewRecorder(costStorage, userCostStorage, costLimitCache, userCostLimitCache, ce, store, requestsLimitStorage) rlm := manager.NewRateLimitManager(rateLimitCache, userRateLimitCache) a := auth.NewAuthenticator(psm, m, rm, store, encryptor) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index 212264b..2ebb0d4 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -92,7 +92,6 @@ func (a *Authenticator) getApiKeyByXCustomProvider(req *http.Request, xCustomPro err = internal_errors.NewAuthError("provider settings not found") return } - fmt.Printf("xCustomProviderId: %s, setting: %+v\n", xCustomProviderId, setting) var xCustomSetting *xcustom.XCustomSettings err = mapstructure.Decode(setting, &xCustomSetting) if err != nil { diff --git a/internal/errors/requests_limit_err.go b/internal/errors/requests_limit_err.go new file mode 100644 index 0000000..212e158 --- /dev/null +++ b/internal/errors/requests_limit_err.go @@ -0,0 +1,17 @@ +package errors + +type RequestsLimitError struct { + message string +} + +func NewRequestsLimitError(msg string) *RequestsLimitError { + return &RequestsLimitError{ + message: msg, + } +} + +func (rle *RequestsLimitError) Error() string { + return rle.message +} + +func (rle *RequestsLimitError) RequestsLimit() {} diff --git a/internal/key/key.go b/internal/key/key.go index 59c7b18..eca6a22 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -163,6 +163,7 @@ type RequestKey struct { CostLimitInUsdUnit TimeUnit `json:"costLimitInUsdUnit"` RateLimitOverTime int `json:"rateLimitOverTime"` RateLimitUnit TimeUnit `json:"rateLimitUnit"` + RequestsLimit int `json:"requestsLimit"` Ttl string `json:"ttl"` KeyRing string `json:"keyRing"` SettingId string `json:"settingId"` @@ -237,6 +238,10 @@ func (rk *RequestKey) Validate() error { invalid = append(invalid, "rateLimitOverTime") } + if rk.RequestsLimit < 0 { + invalid = append(invalid, "requestsLimit") + } + if len(rk.Ttl) != 0 { _, err := time.ParseDuration(rk.Ttl) if err != nil { @@ -317,6 +322,7 @@ type ResponseKey struct { CostLimitInUsdUnit TimeUnit `json:"costLimitInUsdUnit"` RateLimitOverTime int `json:"rateLimitOverTime"` RateLimitUnit TimeUnit `json:"rateLimitUnit"` + RequestsLimit int `json:"requestsLimit"` Ttl string `json:"ttl"` KeyRing string `json:"keyRing"` SettingId string `json:"settingId"` diff --git a/internal/recorder/recorder.go b/internal/recorder/recorder.go index a48f210..6e59d5d 100644 --- a/internal/recorder/recorder.go +++ b/internal/recorder/recorder.go @@ -6,12 +6,13 @@ import ( ) type Recorder struct { - s Store - c Cache - us Store - uc Cache - ce CostEstimator - es EventsStore + s Store + c Cache + us Store + uc Cache + ce CostEstimator + es EventsStore + reqLimitStore Store } type EventsStore interface { @@ -31,14 +32,15 @@ type CostEstimator interface { EstimateCompletionCost(model string, tks int) (float64, error) } -func NewRecorder(s, us Store, c, uc Cache, ce CostEstimator, es EventsStore) *Recorder { +func NewRecorder(s, us Store, c, uc Cache, ce CostEstimator, es EventsStore, reqLimitStore Store) *Recorder { return &Recorder{ - s: s, - c: c, - us: us, - uc: uc, - ce: ce, - es: es, + s: s, + c: c, + us: us, + uc: uc, + ce: ce, + es: es, + reqLimitStore: reqLimitStore, } } @@ -59,7 +61,11 @@ func (r *Recorder) RecordUserSpend(userId string, micros int64, costLimitUnit ke } func (r *Recorder) RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error { - err := r.s.IncrementCounter(keyId, micros) + err := r.reqLimitStore.IncrementCounter(keyId, 1) + if err != nil { + return err + } + err = r.s.IncrementCounter(keyId, micros) if err != nil { return err } diff --git a/internal/storage/postgresql/key.go b/internal/storage/postgresql/key.go index c129aa9..cf4085c 100644 --- a/internal/storage/postgresql/key.go +++ b/internal/storage/postgresql/key.go @@ -28,6 +28,7 @@ func (s *Store) CreateKeysTable() error { cost_limit_in_usd_unit VARCHAR(255), rate_limit_over_time INT, rate_limit_unit VARCHAR(255), + requests_limit INT, ttl VARCHAR(255), key_ring VARCHAR(255) )` @@ -56,7 +57,7 @@ func (s *Store) AlterKeysTable() error { END IF; END $$; - ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '', ADD COLUMN IF NOT EXISTS is_key_not_hashed BOOLEAN NOT NULL DEFAULT FALSE; + ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '', ADD COLUMN IF NOT EXISTS is_key_not_hashed BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS requests_limit INT; ` ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt) @@ -179,6 +180,7 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -304,6 +306,7 @@ func (s *Store) GetKeysV2(tags, keyIds []string, revoked *bool, limit, offset in &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -383,6 +386,7 @@ func (s *Store) GetKeyByHash(hash string) (*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -447,6 +451,7 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -543,6 +548,7 @@ func (s *Store) GetSpentKeyRings(tags []string, order string, limit, offset int, &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -605,6 +611,7 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -665,6 +672,7 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -843,6 +851,7 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -934,6 +943,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, + &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 7efdfb9..b6bfaac 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -21,21 +21,28 @@ type costLimitStorage interface { GetCounter(keyId string) (int64, error) } +type requestsLimitStorage interface { + GetCounter(keyId string) (int64, error) +} + type Validator struct { - clc costLimitCache - rlc rateLimitCache - cls costLimitStorage + clc costLimitCache + rlc rateLimitCache + cls costLimitStorage + rqls requestsLimitStorage } func NewValidator( clc costLimitCache, rlc rateLimitCache, cls costLimitStorage, + rqls requestsLimitStorage, ) *Validator { return &Validator{ - clc: clc, - rlc: rlc, - cls: cls, + clc: clc, + rlc: rlc, + cls: cls, + rqls: rqls, } } @@ -53,7 +60,12 @@ func (v *Validator) Validate(k *key.ResponseKey, promptCost float64) error { return internal_errors.NewExpirationError("api key expired", internal_errors.TtlExpiration) } - err := v.validateRateLimitOverTime(k.KeyId, k.RateLimitOverTime, k.RateLimitUnit) + err := v.validateRequestsLimit(k.KeyId, k.RequestsLimit) + if err != nil { + return err + } + + err = v.validateRateLimitOverTime(k.KeyId, k.RateLimitOverTime, k.RateLimitUnit) if err != nil { return err } @@ -136,3 +148,17 @@ func (v *Validator) validateCostLimit(keyId string, costLimit float64) error { return nil } + +func (v *Validator) validateRequestsLimit(keyId string, requestsLimit int) error { + if requestsLimit == 0 { + return nil + } + existingTotalRequests, err := v.rqls.GetCounter(keyId) + if err != nil { + return errors.New("failed to get total requests") + } + if existingTotalRequests >= int64(requestsLimit) { + return internal_errors.NewRequestsLimitError(fmt.Sprintf("total requests limit: %d, has been reached", requestsLimit)) + } + return nil +} From e6645045bc122fd00946a88ca8284e805ffbdf7b Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Mon, 17 Feb 2025 11:02:06 +0000 Subject: [PATCH 05/26] linked key --- internal/event/key_reporting.go | 7 ++++++- internal/manager/reporting.go | 6 +++--- internal/storage/postgresql/key.go | 10 +++++++--- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/internal/event/key_reporting.go b/internal/event/key_reporting.go index 9550861..1e0b3a2 100644 --- a/internal/event/key_reporting.go +++ b/internal/event/key_reporting.go @@ -47,8 +47,13 @@ type SpentKeyReportingRequest struct { Offset int `json:"offset"` } +type SpentKey struct { + KeyRing string `json:"keyRing"` + LinkedKeyId string `json:"linkedKeyId"` +} + type SpentKeyReportingResponse struct { - KeyRings []string `json:"keyRings"` + Keys []SpentKey `json:"keys"` } type UsageReportingRequest struct { diff --git a/internal/manager/reporting.go b/internal/manager/reporting.go index 068fcfc..f046bb2 100644 --- a/internal/manager/reporting.go +++ b/internal/manager/reporting.go @@ -14,7 +14,7 @@ type costStorage interface { type keyStorage interface { GetKey(keyId string) (*key.ResponseKey, error) - GetSpentKeyRings(tags []string, order string, limit, offset int, validator func(*key.ResponseKey) bool) ([]string, error) + GetSpentKeys(tags []string, order string, limit, offset int, validator func(*key.ResponseKey) bool) ([]event.SpentKey, error) } type keyValidator interface { @@ -160,12 +160,12 @@ func (rm *ReportingManager) GetSpentKeyReporting(r *event.SpentKeyReportingReque return true } - spentKeys, err := rm.ks.GetSpentKeyRings(r.Tags, r.Order, r.Limit, r.Offset, validator) + spentKeys, err := rm.ks.GetSpentKeys(r.Tags, r.Order, r.Limit, r.Offset, validator) if err != nil { return nil, err } return &event.SpentKeyReportingResponse{ - KeyRings: spentKeys, + Keys: spentKeys, }, nil } diff --git a/internal/storage/postgresql/key.go b/internal/storage/postgresql/key.go index cf4085c..b402afd 100644 --- a/internal/storage/postgresql/key.go +++ b/internal/storage/postgresql/key.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "github.com/bricks-cloud/bricksllm/internal/event" "strings" internal_errors "github.com/bricks-cloud/bricksllm/internal/errors" @@ -488,7 +489,7 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) { return keys[0], nil } -func (s *Store) GetSpentKeyRings(tags []string, order string, limit, offset int, validator func(*key.ResponseKey) bool) ([]string, error) { +func (s *Store) GetSpentKeys(tags []string, order string, limit, offset int, validator func(*key.ResponseKey) bool) ([]event.SpentKey, error) { args := []any{} condition := "" @@ -529,7 +530,7 @@ func (s *Store) GetSpentKeyRings(tags []string, order string, limit, offset int, } defer rows.Close() - invalidKeyRings := []string{} + var invalidKeyRings []event.SpentKey for rows.Next() { var k key.ResponseKey var settingId sql.NullString @@ -575,7 +576,10 @@ func (s *Store) GetSpentKeyRings(tags []string, order string, limit, offset int, } if !validator(pk) { - invalidKeyRings = append(invalidKeyRings, pk.KeyRing) + invalidKeyRings = append(invalidKeyRings, event.SpentKey{ + KeyRing: pk.KeyRing, + LinkedKeyId: pk.KeyId, + }) } } From 05a75dab06898d3bddf44e25451239dab45f7a6f Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Mon, 17 Feb 2025 15:41:37 +0000 Subject: [PATCH 06/26] fix create key. refactor --- internal/key/key.go | 2 +- internal/storage/postgresql/key.go | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/internal/key/key.go b/internal/key/key.go index eca6a22..4a40baf 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -163,7 +163,6 @@ type RequestKey struct { CostLimitInUsdUnit TimeUnit `json:"costLimitInUsdUnit"` RateLimitOverTime int `json:"rateLimitOverTime"` RateLimitUnit TimeUnit `json:"rateLimitUnit"` - RequestsLimit int `json:"requestsLimit"` Ttl string `json:"ttl"` KeyRing string `json:"keyRing"` SettingId string `json:"settingId"` @@ -174,6 +173,7 @@ type RequestKey struct { RotationEnabled bool `json:"rotationEnabled"` PolicyId string `json:"policyId"` IsKeyNotHashed bool `json:"isKeyNotHashed"` + RequestsLimit int `json:"requestsLimit"` } func (rk *RequestKey) Validate() error { diff --git a/internal/storage/postgresql/key.go b/internal/storage/postgresql/key.go index b402afd..8c45017 100644 --- a/internal/storage/postgresql/key.go +++ b/internal/storage/postgresql/key.go @@ -29,7 +29,6 @@ func (s *Store) CreateKeysTable() error { cost_limit_in_usd_unit VARCHAR(255), rate_limit_over_time INT, rate_limit_unit VARCHAR(255), - requests_limit INT, ttl VARCHAR(255), key_ring VARCHAR(255) )` @@ -181,7 +180,6 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -192,6 +190,7 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -307,7 +306,6 @@ func (s *Store) GetKeysV2(tags, keyIds []string, revoked *bool, limit, offset in &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -318,6 +316,7 @@ func (s *Store) GetKeysV2(tags, keyIds []string, revoked *bool, limit, offset in &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -387,7 +386,6 @@ func (s *Store) GetKeyByHash(hash string) (*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -398,6 +396,7 @@ func (s *Store) GetKeyByHash(hash string) (*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ) if err != nil { @@ -452,7 +451,6 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -463,6 +461,7 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -549,7 +548,6 @@ func (s *Store) GetSpentKeys(tags []string, order string, limit, offset int, val &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -560,6 +558,7 @@ func (s *Store) GetSpentKeys(tags []string, order string, limit, offset int, val &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -615,7 +614,6 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -626,6 +624,7 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -676,7 +675,6 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -687,6 +685,7 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -855,7 +854,6 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -866,6 +864,7 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { if err == sql.ErrNoRows { return nil, internal_errors.NewNotFoundError(fmt.Sprintf("key not found for id: %s", id)) @@ -890,8 +889,8 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { query := ` - INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, key_ring, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled, policy_id, is_key_not_hashed) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23) + INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, key_ring, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled, policy_id, is_key_not_hashed, requests_limit) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24) RETURNING *; ` @@ -924,6 +923,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { rk.RotationEnabled, rk.PolicyId, rk.IsKeyNotHashed, + rk.RequestsLimit, } ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt) @@ -947,7 +947,6 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { &k.CostLimitInUsdUnit, &k.RateLimitOverTime, &k.RateLimitUnit, - &k.RequestsLimit, &k.Ttl, &k.KeyRing, &settingId, @@ -958,6 +957,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } From fb3528a2e9d55e06dbe417beb631bb27a87edfb8 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Tue, 18 Feb 2025 16:56:52 +0000 Subject: [PATCH 07/26] return empty --- internal/storage/postgresql/key.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/storage/postgresql/key.go b/internal/storage/postgresql/key.go index 8c45017..caf2c72 100644 --- a/internal/storage/postgresql/key.go +++ b/internal/storage/postgresql/key.go @@ -529,7 +529,7 @@ func (s *Store) GetSpentKeys(tags []string, order string, limit, offset int, val } defer rows.Close() - var invalidKeyRings []event.SpentKey + invalidKeyRings := []event.SpentKey{} for rows.Next() { var k key.ResponseKey var settingId sql.NullString From cf582447cfa3386824460157cc62659c1a2fe496 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Wed, 19 Feb 2025 08:27:28 +0000 Subject: [PATCH 08/26] ALTER TABLE --- internal/storage/postgresql/key.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/storage/postgresql/key.go b/internal/storage/postgresql/key.go index caf2c72..e92f114 100644 --- a/internal/storage/postgresql/key.go +++ b/internal/storage/postgresql/key.go @@ -57,7 +57,7 @@ func (s *Store) AlterKeysTable() error { END IF; END $$; - ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '', ADD COLUMN IF NOT EXISTS is_key_not_hashed BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS requests_limit INT; + ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '', ADD COLUMN IF NOT EXISTS is_key_not_hashed BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS requests_limit INT NOT NULL DEFAULT 0; ` ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt) From 5755f6bf4bcd30187c851f585a032f1866868c1c Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Wed, 19 Feb 2025 12:52:25 +0000 Subject: [PATCH 09/26] fix auth --- internal/authenticator/authenticator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index 2ebb0d4..0490d35 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -274,8 +274,8 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid } if xcustom.IsXCustomRequest(req) { - authHeaderK := settings[0].GetParam("header") - authHeaderV := strings.Replace(settings[0].GetParam("maskAuth"), "{{apikey}}", settings[0].GetParam("apikey"), -1) + authHeaderK := settings[0].GetParam("authHeader") + authHeaderV := strings.Replace(settings[0].GetParam("authTemplate"), "{{apikey}}", settings[0].GetParam("apikey"), -1) req.Header.Set(authHeaderK, authHeaderV) return key, settings, nil } From bd42ab3176c3c410619608db5f6d76eb92dbc5e1 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Wed, 19 Feb 2025 15:30:12 +0000 Subject: [PATCH 10/26] metrics --- internal/event/key_reporting.go | 14 ++++++++++---- internal/manager/reporting.go | 4 ++-- internal/storage/postgresql/event.go | 26 ++++++++++++++++++++------ 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/internal/event/key_reporting.go b/internal/event/key_reporting.go index 1e0b3a2..6ebf3f4 100644 --- a/internal/event/key_reporting.go +++ b/internal/event/key_reporting.go @@ -29,11 +29,13 @@ type KeyRingReportingRequest struct { Limit int `json:"limit"` Offset int `json:"offset"` Revoked *bool `json:"revoked"` + TopBy string `json:"topBy"` } type KeyRingDataPoint struct { KeyRing string `json:"keyRing"` CostInUsd float64 `json:"costInUsd"` + Requests int `json:"requests"` } type KeyRingReportingResponse struct { @@ -61,10 +63,14 @@ type UsageReportingRequest struct { } type UsageData struct { - LastDayUsage float64 `json:"lastDayUsage"` - LastWeekUsage float64 `json:"lastWeekUsage"` - LastMonthUsage float64 `json:"lastMonthUsage"` - TotalUsage float64 `json:"totalUsage"` + LastDayUsage float64 `json:"lastDayUsage"` + LastWeekUsage float64 `json:"lastWeekUsage"` + LastMonthUsage float64 `json:"lastMonthUsage"` + TotalUsage float64 `json:"totalUsage"` + LastDayUsageRequests int `json:"lastDayUsageRequests"` + LastWeekUsageRequests int `json:"lastWeekUsageRequests"` + LastMonthUsageRequests int `json:"lastMonthUsageRequests"` + TotalUsageRequests int `json:"totalUsageRequests"` } type UsageReportingResponse struct { diff --git a/internal/manager/reporting.go b/internal/manager/reporting.go index f046bb2..2560a37 100644 --- a/internal/manager/reporting.go +++ b/internal/manager/reporting.go @@ -31,7 +31,7 @@ type eventStorage interface { GetCustomIds(keyId string) ([]string, error) GetTopKeyDataPoints(start, end int64, tags, keyIds []string, order string, limit, offset int, name string, revoked *bool) ([]*event.KeyDataPoint, error) - GetTopKeyRingDataPoints(start, end int64, tags []string, order string, limit, offset int, revoked *bool) ([]*event.KeyRingDataPoint, error) + GetTopKeyRingDataPoints(start, end int64, tags []string, order string, limit, offset int, revoked *bool, topBy string) ([]*event.KeyRingDataPoint, error) GetUsageData(tags []string) (*event.UsageData, error) } @@ -130,7 +130,7 @@ func (rm *ReportingManager) GetTopKeyRingReporting(r *event.KeyRingReportingRequ return nil, internal_errors.NewValidationError("key reporting request order can only be desc or asc") } - dataPoints, err := rm.es.GetTopKeyRingDataPoints(r.Start, r.End, r.Tags, r.Order, r.Limit, r.Offset, r.Revoked) + dataPoints, err := rm.es.GetTopKeyRingDataPoints(r.Start, r.End, r.Tags, r.Order, r.Limit, r.Offset, r.Revoked, r.TopBy) if err != nil { return nil, err } diff --git a/internal/storage/postgresql/event.go b/internal/storage/postgresql/event.go index 50d5a41..37e67ad 100644 --- a/internal/storage/postgresql/event.go +++ b/internal/storage/postgresql/event.go @@ -471,7 +471,7 @@ func (s *Store) GetTopKeyDataPoints(start, end int64, tags, keyIds []string, ord return data, nil } -func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order string, limit, offset int, revoked *bool) ([]*event.KeyRingDataPoint, error) { +func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order string, limit, offset int, revoked *bool, topBy string) ([]*event.KeyRingDataPoint, error) { args := []any{} condition := "" condition2 := "" @@ -517,7 +517,8 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s ( SELECT key_ring, - SUM(cost_in_usd) AS total_cost_in_usd + SUM(cost_in_usd) AS total_cost_in_usd, + COUNT(*) AS total_requests FROM events LEFT JOIN keys ON keys.key_id = events.key_id @@ -531,9 +532,13 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s qorder = "ASC" } + qtopBy := "total_cost_in_usd" + if topBy != "" { + qtopBy = topBy + } query += fmt.Sprintf(` - ORDER BY total_cost_in_usd %s -`, qorder) + ORDER BY %s %s +`, qtopBy, qorder) if limit != 0 { query += fmt.Sprintf(` @@ -558,6 +563,7 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s additional := []any{ &keyRing, &e.CostInUsd, + &e.Requests, } if err := rows.Scan( @@ -597,10 +603,14 @@ func (s *Store) GetUsageData(tags []string) (*event.UsageData, error) { COALESCE(SUM(cost_in_usd), 0) AS total_cost_in_usd, COALESCE(SUM(CASE WHEN created_at > %d THEN cost_in_usd ELSE 0 END), 0) AS total_cost_in_usd_last_day, COALESCE(SUM(CASE WHEN created_at > %d THEN cost_in_usd ELSE 0 END), 0) AS total_cost_in_usd_last_week, - COALESCE(SUM(CASE WHEN created_at > %d THEN cost_in_usd ELSE 0 END), 0) AS total_cost_in_usd_last_month + COALESCE(SUM(CASE WHEN created_at > %d THEN cost_in_usd ELSE 0 END), 0) AS total_cost_in_usd_last_month, + COALESCE(SUM(1), 0) AS total_requests, + COALESCE(SUM(CASE WHEN created_at > %d THEN 1 ELSE 0 END), 0) AS total_requests_last_day, + COALESCE(SUM(CASE WHEN created_at > %d THEN 1 ELSE 0 END), 0) AS total_requests_last_week, + COALESCE(SUM(CASE WHEN created_at > %d THEN 1 ELSE 0 END), 0) AS total_requests_last_month FROM events WHERE %s - `, dayAgo, weekAgo, monthAgo, condition) + `, dayAgo, weekAgo, monthAgo, dayAgo, weekAgo, monthAgo, condition) ctx, cancel := context.WithTimeout(context.Background(), s.rt) defer cancel() @@ -611,6 +621,10 @@ func (s *Store) GetUsageData(tags []string) (*event.UsageData, error) { &data.LastDayUsage, &data.LastWeekUsage, &data.LastMonthUsage, + &data.TotalUsageRequests, + &data.LastDayUsageRequests, + &data.LastWeekUsageRequests, + &data.LastMonthUsageRequests, ); err != nil { if err == sql.ErrNoRows { return nil, nil From 7ab7bd1e15c31f8ff36e6d4ca7a50746d925d44f Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Wed, 19 Feb 2025 15:36:16 +0000 Subject: [PATCH 11/26] allowedTopBy --- internal/storage/postgresql/event.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/storage/postgresql/event.go b/internal/storage/postgresql/event.go index 37e67ad..de2bb20 100644 --- a/internal/storage/postgresql/event.go +++ b/internal/storage/postgresql/event.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strings" "time" @@ -13,6 +14,8 @@ import ( "github.com/lib/pq" ) +var allowedTopBy = []string{"total_cost_in_usd", "total_requests"} + func (s *Store) CreateEventsByDayTable() error { createTableQuery := ` CREATE TABLE IF NOT EXISTS event_agg_by_day ( @@ -533,7 +536,7 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s } qtopBy := "total_cost_in_usd" - if topBy != "" { + if topBy != "" && slices.Contains(allowedTopBy, topBy) { qtopBy = topBy } query += fmt.Sprintf(` From e31457bdfbdb85f9e19ad04d7632817481bbe0b2 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Wed, 19 Feb 2025 22:37:36 +0000 Subject: [PATCH 12/26] refactor --- internal/message/consumer.go | 1 + internal/message/handler.go | 6 ++++++ internal/recorder/recorder.go | 10 +++++----- internal/storage/postgresql/event.go | 11 +---------- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/internal/message/consumer.go b/internal/message/consumer.go index 115f590..5d9dcd2 100644 --- a/internal/message/consumer.go +++ b/internal/message/consumer.go @@ -18,6 +18,7 @@ type recorder interface { RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error RecordUserSpend(userId string, micros int64, costLimitUnit key.TimeUnit) error RecordEvent(e *event.Event) error + RecordKeyRequestSpent(keyId string) error } func NewConsumer(mc <-chan Message, log *zap.Logger, num int, handle func(Message) error) *Consumer { diff --git a/internal/message/handler.go b/internal/message/handler.go index 1e578c1..f03f3b8 100644 --- a/internal/message/handler.go +++ b/internal/message/handler.go @@ -331,6 +331,12 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error { var u *user.User + err = h.recorder.RecordKeyRequestSpent(e.Event.KeyId) + if err != nil { + telemetry.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_key_request_spend_error", nil, 1) + h.log.Debug("error when recording key request spend", zap.Error(err)) + } + if e.Event.CostInUsd != 0 { micros := int64(e.Event.CostInUsd * 1000000) err = h.recorder.RecordKeySpend(e.Event.KeyId, micros, e.Key.CostLimitInUsdUnit) diff --git a/internal/recorder/recorder.go b/internal/recorder/recorder.go index 6e59d5d..5c8b73f 100644 --- a/internal/recorder/recorder.go +++ b/internal/recorder/recorder.go @@ -61,11 +61,7 @@ func (r *Recorder) RecordUserSpend(userId string, micros int64, costLimitUnit ke } func (r *Recorder) RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error { - err := r.reqLimitStore.IncrementCounter(keyId, 1) - if err != nil { - return err - } - err = r.s.IncrementCounter(keyId, micros) + err := r.s.IncrementCounter(keyId, micros) if err != nil { return err } @@ -80,6 +76,10 @@ func (r *Recorder) RecordKeySpend(keyId string, micros int64, costLimitUnit key. return nil } +func (r *Recorder) RecordKeyRequestSpent(keyId string) error { + return r.reqLimitStore.IncrementCounter(keyId, 1) +} + func (r *Recorder) RecordEvent(e *event.Event) error { return r.es.InsertEvent(e) } diff --git a/internal/storage/postgresql/event.go b/internal/storage/postgresql/event.go index de2bb20..274d19b 100644 --- a/internal/storage/postgresql/event.go +++ b/internal/storage/postgresql/event.go @@ -497,21 +497,12 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s } if len(tags) > 0 { - condition2 += fmt.Sprintf("AND keys.tags @> $%d", index) + condition2 += fmt.Sprintf("AND events.tags @> $%d", index) args = append(args, pq.Array(tags)) index++ } - if revoked != nil { - bools := "False" - if *revoked { - bools = "True" - } - - condition2 += fmt.Sprintf(" AND keys.revoked = %s", bools) - } - query := fmt.Sprintf(` WITH keys_table AS ( From ac29f4cfff780e57e6ea1cde51c22d204ea6ab8c Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 20 Feb 2025 14:56:39 +0000 Subject: [PATCH 13/26] logs --- internal/server/web/proxy/x_custom.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index 6d15cdd..d6418a5 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -61,6 +61,9 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { c.JSON(http.StatusInternalServerError, "[BricksLLM] invalid endpoint") return } + dumpA, _ := httputil.DumpRequest(c.Request, true) + fmt.Println("=======dumpA===========") + fmt.Println(string(dumpA)) proxy := &httputil.ReverseProxy{ Director: func(r *http.Request) { r.URL.Scheme = target.Scheme @@ -69,6 +72,9 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { r.RequestURI = target.RequestURI() r.Host = target.Host r = r.WithContext(ctx) + dumpB, _ := httputil.DumpRequest(r, true) + fmt.Println("=======dumpB===========") + fmt.Println(string(dumpB)) }, } proxy.ServeHTTP(c.Writer, c.Request) From f4534b179771df63d9e8f36a69a370dd7fbebb46 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 20 Feb 2025 15:35:16 +0000 Subject: [PATCH 14/26] wip --- internal/server/web/proxy/x_custom.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index d6418a5..3cde475 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -72,6 +72,12 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { r.RequestURI = target.RequestURI() r.Host = target.Host r = r.WithContext(ctx) + + r.Header.Del("X-Amzn-Trace-Id") + r.Header.Del("X-Forwarded-For") + r.Header.Del("X-Forwarded-Port") + r.Header.Del("X-Forwarded-Proto") + dumpB, _ := httputil.DumpRequest(r, true) fmt.Println("=======dumpB===========") fmt.Println(string(dumpB)) From bfc2632059d2dfa0cdaa5daf90c8ac7ed8bf76d0 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 20 Feb 2025 15:39:44 +0000 Subject: [PATCH 15/26] wip --- internal/server/web/proxy/x_custom.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index 3cde475..cb465ef 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -62,6 +62,10 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { return } dumpA, _ := httputil.DumpRequest(c.Request, true) + c.Request.Header.Del("X-Amzn-Trace-Id") + c.Request.Header.Del("X-Forwarded-For") + c.Request.Header.Del("X-Forwarded-Port") + c.Request.Header.Del("X-Forwarded-Proto") fmt.Println("=======dumpA===========") fmt.Println(string(dumpA)) proxy := &httputil.ReverseProxy{ From d55faa751ed99e298af984af2e1e9c393e43d3fe Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 20 Feb 2025 15:53:03 +0000 Subject: [PATCH 16/26] wip --- internal/server/web/proxy/x_custom.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index cb465ef..a633c5e 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -66,6 +66,8 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { c.Request.Header.Del("X-Forwarded-For") c.Request.Header.Del("X-Forwarded-Port") c.Request.Header.Del("X-Forwarded-Proto") + fmt.Println("=========HEADERS==============") + fmt.Println(c.Request.Header) fmt.Println("=======dumpA===========") fmt.Println(string(dumpA)) proxy := &httputil.ReverseProxy{ From f48455c10d3e22c3e79cd5a55a270128c3178fcb Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 20 Feb 2025 16:53:42 +0000 Subject: [PATCH 17/26] wip --- internal/server/web/proxy/x_custom.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index a633c5e..b07b61e 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -66,7 +66,8 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { c.Request.Header.Del("X-Forwarded-For") c.Request.Header.Del("X-Forwarded-Port") c.Request.Header.Del("X-Forwarded-Proto") - fmt.Println("=========HEADERS==============") + c.Request.Header.Del("Accept") + fmt.Println("=========HEADERS2==============") fmt.Println(c.Request.Header) fmt.Println("=======dumpA===========") fmt.Println(string(dumpA)) @@ -83,6 +84,7 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { r.Header.Del("X-Forwarded-For") r.Header.Del("X-Forwarded-Port") r.Header.Del("X-Forwarded-Proto") + r.Header.Del("Accept") dumpB, _ := httputil.DumpRequest(r, true) fmt.Println("=======dumpB===========") From 1a326f74fe9d041adc03283983a3f2e2659105ad Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 20 Feb 2025 17:08:46 +0000 Subject: [PATCH 18/26] wip --- internal/server/web/proxy/x_custom.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index b07b61e..49134f3 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -66,8 +66,7 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { c.Request.Header.Del("X-Forwarded-For") c.Request.Header.Del("X-Forwarded-Port") c.Request.Header.Del("X-Forwarded-Proto") - c.Request.Header.Del("Accept") - fmt.Println("=========HEADERS2==============") + fmt.Println("=========HEADERS3==============") fmt.Println(c.Request.Header) fmt.Println("=======dumpA===========") fmt.Println(string(dumpA)) @@ -84,7 +83,6 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { r.Header.Del("X-Forwarded-For") r.Header.Del("X-Forwarded-Port") r.Header.Del("X-Forwarded-Proto") - r.Header.Del("Accept") dumpB, _ := httputil.DumpRequest(r, true) fmt.Println("=======dumpB===========") From d29bbd1b35209d027db60048b695a69ef80ea01a Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 20 Feb 2025 17:16:38 +0000 Subject: [PATCH 19/26] validator --- internal/validator/validator.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/validator/validator.go b/internal/validator/validator.go index b6bfaac..4b32d98 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -60,12 +60,12 @@ func (v *Validator) Validate(k *key.ResponseKey, promptCost float64) error { return internal_errors.NewExpirationError("api key expired", internal_errors.TtlExpiration) } - err := v.validateRequestsLimit(k.KeyId, k.RequestsLimit) - if err != nil { - return err - } + //err := v.validateRequestsLimit(k.KeyId, k.RequestsLimit) + //if err != nil { + // return err + //} - err = v.validateRateLimitOverTime(k.KeyId, k.RateLimitOverTime, k.RateLimitUnit) + err := v.validateRateLimitOverTime(k.KeyId, k.RateLimitOverTime, k.RateLimitUnit) if err != nil { return err } From 7a1032eb5caade8e7d0b5b091f10649c7628b63c Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Thu, 20 Feb 2025 17:46:15 +0000 Subject: [PATCH 20/26] wip --- internal/server/web/proxy/x_custom.go | 43 +++++++++++++++------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index 49134f3..fe19088 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -1,7 +1,6 @@ package proxy import ( - "context" "errors" "fmt" "github.com/bricks-cloud/bricksllm/internal/provider" @@ -25,8 +24,8 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { return } - ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) - defer cancel() + //ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) + //defer cancel() providerId := c.Param(xcustom.XProviderIdParam) rawProviderSettings, exists := c.Get("settings") @@ -70,23 +69,29 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { fmt.Println(c.Request.Header) fmt.Println("=======dumpA===========") fmt.Println(string(dumpA)) + //proxy := &httputil.ReverseProxy{ + // Director: func(r *http.Request) { + // r.URL.Scheme = target.Scheme + // r.URL.Host = target.Host + // r.URL.Path, r.URL.RawPath = target.Path, target.RawPath + // r.RequestURI = target.RequestURI() + // r.Host = target.Host + // r = r.WithContext(ctx) + // + // r.Header.Del("X-Amzn-Trace-Id") + // r.Header.Del("X-Forwarded-For") + // r.Header.Del("X-Forwarded-Port") + // r.Header.Del("X-Forwarded-Proto") + // + // dumpB, _ := httputil.DumpRequest(r, true) + // fmt.Println("=======dumpB===========") + // fmt.Println(string(dumpB)) + // }, + //} proxy := &httputil.ReverseProxy{ - Director: func(r *http.Request) { - r.URL.Scheme = target.Scheme - r.URL.Host = target.Host - r.URL.Path, r.URL.RawPath = target.Path, target.RawPath - r.RequestURI = target.RequestURI() - r.Host = target.Host - r = r.WithContext(ctx) - - r.Header.Del("X-Amzn-Trace-Id") - r.Header.Del("X-Forwarded-For") - r.Header.Del("X-Forwarded-Port") - r.Header.Del("X-Forwarded-Proto") - - dumpB, _ := httputil.DumpRequest(r, true) - fmt.Println("=======dumpB===========") - fmt.Println(string(dumpB)) + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(target) + r.Out.URL.Path, r.Out.URL.RawPath = target.Path, target.RawPath }, } proxy.ServeHTTP(c.Writer, c.Request) From 95f618b5ba9bde0b3ef1ffcf81ee23d7f9db91aa Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Fri, 21 Feb 2025 14:33:50 +0000 Subject: [PATCH 21/26] refactor --- cmd/bricksllm/main.go | 2 +- internal/authenticator/authenticator.go | 2 +- internal/errors/expiration_err.go | 5 ++-- internal/errors/requests_limit_err.go | 17 ------------- internal/key/key.go | 5 ++++ internal/manager/key.go | 34 +++++++++++++++++-------- internal/manager/provider_setting.go | 4 +-- internal/provider/xcustom/xcustom.go | 2 +- internal/server/web/proxy/x_custom.go | 31 ++++------------------ internal/validator/validator.go | 12 ++++----- 10 files changed, 47 insertions(+), 67 deletions(-) delete mode 100644 internal/errors/requests_limit_err.go diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go index fd8da73..0bdd4b3 100644 --- a/cmd/bricksllm/main.go +++ b/cmd/bricksllm/main.go @@ -301,7 +301,7 @@ func main() { } v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage, requestsLimitStorage) - m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache) + m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache, requestsLimitStorage) krm := manager.NewReportingManager(costStorage, store, store, v) psm := manager.NewProviderSettingsManager(store, psCache, encryptor) cpm := manager.NewCustomProvidersManager(store, cpMemStore) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index 0490d35..4101ec3 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -98,7 +98,7 @@ func (a *Authenticator) getApiKeyByXCustomProvider(req *http.Request, xCustomPro err = internal_errors.NewAuthError("provider settings error") return } - header := req.Header.Get(xCustomSetting.AuthHeader) + header := req.Header.Get(xCustomSetting.AuthLocation) key, err = xcustom.ExtractBricksKey(header, xCustomSetting.AuthTemplate) if err != nil { err = internal_errors.NewAuthError("provider settings error") diff --git a/internal/errors/expiration_err.go b/internal/errors/expiration_err.go index cce8380..59aec34 100644 --- a/internal/errors/expiration_err.go +++ b/internal/errors/expiration_err.go @@ -1,8 +1,9 @@ package errors const ( - TtlExpiration string = "ttl" - CostLimitExpiration string = "cost-limit" + TtlExpiration string = "ttl" + CostLimitExpiration string = "cost-limit" + RequestsLimitExpiration string = "requests-limit" ) type ExpirationError struct { diff --git a/internal/errors/requests_limit_err.go b/internal/errors/requests_limit_err.go deleted file mode 100644 index 212e158..0000000 --- a/internal/errors/requests_limit_err.go +++ /dev/null @@ -1,17 +0,0 @@ -package errors - -type RequestsLimitError struct { - message string -} - -func NewRequestsLimitError(msg string) *RequestsLimitError { - return &RequestsLimitError{ - message: msg, - } -} - -func (rle *RequestsLimitError) Error() string { - return rle.message -} - -func (rle *RequestsLimitError) RequestsLimit() {} diff --git a/internal/key/key.go b/internal/key/key.go index 4a40baf..94eb537 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -25,6 +25,7 @@ type UpdateKey struct { CostLimitInUsdUnit *TimeUnit `json:"costLimitInUsdUnit"` RateLimitOverTime *int `json:"rateLimitOverTime"` RateLimitUnit *TimeUnit `json:"rateLimitUnit"` + RequestsLimit *int `json:"requestsLimit"` AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"` ShouldLogRequest *bool `json:"shouldLogRequest"` ShouldLogResponse *bool `json:"shouldLogResponse"` @@ -51,6 +52,10 @@ func (uk *UpdateKey) Validate() error { invalid = append(invalid, "costLimitInUsd") } + if uk.RequestsLimit != nil && *uk.RequestsLimit < 0 { + invalid = append(invalid, "requestsLimit") + } + if uk.UpdatedAt <= 0 { invalid = append(invalid, "updatedAt") } diff --git a/internal/manager/key.go b/internal/manager/key.go index 3401db4..f4f5b2d 100644 --- a/internal/manager/key.go +++ b/internal/manager/key.go @@ -46,21 +46,27 @@ type keyCache interface { Get(keyId string) (*key.ResponseKey, error) } +type requestsLimitStorage interface { + DeleteCounter(keyId string) error +} + type Manager struct { - s Storage - clc costLimitCache - rlc rateLimitCache - ac accessCache - kc keyCache + s Storage + clc costLimitCache + rlc rateLimitCache + ac accessCache + kc keyCache + rqls requestsLimitStorage } -func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache, kc keyCache) *Manager { +func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache, kc keyCache, rqls requestsLimitStorage) *Manager { return &Manager{ - s: s, - clc: clc, - rlc: rlc, - ac: ac, - kc: kc, + s: s, + clc: clc, + rlc: rlc, + ac: ac, + kc: kc, + rqls: rqls, } } @@ -175,6 +181,12 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err return nil, err } } + if uk.RequestsLimit != nil { + err := m.rqls.DeleteCounter(id) + if err != nil { + return nil, err + } + } if uk.PolicyId != nil { if len(*uk.PolicyId) != 0 { diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index 82b42b2..75a84da 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -111,9 +111,9 @@ func findMissingAuthParams(providerName string, params map[string]string) string if len(val) == 0 { missingFields = append(missingFields, "endpoint") } - val = params["authHeader"] + val = params["authLocation"] if len(val) == 0 { - missingFields = append(missingFields, "authHeader") + missingFields = append(missingFields, "authLocation") } val = params["authTemplate"] if !strings.Contains(val, "{{apikey}}") { diff --git a/internal/provider/xcustom/xcustom.go b/internal/provider/xcustom/xcustom.go index 3eb5ec0..a9443da 100644 --- a/internal/provider/xcustom/xcustom.go +++ b/internal/provider/xcustom/xcustom.go @@ -10,7 +10,7 @@ import ( type XCustomSettings struct { Apikey string `json:"apikey"` Endpoint string `json:"endpoint"` - AuthHeader string `json:"authHeader"` + AuthLocation string `json:"authLocation"` AuthTemplate string `json:"authTemplate"` } diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index fe19088..fea585c 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "errors" "fmt" "github.com/bricks-cloud/bricksllm/internal/provider" @@ -24,8 +25,8 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { return } - //ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) - //defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) + defer cancel() providerId := c.Param(xcustom.XProviderIdParam) rawProviderSettings, exists := c.Get("settings") @@ -60,38 +61,16 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { c.JSON(http.StatusInternalServerError, "[BricksLLM] invalid endpoint") return } - dumpA, _ := httputil.DumpRequest(c.Request, true) c.Request.Header.Del("X-Amzn-Trace-Id") c.Request.Header.Del("X-Forwarded-For") c.Request.Header.Del("X-Forwarded-Port") c.Request.Header.Del("X-Forwarded-Proto") - fmt.Println("=========HEADERS3==============") - fmt.Println(c.Request.Header) - fmt.Println("=======dumpA===========") - fmt.Println(string(dumpA)) - //proxy := &httputil.ReverseProxy{ - // Director: func(r *http.Request) { - // r.URL.Scheme = target.Scheme - // r.URL.Host = target.Host - // r.URL.Path, r.URL.RawPath = target.Path, target.RawPath - // r.RequestURI = target.RequestURI() - // r.Host = target.Host - // r = r.WithContext(ctx) - // - // r.Header.Del("X-Amzn-Trace-Id") - // r.Header.Del("X-Forwarded-For") - // r.Header.Del("X-Forwarded-Port") - // r.Header.Del("X-Forwarded-Proto") - // - // dumpB, _ := httputil.DumpRequest(r, true) - // fmt.Println("=======dumpB===========") - // fmt.Println(string(dumpB)) - // }, - //} + proxy := &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(target) r.Out.URL.Path, r.Out.URL.RawPath = target.Path, target.RawPath + r.Out.WithContext(ctx) }, } proxy.ServeHTTP(c.Writer, c.Request) diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 4b32d98..c3df819 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -60,12 +60,12 @@ func (v *Validator) Validate(k *key.ResponseKey, promptCost float64) error { return internal_errors.NewExpirationError("api key expired", internal_errors.TtlExpiration) } - //err := v.validateRequestsLimit(k.KeyId, k.RequestsLimit) - //if err != nil { - // return err - //} + err := v.validateRequestsLimit(k.KeyId, k.RequestsLimit) + if err != nil { + return err + } - err := v.validateRateLimitOverTime(k.KeyId, k.RateLimitOverTime, k.RateLimitUnit) + err = v.validateRateLimitOverTime(k.KeyId, k.RateLimitOverTime, k.RateLimitUnit) if err != nil { return err } @@ -158,7 +158,7 @@ func (v *Validator) validateRequestsLimit(keyId string, requestsLimit int) error return errors.New("failed to get total requests") } if existingTotalRequests >= int64(requestsLimit) { - return internal_errors.NewRequestsLimitError(fmt.Sprintf("total requests limit: %d, has been reached", requestsLimit)) + return internal_errors.NewExpirationError(fmt.Sprintf("total requests limit: %d, has been reached", requestsLimit), internal_errors.RequestsLimitExpiration) } return nil } From 479891d5c370a7c669b5c46ae918bfb9f8c28c2f Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Mon, 24 Feb 2025 16:50:40 +0000 Subject: [PATCH 22/26] extend auth --- internal/authenticator/authenticator.go | 37 ++++++++----- internal/provider/xcustom/xcustom.go | 71 ++++++++++++++++++++++--- 2 files changed, 90 insertions(+), 18 deletions(-) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index 4101ec3..6ed81c5 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -81,12 +81,7 @@ func getApiKey(req *http.Request) (string, error) { return "", internal_errors.NewAuthError("api key not found in header") } -func (a *Authenticator) getApiKeyByXCustomProvider(req *http.Request, xCustomProviderId string) (key string, pSettings []*provider.Setting, err error) { - providerSetting, err := a.psm.GetSettingViaCache(xCustomProviderId) - if err != nil { - return - } - pSettings = []*provider.Setting{providerSetting} +func getXCustomAuth(req *http.Request, providerSetting *provider.Setting) (key *xcustom.XCustomAuth, pSettings []*provider.Setting, err error) { setting := providerSetting.Setting if setting == nil { err = internal_errors.NewAuthError("provider settings not found") @@ -98,8 +93,7 @@ func (a *Authenticator) getApiKeyByXCustomProvider(req *http.Request, xCustomPro err = internal_errors.NewAuthError("provider settings error") return } - header := req.Header.Get(xCustomSetting.AuthLocation) - key, err = xcustom.ExtractBricksKey(header, xCustomSetting.AuthTemplate) + key, err = xcustom.GetXCustomAuth(req, xCustomSetting) if err != nil { err = internal_errors.NewAuthError("provider settings error") return @@ -236,8 +230,18 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid var raw string var err error var settings []*provider.Setting + var xCustomAuth *xcustom.XCustomAuth if xcustom.IsXCustomRequest(req) { - raw, settings, err = a.getApiKeyByXCustomProvider(req, xCustomProviderId) + providerSetting, er := a.psm.GetSettingViaCache(xCustomProviderId) + if er != nil { + return nil, nil, er + } + settings = []*provider.Setting{providerSetting} + xCustomAuth, settings, err = getXCustomAuth(req, providerSetting) + if err != nil { + return nil, nil, err + } + raw = xCustomAuth.Apikey } else { raw, err = getApiKey(req) } @@ -274,9 +278,18 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid } if xcustom.IsXCustomRequest(req) { - authHeaderK := settings[0].GetParam("authHeader") - authHeaderV := strings.Replace(settings[0].GetParam("authTemplate"), "{{apikey}}", settings[0].GetParam("apikey"), -1) - req.Header.Set(authHeaderK, authHeaderV) + if xCustomAuth == nil { + return nil, nil, errors.New("xCustomAuth invalid") + } + authString := strings.Replace(xCustomAuth.Mask, "{{apikey}}", settings[0].GetParam("apikey"), -1) + switch xCustomAuth.Location { + case xcustom.AuthLocations.Query: + req.URL.Query().Set(xCustomAuth.Target, authString) + case xcustom.AuthLocations.Header: + req.Header.Set(xCustomAuth.Target, authString) + default: + return nil, nil, errors.New("invalid xCustomAuth location") + } return key, settings, nil } diff --git a/internal/provider/xcustom/xcustom.go b/internal/provider/xcustom/xcustom.go index a9443da..009a6a2 100644 --- a/internal/provider/xcustom/xcustom.go +++ b/internal/provider/xcustom/xcustom.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "regexp" + "strings" ) @@ -14,20 +15,78 @@ type XCustomSettings struct { AuthTemplate string `json:"authTemplate"` } +type AuthLocation string + +var AuthLocations = struct { + Header AuthLocation + Query AuthLocation +}{ + Header: AuthLocation("header"), + Query: AuthLocation("query"), +} + +type XCustomAuth struct { + Apikey string + Location AuthLocation + Target string + Mask string +} + const XProviderIdParam = "x_provider_id" func IsXCustomRequest(req *http.Request) bool { return strings.HasPrefix(req.URL.RequestURI(), "/api/providers/xCustom/") } -func ExtractBricksKey(header, mask string) (key string, err error) { +func GetXCustomAuth(req *http.Request, xSettings *XCustomSettings) (*XCustomAuth, error) { + authLocation := getAuthLocation(xSettings.AuthLocation) + var templateSeparator string + switch authLocation { + case AuthLocations.Header: + templateSeparator = ":" + case AuthLocations.Query: + templateSeparator = "=" + default: + return nil, fmt.Errorf("unknown auth location: %s", authLocation) + } + templateArr := strings.Split(xSettings.AuthTemplate, templateSeparator) + if len(templateArr) != 2 { + return nil, fmt.Errorf("invalid auth template: %s", xSettings.AuthTemplate) + } + target := strings.TrimSpace(templateArr[0]) + mask := strings.TrimSpace(templateArr[1]) + var reqAuthStr string + switch authLocation { + case AuthLocations.Header: + reqAuthStr = req.Header.Get(target) + case AuthLocations.Query: + reqAuthStr = req.URL.Query().Get(target) + default: + return nil, fmt.Errorf("unknown auth location: %s", authLocation) + } regexStr := strings.Replace(mask, "{{apikey}}", "(?P.*)", -1) regex := regexp.MustCompile(regexStr) - matches := regex.FindStringSubmatch(header) + matches := regex.FindStringSubmatch(reqAuthStr) if len(matches) < 2 { - err = fmt.Errorf("unable to extract bricks key") - return + return nil, fmt.Errorf("invalid auth template: %s", xSettings.AuthTemplate) + } + key := strings.TrimSpace(matches[1]) + + return &XCustomAuth{ + Apikey: key, + Location: authLocation, + Target: target, + Mask: mask, + }, nil +} + +func getAuthLocation(raw string) AuthLocation { + switch raw { + case "header": + return AuthLocations.Header + case "query": + return AuthLocations.Query + default: + return AuthLocations.Header } - key = strings.TrimSpace(matches[1]) - return } From 7219711b7384bb18a65eb35d4c3fee264e2a43c0 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Tue, 25 Feb 2025 11:57:00 +0000 Subject: [PATCH 23/26] refactor --- internal/authenticator/authenticator.go | 46 ++++---------- internal/manager/provider_setting.go | 23 +++++++ internal/provider/xcustom/xcustom.go | 79 ++++++++++++++----------- internal/server/web/proxy/x_custom.go | 4 -- 4 files changed, 79 insertions(+), 73 deletions(-) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index 6ed81c5..149cb34 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" - "github.com/go-viper/mapstructure/v2" "math/rand" "net/http" "strconv" @@ -81,26 +80,6 @@ func getApiKey(req *http.Request) (string, error) { return "", internal_errors.NewAuthError("api key not found in header") } -func getXCustomAuth(req *http.Request, providerSetting *provider.Setting) (key *xcustom.XCustomAuth, pSettings []*provider.Setting, err error) { - setting := providerSetting.Setting - if setting == nil { - err = internal_errors.NewAuthError("provider settings not found") - return - } - var xCustomSetting *xcustom.XCustomSettings - err = mapstructure.Decode(setting, &xCustomSetting) - if err != nil { - err = internal_errors.NewAuthError("provider settings error") - return - } - key, err = xcustom.GetXCustomAuth(req, xCustomSetting) - if err != nil { - err = internal_errors.NewAuthError("provider settings error") - return - } - return -} - func rewriteHttpAuthHeader(req *http.Request, setting *provider.Setting) error { uri := req.URL.RequestURI() if strings.HasPrefix(uri, "/api/routes") { @@ -230,18 +209,13 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid var raw string var err error var settings []*provider.Setting - var xCustomAuth *xcustom.XCustomAuth if xcustom.IsXCustomRequest(req) { providerSetting, er := a.psm.GetSettingViaCache(xCustomProviderId) if er != nil { return nil, nil, er } settings = []*provider.Setting{providerSetting} - xCustomAuth, settings, err = getXCustomAuth(req, providerSetting) - if err != nil { - return nil, nil, err - } - raw = xCustomAuth.Apikey + raw, err = xcustom.ExtractApiKey(req, providerSetting) } else { raw, err = getApiKey(req) } @@ -278,15 +252,19 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid } if xcustom.IsXCustomRequest(req) { - if xCustomAuth == nil { - return nil, nil, errors.New("xCustomAuth invalid") - } - authString := strings.Replace(xCustomAuth.Mask, "{{apikey}}", settings[0].GetParam("apikey"), -1) - switch xCustomAuth.Location { + pSetting := settings[0] + authString := strings.Replace( + pSetting.GetParam(xcustom.XCustomSettingFields.AuthMask), + "{{apikey}}", + pSetting.GetParam(xcustom.XCustomSettingFields.ApiKey), -1, + ) + location := xcustom.GetAuthLocation(pSetting.GetParam(xcustom.XCustomSettingFields.AuthLocation)) + target := pSetting.GetParam(xcustom.XCustomSettingFields.AuthTarget) + switch location { case xcustom.AuthLocations.Query: - req.URL.Query().Set(xCustomAuth.Target, authString) + req.URL.Query().Set(target, authString) case xcustom.AuthLocations.Header: - req.Header.Set(xCustomAuth.Target, authString) + req.Header.Set(target, authString) default: return nil, nil, errors.New("invalid xCustomAuth location") } diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index 75a84da..6bc2cd2 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -3,6 +3,7 @@ package manager import ( "encoding/json" "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" "slices" "strconv" "strings" @@ -182,6 +183,18 @@ func (m *ProviderSettingsManager) CreateSetting(setting *provider.Setting) (*pro setting.CreatedAt = time.Now().Unix() setting.UpdatedAt = time.Now().Unix() + if setting.Provider == "xCustom" { + advancedSetting, err := xcustom.AdvancedXCustomSetting(setting.Setting) + if err != nil { + return nil, err + } + merged := setting.Setting + for k, v := range advancedSetting { + merged[k] = v + } + setting.Setting = merged + } + if m.Encryptor.Enabled() { params, err := m.EncryptParams(setting.UpdatedAt, setting.Provider, setting.Setting) if err != nil { @@ -214,6 +227,16 @@ func (m *ProviderSettingsManager) UpdateSetting(id string, setting *provider.Upd merged[k] = v } + if existing.Provider == "xCustom" { + advancedSetting, err := xcustom.AdvancedXCustomSetting(setting.Setting) + if err != nil { + return nil, err + } + for k, v := range advancedSetting { + merged[k] = v + } + } + setting.Setting = merged } diff --git a/internal/provider/xcustom/xcustom.go b/internal/provider/xcustom/xcustom.go index 009a6a2..746ce86 100644 --- a/internal/provider/xcustom/xcustom.go +++ b/internal/provider/xcustom/xcustom.go @@ -2,34 +2,39 @@ package xcustom import ( "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider" "net/http" "regexp" "strings" ) -type XCustomSettings struct { - Apikey string `json:"apikey"` - Endpoint string `json:"endpoint"` - AuthLocation string `json:"authLocation"` - AuthTemplate string `json:"authTemplate"` +var XCustomSettingFields = struct { + ApiKey string + Endpoint string + AuthLocation string + AuthTemplate string + AuthTarget string + AuthMask string +}{ + ApiKey: "apikey", + Endpoint: "endpoint", + AuthLocation: "authLocation", + AuthTemplate: "authTemplate", + AuthTarget: "authTarget", + AuthMask: "authMask", } type AuthLocation string var AuthLocations = struct { - Header AuthLocation - Query AuthLocation + Header AuthLocation + Query AuthLocation + Unknown AuthLocation }{ - Header: AuthLocation("header"), - Query: AuthLocation("query"), -} - -type XCustomAuth struct { - Apikey string - Location AuthLocation - Target string - Mask string + Header: AuthLocation("header"), + Query: AuthLocation("query"), + Unknown: AuthLocation("unknown"), } const XProviderIdParam = "x_provider_id" @@ -38,55 +43,59 @@ func IsXCustomRequest(req *http.Request) bool { return strings.HasPrefix(req.URL.RequestURI(), "/api/providers/xCustom/") } -func GetXCustomAuth(req *http.Request, xSettings *XCustomSettings) (*XCustomAuth, error) { - authLocation := getAuthLocation(xSettings.AuthLocation) +func AdvancedXCustomSetting(src map[string]string) (map[string]string, error) { + rawLocation := src[XCustomSettingFields.AuthLocation] + location := GetAuthLocation(rawLocation) var templateSeparator string - switch authLocation { + switch location { case AuthLocations.Header: templateSeparator = ":" case AuthLocations.Query: templateSeparator = "=" default: - return nil, fmt.Errorf("unknown auth location: %s", authLocation) + return nil, fmt.Errorf("unknown auth location: %s", location) } - templateArr := strings.Split(xSettings.AuthTemplate, templateSeparator) + templateArr := strings.Split(src[XCustomSettingFields.AuthTemplate], templateSeparator) if len(templateArr) != 2 { - return nil, fmt.Errorf("invalid auth template: %s", xSettings.AuthTemplate) + return nil, fmt.Errorf("invalid auth template: %s", src[XCustomSettingFields.AuthTemplate]) } target := strings.TrimSpace(templateArr[0]) mask := strings.TrimSpace(templateArr[1]) + return map[string]string{ + XCustomSettingFields.AuthTarget: target, + XCustomSettingFields.AuthMask: mask, + }, nil +} + +func ExtractApiKey(req *http.Request, pSetting *provider.Setting) (string, error) { + location := GetAuthLocation(pSetting.GetParam(XCustomSettingFields.AuthLocation)) + target := strings.TrimSpace(pSetting.GetParam(XCustomSettingFields.AuthTarget)) var reqAuthStr string - switch authLocation { + switch location { case AuthLocations.Header: reqAuthStr = req.Header.Get(target) case AuthLocations.Query: reqAuthStr = req.URL.Query().Get(target) default: - return nil, fmt.Errorf("unknown auth location: %s", authLocation) + return "", fmt.Errorf("unknown auth location: %s", location) } + mask := strings.TrimSpace(pSetting.GetParam(XCustomSettingFields.AuthMask)) regexStr := strings.Replace(mask, "{{apikey}}", "(?P.*)", -1) regex := regexp.MustCompile(regexStr) matches := regex.FindStringSubmatch(reqAuthStr) if len(matches) < 2 { - return nil, fmt.Errorf("invalid auth template: %s", xSettings.AuthTemplate) + return "", fmt.Errorf("error extracting apikey: %s", pSetting.Id) } - key := strings.TrimSpace(matches[1]) - - return &XCustomAuth{ - Apikey: key, - Location: authLocation, - Target: target, - Mask: mask, - }, nil + return strings.TrimSpace(matches[1]), nil } -func getAuthLocation(raw string) AuthLocation { +func GetAuthLocation(raw string) AuthLocation { switch raw { case "header": return AuthLocations.Header case "query": return AuthLocations.Query default: - return AuthLocations.Header + return AuthLocations.Unknown } } diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go index fea585c..c0132f6 100644 --- a/internal/server/web/proxy/x_custom.go +++ b/internal/server/web/proxy/x_custom.go @@ -61,10 +61,6 @@ func getXCustomHandler(prod bool) gin.HandlerFunc { c.JSON(http.StatusInternalServerError, "[BricksLLM] invalid endpoint") return } - c.Request.Header.Del("X-Amzn-Trace-Id") - c.Request.Header.Del("X-Forwarded-For") - c.Request.Header.Del("X-Forwarded-Port") - c.Request.Header.Del("X-Forwarded-Proto") proxy := &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { From 54d0917170497b280033afca2442d4547f02b532 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Tue, 25 Feb 2025 15:14:32 +0000 Subject: [PATCH 24/26] query --- internal/authenticator/authenticator.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index 149cb34..638a9ee 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -262,7 +262,9 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid target := pSetting.GetParam(xcustom.XCustomSettingFields.AuthTarget) switch location { case xcustom.AuthLocations.Query: - req.URL.Query().Set(target, authString) + params := req.URL.Query() + params.Set(target, authString) + req.URL.RawQuery = params.Encode() case xcustom.AuthLocations.Header: req.Header.Set(target, authString) default: From 53ba3a69cae515fc0c16eafba6bb8a015d5aa8f9 Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Fri, 28 Feb 2025 13:42:39 +0000 Subject: [PATCH 25/26] fix update settings --- internal/manager/provider_setting.go | 44 +++++++++++++++++----------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index 6bc2cd2..5952214 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -218,25 +218,10 @@ func (m *ProviderSettingsManager) UpdateSetting(id string, setting *provider.Upd } if len(setting.Setting) != 0 { - if err := m.validateSettings(existing.Provider, setting.Setting); err != nil { + merged, err := m.getMergedSettings(existing, setting.Setting) + if err != nil { return nil, err } - - merged := existing.Setting - for k, v := range setting.Setting { - merged[k] = v - } - - if existing.Provider == "xCustom" { - advancedSetting, err := xcustom.AdvancedXCustomSetting(setting.Setting) - if err != nil { - return nil, err - } - for k, v := range advancedSetting { - merged[k] = v - } - } - setting.Setting = merged } @@ -259,6 +244,31 @@ func (m *ProviderSettingsManager) UpdateSetting(id string, setting *provider.Upd return m.Storage.UpdateProviderSetting(id, setting) } +func (m *ProviderSettingsManager) getMergedSettings(existing *provider.Setting, setting map[string]string) (map[string]string, error) { + merged := existing.Setting + apikey, ok := setting["apikey"] + if ok && apikey == "revoked" { + merged["apikey"] = apikey + return merged, nil + } + if err := m.validateSettings(existing.Provider, setting); err != nil { + return nil, err + } + for k, v := range setting { + merged[k] = v + } + if existing.Provider == "xCustom" { + advancedSetting, err := xcustom.AdvancedXCustomSetting(setting) + if err != nil { + return nil, err + } + for k, v := range advancedSetting { + merged[k] = v + } + } + return merged, nil +} + func (m *ProviderSettingsManager) GetSettingViaCache(id string) (*provider.Setting, error) { setting, _ := m.Cache.Get(id) From a6f99e8c607caa7465db76fd436c96c66ccf504e Mon Sep 17 00:00:00 2001 From: Sergei Bronnikov Date: Fri, 28 Feb 2025 16:42:53 +0000 Subject: [PATCH 26/26] validate --- internal/manager/provider_setting.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index 5952214..c67fc0a 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -251,9 +251,6 @@ func (m *ProviderSettingsManager) getMergedSettings(existing *provider.Setting, merged["apikey"] = apikey return merged, nil } - if err := m.validateSettings(existing.Provider, setting); err != nil { - return nil, err - } for k, v := range setting { merged[k] = v } @@ -266,6 +263,9 @@ func (m *ProviderSettingsManager) getMergedSettings(existing *provider.Setting, merged[k] = v } } + if err := m.validateSettings(existing.Provider, merged); err != nil { + return nil, err + } return merged, nil }