Skip to content

Commit 5cbf9d7

Browse files
committed
feat(api-key-auth): unify session and API key auth into AuthSessionOrAPIKey middleware
1 parent d124ce8 commit 5cbf9d7

3 files changed

Lines changed: 356 additions & 7 deletions

File tree

internal/base/middleware/api_key_auth.go

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,108 @@
2020
package middleware
2121

2222
import (
23+
"strings"
24+
2325
"github.com/apache/answer/internal/base/handler"
2426
"github.com/apache/answer/internal/base/reason"
27+
"github.com/apache/answer/internal/entity"
2528
"github.com/gin-gonic/gin"
2629
"github.com/segmentfault/pacman/errors"
30+
"github.com/segmentfault/pacman/log"
2731
)
2832

29-
// AuthAPIKey middleware to authenticate API key
30-
func (am *AuthUserMiddleware) AuthAPIKey() gin.HandlerFunc {
33+
// apiKeyAllowedPrefixes lists the URL path prefixes accessible via API key.
34+
// Routes not matching any prefix require a session token.
35+
var apiKeyAllowedPrefixes = []string{
36+
"/answer/api/v1/question",
37+
"/answer/api/v1/answer",
38+
"/answer/api/v1/comment",
39+
"/answer/api/v1/tag",
40+
"/answer/api/v1/search",
41+
"/answer/api/v1/collection",
42+
"/answer/api/v1/vote",
43+
"/answer/api/v1/follow",
44+
"/answer/api/v1/revisions",
45+
"/answer/api/v1/chat/completions",
46+
"/answer/api/v1/ai/conversation",
47+
"/answer/api/v1/mcp",
48+
}
49+
50+
func isAPIKeyAllowed(path string) bool {
51+
for _, prefix := range apiKeyAllowedPrefixes {
52+
if strings.HasPrefix(path, prefix) {
53+
return true
54+
}
55+
}
56+
return false
57+
}
58+
59+
// AuthSessionOrAPIKey tries session-based auth first, then falls back to API key auth.
60+
// In both cases it injects a UserCacheInfo into the Gin context so that downstream
61+
// handlers can use GetLoginUserIDFromContext() as usual.
62+
func (am *AuthUserMiddleware) AuthSessionOrAPIKey() gin.HandlerFunc {
3163
return func(ctx *gin.Context) {
3264
token := ExtractToken(ctx)
3365
if len(token) == 0 {
3466
handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil)
3567
ctx.Abort()
3668
return
3769
}
38-
pass, err := am.authService.AuthAPIKey(ctx, ctx.Request.Method == "GET", token)
39-
if err != nil {
70+
71+
// 1. Try session-based auth
72+
userInfo, err := am.authService.GetUserCacheInfo(ctx, token)
73+
if err == nil && userInfo != nil {
74+
if !am.validateUserStatus(ctx, userInfo) {
75+
return
76+
}
77+
ctx.Set(ctxUUIDKey, userInfo)
78+
ctx.Next()
79+
return
80+
}
81+
82+
// 2. Fallback to API key auth (only for whitelisted routes)
83+
if !isAPIKeyAllowed(ctx.Request.URL.Path) {
4084
handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil)
4185
ctx.Abort()
4286
return
4387
}
44-
if !pass {
88+
89+
isRead := ctx.Request.Method == "GET"
90+
apiKeyInfo, err := am.authService.GetAPIKeyInfo(ctx, isRead, token)
91+
if err != nil || apiKeyInfo == nil {
4592
handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil)
4693
ctx.Abort()
4794
return
4895
}
96+
97+
// Resolve user from the API key's UserID
98+
userEntity, exist, err := am.userRepo.GetByUserID(ctx, apiKeyInfo.UserID)
99+
if err != nil || !exist {
100+
log.Errorf("API key %s references unknown user %s", apiKeyInfo.AccessKey, apiKeyInfo.UserID)
101+
handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil)
102+
ctx.Abort()
103+
return
104+
}
105+
if userEntity.Status == entity.UserStatusDeleted || userEntity.Status == entity.UserStatusSuspended {
106+
handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil)
107+
ctx.Abort()
108+
return
109+
}
110+
111+
roleID, err := am.userRoleService.GetUserRole(ctx, userEntity.ID)
112+
if err != nil {
113+
log.Errorf("failed to get role for user %s: %v", userEntity.ID, err)
114+
handler.HandleResponse(ctx, errors.Unauthorized(reason.UnauthorizedError), nil)
115+
ctx.Abort()
116+
return
117+
}
118+
119+
ctx.Set(ctxUUIDKey, &entity.UserCacheInfo{
120+
UserID: userEntity.ID,
121+
UserStatus: userEntity.Status,
122+
EmailStatus: userEntity.MailStatus,
123+
RoleID: roleID,
124+
})
49125
ctx.Next()
50126
}
51127
}
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package middleware
21+
22+
import (
23+
"context"
24+
"net/http"
25+
"net/http/httptest"
26+
"testing"
27+
28+
"github.com/apache/answer/internal/entity"
29+
"github.com/apache/answer/internal/service/auth"
30+
"github.com/apache/answer/internal/service/role"
31+
"github.com/gin-gonic/gin"
32+
"github.com/stretchr/testify/assert"
33+
)
34+
35+
// --- Mock repos for AuthService ---
36+
37+
type mockAuthRepo struct {
38+
userCache *entity.UserCacheInfo
39+
err error
40+
}
41+
42+
func (m *mockAuthRepo) GetUserCacheInfo(_ context.Context, _ string) (*entity.UserCacheInfo, error) {
43+
return m.userCache, m.err
44+
}
45+
func (m *mockAuthRepo) SetUserCacheInfo(_ context.Context, _, _ string, _ *entity.UserCacheInfo) error {
46+
return nil
47+
}
48+
func (m *mockAuthRepo) GetUserVisitCacheInfo(_ context.Context, _ string) (string, error) {
49+
return "", nil
50+
}
51+
func (m *mockAuthRepo) RemoveUserCacheInfo(_ context.Context, _ string) error { return nil }
52+
func (m *mockAuthRepo) RemoveUserVisitCacheInfo(_ context.Context, _ string) error { return nil }
53+
func (m *mockAuthRepo) SetUserStatus(_ context.Context, _ string, _ *entity.UserCacheInfo) error {
54+
return nil
55+
}
56+
func (m *mockAuthRepo) GetUserStatus(_ context.Context, _ string) (*entity.UserCacheInfo, error) {
57+
return nil, nil
58+
}
59+
func (m *mockAuthRepo) RemoveUserStatus(_ context.Context, _ string) error { return nil }
60+
func (m *mockAuthRepo) GetAdminUserCacheInfo(_ context.Context, _ string) (*entity.UserCacheInfo, error) {
61+
return nil, nil
62+
}
63+
func (m *mockAuthRepo) SetAdminUserCacheInfo(_ context.Context, _ string, _ *entity.UserCacheInfo) error {
64+
return nil
65+
}
66+
func (m *mockAuthRepo) RemoveAdminUserCacheInfo(_ context.Context, _ string) error { return nil }
67+
func (m *mockAuthRepo) AddUserTokenMapping(_ context.Context, _, _ string) error { return nil }
68+
func (m *mockAuthRepo) RemoveUserTokens(_ context.Context, _ string, _ string) {}
69+
70+
type mockAPIKeyRepo struct {
71+
key *entity.APIKey
72+
exist bool
73+
err error
74+
}
75+
76+
func (m *mockAPIKeyRepo) GetAPIKeyList(_ context.Context) ([]*entity.APIKey, error) { return nil, nil }
77+
func (m *mockAPIKeyRepo) GetAPIKey(_ context.Context, _ string) (*entity.APIKey, bool, error) {
78+
return m.key, m.exist, m.err
79+
}
80+
func (m *mockAPIKeyRepo) UpdateAPIKey(_ context.Context, _ entity.APIKey) error { return nil }
81+
func (m *mockAPIKeyRepo) AddAPIKey(_ context.Context, _ entity.APIKey) error { return nil }
82+
func (m *mockAPIKeyRepo) DeleteAPIKey(_ context.Context, _ int) error { return nil }
83+
84+
type mockUserRepo struct {
85+
user *entity.User
86+
exist bool
87+
err error
88+
}
89+
90+
func (m *mockUserRepo) AddUser(_ context.Context, _ *entity.User) error { return nil }
91+
func (m *mockUserRepo) IncreaseAnswerCount(_ context.Context, _ string, _ int) error { return nil }
92+
func (m *mockUserRepo) IncreaseQuestionCount(_ context.Context, _ string, _ int) error { return nil }
93+
func (m *mockUserRepo) UpdateQuestionCount(_ context.Context, _ string, _ int64) error { return nil }
94+
func (m *mockUserRepo) UpdateAnswerCount(_ context.Context, _ string, _ int) error { return nil }
95+
func (m *mockUserRepo) UpdateLastLoginDate(_ context.Context, _ string) error { return nil }
96+
func (m *mockUserRepo) UpdateEmailStatus(_ context.Context, _ string, _ int) error { return nil }
97+
func (m *mockUserRepo) UpdateNoticeStatus(_ context.Context, _ string, _ int) error { return nil }
98+
func (m *mockUserRepo) UpdateEmail(_ context.Context, _, _ string) error { return nil }
99+
func (m *mockUserRepo) UpdateUserInterface(_ context.Context, _, _, _ string) error { return nil }
100+
func (m *mockUserRepo) UpdatePass(_ context.Context, _, _ string) error { return nil }
101+
func (m *mockUserRepo) UpdateInfo(_ context.Context, _ *entity.User) error { return nil }
102+
func (m *mockUserRepo) UpdateUserProfile(_ context.Context, _ *entity.User) error { return nil }
103+
func (m *mockUserRepo) BatchGetByID(_ context.Context, _ []string) ([]*entity.User, error) { return nil, nil }
104+
func (m *mockUserRepo) GetByUsername(_ context.Context, _ string) (*entity.User, bool, error) {
105+
return nil, false, nil
106+
}
107+
func (m *mockUserRepo) GetByUsernames(_ context.Context, _ []string) ([]*entity.User, error) {
108+
return nil, nil
109+
}
110+
func (m *mockUserRepo) GetByEmail(_ context.Context, _ string) (*entity.User, bool, error) {
111+
return nil, false, nil
112+
}
113+
func (m *mockUserRepo) GetUserCount(_ context.Context) (int64, error) { return 0, nil }
114+
func (m *mockUserRepo) SearchUserListByName(_ context.Context, _ string, _ int, _ bool) ([]*entity.User, error) {
115+
return nil, nil
116+
}
117+
func (m *mockUserRepo) IsAvatarFileUsed(_ context.Context, _ string) (bool, error) {
118+
return false, nil
119+
}
120+
func (m *mockUserRepo) GetByUserID(_ context.Context, _ string) (*entity.User, bool, error) {
121+
return m.user, m.exist, m.err
122+
}
123+
124+
// mockUserRoleRelRepo implements role.UserRoleRelRepo for testing
125+
type mockUserRoleRelRepo struct {
126+
roleID int
127+
exist bool
128+
}
129+
130+
func (m *mockUserRoleRelRepo) SaveUserRoleRel(_ context.Context, _ string, _ int) error { return nil }
131+
func (m *mockUserRoleRelRepo) GetUserRoleRelList(_ context.Context, _ []string) ([]*entity.UserRoleRel, error) {
132+
return nil, nil
133+
}
134+
func (m *mockUserRoleRelRepo) GetUserRoleRelListByRoleID(_ context.Context, _ []int) ([]*entity.UserRoleRel, error) {
135+
return nil, nil
136+
}
137+
func (m *mockUserRoleRelRepo) GetUserRoleRel(_ context.Context, _ string) (*entity.UserRoleRel, bool, error) {
138+
if !m.exist {
139+
return nil, false, nil
140+
}
141+
return &entity.UserRoleRel{RoleID: m.roleID}, true, nil
142+
}
143+
144+
// --- Helper ---
145+
146+
func newTestMiddleware(
147+
authRepo *mockAuthRepo,
148+
apiKeyRepo *mockAPIKeyRepo,
149+
userRepo *mockUserRepo,
150+
roleID int,
151+
) *AuthUserMiddleware {
152+
svc := auth.NewAuthService(authRepo, apiKeyRepo)
153+
userRoleRelService := role.NewUserRoleRelService(&mockUserRoleRelRepo{roleID: roleID, exist: true}, nil)
154+
return NewAuthUserMiddleware(svc, nil, userRepo, userRoleRelService)
155+
}
156+
157+
func performRequest(mw gin.HandlerFunc, method, path string) *httptest.ResponseRecorder {
158+
gin.SetMode(gin.TestMode)
159+
w := httptest.NewRecorder()
160+
_, engine := gin.CreateTestContext(w)
161+
engine.Use(mw)
162+
engine.Handle(method, path, func(c *gin.Context) {
163+
c.String(http.StatusOK, "ok")
164+
})
165+
req, _ := http.NewRequest(method, path, nil)
166+
req.Header.Set("Authorization", "Bearer test-token")
167+
engine.ServeHTTP(w, req)
168+
return w
169+
}
170+
171+
func performRequestNoToken(mw gin.HandlerFunc) *httptest.ResponseRecorder {
172+
gin.SetMode(gin.TestMode)
173+
w := httptest.NewRecorder()
174+
_, engine := gin.CreateTestContext(w)
175+
engine.Use(mw)
176+
engine.Handle("GET", "/test", func(c *gin.Context) {
177+
c.String(http.StatusOK, "ok")
178+
})
179+
req, _ := http.NewRequest("GET", "/test", nil)
180+
engine.ServeHTTP(w, req)
181+
return w
182+
}
183+
184+
// --- Tests ---
185+
186+
func TestAuthSessionOrAPIKey_ValidSession(t *testing.T) {
187+
m := newTestMiddleware(
188+
&mockAuthRepo{userCache: &entity.UserCacheInfo{
189+
UserID: "100",
190+
UserStatus: entity.UserStatusAvailable,
191+
EmailStatus: entity.EmailStatusAvailable,
192+
RoleID: 1,
193+
}},
194+
&mockAPIKeyRepo{exist: false},
195+
&mockUserRepo{exist: false},
196+
1,
197+
)
198+
w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/test")
199+
assert.Equal(t, http.StatusOK, w.Code)
200+
}
201+
202+
func TestAuthSessionOrAPIKey_InvalidSessionFallbackValidAPIKey(t *testing.T) {
203+
m := newTestMiddleware(
204+
&mockAuthRepo{userCache: nil}, // session fails
205+
&mockAPIKeyRepo{
206+
key: &entity.APIKey{AccessKey: "sk_test", Scope: "read-write", UserID: "200"},
207+
exist: true,
208+
},
209+
&mockUserRepo{
210+
user: &entity.User{ID: "200", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable},
211+
exist: true,
212+
},
213+
1,
214+
)
215+
w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/answer/api/v1/question")
216+
assert.Equal(t, http.StatusOK, w.Code)
217+
}
218+
219+
func TestAuthSessionOrAPIKey_BothFail(t *testing.T) {
220+
m := newTestMiddleware(
221+
&mockAuthRepo{userCache: nil},
222+
&mockAPIKeyRepo{exist: false},
223+
&mockUserRepo{exist: false},
224+
1,
225+
)
226+
w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/answer/api/v1/question")
227+
assert.Equal(t, http.StatusUnauthorized, w.Code)
228+
}
229+
230+
func TestAuthSessionOrAPIKey_NoToken(t *testing.T) {
231+
m := newTestMiddleware(
232+
&mockAuthRepo{userCache: nil},
233+
&mockAPIKeyRepo{exist: false},
234+
&mockUserRepo{exist: false},
235+
1,
236+
)
237+
w := performRequestNoToken(m.AuthSessionOrAPIKey())
238+
assert.Equal(t, http.StatusUnauthorized, w.Code)
239+
}
240+
241+
func TestAuthSessionOrAPIKey_ReadOnlyKeyPostRequest(t *testing.T) {
242+
m := newTestMiddleware(
243+
&mockAuthRepo{userCache: nil}, // session fails
244+
&mockAPIKeyRepo{
245+
key: &entity.APIKey{AccessKey: "sk_ro", Scope: "read-only", UserID: "300"},
246+
exist: true,
247+
},
248+
&mockUserRepo{
249+
user: &entity.User{ID: "300", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable},
250+
exist: true,
251+
},
252+
1,
253+
)
254+
w := performRequest(m.AuthSessionOrAPIKey(), "POST", "/answer/api/v1/question")
255+
assert.Equal(t, http.StatusUnauthorized, w.Code)
256+
}
257+
258+
func TestAuthSessionOrAPIKey_APIKeyBlockedOnNonWhitelistedRoute(t *testing.T) {
259+
m := newTestMiddleware(
260+
&mockAuthRepo{userCache: nil},
261+
&mockAPIKeyRepo{
262+
key: &entity.APIKey{AccessKey: "sk_test", Scope: "read-write", UserID: "400"},
263+
exist: true,
264+
},
265+
&mockUserRepo{
266+
user: &entity.User{ID: "400", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable},
267+
exist: true,
268+
},
269+
1,
270+
)
271+
w := performRequest(m.AuthSessionOrAPIKey(), "PUT", "/answer/api/v1/user/password")
272+
assert.Equal(t, http.StatusUnauthorized, w.Code)
273+
}

0 commit comments

Comments
 (0)