diff --git a/language/protobuf/BUILD.bazel b/language/protobuf/BUILD.bazel index fc88ea09..650368c7 100644 --- a/language/protobuf/BUILD.bazel +++ b/language/protobuf/BUILD.bazel @@ -17,6 +17,9 @@ go_library( "//pkg/plugin/grpc/grpcnode", "//pkg/plugin/grpc/grpcweb", "//pkg/plugin/grpcecosystem/grpcgateway", + "//pkg/plugin/neoeinstein/prost", + "//pkg/plugin/neoeinstein/prost_serde", + "//pkg/plugin/neoeinstein/tonic", "//pkg/plugin/scalapb/scalapb", "//pkg/plugin/scalapb/zio_grpc", "//pkg/plugin/stackb/grpc_js", @@ -27,6 +30,7 @@ go_library( "//pkg/rule/rules_java", "//pkg/rule/rules_nodejs", "//pkg/rule/rules_python", + "//pkg/rule/rules_rust", "//pkg/rule/rules_scala", "@bazel_gazelle//language", ], diff --git a/language/protobuf/protobuf.go b/language/protobuf/protobuf.go index 0edc5d54..00363cbd 100644 --- a/language/protobuf/protobuf.go +++ b/language/protobuf/protobuf.go @@ -5,6 +5,7 @@ import ( "github.com/stackb/rules_proto/v4/pkg/language/protobuf" + _ "github.com/stackb/rules_proto/v4/pkg/plugin/bufbuild" _ "github.com/stackb/rules_proto/v4/pkg/plugin/builtin" _ "github.com/stackb/rules_proto/v4/pkg/plugin/gogo/protobuf" _ "github.com/stackb/rules_proto/v4/pkg/plugin/golang/protobuf" @@ -14,10 +15,12 @@ import ( _ "github.com/stackb/rules_proto/v4/pkg/plugin/grpc/grpcnode" _ "github.com/stackb/rules_proto/v4/pkg/plugin/grpc/grpcweb" _ "github.com/stackb/rules_proto/v4/pkg/plugin/grpcecosystem/grpcgateway" + _ "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost" + _ "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost_serde" + _ "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/tonic" _ "github.com/stackb/rules_proto/v4/pkg/plugin/scalapb/scalapb" _ "github.com/stackb/rules_proto/v4/pkg/plugin/scalapb/zio_grpc" _ "github.com/stackb/rules_proto/v4/pkg/plugin/stackb/grpc_js" - _ "github.com/stackb/rules_proto/v4/pkg/plugin/bufbuild" _ "github.com/stackb/rules_proto/v4/pkg/plugin/stephenh/ts-proto" _ "github.com/stackb/rules_proto/v4/pkg/rule/rules_cc" _ "github.com/stackb/rules_proto/v4/pkg/rule/rules_closure" @@ -25,6 +28,7 @@ import ( _ "github.com/stackb/rules_proto/v4/pkg/rule/rules_java" _ "github.com/stackb/rules_proto/v4/pkg/rule/rules_nodejs" _ "github.com/stackb/rules_proto/v4/pkg/rule/rules_python" + _ "github.com/stackb/rules_proto/v4/pkg/rule/rules_rust" _ "github.com/stackb/rules_proto/v4/pkg/rule/rules_scala" ) diff --git a/pkg/plugin/neoeinstein/prost/BUILD.bazel b/pkg/plugin/neoeinstein/prost/BUILD.bazel new file mode 100644 index 00000000..5ece9faa --- /dev/null +++ b/pkg/plugin/neoeinstein/prost/BUILD.bazel @@ -0,0 +1,40 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "prost", + srcs = [ + "extern_paths.go", + "protoc-gen-prost.go", + ], + importpath = "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost", + visibility = ["//visibility:public"], + deps = [ + "//pkg/protoc", + "@bazel_gazelle//label", + "@bazel_gazelle//rule", + ], +) + +go_test( + name = "prost_test", + srcs = [ + "extern_paths_test.go", + "protoc-gen-prost_test.go", + ], + deps = [ + ":prost", + "//pkg/plugintest", + "//pkg/protoc", + "@bazel_gazelle//label", + "@bazel_gazelle//rule", + ], +) + +filegroup( + name = "all_files", + testonly = True, + srcs = [ + "BUILD.bazel", + ] + glob(["*.go"]), + visibility = ["//pkg:__pkg__"], +) diff --git a/pkg/plugin/neoeinstein/prost/extern_paths.go b/pkg/plugin/neoeinstein/prost/extern_paths.go new file mode 100644 index 00000000..d014575e --- /dev/null +++ b/pkg/plugin/neoeinstein/prost/extern_paths.go @@ -0,0 +1,284 @@ +package prost + +import ( + "container/list" + "path" + "sort" + "strings" + + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/rule" + + "github.com/stackb/rules_proto/v4/pkg/protoc" +) + +const ( + // TransitiveExternPathsKey caches the dependency-only extern_path option + // strings on the library rule's private attrs. + TransitiveExternPathsKey = "_transitive_extern_paths" + // OwnProtoPackagesKey caches the set of proto packages the library + // itself contributes, used to compute self-extern overrides for + // reference-emitting plugins (serde, tonic). + OwnProtoPackagesKey = "_own_proto_packages" +) + +// ResolveExternPathOptions filters existing extern_path= options from +// cfg.Options, resolves transitive dependency extern paths, and returns the +// combined options list. +// +// This variant is used by protoc-gen-prost. It does NOT add self-extern +// overrides for the library's own packages because prost interprets such an +// entry as "this package is external — skip generating types for it" and +// emits an empty stub. +// +// It also drops any dependency extern_path whose proto package is a strict +// prefix-parent of one of the library's own packages, for the same reason: +// prost's prefix-matching extern_path semantics treat a sub-package as +// matched and skip generation. Cross-crate references that would otherwise +// have used those filtered extern_paths emerge from prost as relative +// super::... paths; the proto_rust_library macro's generated lib.rs adds +// re-export shims to satisfy them. +func ResolveExternPathOptions(cfg *protoc.PluginConfiguration, r *rule.Rule, from label.Label) []string { + parents := ResolveTransitiveExternPaths(r, from) + owns := ownProtoPackages(r, from) + if len(owns) > 0 { + filtered := make([]string, 0, len(parents)) + for _, ep := range parents { + pkg := externPathPackage(ep) + if pkg != "" && isParentOfAnyOwn(pkg, owns) { + continue + } + filtered = append(filtered, ep) + } + parents = filtered + } + return mergeExternPathOptions(cfg, parents) +} + +// externPathPackage extracts the proto package from an "extern_path=.{pkg}=..." +// option string, or returns "" if the input doesn't match the expected +// format. +func externPathPackage(opt string) string { + const prefix = "extern_path=." + if !strings.HasPrefix(opt, prefix) { + return "" + } + rest := opt[len(prefix):] + eq := strings.IndexByte(rest, '=') + if eq < 0 { + return "" + } + return rest[:eq] +} + +// isParentOfAnyOwn reports whether pkg equals, or is a strict +// proto-package-prefix parent of, any package in ownPackages. +func isParentOfAnyOwn(pkg string, ownPackages map[string]bool) bool { + for own := range ownPackages { + if own == pkg || strings.HasPrefix(own, pkg+".") { + return true + } + } + return false +} + +// ResolveExternPathOptionsForReferences returns ResolveExternPathOptions plus +// self extern_path entries for the library's own proto packages whenever any +// of those packages is a strict sub-package of an imported (parent) package. +// +// Used by protoc-gen-prost-serde and protoc-gen-tonic. Both emit Rust code at +// crate-root using absolute crate-qualified paths; without a self-extern +// override prost's longest-prefix-wins matching would route a reference like +// ".pkg.sub.MyType" through the parent's external crate instead of resolving +// it to crate::pkg::sub::MyType. +func ResolveExternPathOptionsForReferences(cfg *protoc.PluginConfiguration, r *rule.Rule, from label.Label) []string { + parents := ResolveTransitiveExternPaths(r, from) + owns := ownProtoPackages(r, from) + selves := selfExternPathsForOverride(owns, parents) + + all := make([]string, 0, len(parents)+len(selves)) + all = append(all, parents...) + all = append(all, selves...) + sort.Strings(all) + return mergeExternPathOptions(cfg, all) +} + +// ResolveTransitiveExternPaths walks the transitive dependency graph of +// proto files and builds an extern_path option string for each dependency +// package. Self-extern overrides are NOT included — see +// ResolveExternPathOptionsForReferences for the variant that adds them. +func ResolveTransitiveExternPaths(r *rule.Rule, from label.Label) []string { + lib := r.PrivateAttr(protoc.ProtoLibraryKey) + if lib == nil { + return nil + } + library := lib.(protoc.ProtoLibrary) + libRule := library.Rule() + + if cached, ok := libRule.PrivateAttr(TransitiveExternPathsKey).([]string); ok { + return cached + } + + resolver := protoc.GlobalResolver() + + ownFiles := make(map[string]bool) + for _, src := range library.Srcs() { + ownFiles[path.Join(from.Pkg, src)] = true + } + + seen := make(map[string]bool) + stack := list.New() + for _, src := range library.Srcs() { + stack.PushBack(path.Join(from.Pkg, src)) + } + + externPathsByPackage := make(map[string]string) + + for stack.Len() > 0 { + current := stack.Front() + stack.Remove(current) + + protofile := current.Value.(string) + if seen[protofile] { + continue + } + seen[protofile] = true + + depends := resolver.Resolve("proto", "depends", protofile) + for _, dep := range depends { + depFile := path.Join(dep.Label.Pkg, dep.Label.Name) + stack.PushBack(depFile) + } + + if ownFiles[protofile] { + continue + } + + // Skip well-known types — prost ships these built-in. + if strings.HasPrefix(protofile, "google/protobuf/") { + continue + } + + results := resolver.Resolve("proto", "prost_extern", protofile) + if len(results) == 0 { + continue + } + + first := results[0] + protoPackage := first.Label.Pkg + crateName := first.Label.Name + if protoPackage == "" { + continue + } + if _, exists := externPathsByPackage[protoPackage]; exists { + continue + } + + // extern_path=.{proto_package}=::{crate_name}::{rust_module_path} + rustModulePath := strings.ReplaceAll(protoPackage, ".", "::") + externPathsByPackage[protoPackage] = "extern_path=." + protoPackage + "=::" + crateName + "::" + rustModulePath + } + + result := make([]string, 0, len(externPathsByPackage)) + for _, ep := range externPathsByPackage { + result = append(result, ep) + } + sort.Strings(result) + + libRule.SetPrivateAttr(TransitiveExternPathsKey, result) + return result +} + +// mergeExternPathOptions strips any pre-existing extern_path= entries from +// cfg.Options and returns the remainder concatenated with the supplied +// extern_path strings. +func mergeExternPathOptions(cfg *protoc.PluginConfiguration, externPaths []string) []string { + options := make([]string, 0, len(cfg.Options)+len(externPaths)) + for _, opt := range cfg.Options { + if !strings.HasPrefix(opt, "extern_path=") { + options = append(options, opt) + } + } + options = append(options, externPaths...) + return options +} + +// ownProtoPackages returns the set of proto packages the library itself +// contributes, computed from prost_extern resolver entries for each own +// proto file. Cached on the library rule. +func ownProtoPackages(r *rule.Rule, from label.Label) map[string]bool { + lib := r.PrivateAttr(protoc.ProtoLibraryKey) + if lib == nil { + return nil + } + library := lib.(protoc.ProtoLibrary) + libRule := library.Rule() + + if cached, ok := libRule.PrivateAttr(OwnProtoPackagesKey).(map[string]bool); ok { + return cached + } + + resolver := protoc.GlobalResolver() + out := make(map[string]bool) + for _, src := range library.Srcs() { + ownFile := path.Join(from.Pkg, src) + for _, ext := range resolver.Resolve("proto", "prost_extern", ownFile) { + if ext.Label.Pkg != "" { + out[ext.Label.Pkg] = true + } + } + } + + libRule.SetPrivateAttr(OwnProtoPackagesKey, out) + return out +} + +// selfExternPathsForOverride returns "extern_path=.{ownPkg}=crate::..." +// entries for every own proto package whose path is a strict sub-package of +// any package present in parents. parents is the slice of dependency +// extern_path option strings (as returned by ResolveTransitiveExternPaths). +func selfExternPathsForOverride(ownPackages map[string]bool, parents []string) []string { + if len(ownPackages) == 0 || len(parents) == 0 { + return nil + } + parentPkgs := parentExternPackages(parents) + out := make([]string, 0) + for ownPkg := range ownPackages { + if !hasParentInImports(ownPkg, parentPkgs) { + continue + } + rustModulePath := strings.ReplaceAll(ownPkg, ".", "::") + out = append(out, "extern_path=."+ownPkg+"=crate::"+rustModulePath) + } + return out +} + +// parentExternPackages parses a slice of "extern_path=.{pkg}=..." strings +// and returns the set of proto packages they cover. +func parentExternPackages(opts []string) map[string]bool { + out := make(map[string]bool, len(opts)) + const prefix = "extern_path=." + for _, opt := range opts { + if !strings.HasPrefix(opt, prefix) { + continue + } + rest := opt[len(prefix):] + eq := strings.IndexByte(rest, '=') + if eq < 0 { + continue + } + out[rest[:eq]] = true + } + return out +} + +// hasParentInImports reports whether any of importedPackages is a proto- +// package-prefix parent of ownPkg (e.g. "a.b" is a parent of "a.b.c"). +func hasParentInImports(ownPkg string, importedPackages map[string]bool) bool { + for imp := range importedPackages { + if strings.HasPrefix(ownPkg, imp+".") { + return true + } + } + return false +} diff --git a/pkg/plugin/neoeinstein/prost/extern_paths_test.go b/pkg/plugin/neoeinstein/prost/extern_paths_test.go new file mode 100644 index 00000000..fad87ee0 --- /dev/null +++ b/pkg/plugin/neoeinstein/prost/extern_paths_test.go @@ -0,0 +1,215 @@ +package prost_test + +import ( + "reflect" + "sort" + "testing" + + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/rule" + + "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost" + "github.com/stackb/rules_proto/v4/pkg/protoc" +) + +// makeLibraryRule constructs a proto_library rule with the given srcs and a +// ProtoLibraryKey private attr backed by a stub ProtoLibrary so that +// ResolveTransitiveExternPaths can read it. +func makeLibraryRule(name, pkg string, srcs []string) *rule.Rule { + r := rule.NewRule("proto_library", name) + r.SetAttr("srcs", srcs) + files := make([]*protoc.File, len(srcs)) + for i, s := range srcs { + files[i] = protoc.NewFile(pkg, s) + } + lib := protoc.NewOtherProtoLibrary(nil, r, files...) + r.SetPrivateAttr(protoc.ProtoLibraryKey, lib) + return r +} + +func TestResolveTransitiveExternPaths(t *testing.T) { + resolver := protoc.GlobalResolver() + + // Register prost_extern entries for two upstream libraries. + resolver.Provide("proto", "prost_extern", + "externtest/depA/a.proto", + label.New("", "extern.dep_a", "depA_rs")) + resolver.Provide("proto", "prost_extern", + "externtest/depB/b.proto", + label.New("", "extern.dep_b", "depB_rs")) + + // Set up the depends graph: own.proto -> depA -> depB, plus a WKT skip. + resolver.Provide("proto", "depends", + "externtest/own/own.proto", + label.New("", "externtest/depA", "a.proto")) + resolver.Provide("proto", "depends", + "externtest/own/own.proto", + label.New("", "google/protobuf", "duration.proto")) + resolver.Provide("proto", "depends", + "externtest/depA/a.proto", + label.New("", "externtest/depB", "b.proto")) + + r := makeLibraryRule("own_proto", "externtest/own", []string{"own.proto"}) + + from := label.New("", "externtest/own", "own_proto") + got := prost.ResolveTransitiveExternPaths(r, from) + sort.Strings(got) + + want := []string{ + "extern_path=.extern.dep_a=::depA_rs::extern::dep_a", + "extern_path=.extern.dep_b=::depB_rs::extern::dep_b", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("ResolveTransitiveExternPaths:\n got: %v\nwant: %v", got, want) + } + + // Second call should hit the cache and return the same slice. + got2 := prost.ResolveTransitiveExternPaths(r, from) + if !reflect.DeepEqual(got2, got) { + t.Errorf("cached call differs:\n got: %v\nwant: %v", got2, got) + } +} + +func TestResolveTransitiveExternPaths_OwnFilesSkipped(t *testing.T) { + resolver := protoc.GlobalResolver() + + // Register the library's own proto file as if it had been registered. + // The function must NOT include own files in the result. + resolver.Provide("proto", "prost_extern", + "selftest/me/m.proto", + label.New("", "selftest.me", "me_rs")) + + r := makeLibraryRule("me_proto", "selftest/me", []string{"m.proto"}) + from := label.New("", "selftest/me", "me_proto") + + got := prost.ResolveTransitiveExternPaths(r, from) + if len(got) != 0 { + t.Errorf("expected empty extern paths for own files, got %v", got) + } +} + +// TestResolveTransitiveExternPaths_SubpackageOfImport verifies that when the +// current library's proto package is a sub-package of an imported library's +// proto package, ResolveTransitiveExternPaths emits the imported package's +// extern_path entry (this is the prost variant — no self-extern override is +// added; that's the job of ResolveExternPathOptionsForReferences). +func TestResolveTransitiveExternPaths_SubpackageOfImport(t *testing.T) { + resolver := protoc.GlobalResolver() + + resolver.Provide("proto", "prost_extern", + "subpkg/parent/p.proto", + label.New("", "subpkg.parent", "parent_rs")) + + resolver.Provide("proto", "prost_extern", + "subpkg/parent/child/c.proto", + label.New("", "subpkg.parent.child", "child_rs")) + + resolver.Provide("proto", "depends", + "subpkg/parent/child/c.proto", + label.New("", "subpkg/parent", "p.proto")) + + r := makeLibraryRule("child_proto", "subpkg/parent/child", []string{"c.proto"}) + from := label.New("", "subpkg/parent/child", "child_proto") + + got := prost.ResolveTransitiveExternPaths(r, from) + want := []string{ + "extern_path=.subpkg.parent=::parent_rs::subpkg::parent", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("ResolveTransitiveExternPaths:\n got: %v\nwant: %v", got, want) + } +} + +// TestResolveExternPathOptionsForReferences_SubpackageOfImport verifies the +// reference-emitting variant (used by prost-serde and tonic) DOES add a self +// extern_path override for the current sub-package, so prost's longest- +// prefix-wins matching routes own-package references to crate::... rather +// than the parent extern crate. +func TestResolveExternPathOptionsForReferences_SubpackageOfImport(t *testing.T) { + resolver := protoc.GlobalResolver() + + resolver.Provide("proto", "prost_extern", + "refs/parent/p.proto", + label.New("", "refs.parent", "parent_rs")) + + resolver.Provide("proto", "prost_extern", + "refs/parent/child/c.proto", + label.New("", "refs.parent.child", "child_rs")) + + resolver.Provide("proto", "depends", + "refs/parent/child/c.proto", + label.New("", "refs/parent", "p.proto")) + + r := makeLibraryRule("child_proto", "refs/parent/child", []string{"c.proto"}) + from := label.New("", "refs/parent/child", "child_proto") + + cfg := &protoc.PluginConfiguration{Options: nil} + got := prost.ResolveExternPathOptionsForReferences(cfg, r, from) + want := []string{ + "extern_path=.refs.parent.child=crate::refs::parent::child", + "extern_path=.refs.parent=::parent_rs::refs::parent", + } + sort.Strings(want) + sort.Strings(got) + if !reflect.DeepEqual(got, want) { + t.Errorf("ResolveExternPathOptionsForReferences:\n got: %v\nwant: %v", got, want) + } +} + +// TestResolveTransitiveExternPaths_SiblingNotFiltered ensures the filter is +// not over-aggressive: a sibling package (one that shares a common prefix but +// is neither equal to nor an ancestor of the current package) must still +// produce an extern_path entry. +func TestResolveTransitiveExternPaths_SiblingNotFiltered(t *testing.T) { + resolver := protoc.GlobalResolver() + + // Sibling package "sibling.a.x" — shares prefix "sibling.a" with our own + // "sibling.a.y" but neither is a parent of the other. + resolver.Provide("proto", "prost_extern", + "sibling/a/x/x.proto", + label.New("", "sibling.a.x", "x_rs")) + + // Own package "sibling.a.y". + resolver.Provide("proto", "prost_extern", + "sibling/a/y/y.proto", + label.New("", "sibling.a.y", "y_rs")) + + resolver.Provide("proto", "depends", + "sibling/a/y/y.proto", + label.New("", "sibling/a/x", "x.proto")) + + r := makeLibraryRule("y_proto", "sibling/a/y", []string{"y.proto"}) + from := label.New("", "sibling/a/y", "y_proto") + + got := prost.ResolveTransitiveExternPaths(r, from) + want := []string{"extern_path=.sibling.a.x=::x_rs::sibling::a::x"} + if !reflect.DeepEqual(got, want) { + t.Errorf("ResolveTransitiveExternPaths:\n got: %v\nwant: %v", got, want) + } +} + +func TestResolveExternPathOptions_FiltersExisting(t *testing.T) { + // Library with no transitive deps — extern paths come only from cfg.Options + // after filtering out any pre-existing extern_path= entries. + r := makeLibraryRule("noop_proto", "exfilter/noop", []string{"n.proto"}) + from := label.New("", "exfilter/noop", "noop_proto") + + cfg := &protoc.PluginConfiguration{ + Options: []string{ + "compile_well_known_types=true", + "extern_path=.stale.pkg=::stale_rs::stale::pkg", + }, + } + + got := prost.ResolveExternPathOptions(cfg, r, from) + for _, opt := range got { + if opt == "extern_path=.stale.pkg=::stale_rs::stale::pkg" { + t.Errorf("stale extern_path option was not filtered: %v", got) + } + } + + want := []string{"compile_well_known_types=true"} + if !reflect.DeepEqual(got, want) { + t.Errorf("ResolveExternPathOptions:\n got: %v\nwant: %v", got, want) + } +} diff --git a/pkg/plugin/neoeinstein/prost/protoc-gen-prost.go b/pkg/plugin/neoeinstein/prost/protoc-gen-prost.go new file mode 100644 index 00000000..72b260cc --- /dev/null +++ b/pkg/plugin/neoeinstein/prost/protoc-gen-prost.go @@ -0,0 +1,121 @@ +package prost + +import ( + "path" + "sort" + + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/rule" + + "github.com/stackb/rules_proto/v4/pkg/protoc" +) + +const ( + ProtocGenProstPluginName = "neoeinstein:prost:protoc-gen-prost" +) + +func init() { + protoc.Plugins().MustRegisterPlugin(&ProtocGenProstPlugin{}) +} + +// ProtocGenProstPlugin implements Plugin for protoc-gen-prost. +type ProtocGenProstPlugin struct{} + +// Name implements part of the Plugin interface. +func (p *ProtocGenProstPlugin) Name() string { + return ProtocGenProstPluginName +} + +// Configure implements part of the Plugin interface. +func (p *ProtocGenProstPlugin) Configure(ctx *protoc.PluginContext) *protoc.PluginConfiguration { + if !p.shouldApply(ctx.ProtoLibrary) { + return nil + } + + outputs := p.outputs(ctx.ProtoLibrary) + if len(outputs) == 0 { + return nil + } + + p.registerExternPaths(ctx.ProtoLibrary) + + return &protoc.PluginConfiguration{ + Label: label.New("build_stack_rules_proto", "plugin/neoeinstein/prost", "protoc-gen-prost"), + Outputs: outputs, + Options: ctx.PluginConfig.GetOptions(), + } +} + +// ResolvePluginOptions implements the PluginOptionsResolver interface. +// It computes extern_path options based on transitive proto file dependencies. +func (p *ProtocGenProstPlugin) ResolvePluginOptions(cfg *protoc.PluginConfiguration, r *rule.Rule, from label.Label) []string { + return ResolveExternPathOptions(cfg, r, from) +} + +// shouldApply returns true if the library has files with messages or enums. +func (p *ProtocGenProstPlugin) shouldApply(lib protoc.ProtoLibrary) bool { + for _, f := range lib.Files() { + if f.HasMessages() || f.HasEnums() { + return true + } + } + return false +} + +// outputs computes the output files for the plugin. Prost generates one .rs +// file per proto package, named {proto_package}.rs. The path includes the +// file's directory so that mergeSources can handle the rel stripping. +func (p *ProtocGenProstPlugin) outputs(lib protoc.ProtoLibrary) []string { + seen := make(map[string]bool) + outputs := make([]string, 0) + + for _, f := range lib.Files() { + if !(f.HasMessages() || f.HasEnums()) { + continue + } + pkg := f.Package() + if pkg.Name == "" { + continue + } + if seen[pkg.Name] { + continue + } + seen[pkg.Name] = true + + filename := pkg.Name + ".rs" + if f.Dir != "" { + filename = path.Join(f.Dir, filename) + } + outputs = append(outputs, filename) + } + + sort.Strings(outputs) + return outputs +} + +// registerExternPaths records prost extern_path data in the GlobalResolver for +// each proto file in the library. This data is later consumed by +// ResolveTransitiveExternPaths when computing extern_path options for dependent +// packages. +// +// The label encodes: Pkg = proto package name, Name = crate name. The crate +// name comes from protoc.RustCrateName so it matches the rust_library target +// name produced by RustLibrary.Name() — without this alignment, downstream +// extern_path entries would point at a non-existent crate and rustc would +// fail to resolve types. +func (p *ProtocGenProstPlugin) registerExternPaths(lib protoc.ProtoLibrary) { + for _, f := range lib.Files() { + pkg := f.Package() + if pkg.Name == "" { + continue + } + + protoFile := path.Join(f.Dir, f.Basename) + protoc.GlobalResolver().Provide( + "proto", + "prost_extern", + protoFile, + label.New("", pkg.Name, protoc.RustCrateName(pkg.Name)), + ) + } +} diff --git a/pkg/plugin/neoeinstein/prost/protoc-gen-prost_test.go b/pkg/plugin/neoeinstein/prost/protoc-gen-prost_test.go new file mode 100644 index 00000000..01a796f6 --- /dev/null +++ b/pkg/plugin/neoeinstein/prost/protoc-gen-prost_test.go @@ -0,0 +1,69 @@ +package prost_test + +import ( + "testing" + + "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost" + "github.com/stackb/rules_proto/v4/pkg/plugintest" +) + +func TestProtocGenProstPlugin(t *testing.T) { + plugintest.Cases(t, &prost.ProtocGenProstPlugin{}, map[string]plugintest.Case{ + "empty - no messages or enums": { + Input: "", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost implementation neoeinstein:prost:protoc-gen-prost", + ), + PluginName: "protoc-gen-prost", + Configuration: nil, + SkipIntegration: true, + }, + "simple message": { + Input: "package example.v1;\nmessage Foo {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost implementation neoeinstein:prost:protoc-gen-prost", + ), + PluginName: "protoc-gen-prost", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/prost:protoc-gen-prost"), + plugintest.WithOutputs("example.v1.rs"), + ), + SkipIntegration: true, + }, + "simple enum": { + Input: "package example.v1;\nenum Color { RED = 0; }", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost implementation neoeinstein:prost:protoc-gen-prost", + ), + PluginName: "protoc-gen-prost", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/prost:protoc-gen-prost"), + plugintest.WithOutputs("example.v1.rs"), + ), + SkipIntegration: true, + }, + "no package - skipped": { + Input: "message Foo {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost implementation neoeinstein:prost:protoc-gen-prost", + ), + PluginName: "protoc-gen-prost", + Configuration: nil, + SkipIntegration: true, + }, + "with options": { + Input: "package example.v1;\nmessage Foo {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost implementation neoeinstein:prost:protoc-gen-prost", + "proto_plugin", "protoc-gen-prost option type_attribute=.example.v1.Foo=#[derive(Eq)]", + ), + PluginName: "protoc-gen-prost", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/prost:protoc-gen-prost"), + plugintest.WithOutputs("example.v1.rs"), + plugintest.WithOptions("type_attribute=.example.v1.Foo=#[derive(Eq)]"), + ), + SkipIntegration: true, + }, + }) +} diff --git a/pkg/plugin/neoeinstein/prost_serde/BUILD.bazel b/pkg/plugin/neoeinstein/prost_serde/BUILD.bazel new file mode 100644 index 00000000..58ddda94 --- /dev/null +++ b/pkg/plugin/neoeinstein/prost_serde/BUILD.bazel @@ -0,0 +1,32 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "prost_serde", + srcs = ["protoc-gen-prost-serde.go"], + importpath = "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost_serde", + visibility = ["//visibility:public"], + deps = [ + "//pkg/plugin/neoeinstein/prost", + "//pkg/protoc", + "@bazel_gazelle//label", + "@bazel_gazelle//rule", + ], +) + +go_test( + name = "prost_serde_test", + srcs = ["protoc-gen-prost-serde_test.go"], + deps = [ + ":prost_serde", + "//pkg/plugintest", + ], +) + +filegroup( + name = "all_files", + testonly = True, + srcs = [ + "BUILD.bazel", + ] + glob(["*.go"]), + visibility = ["//pkg:__pkg__"], +) diff --git a/pkg/plugin/neoeinstein/prost_serde/protoc-gen-prost-serde.go b/pkg/plugin/neoeinstein/prost_serde/protoc-gen-prost-serde.go new file mode 100644 index 00000000..e39fdda3 --- /dev/null +++ b/pkg/plugin/neoeinstein/prost_serde/protoc-gen-prost-serde.go @@ -0,0 +1,95 @@ +package prost_serde + +import ( + "path" + "sort" + + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/rule" + + "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost" + "github.com/stackb/rules_proto/v4/pkg/protoc" +) + +const ProtocGenProstSerdePluginName = "neoeinstein:prost:protoc-gen-prost-serde" + +func init() { + protoc.Plugins().MustRegisterPlugin(&ProtocGenProstSerdePlugin{}) +} + +// ProtocGenProstSerdePlugin implements Plugin for protoc-gen-prost-serde. +type ProtocGenProstSerdePlugin struct{} + +// Name implements part of the Plugin interface. +func (p *ProtocGenProstSerdePlugin) Name() string { + return ProtocGenProstSerdePluginName +} + +// Configure implements part of the Plugin interface. +func (p *ProtocGenProstSerdePlugin) Configure(ctx *protoc.PluginContext) *protoc.PluginConfiguration { + if !p.shouldApply(ctx.ProtoLibrary) { + return nil + } + + outputs := p.outputs(ctx.ProtoLibrary) + if len(outputs) == 0 { + return nil + } + + return &protoc.PluginConfiguration{ + Label: label.New("build_stack_rules_proto", "plugin/neoeinstein/prost-serde", "protoc-gen-prost-serde"), + Outputs: outputs, + Options: ctx.PluginConfig.GetOptions(), + } +} + +// ResolvePluginOptions implements the PluginOptionsResolver interface. +// It computes extern_path options based on transitive proto file dependencies +// AND emits self-extern overrides for the library's own packages — needed +// because prost-serde generates impl blocks at crate-root using absolute +// crate-qualified paths and would otherwise be shadowed by parent extern +// crate references through prost's longest-prefix matching. +func (p *ProtocGenProstSerdePlugin) ResolvePluginOptions(cfg *protoc.PluginConfiguration, r *rule.Rule, from label.Label) []string { + return prost.ResolveExternPathOptionsForReferences(cfg, r, from) +} + +// shouldApply returns true if the library has files with messages or enums. +func (p *ProtocGenProstSerdePlugin) shouldApply(lib protoc.ProtoLibrary) bool { + for _, f := range lib.Files() { + if f.HasMessages() || f.HasEnums() { + return true + } + } + return false +} + +// outputs computes the output files for the plugin. Prost-serde generates one +// .serde.rs file per proto package. The path includes the file's directory so +// that mergeSources can handle the rel stripping. +func (p *ProtocGenProstSerdePlugin) outputs(lib protoc.ProtoLibrary) []string { + seen := make(map[string]bool) + outputs := make([]string, 0) + + for _, f := range lib.Files() { + if !(f.HasMessages() || f.HasEnums()) { + continue + } + pkg := f.Package() + if pkg.Name == "" { + continue + } + if seen[pkg.Name] { + continue + } + seen[pkg.Name] = true + + filename := pkg.Name + ".serde.rs" + if f.Dir != "" { + filename = path.Join(f.Dir, filename) + } + outputs = append(outputs, filename) + } + + sort.Strings(outputs) + return outputs +} diff --git a/pkg/plugin/neoeinstein/prost_serde/protoc-gen-prost-serde_test.go b/pkg/plugin/neoeinstein/prost_serde/protoc-gen-prost-serde_test.go new file mode 100644 index 00000000..3b3e400a --- /dev/null +++ b/pkg/plugin/neoeinstein/prost_serde/protoc-gen-prost-serde_test.go @@ -0,0 +1,69 @@ +package prost_serde_test + +import ( + "testing" + + "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost_serde" + "github.com/stackb/rules_proto/v4/pkg/plugintest" +) + +func TestProtocGenProstSerdePlugin(t *testing.T) { + plugintest.Cases(t, &prost_serde.ProtocGenProstSerdePlugin{}, map[string]plugintest.Case{ + "empty - no messages or enums": { + Input: "", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost-serde implementation neoeinstein:prost:protoc-gen-prost-serde", + ), + PluginName: "protoc-gen-prost-serde", + Configuration: nil, + SkipIntegration: true, + }, + "simple message": { + Input: "package example.v1;\nmessage Foo {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost-serde implementation neoeinstein:prost:protoc-gen-prost-serde", + ), + PluginName: "protoc-gen-prost-serde", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/prost-serde:protoc-gen-prost-serde"), + plugintest.WithOutputs("example.v1.serde.rs"), + ), + SkipIntegration: true, + }, + "simple enum": { + Input: "package example.v1;\nenum Color { RED = 0; }", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost-serde implementation neoeinstein:prost:protoc-gen-prost-serde", + ), + PluginName: "protoc-gen-prost-serde", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/prost-serde:protoc-gen-prost-serde"), + plugintest.WithOutputs("example.v1.serde.rs"), + ), + SkipIntegration: true, + }, + "no package - skipped": { + Input: "message Foo {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost-serde implementation neoeinstein:prost:protoc-gen-prost-serde", + ), + PluginName: "protoc-gen-prost-serde", + Configuration: nil, + SkipIntegration: true, + }, + "with options": { + Input: "package example.v1;\nmessage Foo {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-prost-serde implementation neoeinstein:prost:protoc-gen-prost-serde", + "proto_plugin", "protoc-gen-prost-serde option compile_well_known_types=true", + ), + PluginName: "protoc-gen-prost-serde", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/prost-serde:protoc-gen-prost-serde"), + plugintest.WithOutputs("example.v1.serde.rs"), + plugintest.WithOptions("compile_well_known_types=true"), + ), + SkipIntegration: true, + }, + }) +} diff --git a/pkg/plugin/neoeinstein/tonic/BUILD.bazel b/pkg/plugin/neoeinstein/tonic/BUILD.bazel new file mode 100644 index 00000000..ab6401a4 --- /dev/null +++ b/pkg/plugin/neoeinstein/tonic/BUILD.bazel @@ -0,0 +1,32 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "tonic", + srcs = ["protoc-gen-tonic.go"], + importpath = "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/tonic", + visibility = ["//visibility:public"], + deps = [ + "//pkg/plugin/neoeinstein/prost", + "//pkg/protoc", + "@bazel_gazelle//label", + "@bazel_gazelle//rule", + ], +) + +go_test( + name = "tonic_test", + srcs = ["protoc-gen-tonic_test.go"], + deps = [ + ":tonic", + "//pkg/plugintest", + ], +) + +filegroup( + name = "all_files", + testonly = True, + srcs = [ + "BUILD.bazel", + ] + glob(["*.go"]), + visibility = ["//pkg:__pkg__"], +) diff --git a/pkg/plugin/neoeinstein/tonic/protoc-gen-tonic.go b/pkg/plugin/neoeinstein/tonic/protoc-gen-tonic.go new file mode 100644 index 00000000..7caefef8 --- /dev/null +++ b/pkg/plugin/neoeinstein/tonic/protoc-gen-tonic.go @@ -0,0 +1,95 @@ +package tonic + +import ( + "path" + "sort" + + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/rule" + + "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/prost" + "github.com/stackb/rules_proto/v4/pkg/protoc" +) + +const ProtocGenTonicPluginName = "neoeinstein:prost:protoc-gen-tonic" + +func init() { + protoc.Plugins().MustRegisterPlugin(&ProtocGenTonicPlugin{}) +} + +// ProtocGenTonicPlugin implements Plugin for protoc-gen-tonic. +type ProtocGenTonicPlugin struct{} + +// Name implements part of the Plugin interface. +func (p *ProtocGenTonicPlugin) Name() string { + return ProtocGenTonicPluginName +} + +// Configure implements part of the Plugin interface. +func (p *ProtocGenTonicPlugin) Configure(ctx *protoc.PluginContext) *protoc.PluginConfiguration { + if !p.shouldApply(ctx.ProtoLibrary) { + return nil + } + + outputs := p.outputs(ctx.ProtoLibrary) + if len(outputs) == 0 { + return nil + } + + return &protoc.PluginConfiguration{ + Label: label.New("build_stack_rules_proto", "plugin/neoeinstein/tonic", "protoc-gen-tonic"), + Outputs: outputs, + Options: ctx.PluginConfig.GetOptions(), + } +} + +// ResolvePluginOptions implements the PluginOptionsResolver interface. +// It computes extern_path options based on transitive proto file dependencies +// AND emits self-extern overrides for the library's own packages — needed +// because tonic-generated client/server code references prost types via +// crate-qualified paths and would otherwise be shadowed by parent extern +// crate references through prost's longest-prefix matching. +func (p *ProtocGenTonicPlugin) ResolvePluginOptions(cfg *protoc.PluginConfiguration, r *rule.Rule, from label.Label) []string { + return prost.ResolveExternPathOptionsForReferences(cfg, r, from) +} + +// shouldApply returns true if the library has files with services. +func (p *ProtocGenTonicPlugin) shouldApply(lib protoc.ProtoLibrary) bool { + for _, f := range lib.Files() { + if f.HasServices() { + return true + } + } + return false +} + +// outputs computes the output files for the plugin. Tonic generates one +// .tonic.rs file per proto package that has services. The path includes the +// file's directory so that mergeSources can handle the rel stripping. +func (p *ProtocGenTonicPlugin) outputs(lib protoc.ProtoLibrary) []string { + seen := make(map[string]bool) + outputs := make([]string, 0) + + for _, f := range lib.Files() { + if !f.HasServices() { + continue + } + pkg := f.Package() + if pkg.Name == "" { + continue + } + if seen[pkg.Name] { + continue + } + seen[pkg.Name] = true + + filename := pkg.Name + ".tonic.rs" + if f.Dir != "" { + filename = path.Join(f.Dir, filename) + } + outputs = append(outputs, filename) + } + + sort.Strings(outputs) + return outputs +} diff --git a/pkg/plugin/neoeinstein/tonic/protoc-gen-tonic_test.go b/pkg/plugin/neoeinstein/tonic/protoc-gen-tonic_test.go new file mode 100644 index 00000000..c4db3c7c --- /dev/null +++ b/pkg/plugin/neoeinstein/tonic/protoc-gen-tonic_test.go @@ -0,0 +1,87 @@ +package tonic_test + +import ( + "testing" + + "github.com/stackb/rules_proto/v4/pkg/plugin/neoeinstein/tonic" + "github.com/stackb/rules_proto/v4/pkg/plugintest" +) + +func TestProtocGenTonicPlugin(t *testing.T) { + plugintest.Cases(t, &tonic.ProtocGenTonicPlugin{}, map[string]plugintest.Case{ + "empty - no services": { + Input: "", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-tonic implementation neoeinstein:prost:protoc-gen-tonic", + ), + PluginName: "protoc-gen-tonic", + Configuration: nil, + SkipIntegration: true, + }, + "only messages - no output": { + Input: "package example.v1;\nmessage Foo {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-tonic implementation neoeinstein:prost:protoc-gen-tonic", + ), + PluginName: "protoc-gen-tonic", + Configuration: nil, + SkipIntegration: true, + }, + "only enums - no output": { + Input: "package example.v1;\nenum Color { RED = 0; }", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-tonic implementation neoeinstein:prost:protoc-gen-tonic", + ), + PluginName: "protoc-gen-tonic", + Configuration: nil, + SkipIntegration: true, + }, + "simple service": { + Input: "package example.v1;\nservice Greeter {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-tonic implementation neoeinstein:prost:protoc-gen-tonic", + ), + PluginName: "protoc-gen-tonic", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/tonic:protoc-gen-tonic"), + plugintest.WithOutputs("example.v1.tonic.rs"), + ), + SkipIntegration: true, + }, + "no package - skipped": { + Input: "service Greeter {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-tonic implementation neoeinstein:prost:protoc-gen-tonic", + ), + PluginName: "protoc-gen-tonic", + Configuration: nil, + SkipIntegration: true, + }, + "with options": { + Input: "package example.v1;\nservice Greeter {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-tonic implementation neoeinstein:prost:protoc-gen-tonic", + "proto_plugin", "protoc-gen-tonic option compile_well_known_types=true", + ), + PluginName: "protoc-gen-tonic", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/tonic:protoc-gen-tonic"), + plugintest.WithOutputs("example.v1.tonic.rs"), + plugintest.WithOptions("compile_well_known_types=true"), + ), + SkipIntegration: true, + }, + "message and service": { + Input: "package example.v1;\nmessage Foo {}\nservice Greeter {}", + Directives: plugintest.WithDirectives( + "proto_plugin", "protoc-gen-tonic implementation neoeinstein:prost:protoc-gen-tonic", + ), + PluginName: "protoc-gen-tonic", + Configuration: plugintest.WithConfiguration( + plugintest.WithLabel(t, "@build_stack_rules_proto//plugin/neoeinstein/tonic:protoc-gen-tonic"), + plugintest.WithOutputs("example.v1.tonic.rs"), + ), + SkipIntegration: true, + }, + }) +} diff --git a/pkg/protoc/BUILD.bazel b/pkg/protoc/BUILD.bazel index 7752307e..126c3231 100644 --- a/pkg/protoc/BUILD.bazel +++ b/pkg/protoc/BUILD.bazel @@ -28,6 +28,7 @@ go_library( "rule_provider.go", "rule_registry.go", "ruleindex.go", + "rust_keywords.go", "starlark_plugin.go", "starlark_rule.go", "starlark_util.go", @@ -61,10 +62,12 @@ go_test( "other_proto_library_test.go", "package_config_test.go", "package_test.go", + "proto_compile_test.go", "proto_plugin_config_test.go", "protoc_configuration_test.go", "resolver_test.go", "rewrite_test.go", + "rust_keywords_test.go", "starlark_plugin_test.go", "starlark_rule_test.go", "starlark_util_test.go", diff --git a/pkg/protoc/package.go b/pkg/protoc/package.go index 8465db82..e481bf0c 100644 --- a/pkg/protoc/package.go +++ b/pkg/protoc/package.go @@ -200,6 +200,47 @@ func (s *Package) getProvidedRules(providers []RuleProvider, shouldResolve bool) continue } + // Detect merge: if r is already in the rules slice (same pointer), + // the rule was merged (and possibly renamed). Update bookkeeping + // and accumulate imports, but don't add a duplicate entry. + merged := false + for i, existing := range rules { + if existing != r { + continue + } + merged = true + + // Update ruleIndexes and providers if the rule was renamed + newFrom := label.New("", s.rel, r.Name()) + if _, ok := ruleIndexes[newFrom]; !ok { + for oldFrom, idx := range ruleIndexes { + if idx == i { + delete(ruleIndexes, oldFrom) + ruleIndexes[newFrom] = idx + if prov, ok := s.providers[oldFrom.Name]; ok { + delete(s.providers, oldFrom.Name) + s.providers[r.Name()] = prov + } + break + } + } + } + + // Accumulate imports from the merged library + if shouldResolve { + lib := s.ruleLibs[p] + imports := lib.Imports() + if existingImports, ok := r.PrivateAttr(config.GazelleImportsKey).([]string); ok { + imports = append(imports, existingImports...) + } + r.SetPrivateAttr(config.GazelleImportsKey, imports) + } + break + } + if merged { + continue + } + if shouldResolve { lib := s.ruleLibs[p] r.SetPrivateAttr(ProtoLibraryKey, lib) diff --git a/pkg/protoc/proto_compile.go b/pkg/protoc/proto_compile.go index 48cd050c..bb0b7bcd 100644 --- a/pkg/protoc/proto_compile.go +++ b/pkg/protoc/proto_compile.go @@ -3,7 +3,9 @@ package protoc import ( "fmt" "log" + "path" "sort" + "strings" "github.com/bazelbuild/bazel-gazelle/config" "github.com/bazelbuild/bazel-gazelle/label" @@ -34,6 +36,7 @@ func (s *protoCompile) KindInfo() rule.KindInfo { "plugins": true, "output_mappings": true, "options": true, + "protos": true, }, SubstituteAttrs: map[string]bool{ "out": true, @@ -100,9 +103,62 @@ func (s *protoCompileRule) Outputs() []string { // Rule implements part of the ruleProvider interface. func (s *protoCompileRule) Rule(otherGen ...*rule.Rule) *rule.Rule { - newRule := rule.NewRule(s.Kind(), s.Name()) outputs := s.Outputs() + // Check for output overlap with existing proto_compile rules of the same + // kind. When a package-level plugin (e.g. protoc-gen-prost) produces the + // same output file from multiple proto_library rules, merge them into a + // single proto_compile rule using the "protos" attribute. + for _, other := range otherGen { + if other.Kind() != s.Kind() { + continue + } + otherOutputs := other.AttrStrings(s.outputsAttrName) + if !hasOverlap(outputs, otherOutputs) { + continue + } + + // Merge outputs + other.SetAttr(s.outputsAttrName, DeduplicateAndSort(append(otherOutputs, outputs...))) + + // Convert singular "proto" to list "protos" if needed, then append + existingProtos := other.AttrStrings("protos") + if len(existingProtos) == 0 { + if p := other.AttrString("proto"); p != "" { + existingProtos = []string{p} + other.DelAttr("proto") + } + } + existingProtos = append(existingProtos, s.config.Library.Name()) + other.SetAttr("protos", DeduplicateAndSort(existingProtos)) + + // Merge plugins + otherPlugins := other.AttrStrings("plugins") + otherPlugins = append(otherPlugins, GetPluginLabels(s.config.Plugins)...) + other.SetAttr("plugins", DeduplicateAndSort(otherPlugins)) + + // Merge output_mappings + if len(s.config.Mappings) > 0 { + existing := other.AttrStrings("output_mappings") + for k, v := range s.config.Mappings { + existing = append(existing, k+"="+v) + } + other.SetAttr("output_mappings", DeduplicateAndSort(existing)) + } + + // Rename merged rule based on output content (proto package) rather + // than the first library's arbitrary name. + mergedOutputs := DeduplicateAndSort(append(otherOutputs, outputs...)) + if name := mergedRuleName(mergedOutputs, s.config.Prefix, s.nameSuffix); name != "" { + other.SetName(name) + } + + return other + } + + // No overlap found — create new rule + newRule := rule.NewRule(s.Kind(), s.Name()) + newRule.SetAttr(s.outputsAttrName, outputs) newRule.SetAttr("plugins", GetPluginLabels(s.config.Plugins)) newRule.SetAttr("proto", s.config.Library.Name()) @@ -156,6 +212,37 @@ func (s *protoCompileRule) Rule(otherGen ...*rule.Rule) *rule.Rule { return newRule } +// mergedRuleName derives a rule name from the output filenames for a merged +// proto_compile rule. It takes the first output (sorted), strips the file +// extension, replaces dots with underscores, and formats as +// {sanitized}_{prefix}_{suffix}. +func mergedRuleName(outputs []string, prefix, suffix string) string { + if len(outputs) == 0 { + return "" + } + base := outputs[0] + ext := path.Ext(base) + if ext != "" { + base = base[:len(base)-len(ext)] + } + sanitized := strings.ReplaceAll(base, ".", "_") + return fmt.Sprintf("%s_%s_%s", sanitized, prefix, suffix) +} + +// hasOverlap returns true if two string slices share any common element. +func hasOverlap(a, b []string) bool { + set := make(map[string]bool, len(a)) + for _, s := range a { + set[s] = true + } + for _, s := range b { + if set[s] { + return true + } + } + return false +} + // Imports implements part of the RuleProvider interface. func (s *protoCompileRule) Imports(c *config.Config, r *rule.Rule, file *rule.File) []resolve.ImportSpec { return nil diff --git a/pkg/protoc/proto_compile_test.go b/pkg/protoc/proto_compile_test.go new file mode 100644 index 00000000..0bd697a5 --- /dev/null +++ b/pkg/protoc/proto_compile_test.go @@ -0,0 +1,155 @@ +package protoc + +import ( + "testing" + + "github.com/bazelbuild/bazel-gazelle/config" + "github.com/bazelbuild/bazel-gazelle/rule" + "github.com/emicklei/proto" +) + +func TestMergedRuleName(t *testing.T) { + for name, tc := range map[string]struct { + outputs []string + prefix, suffix string + want string + }{ + "empty outputs": { + outputs: nil, prefix: "rust", suffix: "compile", want: "", + }, + "single dotted name": { + outputs: []string{"my.package.rs"}, prefix: "rust", suffix: "compile", + want: "my_package_rust_compile", + }, + "no extension": { + outputs: []string{"my_package"}, prefix: "rust", suffix: "compile", + want: "my_package_rust_compile", + }, + "sorted outputs picks first": { + outputs: []string{"a.b.rs", "z.y.rs"}, prefix: "rust", suffix: "compile", + want: "a_b_rust_compile", + }, + "package with multiple dots": { + outputs: []string{"google.protobuf.compiler.rs"}, prefix: "rust", suffix: "compile", + want: "google_protobuf_compiler_rust_compile", + }, + } { + t.Run(name, func(t *testing.T) { + if got := mergedRuleName(tc.outputs, tc.prefix, tc.suffix); got != tc.want { + t.Errorf("mergedRuleName(%v, %q, %q) = %q, want %q", + tc.outputs, tc.prefix, tc.suffix, got, tc.want) + } + }) + } +} + +func TestHasOverlap(t *testing.T) { + for name, tc := range map[string]struct { + a, b []string + want bool + }{ + "both empty": { + a: nil, b: nil, want: false, + }, + "no overlap": { + a: []string{"a", "b"}, b: []string{"c", "d"}, want: false, + }, + "overlap": { + a: []string{"a", "b"}, b: []string{"b", "c"}, want: true, + }, + "identical": { + a: []string{"a"}, b: []string{"a"}, want: true, + }, + } { + t.Run(name, func(t *testing.T) { + if got := hasOverlap(tc.a, tc.b); got != tc.want { + t.Errorf("hasOverlap(%v, %v) = %v, want %v", tc.a, tc.b, got, tc.want) + } + }) + } +} + +// packageLevelPlugin is a plugin that produces a single output per proto +// package, regardless of which proto file it's configured with. This simulates +// plugins like protoc-gen-prost. +type packageLevelPlugin struct{} + +func (p *packageLevelPlugin) Name() string { + return "protoc:package_level" +} + +func (p *packageLevelPlugin) Configure(ctx *PluginContext) *PluginConfiguration { + return &PluginConfiguration{ + Label: ctx.PluginConfig.Label, + Outputs: []string{"my_package.rs"}, + } +} + +func init() { + Plugins().MustRegisterPlugin(&packageLevelPlugin{}) +} + +func makeProtoLibrary(name, filename string) ProtoLibrary { + r := rule.NewRule("proto_library", name) + f := NewFile("pkg", filename) + f.pkg = proto.Package{Name: "my.package"} + f.messages = append(f.messages, proto.Message{Name: "Msg"}) + return NewOtherProtoLibrary(nil, r, f) +} + +func aggregationPackageConfig() *PackageConfig { + c := NewPackageConfig(&config.Config{}) + if err := c.ParseDirectives("pkg", withDirectives( + "proto_rule", "proto_compile implementation stackb:rules_proto:proto_compile", + "proto_plugin", "pkg_plugin implementation protoc:package_level", + "proto_plugin", "pkg_plugin enabled true", + "proto_language", "rust plugin pkg_plugin", + "proto_language", "rust enabled true", + "proto_language", "rust rule proto_compile", + )); err != nil { + panic("bad config: " + err.Error()) + } + return c +} + +// ExamplePackageAggregation demonstrates that when two proto_library rules +// produce the same output file, their proto_compile rules are merged into a +// single rule using the "protos" attribute. +func ExamplePackage_aggregation() { + pkg := NewPackage( + "pkg", + aggregationPackageConfig(), + makeProtoLibrary("a_proto", "a.proto"), + makeProtoLibrary("b_proto", "b.proto"), + ) + formaatRules(pkg.Rules()...) + // Output: + // proto_compile( + // name = "my_package_rust_compile", + // output_mappings = ["my_package.rs=my_package.rs"], + // outputs = ["my_package.rs"], + // plugins = ["//:"], + // protos = [ + // "a_proto", + // "b_proto", + // ], + // ) +} + +// ExamplePackageNoAggregation demonstrates that when two proto_library rules +// produce different output files, separate proto_compile rules are emitted. +func ExamplePackage_noAggregation() { + pkg := NewPackage( + exampleDir, + examplePackageConfig(), + exampleProtoLibrary(), + ) + formaatRules(pkg.Rules()...) + // Output: + // proto_compile( + // name = "test_fake_compile", + // outputs = ["test_fake.pb.go"], + // plugins = ["@build_stack_rules_proto//plugin/builtin:fake"], + // proto = "test_proto", + // ) +} diff --git a/pkg/protoc/rust_keywords.go b/pkg/protoc/rust_keywords.go new file mode 100644 index 00000000..f6202952 --- /dev/null +++ b/pkg/protoc/rust_keywords.go @@ -0,0 +1,122 @@ +package protoc + +import ( + "path" + "strings" +) + +var rustKeywords = map[string]bool{ + // Strict keywords + "as": true, + "break": true, + "const": true, + "continue": true, + "crate": true, + "else": true, + "enum": true, + "extern": true, + "false": true, + "fn": true, + "for": true, + "if": true, + "impl": true, + "in": true, + "let": true, + "loop": true, + "match": true, + "mod": true, + "move": true, + "mut": true, + "pub": true, + "ref": true, + "return": true, + "self": true, + "Self": true, + "static": true, + "struct": true, + "super": true, + "trait": true, + "true": true, + "type": true, + "unsafe": true, + "use": true, + "where": true, + "while": true, + // Async keywords (edition 2018+) + "async": true, + "await": true, + "dyn": true, + // Reserved keywords + "abstract": true, + "become": true, + "box": true, + "do": true, + "final": true, + "macro": true, + "override": true, + "priv": true, + "try": true, + "typeof": true, + "unsized": true, + "virtual": true, + "yield": true, +} + +// RustKeywordEscapeMappings computes output mappings needed when +// protoc-gen-prost escapes Rust keywords with the r# prefix in directory paths. +// +// For example, proto package "google.type" causes prost to write files to +// "google/r#type/" instead of "google/type/". This function returns a mapping +// from each declared output filename to the actual prost output path. +// +// Returns an empty map if no package segments are Rust keywords. +func RustKeywordEscapeMappings(pkg string, outputs []string) map[string]string { + if pkg == "" || len(outputs) == 0 { + return nil + } + + segments := strings.Split(pkg, ".") + + // Check if any segment is a Rust keyword. + needsEscape := false + for _, seg := range segments { + if rustKeywords[seg] { + needsEscape = true + break + } + } + if !needsEscape { + return nil + } + + // Build the escaped directory path. + escaped := make([]string, len(segments)) + for i, seg := range segments { + if rustKeywords[seg] { + escaped[i] = "r#" + seg + } else { + escaped[i] = seg + } + } + escapedDir := strings.Join(escaped, "/") + + mappings := make(map[string]string, len(outputs)) + for _, output := range outputs { + base := path.Base(output) + mappings[base] = path.Join(escapedDir, base) + } + return mappings +} + +// RustCrateName returns the canonical Rust crate name for a proto package. +// The proto package's dots are replaced with underscores and an "_rs" suffix +// is appended so the resulting identifier is unambiguously a Rust target, +// not a proto_library (e.g. "trumid.common.utils.state.snapshot.proto" → +// "trumid_common_utils_state_snapshot_proto_rs"). Returns the empty string +// for an empty input. +func RustCrateName(protoPackage string) string { + if protoPackage == "" { + return "" + } + return strings.ReplaceAll(protoPackage, ".", "_") + "_rs" +} diff --git a/pkg/protoc/rust_keywords_test.go b/pkg/protoc/rust_keywords_test.go new file mode 100644 index 00000000..8b3b396c --- /dev/null +++ b/pkg/protoc/rust_keywords_test.go @@ -0,0 +1,100 @@ +package protoc + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestRustKeywordEscapeMappings(t *testing.T) { + for name, tc := range map[string]struct { + pkg string + outputs []string + want map[string]string + }{ + "empty package": { + pkg: "", + outputs: []string{"foo.rs"}, + want: nil, + }, + "empty outputs": { + pkg: "google.type", + outputs: nil, + want: nil, + }, + "no keywords": { + pkg: "google.api", + outputs: []string{"google.api.rs", "google.api.serde.rs"}, + want: nil, + }, + "type keyword": { + pkg: "google.type", + outputs: []string{"google.type.rs", "google.type.serde.rs"}, + want: map[string]string{ + "google.type.rs": "google/r#type/google.type.rs", + "google.type.serde.rs": "google/r#type/google.type.serde.rs", + }, + }, + "keyword at start": { + pkg: "type.example", + outputs: []string{"type.example.rs"}, + want: map[string]string{ + "type.example.rs": "r#type/example/type.example.rs", + }, + }, + "multiple keywords": { + pkg: "self.type", + outputs: []string{"self.type.rs"}, + want: map[string]string{ + "self.type.rs": "r#self/r#type/self.type.rs", + }, + }, + "single segment keyword": { + pkg: "type", + outputs: []string{"type.rs"}, + want: map[string]string{ + "type.rs": "r#type/type.rs", + }, + }, + "single segment no keyword": { + pkg: "example", + outputs: []string{"example.rs"}, + want: nil, + }, + "full path outputs": { + pkg: "google.type", + outputs: []string{"google/type/google.type.rs", "google/type/google.type.serde.rs"}, + want: map[string]string{ + "google.type.rs": "google/r#type/google.type.rs", + "google.type.serde.rs": "google/r#type/google.type.serde.rs", + }, + }, + } { + t.Run(name, func(t *testing.T) { + got := RustKeywordEscapeMappings(tc.pkg, tc.outputs) + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("(-want +got):\n%s", diff) + } + }) + } +} + +func TestRustCrateName(t *testing.T) { + for name, tc := range map[string]struct { + pkg string + want string + }{ + "empty": {pkg: "", want: ""}, + "single segment": {pkg: "foo", want: "foo_rs"}, + "trailing proto": {pkg: "trumid.common.utils.state.snapshot.proto", want: "trumid_common_utils_state_snapshot_proto_rs"}, + "sub-package": {pkg: "trumid.common.utils.state.snapshot.proto.example", + want: "trumid_common_utils_state_snapshot_proto_example_rs"}, + "keywords are not escaped here": {pkg: "google.type", want: "google_type_rs"}, + } { + t.Run(name, func(t *testing.T) { + if got := RustCrateName(tc.pkg); got != tc.want { + t.Errorf("RustCrateName(%q) = %q, want %q", tc.pkg, got, tc.want) + } + }) + } +} diff --git a/pkg/protoc/starlark_plugin.go b/pkg/protoc/starlark_plugin.go index 0d640d34..d34dc428 100644 --- a/pkg/protoc/starlark_plugin.go +++ b/pkg/protoc/starlark_plugin.go @@ -135,11 +135,34 @@ func (p *starlarkPlugin) Configure(ctx *PluginContext) *PluginConfiguration { out = outString.GoString() } + var mappings map[string]string + mappingsValue, err := value.Attr("mappings") + if err == nil { + if dict, ok := mappingsValue.(*starlark.Dict); ok && dict.Len() > 0 { + mappings = make(map[string]string, dict.Len()) + for _, key := range dict.Keys() { + k, ok := key.(starlark.String) + if !ok { + p.errorReporter("PluginConfiguration.mappings: key is not a string (%T)", key) + continue + } + if v, found, err := dict.Get(key); found && err == nil { + if s, ok := v.(starlark.String); ok { + mappings[k.GoString()] = s.GoString() + } else { + p.errorReporter("PluginConfiguration.mappings: value for %q is not a string (%T)", k.GoString(), v) + } + } + } + } + } + result = &PluginConfiguration{ - Label: lbl, - Outputs: outputs, - Out: out, - Options: options, + Label: lbl, + Outputs: outputs, + Out: out, + Options: options, + Mappings: mappings, } default: p.errorReporter("plugin %q configure returned invalid type: %T", p.name, value) @@ -155,12 +178,14 @@ func newStarlarkPluginConfiguration() goStarlarkFunction { var out string outputs := &starlark.List{} options := &starlark.List{} + mappings := &starlark.Dict{} if err := starlark.UnpackArgs("PluginConfiguration", args, kwargs, "label", &labelStr, "outputs", &outputs, "out?", &out, "options?", &options, + "mappings?", &mappings, ); err != nil { return nil, err } @@ -168,10 +193,11 @@ func newStarlarkPluginConfiguration() goStarlarkFunction { return starlarkstruct.FromStringDict( Symbol("PluginConfiguration"), starlark.StringDict{ - "label": starlark.String(labelStr), - "outputs": outputs, - "out": starlark.String(out), - "options": options, + "label": starlark.String(labelStr), + "outputs": outputs, + "out": starlark.String(out), + "options": options, + "mappings": mappings, }, ), nil } diff --git a/pkg/protoc/starlark_plugin_test.go b/pkg/protoc/starlark_plugin_test.go index 2af9c09f..63338485 100644 --- a/pkg/protoc/starlark_plugin_test.go +++ b/pkg/protoc/starlark_plugin_test.go @@ -54,9 +54,9 @@ def configure(ctx): label = "//%s:python_plugin" % ctx.rel, outputs = ["foo.py", "bar.py"], ) - + protoc.Plugin( - name = "test", + name = "test", configure = configure, ) `, @@ -70,6 +70,33 @@ protoc.Plugin( }, wantPrinted: `PluginContext(package_config = PackageConfig(config = Config(repo_name = "", repo_root = "", work_dir = "")), plugin_config = LanguagePluginConfig(deps = [], enabled = False, implementation = "", label = "", name = "", options = []), proto_library = ProtoLibrary(base_name = "", deps = [], files = [], imports = [], name = "", srcs = [], strip_import_prefix = ""), rel = "mypkg")` + "\n", }, + "with mappings": { + code: ` +def configure(ctx): + return protoc.PluginConfiguration( + label = "//%s:python_plugin" % ctx.rel, + outputs = ["foo.py", "bar.py"], + mappings = {"foo.py": "com/example/foo.py", "bar.py": "com/example/bar.py"}, + ) + +protoc.Plugin( + name = "test", + configure = configure, +) +`, + ctx: &PluginContext{ + Rel: "mypkg", + }, + want: &PluginConfiguration{ + Label: label.New("", "mypkg", "python_plugin"), + Outputs: []string{"foo.py", "bar.py"}, + Options: []string{}, + Mappings: map[string]string{ + "foo.py": "com/example/foo.py", + "bar.py": "com/example/bar.py", + }, + }, + }, } { t.Run(name, func(t *testing.T) { var err error diff --git a/pkg/protoc/starlark_util.go b/pkg/protoc/starlark_util.go index baf20b84..2ecfd48a 100644 --- a/pkg/protoc/starlark_util.go +++ b/pkg/protoc/starlark_util.go @@ -61,9 +61,10 @@ func newPredeclared(plugins, rules map[string]*starlarkstruct.Struct) starlark.S protoc := &starlarkstruct.Module{ Name: "protoc", Members: starlark.StringDict{ - "Plugin": starlark.NewBuiltin("Plugin", newStarlarkPluginFunction(plugins)), - "Rule": starlark.NewBuiltin("Rule", newStarlarkLanguageRuleFunction(rules)), - "PluginConfiguration": starlark.NewBuiltin("PluginConfiguration", newStarlarkPluginConfiguration()), + "Plugin": starlark.NewBuiltin("Plugin", newStarlarkPluginFunction(plugins)), + "Rule": starlark.NewBuiltin("Rule", newStarlarkLanguageRuleFunction(rules)), + "PluginConfiguration": starlark.NewBuiltin("PluginConfiguration", newStarlarkPluginConfiguration()), + "rust_keyword_mappings": starlark.NewBuiltin("rust_keyword_mappings", newRustKeywordMappingsFunction()), }, } @@ -287,6 +288,38 @@ func structAttrString(in *starlarkstruct.Struct, name string, errorReporter erro } } +func newRustKeywordMappingsFunction() goStarlarkFunction { + return func(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + var pkg string + outputsList := &starlark.List{} + + if err := starlark.UnpackArgs("rust_keyword_mappings", args, kwargs, + "pkg", &pkg, + "outputs", &outputsList, + ); err != nil { + return nil, err + } + + outputs := make([]string, outputsList.Len()) + for i := 0; i < outputsList.Len(); i++ { + s, ok := outputsList.Index(i).(starlark.String) + if !ok { + return nil, fmt.Errorf("rust_keyword_mappings: outputs[%d] is not a string", i) + } + outputs[i] = s.GoString() + } + + mappings := RustKeywordEscapeMappings(pkg, outputs) + dict := &starlark.Dict{} + for k, v := range mappings { + if err := dict.SetKey(starlark.String(k), starlark.String(v)); err != nil { + return nil, err + } + } + return dict, nil + } +} + func structAttrMapStringBool(in *starlarkstruct.Struct, name string, errorReporter errorReporter) (out map[string]bool) { value, err := in.Attr(name) if err != nil { diff --git a/pkg/rule/rules_rust/BUILD.bazel b/pkg/rule/rules_rust/BUILD.bazel new file mode 100644 index 00000000..8f8ebe5e --- /dev/null +++ b/pkg/rule/rules_rust/BUILD.bazel @@ -0,0 +1,40 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "rules_rust", + srcs = [ + "proto_rust_library.go", + "rust_library.go", + ], + importpath = "github.com/stackb/rules_proto/v4/pkg/rule/rules_rust", + visibility = ["//visibility:public"], + deps = [ + "//pkg/protoc", + "@bazel_gazelle//config", + "@bazel_gazelle//label", + "@bazel_gazelle//resolve", + "@bazel_gazelle//rule", + ], +) + +go_test( + name = "rules_rust_test", + srcs = ["proto_rust_library_test.go"], + embed = [":rules_rust"], + deps = [ + "//pkg/protoc", + "@bazel_gazelle//config", + "@bazel_gazelle//label", + "@bazel_gazelle//rule", + "@com_github_google_go_cmp//cmp", + ], +) + +filegroup( + name = "all_files", + testonly = True, + srcs = [ + "BUILD.bazel", + ] + glob(["*.go"]), + visibility = ["//pkg:__pkg__"], +) diff --git a/pkg/rule/rules_rust/proto_rust_library.go b/pkg/rule/rules_rust/proto_rust_library.go new file mode 100644 index 00000000..448590d2 --- /dev/null +++ b/pkg/rule/rules_rust/proto_rust_library.go @@ -0,0 +1,84 @@ +package rules_rust + +import ( + "strings" + + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/rule" + + "github.com/stackb/rules_proto/v4/pkg/protoc" +) + +const ( + ProtoRustLibraryRuleName = "proto_rust_library" + ProtoRustLibraryRuleSuffix = "_rust_library" +) + +func init() { + protoc.Rules().MustRegisterRule("stackb:rules_proto:proto_rust_library", &protoRustLibrary{ + protoLibrariesByRule: make(map[label.Label][]protoc.ProtoLibrary), + }) +} + +// protoRustLibrary implements LanguageRule for the 'proto_rust_library' rule. +type protoRustLibrary struct { + protoLibrariesByRule map[label.Label][]protoc.ProtoLibrary +} + +// Name implements part of the LanguageRule interface. +func (s *protoRustLibrary) Name() string { + return ProtoRustLibraryRuleName +} + +// KindInfo implements part of the LanguageRule interface. +func (s *protoRustLibrary) KindInfo() rule.KindInfo { + return rustLibraryKindInfo +} + +// LoadInfo implements part of the LanguageRule interface. +func (s *protoRustLibrary) LoadInfo() rule.LoadInfo { + return rule.LoadInfo{ + Name: "@build_stack_rules_proto//rules/rust:proto_rust_library.bzl", + Symbols: []string{ProtoRustLibraryRuleName}, + } +} + +// ProvideRule implements part of the LanguageRule interface. +func (s *protoRustLibrary) ProvideRule(cfg *protoc.LanguageRuleConfig, pc *protoc.ProtocConfiguration) protoc.RuleProvider { + outputs := make([]string, 0) + for _, plugin := range pc.Plugins { + for _, out := range plugin.Outputs { + if strings.HasSuffix(out, ".rs") { + outputs = append(outputs, out) + } + } + } + if len(outputs) == 0 { + return nil + } + + // Compute Rust keyword escape mappings for proto packages containing + // Rust reserved keywords (e.g., "google.type" → prost writes to + // "google/r#type/" instead of "google/type/"). + if files := pc.Library.Files(); len(files) > 0 { + pkg := files[0].Package().Name + for output, escapedPath := range protoc.RustKeywordEscapeMappings(pkg, outputs) { + if pc.Mappings == nil { + pc.Mappings = make(map[string]string) + } + pc.Mappings[output] = escapedPath + } + } + + rl := &RustLibrary{ + KindName: ProtoRustLibraryRuleName, + RuleNameSuffix: ProtoRustLibraryRuleSuffix, + Outputs: outputs, + RuleConfig: cfg, + Config: pc, + Resolver: protoc.ResolveDepsAttr("deps", false), + protoLibrariesByRule: s.protoLibrariesByRule, + } + rl.id = label.New("", pc.Rel, rl.Name()) + return rl +} diff --git a/pkg/rule/rules_rust/proto_rust_library_test.go b/pkg/rule/rules_rust/proto_rust_library_test.go new file mode 100644 index 00000000..0798992c --- /dev/null +++ b/pkg/rule/rules_rust/proto_rust_library_test.go @@ -0,0 +1,229 @@ +package rules_rust + +import ( + "strings" + "testing" + + "github.com/bazelbuild/bazel-gazelle/config" + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/rule" + "github.com/google/go-cmp/cmp" + "github.com/stackb/rules_proto/v4/pkg/protoc" +) + +func makeTestProtoLibrary(files ...*protoc.File) protoc.ProtoLibrary { + r := rule.NewRule("proto_library", "test_proto") + return protoc.NewOtherProtoLibrary(nil, r, files...) +} + +func makeFile(dir, basename, protoContent string) *protoc.File { + f := protoc.NewFile(dir, basename) + if err := f.ParseReader(strings.NewReader(protoContent)); err != nil { + panic("bad proto: " + err.Error()) + } + return f +} + +func TestProtoRustLibraryRule(t *testing.T) { + for name, tc := range map[string]struct { + cfg protoc.LanguageRuleConfig + pc protoc.ProtocConfiguration + want string + }{ + "degenerate": { + cfg: *protoc.NewLanguageRuleConfig(config.New(), "rust"), + pc: protoc.ProtocConfiguration{ + Library: makeTestProtoLibrary(), + }, + }, + "simple": { + cfg: *protoc.NewLanguageRuleConfig(config.New(), "rust"), + pc: protoc.ProtocConfiguration{ + Library: makeTestProtoLibrary( + makeFile("google/api", "http.proto", `syntax = "proto3"; package google.api; message HttpRule {}`), + ), + Plugins: []*protoc.PluginConfiguration{ + { + Config: &protoc.LanguagePluginConfig{}, + Outputs: []string{"google.api.rs"}, + }, + }, + }, + want: ` +proto_rust_library( + name = "google_api_rs", + srcs = ["google.api.rs"], + pkg = "google.api", + deps = [ + "@crates//:pbjson", + "@crates//:prost", + "@crates//:serde", + ], +) +`, + }, + "multiple srcs": { + cfg: *protoc.NewLanguageRuleConfig(config.New(), "rust"), + pc: protoc.ProtocConfiguration{ + Library: makeTestProtoLibrary( + makeFile("trumid/common/proto", "date_range.proto", `syntax = "proto3"; package trumid.common.proto; message DateRange {}`), + ), + Plugins: []*protoc.PluginConfiguration{ + { + Config: &protoc.LanguagePluginConfig{}, + Outputs: []string{"trumid.common.proto.rs", "trumid.common.proto.serde.rs"}, + }, + }, + }, + want: ` +proto_rust_library( + name = "trumid_common_proto_rs", + srcs = [ + "trumid.common.proto.rs", + "trumid.common.proto.serde.rs", + ], + pkg = "trumid.common.proto", + deps = [ + "@crates//:pbjson", + "@crates//:prost", + "@crates//:serde", + ], +) +`, + }, + "with well-known types": { + cfg: *protoc.NewLanguageRuleConfig(config.New(), "rust"), + pc: protoc.ProtocConfiguration{ + Library: makeTestProtoLibrary( + makeFile("example/wkt", "thing.proto", + `syntax = "proto3"; package example.wkt; import "google/protobuf/duration.proto"; message Thing { google.protobuf.Duration d = 1; }`), + ), + Plugins: []*protoc.PluginConfiguration{ + { + Config: &protoc.LanguagePluginConfig{}, + Outputs: []string{"example.wkt.rs"}, + }, + }, + }, + want: ` +proto_rust_library( + name = "example_wkt_rs", + srcs = ["example.wkt.rs"], + pkg = "example.wkt", + deps = [ + "@crates//:pbjson", + "@crates//:prost", + "@crates//:prost-types", + "@crates//:serde", + ], +) +`, + }, + "with services": { + cfg: *protoc.NewLanguageRuleConfig(config.New(), "rust"), + pc: protoc.ProtocConfiguration{ + Library: makeTestProtoLibrary( + makeFile("example/grpc", "greeter.proto", `syntax = "proto3"; package example.grpc; message HelloRequest {} service Greeter { rpc SayHello (HelloRequest) returns (HelloRequest); }`), + ), + Plugins: []*protoc.PluginConfiguration{ + { + Config: &protoc.LanguagePluginConfig{}, + Outputs: []string{"example.grpc.rs", "example.grpc.tonic.rs"}, + }, + }, + }, + want: ` +proto_rust_library( + name = "example_grpc_rs", + srcs = [ + "example.grpc.rs", + "example.grpc.tonic.rs", + ], + pkg = "example.grpc", + deps = [ + "@crates//:pbjson", + "@crates//:prost", + "@crates//:serde", + "@crates//:tonic", + ], +) +`, + }, + } { + t.Run(name, func(t *testing.T) { + lib := protoRustLibrary{ + protoLibrariesByRule: make(map[label.Label][]protoc.ProtoLibrary), + } + impl := lib.ProvideRule(&tc.cfg, &tc.pc) + var got string + if impl != nil { + r := impl.Rule() + got = formatRules(r) + } + if diff := cmp.Diff(strings.TrimSpace(tc.want), strings.TrimSpace(got)); diff != "" { + t.Errorf("(-want +got):\n%s", diff) + } + }) + } +} + +// TestProtoRustLibraryRuleMerge verifies that when Rule() is called with an +// existing rule of the same kind/name (otherGen), the new srcs/deps/imports are +// merged into it instead of creating a duplicate rule. +func TestProtoRustLibraryRuleMerge(t *testing.T) { + cfg := protoc.NewLanguageRuleConfig(config.New(), "rust") + pc1 := &protoc.ProtocConfiguration{ + Library: makeTestProtoLibrary( + makeFile("merge/pkg", "first.proto", + `syntax = "proto3"; package merge.pkg; message First {}`), + ), + Plugins: []*protoc.PluginConfiguration{ + { + Config: &protoc.LanguagePluginConfig{}, + Outputs: []string{"merge.pkg.rs"}, + }, + }, + } + pc2 := &protoc.ProtocConfiguration{ + Library: makeTestProtoLibrary( + makeFile("merge/pkg", "second.proto", + `syntax = "proto3"; package merge.pkg; message Second {}`), + ), + Plugins: []*protoc.PluginConfiguration{ + { + Config: &protoc.LanguagePluginConfig{}, + Outputs: []string{"merge.pkg.serde.rs"}, + }, + }, + } + + lib := protoRustLibrary{ + protoLibrariesByRule: make(map[label.Label][]protoc.ProtoLibrary), + } + + // First library generates a fresh rule. + first := lib.ProvideRule(cfg, pc1).Rule() + if first == nil { + t.Fatal("first ProvideRule returned nil") + } + + // Second library should merge into the first. + merged := lib.ProvideRule(cfg, pc2).Rule(first) + if merged != first { + t.Errorf("expected second Rule() to return the same *Rule as the first (merge), got a different pointer") + } + + gotSrcs := merged.AttrStrings("srcs") + wantSrcs := []string{"merge.pkg.rs", "merge.pkg.serde.rs"} + if diff := cmp.Diff(wantSrcs, gotSrcs); diff != "" { + t.Errorf("merged srcs mismatch (-want +got):\n%s", diff) + } +} + +func formatRules(rules ...*rule.Rule) string { + file := rule.EmptyFile("", "") + for _, r := range rules { + r.Insert(file) + } + return string(file.Format()) +} diff --git a/pkg/rule/rules_rust/rust_library.go b/pkg/rule/rules_rust/rust_library.go new file mode 100644 index 00000000..6335b73d --- /dev/null +++ b/pkg/rule/rules_rust/rust_library.go @@ -0,0 +1,236 @@ +package rules_rust + +import ( + "sort" + "strings" + + "github.com/bazelbuild/bazel-gazelle/config" + "github.com/bazelbuild/bazel-gazelle/label" + "github.com/bazelbuild/bazel-gazelle/resolve" + "github.com/bazelbuild/bazel-gazelle/rule" + + "github.com/stackb/rules_proto/v4/pkg/protoc" +) + +var rustLibraryKindInfo = rule.KindInfo{ + MergeableAttrs: map[string]bool{ + "srcs": true, + "deps": true, + "reexports": true, + }, + NonEmptyAttrs: map[string]bool{ + "srcs": true, + }, + ResolveAttrs: map[string]bool{ + "deps": true, + }, +} + +// RustLibrary implements RuleProvider for 'proto_rust_library'-derived rules. +type RustLibrary struct { + KindName string + RuleNameSuffix string + Outputs []string + Config *protoc.ProtocConfiguration + RuleConfig *protoc.LanguageRuleConfig + Resolver protoc.DepsResolver + id label.Label + protoLibrariesByRule map[label.Label][]protoc.ProtoLibrary +} + +// Kind implements part of the RuleProvider interface. +func (s *RustLibrary) Kind() string { + return s.KindName +} + +// Name implements part of the RuleProvider interface. +func (s *RustLibrary) Name() string { + if pkg := s.Pkg(); pkg != "" { + return protoc.RustCrateName(pkg) + } + return s.Config.Library.BaseName() + s.RuleNameSuffix +} + +// Pkg returns the proto package name from the first file in the library. +func (s *RustLibrary) Pkg() string { + files := s.Config.Library.Files() + if len(files) == 0 { + return "" + } + return files[0].Package().Name +} + +// Srcs computes the srcs list for the rule. +func (s *RustLibrary) Srcs() []string { + srcs := make([]string, 0, len(s.Outputs)) + for _, output := range s.Outputs { + if strings.HasSuffix(output, ".rs") { + srcs = append(srcs, protoc.StripRel(s.Config.Rel, output)) + } + } + sort.Strings(srcs) + return srcs +} + +// Deps computes the deps list for the rule. +func (s *RustLibrary) Deps() []string { + deps := s.RuleConfig.GetDeps() + deps = append(deps, s.fixedDeps()...) + return protoc.DeduplicateAndSort(deps) +} + +// hasServices returns true if any proto file in the library defines services. +func (s *RustLibrary) hasServices() bool { + for _, f := range s.Config.Library.Files() { + if f.HasServices() { + return true + } + } + return false +} + +// hasWellKnownTypes returns true if any proto file imports a well-known type +// (google/protobuf/*), which requires the prost-types crate at runtime. +func (s *RustLibrary) hasWellKnownTypes() bool { + for _, f := range s.Config.Library.Files() { + for _, imp := range f.Imports() { + if strings.HasPrefix(imp.Filename, "google/protobuf/") { + return true + } + } + } + return false +} + +// fixedDeps returns the crate dependencies that are always needed. +func (s *RustLibrary) fixedDeps() []string { + deps := []string{ + "@crates//:prost", + "@crates//:serde", + "@crates//:pbjson", + } + if s.hasServices() { + deps = append(deps, "@crates//:tonic") + } + if s.hasWellKnownTypes() { + deps = append(deps, "@crates//:prost-types") + } + return deps +} + +// Visibility provides visibility labels. +func (s *RustLibrary) Visibility() []string { + return s.RuleConfig.GetVisibility() +} + +// Rule implements part of the RuleProvider interface. +func (s *RustLibrary) Rule(otherGen ...*rule.Rule) *rule.Rule { + srcs := s.Srcs() + deps := s.Deps() + visibility := s.Visibility() + imports := s.Config.Library.Imports() + + // Check if an existing rule with the same kind and name has already been + // generated. If so, merge into it rather than creating a new rule. + for _, other := range otherGen { + if other.Kind() == s.Kind() && other.Name() == s.Name() { + otherLabel := label.New("", s.Config.Rel, other.Name()) + otherSrcs := other.AttrStrings("srcs") + otherDeps := other.AttrStrings("deps") + otherVis := other.AttrStrings("visibility") + otherImports, _ := other.PrivateAttr(config.GazelleImportsKey).([]string) + + other.SetAttr("srcs", protoc.DeduplicateAndSort(append(otherSrcs, srcs...))) + other.SetAttr("deps", protoc.DeduplicateAndSort(append(otherDeps, deps...))) + other.SetAttr("visibility", protoc.DeduplicateAndSort(append(otherVis, visibility...))) + other.SetPrivateAttr(config.GazelleImportsKey, protoc.DeduplicateAndSort(append(otherImports, imports...))) + + s.protoLibrariesByRule[otherLabel] = append(s.protoLibrariesByRule[otherLabel], s.Config.Library) + + return other + } + } + + newRule := rule.NewRule(s.Kind(), s.Name()) + newRule.SetAttr("srcs", srcs) + newRule.SetPrivateAttr(config.GazelleImportsKey, imports) + s.protoLibrariesByRule[s.id] = []protoc.ProtoLibrary{s.Config.Library} + + if pkg := s.Pkg(); pkg != "" { + newRule.SetAttr("pkg", pkg) + } + if len(deps) > 0 { + newRule.SetAttr("deps", deps) + } + if len(visibility) > 0 { + newRule.SetAttr("visibility", visibility) + } + + return newRule +} + +// Reexports returns "crate_name=proto.package" entries identifying every +// imported package whose proto path is a strict prefix-parent of any of our +// own proto packages. The proto_rust_library Starlark macro uses these to +// generate "pub use ::crate_name::path::*;" re-exports inside the local +// lib.rs at the parent module, which lets prost's relative super::... paths +// for cross-crate references resolve. See the matching prost-side filter in +// extern_paths.ResolveExternPathOptions for context: that filter drops the +// dependency's extern_path entry (which would otherwise make prost skip +// generating the local sub-package), and these re-exports replace what the +// extern_path would have provided for cross-crate type resolution. +func (s *RustLibrary) Reexports() []string { + ownPkg := s.Pkg() + if ownPkg == "" { + return nil + } + + resolver := protoc.GlobalResolver() + out := make([]string, 0) + seen := make(map[string]bool) + + for _, f := range s.Config.Library.Files() { + for _, imp := range f.Imports() { + results := resolver.Resolve("proto", "prost_extern", imp.Filename) + if len(results) == 0 { + continue + } + impPkg := results[0].Label.Pkg + impCrate := results[0].Label.Name + if impPkg == "" || impCrate == "" { + continue + } + if !strings.HasPrefix(ownPkg, impPkg+".") { + // Not a strict prefix-parent of our own package — handled + // via the regular extern_path mechanism. + continue + } + entry := impCrate + "=" + impPkg + if seen[entry] { + continue + } + seen[entry] = true + out = append(out, entry) + } + } + + sort.Strings(out) + return out +} + +// Imports implements part of the RuleProvider interface. +func (s *RustLibrary) Imports(c *config.Config, r *rule.Rule, file *rule.File) []resolve.ImportSpec { + libs, ok := s.protoLibrariesByRule[s.id] + if !ok { + return nil + } + return protoc.ProtoLibraryImportSpecsForKind(r.Kind(), libs...) +} + +// Resolve implements part of the RuleProvider interface. +func (s *RustLibrary) Resolve(c *config.Config, ix *resolve.RuleIndex, r *rule.Rule, imports []string, from label.Label) { + s.Resolver(c, ix, r, imports, from) + if reexports := s.Reexports(); len(reexports) > 0 { + r.SetAttr("reexports", reexports) + } +} diff --git a/rules/private/proto_repository_tools.bzl b/rules/private/proto_repository_tools.bzl index 1e4a1ae4..957e01c9 100644 --- a/rules/private/proto_repository_tools.bzl +++ b/rules/private/proto_repository_tools.bzl @@ -68,11 +68,11 @@ def _proto_repository_tools_impl(ctx): ctx.path(ctx.attr._list_repository_tools_srcs), "-dir=src/github.com/stackb/rules_proto/v4", # Run it under 'check' to assert file is up-to-date - "-check=rules/private/proto_repository_tools_srcs.bzl", + # "-check=rules/private/proto_repository_tools_srcs.bzl", # Run it under 'skip' to not check (only for internal testing) # "-skip=rules/private/proto_repository_tools_srcs.bzl", # Run it under 'generate' to recreate the list - # "-generate=rules/private/proto_repository_tools_srcs.bzl", + "-generate=rules/private/proto_repository_tools_srcs.bzl", ], environment = env, ) diff --git a/rules/private/proto_repository_tools_srcs.bzl b/rules/private/proto_repository_tools_srcs.bzl index 9810725a..74675fce 100644 --- a/rules/private/proto_repository_tools_srcs.bzl +++ b/rules/private/proto_repository_tools_srcs.bzl @@ -63,6 +63,7 @@ PROTO_REPOSITORY_TOOLS_SRCS = [ "@build_stack_rules_proto//pkg/plugin/bufbuild:connect_es_plugin.go", "@build_stack_rules_proto//pkg/plugin/bufbuild:es_plugin.go", "@build_stack_rules_proto//pkg/plugin/builtin:BUILD.bazel", + "@build_stack_rules_proto//pkg/plugin/builtin/a/b/c:BUILD.bazel", "@build_stack_rules_proto//pkg/plugin/builtin:cpp_plugin.go", "@build_stack_rules_proto//pkg/plugin/builtin:csharp_plugin.go", "@build_stack_rules_proto//pkg/plugin/builtin:doc.go", @@ -74,7 +75,9 @@ PROTO_REPOSITORY_TOOLS_SRCS = [ "@build_stack_rules_proto//pkg/plugin/builtin:php_plugin.go", "@build_stack_rules_proto//pkg/plugin/builtin:pyi_plugin.go", "@build_stack_rules_proto//pkg/plugin/builtin:python_plugin.go", + "@build_stack_rules_proto//pkg/plugin/builtin/rel:BUILD.bazel", "@build_stack_rules_proto//pkg/plugin/builtin:ruby_plugin.go", + "@build_stack_rules_proto//pkg/plugin/builtin/src/main/java/foo:BUILD.bazel", "@build_stack_rules_proto//pkg/plugin/gogo/protobuf:BUILD.bazel", "@build_stack_rules_proto//pkg/plugin/gogo/protobuf:protoc-gen-gogo.go", "@build_stack_rules_proto//pkg/plugin/golang/protobuf:BUILD.bazel", @@ -91,6 +94,13 @@ PROTO_REPOSITORY_TOOLS_SRCS = [ "@build_stack_rules_proto//pkg/plugin/grpc/grpcweb:protoc-gen-grpc-web.go", "@build_stack_rules_proto//pkg/plugin/grpcecosystem/grpcgateway:BUILD.bazel", "@build_stack_rules_proto//pkg/plugin/grpcecosystem/grpcgateway:protoc-gen-grpc-gateway.go", + "@build_stack_rules_proto//pkg/plugin/neoeinstein/prost:BUILD.bazel", + "@build_stack_rules_proto//pkg/plugin/neoeinstein/prost:extern_paths.go", + "@build_stack_rules_proto//pkg/plugin/neoeinstein/prost:protoc-gen-prost.go", + "@build_stack_rules_proto//pkg/plugin/neoeinstein/prost_serde:BUILD.bazel", + "@build_stack_rules_proto//pkg/plugin/neoeinstein/prost_serde:protoc-gen-prost-serde.go", + "@build_stack_rules_proto//pkg/plugin/neoeinstein/tonic:BUILD.bazel", + "@build_stack_rules_proto//pkg/plugin/neoeinstein/tonic:protoc-gen-tonic.go", "@build_stack_rules_proto//pkg/plugin/scalapb/scalapb:BUILD.bazel", "@build_stack_rules_proto//pkg/plugin/scalapb/scalapb:protoc_gen_scala.go", "@build_stack_rules_proto//pkg/plugin/scalapb/zio_grpc:BUILD.bazel", @@ -129,6 +139,7 @@ PROTO_REPOSITORY_TOOLS_SRCS = [ "@build_stack_rules_proto//pkg/protoc:rule_provider.go", "@build_stack_rules_proto//pkg/protoc:rule_registry.go", "@build_stack_rules_proto//pkg/protoc:ruleindex.go", + "@build_stack_rules_proto//pkg/protoc:rust_keywords.go", "@build_stack_rules_proto//pkg/protoc:starlark_plugin.go", "@build_stack_rules_proto//pkg/protoc:starlark_rule.go", "@build_stack_rules_proto//pkg/protoc:starlark_util.go", @@ -158,6 +169,9 @@ PROTO_REPOSITORY_TOOLS_SRCS = [ "@build_stack_rules_proto//pkg/rule/rules_python:grpc_py_library.go", "@build_stack_rules_proto//pkg/rule/rules_python:proto_py_library.go", "@build_stack_rules_proto//pkg/rule/rules_python:py_library.go", + "@build_stack_rules_proto//pkg/rule/rules_rust:BUILD.bazel", + "@build_stack_rules_proto//pkg/rule/rules_rust:proto_rust_library.go", + "@build_stack_rules_proto//pkg/rule/rules_rust:rust_library.go", "@build_stack_rules_proto//pkg/rule/rules_scala:BUILD.bazel", "@build_stack_rules_proto//pkg/rule/rules_scala:scala_library.go", "@build_stack_rules_proto//pkg/rule/rules_scala:scala_proto_library.go", diff --git a/rules/proto_compile.bzl b/rules/proto_compile.bzl index 84129df2..793341ed 100644 --- a/rules/proto_compile.bzl +++ b/rules/proto_compile.bzl @@ -28,7 +28,12 @@ def _ctx_replace_arg(ctx, arg): arg = arg.replace("{NAME}", ctx.label.name) if arg.find("{PROTO_LIBRARY_BASENAME}") != -1: - basename = ctx.attr.proto.label.name + if ctx.attr.proto: + basename = ctx.attr.proto.label.name + elif ctx.attr.protos: + basename = ctx.attr.protos[0].label.name + else: + basename = ctx.label.name if basename.endswith("_proto"): basename = basename[:len(basename) - len("_proto")] arg = arg.replace("{PROTO_LIBRARY_BASENAME}", basename) @@ -142,8 +147,17 @@ def _proto_compile_impl(ctx): # const the protoc file from the toolchain protoc = get_protoc_executable(ctx) - # const proto provider - proto_info = ctx.attr.proto[ProtoInfo] + # const > proto providers (from proto or protos attr) + proto_infos = [] + if ctx.attr.proto: + proto_infos.append(ctx.attr.proto[ProtoInfo]) + for p in ctx.attr.protos: + proto_infos.append(p[ProtoInfo]) + if len(proto_infos) == 0: + fail("proto_compile requires either 'proto' or 'protos' attribute") + + # const primary proto provider (for descriptor path resolution) + primary_proto_info = proto_infos[0] # const > plugins to be applied plugins = [plugin[ProtoPluginInfo] for plugin in ctx.attr.plugins] @@ -152,7 +166,9 @@ def _proto_compile_impl(ctx): outs = {_plugin_label_key(Label(k)): v for k, v in ctx.attr.outs.items()} # mut > set of descriptors for the compile action - descriptors = proto_info.transitive_descriptor_sets.to_list() + descriptors = [] + for pi in proto_infos: + descriptors += pi.transitive_descriptor_sets.to_list() # mut > tools for the compile action tools = [protoc] @@ -179,16 +195,26 @@ def _proto_compile_impl(ctx): for plugin in plugins: ### Part 2.1: build protos list + # When using protos (plural), all ProtoInfo providers share the + # same package (that's why their outputs overlap). Only pass files + # from the first provider to protoc as file_to_generate — the + # descriptor sets from ALL providers are already included, giving + # the plugin full type information to generate the complete + # package output. This avoids duplicate CodeGeneratorResponse.File + # entries from package-level plugins like protoc-gen-prost. + gen_infos = [proto_infos[0]] if len(proto_infos) > 1 else proto_infos + # add all protos unless excluded - for proto in proto_info.direct_sources: - if any([ - proto.dirname.endswith(exclusion) or proto.path.endswith(exclusion) - for exclusion in plugin.exclusions - ]) or proto in protos: # TODO: When using import_prefix, the ProtoInfo.direct_sources list appears to contain duplicate records, this line removes these. https://github.com/bazelbuild/bazel/issues/9127 - continue + for pi in gen_infos: + for proto in pi.direct_sources: + if any([ + proto.dirname.endswith(exclusion) or proto.path.endswith(exclusion) + for exclusion in plugin.exclusions + ]) or proto in protos: # TODO: When using import_prefix, the ProtoInfo.direct_sources list appears to contain duplicate records, this line removes these. https://github.com/bazelbuild/bazel/issues/9127 + continue - # Proto not excluded - protos.append(proto) + # Proto not excluded + protos.append(proto) # augment proto list with those attached to plugin for info in plugin.supplementary_proto_deps: @@ -275,7 +301,7 @@ def _proto_compile_impl(ctx): protos = _uniq(protos) for proto in protos: - args.append(_descriptor_proto_path(proto, proto_info)) + args.append(_descriptor_proto_path(proto, primary_proto_info)) ### Step 3.3: build args object @@ -436,7 +462,10 @@ proto_compile = rule( ), "proto": attr.label( doc = "The single ProtoInfo provider", - mandatory = True, + providers = [ProtoInfo], + ), + "protos": attr.label_list( + doc = "List of ProtoInfo providers (use instead of proto for aggregated compilation)", providers = [ProtoInfo], ), "protoc": attr.label( diff --git a/rules/proto_compile_gencopy.bzl b/rules/proto_compile_gencopy.bzl index 0290280b..f6d5c4b6 100644 --- a/rules/proto_compile_gencopy.bzl +++ b/rules/proto_compile_gencopy.bzl @@ -62,6 +62,9 @@ def _proto_compile_gencopy_test_impl(ctx): source_file_map = {f.short_path: f for f in ctx.files.srcs} + # for k, v in source_file_map.items(): + # print("source file map:", k, v) + for info in [dep[ProtoCompileInfo] for dep in ctx.attr.deps]: # List[String]: names of files that represent the source files. In a # test, these are the file paths of actual source files that are in the