Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions cmd/bricksllm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ func main() {
log.Sugar().Fatalf("error connecting to keys redis storage: %v", err)
}

requestsLimitRedisStorage := redis.NewClient(defaultRedisOption(cfg, 11))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := requestsLimitRedisStorage.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to requests limit redis storage: %v", err)
}

rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
Expand All @@ -285,15 +293,15 @@ func main() {

psCache := redisStorage.NewProviderSettingsCache(providerSettingsRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
keysCache := redisStorage.NewKeysCache(keysRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
requestsLimitStorage := redisStorage.NewStore(requestsLimitRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)

encryptor, err := encryptor.NewEncryptor(cfg.DecryptionEndpoint, cfg.EncryptionEndpoint, cfg.EnableEncrytion, cfg.EncryptionTimeout, cfg.Audience)
if cfg.EnableEncrytion && err != nil {
log.Sugar().Fatalf("error creating encryption client: %v", err)
}
v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage)

v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage, requestsLimitStorage)

m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache)
m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache, requestsLimitStorage)
krm := manager.NewReportingManager(costStorage, store, store, v)
psm := manager.NewProviderSettingsManager(store, psCache, encryptor)
cpm := manager.NewCustomProvidersManager(store, cpMemStore)
Expand Down Expand Up @@ -330,7 +338,7 @@ func main() {

uv := validator.NewUserValidator(userCostLimitCache, userRateLimitCache, userCostStorage)

rec := recorder.NewRecorder(costStorage, userCostStorage, costLimitCache, userCostLimitCache, ce, store)
rec := recorder.NewRecorder(costStorage, userCostStorage, costLimitCache, userCostLimitCache, ce, store, requestsLimitStorage)
rlm := manager.NewRateLimitManager(rateLimitCache, userRateLimitCache)
a := auth.NewAuthenticator(psm, m, rm, store, encryptor)

Expand Down
39 changes: 37 additions & 2 deletions internal/authenticator/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"errors"
"fmt"
"github.com/bricks-cloud/bricksllm/internal/provider/xcustom"
"math/rand"
"net/http"
"strconv"
Expand Down Expand Up @@ -204,8 +205,20 @@ func anonymize(input string) string {
return string(input[0:5]) + "**********************************************"
}

func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.ResponseKey, []*provider.Setting, error) {
raw, err := getApiKey(req)
func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProviderId string) (*key.ResponseKey, []*provider.Setting, error) {
var raw string
var err error
var settings []*provider.Setting
if xcustom.IsXCustomRequest(req) {
providerSetting, er := a.psm.GetSettingViaCache(xCustomProviderId)
if er != nil {
return nil, nil, er
}
settings = []*provider.Setting{providerSetting}
raw, err = xcustom.ExtractApiKey(req, providerSetting)
} else {
raw, err = getApiKey(req)
}
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -238,6 +251,28 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons
return nil, nil, internal_errors.NewAuthError(fmt.Sprintf("key %s has been revoked", anonymize(raw)))
}

if xcustom.IsXCustomRequest(req) {
pSetting := settings[0]
authString := strings.Replace(
pSetting.GetParam(xcustom.XCustomSettingFields.AuthMask),
"{{apikey}}",
pSetting.GetParam(xcustom.XCustomSettingFields.ApiKey), -1,
)
location := xcustom.GetAuthLocation(pSetting.GetParam(xcustom.XCustomSettingFields.AuthLocation))
target := pSetting.GetParam(xcustom.XCustomSettingFields.AuthTarget)
switch location {
case xcustom.AuthLocations.Query:
params := req.URL.Query()
params.Set(target, authString)
req.URL.RawQuery = params.Encode()
case xcustom.AuthLocations.Header:
req.Header.Set(target, authString)
default:
return nil, nil, errors.New("invalid xCustomAuth location")
}
return key, settings, nil
}

if strings.HasPrefix(req.URL.Path, "/api/routes") {
err = a.canKeyAccessCustomRoute(req.URL.Path, key.KeyId)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions internal/errors/expiration_err.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package errors

const (
TtlExpiration string = "ttl"
CostLimitExpiration string = "cost-limit"
TtlExpiration string = "ttl"
CostLimitExpiration string = "cost-limit"
RequestsLimitExpiration string = "requests-limit"
)

type ExpirationError struct {
Expand Down
21 changes: 16 additions & 5 deletions internal/event/key_reporting.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ type KeyRingReportingRequest struct {
Limit int `json:"limit"`
Offset int `json:"offset"`
Revoked *bool `json:"revoked"`
TopBy string `json:"topBy"`
}

type KeyRingDataPoint struct {
KeyRing string `json:"keyRing"`
CostInUsd float64 `json:"costInUsd"`
Requests int `json:"requests"`
}

type KeyRingReportingResponse struct {
Expand All @@ -47,19 +49,28 @@ type SpentKeyReportingRequest struct {
Offset int `json:"offset"`
}

type SpentKey struct {
KeyRing string `json:"keyRing"`
LinkedKeyId string `json:"linkedKeyId"`
}

type SpentKeyReportingResponse struct {
KeyRings []string `json:"keyRings"`
Keys []SpentKey `json:"keys"`
}

type UsageReportingRequest struct {
Tags []string `json:"tags"`
}

type UsageData struct {
LastDayUsage float64 `json:"lastDayUsage"`
LastWeekUsage float64 `json:"lastWeekUsage"`
LastMonthUsage float64 `json:"lastMonthUsage"`
TotalUsage float64 `json:"totalUsage"`
LastDayUsage float64 `json:"lastDayUsage"`
LastWeekUsage float64 `json:"lastWeekUsage"`
LastMonthUsage float64 `json:"lastMonthUsage"`
TotalUsage float64 `json:"totalUsage"`
LastDayUsageRequests int `json:"lastDayUsageRequests"`
LastWeekUsageRequests int `json:"lastWeekUsageRequests"`
LastMonthUsageRequests int `json:"lastMonthUsageRequests"`
TotalUsageRequests int `json:"totalUsageRequests"`
}

type UsageReportingResponse struct {
Expand Down
11 changes: 11 additions & 0 deletions internal/key/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type UpdateKey struct {
CostLimitInUsdUnit *TimeUnit `json:"costLimitInUsdUnit"`
RateLimitOverTime *int `json:"rateLimitOverTime"`
RateLimitUnit *TimeUnit `json:"rateLimitUnit"`
RequestsLimit *int `json:"requestsLimit"`
AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"`
ShouldLogRequest *bool `json:"shouldLogRequest"`
ShouldLogResponse *bool `json:"shouldLogResponse"`
Expand All @@ -51,6 +52,10 @@ func (uk *UpdateKey) Validate() error {
invalid = append(invalid, "costLimitInUsd")
}

if uk.RequestsLimit != nil && *uk.RequestsLimit < 0 {
invalid = append(invalid, "requestsLimit")
}

if uk.UpdatedAt <= 0 {
invalid = append(invalid, "updatedAt")
}
Expand Down Expand Up @@ -173,6 +178,7 @@ type RequestKey struct {
RotationEnabled bool `json:"rotationEnabled"`
PolicyId string `json:"policyId"`
IsKeyNotHashed bool `json:"isKeyNotHashed"`
RequestsLimit int `json:"requestsLimit"`
}

func (rk *RequestKey) Validate() error {
Expand Down Expand Up @@ -237,6 +243,10 @@ func (rk *RequestKey) Validate() error {
invalid = append(invalid, "rateLimitOverTime")
}

if rk.RequestsLimit < 0 {
invalid = append(invalid, "requestsLimit")
}

if len(rk.Ttl) != 0 {
_, err := time.ParseDuration(rk.Ttl)
if err != nil {
Expand Down Expand Up @@ -317,6 +327,7 @@ type ResponseKey struct {
CostLimitInUsdUnit TimeUnit `json:"costLimitInUsdUnit"`
RateLimitOverTime int `json:"rateLimitOverTime"`
RateLimitUnit TimeUnit `json:"rateLimitUnit"`
RequestsLimit int `json:"requestsLimit"`
Ttl string `json:"ttl"`
KeyRing string `json:"keyRing"`
SettingId string `json:"settingId"`
Expand Down
34 changes: 23 additions & 11 deletions internal/manager/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,27 @@ type keyCache interface {
Get(keyId string) (*key.ResponseKey, error)
}

type requestsLimitStorage interface {
DeleteCounter(keyId string) error
}

type Manager struct {
s Storage
clc costLimitCache
rlc rateLimitCache
ac accessCache
kc keyCache
s Storage
clc costLimitCache
rlc rateLimitCache
ac accessCache
kc keyCache
rqls requestsLimitStorage
}

func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache, kc keyCache) *Manager {
func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache, kc keyCache, rqls requestsLimitStorage) *Manager {
return &Manager{
s: s,
clc: clc,
rlc: rlc,
ac: ac,
kc: kc,
s: s,
clc: clc,
rlc: rlc,
ac: ac,
kc: kc,
rqls: rqls,
}
}

Expand Down Expand Up @@ -175,6 +181,12 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err
return nil, err
}
}
if uk.RequestsLimit != nil {
err := m.rqls.DeleteCounter(id)
if err != nil {
return nil, err
}
}

if uk.PolicyId != nil {
if len(*uk.PolicyId) != 0 {
Expand Down
71 changes: 63 additions & 8 deletions internal/manager/provider_setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package manager
import (
"encoding/json"
"fmt"
"github.com/bricks-cloud/bricksllm/internal/provider/xcustom"
"slices"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -39,6 +41,8 @@ type ProviderSettingsManager struct {
Encryptor Encryptor
}

var nativelySupportedProviders = []string{"openai", "anthropic", "azure", "vllm", "deepinfra", "bedrock", "xCustom"}

func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSettingsCache, encryptor Encryptor) *ProviderSettingsManager {
return &ProviderSettingsManager{
Storage: s,
Expand All @@ -48,7 +52,7 @@ func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSetting
}

func isProviderNativelySupported(provider string) bool {
return provider == "openai" || provider == "anthropic" || provider == "azure" || provider == "vllm" || provider == "deepinfra" || provider == "bedrock"
return slices.Contains(nativelySupportedProviders, provider)
}

func findMissingAuthParams(providerName string, params map[string]string) string {
Expand Down Expand Up @@ -99,6 +103,25 @@ func findMissingAuthParams(providerName string, params map[string]string) string
}
}

if providerName == "xCustom" {
val := params["apikey"]
if len(val) == 0 {
missingFields = append(missingFields, "apikey")
}
val = params["endpoint"]
if len(val) == 0 {
missingFields = append(missingFields, "endpoint")
}
val = params["authLocation"]
if len(val) == 0 {
missingFields = append(missingFields, "authLocation")
}
val = params["authTemplate"]
if !strings.Contains(val, "{{apikey}}") {
missingFields = append(missingFields, "authTemplate")
}
}

return strings.Join(missingFields, ",")
}

Expand Down Expand Up @@ -160,6 +183,18 @@ func (m *ProviderSettingsManager) CreateSetting(setting *provider.Setting) (*pro
setting.CreatedAt = time.Now().Unix()
setting.UpdatedAt = time.Now().Unix()

if setting.Provider == "xCustom" {
advancedSetting, err := xcustom.AdvancedXCustomSetting(setting.Setting)
if err != nil {
return nil, err
}
merged := setting.Setting
for k, v := range advancedSetting {
merged[k] = v
}
setting.Setting = merged
}

if m.Encryptor.Enabled() {
params, err := m.EncryptParams(setting.UpdatedAt, setting.Provider, setting.Setting)
if err != nil {
Expand All @@ -183,15 +218,10 @@ func (m *ProviderSettingsManager) UpdateSetting(id string, setting *provider.Upd
}

if len(setting.Setting) != 0 {
if err := m.validateSettings(existing.Provider, setting.Setting); err != nil {
merged, err := m.getMergedSettings(existing, setting.Setting)
if err != nil {
return nil, err
}

merged := existing.Setting
for k, v := range setting.Setting {
merged[k] = v
}

setting.Setting = merged
}

Expand All @@ -214,6 +244,31 @@ func (m *ProviderSettingsManager) UpdateSetting(id string, setting *provider.Upd
return m.Storage.UpdateProviderSetting(id, setting)
}

func (m *ProviderSettingsManager) getMergedSettings(existing *provider.Setting, setting map[string]string) (map[string]string, error) {
merged := existing.Setting
apikey, ok := setting["apikey"]
if ok && apikey == "revoked" {
merged["apikey"] = apikey
return merged, nil
}
for k, v := range setting {
merged[k] = v
}
if existing.Provider == "xCustom" {
advancedSetting, err := xcustom.AdvancedXCustomSetting(setting)
if err != nil {
return nil, err
}
for k, v := range advancedSetting {
merged[k] = v
}
}
if err := m.validateSettings(existing.Provider, merged); err != nil {
return nil, err
}
return merged, nil
}

func (m *ProviderSettingsManager) GetSettingViaCache(id string) (*provider.Setting, error) {
setting, _ := m.Cache.Get(id)

Expand Down
Loading
Loading