diff --git a/core/audit/audit.go b/core/audit/audit.go index c0ac6b898..8b5a6018e 100644 --- a/core/audit/audit.go +++ b/core/audit/audit.go @@ -55,10 +55,12 @@ const ( ServiceUserCreatedEvent EventName = "app.serviceuser.created" ServiceUserDeletedEvent EventName = "app.serviceuser.deleted" - GroupCreatedEvent EventName = "app.group.created" - GroupUpdatedEvent EventName = "app.group.updated" - GroupDeletedEvent EventName = "app.group.deleted" - GroupMemberRemovedEvent EventName = "app.group.members.removed" + GroupCreatedEvent EventName = "app.group.created" + GroupUpdatedEvent EventName = "app.group.updated" + GroupDeletedEvent EventName = "app.group.deleted" + GroupMemberCreatedEvent EventName = "app.group.member.created" + GroupMemberRoleChangedEvent EventName = "app.group.member.role_changed" + GroupMemberRemovedEvent EventName = "app.group.members.removed" RoleCreatedEvent EventName = "app.role.created" RoleUpdatedEvent EventName = "app.role.updated" diff --git a/core/membership/errors.go b/core/membership/errors.go index 49666425e..ce835ae6c 100644 --- a/core/membership/errors.go +++ b/core/membership/errors.go @@ -13,4 +13,6 @@ var ( ErrNotOrgMember = errors.New("principal is not a member of the organization") ErrInvalidProjectRole = errors.New("role is not valid for project scope") ErrInvalidResourceType = errors.New("unsupported resource type") + ErrInvalidGroupRole = errors.New("role is not valid for group scope") + ErrLastGroupOwnerRole = errors.New("cannot change role: this is the last owner of the group") ) diff --git a/core/membership/service.go b/core/membership/service.go index fab850fbe..bb81be5c3 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -1019,3 +1019,362 @@ func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, reso return members, nil } + +// AddGroupMember adds a principal as a member of a group with an explicit role. +// Returns ErrAlreadyMember if the principal already has a policy on this group. +// The principal must be a member of the group's parent organization. +func (s *Service) AddGroupMember(ctx context.Context, groupID, principalID, principalType, roleID string) error { + grp, err := s.groupService.Get(ctx, groupID) + if err != nil { + return err + } + + principal, err := s.validateGroupPrincipal(ctx, principalID, principalType) + if err != nil { + return err + } + + fetchedRole, err := s.validateGroupRole(ctx, roleID, grp.OrganizationID) + if err != nil { + return err + } + + if err := s.validateOrgMembership(ctx, grp.OrganizationID, principalID, principalType); err != nil { + return err + } + + existing, err := s.policyService.List(ctx, policy.Filter{ + GroupID: groupID, + PrincipalID: principalID, + PrincipalType: principalType, + }) + if err != nil { + return fmt.Errorf("list existing policies: %w", err) + } + if len(existing) > 0 { + return ErrAlreadyMember + } + + createdPolicy, err := s.createPolicy(ctx, groupID, schema.GroupNamespace, principalID, principalType, fetchedRole.ID) + if err != nil { + return err + } + + relationName := groupRoleToRelation(fetchedRole) + if err := s.createRelation(ctx, groupID, schema.GroupNamespace, principalID, principalType, relationName); err != nil { + if deleteErr := s.policyService.Delete(ctx, createdPolicy.ID); deleteErr != nil { + s.log.WarnContext(ctx, "orphaned policy: relation creation failed and policy cleanup also failed", + "policy_id", createdPolicy.ID, + "group_id", groupID, + "principal_id", principalID, + "policy_delete_error", deleteErr, + ) + } + return err + } + + s.auditGroupMemberAdded(ctx, grp, principal, fetchedRole.ID) + return nil +} + +// SetGroupMemberRole changes an existing member's role in a group. +// Returns ErrNotMember if the principal has no existing policy on the group. +// Enforces the min-owner constraint: demoting the last owner returns ErrLastGroupOwnerRole. +func (s *Service) SetGroupMemberRole(ctx context.Context, groupID, principalID, principalType, roleID string) error { + grp, err := s.groupService.Get(ctx, groupID) + if err != nil { + return err + } + + principal, err := s.validateGroupPrincipal(ctx, principalID, principalType) + if err != nil { + return err + } + + fetchedRole, err := s.validateGroupRole(ctx, roleID, grp.OrganizationID) + if err != nil { + return err + } + resolvedRoleID := fetchedRole.ID + + existing, err := s.policyService.List(ctx, policy.Filter{ + GroupID: groupID, + PrincipalID: principalID, + PrincipalType: principalType, + }) + if err != nil { + return fmt.Errorf("list existing policies: %w", err) + } + if len(existing) == 0 { + return ErrNotMember + } + + // skip if the user already has exactly this role + if len(existing) == 1 && existing[0].RoleID == resolvedRoleID { + return nil + } + + if err := s.validateMinGroupOwnerConstraint(ctx, groupID, resolvedRoleID, existing); err != nil { + return err + } + + if err := s.replacePolicy(ctx, groupID, schema.GroupNamespace, principalID, principalType, resolvedRoleID, existing); err != nil { + return err + } + + newRelation := groupRoleToRelation(fetchedRole) + oldRelations := []string{schema.OwnerRelationName, schema.MemberRelationName} + if err := s.replaceRelation(ctx, groupID, schema.GroupNamespace, principalID, principalType, oldRelations, newRelation); err != nil { + s.log.ErrorContext(ctx, "membership state inconsistent: policy replaced but group relation update failed, needs manual fix", + "group_id", groupID, + "principal_id", principalID, + "principal_type", principalType, + "new_role_id", resolvedRoleID, + "expected_relation", newRelation, + "error", err, + ) + return err + } + + s.auditGroupMemberRoleChanged(ctx, grp, principal, resolvedRoleID) + return nil +} + +// OnGroupCreated wires up SpiceDB relations for a newly-created group: +// links the group to its parent organization (both directions) and adds the +// creator as owner via AddGroupMember. If the owner add fails, hierarchy +// relations are best-effort rolled back to avoid an unowned, half-linked group. +func (s *Service) OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, creatorType string) error { + if err := s.linkGroupToOrg(ctx, groupID, orgID); err != nil { + return err + } + if err := s.AddGroupMember(ctx, groupID, creatorID, creatorType, schema.GroupOwnerRole); err != nil { + if cleanupErr := s.unlinkGroupFromOrg(ctx, groupID, orgID); cleanupErr != nil { + s.log.WarnContext(ctx, "group hierarchy cleanup failed after owner add failure", + "group_id", groupID, + "org_id", orgID, + "error", cleanupErr, + ) + } + return err + } + return nil +} + +// linkGroupToOrg creates the two hierarchy relations between a group and its org: +// - group#org@organization (identity link from group to org) +// - organization#member@group#member (lets org#member traverse to group members) +// +// If the second relation fails, the first is best-effort rolled back so we +// don't leave a one-way link. +func (s *Service) linkGroupToOrg(ctx context.Context, groupID, orgID string) error { + groupOrg := relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.OrganizationRelationName, + } + if _, err := s.relationService.Create(ctx, groupOrg); err != nil { + return fmt.Errorf("link group to org: %w", err) + } + + if _, err := s.relationService.Create(ctx, relation.Relation{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ + ID: groupID, + Namespace: schema.GroupNamespace, + SubRelationName: schema.MemberRelationName, + }, + RelationName: schema.MemberRelationName, + }); err != nil { + if delErr := s.relationService.Delete(ctx, groupOrg); delErr != nil && !errors.Is(delErr, relation.ErrNotExist) { + s.log.WarnContext(ctx, "group->org rollback failed after org member relation failure", + "group_id", groupID, + "org_id", orgID, + "error", delErr, + ) + } + return fmt.Errorf("add group as org member: %w", err) + } + + return nil +} + +// unlinkGroupFromOrg removes both hierarchy relations between a group and its +// org. Used as best-effort cleanup when group-create wiring fails partway. +// relation.ErrNotExist is ignored; any other error is returned. +func (s *Service) unlinkGroupFromOrg(ctx context.Context, groupID, orgID string) error { + if err := s.relationService.Delete(ctx, relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.OrganizationRelationName, + }); err != nil && !errors.Is(err, relation.ErrNotExist) { + return err + } + + if err := s.relationService.Delete(ctx, relation.Relation{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ + ID: groupID, + Namespace: schema.GroupNamespace, + SubRelationName: schema.MemberRelationName, + }, + RelationName: schema.MemberRelationName, + }); err != nil && !errors.Is(err, relation.ErrNotExist) { + return err + } + return nil +} + +// validateGroupRole checks that the role is valid for group scope: +// - a platform-wide role scoped to groups, or +// - a custom role created for the group's parent organization. +func (s *Service) validateGroupRole(ctx context.Context, roleID, orgID string) (role.Role, error) { + fetchedRole, err := s.roleService.Get(ctx, roleID) + if err != nil { + return role.Role{}, err + } + if !slices.Contains(fetchedRole.Scopes, schema.GroupNamespace) { + return role.Role{}, ErrInvalidGroupRole + } + if fetchedRole.OrgID == orgID { + return fetchedRole, nil + } + if utils.IsNullUUID(fetchedRole.OrgID) { + return fetchedRole, nil + } + return role.Role{}, ErrInvalidGroupRole +} + +// validateGroupPrincipal fetches and validates the principal for group operations. +// Currently only app/user is supported; the switch is structured so future principal +// types (e.g. serviceuser) can be enabled here without touching call sites. +func (s *Service) validateGroupPrincipal(ctx context.Context, principalID, principalType string) (principalInfo, error) { + switch principalType { + case schema.UserPrincipal: + usr, err := s.userService.GetByID(ctx, principalID) + if err != nil { + return principalInfo{}, err + } + if usr.State == user.Disabled { + return principalInfo{}, user.ErrDisabled + } + return principalInfo{ + ID: usr.ID, + Type: schema.UserPrincipal, + Name: usr.Title, + Email: usr.Email, + }, nil + default: + return principalInfo{}, ErrInvalidPrincipalType + } +} + +// validateMinGroupOwnerConstraint ensures the group keeps at least one owner +// after the role change. Mirrors the org-level constraint. +func (s *Service) validateMinGroupOwnerConstraint(ctx context.Context, groupID, newRoleID string, existing []policy.Policy) error { + ownerRole, err := s.roleService.Get(ctx, schema.GroupOwnerRole) + if err != nil { + return fmt.Errorf("get group owner role: %w", err) + } + + if newRoleID == ownerRole.ID { + return nil + } + + isCurrentlyOwner := false + for _, p := range existing { + if p.RoleID == ownerRole.ID { + isCurrentlyOwner = true + break + } + } + if !isCurrentlyOwner { + return nil + } + + ownerPolicies, err := s.policyService.List(ctx, policy.Filter{ + GroupID: groupID, + RoleID: ownerRole.ID, + }) + if err != nil { + return fmt.Errorf("list group owner policies: %w", err) + } + if len(ownerPolicies) <= 1 { + return ErrLastGroupOwnerRole + } + return nil +} + +// groupRoleToRelation maps a group role to the matching SpiceDB relation name. +func groupRoleToRelation(r role.Role) string { + if r.Name == schema.GroupOwnerRole { + return schema.OwnerRelationName + } + return schema.MemberRelationName +} + +func (s *Service) auditGroupMemberAdded(ctx context.Context, grp group.Group, p principalInfo, roleID string) { + targetType, _ := principalTypeToAuditType(p.Type) + meta := map[string]any{"role_id": roleID} + if p.Email != "" { + meta["email"] = p.Email + } + + s.auditRecordRepository.Create(ctx, auditrecord.AuditRecord{ + Event: pkgAuditRecord.GroupMemberAddedEvent, + Resource: auditrecord.Resource{ + ID: grp.ID, + Type: pkgAuditRecord.GroupType, + Name: grp.Title, + }, + Target: &auditrecord.Target{ + ID: p.ID, + Type: targetType, + Name: p.Name, + Metadata: meta, + }, + OrgID: grp.OrganizationID, + OccurredAt: time.Now(), + }) + + audit.GetAuditor(ctx, grp.OrganizationID).LogWithAttrs(audit.GroupMemberCreatedEvent, audit.Target{ + ID: p.ID, + Type: p.Type, + }, map[string]string{ + "role_id": roleID, + "group_id": grp.ID, + }) +} + +func (s *Service) auditGroupMemberRoleChanged(ctx context.Context, grp group.Group, p principalInfo, roleID string) { + targetType, _ := principalTypeToAuditType(p.Type) + meta := map[string]any{"role_id": roleID} + if p.Email != "" { + meta["email"] = p.Email + } + + s.auditRecordRepository.Create(ctx, auditrecord.AuditRecord{ + Event: pkgAuditRecord.GroupMemberRoleChangedEvent, + Resource: auditrecord.Resource{ + ID: grp.ID, + Type: pkgAuditRecord.GroupType, + Name: grp.Title, + }, + Target: &auditrecord.Target{ + ID: p.ID, + Type: targetType, + Name: p.Name, + Metadata: meta, + }, + OrgID: grp.OrganizationID, + OccurredAt: time.Now(), + }) + + audit.GetAuditor(ctx, grp.OrganizationID).LogWithAttrs(audit.GroupMemberRoleChangedEvent, audit.Target{ + ID: p.ID, + Type: p.Type, + }, map[string]string{ + "role_id": roleID, + "group_id": grp.ID, + }) +} diff --git a/core/membership/service_test.go b/core/membership/service_test.go index 2b5e37feb..186dd1203 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -1217,3 +1217,425 @@ func TestService_ListPrincipalsByResource(t *testing.T) { }) } } + +func TestService_AddGroupMember(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + groupID := uuid.New().String() + userID := uuid.New().String() + ownerRoleID := uuid.New().String() + memberRoleID := uuid.New().String() + + enabledUser := user.User{ID: userID, Title: "test-user", Email: "test@acme.dev", State: user.Enabled} + grp := group.Group{ID: groupID, OrganizationID: orgID, Title: "Test Group"} + + groupMemberRelation := func(name string) relation.Relation { + return relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, + RelationName: name, + } + } + + tests := []struct { + name string + setup func(*mocks.PolicyService, *mocks.RelationService, *mocks.RoleService, *mocks.GroupService, *mocks.UserService, *mocks.AuditRecordRepository) + principalType string + roleID string + wantErr error + wantErrContain string + }{ + { + name: "should return error if group does not exist", + setup: func(_ *mocks.PolicyService, _ *mocks.RelationService, _ *mocks.RoleService, grpSvc *mocks.GroupService, _ *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(group.Group{}, group.ErrNotExist) + }, + roleID: memberRoleID, + wantErr: group.ErrNotExist, + }, + { + name: "should return error if principal type is unsupported", + setup: func(_ *mocks.PolicyService, _ *mocks.RelationService, _ *mocks.RoleService, grpSvc *mocks.GroupService, _ *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + }, + principalType: schema.ServiceUserPrincipal, + roleID: memberRoleID, + wantErr: membership.ErrInvalidPrincipalType, + }, + { + name: "should return error if user is disabled", + setup: func(_ *mocks.PolicyService, _ *mocks.RelationService, _ *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(user.User{ID: userID, State: user.Disabled}, nil) + }, + roleID: memberRoleID, + wantErr: user.ErrDisabled, + }, + { + name: "should return error if role is not valid for group scope", + setup: func(_ *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrInvalidGroupRole, + }, + { + name: "should return error if user is not a member of the org", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrNotOrgMember, + }, + { + name: "should return ErrAlreadyMember if principal has existing group policy", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "g-p1"}}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrAlreadyMember, + }, + { + name: "should succeed adding a member with member role", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, policy.Policy{ + RoleID: memberRoleID, ResourceID: groupID, ResourceType: schema.GroupNamespace, + PrincipalID: userID, PrincipalType: schema.UserPrincipal, + }).Return(policy.Policy{ID: "new-p"}, nil) + relSvc.EXPECT().Create(ctx, groupMemberRelation(schema.MemberRelationName)).Return(relation.Relation{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + roleID: memberRoleID, + wantErr: nil, + }, + { + name: "should succeed adding a member with owner role", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, ownerRoleID).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{ID: "new-p"}, nil) + relSvc.EXPECT().Create(ctx, groupMemberRelation(schema.OwnerRelationName)).Return(relation.Relation{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + roleID: ownerRoleID, + wantErr: nil, + }, + { + name: "should cleanup policy if relation creation fails", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{ID: "created-p"}, nil) + relSvc.EXPECT().Create(ctx, mock.Anything).Return(relation.Relation{}, errors.New("spicedb unavailable")) + policySvc.EXPECT().Delete(ctx, "created-p").Return(nil) + }, + roleID: memberRoleID, + wantErrContain: "spicedb unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + if tt.setup != nil { + tt.setup(mockPolicySvc, mockRelSvc, mockRoleSvc, mockGrpSvc, mockUserSvc, mockAuditRepo) + } + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mocks.NewProjectService(t), mockGrpSvc, mocks.NewServiceuserService(t), mockAuditRepo) + + principalType := tt.principalType + if principalType == "" { + principalType = schema.UserPrincipal + } + err := svc.AddGroupMember(ctx, groupID, userID, principalType, tt.roleID) + + if tt.wantErr != nil { + assert.ErrorIs(t, err, tt.wantErr) + } else if tt.wantErrContain != "" { + assert.ErrorContains(t, err, tt.wantErrContain) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestService_SetGroupMemberRole(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + groupID := uuid.New().String() + userID := uuid.New().String() + ownerRoleID := uuid.New().String() + memberRoleID := uuid.New().String() + + enabledUser := user.User{ID: userID, Title: "test-user", Email: "test@acme.dev", State: user.Enabled} + grp := group.Group{ID: groupID, OrganizationID: orgID, Title: "Test Group"} + + groupMemberRelation := func(name string) relation.Relation { + return relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, + RelationName: name, + } + } + + tests := []struct { + name string + setup func(*mocks.PolicyService, *mocks.RelationService, *mocks.RoleService, *mocks.GroupService, *mocks.UserService, *mocks.AuditRecordRepository) + principalType string + roleID string + wantErr error + wantErrContain string + }{ + { + name: "should return error if user is not a group member", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrNotMember, + }, + { + name: "should skip write when role is unchanged", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: memberRoleID}}, nil) + }, + roleID: memberRoleID, + wantErr: nil, + }, + { + name: "should return error if demoting last owner", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) + roleSvc.EXPECT().Get(ctx, schema.GroupOwnerRole).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) + }, + roleID: memberRoleID, + wantErr: membership.ErrLastGroupOwnerRole, + }, + { + name: "should succeed demoting owner to member with multiple owners (relation flips owner->member)", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, memberRoleID).Return(role.Role{ID: memberRoleID, Name: schema.GroupMemberRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) + roleSvc.EXPECT().Get(ctx, schema.GroupOwnerRole).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) + policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + policySvc.EXPECT().Create(ctx, policy.Policy{ + RoleID: memberRoleID, ResourceID: groupID, ResourceType: schema.GroupNamespace, + PrincipalID: userID, PrincipalType: schema.UserPrincipal, + }).Return(policy.Policy{ID: "new-p"}, nil) + relSvc.EXPECT().Delete(ctx, groupMemberRelation(schema.OwnerRelationName)).Return(nil) + relSvc.EXPECT().Delete(ctx, groupMemberRelation(schema.MemberRelationName)).Return(relation.ErrNotExist) + relSvc.EXPECT().Create(ctx, groupMemberRelation(schema.MemberRelationName)).Return(relation.Relation{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + roleID: memberRoleID, + wantErr: nil, + }, + { + name: "should succeed promoting member to owner (relation flips member->owner)", + setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, grpSvc *mocks.GroupService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { + grpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, ownerRoleID).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole, Scopes: []string{schema.GroupNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: memberRoleID}}, nil) + roleSvc.EXPECT().Get(ctx, schema.GroupOwnerRole).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole}, nil) + policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{ID: "new-p"}, nil) + relSvc.EXPECT().Delete(ctx, groupMemberRelation(schema.OwnerRelationName)).Return(relation.ErrNotExist) + relSvc.EXPECT().Delete(ctx, groupMemberRelation(schema.MemberRelationName)).Return(nil) + relSvc.EXPECT().Create(ctx, groupMemberRelation(schema.OwnerRelationName)).Return(relation.Relation{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + roleID: ownerRoleID, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + if tt.setup != nil { + tt.setup(mockPolicySvc, mockRelSvc, mockRoleSvc, mockGrpSvc, mockUserSvc, mockAuditRepo) + } + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mocks.NewProjectService(t), mockGrpSvc, mocks.NewServiceuserService(t), mockAuditRepo) + + principalType := tt.principalType + if principalType == "" { + principalType = schema.UserPrincipal + } + err := svc.SetGroupMemberRole(ctx, groupID, userID, principalType, tt.roleID) + + if tt.wantErr != nil { + assert.ErrorIs(t, err, tt.wantErr) + } else if tt.wantErrContain != "" { + assert.ErrorContains(t, err, tt.wantErrContain) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestService_OnGroupCreated(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + groupID := uuid.New().String() + creatorID := uuid.New().String() + ownerRoleID := uuid.New().String() + + enabledUser := user.User{ID: creatorID, Title: "creator", Email: "creator@acme.dev", State: user.Enabled} + grp := group.Group{ID: groupID, OrganizationID: orgID, Title: "Test Group"} + + groupOrgRelation := relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.OrganizationRelationName, + } + orgGroupMemberRelation := relation.Relation{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ + ID: groupID, + Namespace: schema.GroupNamespace, + SubRelationName: schema.MemberRelationName, + }, + RelationName: schema.MemberRelationName, + } + creatorOwnerRelation := relation.Relation{ + Object: relation.Object{ID: groupID, Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: creatorID, Namespace: schema.UserPrincipal}, + RelationName: schema.OwnerRelationName, + } + + t.Run("should link group<->org and add creator as owner", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockRelSvc.EXPECT().Create(ctx, groupOrgRelation).Return(relation.Relation{}, nil) + mockRelSvc.EXPECT().Create(ctx, orgGroupMemberRelation).Return(relation.Relation{}, nil) + + mockGrpSvc.EXPECT().Get(ctx, groupID).Return(grp, nil) + mockUserSvc.EXPECT().GetByID(ctx, creatorID).Return(enabledUser, nil) + mockRoleSvc.EXPECT().Get(ctx, schema.GroupOwnerRole).Return(role.Role{ID: ownerRoleID, Name: schema.GroupOwnerRole, Scopes: []string{schema.GroupNamespace}}, nil) + mockUserSvc.EXPECT().GetByID(ctx, creatorID).Return(enabledUser, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: creatorID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{GroupID: groupID, PrincipalID: creatorID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{ID: "new-p"}, nil) + mockRelSvc.EXPECT().Create(ctx, creatorOwnerRelation).Return(relation.Relation{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mocks.NewProjectService(t), mockGrpSvc, mocks.NewServiceuserService(t), mockAuditRepo) + + err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) + assert.NoError(t, err) + }) + + t.Run("should return error if hierarchy relation creation fails", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + + mockRelSvc.EXPECT().Create(ctx, groupOrgRelation).Return(relation.Relation{}, errors.New("spicedb unavailable")) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + + err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) + assert.ErrorContains(t, err, "link group to org") + }) + + t.Run("should rollback first hierarchy relation if second fails", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + + mockRelSvc.EXPECT().Create(ctx, groupOrgRelation).Return(relation.Relation{}, nil) + mockRelSvc.EXPECT().Create(ctx, orgGroupMemberRelation).Return(relation.Relation{}, errors.New("spicedb unavailable")) + // rollback: delete the first hierarchy relation + mockRelSvc.EXPECT().Delete(ctx, groupOrgRelation).Return(nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + + err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) + assert.ErrorContains(t, err, "add group as org member") + }) + + t.Run("should rollback both hierarchy relations if owner add fails", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockUserSvc := mocks.NewUserService(t) + + // linkGroupToOrg succeeds + mockRelSvc.EXPECT().Create(ctx, groupOrgRelation).Return(relation.Relation{}, nil) + mockRelSvc.EXPECT().Create(ctx, orgGroupMemberRelation).Return(relation.Relation{}, nil) + + // AddGroupMember fails before policy creation (group fetch fails) + mockGrpSvc.EXPECT().Get(ctx, groupID).Return(group.Group{}, errors.New("db down")) + + // unused mocks: only here for completeness, won't be called + _ = mockRoleSvc + _ = mockUserSvc + + // rollback: delete both hierarchy relations + mockRelSvc.EXPECT().Delete(ctx, groupOrgRelation).Return(nil) + mockRelSvc.EXPECT().Delete(ctx, orgGroupMemberRelation).Return(nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mocks.NewProjectService(t), mockGrpSvc, mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + + err := svc.OnGroupCreated(ctx, groupID, orgID, creatorID, schema.UserPrincipal) + assert.ErrorContains(t, err, "db down") + }) +} diff --git a/internal/api/v1beta1connect/errors.go b/internal/api/v1beta1connect/errors.go index 18a45c72c..42c870ea8 100644 --- a/internal/api/v1beta1connect/errors.go +++ b/internal/api/v1beta1connect/errors.go @@ -37,6 +37,9 @@ var ( ErrNotMember = errors.New("principal is not a member of the resource") ErrInvalidOrgRole = errors.New("role is not valid for organization scope") ErrInvalidProjectRole = errors.New("role is not valid for project scope") + ErrInvalidGroupRole = errors.New("role is not valid for group scope") + ErrLastGroupOwnerRole = errors.New("group must have at least one owner, consider assigning another owner before changing this user's role") + ErrNotOrgMember = errors.New("principal is not a member of the organization") ErrEmptyEmailID = errors.New("email id is empty") ErrEmailConflict = errors.New("user email can't be updated") ErrCustomerNotFound = errors.New("customer doesn't exist") diff --git a/internal/api/v1beta1connect/group.go b/internal/api/v1beta1connect/group.go index 6685eaa77..ad39a29cc 100644 --- a/internal/api/v1beta1connect/group.go +++ b/internal/api/v1beta1connect/group.go @@ -462,6 +462,60 @@ func (h *ConnectHandler) RemoveGroupUser(ctx context.Context, request *connect.R return connect.NewResponse(&frontierv1beta1.RemoveGroupUserResponse{}), nil } +func (h *ConnectHandler) SetGroupMemberRole(ctx context.Context, request *connect.Request[frontierv1beta1.SetGroupMemberRoleRequest]) (*connect.Response[frontierv1beta1.SetGroupMemberRoleResponse], error) { + errorLogger := NewErrorLogger() + + orgID := request.Msg.GetOrgId() + groupID := request.Msg.GetGroupId() + principalID := request.Msg.GetPrincipalId() + principalType := request.Msg.GetPrincipalType() + roleID := request.Msg.GetRoleId() + + if _, err := h.orgService.Get(ctx, orgID); err != nil { + switch { + case errors.Is(err, organization.ErrDisabled): + return nil, connect.NewError(connect.CodeNotFound, ErrOrgDisabled) + case errors.Is(err, organization.ErrNotExist): + return nil, connect.NewError(connect.CodeNotFound, ErrOrgNotFound) + default: + errorLogger.LogServiceError(ctx, request, "SetGroupMemberRole.GetOrg", err, + "org_id", orgID) + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + if err := h.membershipService.SetGroupMemberRole(ctx, groupID, principalID, principalType, roleID); err != nil { + errorLogger.LogServiceError(ctx, request, "SetGroupMemberRole", err, + "group_id", groupID, + "principal_id", principalID, + "principal_type", principalType, + "role_id", roleID) + + switch { + case errors.Is(err, group.ErrNotExist), errors.Is(err, group.ErrInvalidID), errors.Is(err, group.ErrInvalidUUID): + return nil, connect.NewError(connect.CodeNotFound, ErrGroupNotFound) + case errors.Is(err, user.ErrNotExist): + return nil, connect.NewError(connect.CodeNotFound, ErrUserNotExist) + case errors.Is(err, user.ErrDisabled): + return nil, connect.NewError(connect.CodeFailedPrecondition, err) + case errors.Is(err, role.ErrNotExist), errors.Is(err, role.ErrInvalidID): + return nil, connect.NewError(connect.CodeNotFound, ErrInvalidRoleID) + case errors.Is(err, membership.ErrInvalidPrincipalType), errors.Is(err, membership.ErrInvalidPrincipal): + return nil, connect.NewError(connect.CodeInvalidArgument, err) + case errors.Is(err, membership.ErrInvalidGroupRole): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrInvalidGroupRole) + case errors.Is(err, membership.ErrNotMember): + return nil, connect.NewError(connect.CodeFailedPrecondition, ErrNotMember) + case errors.Is(err, membership.ErrLastGroupOwnerRole): + return nil, connect.NewError(connect.CodeFailedPrecondition, ErrLastGroupOwnerRole) + default: + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + return connect.NewResponse(&frontierv1beta1.SetGroupMemberRoleResponse{}), nil +} + func (h *ConnectHandler) EnableGroup(ctx context.Context, request *connect.Request[frontierv1beta1.EnableGroupRequest]) (*connect.Response[frontierv1beta1.EnableGroupResponse], error) { errorLogger := NewErrorLogger() diff --git a/internal/api/v1beta1connect/group_test.go b/internal/api/v1beta1connect/group_test.go index c17b6f00e..bd12739e1 100644 --- a/internal/api/v1beta1connect/group_test.go +++ b/internal/api/v1beta1connect/group_test.go @@ -1717,3 +1717,155 @@ func TestConnectHandler_DeleteGroup(t *testing.T) { }) } } + +func TestConnectHandler_SetGroupMemberRole(t *testing.T) { + someGroupID := utils.NewString() + somePrincipalID := utils.NewString() + someRoleID := utils.NewString() + + baseRequest := func() *connect.Request[frontierv1beta1.SetGroupMemberRoleRequest] { + return connect.NewRequest(&frontierv1beta1.SetGroupMemberRoleRequest{ + OrgId: testOrgID, + GroupId: someGroupID, + PrincipalId: somePrincipalID, + PrincipalType: schema.UserPrincipal, + RoleId: someRoleID, + }) + } + + tests := []struct { + name string + setup func(ms *mocks.MembershipService, os *mocks.OrganizationService) + request *connect.Request[frontierv1beta1.SetGroupMemberRoleRequest] + want *connect.Response[frontierv1beta1.SetGroupMemberRoleResponse] + wantErr error + }{ + { + name: "should return not found if org does not exist", + setup: func(_ *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(organization.Organization{}, organization.ErrNotExist) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrOrgNotFound), + }, + { + name: "should return not found if org is disabled", + setup: func(_ *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(organization.Organization{}, organization.ErrDisabled) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrOrgDisabled), + }, + { + name: "should return not found if group does not exist", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(group.ErrNotExist) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrGroupNotFound), + }, + { + name: "should return not found if user does not exist", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(user.ErrNotExist) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrUserNotExist), + }, + { + name: "should return not found if role does not exist", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(role.ErrNotExist) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeNotFound, ErrInvalidRoleID), + }, + { + name: "should return invalid argument if role is not valid for group scope", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(membership.ErrInvalidGroupRole) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeInvalidArgument, ErrInvalidGroupRole), + }, + { + name: "should return invalid argument if principal type is unsupported", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + // handler must forward the unsupported principal_type to the service unchanged + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.ServiceUserPrincipal, someRoleID).Return(membership.ErrInvalidPrincipalType) + }, + request: connect.NewRequest(&frontierv1beta1.SetGroupMemberRoleRequest{ + OrgId: testOrgID, + GroupId: someGroupID, + PrincipalId: somePrincipalID, + PrincipalType: schema.ServiceUserPrincipal, + RoleId: someRoleID, + }), + wantErr: connect.NewError(connect.CodeInvalidArgument, membership.ErrInvalidPrincipalType), + }, + { + name: "should return failed precondition if principal is not a group member", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(membership.ErrNotMember) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeFailedPrecondition, ErrNotMember), + }, + { + name: "should return failed precondition if demoting last group owner", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(membership.ErrLastGroupOwnerRole) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeFailedPrecondition, ErrLastGroupOwnerRole), + }, + { + name: "should return internal error for unknown errors", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(errors.New("unknown")) + }, + request: baseRequest(), + wantErr: connect.NewError(connect.CodeInternal, ErrInternalServerError), + }, + { + name: "should return success on valid request", + setup: func(ms *mocks.MembershipService, os *mocks.OrganizationService) { + os.EXPECT().Get(mock.Anything, testOrgID).Return(testOrgMap[testOrgID], nil) + ms.EXPECT().SetGroupMemberRole(mock.Anything, someGroupID, somePrincipalID, schema.UserPrincipal, someRoleID).Return(nil) + }, + request: baseRequest(), + want: connect.NewResponse(&frontierv1beta1.SetGroupMemberRoleResponse{}), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockMembershipSvc := new(mocks.MembershipService) + mockOrgSvc := new(mocks.OrganizationService) + if tt.setup != nil { + tt.setup(mockMembershipSvc, mockOrgSvc) + } + h := ConnectHandler{ + membershipService: mockMembershipSvc, + orgService: mockOrgSvc, + } + got, err := h.SetGroupMemberRole(context.Background(), tt.request) + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.wantErr.(*connect.Error).Code(), err.(*connect.Error).Code()) + assert.Equal(t, tt.wantErr.(*connect.Error).Message(), err.(*connect.Error).Message()) + } else { + assert.NoError(t, err) + assert.EqualValues(t, tt.want, got) + } + }) + } +} diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 6d1f7e24e..a37004a5c 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -406,6 +406,9 @@ type MembershipService interface { SetProjectMemberRole(ctx context.Context, projectID, principalID, principalType, roleID string) error RemoveProjectMember(ctx context.Context, projectID, principalID, principalType string) error ListPrincipalsByResource(ctx context.Context, resourceID, resourceType string, filter membership.MemberFilter) ([]membership.Member, error) + AddGroupMember(ctx context.Context, groupID, principalID, principalType, roleID string) error + SetGroupMemberRole(ctx context.Context, groupID, principalID, principalType, roleID string) error + OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, creatorType string) error } type UserPATService interface { diff --git a/internal/api/v1beta1connect/mocks/membership_service.go b/internal/api/v1beta1connect/mocks/membership_service.go index cbe58da4a..ca0b3f620 100644 --- a/internal/api/v1beta1connect/mocks/membership_service.go +++ b/internal/api/v1beta1connect/mocks/membership_service.go @@ -22,6 +22,56 @@ func (_m *MembershipService) EXPECT() *MembershipService_Expecter { return &MembershipService_Expecter{mock: &_m.Mock} } +// AddGroupMember provides a mock function with given fields: ctx, groupID, principalID, principalType, roleID +func (_m *MembershipService) AddGroupMember(ctx context.Context, groupID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, groupID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for AddGroupMember") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, groupID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_AddGroupMember_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddGroupMember' +type MembershipService_AddGroupMember_Call struct { + *mock.Call +} + +// AddGroupMember is a helper method to define mock.On call +// - ctx context.Context +// - groupID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) AddGroupMember(ctx interface{}, groupID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_AddGroupMember_Call { + return &MembershipService_AddGroupMember_Call{Call: _e.mock.On("AddGroupMember", ctx, groupID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_AddGroupMember_Call) Run(run func(ctx context.Context, groupID string, principalID string, principalType string, roleID string)) *MembershipService_AddGroupMember_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_AddGroupMember_Call) Return(_a0 error) *MembershipService_AddGroupMember_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_AddGroupMember_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_AddGroupMember_Call { + _c.Call.Return(run) + return _c +} + // AddOrganizationMember provides a mock function with given fields: ctx, orgID, principalID, principalType, roleID func (_m *MembershipService) AddOrganizationMember(ctx context.Context, orgID string, principalID string, principalType string, roleID string) error { ret := _m.Called(ctx, orgID, principalID, principalType, roleID) @@ -133,6 +183,56 @@ func (_c *MembershipService_ListPrincipalsByResource_Call) RunAndReturn(run func return _c } +// OnGroupCreated provides a mock function with given fields: ctx, groupID, orgID, creatorID, creatorType +func (_m *MembershipService) OnGroupCreated(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string) error { + ret := _m.Called(ctx, groupID, orgID, creatorID, creatorType) + + if len(ret) == 0 { + panic("no return value specified for OnGroupCreated") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, groupID, orgID, creatorID, creatorType) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_OnGroupCreated_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnGroupCreated' +type MembershipService_OnGroupCreated_Call struct { + *mock.Call +} + +// OnGroupCreated is a helper method to define mock.On call +// - ctx context.Context +// - groupID string +// - orgID string +// - creatorID string +// - creatorType string +func (_e *MembershipService_Expecter) OnGroupCreated(ctx interface{}, groupID interface{}, orgID interface{}, creatorID interface{}, creatorType interface{}) *MembershipService_OnGroupCreated_Call { + return &MembershipService_OnGroupCreated_Call{Call: _e.mock.On("OnGroupCreated", ctx, groupID, orgID, creatorID, creatorType)} +} + +func (_c *MembershipService_OnGroupCreated_Call) Run(run func(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string)) *MembershipService_OnGroupCreated_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_OnGroupCreated_Call) Return(_a0 error) *MembershipService_OnGroupCreated_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_OnGroupCreated_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_OnGroupCreated_Call { + _c.Call.Return(run) + return _c +} + // RemoveOrganizationMember provides a mock function with given fields: ctx, orgID, principalID, principalType func (_m *MembershipService) RemoveOrganizationMember(ctx context.Context, orgID string, principalID string, principalType string) error { ret := _m.Called(ctx, orgID, principalID, principalType) @@ -231,6 +331,56 @@ func (_c *MembershipService_RemoveProjectMember_Call) RunAndReturn(run func(cont return _c } +// SetGroupMemberRole provides a mock function with given fields: ctx, groupID, principalID, principalType, roleID +func (_m *MembershipService) SetGroupMemberRole(ctx context.Context, groupID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, groupID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for SetGroupMemberRole") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, groupID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_SetGroupMemberRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetGroupMemberRole' +type MembershipService_SetGroupMemberRole_Call struct { + *mock.Call +} + +// SetGroupMemberRole is a helper method to define mock.On call +// - ctx context.Context +// - groupID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) SetGroupMemberRole(ctx interface{}, groupID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_SetGroupMemberRole_Call { + return &MembershipService_SetGroupMemberRole_Call{Call: _e.mock.On("SetGroupMemberRole", ctx, groupID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_SetGroupMemberRole_Call) Run(run func(ctx context.Context, groupID string, principalID string, principalType string, roleID string)) *MembershipService_SetGroupMemberRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_SetGroupMemberRole_Call) Return(_a0 error) *MembershipService_SetGroupMemberRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetGroupMemberRole_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_SetGroupMemberRole_Call { + _c.Call.Return(run) + return _c +} + // SetOrganizationMemberRole provides a mock function with given fields: ctx, orgID, principalID, principalType, roleID func (_m *MembershipService) SetOrganizationMemberRole(ctx context.Context, orgID string, principalID string, principalType string, roleID string) error { ret := _m.Called(ctx, orgID, principalID, principalType, roleID) diff --git a/pkg/auditrecord/consts.go b/pkg/auditrecord/consts.go index ca8a09bf2..4045d2eca 100644 --- a/pkg/auditrecord/consts.go +++ b/pkg/auditrecord/consts.go @@ -43,6 +43,10 @@ const ( ProjectMemberRoleChangedEvent Event = "project.member_role_changed" ProjectMemberRemovedEvent Event = "project.member_removed" + // Group Member Events + GroupMemberAddedEvent Event = "group.member_added" + GroupMemberRoleChangedEvent Event = "group.member_role_changed" + // KYC Events KYCVerifiedEvent Event = "kyc.verified" KYCUnverifiedEvent Event = "kyc.unverified" diff --git a/pkg/server/connect_interceptors/authorization.go b/pkg/server/connect_interceptors/authorization.go index 39b65e81e..9a836e027 100644 --- a/pkg/server/connect_interceptors/authorization.go +++ b/pkg/server/connect_interceptors/authorization.go @@ -474,6 +474,10 @@ var authorizationValidationMap = map[string]func(ctx context.Context, handler *v pbreq := req.(*connect.Request[frontierv1beta1.RemoveGroupUserRequest]) return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.GroupNamespace, ID: pbreq.Msg.GetId()}, schema.UpdatePermission, req) }, + "/raystack.frontier.v1beta1.FrontierService/SetGroupMemberRole": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error { + pbreq := req.(*connect.Request[frontierv1beta1.SetGroupMemberRoleRequest]) + return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.GroupNamespace, ID: pbreq.Msg.GetGroupId()}, schema.UpdatePermission, req) + }, "/raystack.frontier.v1beta1.FrontierService/EnableGroup": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error { pbreq := req.(*connect.Request[frontierv1beta1.EnableGroupRequest]) return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.GroupNamespace, ID: pbreq.Msg.GetId()}, schema.DeletePermission, req)