diff --git a/core/membership/mocks/policy_service.go b/core/membership/mocks/policy_service.go index ab81d7b31..c2e1b4dac 100644 --- a/core/membership/mocks/policy_service.go +++ b/core/membership/mocks/policy_service.go @@ -127,6 +127,54 @@ func (_c *PolicyService_Delete_Call) RunAndReturn(run func(context.Context, stri return _c } +// DeleteWithMinRoleGuard provides a mock function with given fields: ctx, id, guardRoleID +func (_m *PolicyService) DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error { + ret := _m.Called(ctx, id, guardRoleID) + + if len(ret) == 0 { + panic("no return value specified for DeleteWithMinRoleGuard") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, id, guardRoleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PolicyService_DeleteWithMinRoleGuard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteWithMinRoleGuard' +type PolicyService_DeleteWithMinRoleGuard_Call struct { + *mock.Call +} + +// DeleteWithMinRoleGuard is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - guardRoleID string +func (_e *PolicyService_Expecter) DeleteWithMinRoleGuard(ctx interface{}, id interface{}, guardRoleID interface{}) *PolicyService_DeleteWithMinRoleGuard_Call { + return &PolicyService_DeleteWithMinRoleGuard_Call{Call: _e.mock.On("DeleteWithMinRoleGuard", ctx, id, guardRoleID)} +} + +func (_c *PolicyService_DeleteWithMinRoleGuard_Call) Run(run func(ctx context.Context, id string, guardRoleID string)) *PolicyService_DeleteWithMinRoleGuard_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *PolicyService_DeleteWithMinRoleGuard_Call) Return(_a0 error) *PolicyService_DeleteWithMinRoleGuard_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *PolicyService_DeleteWithMinRoleGuard_Call) RunAndReturn(run func(context.Context, string, string) error) *PolicyService_DeleteWithMinRoleGuard_Call { + _c.Call.Return(run) + return _c +} + // List provides a mock function with given fields: ctx, flt func (_m *PolicyService) List(ctx context.Context, flt policy.Filter) ([]policy.Policy, error) { ret := _m.Called(ctx, flt) diff --git a/core/membership/service.go b/core/membership/service.go index bb81be5c3..25bea4edc 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -28,6 +28,7 @@ type PolicyService interface { Create(ctx context.Context, pol policy.Policy) (policy.Policy, error) List(ctx context.Context, flt policy.Filter) ([]policy.Policy, error) Delete(ctx context.Context, id string) error + DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error } type RelationService interface { @@ -218,11 +219,12 @@ func (s *Service) SetOrganizationMemberRole(ctx context.Context, orgID, principa return nil } - if err := s.validateMinOwnerConstraint(ctx, orgID, resolvedRoleID, existing); err != nil { + ownerRoleID, err := s.validateMinOwnerConstraint(ctx, orgID, resolvedRoleID, existing) + if err != nil { return err } - if err := s.replacePolicy(ctx, orgID, schema.OrganizationNamespace, principalID, principalType, resolvedRoleID, existing); err != nil { + if err := s.replacePolicy(ctx, orgID, schema.OrganizationNamespace, principalID, principalType, resolvedRoleID, existing, ownerRoleID); err != nil { return err } @@ -272,11 +274,31 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return ErrNotMember } - if err = s.validateMinOwnerConstraint(ctx, orgID, "", orgPolicies); err != nil { + ownerRoleID, err := s.validateMinOwnerConstraint(ctx, orgID, "", orgPolicies) + if err != nil { return err } - // pre-compute org project and group ID sets for O(1) lookups + if err := s.cascadeRemovePrincipal(ctx, org, principalID, principalType, ownerRoleID); err != nil { + return err + } + + s.auditOrgMemberRemoved(ctx, org, principalID, targetAuditType) + audit.GetAuditor(ctx, org.ID).Log(audit.OrgMemberDeletedEvent, audit.Target{ + ID: principalID, + Type: principalType, + }) + + return nil +} + +// cascadeRemovePrincipal deletes all policies and SpiceDB relations for a principal +// being removed from an organization, including cascaded project/group sub-resources. +// Owner-role org policies are deleted with the atomic guard first; if the guard rejects +// (last owner), the method returns ErrLastOwnerRole before any other mutation. +func (s *Service) cascadeRemovePrincipal(ctx context.Context, org organization.Organization, principalID, principalType, ownerRoleID string) error { + orgID := org.ID + orgProjects, err := s.projectService.List(ctx, project.Filter{OrgID: orgID}) if err != nil { return fmt.Errorf("list org projects: %w", err) @@ -295,7 +317,6 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal orgGroupIDSet[g.ID] = struct{}{} } - // list all policies for the principal across all resources allPolicies, err := s.policyService.List(ctx, policy.Filter{ PrincipalID: principalID, PrincipalType: principalType, @@ -304,28 +325,40 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return fmt.Errorf("list all principal policies: %w", err) } - // delete sub-resource policies first (projects, groups), then relations, - // then org policies last — so a retry after partial failure won't hit ErrNotMember - var orgPolicyIDs []string - var errs error + // classify policies by scope + var orgPolicies, subResourcePolicies []policy.Policy for _, pol := range allPolicies { switch pol.ResourceType { case schema.OrganizationNamespace: if pol.ResourceID == orgID { - orgPolicyIDs = append(orgPolicyIDs, pol.ID) + orgPolicies = append(orgPolicies, pol) } case schema.ProjectNamespace: if _, ok := orgProjectIDSet[pol.ResourceID]; ok { - if err := s.policyService.Delete(ctx, pol.ID); err != nil { - errs = errors.Join(errs, fmt.Errorf("delete project policy %s: %w", pol.ID, err)) - } + subResourcePolicies = append(subResourcePolicies, pol) } case schema.GroupNamespace: if _, ok := orgGroupIDSet[pol.ResourceID]; ok { - if err := s.policyService.Delete(ctx, pol.ID); err != nil { - errs = errors.Join(errs, fmt.Errorf("delete group policy %s: %w", pol.ID, err)) - } + subResourcePolicies = append(subResourcePolicies, pol) + } + } + } + + // guarded owner delete first — returns early if this is the last owner + for _, pol := range orgPolicies { + if err := s.deletePolicy(ctx, pol, ownerRoleID); err != nil { + if errors.Is(err, policy.ErrLastRoleGuard) { + return ErrLastOwnerRole } + return fmt.Errorf("delete org policy %s: %w", pol.ID, err) + } + } + + // guard passed — delete sub-resource policies + var errs error + for _, pol := range subResourcePolicies { + if err := s.policyService.Delete(ctx, pol.ID); err != nil { + errs = errors.Join(errs, fmt.Errorf("delete sub-resource policy %s: %w", pol.ID, err)) } } if errs != nil { @@ -338,7 +371,7 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return errs } - // remove relations at group level + // clean up SpiceDB relations for _, g := range orgGroups { if err := s.removeRelations(ctx, g.ID, schema.GroupNamespace, principalID, principalType); err != nil { s.log.Error("partial failure removing member: group relation cleanup failed, manual cleanup may be needed", @@ -351,8 +384,6 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return fmt.Errorf("remove group %s relations: %w", g.ID, err) } } - - // remove relations at org level if err := s.removeRelations(ctx, orgID, schema.OrganizationNamespace, principalID, principalType); err != nil { s.log.Error("partial failure removing member: org relation cleanup failed, manual cleanup may be needed", "org_id", orgID, @@ -363,7 +394,7 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return fmt.Errorf("remove org relations: %w", err) } - // remove identity link for service users (serviceuser#org@organization) + // remove identity link for service users if principalType == schema.ServiceUserPrincipal { err := s.relationService.Delete(ctx, relation.Relation{ Object: relation.Object{ID: principalID, Namespace: schema.ServiceUserPrincipal}, @@ -375,26 +406,6 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal } } - // delete org-level policies last - for _, policyID := range orgPolicyIDs { - if err := s.policyService.Delete(ctx, policyID); err != nil { - s.log.Error("partial failure removing member: org policy deletion failed, manual cleanup may be needed", - "org_id", orgID, - "policy_id", policyID, - "principal_id", principalID, - "principal_type", principalType, - "error", err, - ) - return fmt.Errorf("delete org policy %s: %w", policyID, err) - } - } - - s.auditOrgMemberRemoved(ctx, org, principalID, targetAuditType) - audit.GetAuditor(ctx, org.ID).Log(audit.OrgMemberDeletedEvent, audit.Target{ - ID: principalID, - Type: principalType, - }) - return nil } @@ -412,15 +423,16 @@ func (s *Service) removeRelations(ctx context.Context, resourceID, resourceType, } // validateMinOwnerConstraint ensures the org always has at least one owner after a role change. -func (s *Service) validateMinOwnerConstraint(ctx context.Context, orgID, newRoleID string, existing []policy.Policy) error { +// Returns the resolved owner role ID for reuse by callers. +func (s *Service) validateMinOwnerConstraint(ctx context.Context, orgID, newRoleID string, existing []policy.Policy) (string, error) { ownerRole, err := s.roleService.Get(ctx, schema.RoleOrganizationOwner) if err != nil { - return fmt.Errorf("get owner role: %w", err) + return "", fmt.Errorf("get owner role: %w", err) } // no constraint if promoting to owner if newRoleID == ownerRole.ID { - return nil + return ownerRole.ID, nil } // no constraint if user is not currently an owner @@ -432,7 +444,7 @@ func (s *Service) validateMinOwnerConstraint(ctx context.Context, orgID, newRole } } if !isCurrentlyOwner { - return nil + return ownerRole.ID, nil } // user is owner, being demoted — make sure at least one other owner remains @@ -441,19 +453,23 @@ func (s *Service) validateMinOwnerConstraint(ctx context.Context, orgID, newRole RoleID: ownerRole.ID, }) if err != nil { - return fmt.Errorf("list owner policies: %w", err) + return "", fmt.Errorf("list owner policies: %w", err) } if len(ownerPolicies) <= 1 { - return ErrLastOwnerRole + return "", ErrLastOwnerRole } - return nil + return ownerRole.ID, nil } // replacePolicy deletes the given existing policies and creates a new one with the new role. -func (s *Service) replacePolicy(ctx context.Context, resourceID, resourceType, principalID, principalType, roleID string, existing []policy.Policy) error { +// When ownerRoleID is non-empty, owner-role policies are deleted atomically via +// DeleteWithMinRoleGuard to prevent the last-owner TOCTOU race. +func (s *Service) replacePolicy(ctx context.Context, resourceID, resourceType, principalID, principalType, roleID string, existing []policy.Policy, ownerRoleID string) error { for _, p := range existing { - err := s.policyService.Delete(ctx, p.ID) - if err != nil { + if err := s.deletePolicy(ctx, p, ownerRoleID); err != nil { + if errors.Is(err, policy.ErrLastRoleGuard) { + return ErrLastOwnerRole + } return fmt.Errorf("delete policy %s: %w", p.ID, err) } } @@ -473,6 +489,13 @@ func (s *Service) replacePolicy(ctx context.Context, resourceID, resourceType, p return nil } +func (s *Service) deletePolicy(ctx context.Context, pol policy.Policy, ownerRoleID string) error { + if ownerRoleID != "" && pol.RoleID == ownerRoleID { + return s.policyService.DeleteWithMinRoleGuard(ctx, pol.ID, ownerRoleID) + } + return s.policyService.Delete(ctx, pol.ID) +} + // replaceRelation deletes the given old relations for the principal on the resource, // then creates a new relation with the given name. // Only relation.ErrNotExist is ignored on delete — any other error is returned. @@ -743,7 +766,7 @@ func (s *Service) SetProjectMemberRole(ctx context.Context, projectID, principal return nil } - if err := s.replacePolicy(ctx, projectID, schema.ProjectNamespace, principalID, principalType, resolvedRoleID, existing); err != nil { + if err := s.replacePolicy(ctx, projectID, schema.ProjectNamespace, principalID, principalType, resolvedRoleID, existing, ""); err != nil { return err } @@ -1118,7 +1141,7 @@ func (s *Service) SetGroupMemberRole(ctx context.Context, groupID, principalID, return err } - if err := s.replacePolicy(ctx, groupID, schema.GroupNamespace, principalID, principalType, resolvedRoleID, existing); err != nil { + if err := s.replacePolicy(ctx, groupID, schema.GroupNamespace, principalID, principalType, resolvedRoleID, existing, ""); err != nil { return err } diff --git a/core/membership/service_test.go b/core/membership/service_test.go index 186dd1203..175773583 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -405,6 +405,22 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { roleID: managerRoleID, wantErr: membership.ErrLastOwnerRole, }, + { + name: "should return ErrLastOwnerRole when DB guard rejects concurrent demotion", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, orgSvc *mocks.OrgService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + orgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) + roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) + // app-level check passes (sees 2 owners) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) + // DB-level guard rejects (concurrent request already deleted the other owner) + policySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(policy.ErrLastRoleGuard) + }, + roleID: viewerRoleID, + wantErr: membership.ErrLastOwnerRole, + }, { name: "should succeed demoting owner to viewer with multiple owners", setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, orgSvc *mocks.OrgService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { @@ -414,8 +430,8 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}, {ID: "p2", RoleID: ownerRoleID}}, nil) - // replace policy - policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + // replace policy with owner guard + policySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(nil) policySvc.EXPECT().Create(ctx, policy.Policy{ RoleID: viewerRoleID, ResourceID: orgID, ResourceType: schema.OrganizationNamespace, PrincipalID: userID, PrincipalType: schema.UserPrincipal, @@ -438,7 +454,7 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: viewerRoleID}}, nil) // promoting to owner — min-owner constraint doesn't apply roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) - // replace policy + // existing policy is viewer (non-owner), uses plain Delete policySvc.EXPECT().Delete(ctx, "p1").Return(nil) policySvc.EXPECT().Create(ctx, policy.Policy{ RoleID: ownerRoleID, ResourceID: orgID, ResourceType: schema.OrganizationNamespace, @@ -462,7 +478,7 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) - policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + policySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(nil) policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) // relation delete fails with a real error — logged, no rollback relSvc.EXPECT().Delete(ctx, orgRelation(schema.OwnerRelationName)).Return(errors.New("spicedb connection error")) @@ -478,6 +494,7 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { roleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: managerRoleID}}, nil) roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) + // existing policy is manager (non-owner), uses plain Delete policySvc.EXPECT().Delete(ctx, "p1").Return(nil) policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) // both relation deletes return not-found — that's fine, should continue @@ -546,7 +563,7 @@ func TestService_SetOrganizationMemberRole_ServiceUser(t *testing.T) { mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: suID, PrincipalType: schema.ServiceUserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) mockRoleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) - mockPolicySvc.EXPECT().Delete(ctx, "p1").Return(nil) + mockPolicySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(nil) mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) mockRelSvc.EXPECT().Delete(ctx, mock.Anything).Return(relation.ErrNotExist).Times(2) mockRelSvc.EXPECT().Create(ctx, mock.Anything).Return(relation.Relation{}, nil) @@ -752,10 +769,10 @@ func TestService_RemoveOrganizationMember(t *testing.T) { }, nil) d.policySvc.EXPECT().Delete(ctx, "proj-p1").Return(errors.New("delete failed")) }, - wantErrContain: "delete project policy", + wantErrContain: "delete sub-resource policy", }, { - name: "should return error if org relation removal fails without deleting org policies", + name: "should return error if org relation removal fails after org policies deleted", setup: func(d testDeps) { d.orgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) d.policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1", RoleID: viewerRoleID}}, nil) @@ -763,9 +780,11 @@ func TestService_RemoveOrganizationMember(t *testing.T) { d.projSvc.EXPECT().List(ctx, project.Filter{OrgID: orgID}).Return([]project.Project{}, nil) d.grpSvc.EXPECT().List(ctx, group.Filter{OrganizationID: orgID}).Return([]group.Group{}, nil) d.policySvc.EXPECT().List(ctx, policy.Filter{PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{ - {ID: "org-p1", ResourceType: schema.OrganizationNamespace, ResourceID: orgID}, + {ID: "org-p1", ResourceType: schema.OrganizationNamespace, ResourceID: orgID, RoleID: viewerRoleID}, }, nil) - // org policy Delete should NOT be called — relations fail first, org policies are last + // org policy deleted first (viewer, plain Delete) + d.policySvc.EXPECT().Delete(ctx, "org-p1").Return(nil) + // then relation removal fails d.relSvc.EXPECT().Delete(ctx, relation.Relation{Object: orgObj, Subject: userSub, RelationName: schema.OwnerRelationName}).Return(errors.New("spicedb down")) }, wantErrContain: "remove org relations", diff --git a/core/policy/errors.go b/core/policy/errors.go index e33649d08..236305b23 100644 --- a/core/policy/errors.go +++ b/core/policy/errors.go @@ -8,4 +8,5 @@ var ( ErrInvalidID = errors.New("policy id is invalid") ErrConflict = errors.New("policy already exist") ErrInvalidDetail = errors.New("invalid policy detail") + ErrLastRoleGuard = errors.New("cannot delete: this is the last policy with the guarded role for this resource") ) diff --git a/core/policy/mocks/repository.go b/core/policy/mocks/repository.go index fdd9172ba..7486bbebe 100644 --- a/core/policy/mocks/repository.go +++ b/core/policy/mocks/repository.go @@ -126,6 +126,54 @@ func (_c *Repository_Delete_Call) RunAndReturn(run func(context.Context, string) return _c } +// DeleteWithMinRoleGuard provides a mock function with given fields: ctx, id, guardRoleID +func (_m *Repository) DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error { + ret := _m.Called(ctx, id, guardRoleID) + + if len(ret) == 0 { + panic("no return value specified for DeleteWithMinRoleGuard") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, id, guardRoleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_DeleteWithMinRoleGuard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteWithMinRoleGuard' +type Repository_DeleteWithMinRoleGuard_Call struct { + *mock.Call +} + +// DeleteWithMinRoleGuard is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - guardRoleID string +func (_e *Repository_Expecter) DeleteWithMinRoleGuard(ctx interface{}, id interface{}, guardRoleID interface{}) *Repository_DeleteWithMinRoleGuard_Call { + return &Repository_DeleteWithMinRoleGuard_Call{Call: _e.mock.On("DeleteWithMinRoleGuard", ctx, id, guardRoleID)} +} + +func (_c *Repository_DeleteWithMinRoleGuard_Call) Run(run func(ctx context.Context, id string, guardRoleID string)) *Repository_DeleteWithMinRoleGuard_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *Repository_DeleteWithMinRoleGuard_Call) Return(_a0 error) *Repository_DeleteWithMinRoleGuard_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_DeleteWithMinRoleGuard_Call) RunAndReturn(run func(context.Context, string, string) error) *Repository_DeleteWithMinRoleGuard_Call { + _c.Call.Return(run) + return _c +} + // Get provides a mock function with given fields: ctx, id func (_m *Repository) Get(ctx context.Context, id string) (policy.Policy, error) { ret := _m.Called(ctx, id) diff --git a/core/policy/policy.go b/core/policy/policy.go index ce2aaecf6..7287d10b6 100644 --- a/core/policy/policy.go +++ b/core/policy/policy.go @@ -13,6 +13,7 @@ type Repository interface { Count(ctx context.Context, f Filter) (int64, error) Upsert(ctx context.Context, pol Policy) (Policy, error) Delete(ctx context.Context, id string) error + DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error GroupMemberCount(ctx context.Context, IDs []string) ([]MemberCount, error) ProjectMemberCount(ctx context.Context, IDs []string) ([]MemberCount, error) OrgMemberCount(ctx context.Context, ID string) (MemberCount, error) diff --git a/core/policy/service.go b/core/policy/service.go index 93075d4c6..e1a432d31 100644 --- a/core/policy/service.go +++ b/core/policy/service.go @@ -89,6 +89,18 @@ func (s Service) Delete(ctx context.Context, id string) error { return s.repository.Delete(ctx, id) } +func (s Service) DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error { + if err := s.repository.DeleteWithMinRoleGuard(ctx, id, guardRoleID); err != nil { + return err + } + return s.relationService.Delete(ctx, relation.Relation{ + Object: relation.Object{ + ID: id, + Namespace: schema.RoleBindingNamespace, + }, + }) +} + // AssignRole Note: ideally this should be in a single transaction // read more about how user defined roles work in spicedb https://authzed.com/blog/user-defined-roles func (s Service) AssignRole(ctx context.Context, pol Policy) error { diff --git a/internal/store/postgres/policy_repository.go b/internal/store/postgres/policy_repository.go index cb834a479..aba751860 100644 --- a/internal/store/postgres/policy_repository.go +++ b/internal/store/postgres/policy_repository.go @@ -363,6 +363,84 @@ func (r PolicyRepository) Delete(ctx context.Context, id string) error { return nil } +// DeleteWithMinRoleGuard atomically deletes a policy only if at least one other +// policy with the same guarded role remains for the resource. Uses SELECT FOR UPDATE +// to serialize concurrent deletions under READ COMMITTED isolation, preventing the +// TOCTOU race where two concurrent requests both pass a count check then both delete. +// Resource ID and type are derived from the existing policy, not from caller input. +func (r PolicyRepository) DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error { + existingPolicy, err := r.Get(ctx, id) + if err != nil { + return err + } + + if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error { + return r.dbc.WithTimeout(ctx, TABLE_POLICIES, "DeleteWithMinRoleGuard", func(ctx context.Context) error { + query := `WITH locked AS ( + SELECT id FROM ` + TABLE_POLICIES + ` + WHERE resource_id = $2 + AND resource_type = $3 + AND role_id = $4 + ORDER BY id + FOR UPDATE + ) + DELETE FROM ` + TABLE_POLICIES + ` WHERE id = $1 AND ( + (SELECT role_id FROM ` + TABLE_POLICIES + ` WHERE id = $1) != $4 + OR (SELECT COUNT(*) FROM locked WHERE id != $1) > 0 + )` + result, err := tx.ExecContext(ctx, query, + id, + existingPolicy.ResourceID, + existingPolicy.ResourceType, + guardRoleID, + ) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + var existingID string + err := tx.QueryRowContext(ctx, + `SELECT id FROM `+TABLE_POLICIES+` WHERE id = $1`, id, + ).Scan(&existingID) + if errors.Is(err, sql.ErrNoRows) { + return sql.ErrNoRows + } + if err != nil { + return err + } + return policy.ErrLastRoleGuard + } + + policyDB := Policy{ + ID: existingPolicy.ID, + RoleID: existingPolicy.RoleID, + ResourceID: existingPolicy.ResourceID, + ResourceType: existingPolicy.ResourceType, + PrincipalID: existingPolicy.PrincipalID, + PrincipalType: existingPolicy.PrincipalType, + } + auditRecord := r.buildPolicyAuditRecord(ctx, tx, auditrecord.PolicyDeletedEvent, policyDB, time.Now(), nil) + return InsertAuditRecordInTx(ctx, tx, auditRecord) + }) + }); err != nil { + if errors.Is(err, policy.ErrLastRoleGuard) { + return err + } + err = checkPostgresError(err) + switch { + case errors.Is(err, sql.ErrNoRows): + return policy.ErrNotExist + default: + return err + } + } + return nil +} + func (r PolicyRepository) GroupMemberCount(ctx context.Context, groupIDs []string) ([]policy.MemberCount, error) { if len(groupIDs) == 0 { return nil, policy.ErrInvalidID