diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go index 388e4e9..0bdd4b3 100644 --- a/cmd/bricksllm/main.go +++ b/cmd/bricksllm/main.go @@ -272,6 +272,14 @@ func main() { log.Sugar().Fatalf("error connecting to keys redis storage: %v", err) } + requestsLimitRedisStorage := redis.NewClient(defaultRedisOption(cfg, 11)) + + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := requestsLimitRedisStorage.Ping(ctx).Err(); err != nil { + log.Sugar().Fatalf("error connecting to requests limit redis storage: %v", err) + } + rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) @@ -285,15 +293,15 @@ func main() { psCache := redisStorage.NewProviderSettingsCache(providerSettingsRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) keysCache := redisStorage.NewKeysCache(keysRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) + requestsLimitStorage := redisStorage.NewStore(requestsLimitRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) encryptor, err := encryptor.NewEncryptor(cfg.DecryptionEndpoint, cfg.EncryptionEndpoint, cfg.EnableEncrytion, cfg.EncryptionTimeout, cfg.Audience) if cfg.EnableEncrytion && err != nil { log.Sugar().Fatalf("error creating encryption client: %v", err) } - v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage) - + v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage, requestsLimitStorage) - m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache) + m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache, requestsLimitStorage) krm := manager.NewReportingManager(costStorage, store, store, v) psm := manager.NewProviderSettingsManager(store, psCache, encryptor) cpm := manager.NewCustomProvidersManager(store, cpMemStore) @@ -330,7 +338,7 @@ func main() { uv := validator.NewUserValidator(userCostLimitCache, userRateLimitCache, userCostStorage) - rec := recorder.NewRecorder(costStorage, userCostStorage, costLimitCache, userCostLimitCache, ce, store) + rec := recorder.NewRecorder(costStorage, userCostStorage, costLimitCache, userCostLimitCache, ce, store, requestsLimitStorage) rlm := manager.NewRateLimitManager(rateLimitCache, userRateLimitCache) a := auth.NewAuthenticator(psm, m, rm, store, encryptor) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index bec413c..638a9ee 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -3,6 +3,7 @@ package auth import ( "errors" "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" "math/rand" "net/http" "strconv" @@ -204,8 +205,20 @@ func anonymize(input string) string { return string(input[0:5]) + "**********************************************" } -func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.ResponseKey, []*provider.Setting, error) { - raw, err := getApiKey(req) +func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProviderId string) (*key.ResponseKey, []*provider.Setting, error) { + var raw string + var err error + var settings []*provider.Setting + if xcustom.IsXCustomRequest(req) { + providerSetting, er := a.psm.GetSettingViaCache(xCustomProviderId) + if er != nil { + return nil, nil, er + } + settings = []*provider.Setting{providerSetting} + raw, err = xcustom.ExtractApiKey(req, providerSetting) + } else { + raw, err = getApiKey(req) + } if err != nil { return nil, nil, err } @@ -238,6 +251,28 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons return nil, nil, internal_errors.NewAuthError(fmt.Sprintf("key %s has been revoked", anonymize(raw))) } + if xcustom.IsXCustomRequest(req) { + pSetting := settings[0] + authString := strings.Replace( + pSetting.GetParam(xcustom.XCustomSettingFields.AuthMask), + "{{apikey}}", + pSetting.GetParam(xcustom.XCustomSettingFields.ApiKey), -1, + ) + location := xcustom.GetAuthLocation(pSetting.GetParam(xcustom.XCustomSettingFields.AuthLocation)) + target := pSetting.GetParam(xcustom.XCustomSettingFields.AuthTarget) + switch location { + case xcustom.AuthLocations.Query: + params := req.URL.Query() + params.Set(target, authString) + req.URL.RawQuery = params.Encode() + case xcustom.AuthLocations.Header: + req.Header.Set(target, authString) + default: + return nil, nil, errors.New("invalid xCustomAuth location") + } + return key, settings, nil + } + if strings.HasPrefix(req.URL.Path, "/api/routes") { err = a.canKeyAccessCustomRoute(req.URL.Path, key.KeyId) if err != nil { diff --git a/internal/errors/expiration_err.go b/internal/errors/expiration_err.go index cce8380..59aec34 100644 --- a/internal/errors/expiration_err.go +++ b/internal/errors/expiration_err.go @@ -1,8 +1,9 @@ package errors const ( - TtlExpiration string = "ttl" - CostLimitExpiration string = "cost-limit" + TtlExpiration string = "ttl" + CostLimitExpiration string = "cost-limit" + RequestsLimitExpiration string = "requests-limit" ) type ExpirationError struct { diff --git a/internal/event/key_reporting.go b/internal/event/key_reporting.go index 9550861..6ebf3f4 100644 --- a/internal/event/key_reporting.go +++ b/internal/event/key_reporting.go @@ -29,11 +29,13 @@ type KeyRingReportingRequest struct { Limit int `json:"limit"` Offset int `json:"offset"` Revoked *bool `json:"revoked"` + TopBy string `json:"topBy"` } type KeyRingDataPoint struct { KeyRing string `json:"keyRing"` CostInUsd float64 `json:"costInUsd"` + Requests int `json:"requests"` } type KeyRingReportingResponse struct { @@ -47,8 +49,13 @@ type SpentKeyReportingRequest struct { Offset int `json:"offset"` } +type SpentKey struct { + KeyRing string `json:"keyRing"` + LinkedKeyId string `json:"linkedKeyId"` +} + type SpentKeyReportingResponse struct { - KeyRings []string `json:"keyRings"` + Keys []SpentKey `json:"keys"` } type UsageReportingRequest struct { @@ -56,10 +63,14 @@ type UsageReportingRequest struct { } type UsageData struct { - LastDayUsage float64 `json:"lastDayUsage"` - LastWeekUsage float64 `json:"lastWeekUsage"` - LastMonthUsage float64 `json:"lastMonthUsage"` - TotalUsage float64 `json:"totalUsage"` + LastDayUsage float64 `json:"lastDayUsage"` + LastWeekUsage float64 `json:"lastWeekUsage"` + LastMonthUsage float64 `json:"lastMonthUsage"` + TotalUsage float64 `json:"totalUsage"` + LastDayUsageRequests int `json:"lastDayUsageRequests"` + LastWeekUsageRequests int `json:"lastWeekUsageRequests"` + LastMonthUsageRequests int `json:"lastMonthUsageRequests"` + TotalUsageRequests int `json:"totalUsageRequests"` } type UsageReportingResponse struct { diff --git a/internal/key/key.go b/internal/key/key.go index 59c7b18..94eb537 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -25,6 +25,7 @@ type UpdateKey struct { CostLimitInUsdUnit *TimeUnit `json:"costLimitInUsdUnit"` RateLimitOverTime *int `json:"rateLimitOverTime"` RateLimitUnit *TimeUnit `json:"rateLimitUnit"` + RequestsLimit *int `json:"requestsLimit"` AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"` ShouldLogRequest *bool `json:"shouldLogRequest"` ShouldLogResponse *bool `json:"shouldLogResponse"` @@ -51,6 +52,10 @@ func (uk *UpdateKey) Validate() error { invalid = append(invalid, "costLimitInUsd") } + if uk.RequestsLimit != nil && *uk.RequestsLimit < 0 { + invalid = append(invalid, "requestsLimit") + } + if uk.UpdatedAt <= 0 { invalid = append(invalid, "updatedAt") } @@ -173,6 +178,7 @@ type RequestKey struct { RotationEnabled bool `json:"rotationEnabled"` PolicyId string `json:"policyId"` IsKeyNotHashed bool `json:"isKeyNotHashed"` + RequestsLimit int `json:"requestsLimit"` } func (rk *RequestKey) Validate() error { @@ -237,6 +243,10 @@ func (rk *RequestKey) Validate() error { invalid = append(invalid, "rateLimitOverTime") } + if rk.RequestsLimit < 0 { + invalid = append(invalid, "requestsLimit") + } + if len(rk.Ttl) != 0 { _, err := time.ParseDuration(rk.Ttl) if err != nil { @@ -317,6 +327,7 @@ type ResponseKey struct { CostLimitInUsdUnit TimeUnit `json:"costLimitInUsdUnit"` RateLimitOverTime int `json:"rateLimitOverTime"` RateLimitUnit TimeUnit `json:"rateLimitUnit"` + RequestsLimit int `json:"requestsLimit"` Ttl string `json:"ttl"` KeyRing string `json:"keyRing"` SettingId string `json:"settingId"` diff --git a/internal/manager/key.go b/internal/manager/key.go index 3401db4..f4f5b2d 100644 --- a/internal/manager/key.go +++ b/internal/manager/key.go @@ -46,21 +46,27 @@ type keyCache interface { Get(keyId string) (*key.ResponseKey, error) } +type requestsLimitStorage interface { + DeleteCounter(keyId string) error +} + type Manager struct { - s Storage - clc costLimitCache - rlc rateLimitCache - ac accessCache - kc keyCache + s Storage + clc costLimitCache + rlc rateLimitCache + ac accessCache + kc keyCache + rqls requestsLimitStorage } -func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache, kc keyCache) *Manager { +func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache, kc keyCache, rqls requestsLimitStorage) *Manager { return &Manager{ - s: s, - clc: clc, - rlc: rlc, - ac: ac, - kc: kc, + s: s, + clc: clc, + rlc: rlc, + ac: ac, + kc: kc, + rqls: rqls, } } @@ -175,6 +181,12 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err return nil, err } } + if uk.RequestsLimit != nil { + err := m.rqls.DeleteCounter(id) + if err != nil { + return nil, err + } + } if uk.PolicyId != nil { if len(*uk.PolicyId) != 0 { diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index d02e699..c67fc0a 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -3,6 +3,8 @@ package manager import ( "encoding/json" "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" + "slices" "strconv" "strings" "time" @@ -39,6 +41,8 @@ type ProviderSettingsManager struct { Encryptor Encryptor } +var nativelySupportedProviders = []string{"openai", "anthropic", "azure", "vllm", "deepinfra", "bedrock", "xCustom"} + func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSettingsCache, encryptor Encryptor) *ProviderSettingsManager { return &ProviderSettingsManager{ Storage: s, @@ -48,7 +52,7 @@ func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSetting } func isProviderNativelySupported(provider string) bool { - return provider == "openai" || provider == "anthropic" || provider == "azure" || provider == "vllm" || provider == "deepinfra" || provider == "bedrock" + return slices.Contains(nativelySupportedProviders, provider) } func findMissingAuthParams(providerName string, params map[string]string) string { @@ -99,6 +103,25 @@ func findMissingAuthParams(providerName string, params map[string]string) string } } + if providerName == "xCustom" { + val := params["apikey"] + if len(val) == 0 { + missingFields = append(missingFields, "apikey") + } + val = params["endpoint"] + if len(val) == 0 { + missingFields = append(missingFields, "endpoint") + } + val = params["authLocation"] + if len(val) == 0 { + missingFields = append(missingFields, "authLocation") + } + val = params["authTemplate"] + if !strings.Contains(val, "{{apikey}}") { + missingFields = append(missingFields, "authTemplate") + } + } + return strings.Join(missingFields, ",") } @@ -160,6 +183,18 @@ func (m *ProviderSettingsManager) CreateSetting(setting *provider.Setting) (*pro setting.CreatedAt = time.Now().Unix() setting.UpdatedAt = time.Now().Unix() + if setting.Provider == "xCustom" { + advancedSetting, err := xcustom.AdvancedXCustomSetting(setting.Setting) + if err != nil { + return nil, err + } + merged := setting.Setting + for k, v := range advancedSetting { + merged[k] = v + } + setting.Setting = merged + } + if m.Encryptor.Enabled() { params, err := m.EncryptParams(setting.UpdatedAt, setting.Provider, setting.Setting) if err != nil { @@ -183,15 +218,10 @@ func (m *ProviderSettingsManager) UpdateSetting(id string, setting *provider.Upd } if len(setting.Setting) != 0 { - if err := m.validateSettings(existing.Provider, setting.Setting); err != nil { + merged, err := m.getMergedSettings(existing, setting.Setting) + if err != nil { return nil, err } - - merged := existing.Setting - for k, v := range setting.Setting { - merged[k] = v - } - setting.Setting = merged } @@ -214,6 +244,31 @@ func (m *ProviderSettingsManager) UpdateSetting(id string, setting *provider.Upd return m.Storage.UpdateProviderSetting(id, setting) } +func (m *ProviderSettingsManager) getMergedSettings(existing *provider.Setting, setting map[string]string) (map[string]string, error) { + merged := existing.Setting + apikey, ok := setting["apikey"] + if ok && apikey == "revoked" { + merged["apikey"] = apikey + return merged, nil + } + for k, v := range setting { + merged[k] = v + } + if existing.Provider == "xCustom" { + advancedSetting, err := xcustom.AdvancedXCustomSetting(setting) + if err != nil { + return nil, err + } + for k, v := range advancedSetting { + merged[k] = v + } + } + if err := m.validateSettings(existing.Provider, merged); err != nil { + return nil, err + } + return merged, nil +} + func (m *ProviderSettingsManager) GetSettingViaCache(id string) (*provider.Setting, error) { setting, _ := m.Cache.Get(id) diff --git a/internal/manager/reporting.go b/internal/manager/reporting.go index 068fcfc..2560a37 100644 --- a/internal/manager/reporting.go +++ b/internal/manager/reporting.go @@ -14,7 +14,7 @@ type costStorage interface { type keyStorage interface { GetKey(keyId string) (*key.ResponseKey, error) - GetSpentKeyRings(tags []string, order string, limit, offset int, validator func(*key.ResponseKey) bool) ([]string, error) + GetSpentKeys(tags []string, order string, limit, offset int, validator func(*key.ResponseKey) bool) ([]event.SpentKey, error) } type keyValidator interface { @@ -31,7 +31,7 @@ type eventStorage interface { GetCustomIds(keyId string) ([]string, error) GetTopKeyDataPoints(start, end int64, tags, keyIds []string, order string, limit, offset int, name string, revoked *bool) ([]*event.KeyDataPoint, error) - GetTopKeyRingDataPoints(start, end int64, tags []string, order string, limit, offset int, revoked *bool) ([]*event.KeyRingDataPoint, error) + GetTopKeyRingDataPoints(start, end int64, tags []string, order string, limit, offset int, revoked *bool, topBy string) ([]*event.KeyRingDataPoint, error) GetUsageData(tags []string) (*event.UsageData, error) } @@ -130,7 +130,7 @@ func (rm *ReportingManager) GetTopKeyRingReporting(r *event.KeyRingReportingRequ return nil, internal_errors.NewValidationError("key reporting request order can only be desc or asc") } - dataPoints, err := rm.es.GetTopKeyRingDataPoints(r.Start, r.End, r.Tags, r.Order, r.Limit, r.Offset, r.Revoked) + dataPoints, err := rm.es.GetTopKeyRingDataPoints(r.Start, r.End, r.Tags, r.Order, r.Limit, r.Offset, r.Revoked, r.TopBy) if err != nil { return nil, err } @@ -160,12 +160,12 @@ func (rm *ReportingManager) GetSpentKeyReporting(r *event.SpentKeyReportingReque return true } - spentKeys, err := rm.ks.GetSpentKeyRings(r.Tags, r.Order, r.Limit, r.Offset, validator) + spentKeys, err := rm.ks.GetSpentKeys(r.Tags, r.Order, r.Limit, r.Offset, validator) if err != nil { return nil, err } return &event.SpentKeyReportingResponse{ - KeyRings: spentKeys, + Keys: spentKeys, }, nil } diff --git a/internal/message/consumer.go b/internal/message/consumer.go index 115f590..5d9dcd2 100644 --- a/internal/message/consumer.go +++ b/internal/message/consumer.go @@ -18,6 +18,7 @@ type recorder interface { RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error RecordUserSpend(userId string, micros int64, costLimitUnit key.TimeUnit) error RecordEvent(e *event.Event) error + RecordKeyRequestSpent(keyId string) error } func NewConsumer(mc <-chan Message, log *zap.Logger, num int, handle func(Message) error) *Consumer { diff --git a/internal/message/handler.go b/internal/message/handler.go index 1e578c1..f03f3b8 100644 --- a/internal/message/handler.go +++ b/internal/message/handler.go @@ -331,6 +331,12 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error { var u *user.User + err = h.recorder.RecordKeyRequestSpent(e.Event.KeyId) + if err != nil { + telemetry.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_key_request_spend_error", nil, 1) + h.log.Debug("error when recording key request spend", zap.Error(err)) + } + if e.Event.CostInUsd != 0 { micros := int64(e.Event.CostInUsd * 1000000) err = h.recorder.RecordKeySpend(e.Event.KeyId, micros, e.Key.CostLimitInUsdUnit) diff --git a/internal/provider/xcustom/xcustom.go b/internal/provider/xcustom/xcustom.go new file mode 100644 index 0000000..746ce86 --- /dev/null +++ b/internal/provider/xcustom/xcustom.go @@ -0,0 +1,101 @@ +package xcustom + +import ( + "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider" + "net/http" + "regexp" + + "strings" +) + +var XCustomSettingFields = struct { + ApiKey string + Endpoint string + AuthLocation string + AuthTemplate string + AuthTarget string + AuthMask string +}{ + ApiKey: "apikey", + Endpoint: "endpoint", + AuthLocation: "authLocation", + AuthTemplate: "authTemplate", + AuthTarget: "authTarget", + AuthMask: "authMask", +} + +type AuthLocation string + +var AuthLocations = struct { + Header AuthLocation + Query AuthLocation + Unknown AuthLocation +}{ + Header: AuthLocation("header"), + Query: AuthLocation("query"), + Unknown: AuthLocation("unknown"), +} + +const XProviderIdParam = "x_provider_id" + +func IsXCustomRequest(req *http.Request) bool { + return strings.HasPrefix(req.URL.RequestURI(), "/api/providers/xCustom/") +} + +func AdvancedXCustomSetting(src map[string]string) (map[string]string, error) { + rawLocation := src[XCustomSettingFields.AuthLocation] + location := GetAuthLocation(rawLocation) + var templateSeparator string + switch location { + case AuthLocations.Header: + templateSeparator = ":" + case AuthLocations.Query: + templateSeparator = "=" + default: + return nil, fmt.Errorf("unknown auth location: %s", location) + } + templateArr := strings.Split(src[XCustomSettingFields.AuthTemplate], templateSeparator) + if len(templateArr) != 2 { + return nil, fmt.Errorf("invalid auth template: %s", src[XCustomSettingFields.AuthTemplate]) + } + target := strings.TrimSpace(templateArr[0]) + mask := strings.TrimSpace(templateArr[1]) + return map[string]string{ + XCustomSettingFields.AuthTarget: target, + XCustomSettingFields.AuthMask: mask, + }, nil +} + +func ExtractApiKey(req *http.Request, pSetting *provider.Setting) (string, error) { + location := GetAuthLocation(pSetting.GetParam(XCustomSettingFields.AuthLocation)) + target := strings.TrimSpace(pSetting.GetParam(XCustomSettingFields.AuthTarget)) + var reqAuthStr string + switch location { + case AuthLocations.Header: + reqAuthStr = req.Header.Get(target) + case AuthLocations.Query: + reqAuthStr = req.URL.Query().Get(target) + default: + return "", fmt.Errorf("unknown auth location: %s", location) + } + mask := strings.TrimSpace(pSetting.GetParam(XCustomSettingFields.AuthMask)) + regexStr := strings.Replace(mask, "{{apikey}}", "(?P.*)", -1) + regex := regexp.MustCompile(regexStr) + matches := regex.FindStringSubmatch(reqAuthStr) + if len(matches) < 2 { + return "", fmt.Errorf("error extracting apikey: %s", pSetting.Id) + } + return strings.TrimSpace(matches[1]), nil +} + +func GetAuthLocation(raw string) AuthLocation { + switch raw { + case "header": + return AuthLocations.Header + case "query": + return AuthLocations.Query + default: + return AuthLocations.Unknown + } +} diff --git a/internal/recorder/recorder.go b/internal/recorder/recorder.go index a48f210..5c8b73f 100644 --- a/internal/recorder/recorder.go +++ b/internal/recorder/recorder.go @@ -6,12 +6,13 @@ import ( ) type Recorder struct { - s Store - c Cache - us Store - uc Cache - ce CostEstimator - es EventsStore + s Store + c Cache + us Store + uc Cache + ce CostEstimator + es EventsStore + reqLimitStore Store } type EventsStore interface { @@ -31,14 +32,15 @@ type CostEstimator interface { EstimateCompletionCost(model string, tks int) (float64, error) } -func NewRecorder(s, us Store, c, uc Cache, ce CostEstimator, es EventsStore) *Recorder { +func NewRecorder(s, us Store, c, uc Cache, ce CostEstimator, es EventsStore, reqLimitStore Store) *Recorder { return &Recorder{ - s: s, - c: c, - us: us, - uc: uc, - ce: ce, - es: es, + s: s, + c: c, + us: us, + uc: uc, + ce: ce, + es: es, + reqLimitStore: reqLimitStore, } } @@ -74,6 +76,10 @@ func (r *Recorder) RecordKeySpend(keyId string, micros int64, costLimitUnit key. return nil } +func (r *Recorder) RecordKeyRequestSpent(keyId string) error { + return r.reqLimitStore.IncrementCounter(keyId, 1) +} + func (r *Recorder) RecordEvent(e *event.Event) error { return r.es.InsertEvent(e) } diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go index 6e2a75e..5a215f5 100644 --- a/internal/server/web/proxy/middleware.go +++ b/internal/server/web/proxy/middleware.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" "io" "net/http" "strconv" @@ -67,7 +68,7 @@ type deepinfraEstimator interface { } type authenticator interface { - AuthenticateHttpRequest(req *http.Request) (*key.ResponseKey, []*provider.Setting, error) + AuthenticateHttpRequest(req *http.Request, xCustomProviderId string) (*key.ResponseKey, []*provider.Setting, error) } type validator interface { @@ -304,7 +305,7 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag return } - kc, settings, err := a.AuthenticateHttpRequest(c.Request) + kc, settings, err := a.AuthenticateHttpRequest(c.Request, c.Param(xcustom.XProviderIdParam)) enrichedEvent.Key = kc _, ok := err.(notAuthorizedError) if ok { diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go index e6078f4..8bf70b9 100644 --- a/internal/server/web/proxy/proxy.go +++ b/internal/server/web/proxy/proxy.go @@ -220,6 +220,9 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/cancel", getCancelVectorStoreFileBatchHandler(prod, client)) router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files", getListVectorStoreFileBatchFilesHandler(prod, client)) + // codio xCustom + router.Any("/api/providers/xCustom/:x_provider_id/*wildcard", getXCustomHandler(prod)) + srv := &http.Server{ Addr: ":8002", Handler: router, diff --git a/internal/server/web/proxy/x_custom.go b/internal/server/web/proxy/x_custom.go new file mode 100644 index 0000000..c0132f6 --- /dev/null +++ b/internal/server/web/proxy/x_custom.go @@ -0,0 +1,74 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "github.com/bricks-cloud/bricksllm/internal/provider" + "github.com/bricks-cloud/bricksllm/internal/provider/xcustom" + "github.com/bricks-cloud/bricksllm/internal/telemetry" + "github.com/bricks-cloud/bricksllm/internal/util" + "github.com/gin-gonic/gin" + "net/http" + "net/http/httputil" + "net/url" + "strings" +) + +func getXCustomHandler(prod bool) gin.HandlerFunc { + return func(c *gin.Context) { + log := util.GetLogFromCtx(c) + telemetry.Incr("bricksllm.proxy.get_x_custom_handler.requests", nil, 1) + + if c == nil || c.Request == nil { + JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), c.GetDuration("requestTimeout")) + defer cancel() + + providerId := c.Param(xcustom.XProviderIdParam) + rawProviderSettings, exists := c.Get("settings") + if !exists { + logError(log, "error provider setting", prod, errors.New("provider setting not found")) + c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") + return + } + settings, ok := rawProviderSettings.([]*provider.Setting) + if !ok { + logError(log, "error provider setting", prod, errors.New("incorrect setting")) + c.JSON(http.StatusInternalServerError, "[BricksLLM] incorrect provider setting") + return + } + var providerSetting *provider.Setting + for _, setting := range settings { + if setting.Id == providerId { + providerSetting = setting + } + } + if providerSetting == nil { + logError(log, "error provider setting", prod, errors.New("provider setting not found")) + c.JSON(http.StatusInternalServerError, "[BricksLLM] no settings found") + return + } + wildcard := c.Param("wildcard") + endpoint := strings.TrimSuffix(providerSetting.GetParam("endpoint"), "/") + targetUrl := fmt.Sprintf("%s%s", endpoint, wildcard) + target, e := url.Parse(targetUrl) + if e != nil { + logError(log, "error parsing target url", prod, e) + c.JSON(http.StatusInternalServerError, "[BricksLLM] invalid endpoint") + return + } + + proxy := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(target) + r.Out.URL.Path, r.Out.URL.RawPath = target.Path, target.RawPath + r.Out.WithContext(ctx) + }, + } + proxy.ServeHTTP(c.Writer, c.Request) + } +} diff --git a/internal/storage/postgresql/event.go b/internal/storage/postgresql/event.go index 50d5a41..274d19b 100644 --- a/internal/storage/postgresql/event.go +++ b/internal/storage/postgresql/event.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strings" "time" @@ -13,6 +14,8 @@ import ( "github.com/lib/pq" ) +var allowedTopBy = []string{"total_cost_in_usd", "total_requests"} + func (s *Store) CreateEventsByDayTable() error { createTableQuery := ` CREATE TABLE IF NOT EXISTS event_agg_by_day ( @@ -471,7 +474,7 @@ func (s *Store) GetTopKeyDataPoints(start, end int64, tags, keyIds []string, ord return data, nil } -func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order string, limit, offset int, revoked *bool) ([]*event.KeyRingDataPoint, error) { +func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order string, limit, offset int, revoked *bool, topBy string) ([]*event.KeyRingDataPoint, error) { args := []any{} condition := "" condition2 := "" @@ -494,21 +497,12 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s } if len(tags) > 0 { - condition2 += fmt.Sprintf("AND keys.tags @> $%d", index) + condition2 += fmt.Sprintf("AND events.tags @> $%d", index) args = append(args, pq.Array(tags)) index++ } - if revoked != nil { - bools := "False" - if *revoked { - bools = "True" - } - - condition2 += fmt.Sprintf(" AND keys.revoked = %s", bools) - } - query := fmt.Sprintf(` WITH keys_table AS ( @@ -517,7 +511,8 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s ( SELECT key_ring, - SUM(cost_in_usd) AS total_cost_in_usd + SUM(cost_in_usd) AS total_cost_in_usd, + COUNT(*) AS total_requests FROM events LEFT JOIN keys ON keys.key_id = events.key_id @@ -531,9 +526,13 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s qorder = "ASC" } + qtopBy := "total_cost_in_usd" + if topBy != "" && slices.Contains(allowedTopBy, topBy) { + qtopBy = topBy + } query += fmt.Sprintf(` - ORDER BY total_cost_in_usd %s -`, qorder) + ORDER BY %s %s +`, qtopBy, qorder) if limit != 0 { query += fmt.Sprintf(` @@ -558,6 +557,7 @@ func (s *Store) GetTopKeyRingDataPoints(start, end int64, tags []string, order s additional := []any{ &keyRing, &e.CostInUsd, + &e.Requests, } if err := rows.Scan( @@ -597,10 +597,14 @@ func (s *Store) GetUsageData(tags []string) (*event.UsageData, error) { COALESCE(SUM(cost_in_usd), 0) AS total_cost_in_usd, COALESCE(SUM(CASE WHEN created_at > %d THEN cost_in_usd ELSE 0 END), 0) AS total_cost_in_usd_last_day, COALESCE(SUM(CASE WHEN created_at > %d THEN cost_in_usd ELSE 0 END), 0) AS total_cost_in_usd_last_week, - COALESCE(SUM(CASE WHEN created_at > %d THEN cost_in_usd ELSE 0 END), 0) AS total_cost_in_usd_last_month + COALESCE(SUM(CASE WHEN created_at > %d THEN cost_in_usd ELSE 0 END), 0) AS total_cost_in_usd_last_month, + COALESCE(SUM(1), 0) AS total_requests, + COALESCE(SUM(CASE WHEN created_at > %d THEN 1 ELSE 0 END), 0) AS total_requests_last_day, + COALESCE(SUM(CASE WHEN created_at > %d THEN 1 ELSE 0 END), 0) AS total_requests_last_week, + COALESCE(SUM(CASE WHEN created_at > %d THEN 1 ELSE 0 END), 0) AS total_requests_last_month FROM events WHERE %s - `, dayAgo, weekAgo, monthAgo, condition) + `, dayAgo, weekAgo, monthAgo, dayAgo, weekAgo, monthAgo, condition) ctx, cancel := context.WithTimeout(context.Background(), s.rt) defer cancel() @@ -611,6 +615,10 @@ func (s *Store) GetUsageData(tags []string) (*event.UsageData, error) { &data.LastDayUsage, &data.LastWeekUsage, &data.LastMonthUsage, + &data.TotalUsageRequests, + &data.LastDayUsageRequests, + &data.LastWeekUsageRequests, + &data.LastMonthUsageRequests, ); err != nil { if err == sql.ErrNoRows { return nil, nil diff --git a/internal/storage/postgresql/key.go b/internal/storage/postgresql/key.go index c129aa9..e92f114 100644 --- a/internal/storage/postgresql/key.go +++ b/internal/storage/postgresql/key.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "github.com/bricks-cloud/bricksllm/internal/event" "strings" internal_errors "github.com/bricks-cloud/bricksllm/internal/errors" @@ -56,7 +57,7 @@ func (s *Store) AlterKeysTable() error { END IF; END $$; - ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '', ADD COLUMN IF NOT EXISTS is_key_not_hashed BOOLEAN NOT NULL DEFAULT FALSE; + ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS policy_id VARCHAR(255) NOT NULL DEFAULT '', ADD COLUMN IF NOT EXISTS is_key_not_hashed BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS requests_limit INT NOT NULL DEFAULT 0; ` ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt) @@ -189,6 +190,7 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -314,6 +316,7 @@ func (s *Store) GetKeysV2(tags, keyIds []string, revoked *bool, limit, offset in &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -393,6 +396,7 @@ func (s *Store) GetKeyByHash(hash string) (*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ) if err != nil { @@ -457,6 +461,7 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -483,7 +488,7 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) { return keys[0], nil } -func (s *Store) GetSpentKeyRings(tags []string, order string, limit, offset int, validator func(*key.ResponseKey) bool) ([]string, error) { +func (s *Store) GetSpentKeys(tags []string, order string, limit, offset int, validator func(*key.ResponseKey) bool) ([]event.SpentKey, error) { args := []any{} condition := "" @@ -524,7 +529,7 @@ func (s *Store) GetSpentKeyRings(tags []string, order string, limit, offset int, } defer rows.Close() - invalidKeyRings := []string{} + invalidKeyRings := []event.SpentKey{} for rows.Next() { var k key.ResponseKey var settingId sql.NullString @@ -553,6 +558,7 @@ func (s *Store) GetSpentKeyRings(tags []string, order string, limit, offset int, &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -569,7 +575,10 @@ func (s *Store) GetSpentKeyRings(tags []string, order string, limit, offset int, } if !validator(pk) { - invalidKeyRings = append(invalidKeyRings, pk.KeyRing) + invalidKeyRings = append(invalidKeyRings, event.SpentKey{ + KeyRing: pk.KeyRing, + LinkedKeyId: pk.KeyId, + }) } } @@ -615,6 +624,7 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -675,6 +685,7 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } @@ -853,6 +864,7 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { if err == sql.ErrNoRows { return nil, internal_errors.NewNotFoundError(fmt.Sprintf("key not found for id: %s", id)) @@ -877,8 +889,8 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { query := ` - INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, key_ring, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled, policy_id, is_key_not_hashed) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23) + INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, key_ring, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled, policy_id, is_key_not_hashed, requests_limit) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24) RETURNING *; ` @@ -911,6 +923,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { rk.RotationEnabled, rk.PolicyId, rk.IsKeyNotHashed, + rk.RequestsLimit, } ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt) @@ -944,6 +957,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) { &k.RotationEnabled, &k.PolicyId, &k.IsKeyNotHashed, + &k.RequestsLimit, ); err != nil { return nil, err } diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 7efdfb9..c3df819 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -21,21 +21,28 @@ type costLimitStorage interface { GetCounter(keyId string) (int64, error) } +type requestsLimitStorage interface { + GetCounter(keyId string) (int64, error) +} + type Validator struct { - clc costLimitCache - rlc rateLimitCache - cls costLimitStorage + clc costLimitCache + rlc rateLimitCache + cls costLimitStorage + rqls requestsLimitStorage } func NewValidator( clc costLimitCache, rlc rateLimitCache, cls costLimitStorage, + rqls requestsLimitStorage, ) *Validator { return &Validator{ - clc: clc, - rlc: rlc, - cls: cls, + clc: clc, + rlc: rlc, + cls: cls, + rqls: rqls, } } @@ -53,7 +60,12 @@ func (v *Validator) Validate(k *key.ResponseKey, promptCost float64) error { return internal_errors.NewExpirationError("api key expired", internal_errors.TtlExpiration) } - err := v.validateRateLimitOverTime(k.KeyId, k.RateLimitOverTime, k.RateLimitUnit) + err := v.validateRequestsLimit(k.KeyId, k.RequestsLimit) + if err != nil { + return err + } + + err = v.validateRateLimitOverTime(k.KeyId, k.RateLimitOverTime, k.RateLimitUnit) if err != nil { return err } @@ -136,3 +148,17 @@ func (v *Validator) validateCostLimit(keyId string, costLimit float64) error { return nil } + +func (v *Validator) validateRequestsLimit(keyId string, requestsLimit int) error { + if requestsLimit == 0 { + return nil + } + existingTotalRequests, err := v.rqls.GetCounter(keyId) + if err != nil { + return errors.New("failed to get total requests") + } + if existingTotalRequests >= int64(requestsLimit) { + return internal_errors.NewExpirationError(fmt.Sprintf("total requests limit: %d, has been reached", requestsLimit), internal_errors.RequestsLimitExpiration) + } + return nil +}