From 940d0fa01a9dee3fa3c0ef5553f3147e6b0af564 Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Tue, 12 May 2026 13:03:42 +0530 Subject: [PATCH 1/3] feat(membership): add group membership management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces the SetGroupMemberRole RPC and three service methods on the membership package: AddGroupMember, SetGroupMemberRole, and OnGroupCreated. These manage policy + SpiceDB relation atomically and keep them in sync, fixing the leaky-relation pattern at the group layer. - AddGroupMember validates org membership of the principal and rejects duplicates with ErrAlreadyMember (service-only, no proto). - SetGroupMemberRole rejects non-members with ErrNotMember and enforces a min-owner constraint (ErrLastGroupOwnerRole) on demotion. - OnGroupCreated bundles the group<->org hierarchy relations with the initial owner add, so group.Create can wire SpiceDB with one call. - Principal validation is restricted to app/user; the switch is kept extensible for future principal types. Audit events are added for both the added and role-changed cases. No call sites are migrated yet — group.Create, AddGroupUsers, and the deletion of legacy group service methods will follow in subsequent PRs. PROTON_COMMIT is temporarily pinned to the feature-branch SHA on raystack/proton#485; it will be re-pinned to the merge commit once that PR lands. Co-Authored-By: Claude Opus 4.7 (1M context) --- core/audit/audit.go | 10 +- core/membership/errors.go | 2 + core/membership/service.go | 314 +++++++++++++++ core/membership/service_test.go | 379 ++++++++++++++++++ internal/api/v1beta1connect/errors.go | 3 + internal/api/v1beta1connect/group.go | 54 +++ internal/api/v1beta1connect/group_test.go | 145 +++++++ internal/api/v1beta1connect/interfaces.go | 3 + .../mocks/membership_service.go | 150 +++++++ pkg/auditrecord/consts.go | 4 + 10 files changed, 1060 insertions(+), 4 deletions(-) diff --git a/core/audit/audit.go b/core/audit/audit.go index c0ac6b898..8b5a6018e 100644 --- a/core/audit/audit.go +++ b/core/audit/audit.go @@ -55,10 +55,12 @@ const ( ServiceUserCreatedEvent EventName = "app.serviceuser.created" ServiceUserDeletedEvent EventName = "app.serviceuser.deleted" - GroupCreatedEvent EventName = "app.group.created" - GroupUpdatedEvent EventName = "app.group.updated" - GroupDeletedEvent EventName = "app.group.deleted" - GroupMemberRemovedEvent EventName = "app.group.members.removed" + GroupCreatedEvent EventName = "app.group.created" + GroupUpdatedEvent EventName = "app.group.updated" + GroupDeletedEvent EventName = "app.group.deleted" + GroupMemberCreatedEvent EventName = "app.group.member.created" + GroupMemberRoleChangedEvent EventName = "app.group.member.role_changed" + GroupMemberRemovedEvent EventName = "app.group.members.removed" RoleCreatedEvent EventName = "app.role.created" RoleUpdatedEvent EventName = "app.role.updated" diff --git a/core/membership/errors.go b/core/membership/errors.go index 49666425e..ce835ae6c 100644 --- a/core/membership/errors.go +++ b/core/membership/errors.go @@ -13,4 +13,6 @@ var ( ErrNotOrgMember = errors.New("principal is not a member of the organization") ErrInvalidProjectRole = errors.New("role is not valid for project scope") ErrInvalidResourceType = errors.New("unsupported resource type") + ErrInvalidGroupRole = errors.New("role is not valid for group scope") + ErrLastGroupOwnerRole = errors.New("cannot change role: this is the last owner of the group") ) diff --git a/core/membership/service.go b/core/membership/service.go index fab850fbe..8aa33f73b 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -1019,3 +1019,317 @@ func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, reso return members, nil } + +// AddGroupMember adds a principal as a member of a group with an explicit role. +// Returns ErrAlreadyMember if the principal already has a policy on this group. +// The principal must be a member of the group's parent organization. +func (s *Service) AddGroupMember(ctx context.Context, groupID, principalID, principalType, roleID string) error { + grp, err := s.groupService.Get(ctx, groupID) + if err != nil { + return err + } + + principal, err := s.validateGroupPrincipal(ctx, principalID, principalType) + if err != nil { + return err + } + + fetchedRole, err := s.validateGroupRole(ctx, roleID, grp.OrganizationID) + if err != nil { + return err + } + + if err := s.validateOrgMembership(ctx, grp.OrganizationID, principalID, principalType); err != nil { + return err + } + + existing, err := s.policyService.List(ctx, policy.Filter{ + GroupID: groupID, + PrincipalID: principalID, + PrincipalType: principalType, + }) + if err != nil { + return fmt.Errorf("list existing policies: %w", err) + } + if len(existing) > 0 { + return ErrAlreadyMember + } + + createdPolicy, err := s.createPolicy(ctx, groupID, schema.GroupNamespace, principalID, principalType, fetchedRole.ID) + if err != nil { + return err + } + + relationName := groupRoleToRelation(fetchedRole) + if err := s.createRelation(ctx, groupID, schema.GroupNamespace, principalID, principalType, relationName); err != nil { + if deleteErr := s.policyService.Delete(ctx, createdPolicy.ID); deleteErr != nil { + s.log.WarnContext(ctx, "orphaned policy: relation creation failed and policy cleanup also failed", + "policy_id", createdPolicy.ID, + "group_id", groupID, + "principal_id", principalID, + "policy_delete_error", deleteErr, + ) + } + return err + } + + s.auditGroupMemberAdded(ctx, grp, principal, fetchedRole.ID) + return nil +} + +// SetGroupMemberRole changes an existing member's role in a group. +// Returns ErrNotMember if the principal has no existing policy on the group. +// Enforces the min-owner constraint: demoting the last owner returns ErrLastGroupOwnerRole. +func (s *Service) SetGroupMemberRole(ctx context.Context, groupID, principalID, principalType, roleID string) error { + grp, err := s.groupService.Get(ctx, groupID) + if err != nil { + return err + } + + principal, err := s.validateGroupPrincipal(ctx, principalID, principalType) + if err != nil { + return err + } + + fetchedRole, err := s.validateGroupRole(ctx, roleID, grp.OrganizationID) + if err != nil { + return err + } + resolvedRoleID := fetchedRole.ID + + existing, err := s.policyService.List(ctx, policy.Filter{ + GroupID: groupID, + PrincipalID: principalID, + PrincipalType: principalType, + }) + if err != nil { + return fmt.Errorf("list existing policies: %w", err) + } + if len(existing) == 0 { + return ErrNotMember + } + + // skip if the user already has exactly this role + if len(existing) == 1 && existing[0].RoleID == resolvedRoleID { + return nil + } + + if err := s.validateMinGroupOwnerConstraint(ctx, groupID, resolvedRoleID, existing); err != nil { + return err + } + + if err := s.replacePolicy(ctx, groupID, schema.GroupNamespace, principalID, principalType, resolvedRoleID, existing); err != nil { + return err + } + + newRelation := groupRoleToRelation(fetchedRole) + oldRelations := []string{schema.OwnerRelationName, schema.MemberRelationName} + if err := s.replaceRelation(ctx, groupID, schema.GroupNamespace, principalID, principalType, oldRelations, newRelation); err != nil { + s.log.ErrorContext(ctx, "membership state inconsistent: policy replaced but group relation update failed, needs manual fix", + "group_id", groupID, + "principal_id", principalID, + "principal_type", principalType, + "new_role_id", resolvedRoleID, + "expected_relation", newRelation, + "error", err, + ) + return err + } + + s.auditGroupMemberRoleChanged(ctx, grp, principal, resolvedRoleID) + return nil +} + +// OnGroupCreated wires up SpiceDB relations for a newly-created group: +// links the group to its parent organization (both directions) and adds the +// creator as owner via AddGroupMember. +func (s *Service) OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, creatorType string) error { + if err := s.linkGroupToOrg(ctx, groupID, orgID); err != nil { + return err + } + if err := s.AddGroupMember(ctx, groupID, creatorID, creatorType, schema.GroupOwnerRole); err != nil { + return err + } + return nil +} + +// linkGroupToOrg creates the two hierarchy relations between a group and its org: +// - group#org@organization (identity link from group to org) +// - organization#member@group#member (lets org#member traverse to group members) +func (s *Service) linkGroupToOrg(ctx context.Context, groupID, orgID string) error { + if _, err := s.relationService.Create(ctx, relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.OrganizationRelationName, + }); err != nil { + return fmt.Errorf("link group to org: %w", err) + } + + if _, err := s.relationService.Create(ctx, relation.Relation{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ + ID: groupID, + Namespace: schema.GroupNamespace, + SubRelationName: schema.MemberRelationName, + }, + RelationName: schema.MemberRelationName, + }); err != nil { + return fmt.Errorf("add group as org member: %w", err) + } + + return nil +} + +// validateGroupRole checks that the role is valid for group scope: +// - a platform-wide role scoped to groups, or +// - a custom role created for the group's parent organization. +func (s *Service) validateGroupRole(ctx context.Context, roleID, orgID string) (role.Role, error) { + fetchedRole, err := s.roleService.Get(ctx, roleID) + if err != nil { + return role.Role{}, err + } + if !slices.Contains(fetchedRole.Scopes, schema.GroupNamespace) { + return role.Role{}, ErrInvalidGroupRole + } + if fetchedRole.OrgID == orgID { + return fetchedRole, nil + } + if utils.IsNullUUID(fetchedRole.OrgID) { + return fetchedRole, nil + } + return role.Role{}, ErrInvalidGroupRole +} + +// validateGroupPrincipal fetches and validates the principal for group operations. +// Currently only app/user is supported; the switch is structured so future principal +// types (e.g. serviceuser) can be enabled here without touching call sites. +func (s *Service) validateGroupPrincipal(ctx context.Context, principalID, principalType string) (principalInfo, error) { + switch principalType { + case schema.UserPrincipal: + usr, err := s.userService.GetByID(ctx, principalID) + if err != nil { + return principalInfo{}, err + } + if usr.State == user.Disabled { + return principalInfo{}, user.ErrDisabled + } + return principalInfo{ + ID: usr.ID, + Type: schema.UserPrincipal, + Name: usr.Title, + Email: usr.Email, + }, nil + default: + return principalInfo{}, ErrInvalidPrincipalType + } +} + +// validateMinGroupOwnerConstraint ensures the group keeps at least one owner +// after the role change. Mirrors the org-level constraint. +func (s *Service) validateMinGroupOwnerConstraint(ctx context.Context, groupID, newRoleID string, existing []policy.Policy) error { + ownerRole, err := s.roleService.Get(ctx, schema.GroupOwnerRole) + if err != nil { + return fmt.Errorf("get group owner role: %w", err) + } + + if newRoleID == ownerRole.ID { + return nil + } + + isCurrentlyOwner := false + for _, p := range existing { + if p.RoleID == ownerRole.ID { + isCurrentlyOwner = true + break + } + } + if !isCurrentlyOwner { + return nil + } + + ownerPolicies, err := s.policyService.List(ctx, policy.Filter{ + GroupID: groupID, + RoleID: ownerRole.ID, + }) + if err != nil { + return fmt.Errorf("list group owner policies: %w", err) + } + if len(ownerPolicies) <= 1 { + return ErrLastGroupOwnerRole + } + return nil +} + +// groupRoleToRelation maps a group role to the matching SpiceDB relation name. +func groupRoleToRelation(r role.Role) string { + if r.Name == schema.GroupOwnerRole { + return schema.OwnerRelationName + } + return schema.MemberRelationName +} + +func (s *Service) auditGroupMemberAdded(ctx context.Context, grp group.Group, p principalInfo, roleID string) { + targetType, _ := principalTypeToAuditType(p.Type) + meta := map[string]any{"role_id": roleID} + if p.Email != "" { + meta["email"] = p.Email + } + + s.auditRecordRepository.Create(ctx, auditrecord.AuditRecord{ + Event: pkgAuditRecord.GroupMemberAddedEvent, + Resource: auditrecord.Resource{ + ID: grp.ID, + Type: pkgAuditRecord.GroupType, + Name: grp.Title, + }, + Target: &auditrecord.Target{ + ID: p.ID, + Type: targetType, + Name: p.Name, + Metadata: meta, + }, + OrgID: grp.OrganizationID, + OccurredAt: time.Now(), + }) + + audit.GetAuditor(ctx, grp.OrganizationID).LogWithAttrs(audit.GroupMemberCreatedEvent, audit.Target{ + ID: p.ID, + Type: p.Type, + }, map[string]string{ + "role_id": roleID, + "group_id": grp.ID, + }) +} + +func (s *Service) auditGroupMemberRoleChanged(ctx context.Context, grp group.Group, p principalInfo, roleID string) { + targetType, _ := principalTypeToAuditType(p.Type) + meta := map[string]any{"role_id": roleID} + if p.Email != "" { + meta["email"] = p.Email + } + + s.auditRecordRepository.Create(ctx, auditrecord.AuditRecord{ + Event: pkgAuditRecord.GroupMemberRoleChangedEvent, + Resource: auditrecord.Resource{ + ID: grp.ID, + Type: pkgAuditRecord.GroupType, + Name: grp.Title, + }, + Target: &auditrecord.Target{ + ID: p.ID, + Type: targetType, + Name: p.Name, + Metadata: meta, + }, + OrgID: grp.OrganizationID, + OccurredAt: time.Now(), + }) + + audit.GetAuditor(ctx, grp.OrganizationID).LogWithAttrs(audit.GroupMemberRoleChangedEvent, audit.Target{ + ID: p.ID, + Type: p.Type, + }, map[string]string{ + "role_id": roleID, + "group_id": grp.ID, + }) +} diff --git a/core/membership/service_test.go b/core/membership/service_test.go index 2b5e37feb..02ac2f430 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -1217,3 +1217,382 @@ func TestService_ListPrincipalsByResource(t *testing.T) { }) } } + +func TestService_AddGroupMember(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + groupID := uuid.New().String() + userID := uuid.New().String() + ownerRoleID := uuid.New().String() + memberRoleID := uuid.New().String() + + enabledUser := user.User{ID: userID, Title: "test-user", Email: "test@acme.dev", State: user.Enabled} + grp := group.Group{ID: groupID, OrganizationID: orgID, Title: "Test Group"} + + groupMemberRelation := func(name string) relation.Relation { + return relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, + RelationName: name, + } + } + + tests := []struct { + name string + setup func(*mocks.PolicyService, *mocks.RelationService, *mocks.RoleService, *mocks.GroupService, *mocks.UserService, *mocks.AuditRecordRepository) + principalType string + roleID string + wantErr error + wantErrContain string + }{ + { + name: "should return error if group does not exist", + setup: func(_ *mocks.PolicyService, _ *mocks.RelationService, _ *mocks.RoleService, grpSvc *mocks.GroupService, _ *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(group.Group{}, group.ErrNotExist) + }, + roleID: memberRoleID, + wantErr: group.ErrNotExist, + }, + { + name: "should return error if principal type is unsupported", + setup: func(_ *mocks.PolicyService, _ *mocks.RelationService, _ *mocks.RoleService, grpSvc *mocks.GroupService, _ *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + }, + principalType: schema.ServiceUserPrincipal, + roleID: memberRoleID, + wantErr: membership.ErrInvalidPrincipalType, + }, + { + name: "should return error if user is disabled", + setup: func(_ *mocks.PolicyService, _ *mocks.RelationService, _ *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(user.User{ID: userID, State: user.Disabled}, nil) + }, + roleID: memberRoleID, + wantErr: user.ErrDisabled, + }, + { + name: "should return error if role is not valid for group scope", + setup: func(_ *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrInvalidGroupRole, + }, + { + name: "should return error if user is not a member of the org", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrNotOrgMember, + }, + { + name: "should return ErrAlreadyMember if principal has existing group policy", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "g-p1"}}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrAlreadyMember, + }, + { + name: "should succeed adding a member with member role", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, policy.Policy{ + RoleID: memberRoleID, ResourceID: groupID, ResourceType: schema.GroupNamespace, + PrincipalID: userID, PrincipalType: schema.UserPrincipal, + }).Return(policy.Policy{ID: "new-p"}, nil) + relSvc.EXPECT().Create(ctx, groupMemberRelation(schema.MemberRelationName)).Return(relation.Relation{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + roleID: memberRoleID, + wantErr: nil, + }, + { + name: "should succeed adding a member with owner role", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, ownerRoleID).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{ID: "new-p"}, nil) + relSvc.EXPECT().Create(ctx, groupMemberRelation(schema.OwnerRelationName)).Return(relation.Relation{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + roleID: ownerRoleID, + wantErr: nil, + }, + { + name: "should cleanup policy if relation creation fails", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{ID: "created-p"}, nil) + relSvc.EXPECT().Create(ctx, mock.Anything).Return(relation.Relation{}, errors.New("spicedb unavailable")) + policySvc.EXPECT().Delete(ctx, "created-p").Return(nil) + }, + roleID: memberRoleID, + wantErrContain: "spicedb unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + if tt.setup != nil { + tt.setup(mockPolicySvc, mockRelSvc, mockRoleSvc, mockGrpSvc, mockUserSvc, mockAuditRepo) + } + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mocks.NewProjectService(t), mockGrpSvc, mocks.NewServiceuserService(t), mockAuditRepo) + + principalType := tt.principalType + if principalType == "" { + principalType = schema.UserPrincipal + } + err := svc.AddGroupMember(ctx, groupID, userID, principalType, tt.roleID) + + if tt.wantErr != nil { + assert.ErrorIs(t, err, tt.wantErr) + } else if tt.wantErrContain != "" { + assert.ErrorContains(t, err, tt.wantErrContain) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestService_SetGroupMemberRole(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + groupID := uuid.New().String() + userID := uuid.New().String() + ownerRoleID := uuid.New().String() + memberRoleID := uuid.New().String() + + enabledUser := user.User{ID: userID, Title: "test-user", Email: "test@acme.dev", State: user.Enabled} + grp := group.Group{ID: groupID, OrganizationID: orgID, Title: "Test Group"} + + groupMemberRelation := func(name string) relation.Relation { + return relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, + RelationName: name, + } + } + + tests := []struct { + name string + setup func(*mocks.PolicyService, *mocks.RelationService, *mocks.RoleService, *mocks.GroupService, *mocks.UserService, *mocks.AuditRecordRepository) + principalType string + roleID string + wantErr error + wantErrContain string + }{ + { + name: "should return error if user is not a group member", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrNotMember, + }, + { + name: "should skip write when role is unchanged", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: memberRoleID}}, nil) + }, + roleID: memberRoleID, + wantErr: nil, + }, + { + name: "should return error if demoting last owner", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) + roleSvc.EXPECT().Get(ctx, schema.GroupOwnerRole).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrLastGroupOwnerRole, + }, + { + name: "should succeed demoting owner to member with multiple owners (relation flips owner->member)", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) + roleSvc.EXPECT().Get(ctx, schema.GroupOwnerRole).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) + policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + policySvc.EXPECT().Create(ctx, policy.Policy{ + RoleID: memberRoleID, ResourceID: groupID, ResourceType: schema.GroupNamespace, + PrincipalID: userID, PrincipalType: schema.UserPrincipal, + }).Return(policy.Policy{ID: "new-p"}, nil) + relSvc.EXPECT().Delete(ctx, groupMemberRelation(schema.OwnerRelationName)).Return(nil) + relSvc.EXPECT().Delete(ctx, groupMemberRelation(schema.MemberRelationName)).Return(relation.ErrNotExist) + relSvc.EXPECT().Create(ctx, groupMemberRelation(schema.MemberRelationName)).Return(relation.Relation{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + roleID: memberRoleID, + wantErr: nil, + }, + { + name: "should succeed promoting member to owner (relation flips member->owner)", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, ownerRoleID).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: memberRoleID}}, nil) + roleSvc.EXPECT().Get(ctx, schema.GroupOwnerRole).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole}, nil) + policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{ID: "new-p"}, nil) + relSvc.EXPECT().Delete(ctx, groupMemberRelation(schema.OwnerRelationName)).Return(relation.ErrNotExist) + relSvc.EXPECT().Delete(ctx, groupMemberRelation(schema.MemberRelationName)).Return(nil) + relSvc.EXPECT().Create(ctx, groupMemberRelation(schema.OwnerRelationName)).Return(relation.Relation{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + roleID: ownerRoleID, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + if tt.setup != nil { + tt.setup(mockPolicySvc, mockRelSvc, mockRoleSvc, mockGrpSvc, mockUserSvc, mockAuditRepo) + } + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mocks.NewProjectService(t), mockGrpSvc, mocks.NewServiceuserService(t), mockAuditRepo) + + principalType := tt.principalType + if principalType == "" { + principalType = schema.UserPrincipal + } + err := svc.SetGroupMemberRole(ctx, groupID, userID, principalType, tt.roleID) + + if tt.wantErr != nil { + assert.ErrorIs(t, err, tt.wantErr) + } else if tt.wantErrContain != "" { + assert.ErrorContains(t, err, tt.wantErrContain) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestService_OnGroupCreated(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + groupID := uuid.New().String() + creatorID := uuid.New().String() + ownerRoleID := uuid.New().String() + + enabledUser := user.User{ID: creatorID, Title: "creator", Email: "creator@acme.dev", State: user.Enabled} + grp := group.Group{ID: groupID, OrganizationID: orgID, Title: "Test Group"} + + groupOrgRelation := relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.OrganizationRelationName, + } + orgGroupMemberRelation := relation.Relation{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ + ID: groupID, + Namespace: schema.GroupNamespace, + SubRelationName: schema.MemberRelationName, + }, + RelationName: schema.MemberRelationName, + } + creatorOwnerRelation := relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: creatorID, Namespace: schema.UserPrincipal}, + RelationName: schema.OwnerRelationName, + } + + t.Run("should link group<->org and add creator as owner", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockRelSvc.EXPECT().Create(ctx, groupOrgRelation).Return(relation.Relation{}, nil) + mockRelSvc.EXPECT().Create(ctx, orgGroupMemberRelation).Return(relation.Relation{}, nil) + + mockGrpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + mockUserSvc.EXPECT().GetByID(ctx, creatorID).Return(enabledUser, nil) + mockRoleSvc.EXPECT().Get(ctx, schema.GroupOwnerRole).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole, Scopes: []string{schema.GroupNamespace}}, nil) + mockUserSvc.EXPECT().GetByID(ctx, creatorID).Return(enabledUser, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: creatorID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: creatorID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{ID: "new-p"}, nil) + mockRelSvc.EXPECT().Create(ctx, creatorOwnerRelation).Return(relation.Relation{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mocks.NewProjectService(t), mockGrpSvc, mocks.NewServiceuserService(t), mockAuditRepo) + + err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) + assert.NoError(t, err) + }) + + t.Run("should return error if hierarchy relation creation fails", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + + mockRelSvc.EXPECT().Create(ctx, groupOrgRelation).Return(relation.Relation{}, errors.New("spicedb unavailable")) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + + err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) + assert.ErrorContains(t, err, "link group to org") + }) +} diff --git a/internal/api/v1beta1connect/errors.go b/internal/api/v1beta1connect/errors.go index 18a45c72c..42c870ea8 100644 --- a/internal/api/v1beta1connect/errors.go +++ b/internal/api/v1beta1connect/errors.go @@ -37,6 +37,9 @@ var ( ErrNotMember = errors.New("principal is not a member of the resource") ErrInvalidOrgRole = errors.New("role is not valid for organization scope") ErrInvalidProjectRole = errors.New("role is not valid for project scope") + ErrInvalidGroupRole = errors.New("role is not valid for group scope") + ErrLastGroupOwnerRole = errors.New("group must have at least one owner, consider assigning another owner before changing this user's role") + ErrNotOrgMember = errors.New("principal is not a member of the organization") ErrEmptyEmailID = errors.New("email id is empty") ErrEmailConflict = errors.New("user email can't be updated") ErrCustomerNotFound = errors.New("customer doesn't exist") diff --git a/internal/api/v1beta1connect/group.go b/internal/api/v1beta1connect/group.go index 6685eaa77..ad39a29cc 100644 --- a/internal/api/v1beta1connect/group.go +++ b/internal/api/v1beta1connect/group.go @@ -462,6 +462,60 @@ func (h *ConnectHandler) RemoveGroupUser(ctx context.Context, request *connect.R return connect.NewResponse(&frontierv1beta1.RemoveGroupUserResponse{}), nil } +func (h *ConnectHandler) SetGroupMemberRole(ctx context.Context, request *connect.Request[frontierv1beta1.SetGroupMemberRoleRequest]) (*connect.Response[frontierv1beta1.SetGroupMemberRoleResponse], error) { + errorLogger := NewErrorLogger() + + orgID := request.Msg.GetOrgId() + groupID := request.Msg.GetGroupId() + principalID := request.Msg.GetPrincipalId() + principalType := request.Msg.GetPrincipalType() + roleID := request.Msg.GetRoleId() + + if _, err := h.orgService.Get(ctx, orgID); err != nil { + switch { + case errors.Is(err, organization.ErrDisabled): + return nil, connect.NewError(connect.CodeNotFound, ErrOrgDisabled) + case errors.Is(err, organization.ErrNotExist): + return nil, connect.NewError(connect.CodeNotFound, ErrOrgNotFound) + default: + errorLogger.LogServiceError(ctx, request, "SetGroupMemberRole.GetOrg", err, + "org_id", orgID) + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + if err := h.membershipService.SetGroupMemberRole(ctx, groupID, principalID, principalType, roleID); err != nil { + errorLogger.LogServiceError(ctx, request, "SetGroupMemberRole", err, + "group_id", groupID, + "principal_id", principalID, + "principal_type", principalType, + "role_id", roleID) + + switch { + case errors.Is(err, group.ErrNotExist), errors.Is(err, group.ErrInvalidID), errors.Is(err, group.ErrInvalidUUID): + return nil, connect.NewError(connect.CodeNotFound, ErrGroupNotFound) + case errors.Is(err, user.ErrNotExist): + return nil, connect.NewError(connect.CodeNotFound, ErrUserNotExist) + case errors.Is(err, user.ErrDisabled): + return nil, connect.NewError(connect.CodeFailedPrecondition, err) + case errors.Is(err, role.ErrNotExist), errors.Is(err, role.ErrInvalidID): + return nil, connect.NewError(connect.CodeNotFound, ErrInvalidRoleID) + case errors.Is(err, membership.ErrInvalidPrincipalType), errors.Is(err, membership.ErrInvalidPrincipal): + return nil, connect.NewError(connect.CodeInvalidArgument, err) + case errors.Is(err, membership.ErrInvalidGroupRole): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrInvalidGroupRole) + case errors.Is(err, membership.ErrNotMember): + return nil, connect.NewError(connect.CodeFailedPrecondition, ErrNotMember) + case errors.Is(err, membership.ErrLastGroupOwnerRole): + return nil, connect.NewError(connect.CodeFailedPrecondition, ErrLastGroupOwnerRole) + default: + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + return connect.NewResponse(&frontierv1beta1.SetGroupMemberRoleResponse{}), nil +} + func (h *ConnectHandler) EnableGroup(ctx context.Context, request *connect.Request[frontierv1beta1.EnableGroupRequest]) (*connect.Response[frontierv1beta1.EnableGroupResponse], error) { errorLogger := NewErrorLogger() diff --git a/internal/api/v1beta1connect/group_test.go b/internal/api/v1beta1connect/group_test.go index c17b6f00e..84df5a19b 100644 --- a/internal/api/v1beta1connect/group_test.go +++ b/internal/api/v1beta1connect/group_test.go @@ -1717,3 +1717,148 @@ func TestConnectHandler_DeleteGroup(t *testing.T) { }) } } + +func TestConnectHandler_SetGroupMemberRole(t *testing.T) { + someGroupID := utils.NewString() + somePrincipalID := utils.NewString() + someRoleID := utils.NewString() + + baseRequest := func() *connect.Request[frontierv1beta1.SetGroupMemberRoleRequest] { + return connect.NewRequest(&frontierv1beta1.SetGroupMemberRoleRequest{ + OrgId: testOrgID, + GroupId: someGroupID, + PrincipalId: somePrincipalID, + PrincipalType: schema.UserPrincipal, + RoleId: someRoleID, + }) + } + + tests := []struct { + name string + setup func(ms *mocks.MembershipService, os *mocks.OrganizationService) + request *connect.Request[frontierv1beta1.SetGroupMemberRoleRequest] + want *connect.Response[frontierv1beta1.SetGroupMemberRoleResponse] + wantErr error + }{ + { + name: "should return not found if org does not exist", + setup: func(_ *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(organization.Organization{}, organization.ErrNotExist) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrOrgNotFound), + }, + { + name: "should return not found if org is disabled", + setup: func(_ *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(organization.Organization{}, organization.ErrDisabled) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrOrgDisabled), + }, + { + name: "should return not found if group does not exist", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(group.ErrNotExist) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrGroupNotFound), + }, + { + name: "should return not found if user does not exist", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(user.ErrNotExist) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrUserNotExist), + }, + { + name: "should return not found if role does not exist", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(role.ErrNotExist) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrInvalidRoleID), + }, + { + name: "should return invalid argument if role is not valid for group scope", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(membership.ErrInvalidGroupRole) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeInvalidArgument, ErrInvalidGroupRole), + }, + { + name: "should return invalid argument if principal type is unsupported", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(membership.ErrInvalidPrincipalType) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeInvalidArgument, membership.ErrInvalidPrincipalType), + }, + { + name: "should return failed precondition if principal is not a group member", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(membership.ErrNotMember) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeFailedPrecondition, ErrNotMember), + }, + { + name: "should return failed precondition if demoting last group owner", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(membership.ErrLastGroupOwnerRole) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeFailedPrecondition, ErrLastGroupOwnerRole), + }, + { + name: "should return internal error for unknown errors", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(errors.New("unknown")) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeInternal, ErrInternalServerError), + }, + { + name: "should return success on valid request", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(nil) + }, + request: baseRequest(), + want: connect.NewResponse(&frontierv1beta1.SetGroupMemberRoleResponse{}), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockMembershipSvc := new(mocks.MembershipService) + mockOrgSvc := new(mocks.OrganizationService) + if tt.setup != nil { + tt.setup(mockMembershipSvc, mockOrgSvc) + } + h := ConnectHandler{ + membershipService: mockMembershipSvc, + orgService: mockOrgSvc, + } + got, err := h.SetGroupMemberRole(context.Background(), tt.request) + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.wantErr.(*connect.Error).Code(), err.(*connect.Error).Code()) + assert.Equal(t, tt.wantErr.(*connect.Error).Message(), err.(*connect.Error).Message()) + } else { + assert.NoError(t, err) + assert.EqualValues(t, tt.want, got) + } + }) + } +} diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 6d1f7e24e..a37004a5c 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -406,6 +406,9 @@ type MembershipService interface { SetProjectMemberRole(ctx context.Context, projectID, principalID, principalType, roleID string) error RemoveProjectMember(ctx context.Context, projectID, principalID, principalType string) error ListPrincipalsByResource(ctx context.Context, resourceID, resourceType string, filter membership.MemberFilter) ([]membership.Member, error) + AddGroupMember(ctx context.Context, groupID, principalID, principalType, roleID string) error + SetGroupMemberRole(ctx context.Context, groupID, principalID, principalType, roleID string) error + OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, creatorType string) error } type UserPATService interface { diff --git a/internal/api/v1beta1connect/mocks/membership_service.go b/internal/api/v1beta1connect/mocks/membership_service.go index cbe58da4a..ca0b3f620 100644 --- a/internal/api/v1beta1connect/mocks/membership_service.go +++ b/internal/api/v1beta1connect/mocks/membership_service.go @@ -22,6 +22,56 @@ func (_m *MembershipService) EXPECT() *MembershipService_Expecter { return &MembershipService_Expecter{mock: &_m.Mock} } +// AddGroupMember provides a mock function with given fields: ctx, groupID, principalID, principalType, roleID +func (_m *MembershipService) AddGroupMember(ctx context.Context, groupID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, groupID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for AddGroupMember") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, groupID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_AddGroupMember_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddGroupMember' +type MembershipService_AddGroupMember_Call struct { + *mock.Call +} + +// AddGroupMember is a helper method to define mock.On call +// - ctx context.Context +// - groupID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) AddGroupMember(ctx interface{}, groupID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_AddGroupMember_Call { + return &MembershipService_AddGroupMember_Call{Call: _e.mock.On("AddGroupMember", ctx, groupID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_AddGroupMember_Call) Run(run func(ctx context.Context, groupID string, principalID string, principalType string, roleID string)) *MembershipService_AddGroupMember_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_AddGroupMember_Call) Return(_a0 error) *MembershipService_AddGroupMember_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_AddGroupMember_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_AddGroupMember_Call { + _c.Call.Return(run) + return _c +} + // AddOrganizationMember provides a mock function with given fields: ctx, orgID, principalID, principalType, roleID func (_m *MembershipService) AddOrganizationMember(ctx context.Context, orgID string, principalID string, principalType string, roleID string) error { ret := _m.Called(ctx, orgID, principalID, principalType, roleID) @@ -133,6 +183,56 @@ func (_c *MembershipService_ListPrincipalsByResource_Call) RunAndReturn(run func return _c } +// OnGroupCreated provides a mock function with given fields: ctx, groupID, orgID, creatorID, creatorType +func (_m *MembershipService) OnGroupCreated(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string) error { + ret := _m.Called(ctx, groupID, orgID, creatorID, creatorType) + + if len(ret) == 0 { + panic("no return value specified for OnGroupCreated") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, groupID, orgID, creatorID, creatorType) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_OnGroupCreated_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnGroupCreated' +type MembershipService_OnGroupCreated_Call struct { + *mock.Call +} + +// OnGroupCreated is a helper method to define mock.On call +// - ctx context.Context +// - groupID string +// - orgID string +// - creatorID string +// - creatorType string +func (_e *MembershipService_Expecter) OnGroupCreated(ctx interface{}, groupID interface{}, orgID interface{}, creatorID interface{}, creatorType interface{}) *MembershipService_OnGroupCreated_Call { + return &MembershipService_OnGroupCreated_Call{Call: _e.mock.On("OnGroupCreated", ctx, groupID, orgID, creatorID, creatorType)} +} + +func (_c *MembershipService_OnGroupCreated_Call) Run(run func(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string)) *MembershipService_OnGroupCreated_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_OnGroupCreated_Call) Return(_a0 error) *MembershipService_OnGroupCreated_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_OnGroupCreated_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_OnGroupCreated_Call { + _c.Call.Return(run) + return _c +} + // RemoveOrganizationMember provides a mock function with given fields: ctx, orgID, principalID, principalType func (_m *MembershipService) RemoveOrganizationMember(ctx context.Context, orgID string, principalID string, principalType string) error { ret := _m.Called(ctx, orgID, principalID, principalType) @@ -231,6 +331,56 @@ func (_c *MembershipService_RemoveProjectMember_Call) RunAndReturn(run func(cont return _c } +// SetGroupMemberRole provides a mock function with given fields: ctx, groupID, principalID, principalType, roleID +func (_m *MembershipService) SetGroupMemberRole(ctx context.Context, groupID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, groupID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for SetGroupMemberRole") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, groupID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_SetGroupMemberRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetGroupMemberRole' +type MembershipService_SetGroupMemberRole_Call struct { + *mock.Call +} + +// SetGroupMemberRole is a helper method to define mock.On call +// - ctx context.Context +// - groupID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) SetGroupMemberRole(ctx interface{}, groupID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_SetGroupMemberRole_Call { + return &MembershipService_SetGroupMemberRole_Call{Call: _e.mock.On("SetGroupMemberRole", ctx, groupID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_SetGroupMemberRole_Call) Run(run func(ctx context.Context, groupID string, principalID string, principalType string, roleID string)) *MembershipService_SetGroupMemberRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_SetGroupMemberRole_Call) Return(_a0 error) *MembershipService_SetGroupMemberRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetGroupMemberRole_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_SetGroupMemberRole_Call { + _c.Call.Return(run) + return _c +} + // SetOrganizationMemberRole provides a mock function with given fields: ctx, orgID, principalID, principalType, roleID func (_m *MembershipService) SetOrganizationMemberRole(ctx context.Context, orgID string, principalID string, principalType string, roleID string) error { ret := _m.Called(ctx, orgID, principalID, principalType, roleID) diff --git a/pkg/auditrecord/consts.go b/pkg/auditrecord/consts.go index ca8a09bf2..4045d2eca 100644 --- a/pkg/auditrecord/consts.go +++ b/pkg/auditrecord/consts.go @@ -43,6 +43,10 @@ const ( ProjectMemberRoleChangedEvent Event = "project.member_role_changed" ProjectMemberRemovedEvent Event = "project.member_removed" + // Group Member Events + GroupMemberAddedEvent Event = "group.member_added" + GroupMemberRoleChangedEvent Event = "group.member_role_changed" + // KYC Events KYCVerifiedEvent Event = "kyc.verified" KYCUnverifiedEvent Event = "kyc.unverified" From 291b0fc027b84fa570e7b766c3d26856fa6bf067 Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Tue, 12 May 2026 13:03:42 +0530 Subject: [PATCH 2/3] fix(authz): register SetGroupMemberRole in authorization map The authorization interceptor denies any procedure not in its map. Without this entry the new RPC returned PermissionDenied at the interceptor before reaching the handler. Uses GroupNamespace + UpdatePermission, matching AddGroupUsers/RemoveGroupUser. Co-Authored-By: Claude Opus 4.7 (1M context) --- pkg/server/connect_interceptors/authorization.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/server/connect_interceptors/authorization.go b/pkg/server/connect_interceptors/authorization.go index 39b65e81e..9a836e027 100644 --- a/pkg/server/connect_interceptors/authorization.go +++ b/pkg/server/connect_interceptors/authorization.go @@ -474,6 +474,10 @@ var authorizationValidationMap = map[string]func(ctx context.Context, handler *v pbreq := req.(*connect.Request[frontierv1beta1.RemoveGroupUserRequest]) return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.GroupNamespace, ID: pbreq.Msg.GetId()}, schema.UpdatePermission, req) }, + "/raystack.frontier.v1beta1.FrontierService/SetGroupMemberRole": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error { + pbreq := req.(*connect.Request[frontierv1beta1.SetGroupMemberRoleRequest]) + return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.GroupNamespace, ID: pbreq.Msg.GetGroupId()}, schema.UpdatePermission, req) + }, "/raystack.frontier.v1beta1.FrontierService/EnableGroup": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error { pbreq := req.(*connect.Request[frontierv1beta1.EnableGroupRequest]) return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.GroupNamespace, ID: pbreq.Msg.GetId()}, schema.DeletePermission, req) From f4d66d2335185ac96ce30e95f0e69a432a2941bf Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Wed, 13 May 2026 09:04:16 +0530 Subject: [PATCH 3/3] fix(membership): rollback group hierarchy on OnGroupCreated failure When the second hierarchy relation or the owner add fails, the partially written relations are best-effort cleaned up so a group isn't left in a half-linked or unowned state. Adds unlinkGroupFromOrg helper. Also fixes the SetGroupMemberRole handler test for unsupported principal type: previously it sent app/user and forced the service to return an error, which only exercised error mapping. Now it sends app/serviceuser and asserts the handler forwards that value unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- core/membership/service.go | 51 +++++++++++++++++++++-- core/membership/service_test.go | 43 +++++++++++++++++++ internal/api/v1beta1connect/group_test.go | 11 ++++- 3 files changed, 100 insertions(+), 5 deletions(-) diff --git a/core/membership/service.go b/core/membership/service.go index 8aa33f73b..bb81be5c3 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -1142,12 +1142,20 @@ func (s *Service) SetGroupMemberRole(ctx context.Context, groupID, principalID, // OnGroupCreated wires up SpiceDB relations for a newly-created group: // links the group to its parent organization (both directions) and adds the -// creator as owner via AddGroupMember. +// creator as owner via AddGroupMember. If the owner add fails, hierarchy +// relations are best-effort rolled back to avoid an unowned, half-linked group. func (s *Service) OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, creatorType string) error { if err := s.linkGroupToOrg(ctx, groupID, orgID); err != nil { return err } if err := s.AddGroupMember(ctx, groupID, creatorID, creatorType, schema.GroupOwnerRole); err != nil { + if cleanupErr := s.unlinkGroupFromOrg(ctx, groupID, orgID); cleanupErr != nil { + s.log.WarnContext(ctx, "group hierarchy cleanup failed after owner add failure", + "group_id", groupID, + "org_id", orgID, + "error", cleanupErr, + ) + } return err } return nil @@ -1156,12 +1164,16 @@ func (s *Service) OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, // linkGroupToOrg creates the two hierarchy relations between a group and its org: // - group#org@organization (identity link from group to org) // - organization#member@group#member (lets org#member traverse to group members) +// +// If the second relation fails, the first is best-effort rolled back so we +// don't leave a one-way link. func (s *Service) linkGroupToOrg(ctx context.Context, groupID, orgID string) error { - if _, err := s.relationService.Create(ctx, relation.Relation{ + groupOrg := relation.Relation{ Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, Subject: relation.Subject{ID: orgID, Namespace: schema.OrganizationNamespace}, RelationName: schema.OrganizationRelationName, - }); err != nil { + } + if _, err := s.relationService.Create(ctx, groupOrg); err != nil { return fmt.Errorf("link group to org: %w", err) } @@ -1174,12 +1186,45 @@ func (s *Service) linkGroupToOrg(ctx context.Context, groupID, orgID string) err }, RelationName: schema.MemberRelationName, }); err != nil { + if delErr := s.relationService.Delete(ctx, groupOrg); delErr != nil && !errors.Is(delErr, relation.ErrNotExist) { + s.log.WarnContext(ctx, "group->org rollback failed after org member relation failure", + "group_id", groupID, + "org_id", orgID, + "error", delErr, + ) + } return fmt.Errorf("add group as org member: %w", err) } return nil } +// unlinkGroupFromOrg removes both hierarchy relations between a group and its +// org. Used as best-effort cleanup when group-create wiring fails partway. +// relation.ErrNotExist is ignored; any other error is returned. +func (s *Service) unlinkGroupFromOrg(ctx context.Context, groupID, orgID string) error { + if err := s.relationService.Delete(ctx, relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.OrganizationRelationName, + }); err != nil && !errors.Is(err, relation.ErrNotExist) { + return err + } + + if err := s.relationService.Delete(ctx, relation.Relation{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ + ID: groupID, + Namespace: schema.GroupNamespace, + SubRelationName: schema.MemberRelationName, + }, + RelationName: schema.MemberRelationName, + }); err != nil && !errors.Is(err, relation.ErrNotExist) { + return err + } + return nil +} + // validateGroupRole checks that the role is valid for group scope: // - a platform-wide role scoped to groups, or // - a custom role created for the group's parent organization. diff --git a/core/membership/service_test.go b/core/membership/service_test.go index 02ac2f430..186dd1203 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -1595,4 +1595,47 @@ func TestService_OnGroupCreated(t *testing.T) { err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) assert.ErrorContains(t, err, "link group to org") }) + + t.Run("should rollback first hierarchy relation if second fails", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + + mockRelSvc.EXPECT().Create(ctx, groupOrgRelation).Return(relation.Relation{}, nil) + mockRelSvc.EXPECT().Create(ctx, orgGroupMemberRelation).Return(relation.Relation{}, errors.New("spicedb unavailable")) + // rollback: delete the first hierarchy relation + mockRelSvc.EXPECT().Delete(ctx, groupOrgRelation).Return(nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + + err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) + assert.ErrorContains(t, err, "add group as org member") + }) + + t.Run("should rollback both hierarchy relations if owner add fails", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockUserSvc := mocks.NewUserService(t) + + // linkGroupToOrg succeeds + mockRelSvc.EXPECT().Create(ctx, groupOrgRelation).Return(relation.Relation{}, nil) + mockRelSvc.EXPECT().Create(ctx, orgGroupMemberRelation).Return(relation.Relation{}, nil) + + // AddGroupMember fails before policy creation (group fetch fails) + mockGrpSvc.EXPECT().Get(ctx, groupID).Return(group.Group{}, errors.New("db down")) + + // unused mocks: only here for completeness, won't be called + _ = mockRoleSvc + _ = mockUserSvc + + // rollback: delete both hierarchy relations + mockRelSvc.EXPECT().Delete(ctx, groupOrgRelation).Return(nil) + mockRelSvc.EXPECT().Delete(ctx, orgGroupMemberRelation).Return(nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mocks.NewProjectService(t), mockGrpSvc, mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + + err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) + assert.ErrorContains(t, err, "db down") + }) } diff --git a/internal/api/v1beta1connect/group_test.go b/internal/api/v1beta1connect/group_test.go index 84df5a19b..bd12739e1 100644 --- a/internal/api/v1beta1connect/group_test.go +++ b/internal/api/v1beta1connect/group_test.go @@ -1796,9 +1796,16 @@ func TestConnectHandler_SetGroupMemberRole(t *testing.T) { name: "should return invalid argument if principal type is unsupported", setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) - ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(membership.ErrInvalidPrincipalType) + // handler must forward the unsupported principal_type to the service unchanged + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.ServiceUserPrincipal, someRoleID).Return(membership.ErrInvalidPrincipalType) }, - request: baseRequest(), + request: connect.NewRequest(&frontierv1beta1.SetGroupMemberRoleRequest{ + OrgId: testOrgID, + GroupId: someGroupID, + PrincipalId: somePrincipalID, + PrincipalType: schema.ServiceUserPrincipal, + RoleId: someRoleID, + }), wantErr: connect.NewError(connect.CodeInvalidArgument, membership.ErrInvalidPrincipalType), }, {