From 3a3ee2d2f7b6d9e13671db06934f71945f794aac Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Wed, 13 May 2026 09:16:19 +0530 Subject: [PATCH] refactor(group): migrate group creation to membership package group.Create now delegates SpiceDB wiring (org<->group hierarchy + creator-as-owner) to membership.OnGroupCreated. The three helpers (addAsOrgMember, addOrgToGroup, addOwner) are removed from the group service. MembershipService is injected via setter to break the circular init order with the membership service (same pattern as organization and serviceuser). group.Create's behavior is unchanged at the SpiceDB layer; this is a pure refactor that consolidates the writes through the membership package so the rollback/compensation logic added in #1596 applies to group creation as well. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/serve.go | 5 +- core/group/mocks/membership_service.go | 86 ++++++++++++++++++++ core/group/service.go | 106 ++++--------------------- core/group/service_test.go | 83 ++++++++----------- 4 files changed, 140 insertions(+), 140 deletions(-) create mode 100644 core/group/mocks/membership_service.go diff --git a/cmd/serve.go b/cmd/serve.go index e51ad7ae2..61dd1a2c8 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -432,10 +432,11 @@ func buildAPIDependencies( authnService, serviceUserService, groupService, roleService) membershipService := membership.NewService(logger, policyService, relationService, roleService, organizationService, userService, projectService, groupService, serviceUserService, auditRecordRepository) - // Setter injection: org → membership is circular (membership needs org for validation, - // org needs membership for Create/AdminCreate). Break the cycle with a post-init setter. + // Setter injection: org/group → membership is circular (membership needs them + // for validation; they need membership for Create). Break the cycle post-init. organizationService.SetMembershipService(membershipService) serviceUserService.SetMembershipService(membershipService) + groupService.SetMembershipService(membershipService) orgKycRepository := postgres.NewOrgKycRepository(dbc) orgKycService := kyc.NewService(orgKycRepository) diff --git a/core/group/mocks/membership_service.go b/core/group/mocks/membership_service.go new file mode 100644 index 000000000..b425239ae --- /dev/null +++ b/core/group/mocks/membership_service.go @@ -0,0 +1,86 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// MembershipService is an autogenerated mock type for the MembershipService type +type MembershipService struct { + mock.Mock +} + +type MembershipService_Expecter struct { + mock *mock.Mock +} + +func (_m *MembershipService) EXPECT() *MembershipService_Expecter { + return &MembershipService_Expecter{mock: &_m.Mock} +} + +// OnGroupCreated provides a mock function with given fields: ctx, groupID, orgID, creatorID, creatorType +func (_m *MembershipService) OnGroupCreated(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string) error { + ret := _m.Called(ctx, groupID, orgID, creatorID, creatorType) + + if len(ret) == 0 { + panic("no return value specified for OnGroupCreated") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, groupID, orgID, creatorID, creatorType) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_OnGroupCreated_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OnGroupCreated' +type MembershipService_OnGroupCreated_Call struct { + *mock.Call +} + +// OnGroupCreated is a helper method to define mock.On call +// - ctx context.Context +// - groupID string +// - orgID string +// - creatorID string +// - creatorType string +func (_e *MembershipService_Expecter) OnGroupCreated(ctx interface{}, groupID interface{}, orgID interface{}, creatorID interface{}, creatorType interface{}) *MembershipService_OnGroupCreated_Call { + return &MembershipService_OnGroupCreated_Call{Call: _e.mock.On("OnGroupCreated", ctx, groupID, orgID, creatorID, creatorType)} +} + +func (_c *MembershipService_OnGroupCreated_Call) Run(run func(ctx context.Context, groupID string, orgID string, creatorID string, creatorType string)) *MembershipService_OnGroupCreated_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_OnGroupCreated_Call) Return(_a0 error) *MembershipService_OnGroupCreated_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_OnGroupCreated_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_OnGroupCreated_Call { + _c.Call.Return(run) + return _c +} + +// NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMembershipService(t interface { + mock.TestingT + Cleanup(func()) +}) *MembershipService { + mock := &MembershipService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/group/service.go b/core/group/service.go index cad606489..54cea3876 100644 --- a/core/group/service.go +++ b/core/group/service.go @@ -38,11 +38,16 @@ type PolicyService interface { GroupMemberCount(ctx context.Context, ids []string) ([]policy.MemberCount, error) } +type MembershipService interface { + OnGroupCreated(ctx context.Context, groupID, orgID, creatorID, creatorType string) error +} + type Service struct { - repository Repository - relationService RelationService - authnService AuthnService - policyService PolicyService + repository Repository + relationService RelationService + authnService AuthnService + policyService PolicyService + membershipService MembershipService } func NewService(repository Repository, relationService RelationService, @@ -55,6 +60,12 @@ func NewService(repository Repository, relationService RelationService, } } +// SetMembershipService sets the membership dependency after construction to break +// the circular init order between group and membership services. +func (s *Service) SetMembershipService(ms MembershipService) { + s.membershipService = ms +} + func (s Service) Create(ctx context.Context, grp Group) (Group, error) { principal, err := s.authnService.GetPrincipal(ctx) if err != nil { @@ -66,17 +77,7 @@ func (s Service) Create(ctx context.Context, grp Group) (Group, error) { return Group{}, err } - // attach group to org - if err = s.addAsOrgMember(ctx, newGroup); err != nil { - return Group{}, err - } - // add relationship between group to org - if err = s.addOrgToGroup(ctx, newGroup); err != nil { - return Group{}, err - } - - // attach current user to group as owner - if err = s.addOwner(ctx, newGroup.ID, principal); err != nil { + if err = s.membershipService.OnGroupCreated(ctx, newGroup.ID, newGroup.OrganizationID, principal.ID, principal.Type); err != nil { return Group{}, err } @@ -190,36 +191,6 @@ func (s Service) AddMember(ctx context.Context, groupID string, principal authen return nil } -// addOwner adds a user as an owner of group by creating a policy of owner role and an owner relation -func (s Service) addOwner(ctx context.Context, groupID string, principal authenticate.Principal) error { - pol := policy.Policy{ - RoleID: schema.GroupOwnerRole, - ResourceID: groupID, - ResourceType: schema.GroupNamespace, - PrincipalID: principal.ID, - PrincipalType: principal.Type, - } - if _, err := s.policyService.Create(ctx, pol); err != nil { - return err - } - // then create a relation between group and user - rel := relation.Relation{ - Object: relation.Object{ - ID: groupID, - Namespace: schema.GroupNamespace, - }, - Subject: relation.Subject{ - ID: principal.ID, - Namespace: principal.Type, - }, - RelationName: schema.OwnerRelationName, - } - if _, err := s.relationService.Create(ctx, rel); err != nil { - return err - } - return nil -} - // add a policy to user as member of group func (s Service) addMemberPolicy(ctx context.Context, groupID string, principal authenticate.Principal) error { pol := policy.Policy{ @@ -235,51 +206,6 @@ func (s Service) addMemberPolicy(ctx context.Context, groupID string, principal return nil } -// addOrgToGroup creates an inverse relation that connects group to org -func (s Service) addOrgToGroup(ctx context.Context, team Group) error { - rel := relation.Relation{ - Object: relation.Object{ - ID: team.ID, - Namespace: schema.GroupNamespace, - }, - Subject: relation.Subject{ - ID: team.OrganizationID, - Namespace: schema.OrganizationNamespace, - }, - RelationName: schema.OrganizationRelationName, - } - - _, err := s.relationService.Create(ctx, rel) - if err != nil { - return err - } - - return nil -} - -// addAsOrgMember connects group as a member to org -func (s Service) addAsOrgMember(ctx context.Context, team Group) error { - rel := relation.Relation{ - Object: relation.Object{ - ID: team.OrganizationID, - Namespace: schema.OrganizationNamespace, - }, - Subject: relation.Subject{ - ID: team.ID, - Namespace: schema.GroupNamespace, - SubRelationName: schema.MemberRelationName, - }, - RelationName: schema.MemberRelationName, - } - - _, err := s.relationService.Create(ctx, rel) - if err != nil { - return err - } - - return nil -} - // ListByOrganization will be useful for nested groups but we don't do that at the moment // so it will not be directly used func (s Service) ListByOrganization(ctx context.Context, id string) ([]Group, error) { diff --git a/core/group/service_test.go b/core/group/service_test.go index 1df317482..130477dd3 100644 --- a/core/group/service_test.go +++ b/core/group/service_test.go @@ -21,21 +21,21 @@ import ( ) func TestService_Create(t *testing.T) { - t.Run("should create group successfully by adding member to org, adding relation between group and org, and making current user owner", func(t *testing.T) { + t.Run("should create group and delegate hierarchy + owner wiring to membership", func(t *testing.T) { mockRepo := mocks.NewRepository(t) mockAuthnSvc := mocks.NewAuthnService(t) mockRelationSvc := mocks.NewRelationService(t) mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) - mockUserID := uuid.New() + mockUserID := uuid.New().String() mockAuthnSvc.On("GetPrincipal", mock.Anything).Return(authenticate.Principal{ - ID: mockUserID.String(), - Type: "user", - User: &user.User{ - ID: mockUserID.String(), - }, + ID: mockUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: mockUserID}, }, nil) groupParam := group.Group{ @@ -43,53 +43,13 @@ func TestService_Create(t *testing.T) { Title: "Test Group", OrganizationID: uuid.New().String(), } - groupInRepo := groupParam groupInRepo.ID = uuid.New().String() mockRepo.On("Create", mock.Anything, groupParam).Return(groupInRepo, nil) - // when adding group as org member - mockRelationSvc.On("Create", mock.Anything, mock.AnythingOfType("relation.Relation")).Run(func(args mock.Arguments) { - arg := args.Get(1) - r := arg.(relation.Relation) - assert.Equal(t, r.Object.ID, groupInRepo.OrganizationID) - assert.Equal(t, r.Subject.ID, groupInRepo.ID) - assert.Equal(t, r.RelationName, schema.MemberRelationName) - }).Return(relation.Relation{}, nil).Once() - - // when adding group to org - mockRelationSvc.On("Create", mock.Anything, mock.AnythingOfType("relation.Relation")).Run(func(args mock.Arguments) { - arg := args.Get(1) - r := arg.(relation.Relation) - assert.Equal(t, r.Object.ID, groupInRepo.ID) - assert.Equal(t, r.Subject.ID, groupInRepo.OrganizationID) - assert.Equal(t, r.RelationName, schema.OrganizationRelationName) - }).Return(relation.Relation{}, nil).Once() - - // when adding current user as group owner - mockPolicySvc.On("Create", mock.Anything, mock.AnythingOfType("policy.Policy")).Run(func(args mock.Arguments) { - arg := args.Get(1) - r := arg.(policy.Policy) - assert.Equal(t, r.RoleID, schema.GroupOwnerRole) - assert.Equal(t, r.ResourceID, groupInRepo.ID) - assert.Equal(t, r.ResourceType, schema.GroupNamespace) - assert.Equal(t, r.PrincipalID, mockUserID.String()) - assert.Equal(t, r.PrincipalType, "user") - }).Return(policy.Policy{}, nil).Once() - - // adding relation between group and user - mockRelationSvc.On("Create", mock.Anything, mock.AnythingOfType("relation.Relation")).Run(func(args mock.Arguments) { - arg := args.Get(1) - r := arg.(relation.Relation) - assert.Equal(t, r.Object.ID, groupInRepo.ID) - assert.Equal(t, r.Object.Namespace, schema.GroupNamespace) - assert.Equal(t, r.Subject.ID, mockUserID.String()) - assert.Equal(t, r.Subject.Namespace, "user") - assert.Equal(t, r.RelationName, schema.OwnerRelationName) - }).Return(relation.Relation{}, nil).Once() + mockMembershipSvc.EXPECT().OnGroupCreated(mock.Anything, groupInRepo.ID, groupInRepo.OrganizationID, mockUserID, schema.UserPrincipal).Return(nil) grp, err := svc.Create(context.Background(), groupParam) - assert.Nil(t, err) assert.Equal(t, grp.Name, groupParam.Name) }) @@ -108,6 +68,33 @@ func TestService_Create(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, strings.Contains(err.Error(), authenticate.ErrInvalidID.Error()), true) }) + + t.Run("should propagate error from membership.OnGroupCreated", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockRelationSvc := mocks.NewRelationService(t) + mockPolicySvc := mocks.NewPolicyService(t) + mockMembershipSvc := mocks.NewMembershipService(t) + + svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + svc.SetMembershipService(mockMembershipSvc) + + mockUserID := uuid.New().String() + mockAuthnSvc.On("GetPrincipal", mock.Anything).Return(authenticate.Principal{ + ID: mockUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: mockUserID}, + }, nil) + + groupParam := group.Group{Name: "g", OrganizationID: uuid.New().String()} + groupInRepo := groupParam + groupInRepo.ID = uuid.New().String() + mockRepo.On("Create", mock.Anything, groupParam).Return(groupInRepo, nil) + mockMembershipSvc.EXPECT().OnGroupCreated(mock.Anything, groupInRepo.ID, groupInRepo.OrganizationID, mockUserID, schema.UserPrincipal).Return(errors.New("spicedb down")) + + _, err := svc.Create(context.Background(), groupParam) + assert.ErrorContains(t, err, "spicedb down") + }) } func TestService_Get(t *testing.T) {