|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +package middleware |
| 21 | + |
| 22 | +import ( |
| 23 | + "context" |
| 24 | + "net/http" |
| 25 | + "net/http/httptest" |
| 26 | + "testing" |
| 27 | + |
| 28 | + "github.com/apache/answer/internal/entity" |
| 29 | + "github.com/apache/answer/internal/service/auth" |
| 30 | + "github.com/apache/answer/internal/service/role" |
| 31 | + "github.com/gin-gonic/gin" |
| 32 | + "github.com/stretchr/testify/assert" |
| 33 | +) |
| 34 | + |
| 35 | +// --- Mock repos for AuthService --- |
| 36 | + |
| 37 | +type mockAuthRepo struct { |
| 38 | + userCache *entity.UserCacheInfo |
| 39 | + err error |
| 40 | +} |
| 41 | + |
| 42 | +func (m *mockAuthRepo) GetUserCacheInfo(_ context.Context, _ string) (*entity.UserCacheInfo, error) { |
| 43 | + return m.userCache, m.err |
| 44 | +} |
| 45 | +func (m *mockAuthRepo) SetUserCacheInfo(_ context.Context, _, _ string, _ *entity.UserCacheInfo) error { |
| 46 | + return nil |
| 47 | +} |
| 48 | +func (m *mockAuthRepo) GetUserVisitCacheInfo(_ context.Context, _ string) (string, error) { |
| 49 | + return "", nil |
| 50 | +} |
| 51 | +func (m *mockAuthRepo) RemoveUserCacheInfo(_ context.Context, _ string) error { return nil } |
| 52 | +func (m *mockAuthRepo) RemoveUserVisitCacheInfo(_ context.Context, _ string) error { return nil } |
| 53 | +func (m *mockAuthRepo) SetUserStatus(_ context.Context, _ string, _ *entity.UserCacheInfo) error { |
| 54 | + return nil |
| 55 | +} |
| 56 | +func (m *mockAuthRepo) GetUserStatus(_ context.Context, _ string) (*entity.UserCacheInfo, error) { |
| 57 | + return nil, nil |
| 58 | +} |
| 59 | +func (m *mockAuthRepo) RemoveUserStatus(_ context.Context, _ string) error { return nil } |
| 60 | +func (m *mockAuthRepo) GetAdminUserCacheInfo(_ context.Context, _ string) (*entity.UserCacheInfo, error) { |
| 61 | + return nil, nil |
| 62 | +} |
| 63 | +func (m *mockAuthRepo) SetAdminUserCacheInfo(_ context.Context, _ string, _ *entity.UserCacheInfo) error { |
| 64 | + return nil |
| 65 | +} |
| 66 | +func (m *mockAuthRepo) RemoveAdminUserCacheInfo(_ context.Context, _ string) error { return nil } |
| 67 | +func (m *mockAuthRepo) AddUserTokenMapping(_ context.Context, _, _ string) error { return nil } |
| 68 | +func (m *mockAuthRepo) RemoveUserTokens(_ context.Context, _ string, _ string) {} |
| 69 | + |
| 70 | +type mockAPIKeyRepo struct { |
| 71 | + key *entity.APIKey |
| 72 | + exist bool |
| 73 | + err error |
| 74 | +} |
| 75 | + |
| 76 | +func (m *mockAPIKeyRepo) GetAPIKeyList(_ context.Context) ([]*entity.APIKey, error) { return nil, nil } |
| 77 | +func (m *mockAPIKeyRepo) GetAPIKey(_ context.Context, _ string) (*entity.APIKey, bool, error) { |
| 78 | + return m.key, m.exist, m.err |
| 79 | +} |
| 80 | +func (m *mockAPIKeyRepo) UpdateAPIKey(_ context.Context, _ entity.APIKey) error { return nil } |
| 81 | +func (m *mockAPIKeyRepo) AddAPIKey(_ context.Context, _ entity.APIKey) error { return nil } |
| 82 | +func (m *mockAPIKeyRepo) DeleteAPIKey(_ context.Context, _ int) error { return nil } |
| 83 | + |
| 84 | +type mockUserRepo struct { |
| 85 | + user *entity.User |
| 86 | + exist bool |
| 87 | + err error |
| 88 | +} |
| 89 | + |
| 90 | +func (m *mockUserRepo) AddUser(_ context.Context, _ *entity.User) error { return nil } |
| 91 | +func (m *mockUserRepo) IncreaseAnswerCount(_ context.Context, _ string, _ int) error { return nil } |
| 92 | +func (m *mockUserRepo) IncreaseQuestionCount(_ context.Context, _ string, _ int) error { return nil } |
| 93 | +func (m *mockUserRepo) UpdateQuestionCount(_ context.Context, _ string, _ int64) error { return nil } |
| 94 | +func (m *mockUserRepo) UpdateAnswerCount(_ context.Context, _ string, _ int) error { return nil } |
| 95 | +func (m *mockUserRepo) UpdateLastLoginDate(_ context.Context, _ string) error { return nil } |
| 96 | +func (m *mockUserRepo) UpdateEmailStatus(_ context.Context, _ string, _ int) error { return nil } |
| 97 | +func (m *mockUserRepo) UpdateNoticeStatus(_ context.Context, _ string, _ int) error { return nil } |
| 98 | +func (m *mockUserRepo) UpdateEmail(_ context.Context, _, _ string) error { return nil } |
| 99 | +func (m *mockUserRepo) UpdateUserInterface(_ context.Context, _, _, _ string) error { return nil } |
| 100 | +func (m *mockUserRepo) UpdatePass(_ context.Context, _, _ string) error { return nil } |
| 101 | +func (m *mockUserRepo) UpdateInfo(_ context.Context, _ *entity.User) error { return nil } |
| 102 | +func (m *mockUserRepo) UpdateUserProfile(_ context.Context, _ *entity.User) error { return nil } |
| 103 | +func (m *mockUserRepo) BatchGetByID(_ context.Context, _ []string) ([]*entity.User, error) { return nil, nil } |
| 104 | +func (m *mockUserRepo) GetByUsername(_ context.Context, _ string) (*entity.User, bool, error) { |
| 105 | + return nil, false, nil |
| 106 | +} |
| 107 | +func (m *mockUserRepo) GetByUsernames(_ context.Context, _ []string) ([]*entity.User, error) { |
| 108 | + return nil, nil |
| 109 | +} |
| 110 | +func (m *mockUserRepo) GetByEmail(_ context.Context, _ string) (*entity.User, bool, error) { |
| 111 | + return nil, false, nil |
| 112 | +} |
| 113 | +func (m *mockUserRepo) GetUserCount(_ context.Context) (int64, error) { return 0, nil } |
| 114 | +func (m *mockUserRepo) SearchUserListByName(_ context.Context, _ string, _ int, _ bool) ([]*entity.User, error) { |
| 115 | + return nil, nil |
| 116 | +} |
| 117 | +func (m *mockUserRepo) IsAvatarFileUsed(_ context.Context, _ string) (bool, error) { |
| 118 | + return false, nil |
| 119 | +} |
| 120 | +func (m *mockUserRepo) GetByUserID(_ context.Context, _ string) (*entity.User, bool, error) { |
| 121 | + return m.user, m.exist, m.err |
| 122 | +} |
| 123 | + |
| 124 | +// mockUserRoleRelRepo implements role.UserRoleRelRepo for testing |
| 125 | +type mockUserRoleRelRepo struct { |
| 126 | + roleID int |
| 127 | + exist bool |
| 128 | +} |
| 129 | + |
| 130 | +func (m *mockUserRoleRelRepo) SaveUserRoleRel(_ context.Context, _ string, _ int) error { return nil } |
| 131 | +func (m *mockUserRoleRelRepo) GetUserRoleRelList(_ context.Context, _ []string) ([]*entity.UserRoleRel, error) { |
| 132 | + return nil, nil |
| 133 | +} |
| 134 | +func (m *mockUserRoleRelRepo) GetUserRoleRelListByRoleID(_ context.Context, _ []int) ([]*entity.UserRoleRel, error) { |
| 135 | + return nil, nil |
| 136 | +} |
| 137 | +func (m *mockUserRoleRelRepo) GetUserRoleRel(_ context.Context, _ string) (*entity.UserRoleRel, bool, error) { |
| 138 | + if !m.exist { |
| 139 | + return nil, false, nil |
| 140 | + } |
| 141 | + return &entity.UserRoleRel{RoleID: m.roleID}, true, nil |
| 142 | +} |
| 143 | + |
| 144 | +// --- Helper --- |
| 145 | + |
| 146 | +func newTestMiddleware( |
| 147 | + authRepo *mockAuthRepo, |
| 148 | + apiKeyRepo *mockAPIKeyRepo, |
| 149 | + userRepo *mockUserRepo, |
| 150 | + roleID int, |
| 151 | +) *AuthUserMiddleware { |
| 152 | + svc := auth.NewAuthService(authRepo, apiKeyRepo) |
| 153 | + userRoleRelService := role.NewUserRoleRelService(&mockUserRoleRelRepo{roleID: roleID, exist: true}, nil) |
| 154 | + return NewAuthUserMiddleware(svc, nil, userRepo, userRoleRelService) |
| 155 | +} |
| 156 | + |
| 157 | +func performRequest(mw gin.HandlerFunc, method, path string) *httptest.ResponseRecorder { |
| 158 | + gin.SetMode(gin.TestMode) |
| 159 | + w := httptest.NewRecorder() |
| 160 | + _, engine := gin.CreateTestContext(w) |
| 161 | + engine.Use(mw) |
| 162 | + engine.Handle(method, path, func(c *gin.Context) { |
| 163 | + c.String(http.StatusOK, "ok") |
| 164 | + }) |
| 165 | + req, _ := http.NewRequest(method, path, nil) |
| 166 | + req.Header.Set("Authorization", "Bearer test-token") |
| 167 | + engine.ServeHTTP(w, req) |
| 168 | + return w |
| 169 | +} |
| 170 | + |
| 171 | +func performRequestNoToken(mw gin.HandlerFunc) *httptest.ResponseRecorder { |
| 172 | + gin.SetMode(gin.TestMode) |
| 173 | + w := httptest.NewRecorder() |
| 174 | + _, engine := gin.CreateTestContext(w) |
| 175 | + engine.Use(mw) |
| 176 | + engine.Handle("GET", "/test", func(c *gin.Context) { |
| 177 | + c.String(http.StatusOK, "ok") |
| 178 | + }) |
| 179 | + req, _ := http.NewRequest("GET", "/test", nil) |
| 180 | + engine.ServeHTTP(w, req) |
| 181 | + return w |
| 182 | +} |
| 183 | + |
| 184 | +// --- Tests --- |
| 185 | + |
| 186 | +func TestAuthSessionOrAPIKey_ValidSession(t *testing.T) { |
| 187 | + m := newTestMiddleware( |
| 188 | + &mockAuthRepo{userCache: &entity.UserCacheInfo{ |
| 189 | + UserID: "100", |
| 190 | + UserStatus: entity.UserStatusAvailable, |
| 191 | + EmailStatus: entity.EmailStatusAvailable, |
| 192 | + RoleID: 1, |
| 193 | + }}, |
| 194 | + &mockAPIKeyRepo{exist: false}, |
| 195 | + &mockUserRepo{exist: false}, |
| 196 | + 1, |
| 197 | + ) |
| 198 | + w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/test") |
| 199 | + assert.Equal(t, http.StatusOK, w.Code) |
| 200 | +} |
| 201 | + |
| 202 | +func TestAuthSessionOrAPIKey_InvalidSessionFallbackValidAPIKey(t *testing.T) { |
| 203 | + m := newTestMiddleware( |
| 204 | + &mockAuthRepo{userCache: nil}, // session fails |
| 205 | + &mockAPIKeyRepo{ |
| 206 | + key: &entity.APIKey{AccessKey: "sk_test", Scope: "read-write", UserID: "200"}, |
| 207 | + exist: true, |
| 208 | + }, |
| 209 | + &mockUserRepo{ |
| 210 | + user: &entity.User{ID: "200", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable}, |
| 211 | + exist: true, |
| 212 | + }, |
| 213 | + 1, |
| 214 | + ) |
| 215 | + w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/answer/api/v1/question") |
| 216 | + assert.Equal(t, http.StatusOK, w.Code) |
| 217 | +} |
| 218 | + |
| 219 | +func TestAuthSessionOrAPIKey_BothFail(t *testing.T) { |
| 220 | + m := newTestMiddleware( |
| 221 | + &mockAuthRepo{userCache: nil}, |
| 222 | + &mockAPIKeyRepo{exist: false}, |
| 223 | + &mockUserRepo{exist: false}, |
| 224 | + 1, |
| 225 | + ) |
| 226 | + w := performRequest(m.AuthSessionOrAPIKey(), "GET", "/answer/api/v1/question") |
| 227 | + assert.Equal(t, http.StatusUnauthorized, w.Code) |
| 228 | +} |
| 229 | + |
| 230 | +func TestAuthSessionOrAPIKey_NoToken(t *testing.T) { |
| 231 | + m := newTestMiddleware( |
| 232 | + &mockAuthRepo{userCache: nil}, |
| 233 | + &mockAPIKeyRepo{exist: false}, |
| 234 | + &mockUserRepo{exist: false}, |
| 235 | + 1, |
| 236 | + ) |
| 237 | + w := performRequestNoToken(m.AuthSessionOrAPIKey()) |
| 238 | + assert.Equal(t, http.StatusUnauthorized, w.Code) |
| 239 | +} |
| 240 | + |
| 241 | +func TestAuthSessionOrAPIKey_ReadOnlyKeyPostRequest(t *testing.T) { |
| 242 | + m := newTestMiddleware( |
| 243 | + &mockAuthRepo{userCache: nil}, // session fails |
| 244 | + &mockAPIKeyRepo{ |
| 245 | + key: &entity.APIKey{AccessKey: "sk_ro", Scope: "read-only", UserID: "300"}, |
| 246 | + exist: true, |
| 247 | + }, |
| 248 | + &mockUserRepo{ |
| 249 | + user: &entity.User{ID: "300", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable}, |
| 250 | + exist: true, |
| 251 | + }, |
| 252 | + 1, |
| 253 | + ) |
| 254 | + w := performRequest(m.AuthSessionOrAPIKey(), "POST", "/answer/api/v1/question") |
| 255 | + assert.Equal(t, http.StatusUnauthorized, w.Code) |
| 256 | +} |
| 257 | + |
| 258 | +func TestAuthSessionOrAPIKey_APIKeyBlockedOnNonWhitelistedRoute(t *testing.T) { |
| 259 | + m := newTestMiddleware( |
| 260 | + &mockAuthRepo{userCache: nil}, |
| 261 | + &mockAPIKeyRepo{ |
| 262 | + key: &entity.APIKey{AccessKey: "sk_test", Scope: "read-write", UserID: "400"}, |
| 263 | + exist: true, |
| 264 | + }, |
| 265 | + &mockUserRepo{ |
| 266 | + user: &entity.User{ID: "400", Status: entity.UserStatusAvailable, MailStatus: entity.EmailStatusAvailable}, |
| 267 | + exist: true, |
| 268 | + }, |
| 269 | + 1, |
| 270 | + ) |
| 271 | + w := performRequest(m.AuthSessionOrAPIKey(), "PUT", "/answer/api/v1/user/password") |
| 272 | + assert.Equal(t, http.StatusUnauthorized, w.Code) |
| 273 | +} |
0 commit comments