From c8b2f7dd04adce3bff0ac0252b9cb86c28bd31dc Mon Sep 17 00:00:00 2001 From: cooronx <2197083441@qq.com> Date: Fri, 8 May 2026 19:37:43 +0800 Subject: [PATCH] feat: add optional type_definition for user and role prefixes --- errors/rbac_errors.go | 1 + management_api.go | 87 +++++++++++++++++++++++- model/model.go | 11 ++++ model/policy.go | 11 ++++ model/type.go | 150 ++++++++++++++++++++++++++++++++++++++++++ rbac_api.go | 22 +++++++ type_system_test.go | 144 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 424 insertions(+), 2 deletions(-) create mode 100644 model/type.go create mode 100644 type_system_test.go diff --git a/errors/rbac_errors.go b/errors/rbac_errors.go index 2f358b372..29f558730 100644 --- a/errors/rbac_errors.go +++ b/errors/rbac_errors.go @@ -23,6 +23,7 @@ var ( ErrLinkNotFound = errors.New("error: link between name1 and name2 does not exist") ErrUseDomainParameter = errors.New("error: useDomain should be 1 parameter") ErrInvalidFieldValuesParameter = errors.New("fieldValues requires at least one parameter") + ErrInvalidTypeDefinition = errors.New("error: invalid type definition") // GetAllowedObjectConditions errors. ErrObjCondition = errors.New("need to meet the prefix required by the object condition") diff --git a/management_api.go b/management_api.go index d18536cc2..1fb8d9d0b 100644 --- a/management_api.go +++ b/management_api.go @@ -17,6 +17,7 @@ package casbin import ( "errors" "fmt" + "sort" "strings" "github.com/casbin/casbin/v3/constant" @@ -24,6 +25,73 @@ import ( "github.com/casbin/govaluate" ) +func (e *Enforcer) getTypedPrincipals() ([]string, []string, bool, error) { + userSet := map[string]struct{}{} + roleSet := map[string]struct{}{} + + appendByType := func(values []string) error { + for _, value := range values { + entityType, enabled, err := e.model.GetEntityType(value) + if err != nil { + return err + } + if !enabled { + return nil + } + switch entityType { + case "user": + userSet[value] = struct{}{} + case "role": + roleSet[value] = struct{}{} + } + } + return nil + } + + values, err := e.model.GetValuesForFieldInPolicyAllTypesByName("p", constant.SubjectIndex) + if err != nil { + return nil, nil, false, err + } + if err := appendByType(values); err != nil { + return nil, nil, false, err + } + _, enabled, err := e.model.GetEntityType("") + if err != nil { + return nil, nil, false, err + } + if !enabled { + return nil, nil, false, nil + } + + if _, err := e.model.GetAssertion("g", "g"); err == nil { + groupingPolicy, err := e.GetNamedGroupingPolicy("g") + if err != nil { + return nil, nil, false, err + } + for _, rule := range groupingPolicy { + limit := 2 + if len(rule) < limit { + limit = len(rule) + } + if err := appendByType(rule[:limit]); err != nil { + return nil, nil, false, err + } + } + } + + users := make([]string, 0, len(userSet)) + for user := range userSet { + users = append(users, user) + } + roles := make([]string, 0, len(roleSet)) + for role := range roleSet { + roles = append(roles, role) + } + sort.Strings(users) + sort.Strings(roles) + return users, roles, true, nil +} + // GetAllSubjects gets the list of subjects that show up in the current policy. func (e *Enforcer) GetAllSubjects() ([]string, error) { return e.model.GetValuesForFieldInPolicyAllTypesByName("p", constant.SubjectIndex) @@ -68,6 +136,13 @@ func (e *Enforcer) GetAllNamedActions(ptype string) ([]string, error) { // GetAllRoles gets the list of roles that show up in the current policy. func (e *Enforcer) GetAllRoles() ([]string, error) { + _, roles, enabled, err := e.getTypedPrincipals() + if err != nil { + return nil, err + } + if enabled { + return roles, nil + } return e.model.GetValuesForFieldInPolicyAllTypes("g", 1) } @@ -79,6 +154,14 @@ func (e *Enforcer) GetAllNamedRoles(ptype string) ([]string, error) { // GetAllUsers gets the list of users that show up in the current policy. // Users are subjects that are not roles (i.e., subjects that do not appear as the second element in any grouping policy). func (e *Enforcer) GetAllUsers() ([]string, error) { + users, _, enabled, err := e.getTypedPrincipals() + if err != nil { + return nil, err + } + if enabled { + return users, nil + } + subjects, err := e.GetAllSubjects() if err != nil { return nil, err @@ -89,8 +172,8 @@ func (e *Enforcer) GetAllUsers() ([]string, error) { return nil, err } - users := util.SetSubtract(subjects, roles) - return users, nil + result := util.SetSubtract(subjects, roles) + return result, nil } // GetPolicy gets all the authorization rules in the policy. diff --git a/model/model.go b/model/model.go index b541e1b84..7f51867cc 100644 --- a/model/model.go +++ b/model/model.go @@ -44,6 +44,7 @@ var sectionNameMap = map[string]string{ "e": "policy_effect", "m": "matchers", "c": "constraint_definition", + "t": "type_definition", } // Minimal required sections for a model to be valid. @@ -118,6 +119,12 @@ func getKeySuffix(i int) string { } func loadSection(model Model, cfg config.ConfigInterface, sec string) { + if sec == "t" { + loadAssertion(model, cfg, sec, userTypeKey) + loadAssertion(model, cfg, sec, roleTypeKey) + return + } + i := 1 for { if !loadAssertion(model, cfg, sec, sec+getKeySuffix(i)) { @@ -203,6 +210,10 @@ func (model Model) loadModelFromConfig(cfg config.ConfigInterface) error { return err } + if err := model.ValidateTypeDefinitions(); err != nil { + return err + } + return nil } diff --git a/model/policy.go b/model/policy.go index e55bf4105..ef28786c1 100644 --- a/model/policy.go +++ b/model/policy.go @@ -199,6 +199,9 @@ func (model Model) AddPolicy(sec string, ptype string, rule []string) error { if err != nil { return err } + if err := model.ValidatePolicyTypes(sec, ptype, rule); err != nil { + return err + } assertion.Policy = append(assertion.Policy, rule) assertion.PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec][ptype].Policy) - 1 @@ -282,6 +285,9 @@ func (model Model) UpdatePolicy(sec string, ptype string, oldRule []string, newR if err != nil { return false, err } + if err := model.ValidatePolicyTypes(sec, ptype, newRule); err != nil { + return false, err + } oldPolicy := strings.Join(oldRule, DefaultSep) index, ok := model[sec][ptype].PolicyMap[oldPolicy] if !ok { @@ -301,6 +307,11 @@ func (model Model) UpdatePolicies(sec string, ptype string, oldRules, newRules [ if err != nil { return false, err } + for _, newRule := range newRules { + if err := model.ValidatePolicyTypes(sec, ptype, newRule); err != nil { + return false, err + } + } rollbackFlag := false // index -> []{oldIndex, newIndex} modifiedRuleIndex := make(map[int][]int) diff --git a/model/type.go b/model/type.go new file mode 100644 index 000000000..e2140510f --- /dev/null +++ b/model/type.go @@ -0,0 +1,150 @@ +// Copyright 2026 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "fmt" + "strings" + + "github.com/casbin/casbin/v3/constant" + Err "github.com/casbin/casbin/v3/errors" +) + +const ( + typeDefinitionSection = "t" + userTypeKey = "user" + roleTypeKey = "role" +) + +type entityType string + +const ( + entityTypeUnknown entityType = "" + entityTypeUser entityType = "user" + entityTypeRole entityType = "role" +) + +type typeDefinition struct { + userPrefix string + rolePrefix string +} + +func (model Model) getTypeDefinition() (*typeDefinition, bool, error) { + section := model[typeDefinitionSection] + if len(section) == 0 { + return nil, false, nil + } + + userAssertion, hasUser := section[userTypeKey] + roleAssertion, hasRole := section[roleTypeKey] + if !hasUser || !hasRole { + return nil, false, fmt.Errorf("%w: type_definition must define both user and role", Err.ErrInvalidTypeDefinition) + } + + userPrefix := strings.TrimSpace(userAssertion.Value) + rolePrefix := strings.TrimSpace(roleAssertion.Value) + if userPrefix == "" || rolePrefix == "" { + return nil, false, fmt.Errorf("%w: user and role prefixes cannot be empty", Err.ErrInvalidTypeDefinition) + } + if userPrefix == rolePrefix { + return nil, false, fmt.Errorf("%w: user and role prefixes must be different", Err.ErrInvalidTypeDefinition) + } + if strings.HasPrefix(userPrefix, rolePrefix) || strings.HasPrefix(rolePrefix, userPrefix) { + return nil, false, fmt.Errorf("%w: user and role prefixes must not overlap", Err.ErrInvalidTypeDefinition) + } + + return &typeDefinition{userPrefix: userPrefix, rolePrefix: rolePrefix}, true, nil +} + +func (model Model) ValidateTypeDefinitions() error { + _, _, err := model.getTypeDefinition() + return err +} + +func (model Model) GetEntityType(name string) (string, bool, error) { + def, enabled, err := model.getTypeDefinition() + if err != nil || !enabled { + return "", enabled, err + } + + switch { + case strings.HasPrefix(name, def.userPrefix): + return string(entityTypeUser), true, nil + case strings.HasPrefix(name, def.rolePrefix): + return string(entityTypeRole), true, nil + default: + return "", true, nil + } +} + +func (model Model) ValidatePolicyTypes(sec string, ptype string, rule []string) error { + def, enabled, err := model.getTypeDefinition() + if err != nil || !enabled { + return err + } + + switch sec { + case "p": + index, err := model.GetFieldIndex(ptype, constant.SubjectIndex) + if err != nil { + return err + } + if index >= len(rule) { + return nil + } + return validateEntityType(rule[index], ptype+".sub", def, entityTypeUser, entityTypeRole) + case "g": + if ptype != "g" || len(rule) < 2 { + return nil + } + if err := validateEntityType(rule[0], ptype+"[0]", def, entityTypeUser, entityTypeRole); err != nil { + return err + } + return validateEntityType(rule[1], ptype+"[1]", def, entityTypeRole) + default: + return nil + } +} + +func validateEntityType(name string, field string, def *typeDefinition, allowed ...entityType) error { + actual := getEntityType(name, def) + if actual == entityTypeUnknown { + return fmt.Errorf("type mismatch for %s: %q does not match any configured user/role prefix", field, name) + } + + for _, allowedType := range allowed { + if actual == allowedType { + return nil + } + } + + expected := make([]string, 0, len(allowed)) + for _, allowedType := range allowed { + expected = append(expected, string(allowedType)) + } + + return fmt.Errorf("type mismatch for %s: %q is %s, expected %s", field, name, actual, strings.Join(expected, " or ")) +} + +func getEntityType(name string, def *typeDefinition) entityType { + switch { + case strings.HasPrefix(name, def.userPrefix): + return entityTypeUser + case strings.HasPrefix(name, def.rolePrefix): + return entityTypeRole + default: + return entityTypeUnknown + } +} diff --git a/rbac_api.go b/rbac_api.go index c1ca4f7a3..6569fd045 100644 --- a/rbac_api.go +++ b/rbac_api.go @@ -371,6 +371,28 @@ func (e *Enforcer) GetNamedImplicitPermissionsForUser(ptype string, gtype string // GetImplicitUsersForPermission("data1", "read") will get: ["alice", "bob"]. // Note: only users will be returned, roles (2nd arg in "g") will be excluded. func (e *Enforcer) GetImplicitUsersForPermission(permission ...string) ([]string, error) { + if _, _, enabled, err := e.getTypedPrincipals(); err != nil { + return nil, err + } else if enabled { + subjects, err := e.GetAllUsers() + if err != nil { + return nil, err + } + + res := []string{} + for _, user := range subjects { + req := util.JoinSliceAny(user, permission...) + allowed, err := e.Enforce(req...) + if err != nil { + return nil, err + } + if allowed { + res = append(res, user) + } + } + return res, nil + } + pSubjects, err := e.GetAllSubjects() if err != nil { return nil, err diff --git a/type_system_test.go b/type_system_test.go new file mode 100644 index 000000000..9406e86be --- /dev/null +++ b/type_system_test.go @@ -0,0 +1,144 @@ +// Copyright 2026 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "strings" + "testing" + + "github.com/casbin/casbin/v3/model" + "github.com/casbin/casbin/v3/persist" + stringadapter "github.com/casbin/casbin/v3/persist/string-adapter" +) + +const typedRBACModel = ` +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[type_definition] +user = user: +role = role: + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act +` + +func TestTypedRoleListsAndEnforce(t *testing.T) { + policy := strings.TrimSpace(` +p, role:data2_admin, data2, read +p, role:data2_admin, data2, write +p, user:bob, data2, write +g, user:alice, role:data2_admin +`) + + m, err := model.NewModelFromString(typedRBACModel) + if err != nil { + t.Fatalf("load model failed: %v", err) + } + + e, err := NewEnforcer(m, stringadapter.NewAdapter(policy)) + if err != nil { + t.Fatalf("new enforcer failed: %v", err) + } + + testStringList(t, "Roles", e.GetAllRoles, []string{"role:data2_admin"}) + testStringList(t, "Users", e.GetAllUsers, []string{"user:alice", "user:bob"}) + + testEnforce(t, e, "user:alice", "data2", "read", true) + testEnforce(t, e, "user:bob", "data2", "write", true) + + users, err := e.GetImplicitUsersForPermission("data2", "write") + if err != nil { + t.Fatalf("GetImplicitUsersForPermission failed: %v", err) + } + if len(users) != 2 || users[0] != "user:alice" || users[1] != "user:bob" { + t.Fatalf("GetImplicitUsersForPermission got %v", users) + } +} + +func TestTypedRoleValidationOnPolicyMutation(t *testing.T) { + m, err := model.NewModelFromString(typedRBACModel) + if err != nil { + t.Fatalf("load model failed: %v", err) + } + + e, err := NewEnforcer(m) + if err != nil { + t.Fatalf("new enforcer failed: %v", err) + } + + if _, err := e.AddRoleForUser("user:alice", "role:admin"); err != nil { + t.Fatalf("AddRoleForUser should succeed: %v", err) + } + + if _, err := e.AddRoleForUser("user:alice", "user:bob"); err == nil || !strings.Contains(err.Error(), "expected role") { + t.Fatalf("expected role type error, got %v", err) + } + + if _, err := e.AddPermissionForUser("group:ops", "data1", "read"); err == nil || !strings.Contains(err.Error(), "does not match any configured user/role prefix") { + t.Fatalf("expected typed subject error, got %v", err) + } +} + +func TestTypedRoleValidationOnPolicyLoad(t *testing.T) { + m, err := model.NewModelFromString(typedRBACModel) + if err != nil { + t.Fatalf("load model failed: %v", err) + } + + if err := persist.LoadPolicyLine("g, user:alice, user:bob", m); err == nil || !strings.Contains(err.Error(), "expected role") { + t.Fatalf("expected invalid grouping policy error, got %v", err) + } + + if err := persist.LoadPolicyLine("p, team:ops, data1, read", m); err == nil || !strings.Contains(err.Error(), "does not match any configured user/role prefix") { + t.Fatalf("expected invalid subject policy error, got %v", err) + } +} + +func TestTypedRoleDefinitionValidation(t *testing.T) { + invalidModel := ` +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[type_definition] +user = actor: +role = actor: + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act +` + + if _, err := model.NewModelFromString(invalidModel); err == nil || !strings.Contains(err.Error(), "prefixes must be different") { + t.Fatalf("expected invalid type definition error, got %v", err) + } +}