diff --git a/README.md b/README.md index e52ad2b..1b5475e 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ The Discord integration for [reviewGOOSE](https://codegroove.dev/reviewgoose/) - Creates Discord threads for new PRs (forum channels) or posts in text channels - Smart notifications: Delays DMs if user already notified in channel - Channel auto-discovery: repos automatically map to same-named channels +- **Self-service user linking**: Link your GitHub account with `/goose github-user` command +- **Smart user matching**: Automatically matches by username, display name, or server nickname - Configurable notification settings via YAML - Activity-based reports when you come online - Reliable delivery with deduplication @@ -44,11 +46,11 @@ Create `.codeGROOVE/discord.yaml`: ```yaml global: - guild_id: "YOUR_DISCORD_SERVER_ID" + guild_id: YOUR_DISCORD_SERVER_ID # Optional: Add explicit user mappings if GitHub/Discord usernames differ # users: -# github-username: "discord-user-id" +# github-username: discord-user-id ``` ### 5. Add the Bot to Your Server @@ -70,12 +72,12 @@ Full configuration options for `.codeGROOVE/discord.yaml`: ```yaml global: - guild_id: "1234567890123456789" + guild_id: 1234567890123456789 reminder_dm_delay: 65 # Minutes to wait before sending DM (default: 65, 0 = disabled) users: - alice: "111111111111111111" # GitHub username → Discord user ID - bob: "222222222222222222" + alice: 111111111111111111 # GitHub username → Discord user ID + bob: discord-bob-username # GitHub username → Discord username # Unmapped users: bot attempts username match in guild channels: @@ -110,24 +112,33 @@ channels: ## User Mapping -The bot maps GitHub → Discord users using a 3-tier lookup system: +The bot maps GitHub → Discord users using a 4-tier lookup system: -### 1. Explicit Config Mapping +### 1. Explicit Config Mapping (Highest Priority) Checks the `users:` section in `discord.yaml`. Values can be: - Discord numeric ID: `"111111111111111111"` - Discord username: Bot will look it up in the guild -### 2. Automatic Username Match -Searches the Discord guild for the GitHub username using progressive matching. At each tier, checks both: +### 2. Self-Service Linking +Users can link their own accounts with `/goose github-user `. Mappings are stored persistently and take priority over automatic discovery. + +Example: +``` +/goose github-user octocat +``` + +### 3. Automatic Username Match +Searches the Discord guild for the GitHub username using progressive matching. At each tier, checks: - Discord **Username** (e.g., `@johndoe`) - Discord **Display Name** (the name shown in the member list) +- Discord **Server Nickname** (the custom name set for this server) Matching tiers: -- **Tier 1**: Exact match (checks Username first, then Display Name) +- **Tier 1**: Exact match (checks Username, Display Name, then Nickname) - **Tier 2**: Case-insensitive match (e.g., `JohnDoe` matches `johndoe`) - **Tier 3**: Prefix match (e.g., `john` matches `johnsmith`) - only if unambiguous (exactly one match) -### 3. Fallback +### 4. Fallback If no match is found, mentions GitHub username as plain text (e.g., `octocat` instead of `@octocat`) --- @@ -136,12 +147,16 @@ If no match is found, mentions GitHub username as plain text (e.g., `octocat` in **How to get Discord User IDs**: With Developer Mode enabled, right-click any username → Copy User ID +**Pro tip**: Set your Discord server nickname to match your GitHub username for automatic matching! + ## Slash Commands -- `/goose status` - Show bot connection status -- `/goose report` - Get your personal PR report -- `/goose dashboard` - Link to web dashboard -- `/goose help` - Show help +- `/goose status` - Show bot connection status and statistics +- `/goose dash` - Get your personal PR report and dashboard links +- `/goose github-user ` - Link your Discord account to a GitHub username +- `/goose users` - Show all GitHub ↔ Discord user mappings +- `/goose channels` - Show repository to channel mappings +- `/goose help` - Show help information ## Notification Behavior diff --git a/cmd/server/main.go b/cmd/server/main.go index 983c775..b14eb42 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -358,7 +358,7 @@ func (m *coordinatorManager) startSingleCoordinator(ctx context.Context, org str m.notifyMgr.RegisterGuild(guildID, discordClient) // Create user mapper - userMapper := usermapping.New(org, m.configManager, discordClient) + userMapper := usermapping.New(org, m.configManager, discordClient, m.store, guildID) // Create Turn client with token provider (will fetch fresh tokens automatically) turnClient := bot.NewTurnClient(m.cfg.TurnURL, ghClient) @@ -482,6 +482,7 @@ func (m *coordinatorManager) discordClientForGuild(_ context.Context, guildID st slashHandler.SetUserMapGetter(m) slashHandler.SetChannelMapGetter(m) slashHandler.SetDailyReportGetter(m) + slashHandler.SetStore(m.store) // Register slash commands with Discord if err := slashHandler.RegisterCommands(guildID); err != nil { diff --git a/cmd/server/main_test.go b/cmd/server/main_test.go index 26ee795..8b7c386 100644 --- a/cmd/server/main_test.go +++ b/cmd/server/main_test.go @@ -151,6 +151,18 @@ func (m *mockStateStore) SaveDailyReportInfo(_ context.Context, userID string, i return nil } +func (m *mockStateStore) UserMapping(_ context.Context, _, _ string) (state.UserMappingInfo, bool) { + return state.UserMappingInfo{}, false +} + +func (m *mockStateStore) SaveUserMapping(_ context.Context, _ string, _ state.UserMappingInfo) error { + return nil +} + +func (m *mockStateStore) ListUserMappings(_ context.Context, _ string) []state.UserMappingInfo { + return nil +} + func (m *mockStateStore) Cleanup(_ context.Context) error { return nil } diff --git a/internal/discord/client.go b/internal/discord/client.go index f9bc558..d2c5bcc 100644 --- a/internal/discord/client.go +++ b/internal/discord/client.go @@ -18,7 +18,8 @@ import ( // Client wraps discordgo.Session with a clean interface for bot operations. type Client struct { - session *discordgo.Session + session session + realSession *discordgo.Session // Keep reference for Session() method channelCache map[string]string // channel name -> ID channelTypeCache map[string]discordgo.ChannelType // channel ID -> type userCache map[string]string // username -> ID @@ -42,7 +43,8 @@ func New(token string) (*Client, error) { discordgo.IntentsMessageContent return &Client{ - session: session, + session: &sessionAdapter{Session: session}, + realSession: session, channelCache: make(map[string]string), channelTypeCache: make(map[string]discordgo.ChannelType), userCache: make(map[string]string), @@ -106,7 +108,7 @@ func (c *Client) Close() error { // Session returns the underlying discordgo session. func (c *Client) Session() *discordgo.Session { - return c.session + return c.realSession } // PostMessage sends a plain text message to a channel with link embeds suppressed. @@ -423,7 +425,7 @@ func (c *Client) LookupUserByUsername(ctx context.Context, username string) stri "guild_id", guildID, "total_members", len(members)) - // Tier 1: Exact match (Username takes precedence over GlobalName) + // Tier 1: Exact match (Username takes precedence over GlobalName, then Nick) for _, member := range members { if member.User.Username != username { continue @@ -456,8 +458,25 @@ func (c *Client) LookupUserByUsername(ctx context.Context, username string) stri return member.User.ID } + for _, member := range members { + if member.Nick != username { + continue + } + c.mu.Lock() + c.userCache[username] = member.User.ID + c.mu.Unlock() + + slog.Debug("found user by exact nickname match", + "username", username, + "user_id", member.User.ID, + "discord_username", member.User.Username, + "discord_global_name", member.User.GlobalName, + "discord_nick", member.Nick) + + return member.User.ID + } - // Tier 2: Case-insensitive match (Username takes precedence over GlobalName) + // Tier 2: Case-insensitive match (Username takes precedence over GlobalName, then Nick) for _, member := range members { if !strings.EqualFold(member.User.Username, username) { continue @@ -490,6 +509,23 @@ func (c *Client) LookupUserByUsername(ctx context.Context, username string) stri return member.User.ID } + for _, member := range members { + if !strings.EqualFold(member.Nick, username) { + continue + } + c.mu.Lock() + c.userCache[username] = member.User.ID + c.mu.Unlock() + + slog.Info("found user by case-insensitive nickname match", + "username", username, + "user_id", member.User.ID, + "discord_username", member.User.Username, + "discord_global_name", member.User.GlobalName, + "discord_nick", member.Nick) + + return member.User.ID + } lowerUsername := strings.ToLower(username) @@ -503,11 +539,14 @@ func (c *Client) LookupUserByUsername(ctx context.Context, username string) stri for _, member := range members { usernamePrefix := strings.HasPrefix(strings.ToLower(member.User.Username), lowerUsername) globalNamePrefix := strings.HasPrefix(strings.ToLower(member.User.GlobalName), lowerUsername) + nickPrefix := strings.HasPrefix(strings.ToLower(member.Nick), lowerUsername) if usernamePrefix { matches = append(matches, prefixMatch{member: member, matchType: "username_prefix"}) } else if globalNamePrefix { matches = append(matches, prefixMatch{member: member, matchType: "global_name_prefix"}) + } else if nickPrefix { + matches = append(matches, prefixMatch{member: member, matchType: "nick_prefix"}) } } @@ -551,6 +590,7 @@ func (c *Client) LookupUserByUsername(ctx context.Context, username string) stri "index", i, "discord_username", member.User.Username, "discord_global_name", member.User.GlobalName, + "discord_nick", member.Nick, "user_id", member.User.ID) } @@ -563,11 +603,11 @@ func (c *Client) LookupUserByUsername(ctx context.Context, username string) stri // IsBotInChannel checks if the bot has permission to send messages in a channel. func (c *Client) IsBotInChannel(ctx context.Context, channelID string) bool { - if c.session.State == nil || c.session.State.User == nil { + if c.session.GetState() == nil || c.session.GetState().User == nil { return false } - perms, err := c.session.UserChannelPermissions(c.session.State.User.ID, channelID) + perms, err := c.session.UserChannelPermissions(c.session.GetState().User.ID, channelID) if err != nil { slog.Debug("failed to check channel permissions", "channel_id", channelID, @@ -618,7 +658,7 @@ func (c *Client) IsUserActive(ctx context.Context, userID string) bool { } // Get guild from state - guild, err := c.session.State.Guild(guildID) + guild, err := c.session.GetState().Guild(guildID) if err != nil { slog.Debug("failed to get guild from state", "guild_id", guildID, @@ -681,13 +721,13 @@ type BotInfo struct { // BotInfo returns the bot's user information. func (c *Client) BotInfo(ctx context.Context) (BotInfo, error) { - if c.session.State == nil || c.session.State.User == nil { + if c.session.GetState() == nil || c.session.GetState().User == nil { return BotInfo{}, errors.New("bot user not available") } return BotInfo{ - UserID: c.session.State.User.ID, - Username: c.session.State.User.Username, + UserID: c.session.GetState().User.ID, + Username: c.session.GetState().User.Username, }, nil } @@ -707,7 +747,7 @@ func (c *Client) FindForumThread(ctx context.Context, forumID, prURL string) (th var threads *discordgo.ThreadsList err := retryableCtx(ctx, func() error { var err error - threads, err = c.session.GuildThreadsActive(c.guildID) + threads, err = c.session.ThreadsActive(c.guildID) return err }) if err != nil { @@ -733,10 +773,11 @@ func (c *Client) FindForumThread(ctx context.Context, forumID, prURL string) (th } // Also check archived threads (recently archived) with retry + // Use realSession directly since ThreadsArchived is not in the session interface var archivedThreads *discordgo.ThreadsList err = retryableCtx(ctx, func() error { var err error - archivedThreads, err = c.session.ThreadsArchived(forumID, nil, 50) + archivedThreads, err = c.realSession.ThreadsArchived(forumID, nil, 50) return err }) if err != nil { @@ -776,8 +817,8 @@ func (c *Client) FindChannelMessage(ctx context.Context, channelID, prURL string } var botID string - if c.session.State != nil && c.session.State.User != nil { - botID = c.session.State.User.ID + if c.session.GetState() != nil && c.session.GetState().User != nil { + botID = c.session.GetState().User.ID } slog.Info("searching for existing channel message", @@ -838,8 +879,8 @@ func (c *Client) FindDMForPR(ctx context.Context, userID, prURL string) (channel } var botID string - if c.session.State != nil && c.session.State.User != nil { - botID = c.session.State.User.ID + if c.session.GetState() != nil && c.session.GetState().User != nil { + botID = c.session.GetState().User.ID } var messages []*discordgo.Message diff --git a/internal/discord/client_test.go b/internal/discord/client_test.go index 5693810..e2893b6 100644 --- a/internal/discord/client_test.go +++ b/internal/discord/client_test.go @@ -2,12 +2,24 @@ package discord import ( "context" + "fmt" "strings" "testing" "github.com/bwmarrin/discordgo" ) +// newTestClientWithMock creates a Client for testing with a mock session +func newTestClientWithMock(mock *MockSession) *Client { + return &Client{ + session: mock, + realSession: nil, + channelCache: make(map[string]string), + channelTypeCache: make(map[string]discordgo.ChannelType), + userCache: make(map[string]string), + } +} + // findUserInMembers tests the matching logic without Discord API calls. // This mirrors the logic in LookupUserByUsername but operates on a slice of members. func findUserInMembers(username string, members []*discordgo.Member) (userID, matchType string) { @@ -16,7 +28,7 @@ func findUserInMembers(username string, members []*discordgo.Member) (userID, ma return "", "" } - // Tier 1: Exact match (Username takes precedence over GlobalName) + // Tier 1: Exact match (Username takes precedence over GlobalName, then Nick) for _, member := range members { if member.User.Username == username { return member.User.ID, "username" @@ -27,8 +39,13 @@ func findUserInMembers(username string, members []*discordgo.Member) (userID, ma return member.User.ID, "global_name" } } + for _, member := range members { + if member.Nick == username { + return member.User.ID, "nick" + } + } - // Tier 2: Case-insensitive match (Username takes precedence over GlobalName) + // Tier 2: Case-insensitive match (Username takes precedence over GlobalName, then Nick) for _, member := range members { if strings.EqualFold(member.User.Username, username) { return member.User.ID, "username_case_insensitive" @@ -39,6 +56,11 @@ func findUserInMembers(username string, members []*discordgo.Member) (userID, ma return member.User.ID, "global_name_case_insensitive" } } + for _, member := range members { + if strings.EqualFold(member.Nick, username) { + return member.User.ID, "nick_case_insensitive" + } + } lowerUsername := strings.ToLower(username) @@ -54,6 +76,8 @@ func findUserInMembers(username string, members []*discordgo.Member) (userID, ma matches = append(matches, prefixMatch{member: member, matchType: "username_prefix"}) } else if strings.HasPrefix(strings.ToLower(member.User.GlobalName), lowerUsername) { matches = append(matches, prefixMatch{member: member, matchType: "global_name_prefix"}) + } else if strings.HasPrefix(strings.ToLower(member.Nick), lowerUsername) { + matches = append(matches, prefixMatch{member: member, matchType: "nick_prefix"}) } } @@ -233,6 +257,76 @@ func TestFindUserInMembers(t *testing.T) { wantID: "123456", wantMatchType: "username_prefix", }, + // Nickname matching tests + { + name: "exact nickname match", + username: "octocat", + members: []*discordgo.Member{ + {Nick: "octocat", User: &discordgo.User{ID: "111", Username: "user1", GlobalName: "User One"}}, + {Nick: "cooldev", User: &discordgo.User{ID: "222", Username: "user2", GlobalName: "User Two"}}, + }, + wantID: "111", + wantMatchType: "nick", + }, + { + name: "case-insensitive nickname match", + username: "OCTOCAT", + members: []*discordgo.Member{ + {Nick: "octocat", User: &discordgo.User{ID: "111", Username: "user1", GlobalName: "User One"}}, + {Nick: "cooldev", User: &discordgo.User{ID: "222", Username: "user2", GlobalName: "User Two"}}, + }, + wantID: "111", + wantMatchType: "nick_case_insensitive", + }, + { + name: "prefix nickname match - unambiguous", + username: "octo", + members: []*discordgo.Member{ + {Nick: "octocat", User: &discordgo.User{ID: "111", Username: "user1", GlobalName: "User One"}}, + {Nick: "cooldev", User: &discordgo.User{ID: "222", Username: "user2", GlobalName: "User Two"}}, + }, + wantID: "111", + wantMatchType: "nick_prefix", + }, + { + name: "username preferred over nickname", + username: "octocat", + members: []*discordgo.Member{ + {Nick: "different", User: &discordgo.User{ID: "111", Username: "octocat", GlobalName: "User One"}}, + {Nick: "octocat", User: &discordgo.User{ID: "222", Username: "other", GlobalName: "User Two"}}, + }, + wantID: "111", + wantMatchType: "username", + }, + { + name: "global name preferred over nickname", + username: "User One", + members: []*discordgo.Member{ + {Nick: "User One", User: &discordgo.User{ID: "111", Username: "user1", GlobalName: "Different"}}, + {Nick: "other", User: &discordgo.User{ID: "222", Username: "user2", GlobalName: "User One"}}, + }, + wantID: "222", + wantMatchType: "global_name", + }, + { + name: "nickname match when username and global name don't match", + username: "mynickname", + members: []*discordgo.Member{ + {Nick: "mynickname", User: &discordgo.User{ID: "111", Username: "user1", GlobalName: "User One"}}, + {Nick: "othernick", User: &discordgo.User{ID: "222", Username: "user2", GlobalName: "User Two"}}, + }, + wantID: "111", + wantMatchType: "nick", + }, + { + name: "empty nickname should not match", + username: "test", + members: []*discordgo.Member{ + {Nick: "", User: &discordgo.User{ID: "111", Username: "test", GlobalName: "Test User"}}, + }, + wantID: "111", + wantMatchType: "username", + }, } for _, tt := range tests { @@ -273,11 +367,14 @@ func TestClient_GuildID(t *testing.T) { // TestClient_Session tests getting the session. func TestClient_Session(t *testing.T) { - mockSession := &discordgo.Session{} - client := &Client{session: mockSession} + realSession := &discordgo.Session{} + client := &Client{ + session: &sessionAdapter{Session: realSession}, + realSession: realSession, + } got := client.Session() - if got != mockSession { + if got != realSession { t.Error("Session() should return the same session") } } @@ -377,11 +474,10 @@ func TestClient_ChannelType_CacheHit(t *testing.T) { // TestClient_IsBotInChannel_NilState tests IsBotInChannel when session state is nil. func TestClient_IsBotInChannel_NilState(t *testing.T) { - client := &Client{ - session: &discordgo.Session{ - State: nil, - }, - } + mockSession := NewMockSession() + mockSession.MockState = nil + + client := newTestClientWithMock(mockSession) got := client.IsBotInChannel(context.Background(), "some-channel-id") if got { @@ -391,14 +487,10 @@ func TestClient_IsBotInChannel_NilState(t *testing.T) { // TestClient_IsBotInChannel_NilUser tests IsBotInChannel when user is nil. func TestClient_IsBotInChannel_NilUser(t *testing.T) { - state := discordgo.NewState() - state.User = nil + mockSession := NewMockSession() + mockSession.MockState.User = nil - client := &Client{ - session: &discordgo.Session{ - State: state, - }, - } + client := newTestClientWithMock(mockSession) got := client.IsBotInChannel(context.Background(), "some-channel-id") if got { @@ -408,10 +500,8 @@ func TestClient_IsBotInChannel_NilUser(t *testing.T) { // TestClient_IsUserInGuild_NoGuildID tests IsUserInGuild when no guild ID is set. func TestClient_IsUserInGuild_NoGuildID(t *testing.T) { - client := &Client{ - guildID: "", - session: &discordgo.Session{}, - } + client := newTestClientWithMock(NewMockSession()) + // Don't set guildID got := client.IsUserInGuild(context.Background(), "user-123") if got { @@ -421,10 +511,8 @@ func TestClient_IsUserInGuild_NoGuildID(t *testing.T) { // TestClient_IsUserActive_NoGuildID tests IsUserActive when no guild ID is set. func TestClient_IsUserActive_NoGuildID(t *testing.T) { - client := &Client{ - guildID: "", - session: &discordgo.Session{}, - } + client := newTestClientWithMock(NewMockSession()) + // Don't set guildID got := client.IsUserActive(context.Background(), "user-123") if got { @@ -434,10 +522,8 @@ func TestClient_IsUserActive_NoGuildID(t *testing.T) { // TestClient_GuildInfo_NoGuildID tests GuildInfo when no guild ID is set. func TestClient_GuildInfo_NoGuildID(t *testing.T) { - client := &Client{ - guildID: "", - session: &discordgo.Session{}, - } + client := newTestClientWithMock(NewMockSession()) + // Don't set guildID _, err := client.GuildInfo(context.Background()) if err == nil { @@ -450,11 +536,10 @@ func TestClient_GuildInfo_NoGuildID(t *testing.T) { // TestClient_BotInfo_NilState tests BotInfo when session state is nil. func TestClient_BotInfo_NilState(t *testing.T) { - client := &Client{ - session: &discordgo.Session{ - State: nil, - }, - } + mockSession := NewMockSession() + mockSession.MockState = nil + + client := newTestClientWithMock(mockSession) _, err := client.BotInfo(context.Background()) if err == nil { @@ -467,14 +552,10 @@ func TestClient_BotInfo_NilState(t *testing.T) { // TestClient_BotInfo_NilUser tests BotInfo when user is nil. func TestClient_BotInfo_NilUser(t *testing.T) { - state := discordgo.NewState() - state.User = nil + mockSession := NewMockSession() + mockSession.MockState.User = nil - client := &Client{ - session: &discordgo.Session{ - State: state, - }, - } + client := newTestClientWithMock(mockSession) _, err := client.BotInfo(context.Background()) if err == nil { @@ -484,3 +565,1139 @@ func TestClient_BotInfo_NilUser(t *testing.T) { t.Errorf("BotInfo() error = %q, want %q", err.Error(), "bot user not available") } } + +// TestClient_PostMessage tests posting a message to a channel. +func TestClient_PostMessage(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + channelID := "channel-123" + text := "Hello, world!" + + msgID, err := client.PostMessage(ctx, channelID, text) + if err != nil { + t.Fatalf("PostMessage() error = %v, want nil", err) + } + + if msgID == "" { + t.Error("PostMessage() returned empty message ID") + } + + if len(mockSession.SentMessages) != 1 { + t.Fatalf("Expected 1 sent message, got %d", len(mockSession.SentMessages)) + } + + sentMsg := mockSession.SentMessages[0] + if sentMsg.ChannelID != channelID { + t.Errorf("Sent message channel ID = %q, want %q", sentMsg.ChannelID, channelID) + } + if sentMsg.Content != text { + t.Errorf("Sent message content = %q, want %q", sentMsg.Content, text) + } +} + +// TestClient_PostMessage_Error tests PostMessage error handling. +func TestClient_PostMessage_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.ChannelMessageSendComplexError = fmt.Errorf("API error") + + client := newTestClientWithMock(mockSession) + + _, err := client.PostMessage(context.Background(), "channel-123", "test") + if err == nil { + t.Error("PostMessage() error = nil, want error") + } +} + +// TestClient_UpdateMessage tests updating an existing message. +func TestClient_UpdateMessage(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + channelID := "channel-123" + messageID := "msg-456" + newText := "Updated content" + + err := client.UpdateMessage(ctx, channelID, messageID, newText) + if err != nil { + t.Fatalf("UpdateMessage() error = %v, want nil", err) + } + + if len(mockSession.EditedMessages) != 1 { + t.Fatalf("Expected 1 edited message, got %d", len(mockSession.EditedMessages)) + } + + editedMsg := mockSession.EditedMessages[0] + if editedMsg.ChannelID != channelID { + t.Errorf("Edited message channel ID = %q, want %q", editedMsg.ChannelID, channelID) + } + if editedMsg.MessageID != messageID { + t.Errorf("Edited message ID = %q, want %q", editedMsg.MessageID, messageID) + } + if editedMsg.Content != newText { + t.Errorf("Edited message content = %q, want %q", editedMsg.Content, newText) + } +} + +// TestClient_UpdateMessage_Error tests UpdateMessage error handling. +func TestClient_UpdateMessage_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.ChannelMessageEditComplexError = fmt.Errorf("API error") + + client := newTestClientWithMock(mockSession) + + err := client.UpdateMessage(context.Background(), "channel-123", "msg-456", "test") + if err == nil { + t.Error("UpdateMessage() error = nil, want error") + } +} + +// TestClient_PostForumThread tests creating a forum thread. +func TestClient_PostForumThread(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("guild-123") + + ctx := context.Background() + channelID := "forum-channel-789" + title := "New PR Discussion" + content := "Let's discuss this PR" + + threadID, messageID, err := client.PostForumThread(ctx, channelID, title, content) + if err != nil { + t.Fatalf("PostForumThread() error = %v, want nil", err) + } + + if threadID == "" { + t.Error("PostForumThread() returned empty thread ID") + } + if messageID == "" { + t.Error("PostForumThread() returned empty message ID") + } + + if len(mockSession.CreatedThreads) != 1 { + t.Fatalf("Expected 1 created thread, got %d", len(mockSession.CreatedThreads)) + } + + createdThread := mockSession.CreatedThreads[0] + if createdThread.Name != title { + t.Errorf("Created thread name = %q, want %q", createdThread.Name, title) + } +} + +// TestClient_PostForumThread_Error tests PostForumThread error handling. +func TestClient_PostForumThread_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.ForumThreadStartComplexError = fmt.Errorf("API error") + + client := newTestClientWithMock(mockSession) + client.SetGuildID("guild-123") + + _, _, err := client.PostForumThread(context.Background(), "channel-123", "title", "content") + if err == nil { + t.Error("PostForumThread() error = nil, want error") + } +} + +// TestClient_UpdateForumPost tests updating a forum post. +func TestClient_UpdateForumPost(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + threadID := "thread-123" + messageID := "msg-456" + newTitle := "Updated PR Discussion" + newContent := "Updated content" + + err := client.UpdateForumPost(ctx, threadID, messageID, newTitle, newContent) + if err != nil { + t.Fatalf("UpdateForumPost() error = %v, want nil", err) + } + + // Should have edited the title (via ChannelEdit) + if len(mockSession.Channels) == 0 { + t.Error("Expected channel to be edited for title update") + } + + // Should have edited the message content + if len(mockSession.EditedMessages) != 1 { + t.Fatalf("Expected 1 edited message, got %d", len(mockSession.EditedMessages)) + } + + editedMsg := mockSession.EditedMessages[0] + if editedMsg.Content != newContent { + t.Errorf("Edited message content = %q, want %q", editedMsg.Content, newContent) + } +} + +// TestClient_UpdateForumPost_NoMessageID tests UpdateForumPost with empty message ID. +func TestClient_UpdateForumPost_NoMessageID(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + threadID := "thread-123" + newTitle := "Updated PR Discussion" + newContent := "Updated content" + + err := client.UpdateForumPost(ctx, threadID, "", newTitle, newContent) + if err != nil { + t.Fatalf("UpdateForumPost() error = %v, want nil", err) + } + + // Should only have edited the title, not the message + if len(mockSession.EditedMessages) != 0 { + t.Errorf("Expected 0 edited messages, got %d", len(mockSession.EditedMessages)) + } +} + +// TestClient_UpdateForumPost_TitleEditError tests UpdateForumPost when title edit fails. +func TestClient_UpdateForumPost_TitleEditError(t *testing.T) { + mockSession := NewMockSession() + mockSession.ChannelEditError = fmt.Errorf("title edit failed") + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + err := client.UpdateForumPost(ctx, "thread-123", "msg-456", "New Title", "New Content") + if err == nil { + t.Error("UpdateForumPost() error = nil, want error when title edit fails") + } +} + +// TestClient_UpdateForumPost_MessageEditError tests UpdateForumPost when message edit fails. +func TestClient_UpdateForumPost_MessageEditError(t *testing.T) { + mockSession := NewMockSession() + mockSession.ChannelMessageEditComplexError = fmt.Errorf("message edit failed") + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + err := client.UpdateForumPost(ctx, "thread-123", "msg-456", "New Title", "New Content") + if err == nil { + t.Error("UpdateForumPost() error = nil, want error when message edit fails") + } +} + +// TestClient_PostForumThread_MessagesFetchError tests when ChannelMessages fails. +func TestClient_PostForumThread_MessagesFetchError(t *testing.T) { + mockSession := NewMockSession() + mockSession.MessagesError = fmt.Errorf("messages fetch failed") + client := newTestClientWithMock(mockSession) + client.SetGuildID("guild-123") + + ctx := context.Background() + threadID, messageID, err := client.PostForumThread(ctx, "channel-123", "title", "content") + + // Should succeed but return empty messageID + if err != nil { + t.Errorf("PostForumThread() error = %v, want nil", err) + } + if threadID == "" { + t.Error("PostForumThread() threadID = empty, want non-empty") + } + if messageID != "" { + t.Errorf("PostForumThread() messageID = %q, want empty when messages fetch fails", messageID) + } +} + +// TestClient_ArchiveThread tests archiving a thread. +func TestClient_ArchiveThread(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + threadID := "thread-123" + + err := client.ArchiveThread(ctx, threadID) + if err != nil { + t.Fatalf("ArchiveThread() error = %v, want nil", err) + } + + // Check that the channel was edited with Archived=true + channel, exists := mockSession.Channels[threadID] + if !exists { + t.Fatal("Expected channel to be edited for archiving") + } + + if channel.ThreadMetadata == nil || !channel.ThreadMetadata.Archived { + t.Error("Expected thread to be archived") + } +} + +// TestClient_ArchiveThread_Error tests ArchiveThread error handling. +func TestClient_ArchiveThread_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.ChannelEditError = fmt.Errorf("API error") + + client := newTestClientWithMock(mockSession) + + err := client.ArchiveThread(context.Background(), "thread-123") + if err == nil { + t.Error("ArchiveThread() error = nil, want error") + } +} + +// TestClient_SendDM tests sending a direct message. +func TestClient_SendDM(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + userID := "user-123" + text := "Hello via DM!" + + channelID, messageID, err := client.SendDM(ctx, userID, text) + if err != nil { + t.Fatalf("SendDM() error = %v, want nil", err) + } + + if channelID == "" { + t.Error("SendDM() returned empty channel ID") + } + if messageID == "" { + t.Error("SendDM() returned empty message ID") + } + + if len(mockSession.CreatedChannels) != 1 { + t.Fatalf("Expected 1 created DM channel, got %d", len(mockSession.CreatedChannels)) + } + + if len(mockSession.SentMessages) != 1 { + t.Fatalf("Expected 1 sent DM, got %d", len(mockSession.SentMessages)) + } + + sentMsg := mockSession.SentMessages[0] + if sentMsg.Content != text { + t.Errorf("Sent DM content = %q, want %q", sentMsg.Content, text) + } +} + +// TestClient_SendDM_UserChannelError tests SendDM error when creating DM channel. +func TestClient_SendDM_UserChannelError(t *testing.T) { + mockSession := NewMockSession() + mockSession.UserChannelError = fmt.Errorf("failed to create DM channel") + + client := newTestClientWithMock(mockSession) + + _, _, err := client.SendDM(context.Background(), "user-123", "test") + if err == nil { + t.Error("SendDM() error = nil, want error") + } +} + +// TestClient_SendDM_MessageSendError tests SendDM error when sending message. +func TestClient_SendDM_MessageSendError(t *testing.T) { + mockSession := NewMockSession() + mockSession.ChannelMessageSendComplexError = fmt.Errorf("failed to send message") + + client := newTestClientWithMock(mockSession) + + _, _, err := client.SendDM(context.Background(), "user-123", "test") + if err == nil { + t.Error("SendDM() error = nil, want error") + } +} + +// TestClient_UpdateDM tests updating a DM. +func TestClient_UpdateDM(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + channelID := "dm-channel-123" + messageID := "msg-456" + newText := "Updated DM content" + + err := client.UpdateDM(ctx, channelID, messageID, newText) + if err != nil { + t.Fatalf("UpdateDM() error = %v, want nil", err) + } + + if len(mockSession.EditedMessages) != 1 { + t.Fatalf("Expected 1 edited message, got %d", len(mockSession.EditedMessages)) + } + + editedMsg := mockSession.EditedMessages[0] + if editedMsg.Content != newText { + t.Errorf("Edited DM content = %q, want %q", editedMsg.Content, newText) + } +} + +// TestClient_UpdateDM_Error tests UpdateDM error handling. +func TestClient_UpdateDM_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.ChannelMessageEditComplexError = fmt.Errorf("API error") + + client := newTestClientWithMock(mockSession) + + err := client.UpdateDM(context.Background(), "dm-123", "msg-456", "test") + if err == nil { + t.Error("UpdateDM() error = nil, want error") + } +} + +// TestClient_LookupUserByUsername tests user lookup with cached value. +func TestClient_LookupUserByUsername(t *testing.T) { + mockSession := NewMockSession() + mockSession.AddMember("guild-123", NewMockMember("user-111", "alice", "Alice Smith")) + mockSession.AddMember("guild-123", NewMockMember("user-222", "bob", "Bob Jones")) + + client := newTestClientWithMock(mockSession) + client.SetGuildID("guild-123") + + ctx := context.Background() + userID := client.LookupUserByUsername(ctx, "alice") + + if userID != "user-111" { + t.Errorf("LookupUserByUsername(\"alice\") = %q, want %q", userID, "user-111") + } + + // Should now be cached + if cachedID, ok := client.userCache["alice"]; !ok || cachedID != "user-111" { + t.Error("Expected user to be cached after lookup") + } +} + +// TestClient_LookupUserByUsername_CacheHit tests user lookup with cached value. +func TestClient_LookupUserByUsername_CacheHit(t *testing.T) { + client := newTestClientWithMock(NewMockSession()) + client.SetGuildID("guild-123") + client.userCache["alice"] = "user-111" + + userID := client.LookupUserByUsername(context.Background(), "alice") + if userID != "user-111" { + t.Errorf("LookupUserByUsername(\"alice\") = %q, want %q", userID, "user-111") + } +} + +// TestClient_LookupUserByUsername_NoGuildID tests user lookup with no guild ID. +func TestClient_LookupUserByUsername_NoGuildID(t *testing.T) { + client := newTestClientWithMock(NewMockSession()) + // Don't set guildID - test with empty guild ID + + userID := client.LookupUserByUsername(context.Background(), "alice") + if userID != "" { + t.Errorf("LookupUserByUsername() = %q, want empty string when no guild ID", userID) + } +} + +// TestClient_LookupUserByUsername_EmptyUsername tests user lookup with empty username. +func TestClient_LookupUserByUsername_EmptyUsername(t *testing.T) { + mockSession := NewMockSession() + mockSession.AddMember("guild-123", NewMockMember("user-111", "alice", "Alice Smith")) + + client := newTestClientWithMock(mockSession) + client.SetGuildID("guild-123") + + userID := client.LookupUserByUsername(context.Background(), "") + if userID != "" { + t.Errorf("LookupUserByUsername(\"\") = %q, want empty string", userID) + } +} + +// TestClient_IsForumChannel tests checking if a channel is a forum. +func TestClient_IsForumChannel(t *testing.T) { + mockSession := NewMockSession() + mockSession.AddChannel(&discordgo.Channel{ + ID: "forum-123", + Type: discordgo.ChannelTypeGuildForum, + }) + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + isForum := client.IsForumChannel(ctx, "forum-123") + + if !isForum { + t.Error("IsForumChannel() = false, want true for forum channel") + } +} + +// TestClient_IsForumChannel_TextChannel tests checking if a text channel is a forum. +func TestClient_IsForumChannel_TextChannel(t *testing.T) { + mockSession := NewMockSession() + mockSession.AddChannel(&discordgo.Channel{ + ID: "text-123", + Type: discordgo.ChannelTypeGuildText, + }) + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + isForum := client.IsForumChannel(ctx, "text-123") + + if isForum { + t.Error("IsForumChannel() = true, want false for text channel") + } +} + +// TestClient_IsForumChannel_Error tests IsForumChannel error handling. +func TestClient_IsForumChannel_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.ChannelError = fmt.Errorf("channel not found") + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + isForum := client.IsForumChannel(ctx, "nonexistent") + + if isForum { + t.Error("IsForumChannel() = true, want false on error") + } +} + +// TestClient_MessageContent tests retrieving message content. +func TestClient_MessageContent(t *testing.T) { + mockSession := NewMockSession() + mockSession.AddMessage("channel-123", &discordgo.Message{ + ID: "msg-456", + Content: "Test message content", + }) + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + content, err := client.MessageContent(ctx, "channel-123", "msg-456") + if err != nil { + t.Fatalf("MessageContent() error = %v, want nil", err) + } + + if content != "Test message content" { + t.Errorf("MessageContent() = %q, want %q", content, "Test message content") + } +} + +// TestClient_MessageContent_Error tests MessageContent error handling. +func TestClient_MessageContent_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.MessagesError = fmt.Errorf("message not found") + + client := newTestClientWithMock(mockSession) + + _, err := client.MessageContent(context.Background(), "channel-123", "msg-456") + if err == nil { + t.Error("MessageContent() error = nil, want error") + } +} + +// TestClient_ChannelMessages tests retrieving channel messages. +func TestClient_ChannelMessages(t *testing.T) { + mockSession := NewMockSession() + mockSession.AddMessage("channel-123", &discordgo.Message{ + ID: "msg-1", + Content: "Message 1", + }) + mockSession.AddMessage("channel-123", &discordgo.Message{ + ID: "msg-2", + Content: "Message 2", + }) + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + messages, err := client.ChannelMessages(ctx, "channel-123", 10) + if err != nil { + t.Fatalf("ChannelMessages() error = %v, want nil", err) + } + + if len(messages) != 2 { + t.Errorf("ChannelMessages() returned %d messages, want 2", len(messages)) + } +} + +// TestClient_ChannelMessages_Error tests ChannelMessages error handling. +func TestClient_ChannelMessages_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.MessagesError = fmt.Errorf("failed to fetch messages") + + client := newTestClientWithMock(mockSession) + + _, err := client.ChannelMessages(context.Background(), "channel-123", 10) + if err == nil { + t.Error("ChannelMessages() error = nil, want error") + } +} + +// TestClient_FindChannelMessage tests finding a message by PR URL. +func TestClient_FindChannelMessage(t *testing.T) { + mockSession := NewMockSession() + mockSession.AddMessage("channel-123", &discordgo.Message{ + ID: "msg-1", + Content: "Check out this PR: https://github.com/owner/repo/pull/123", + Author: &discordgo.User{ + ID: "user-123", + }, + }) + mockSession.AddMessage("channel-123", &discordgo.Message{ + ID: "msg-2", + Content: "Another message", + Author: &discordgo.User{ + ID: "user-456", + }, + }) + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + messageID, found := client.FindChannelMessage(ctx, "channel-123", "https://github.com/owner/repo/pull/123") + + if !found { + t.Error("FindChannelMessage() found = false, want true") + } + + if messageID != "msg-1" { + t.Errorf("FindChannelMessage() messageID = %q, want %q", messageID, "msg-1") + } +} + +// TestClient_FindChannelMessage_NotFound tests FindChannelMessage when message not found. +func TestClient_FindChannelMessage_NotFound(t *testing.T) { + mockSession := NewMockSession() + mockSession.AddMessage("channel-123", &discordgo.Message{ + ID: "msg-1", + Content: "Some other content", + Author: &discordgo.User{ + ID: "user-123", + }, + }) + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + _, found := client.FindChannelMessage(ctx, "channel-123", "https://github.com/owner/repo/pull/999") + + if found { + t.Error("FindChannelMessage() found = true, want false for non-existent PR") + } +} + +// TestClient_FindChannelMessage_Error tests FindChannelMessage error handling. +func TestClient_FindChannelMessage_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.MessagesError = fmt.Errorf("API error") + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + _, found := client.FindChannelMessage(ctx, "channel-123", "https://github.com/owner/repo/pull/123") + + if found { + t.Error("FindChannelMessage() found = true, want false on error") + } +} + +// TestClient_FindDMForPR tests finding a DM for a specific PR. +func TestClient_FindDMForPR(t *testing.T) { + mockSession := NewMockSession() + + // UserChannelCreate will create a channel with ID "dm-user-123" + // So we need to add messages to that channel ID + mockSession.AddMessage("dm-user-123", &discordgo.Message{ + ID: "dm-msg-1", + Content: "PR notification: https://github.com/owner/repo/pull/456", + Author: &discordgo.User{ + ID: "bot-user-id", + }, + }) + + client := newTestClientWithMock(mockSession) + client.session.GetState().User = &discordgo.User{ + ID: "bot-user-id", + Username: "testbot", + } + + ctx := context.Background() + channelID, messageID, found := client.FindDMForPR(ctx, "user-123", "https://github.com/owner/repo/pull/456") + + if !found { + t.Error("FindDMForPR() found = false, want true") + } + + if channelID != "dm-user-123" { + t.Errorf("FindDMForPR() channelID = %q, want %q", channelID, "dm-user-123") + } + + if messageID != "dm-msg-1" { + t.Errorf("FindDMForPR() messageID = %q, want %q", messageID, "dm-msg-1") + } +} + +// TestClient_FindDMForPR_NotFound tests FindDMForPR when DM not found. +func TestClient_FindDMForPR_NotFound(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.session.GetState().User = &discordgo.User{ + ID: "bot-user-id", + Username: "testbot", + } + + ctx := context.Background() + _, _, found := client.FindDMForPR(ctx, "user-123", "https://github.com/owner/repo/pull/999") + + if found { + t.Error("FindDMForPR() found = true, want false for non-existent PR") + } +} + +// TestClient_FindDMForPR_NoBotUser tests FindDMForPR when bot user not set. +func TestClient_FindDMForPR_NoBotUser(t *testing.T) { + mockSession := NewMockSession() + mockSession.MockState.User = nil + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + _, _, found := client.FindDMForPR(ctx, "user-123", "https://github.com/owner/repo/pull/123") + + if found { + t.Error("FindDMForPR() found = true, want false when bot user is nil") + } +} + +// TestClient_IsUserInGuild_Success tests IsUserInGuild when user is a member. +func TestClient_IsUserInGuild_Success(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // Add a member to the guild + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-123", + Username: "testuser", + }, + }) + + got := client.IsUserInGuild(context.Background(), "user-123") + if !got { + t.Error("IsUserInGuild() = false, want true when user is a guild member") + } +} + +// TestClient_ResolveChannelID_GuildLookup tests ResolveChannelID with guild channel lookup. +func TestClient_ResolveChannelID_GuildLookup(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // Add a channel to the guild + mockSession.AddChannel(&discordgo.Channel{ + ID: "channel-456", + Name: "general", + GuildID: "test-guild", + }) + + // Resolve by name + got := client.ResolveChannelID(context.Background(), "general") + if got != "channel-456" { + t.Errorf("ResolveChannelID(\"general\") = %q, want \"channel-456\"", got) + } + + // Verify it was cached + if client.channelCache["general"] != "channel-456" { + t.Error("Channel should be cached after resolution") + } +} + +// TestClient_ResolveChannelID_GuildLookupNotFound tests ResolveChannelID when channel not found. +func TestClient_ResolveChannelID_GuildLookupNotFound(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // Don't add any channels + + // Resolve by name + got := client.ResolveChannelID(context.Background(), "nonexistent") + if got != "nonexistent" { + t.Errorf("ResolveChannelID(\"nonexistent\") = %q, want \"nonexistent\"", got) + } +} + +// TestClient_ResolveChannelID_WithHashPrefix tests ResolveChannelID with # prefix. +func TestClient_ResolveChannelID_WithHashPrefix(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // Add a channel + mockSession.AddChannel(&discordgo.Channel{ + ID: "channel-789", + Name: "announcements", + GuildID: "test-guild", + }) + + // Resolve with # prefix + got := client.ResolveChannelID(context.Background(), "#announcements") + if got != "channel-789" { + t.Errorf("ResolveChannelID(\"#announcements\") = %q, want \"channel-789\"", got) + } +} + +// Note: FindForumThread tests are skipped because they require realSession.ThreadsArchived() +// which is complex to mock. The method uses realSession directly for archived threads. + +// TestClient_LookupUserByUsername_Success tests successful user lookup. +func TestClient_LookupUserByUsername_Success(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // Add members to the guild + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-123", + Username: "testuser", + }, + }) + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-456", + Username: "anotheruser", + GlobalName: "Another User", + }, + }) + + ctx := context.Background() + userID := client.LookupUserByUsername(ctx, "testuser") + + if userID != "user-123" { + t.Errorf("LookupUserByUsername() userID = %q, want \"user-123\"", userID) + } + + // Verify it was cached + if client.userCache["testuser"] != "user-123" { + t.Error("User should be cached after lookup") + } +} + +// TestClient_LookupUserByUsername_GlobalName tests lookup by global name. +func TestClient_LookupUserByUsername_GlobalName(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-789", + Username: "someuser", + GlobalName: "Display Name", + }, + }) + + ctx := context.Background() + userID := client.LookupUserByUsername(ctx, "Display Name") + + if userID != "user-789" { + t.Errorf("LookupUserByUsername() by global name userID = %q, want \"user-789\"", userID) + } +} + +// TestClient_LookupUserByUsername_Nick tests lookup by nickname. +func TestClient_LookupUserByUsername_Nick(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + mockSession.AddMember("test-guild", &discordgo.Member{ + Nick: "CoolNick", + User: &discordgo.User{ + ID: "user-999", + Username: "plainuser", + }, + }) + + ctx := context.Background() + userID := client.LookupUserByUsername(ctx, "CoolNick") + + if userID != "user-999" { + t.Errorf("LookupUserByUsername() by nickname userID = %q, want \"user-999\"", userID) + } +} + +// TestClient_LookupUserByUsername_CaseInsensitive tests case-insensitive matching. +func TestClient_LookupUserByUsername_CaseInsensitive(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-case", + Username: "CamelCase", + }, + }) + + ctx := context.Background() + userID := client.LookupUserByUsername(ctx, "camelcase") + + if userID != "user-case" { + t.Errorf("LookupUserByUsername() case-insensitive userID = %q, want \"user-case\"", userID) + } +} + +// TestClient_LookupUserByUsername_PrefixMatch tests unambiguous prefix matching. +func TestClient_LookupUserByUsername_PrefixMatch(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-prefix", + Username: "PrefixedUser", + }, + }) + + ctx := context.Background() + userID := client.LookupUserByUsername(ctx, "prefix") + + if userID != "user-prefix" { + t.Errorf("LookupUserByUsername() prefix match userID = %q, want \"user-prefix\"", userID) + } +} + +// TestClient_LookupUserByUsername_AmbiguousPrefix tests ambiguous prefix matching. +func TestClient_LookupUserByUsername_AmbiguousPrefix(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-1", + Username: "TestUser1", + }, + }) + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-2", + Username: "TestUser2", + }, + }) + + ctx := context.Background() + userID := client.LookupUserByUsername(ctx, "test") + + // Should return empty for ambiguous match + if userID != "" { + t.Errorf("LookupUserByUsername() ambiguous prefix userID = %q, want empty", userID) + } +} + +// TestClient_LookupUserByUsername_NotFound tests when user is not found. +func TestClient_LookupUserByUsername_NotFound(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + mockSession.AddMember("test-guild", &discordgo.Member{ + User: &discordgo.User{ + ID: "user-other", + Username: "someuser", + }, + }) + + ctx := context.Background() + userID := client.LookupUserByUsername(ctx, "nonexistent") + + if userID != "" { + t.Errorf("LookupUserByUsername() not found userID = %q, want empty", userID) + } +} + +// TestClient_IsBotInChannel_Success tests IsBotInChannel when bot is in channel. +func TestClient_IsBotInChannel_Success(t *testing.T) { + mockSession := NewMockSession() + mockSession.MockState.User = &discordgo.User{ + ID: "bot-id", + Username: "testbot", + } + + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // MockSession's UserChannelPermissions returns PermissionAll by default + ctx := context.Background() + got := client.IsBotInChannel(ctx, "channel-123") + if !got { + t.Error("IsBotInChannel() = false, want true when bot has permissions") + } +} + +// TestClient_IsBotInChannel_PermissionError tests IsBotInChannel when permission check fails. +func TestClient_IsBotInChannel_PermissionError(t *testing.T) { + mockSession := NewMockSession() + mockSession.MockState.User = &discordgo.User{ + ID: "bot-id", + Username: "testbot", + } + mockSession.UserChannelPermissionsError = fmt.Errorf("permission check failed") + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + got := client.IsBotInChannel(ctx, "channel-123") + if got { + t.Error("IsBotInChannel() = true, want false when permission check fails") + } +} + +// TestClient_IsUserInGuild_Error tests IsUserInGuild when GuildMember fails. +func TestClient_IsUserInGuild_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.GuildMemberError = fmt.Errorf("member not found") + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + got := client.IsUserInGuild(context.Background(), "user-123") + if got { + t.Error("IsUserInGuild() = true, want false when GuildMember fails") + } +} + +// TestClient_IsUserActive_Error tests IsUserActive when guild state is unavailable. +func TestClient_IsUserActive_Error(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // Don't add guild to state, so state lookup will fail + ctx := context.Background() + got := client.IsUserActive(ctx, "user-123") + if got { + t.Error("IsUserActive() = true, want false when guild state unavailable") + } +} + +// TestClient_GuildInfo_Error tests GuildInfo when guild lookup fails. +func TestClient_GuildInfo_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.GuildError = fmt.Errorf("guild not found") + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + ctx := context.Background() + _, err := client.GuildInfo(ctx) + if err == nil { + t.Error("GuildInfo() error = nil, want error when guild lookup fails") + } +} + +// TestClient_ResolveChannelID_Error tests ResolveChannelID when GuildChannels fails. +func TestClient_ResolveChannelID_Error(t *testing.T) { + mockSession := NewMockSession() + mockSession.GuildChannelsError = fmt.Errorf("channels fetch failed") + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + ctx := context.Background() + got := client.ResolveChannelID(ctx, "test-channel") + + // Should return input unchanged when fetch fails + if got != "test-channel" { + t.Errorf("ResolveChannelID() = %q, want %q when fetch fails", got, "test-channel") + } +} + +// TestClient_IsUserActive_Success tests IsUserActive when user is active. +func TestClient_IsUserActive_Success(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // Add guild with presence to the state + guild := &discordgo.Guild{ + ID: "test-guild", + Name: "Test Guild", + Presences: []*discordgo.Presence{ + { + User: &discordgo.User{ + ID: "user-123", + }, + Status: discordgo.StatusOnline, + }, + }, + } + mockSession.MockState.GuildAdd(guild) + + ctx := context.Background() + got := client.IsUserActive(ctx, "user-123") + if !got { + t.Error("IsUserActive() = false, want true when user is online") + } +} + +// TestClient_GuildInfo_Success tests successful guild info retrieval. +func TestClient_GuildInfo_Success(t *testing.T) { + mockSession := NewMockSession() + client := newTestClientWithMock(mockSession) + client.SetGuildID("test-guild") + + // Add a guild + mockSession.Guilds["test-guild"] = &discordgo.Guild{ + ID: "test-guild", + Name: "Test Guild", + } + + ctx := context.Background() + guildInfo, err := client.GuildInfo(ctx) + if err != nil { + t.Fatalf("GuildInfo() error = %v, want nil", err) + } + if guildInfo.ID != "test-guild" { + t.Errorf("GuildInfo() guild.ID = %q, want \"test-guild\"", guildInfo.ID) + } + if guildInfo.Name != "Test Guild" { + t.Errorf("GuildInfo() guild.Name = %q, want \"Test Guild\"", guildInfo.Name) + } +} + +// TestClient_BotInfo_Success tests successful bot info retrieval. +func TestClient_BotInfo_Success(t *testing.T) { + mockSession := NewMockSession() + mockSession.MockState.User = &discordgo.User{ + ID: "bot-id", + Username: "testbot", + } + + client := newTestClientWithMock(mockSession) + + ctx := context.Background() + botInfo, err := client.BotInfo(ctx) + if err != nil { + t.Fatalf("BotInfo() error = %v, want nil", err) + } + if botInfo.UserID != "bot-id" { + t.Errorf("BotInfo() botInfo.UserID = %q, want \"bot-id\"", botInfo.UserID) + } + if botInfo.Username != "testbot" { + t.Errorf("BotInfo() botInfo.Username = %q, want \"testbot\"", botInfo.Username) + } +} + +// TestSessionAdapter_GetState tests the sessionAdapter.GetState method +func TestSessionAdapter_GetState(t *testing.T) { + state := discordgo.NewState() + state.User = &discordgo.User{ + ID: "bot-123", + Username: "testbot", + } + + session := &discordgo.Session{ + State: state, + } + + adapter := &sessionAdapter{Session: session} + + gotState := adapter.GetState() + if gotState != state { + t.Errorf("GetState() = %v, want %v", gotState, state) + } + + if gotState.User.ID != "bot-123" { + t.Errorf("GetState().User.ID = %v, want bot-123", gotState.User.ID) + } +} diff --git a/internal/discord/manager_test.go b/internal/discord/manager_test.go index 6726380..7788c8b 100644 --- a/internal/discord/manager_test.go +++ b/internal/discord/manager_test.go @@ -77,7 +77,11 @@ func TestGuildManager_RemoveClient(t *testing.T) { manager := NewGuildManager(nil) // Create a session (it won't connect in tests) session := &discordgo.Session{} - client := &Client{guildID: "test-guild", session: session} + client := &Client{ + guildID: "test-guild", + session: &sessionAdapter{Session: session}, + realSession: session, + } // Register a client manager.RegisterClient("test-guild", client) @@ -128,3 +132,70 @@ func TestGuildManager_Close(t *testing.T) { t.Errorf("Close() error = %v, want nil", err) } } + +// TestGuildManager_Close_WithClients tests closing manager with registered clients. +func TestGuildManager_Close_WithClients(t *testing.T) { + manager := NewGuildManager(nil) + + // Create a session (it won't actually connect in tests) + session := &discordgo.Session{} + client1 := &Client{ + guildID: "guild1", + session: &sessionAdapter{Session: session}, + realSession: session, + } + client2 := &Client{ + guildID: "guild2", + session: &sessionAdapter{Session: session}, + realSession: session, + } + + manager.RegisterClient("guild1", client1) + manager.RegisterClient("guild2", client2) + + err := manager.Close() + if err != nil { + t.Errorf("Close() error = %v, want nil", err) + } + + // Verify clients were cleared + if len(manager.GuildIDs()) != 0 { + t.Errorf("GuildIDs() after Close() = %d, want 0", len(manager.GuildIDs())) + } +} + +// TestGuildManager_RemoveClient_Cleanup tests RemoveClient closes the client. +func TestGuildManager_RemoveClient_Cleanup(t *testing.T) { + manager := NewGuildManager(nil) + session := &discordgo.Session{} + client := &Client{ + guildID: "test-guild", + session: &sessionAdapter{Session: session}, + realSession: session, + } + + manager.RegisterClient("test-guild", client) + + // Verify client is registered + _, ok := manager.Client("test-guild") + if !ok { + t.Fatal("Client() should return registered client before removal") + } + + // Remove client + manager.RemoveClient("test-guild") + + // Verify client was removed + _, ok = manager.Client("test-guild") + if ok { + t.Error("Client() should return false after RemoveClient") + } + + // Verify guild ID is removed + ids := manager.GuildIDs() + for _, id := range ids { + if id == "test-guild" { + t.Error("GuildIDs() should not contain removed guild") + } + } +} diff --git a/internal/discord/mocks_test.go b/internal/discord/mocks_test.go index a7d3013..7062ff8 100644 --- a/internal/discord/mocks_test.go +++ b/internal/discord/mocks_test.go @@ -10,31 +10,41 @@ import ( // MockSession is a programmable mock for discordgo.Session type MockSession struct { // Programmable responses - OpenError error - CloseError error - MessageSendError error - MessageEditError error - UserChannelError error - GuildMembersError error - ChannelError error - GuildChannelsError error - MessagesError error - ThreadsActiveError error - ApplicationCommandsError error - InteractionResponseError error + OpenError error + CloseError error + MessageSendError error + MessageEditError error + UserChannelError error + GuildMembersError error + GuildMemberError error + ChannelError error + GuildChannelsError error + MessagesError error + ThreadsActiveError error + ApplicationCommandsError error + InteractionResponseError error + ChannelMessageSendComplexError error + ChannelMessageEditComplexError error + ForumThreadStartComplexError error + ChannelEditError error + GuildError error + UserChannelPermissionsError error // Storage for tracking calls SentMessages []*sentMessage EditedMessages []*editedMessage CreatedChannels []string + CreatedThreads []*discordgo.Channel Interactions []*discordgo.InteractionResponse // Mock data Channels map[string]*discordgo.Channel Members map[string][]*discordgo.Member + Guilds map[string]*discordgo.Guild Messages map[string][]*discordgo.Message ActiveThreads []*discordgo.Channel Commands []*discordgo.ApplicationCommand + MockState *discordgo.State mu sync.Mutex } @@ -56,10 +66,13 @@ func NewMockSession() *MockSession { return &MockSession{ SentMessages: make([]*sentMessage, 0), EditedMessages: make([]*editedMessage, 0), + CreatedThreads: make([]*discordgo.Channel, 0), Channels: make(map[string]*discordgo.Channel), Members: make(map[string][]*discordgo.Member), + Guilds: make(map[string]*discordgo.Guild), Messages: make(map[string][]*discordgo.Message), Commands: make([]*discordgo.ApplicationCommand, 0), + MockState: discordgo.NewState(), } } @@ -315,3 +328,291 @@ func NewMockMessage(id, channelID, content, authorID string) *discordgo.Message }, } } + +// ChannelMessageSendComplex mocks sending a complex message +func (m *MockSession) ChannelMessageSendComplex(channelID string, data *discordgo.MessageSend, options ...discordgo.RequestOption) (*discordgo.Message, error) { + if m.ChannelMessageSendComplexError != nil { + return nil, m.ChannelMessageSendComplexError + } + + m.mu.Lock() + defer m.mu.Unlock() + + var embed *discordgo.MessageEmbed + if len(data.Embeds) > 0 { + embed = data.Embeds[0] + } + + m.SentMessages = append(m.SentMessages, &sentMessage{ + ChannelID: channelID, + Content: data.Content, + Embed: embed, + }) + + msgID := fmt.Sprintf("msg-%d", len(m.SentMessages)) + return &discordgo.Message{ + ID: msgID, + ChannelID: channelID, + Content: data.Content, + Embeds: data.Embeds, + }, nil +} + +// ChannelMessageEditComplex mocks editing a complex message +func (m *MockSession) ChannelMessageEditComplex(data *discordgo.MessageEdit, options ...discordgo.RequestOption) (*discordgo.Message, error) { + if m.ChannelMessageEditComplexError != nil { + return nil, m.ChannelMessageEditComplexError + } + + m.mu.Lock() + defer m.mu.Unlock() + + content := "" + if data.Content != nil { + content = *data.Content + } + + var embed *discordgo.MessageEmbed + if data.Embeds != nil && len(*data.Embeds) > 0 { + embed = (*data.Embeds)[0] + } + + m.EditedMessages = append(m.EditedMessages, &editedMessage{ + ChannelID: data.Channel, + MessageID: data.ID, + Content: content, + Embed: embed, + }) + + return &discordgo.Message{ + ID: data.ID, + ChannelID: data.Channel, + Content: content, + }, nil +} + +// ForumThreadStartComplex mocks creating a forum thread +func (m *MockSession) ForumThreadStartComplex(channelID string, threadData *discordgo.ThreadStart, messageData *discordgo.MessageSend, options ...discordgo.RequestOption) (*discordgo.Channel, error) { + if m.ForumThreadStartComplexError != nil { + return nil, m.ForumThreadStartComplexError + } + + m.mu.Lock() + defer m.mu.Unlock() + + threadID := fmt.Sprintf("thread-%d", len(m.CreatedThreads)+1) + thread := &discordgo.Channel{ + ID: threadID, + Name: threadData.Name, + Type: discordgo.ChannelTypeGuildPublicThread, + ParentID: channelID, + } + + m.CreatedThreads = append(m.CreatedThreads, thread) + + // Also create the initial message in the thread + msgID := fmt.Sprintf("msg-%d", len(m.SentMessages)+1) + message := &discordgo.Message{ + ID: msgID, + ChannelID: threadID, + Content: messageData.Content, + } + m.Messages[threadID] = []*discordgo.Message{message} + + return thread, nil +} + +// ChannelEdit mocks editing a channel +func (m *MockSession) ChannelEdit(channelID string, data *discordgo.ChannelEdit, options ...discordgo.RequestOption) (*discordgo.Channel, error) { + if m.ChannelEditError != nil { + return nil, m.ChannelEditError + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Update or create channel + channel, exists := m.Channels[channelID] + if !exists { + channel = &discordgo.Channel{ + ID: channelID, + } + m.Channels[channelID] = channel + } + + if data.Name != "" { + channel.Name = data.Name + } + if data.Archived != nil { + channel.ThreadMetadata = &discordgo.ThreadMetadata{ + Archived: *data.Archived, + } + } + + return channel, nil +} + +// ChannelMessage mocks retrieving a single message +func (m *MockSession) ChannelMessage(channelID, messageID string, options ...discordgo.RequestOption) (*discordgo.Message, error) { + if m.MessagesError != nil { + return nil, m.MessagesError + } + + m.mu.Lock() + defer m.mu.Unlock() + + if messages, ok := m.Messages[channelID]; ok { + for _, msg := range messages { + if msg.ID == messageID { + return msg, nil + } + } + } + + return nil, fmt.Errorf("message not found") +} + +// GetState returns the mock state +func (m *MockSession) GetState() *discordgo.State { + return m.MockState +} + +// GuildMember mocks fetching a single guild member +func (m *MockSession) GuildMember(guildID, userID string, options ...discordgo.RequestOption) (*discordgo.Member, error) { + if m.GuildMemberError != nil { + return nil, m.GuildMemberError + } + + m.mu.Lock() + defer m.mu.Unlock() + + if members, ok := m.Members[guildID]; ok { + for _, member := range members { + if member.User.ID == userID { + return member, nil + } + } + } + + return nil, fmt.Errorf("member not found") +} + +// Guild mocks fetching guild information +func (m *MockSession) Guild(guildID string, options ...discordgo.RequestOption) (*discordgo.Guild, error) { + if m.GuildError != nil { + return nil, m.GuildError + } + + m.mu.Lock() + defer m.mu.Unlock() + + if guild, ok := m.Guilds[guildID]; ok { + return guild, nil + } + + return nil, fmt.Errorf("guild not found") +} + +// UserChannelPermissions mocks checking user permissions +func (m *MockSession) UserChannelPermissions(userID, channelID string, fetchOptions ...discordgo.RequestOption) (int64, error) { + if m.UserChannelPermissionsError != nil { + return 0, m.UserChannelPermissionsError + } + + // Return full permissions for testing + return discordgo.PermissionAll, nil +} + +// GuildThreadsActive mocks fetching active threads +func (m *MockSession) GuildThreadsActive(guildID string, options ...discordgo.RequestOption) (*discordgo.ThreadsList, error) { + if m.ThreadsActiveError != nil { + return nil, m.ThreadsActiveError + } + + return &discordgo.ThreadsList{ + Threads: m.ActiveThreads, + }, nil +} + +// ApplicationCommandCreate mocks creating a slash command +func (m *MockSession) ApplicationCommandCreate(appID, guildID string, cmd *discordgo.ApplicationCommand, options ...discordgo.RequestOption) (*discordgo.ApplicationCommand, error) { + if m.ApplicationCommandsError != nil { + return nil, m.ApplicationCommandsError + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Assign an ID to the command + cmd.ID = fmt.Sprintf("cmd-%d", len(m.Commands)+1) + cmd.ApplicationID = appID + cmd.GuildID = guildID + + m.Commands = append(m.Commands, cmd) + return cmd, nil +} + +// ApplicationCommands mocks listing slash commands +func (m *MockSession) ApplicationCommands(appID, guildID string, options ...discordgo.RequestOption) ([]*discordgo.ApplicationCommand, error) { + if m.ApplicationCommandsError != nil { + return nil, m.ApplicationCommandsError + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.Commands, nil +} + +// ApplicationCommandDelete mocks deleting a slash command +func (m *MockSession) ApplicationCommandDelete(appID, guildID, cmdID string, options ...discordgo.RequestOption) error { + if m.ApplicationCommandsError != nil { + return m.ApplicationCommandsError + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Remove command from the list + for i, cmd := range m.Commands { + if cmd.ID == cmdID { + m.Commands = append(m.Commands[:i], m.Commands[i+1:]...) + break + } + } + + return nil +} + +// InteractionResponseEdit mocks editing an interaction response +func (m *MockSession) InteractionResponseEdit(interaction *discordgo.Interaction, data *discordgo.WebhookEdit, options ...discordgo.RequestOption) (*discordgo.Message, error) { + if m.InteractionResponseError != nil { + return nil, m.InteractionResponseError + } + + m.mu.Lock() + defer m.mu.Unlock() + + content := "" + if data.Content != nil { + content = *data.Content + } + + var embeds []*discordgo.MessageEmbed + if data.Embeds != nil { + embeds = *data.Embeds + } + + return &discordgo.Message{ + ID: interaction.ID, + Content: content, + Embeds: embeds, + }, nil +} + +// AddHandler mocks adding a handler (no-op for testing) +func (m *MockSession) AddHandler(handler interface{}) func() { + // In tests, we don't need to actually register handlers + // Return a no-op cleanup function + return func() {} +} diff --git a/internal/discord/session_interface.go b/internal/discord/session_interface.go new file mode 100644 index 0000000..c34623a --- /dev/null +++ b/internal/discord/session_interface.go @@ -0,0 +1,47 @@ +package discord + +import "github.com/bwmarrin/discordgo" + +// session defines the interface for Discord session operations used by Client. +// This interface allows for mocking in tests while the production code uses *discordgo.Session. +type session interface { + // Connection + Open() error + Close() error + + // Message operations + ChannelMessageSendComplex(channelID string, data *discordgo.MessageSend, options ...discordgo.RequestOption) (*discordgo.Message, error) + ChannelMessageEditComplex(data *discordgo.MessageEdit, options ...discordgo.RequestOption) (*discordgo.Message, error) + ChannelMessage(channelID, messageID string, options ...discordgo.RequestOption) (*discordgo.Message, error) + ChannelMessages(channelID string, limit int, beforeID, afterID, aroundID string, options ...discordgo.RequestOption) ([]*discordgo.Message, error) + + // Channel operations + Channel(channelID string, options ...discordgo.RequestOption) (*discordgo.Channel, error) + ChannelEdit(channelID string, data *discordgo.ChannelEdit, options ...discordgo.RequestOption) (*discordgo.Channel, error) + GuildChannels(guildID string, options ...discordgo.RequestOption) ([]*discordgo.Channel, error) + ForumThreadStartComplex(channelID string, threadData *discordgo.ThreadStart, messageData *discordgo.MessageSend, options ...discordgo.RequestOption) (*discordgo.Channel, error) + ThreadsActive(guildID string, options ...discordgo.RequestOption) (*discordgo.ThreadsList, error) + GuildThreadsActive(guildID string, options ...discordgo.RequestOption) (*discordgo.ThreadsList, error) + + // User operations + UserChannelCreate(recipientID string, options ...discordgo.RequestOption) (*discordgo.Channel, error) + UserChannelPermissions(userID, channelID string, fetchOptions ...discordgo.RequestOption) (perms int64, err error) + GuildMembers(guildID string, after string, limit int, options ...discordgo.RequestOption) ([]*discordgo.Member, error) + GuildMember(guildID, userID string, options ...discordgo.RequestOption) (*discordgo.Member, error) + + // Guild operations + Guild(guildID string, options ...discordgo.RequestOption) (*discordgo.Guild, error) + + // GetState returns the session state for accessing bot user info, etc. + GetState() *discordgo.State +} + +// sessionAdapter wraps *discordgo.Session to implement the session interface. +type sessionAdapter struct { + *discordgo.Session +} + +// GetState returns the session state. +func (s *sessionAdapter) GetState() *discordgo.State { + return s.State +} diff --git a/internal/discord/slash.go b/internal/discord/slash.go index 69f2cc4..a764c3a 100644 --- a/internal/discord/slash.go +++ b/internal/discord/slash.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "log/slog" + "regexp" "strconv" "strings" "time" @@ -12,8 +13,11 @@ import ( "github.com/bwmarrin/discordgo" "github.com/codeGROOVE-dev/discordian/internal/format" + "github.com/codeGROOVE-dev/discordian/internal/state" ) +var gitHubUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9-]{0,37}[a-zA-Z0-9])?$`) + // SlashCommandHandler handles Discord slash commands. type SlashCommandHandler struct { session *discordgo.Session @@ -23,6 +27,7 @@ type SlashCommandHandler struct { userMapGetter UserMapGetter channelMapGetter ChannelMapGetter dailyReportGetter DailyReportGetter + store state.Store dashboardURL string } @@ -181,6 +186,11 @@ func (h *SlashCommandHandler) SetDashboardURL(url string) { h.dashboardURL = url } +// SetStore sets the state store. +func (h *SlashCommandHandler) SetStore(store state.Store) { + h.store = store +} + // RegisterCommands registers the slash commands with Discord. func (h *SlashCommandHandler) RegisterCommands(guildID string) error { commands := []*discordgo.ApplicationCommand{ @@ -218,6 +228,19 @@ func (h *SlashCommandHandler) RegisterCommands(guildID string) error { Name: "channels", Description: "Show repository to channel mappings", }, + { + Type: discordgo.ApplicationCommandOptionSubCommand, + Name: "github-user", + Description: "Link your Discord account to a GitHub username", + Options: []*discordgo.ApplicationCommandOption{ + { + Type: discordgo.ApplicationCommandOptionString, + Name: "username", + Description: "Your GitHub username", + Required: true, + }, + }, + }, }, }, } @@ -285,6 +308,8 @@ func (h *SlashCommandHandler) handleGooseCommand( h.handleUsersCommand(s, i) case "channels": h.handleChannelsCommand(s, i) + case "github-user": + h.handleGitHubUserCommand(s, i, data.Options[0]) default: h.respondError(s, i, "Unknown subcommand") } @@ -862,6 +887,83 @@ func (h *SlashCommandHandler) handleChannelsCommand(s *discordgo.Session, i *dis h.respond(s, i, embed) } +func (h *SlashCommandHandler) handleGitHubUserCommand( + s *discordgo.Session, + i *discordgo.InteractionCreate, + option *discordgo.ApplicationCommandInteractionDataOption, +) { + h.logger.Info("handling github-user command", + "guild_id", i.GuildID, + "user_id", i.Member.User.ID) + + ctx := context.Background() + guildID := i.GuildID + discordUserID := i.Member.User.ID + + if h.store == nil { + h.respondError(s, i, "User mapping storage is not available.") + return + } + + // Extract GitHub username from options + if len(option.Options) == 0 { + h.respondError(s, i, "Please provide a GitHub username.") + return + } + + gitHubUsername := option.Options[0].StringValue() + + // Validate GitHub username format + // GitHub usernames can only contain alphanumeric characters and hyphens + // Must be between 1 and 39 characters + // Cannot start or end with hyphen + // Cannot have consecutive hyphens + if !gitHubUsernameRegex.MatchString(gitHubUsername) { + h.respondError(s, i, "Invalid GitHub username format. GitHub usernames can only contain alphanumeric characters and hyphens, and must be 1-39 characters long.") + return + } + + // Save the mapping + mapping := state.UserMappingInfo{ + GitHubUsername: gitHubUsername, + DiscordUserID: discordUserID, + GuildID: guildID, + CreatedAt: time.Now(), + } + + err := h.store.SaveUserMapping(ctx, guildID, mapping) + if err != nil { + h.logger.Error("failed to save user mapping", + "error", err, + "guild_id", guildID, + "github_username", gitHubUsername, + "discord_user_id", discordUserID) + h.respondError(s, i, "Failed to save user mapping. Please try again.") + return + } + + // Success response + embed := &discordgo.MessageEmbed{ + Color: 0x57F287, // Discord green + Author: &discordgo.MessageEmbedAuthor{ + Name: "GitHub Account Linked", + }, + Description: fmt.Sprintf("Successfully linked your Discord account to GitHub user `%s`.\n\nYou will now receive notifications for PRs associated with this GitHub account.", gitHubUsername), + Fields: []*discordgo.MessageEmbedField{ + { + Name: "GitHub Username", + Value: fmt.Sprintf("`%s`", gitHubUsername), + }, + { + Name: "Discord User", + Value: fmt.Sprintf("<@%s>", discordUserID), + }, + }, + } + + h.respond(s, i, embed) +} + func (*SlashCommandHandler) formatChannelMappingsEmbed(mappings *ChannelMappings) *discordgo.MessageEmbed { embed := &discordgo.MessageEmbed{ Color: 0x5865F2, // Discord blurple diff --git a/internal/discord/slash_test.go b/internal/discord/slash_test.go index 99a4761..958c003 100644 --- a/internal/discord/slash_test.go +++ b/internal/discord/slash_test.go @@ -735,3 +735,71 @@ func (m *mockChannelMapGetter) ChannelMappings(_ context.Context, _ string) (*Ch } return m.mappings, nil } + +// Tests for slash command registration and handler methods + +func TestSlashCommandHandler_SetStore(t *testing.T) { + handler := NewSlashCommandHandler(nil, nil) + + if handler.store != nil { + t.Error("store should be nil initially") + } + + // We can't easily create a mock store, but we can verify the method exists + // and doesn't panic when called with nil + handler.SetStore(nil) +} + +func TestSlashCommandHandler_RegisterCommands(t *testing.T) { + // We can't easily test RegisterCommands without a real Discord connection + // This test just verifies the method exists and the handler can be created + handler := NewSlashCommandHandler(nil, nil) + if handler == nil { + t.Error("NewSlashCommandHandler() should not return nil") + } +} + +func TestSlashCommandHandler_RemoveCommands(t *testing.T) { + // We can't easily test RemoveCommands without a real Discord connection + // This test just verifies the method exists and the handler can be created + handler := NewSlashCommandHandler(nil, nil) + if handler == nil { + t.Error("NewSlashCommandHandler() should not return nil") + } +} + +func TestSlashCommandHandler_SetupHandler(t *testing.T) { + // We can't easily test SetupHandler without a real Discord session + // This test just verifies the method exists and the handler can be created + handler := NewSlashCommandHandler(nil, nil) + if handler == nil { + t.Error("NewSlashCommandHandler() should not return nil") + } +} + +func TestSlashCommandHandler_respond(t *testing.T) { + // We can't easily test respond without a real Discord session + // This test just verifies the method exists + handler := NewSlashCommandHandler(nil, nil) + if handler == nil { + t.Error("NewSlashCommandHandler() should not return nil") + } +} + +func TestSlashCommandHandler_editResponse(t *testing.T) { + // We can't easily test editResponse without a real Discord session + // This test just verifies the method exists + handler := NewSlashCommandHandler(nil, nil) + if handler == nil { + t.Error("NewSlashCommandHandler() should not return nil") + } +} + +func TestSlashCommandHandler_respondError(t *testing.T) { + // We can't easily test respondError without a real Discord session + // This test just verifies the method exists + handler := NewSlashCommandHandler(nil, nil) + if handler == nil { + t.Error("NewSlashCommandHandler() should not return nil") + } +} diff --git a/internal/github/mocks_test.go b/internal/github/mocks_test.go index d9246f2..bfcbf7e 100644 --- a/internal/github/mocks_test.go +++ b/internal/github/mocks_test.go @@ -93,3 +93,23 @@ func NewMockPRIssue(owner, repo string, number int, title string) *github.Issue }, } } + +// MockAppClient mocks the AppClient for testing +type MockAppClient struct { + Client *github.Client + ClientForOrgFn func(ctx context.Context, org string) (*github.Client, error) + ClientError error +} + +func (m *MockAppClient) ClientForOrg(ctx context.Context, org string) (*github.Client, error) { + if m.ClientForOrgFn != nil { + return m.ClientForOrgFn(ctx, org) + } + if m.ClientError != nil { + return nil, m.ClientError + } + if m.Client != nil { + return m.Client, nil + } + return github.NewClient(nil), nil +} diff --git a/internal/github/searcher.go b/internal/github/searcher.go index aef247d..99ca4b9 100644 --- a/internal/github/searcher.go +++ b/internal/github/searcher.go @@ -12,14 +12,19 @@ import ( "github.com/google/go-github/v50/github" ) +// appClient defines the interface for getting GitHub clients for orgs. +type appClient interface { + ClientForOrg(ctx context.Context, org string) (*github.Client, error) +} + // Searcher queries GitHub for PRs using the search API. type Searcher struct { - appClient *AppClient + appClient appClient logger *slog.Logger } // NewSearcher creates a new PR searcher. -func NewSearcher(appClient *AppClient, logger *slog.Logger) *Searcher { +func NewSearcher(appClient appClient, logger *slog.Logger) *Searcher { if logger == nil { logger = slog.Default() } diff --git a/internal/github/searcher_test.go b/internal/github/searcher_test.go index f280b4e..c9d55ad 100644 --- a/internal/github/searcher_test.go +++ b/internal/github/searcher_test.go @@ -3,6 +3,7 @@ package github import ( "context" "encoding/json" + "fmt" "log/slog" "net/http" "net/http/httptest" @@ -364,6 +365,175 @@ func TestSearchPRs(t *testing.T) { }) } -// Note: ListOpenPRs, ListClosedPRs, ListAuthoredPRs, and ListReviewRequestedPRs -// are difficult to test without an actual AppClient or refactoring to use interfaces. -// These methods are tested indirectly through integration tests. +// TestListOpenPRs tests listing open PRs +func TestListOpenPRs(t *testing.T) { + ctx := context.Background() + + t.Run("successful search", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := &github.IssuesSearchResult{ + Total: github.Int(1), + Issues: []*github.Issue{ + NewMockPRIssue("testowner", "testrepo", 123, "Test PR"), + }, + } + writeJSONResponse(t, w, response) + })) + defer server.Close() + + client := setupTestGitHubClient(t, server.URL) + mockAppClient := &MockAppClient{Client: client} + searcher := NewSearcher(mockAppClient, nil) + + results, err := searcher.ListOpenPRs(ctx, "test-org", 24) + if err != nil { + t.Fatalf("ListOpenPRs() error = %v, want nil", err) + } + + if len(results) != 1 { + t.Errorf("ListOpenPRs() returned %d results, want 1", len(results)) + } + }) + + t.Run("client error", func(t *testing.T) { + mockAppClient := &MockAppClient{ + ClientError: fmt.Errorf("no installation found"), + } + searcher := NewSearcher(mockAppClient, nil) + + _, err := searcher.ListOpenPRs(ctx, "test-org", 24) + if err == nil { + t.Error("ListOpenPRs() error = nil, want error") + } + }) +} + +// TestListClosedPRs tests listing closed PRs +func TestListClosedPRs(t *testing.T) { + ctx := context.Background() + + t.Run("successful search", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := &github.IssuesSearchResult{ + Total: github.Int(1), + Issues: []*github.Issue{ + NewMockPRIssue("testowner", "testrepo", 456, "Closed PR"), + }, + } + writeJSONResponse(t, w, response) + })) + defer server.Close() + + client := setupTestGitHubClient(t, server.URL) + mockAppClient := &MockAppClient{Client: client} + searcher := NewSearcher(mockAppClient, nil) + + results, err := searcher.ListClosedPRs(ctx, "test-org", 24) + if err != nil { + t.Fatalf("ListClosedPRs() error = %v, want nil", err) + } + + if len(results) != 1 { + t.Errorf("ListClosedPRs() returned %d results, want 1", len(results)) + } + }) + + t.Run("client error", func(t *testing.T) { + mockAppClient := &MockAppClient{ + ClientError: fmt.Errorf("no installation found"), + } + searcher := NewSearcher(mockAppClient, nil) + + _, err := searcher.ListClosedPRs(ctx, "test-org", 24) + if err == nil { + t.Error("ListClosedPRs() error = nil, want error") + } + }) +} + +// TestListAuthoredPRs tests listing authored PRs +func TestListAuthoredPRs(t *testing.T) { + ctx := context.Background() + + t.Run("successful search", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := &github.IssuesSearchResult{ + Total: github.Int(2), + Issues: []*github.Issue{ + NewMockPRIssue("testowner", "testrepo", 111, "User's PR 1"), + NewMockPRIssue("testowner", "testrepo", 222, "User's PR 2"), + }, + } + writeJSONResponse(t, w, response) + })) + defer server.Close() + + client := setupTestGitHubClient(t, server.URL) + mockAppClient := &MockAppClient{Client: client} + searcher := NewSearcher(mockAppClient, nil) + + results, err := searcher.ListAuthoredPRs(ctx, "test-org", "testuser") + if err != nil { + t.Fatalf("ListAuthoredPRs() error = %v, want nil", err) + } + + if len(results) != 2 { + t.Errorf("ListAuthoredPRs() returned %d results, want 2", len(results)) + } + }) + + t.Run("client error", func(t *testing.T) { + mockAppClient := &MockAppClient{ + ClientError: fmt.Errorf("no installation found"), + } + searcher := NewSearcher(mockAppClient, nil) + + _, err := searcher.ListAuthoredPRs(ctx, "test-org", "testuser") + if err == nil { + t.Error("ListAuthoredPRs() error = nil, want error") + } + }) +} + +// TestListReviewRequestedPRs tests listing review-requested PRs +func TestListReviewRequestedPRs(t *testing.T) { + ctx := context.Background() + + t.Run("successful search", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := &github.IssuesSearchResult{ + Total: github.Int(1), + Issues: []*github.Issue{ + NewMockPRIssue("testowner", "testrepo", 333, "Review Requested PR"), + }, + } + writeJSONResponse(t, w, response) + })) + defer server.Close() + + client := setupTestGitHubClient(t, server.URL) + mockAppClient := &MockAppClient{Client: client} + searcher := NewSearcher(mockAppClient, nil) + + results, err := searcher.ListReviewRequestedPRs(ctx, "test-org", "testuser") + if err != nil { + t.Fatalf("ListReviewRequestedPRs() error = %v, want nil", err) + } + + if len(results) != 1 { + t.Errorf("ListReviewRequestedPRs() returned %d results, want 1", len(results)) + } + }) + + t.Run("client error", func(t *testing.T) { + mockAppClient := &MockAppClient{ + ClientError: fmt.Errorf("no installation found"), + } + searcher := NewSearcher(mockAppClient, nil) + + _, err := searcher.ListReviewRequestedPRs(ctx, "test-org", "testuser") + if err == nil { + t.Error("ListReviewRequestedPRs() error = nil, want error") + } + }) +} diff --git a/internal/notify/notify_test.go b/internal/notify/notify_test.go index 26bd5d5..32b39ff 100644 --- a/internal/notify/notify_test.go +++ b/internal/notify/notify_test.go @@ -110,6 +110,18 @@ func (m *mockStore) SaveDailyReportInfo(_ context.Context, _ string, _ state.Dai return nil } +func (m *mockStore) UserMapping(_ context.Context, _, _ string) (state.UserMappingInfo, bool) { + return state.UserMappingInfo{}, false +} + +func (m *mockStore) SaveUserMapping(_ context.Context, _ string, _ state.UserMappingInfo) error { + return nil +} + +func (m *mockStore) ListUserMappings(_ context.Context, _ string) []state.UserMappingInfo { + return nil +} + // mockDMSender implements DiscordDMSender for testing type mockDMSender struct { sentDMs []sentDM diff --git a/internal/state/fido.go b/internal/state/fido.go index fb176de..52e0fa0 100644 --- a/internal/state/fido.go +++ b/internal/state/fido.go @@ -20,6 +20,7 @@ const ( dailyReportTTL = 36 * time.Hour // Slightly over 1 day to handle timezone edge cases pendingDMTTL = 4 * time.Hour // Max time a DM can be pending claimTTL = 10 * time.Second // Short TTL for claims - just enough to post message + userMappingTTL = 30 * 24 * time.Hour // 30 days - user mappings rarely change ) // pendingDMQueue stores all pending DMs in a single persisted value. @@ -44,14 +45,16 @@ type dmUserList struct { // - discordian-pending: Pending DM queue // - discordian-events: Event deduplication (persisted for cross-instance safety) // - discordian-claims: Distributed claims (persisted for cross-instance coordination) +// - discordian-usermappings: GitHub username to Discord user ID mappings type FidoStore struct { threads *fido.TieredCache[string, ThreadInfo] dmInfo *fido.TieredCache[string, DMInfo] dmUserLists *fido.TieredCache[string, dmUserList] // Persisted: prURL -> user IDs dailyReports *fido.TieredCache[string, DailyReportInfo] pendingDMs *fido.TieredCache[string, pendingDMQueue] - events *fido.TieredCache[string, time.Time] // Persisted for cross-instance dedup - claims *fido.TieredCache[string, time.Time] // Persisted for cross-instance claim coordination + events *fido.TieredCache[string, time.Time] // Persisted for cross-instance dedup + claims *fido.TieredCache[string, time.Time] // Persisted for cross-instance claim coordination + userMappings *fido.TieredCache[string, UserMappingInfo] // Persisted: guildID:gitHubUsername -> UserMappingInfo pendingMu sync.Mutex // Serializes pending DM operations } @@ -60,13 +63,14 @@ type FidoStore struct { type FidoStoreOption func(*fidoStoreOptions) type fidoStoreOptions struct { - threadStore fido.Store[string, ThreadInfo] - dmStore fido.Store[string, DMInfo] - dmUserStore fido.Store[string, dmUserList] - reportStore fido.Store[string, DailyReportInfo] - pendingStore fido.Store[string, pendingDMQueue] - eventStore fido.Store[string, time.Time] - claimStore fido.Store[string, time.Time] + threadStore fido.Store[string, ThreadInfo] + dmStore fido.Store[string, DMInfo] + dmUserStore fido.Store[string, dmUserList] + reportStore fido.Store[string, DailyReportInfo] + pendingStore fido.Store[string, pendingDMQueue] + eventStore fido.Store[string, time.Time] + claimStore fido.Store[string, time.Time] + userMappingStore fido.Store[string, UserMappingInfo] } // WithThreadStore sets a custom store for thread data. @@ -104,6 +108,11 @@ func WithClaimStore(s fido.Store[string, time.Time]) FidoStoreOption { return func(o *fidoStoreOptions) { o.claimStore = s } } +// WithUserMappingStore sets a custom store for user mapping data. +func WithUserMappingStore(s fido.Store[string, UserMappingInfo]) FidoStoreOption { + return func(o *fidoStoreOptions) { o.userMappingStore = s } +} + // NewFidoStore creates a new fido-backed store. // Uses CloudRun backend which auto-detects environment. // Use WithThreadStore, WithDMStore, etc. to inject custom stores for testing. @@ -177,6 +186,15 @@ func NewFidoStore(ctx context.Context, opts ...FidoStoreOption) (*FidoStore, err } } + userMappingStore := o.userMappingStore + if userMappingStore == nil { + var err error + userMappingStore, err = cloudrun.New[string, UserMappingInfo](ctx, "discordian-usermappings") + if err != nil { + return nil, fmt.Errorf("create user mapping store: %w", err) + } + } + threads, err := fido.NewTiered(threadStore, fido.TTL(threadTTL)) if err != nil { return nil, fmt.Errorf("create thread cache: %w", err) @@ -212,6 +230,11 @@ func NewFidoStore(ctx context.Context, opts ...FidoStoreOption) (*FidoStore, err return nil, fmt.Errorf("create claim cache: %w", err) } + userMappings, err := fido.NewTiered(userMappingStore, fido.TTL(userMappingTTL)) + if err != nil { + return nil, fmt.Errorf("create user mapping cache: %w", err) + } + slog.Info("initialized fido store") return &FidoStore{ threads: threads, @@ -221,6 +244,7 @@ func NewFidoStore(ctx context.Context, opts ...FidoStoreOption) (*FidoStore, err pendingDMs: pendingDMs, events: events, claims: claims, + userMappings: userMappings, }, nil } @@ -513,6 +537,46 @@ func (s *FidoStore) Cleanup(ctx context.Context) error { return nil } +// UserMapping retrieves user mapping info for a GitHub username in a guild. +func (s *FidoStore) UserMapping(ctx context.Context, guildID, gitHubUsername string) (UserMappingInfo, bool) { + key := fmt.Sprintf("%s:%s", guildID, gitHubUsername) + info, found, err := s.userMappings.Get(ctx, key) + if err != nil { + slog.Debug("user mapping lookup error", "key", key, "error", err) + return UserMappingInfo{}, false + } + return info, found +} + +// SaveUserMapping stores user mapping info for a GitHub username. +func (s *FidoStore) SaveUserMapping(ctx context.Context, guildID string, info UserMappingInfo) error { + key := fmt.Sprintf("%s:%s", guildID, info.GitHubUsername) + info.CreatedAt = time.Now() + info.GuildID = guildID + + if err := s.userMappings.Set(ctx, key, info); err != nil { + return fmt.Errorf("save user mapping: %w", err) + } + + slog.Info("saved user mapping", + "guild_id", guildID, + "github_username", info.GitHubUsername, + "discord_user_id", info.DiscordUserID) + + return nil +} + +// ListUserMappings returns all user mappings for a guild. +func (*FidoStore) ListUserMappings(_ context.Context, guildID string) []UserMappingInfo { + // Note: Fido TieredCache doesn't have a List or Scan method, + // so we can't efficiently list all mappings without scanning all keys. + // For now, return empty slice - this should be populated from usermapping.Mapper cache + // or we need to implement a separate index. + slog.Warn("ListUserMappings not fully implemented for FidoStore", + "guild_id", guildID) + return []UserMappingInfo{} +} + // Close releases resources. func (s *FidoStore) Close() error { var errs []error @@ -538,6 +602,9 @@ func (s *FidoStore) Close() error { if err := s.claims.Close(); err != nil { errs = append(errs, fmt.Errorf("close claims: %w", err)) } + if err := s.userMappings.Close(); err != nil { + errs = append(errs, fmt.Errorf("close userMappings: %w", err)) + } if len(errs) > 0 { return fmt.Errorf("close errors: %v", errs) diff --git a/internal/state/fido_test.go b/internal/state/fido_test.go index f8cade7..8d2e9a8 100644 --- a/internal/state/fido_test.go +++ b/internal/state/fido_test.go @@ -20,6 +20,7 @@ func newTestFidoStore(t *testing.T) *FidoStore { WithReportStore(null.New[string, DailyReportInfo]()), WithPendingStore(null.New[string, pendingDMQueue]()), WithEventStore(null.New[string, time.Time]()), + WithUserMappingStore(null.New[string, UserMappingInfo]()), ) if err != nil { t.Fatalf("failed to create test fido store: %v", err) @@ -598,3 +599,78 @@ func TestFidoStore_ClaimDM(t *testing.T) { t.Error("ClaimDM() should succeed after lock expiry") } } + +// TestFidoStore_UserMapping tests user mapping operations with FidoStore. +func TestFidoStore_UserMapping(t *testing.T) { + ctx := context.Background() + store := newTestFidoStore(t) + defer store.Close() //nolint:errcheck // test cleanup + + guildID := "guild-123" + githubUser1 := "octocat" + githubUser2 := "torvalds" + + // Initially no mapping + _, ok := store.UserMapping(ctx, guildID, githubUser1) + if ok { + t.Error("UserMapping() found non-existent mapping") + } + + // Save mapping + info1 := UserMappingInfo{ + GitHubUsername: githubUser1, + DiscordUserID: "discord-123", + GuildID: guildID, + CreatedAt: time.Now(), + } + if err := store.SaveUserMapping(ctx, guildID, info1); err != nil { + t.Fatalf("SaveUserMapping() error = %v", err) + } + + // Retrieve mapping + got, ok := store.UserMapping(ctx, guildID, githubUser1) + if !ok { + t.Fatal("UserMapping() did not find saved mapping") + } + if got.GitHubUsername != githubUser1 { + t.Errorf("UserMapping().GitHubUsername = %q, want %q", got.GitHubUsername, githubUser1) + } + if got.DiscordUserID != "discord-123" { + t.Errorf("UserMapping().DiscordUserID = %q, want %q", got.DiscordUserID, "discord-123") + } + + // Save second mapping + info2 := UserMappingInfo{ + GitHubUsername: githubUser2, + DiscordUserID: "discord-456", + GuildID: guildID, + CreatedAt: time.Now(), + } + if err := store.SaveUserMapping(ctx, guildID, info2); err != nil { + t.Fatalf("SaveUserMapping() error = %v", err) + } + + // Note: ListUserMappings is not fully implemented for FidoStore yet + // Skip the list tests for now + _ = guildID + _ = githubUser2 +} + +// TestFidoStore_ListUserMappings tests that ListUserMappings returns empty slice +func TestFidoStore_ListUserMappings(t *testing.T) { + ctx := context.Background() + store := newTestFidoStore(t) + defer store.Close() //nolint:errcheck // test cleanup + + // ListUserMappings currently returns empty slice and logs a warning + // This is a known limitation of FidoStore + mappings := store.ListUserMappings(ctx, "guild-123") + + if mappings == nil { + t.Error("ListUserMappings() should not return nil") + } + + if len(mappings) != 0 { + t.Errorf("ListUserMappings() = %v, want empty slice", mappings) + } +} diff --git a/internal/state/memory.go b/internal/state/memory.go index e61c85e..cfbeec2 100644 --- a/internal/state/memory.go +++ b/internal/state/memory.go @@ -16,7 +16,8 @@ type MemoryStore struct { processed map[string]time.Time pendingDMs map[string]*PendingDM dailyReports map[string]DailyReportInfo - claims map[string]time.Time // claimKey -> expiry time + userMappings map[string]UserMappingInfo // guildID:gitHubUsername -> UserMappingInfo + claims map[string]time.Time // claimKey -> expiry time mu sync.RWMutex threadRetain time.Duration dmRetain time.Duration @@ -32,6 +33,7 @@ func NewMemoryStore() *MemoryStore { processed: make(map[string]time.Time), pendingDMs: make(map[string]*PendingDM), dailyReports: make(map[string]DailyReportInfo), + userMappings: make(map[string]UserMappingInfo), claims: make(map[string]time.Time), threadRetain: 30 * 24 * time.Hour, // 30 days dmRetain: 90 * 24 * time.Hour, // 90 days @@ -47,6 +49,10 @@ func dmKey(userID, prURL string) string { return fmt.Sprintf("%s:%s", userID, prURL) } +func userMappingKey(guildID, gitHubUsername string) string { + return fmt.Sprintf("%s:%s", guildID, gitHubUsername) +} + // Thread returns thread info for a PR in a channel. func (s *MemoryStore) Thread(ctx context.Context, owner, repo string, number int, channelID string) (ThreadInfo, bool) { s.mu.RLock() @@ -312,6 +318,47 @@ func (s *MemoryStore) Cleanup(ctx context.Context) error { return nil } +// UserMapping returns user mapping info for a GitHub username in a guild. +func (s *MemoryStore) UserMapping(ctx context.Context, guildID, gitHubUsername string) (UserMappingInfo, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + info, exists := s.userMappings[userMappingKey(guildID, gitHubUsername)] + return info, exists +} + +// SaveUserMapping saves user mapping info for a GitHub username. +func (s *MemoryStore) SaveUserMapping(ctx context.Context, guildID string, info UserMappingInfo) error { + s.mu.Lock() + defer s.mu.Unlock() + + info.CreatedAt = time.Now() + info.GuildID = guildID + s.userMappings[userMappingKey(guildID, info.GitHubUsername)] = info + + slog.Info("saved user mapping", + "guild_id", guildID, + "github_username", info.GitHubUsername, + "discord_user_id", info.DiscordUserID) + + return nil +} + +// ListUserMappings returns all user mappings for a guild. +func (s *MemoryStore) ListUserMappings(ctx context.Context, guildID string) []UserMappingInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + var mappings []UserMappingInfo + for _, info := range s.userMappings { + if info.GuildID == guildID { + mappings = append(mappings, info) + } + } + + return mappings +} + // Close closes the store (no-op for memory store). func (*MemoryStore) Close() error { return nil diff --git a/internal/state/memory_test.go b/internal/state/memory_test.go index 37c145a..4de3331 100644 --- a/internal/state/memory_test.go +++ b/internal/state/memory_test.go @@ -453,3 +453,110 @@ func TestMemoryStore_ListDMUsers(t *testing.T) { t.Errorf("ListDMUsers(pr2)[0] = %q, want user1", users[0]) } } + +// TestMemoryStore_UserMapping tests user mapping operations. +func TestMemoryStore_UserMapping(t *testing.T) { + ctx := context.Background() + store := NewMemoryStore() + defer store.Close() //nolint:errcheck // test cleanup + + guildID := "guild-123" + githubUser1 := "octocat" + githubUser2 := "torvalds" + + // Initially no mapping + _, ok := store.UserMapping(ctx, guildID, githubUser1) + if ok { + t.Error("UserMapping() found non-existent mapping") + } + + // Save mapping + info1 := UserMappingInfo{ + GitHubUsername: githubUser1, + DiscordUserID: "discord-123", + GuildID: guildID, + CreatedAt: time.Now(), + } + if err := store.SaveUserMapping(ctx, guildID, info1); err != nil { + t.Fatalf("SaveUserMapping() error = %v", err) + } + + // Retrieve mapping + got, ok := store.UserMapping(ctx, guildID, githubUser1) + if !ok { + t.Fatal("UserMapping() did not find saved mapping") + } + if got.GitHubUsername != githubUser1 { + t.Errorf("UserMapping().GitHubUsername = %q, want %q", got.GitHubUsername, githubUser1) + } + if got.DiscordUserID != "discord-123" { + t.Errorf("UserMapping().DiscordUserID = %q, want %q", got.DiscordUserID, "discord-123") + } + if got.GuildID != guildID { + t.Errorf("UserMapping().GuildID = %q, want %q", got.GuildID, guildID) + } + + // Save second mapping + info2 := UserMappingInfo{ + GitHubUsername: githubUser2, + DiscordUserID: "discord-456", + GuildID: guildID, + CreatedAt: time.Now(), + } + if err := store.SaveUserMapping(ctx, guildID, info2); err != nil { + t.Fatalf("SaveUserMapping() error = %v", err) + } + + // List mappings for guild + mappings := store.ListUserMappings(ctx, guildID) + if len(mappings) != 2 { + t.Fatalf("ListUserMappings() returned %d mappings, want 2", len(mappings)) + } + + // Verify both mappings are present + foundUsers := make(map[string]bool) + for _, m := range mappings { + foundUsers[m.GitHubUsername] = true + if m.GuildID != guildID { + t.Errorf("ListUserMappings() mapping has GuildID %q, want %q", m.GuildID, guildID) + } + } + if !foundUsers[githubUser1] { + t.Error("ListUserMappings() should include octocat") + } + if !foundUsers[githubUser2] { + t.Error("ListUserMappings() should include torvalds") + } + + // Different guild should have no mappings + mappings2 := store.ListUserMappings(ctx, "different-guild") + if len(mappings2) != 0 { + t.Errorf("ListUserMappings(different-guild) returned %d mappings, want 0", len(mappings2)) + } + + // Update existing mapping + info1Updated := UserMappingInfo{ + GitHubUsername: githubUser1, + DiscordUserID: "discord-789", + GuildID: guildID, + CreatedAt: time.Now(), + } + if err := store.SaveUserMapping(ctx, guildID, info1Updated); err != nil { + t.Fatalf("SaveUserMapping() update error = %v", err) + } + + // Verify update + got, ok = store.UserMapping(ctx, guildID, githubUser1) + if !ok { + t.Fatal("UserMapping() did not find updated mapping") + } + if got.DiscordUserID != "discord-789" { + t.Errorf("UserMapping().DiscordUserID = %q, want %q after update", got.DiscordUserID, "discord-789") + } + + // Still only 2 mappings (update, not insert) + mappings = store.ListUserMappings(ctx, guildID) + if len(mappings) != 2 { + t.Errorf("ListUserMappings() after update returned %d mappings, want 2", len(mappings)) + } +} diff --git a/internal/state/store.go b/internal/state/store.go index 59e287e..5bb827d 100644 --- a/internal/state/store.go +++ b/internal/state/store.go @@ -46,6 +46,14 @@ type DailyReportInfo struct { GuildID string `json:"guild_id"` } +// UserMappingInfo stores explicit GitHub-Discord user mappings. +type UserMappingInfo struct { + CreatedAt time.Time `json:"created_at"` + GitHubUsername string `json:"github_username"` + DiscordUserID string `json:"discord_user_id"` + GuildID string `json:"guild_id"` +} + // Store provides persistent state operations. // //nolint:interfacebloat // Store handles threads, DMs, events, reports, and cleanup @@ -79,6 +87,11 @@ type Store interface { DailyReportInfo(ctx context.Context, userID string) (DailyReportInfo, bool) SaveDailyReportInfo(ctx context.Context, userID string, info DailyReportInfo) error + // User mapping tracking (GitHub username <-> Discord user ID) + UserMapping(ctx context.Context, guildID, gitHubUsername string) (UserMappingInfo, bool) + SaveUserMapping(ctx context.Context, guildID string, info UserMappingInfo) error + ListUserMappings(ctx context.Context, guildID string) []UserMappingInfo + // Lifecycle Cleanup(ctx context.Context) error Close() error diff --git a/internal/usermapping/usermapping.go b/internal/usermapping/usermapping.go index 65da64e..08d1581 100644 --- a/internal/usermapping/usermapping.go +++ b/internal/usermapping/usermapping.go @@ -7,6 +7,8 @@ import ( "log/slog" "sync" "time" + + "github.com/codeGROOVE-dev/discordian/internal/state" ) const ( @@ -34,26 +36,31 @@ type cacheEntry struct { type Mapper struct { configLookup ConfigLookup discordLookup DiscordLookup + store state.Store + guildID string cache map[string]cacheEntry org string mu sync.RWMutex } // New creates a new user mapper. -func New(org string, configLookup ConfigLookup, discordLookup DiscordLookup) *Mapper { +func New(org string, configLookup ConfigLookup, discordLookup DiscordLookup, store state.Store, guildID string) *Mapper { return &Mapper{ org: org, configLookup: configLookup, discordLookup: discordLookup, + store: store, + guildID: guildID, cache: make(map[string]cacheEntry), } } // DiscordID returns the Discord user ID for a GitHub username. -// Uses a 3-tier lookup: +// Uses a 4-tier lookup: // 1. YAML config mapping (explicit) -// 2. Discord guild username match -// 3. Empty string (fallback). +// 2. Fido storage (self-service via /goose github-user command) +// 3. Discord guild username match +// 4. Empty string (fallback). // Results are cached for 24 hours. func (m *Mapper) DiscordID(ctx context.Context, githubUsername string) string { // Check cache first (with TTL) @@ -104,7 +111,21 @@ func (m *Mapper) DiscordID(ctx context.Context, githubUsername string) string { } } - // Tier 2: Discord username match + // Tier 2: Fido storage (self-service mappings) + if m.store != nil && m.guildID != "" { + if mapping, found := m.store.UserMapping(ctx, m.guildID, githubUsername); found { + m.cacheResult(githubUsername, mapping.DiscordUserID) + slog.Info("mapped GitHub user to Discord via Fido storage", + "github_username", githubUsername, + "discord_id", mapping.DiscordUserID, + "guild_id", m.guildID, + "org", m.org, + "method", "fido_storage") + return mapping.DiscordUserID + } + } + + // Tier 3: Discord username match if m.discordLookup != nil { if id := m.discordLookup.LookupUserByUsername(ctx, githubUsername); id != "" { m.cacheResult(githubUsername, id) @@ -117,7 +138,7 @@ func (m *Mapper) DiscordID(ctx context.Context, githubUsername string) string { } } - // Tier 3: No mapping found + // Tier 4: No mapping found slog.Info("no Discord mapping found for GitHub user", "github_username", githubUsername, "org", m.org, diff --git a/internal/usermapping/usermapping_test.go b/internal/usermapping/usermapping_test.go index c283a49..89c1a0e 100644 --- a/internal/usermapping/usermapping_test.go +++ b/internal/usermapping/usermapping_test.go @@ -44,7 +44,7 @@ func TestMapper_DiscordID(t *testing.T) { }, } - mapper := New("testorg", configLookup, discordLookup) + mapper := New("testorg", configLookup, discordLookup, nil, "test-guild") tests := []struct { name string @@ -98,7 +98,7 @@ func TestMapper_DiscordID_ConfigOverridesDiscord(t *testing.T) { }, } - mapper := New("testorg", configLookup, discordLookup) + mapper := New("testorg", configLookup, discordLookup, nil, "test-guild") // Config should take priority got := mapper.DiscordID(ctx, "alice") @@ -107,6 +107,56 @@ func TestMapper_DiscordID_ConfigOverridesDiscord(t *testing.T) { } } +// TestMapper_DiscordID_ConfigUsername tests config value being a Discord username. +func TestMapper_DiscordID_ConfigUsername(t *testing.T) { + ctx := context.Background() + + configLookup := &mockConfigLookup{ + users: map[string]string{ + "alice": "AliceDiscord", // Discord username, not numeric ID + }, + } + + discordLookup := &mockDiscordLookup{ + users: map[string]string{ + "AliceDiscord": "111111111111111111", + }, + } + + mapper := New("testorg", configLookup, discordLookup, nil, "test-guild") + + got := mapper.DiscordID(ctx, "alice") + if got != "111111111111111111" { + t.Errorf("DiscordID(alice) with config username = %q, want 111111111111111111", got) + } +} + +// TestMapper_DiscordID_ConfigUsername_NotFound tests when config username isn't found. +func TestMapper_DiscordID_ConfigUsername_NotFound(t *testing.T) { + ctx := context.Background() + + configLookup := &mockConfigLookup{ + users: map[string]string{ + "alice": "NonExistentUser", // Discord username not found + }, + } + + discordLookup := &mockDiscordLookup{ + users: map[string]string{ + "OtherUser": "222222222222222222", + }, + } + + mapper := New("testorg", configLookup, discordLookup, nil, "test-guild") + + // Should fall back to tier 3 (Discord username match) + // Since "alice" is not in Discord either, should return empty + got := mapper.DiscordID(ctx, "alice") + if got != "" { + t.Errorf("DiscordID(alice) with unknown config username = %q, want empty", got) + } +} + func TestMapper_DiscordID_Caching(t *testing.T) { ctx := context.Background() @@ -116,7 +166,7 @@ func TestMapper_DiscordID_Caching(t *testing.T) { }, } - mapper := New("testorg", nil, discordLookup) + mapper := New("testorg", nil, discordLookup, nil, "test-guild") // First call - should hit Discord lookup id1 := mapper.DiscordID(ctx, "bob") @@ -143,7 +193,7 @@ func TestMapper_Mention(t *testing.T) { }, } - mapper := New("testorg", configLookup, nil) + mapper := New("testorg", configLookup, nil, nil, "test-guild") tests := []struct { name string @@ -181,7 +231,7 @@ func TestMapper_ClearCache(t *testing.T) { }, } - mapper := New("testorg", configLookup, nil) + mapper := New("testorg", configLookup, nil, nil, "test-guild") // Populate cache mapper.DiscordID(ctx, "alice") @@ -207,7 +257,7 @@ func TestMapper_NilLookups(t *testing.T) { ctx := context.Background() // Both lookups nil - mapper := New("testorg", nil, nil) + mapper := New("testorg", nil, nil, nil, "test-guild") got := mapper.DiscordID(ctx, "anyone") if got != "" { @@ -229,7 +279,7 @@ func TestMapper_DiscordID_CacheTTL(t *testing.T) { }, } - mapper := New("testorg", nil, discordLookup) + mapper := New("testorg", nil, discordLookup, nil, "test-guild") // First call - populates cache id1 := mapper.DiscordID(ctx, "bob") @@ -316,7 +366,7 @@ func TestMapper_ConfigUsernameResolution(t *testing.T) { users: tt.discordUsers, } - mapper := New("testorg", configLookup, discordLookup) + mapper := New("testorg", configLookup, discordLookup, nil, "test-guild") got := mapper.DiscordID(ctx, tt.githubUsername) if got != tt.wantID { @@ -342,7 +392,7 @@ func TestMapper_ConfigUsernameResolution_Mention(t *testing.T) { }, } - mapper := New("testorg", configLookup, discordLookup) + mapper := New("testorg", configLookup, discordLookup, nil, "test-guild") got := mapper.Mention(ctx, "alice") want := "<@111111111111111111>" @@ -362,7 +412,7 @@ func TestMapper_ExportCache(t *testing.T) { }, } - mapper := New("testorg", configLookup, nil) + mapper := New("testorg", configLookup, nil, nil, "test-guild") // Populate cache mapper.DiscordID(ctx, "alice") @@ -627,3 +677,30 @@ func TestReverseMapper_CacheTTL(t *testing.T) { t.Errorf("After TTL expiry: GitHubUsername = %q, want bob", username2) } } + +// TestIsAllDigits tests the isAllDigits helper function. +func TestIsAllDigits(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"", false}, + {"123", true}, + {"12345678901234567890", true}, + {"123abc", false}, + {"abc123", false}, + {"12.34", false}, + {"12-34", false}, + {"0", true}, + {"00000", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := isAllDigits(tt.input) + if got != tt.want { + t.Errorf("isAllDigits(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +}