diff --git a/internal/cmd/import-test/with-import-common.zed b/internal/cmd/import-test/with-import-common.zed new file mode 100644 index 00000000..b002e033 --- /dev/null +++ b/internal/cmd/import-test/with-import-common.zed @@ -0,0 +1,6 @@ +definition user {} + +caveat mycaveat(day_of_week string) { + day_of_week == "friday" +} + diff --git a/internal/cmd/import-test/with-import-root.zed b/internal/cmd/import-test/with-import-root.zed new file mode 100644 index 00000000..b57c62fe --- /dev/null +++ b/internal/cmd/import-test/with-import-root.zed @@ -0,0 +1,10 @@ +use import + +import "with-import-common.zed" + +definition resource { + relation writer: user + relation reader: user with mycaveat + permission write = writer + permission view = reader + write +} \ No newline at end of file diff --git a/internal/cmd/import-test/with-import-validation-file.yaml b/internal/cmd/import-test/with-import-validation-file.yaml new file mode 100644 index 00000000..7c372fae --- /dev/null +++ b/internal/cmd/import-test/with-import-validation-file.yaml @@ -0,0 +1,5 @@ +--- +schemaFile: "./with-import-root.zed" +relationships: |- + resource:1#reader@user:1[mycaveat] + resource:2#writer@user:1 diff --git a/internal/cmd/import.go b/internal/cmd/import.go index 8ce71fb9..0e4de0db 100644 --- a/internal/cmd/import.go +++ b/internal/cmd/import.go @@ -82,7 +82,7 @@ func importCmdFunc(cmd *cobra.Command, schemaClient v1.SchemaServiceClient, rela } if cobrautil.MustGetBool(cmd, "schema") { - if err := importSchema(cmd.Context(), schemaClient, p.Schema.Schema, prefix); err != nil { + if err := importSchema(cmd.Context(), schemaClient, p.Schema.Schema, prefix, p.RootSchemaDir); err != nil { return fmt.Errorf("error importing schema: %w", err) } } @@ -98,11 +98,12 @@ func importCmdFunc(cmd *cobra.Command, schemaClient v1.SchemaServiceClient, rela return err } -func importSchema(ctx context.Context, client v1.SchemaServiceClient, schema string, definitionPrefix string) error { +func importSchema(ctx context.Context, client v1.SchemaServiceClient, schema string, definitionPrefix string, rootSchemaDir string) error { log.Info().Msg("importing schema") - // Recompile the schema with the specified prefix. - schemaText, err := rewriteSchema(ctx, schema, definitionPrefix) + // Compile with the schema's root directory so any `import` statements resolve, and + // (optionally) apply the definition prefix. + schemaText, err := rewriteSchema(ctx, schema, definitionPrefix, rootSchemaDir) if err != nil { return err } diff --git a/internal/cmd/import_test.go b/internal/cmd/import_test.go index 11c76fab..d6847022 100644 --- a/internal/cmd/import_test.go +++ b/internal/cmd/import_test.go @@ -103,6 +103,53 @@ func TestImportCmd(t *testing.T) { } } +func TestImportCmdSchemaWithImports(t *testing.T) { + require := require.New(t) + cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t, + zedtesting.StringFlag{FlagName: "schema-definition-prefix"}, + zedtesting.BoolFlag{FlagName: "schema", FlagValue: true}, + zedtesting.BoolFlag{FlagName: "relationships", FlagValue: true}, + zedtesting.IntFlag{FlagName: "batch-size", FlagValue: 100}, + zedtesting.IntFlag{FlagName: "workers", FlagValue: 1}, + ) + f := filepath.Join("import-test", "with-import-validation-file.yaml") + + ctx := t.Context() + srv := zedtesting.NewTestServer(ctx, t) + go func() { + assert.NoError(t, srv.Run(ctx)) + }() + conn, err := srv.GRPCDialContext(ctx) + require.NoError(err) + t.Cleanup(func() { + conn.Close() + }) + + c, err := zedtesting.ClientFromConn(conn)(cmd) + require.NoError(err) + + // The YAML points to a .zed file that uses `import "with-import-common.zed"`. WriteSchema + // rejects `import` statements, so this exercises that the client flattens the schema + // (via rewriteSchema + SourceFolder) before sending. + err = importCmdFunc(cmd, c, c, "", f) + require.NoError(err) + + rel := tuple.MustParse(`resource:1#view@user:1[mycaveat]`) + resp, err := c.CheckPermission(ctx, &v1.CheckPermissionRequest{ + Consistency: fullyConsistent, + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: rel.Subject.ObjectType, ObjectId: rel.Subject.ObjectID}}, + Permission: "view", + Resource: &v1.ObjectReference{ObjectType: rel.Resource.ObjectType, ObjectId: rel.Resource.ObjectID}, + Context: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "day_of_week": structpb.NewStringValue("friday"), + }, + }, + }) + require.NoError(err) + require.Equal(v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, resp.Permissionship) +} + func TestImportCmdRelationsOnly(t *testing.T) { cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t, zedtesting.StringFlag{FlagName: "schema-definition-prefix"}, diff --git a/internal/cmd/schema.go b/internal/cmd/schema.go index 031117fd..b3a75b54 100644 --- a/internal/cmd/schema.go +++ b/internal/cmd/schema.go @@ -235,7 +235,7 @@ func schemaCopyInner(ctx context.Context, srcClient, destClient v1.SchemaService return nil, err } - schemaText, err := rewriteSchema(ctx, readResp.SchemaText, prefix) + schemaText, err := rewriteSchema(ctx, readResp.SchemaText, prefix, "") if err != nil { return nil, err } @@ -263,12 +263,14 @@ func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServi } var schemaBytes []byte + var sourceFolder string switch len(args) { case 1: schemaBytes, err = os.ReadFile(args[0]) if err != nil { return fmt.Errorf("failed to read schema file: %w", err) } + sourceFolder = filepath.Dir(args[0]) log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file") case 0: schemaBytes, err = io.ReadAll(os.Stdin) @@ -289,7 +291,7 @@ func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServi return err } - schemaText, err := rewriteSchema(cmd.Context(), string(schemaBytes), prefix) + schemaText, err := rewriteSchema(cmd.Context(), string(schemaBytes), prefix, sourceFolder) if err != nil { return err } @@ -316,8 +318,8 @@ func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServi } // rewriteSchema rewrites the given existing schema to include the specified prefix on all definitions and caveats. -func rewriteSchema(ctx context.Context, existingSchemaText string, definitionPrefix string) (string, error) { - if definitionPrefix == "" { +func rewriteSchema(ctx context.Context, existingSchemaText string, definitionPrefix string, sourceFolder string) (string, error) { + if definitionPrefix == "" && sourceFolder == "" { return existingSchemaText, nil } @@ -325,6 +327,7 @@ func rewriteSchema(ctx context.Context, existingSchemaText string, definitionPre compiler.InputSchema{Source: input.Source("schema"), SchemaString: existingSchemaText}, compiler.ObjectTypePrefix(definitionPrefix), compiler.SkipValidation(), + compiler.SourceFolder(sourceFolder), ) if err != nil { return "", err diff --git a/internal/cmd/schema_test.go b/internal/cmd/schema_test.go index ac83ce07..bb4ea34e 100644 --- a/internal/cmd/schema_test.go +++ b/internal/cmd/schema_test.go @@ -123,7 +123,7 @@ caveat test/some_caveat(someCondition int) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - found, err := rewriteSchema(t.Context(), test.existingSchema, test.definitionPrefix) + found, err := rewriteSchema(t.Context(), test.existingSchema, test.definitionPrefix, "") require.NoError(t, err) require.Equal(t, test.expectedSchema, found) }) @@ -375,9 +375,10 @@ func TestSchemaWrite(t *testing.T) { }, nil }, expectSchemaWritten: `definition user {} + definition resource { - relation view: user - permission viewer = view + relation view: user + permission viewer = view }`, terminalChecker: &mockTermChecker{returnVal: false}, },