diff --git a/go/cmd/importer/main.go b/go/cmd/importer/main.go index 6fa4f82e913..b48f57e0166 100644 --- a/go/cmd/importer/main.go +++ b/go/cmd/importer/main.go @@ -81,7 +81,7 @@ func main() { } config.SourceRepoStore = db.NewSourceRepositoryStore(datastoreClient) // Needed for deletions only - config.VulnerabilityStore = db.NewVulnerabilityStore(datastoreClient) + config.VulnerabilityStore = db.NewVulnerabilityStore(datastoreClient, nil) psClient, err := pubsub.NewClient(ctx, project) if err != nil { diff --git a/go/cmd/relations/update.go b/go/cmd/relations/update.go index 6f330cea0e4..a1b3a07b677 100644 --- a/go/cmd/relations/update.go +++ b/go/cmd/relations/update.go @@ -26,6 +26,7 @@ import ( "cloud.google.com/go/datastore" "cloud.google.com/go/pubsub/v2" + osvdatastore "github.com/google/osv.dev/go/internal/database/datastore" "github.com/google/osv.dev/go/logger" "github.com/google/osv.dev/go/osv/clients" "github.com/google/osv.dev/go/osv/models" @@ -199,7 +200,7 @@ ConsumerLoop: continue } - listedVuln := models.NewListedVulnerabilityFromProto(v) + listedVuln := osvdatastore.NewListedVulnerabilityFromProto(v) listedKey := datastore.NameKey("ListedVulnerability", id, nil) if _, err := tx.Put(listedKey, listedVuln); err != nil { logger.Error("failed to put listed vuln to Datastore", slog.String("id", id), slog.Any("err", err)) diff --git a/go/internal/database/datastore/affected_versions.go b/go/internal/database/datastore/affected_versions.go new file mode 100644 index 00000000000..f5797837e13 --- /dev/null +++ b/go/internal/database/datastore/affected_versions.go @@ -0,0 +1,218 @@ +package datastore + +import ( + "net/url" + "slices" + "strings" + + "github.com/google/osv.dev/go/osv/ecosystem" + "github.com/ossf/osv-schema/bindings/go/osvschema" +) + +const ( + minCoarseVersion = "00:00000000.00000000.00000000" + maxCoarseVersion = "99:99999999.99999999.99999999" +) + +func normalizeRepo(repoURL string) string { + // Normalize the repo_url for use with GIT AffectedVersions entities. + // Removes the scheme/protocol, the .git extension, and trailing slashes. + if repoURL == "" { + return "" + } + parsed, err := url.Parse(repoURL) + if err != nil { + return repoURL + } + normalized := parsed.Host + parsed.Path + normalized = strings.TrimRight(normalized, "/") + normalized = strings.TrimSuffix(normalized, ".git") + + return normalized +} + +func computeAffectedVersions(vuln *osvschema.Vulnerability) []AffectedVersions { + var res []AffectedVersions + + for _, affected := range vuln.GetAffected() { + pkgEcosystem := affected.GetPackage().GetEcosystem() + if pkgEcosystem == "" { + continue + } + + allPkgEcosystems := []string{pkgEcosystem} + normalized, _, _ := strings.Cut(pkgEcosystem, ":") + if normalized != pkgEcosystem { + allPkgEcosystems = append(allPkgEcosystems, normalized) + } + if v := removeVariants(pkgEcosystem); v != "" { + allPkgEcosystems = append(allPkgEcosystems, v) + } + + slices.Sort(allPkgEcosystems) + allPkgEcosystems = slices.Compact(allPkgEcosystems) + + pkgName := affected.GetPackage().GetName() + eHelper, exists := ecosystem.DefaultProvider.Get(pkgEcosystem) + + // TODO(michaelkedar): Matching the current behavior of the API, + // where GIT tags match to the first git repo in the ranges list, even if + // there are non-git ranges or multiple git repos in a range. + repoURL := "" + hasAffected := false + + for _, r := range affected.GetRanges() { + if r.GetType() == osvschema.Range_GIT && repoURL == "" { + repoURL = r.GetRepo() + } + if r.GetType() != osvschema.Range_ECOSYSTEM && r.GetType() != osvschema.Range_SEMVER { + continue + } + if len(r.GetEvents()) == 0 { + continue + } + + hasAffected = true + var rangeEvents []AffectedEvent + for _, e := range r.GetEvents() { + if e.GetIntroduced() != "" { + rangeEvents = append(rangeEvents, AffectedEvent{Type: "introduced", Value: e.GetIntroduced()}) + } else if e.GetFixed() != "" { + rangeEvents = append(rangeEvents, AffectedEvent{Type: "fixed", Value: e.GetFixed()}) + } else if e.GetLimit() != "" { + rangeEvents = append(rangeEvents, AffectedEvent{Type: "limit", Value: e.GetLimit()}) + } else if e.GetLastAffected() != "" { + rangeEvents = append(rangeEvents, AffectedEvent{Type: "last_affected", Value: e.GetLastAffected()}) + } + } + + var eventsMap = map[string]int{ + "introduced": 0, + "last_affected": 1, + "fixed": 2, + "limit": 3, + } + + if exists { + // If we have an ecosystem helper, sort the events to help with querying. + slices.SortFunc(rangeEvents, func(a, b AffectedEvent) int { + pa, errA := eHelper.Parse(a.Value) + pb, errB := eHelper.Parse(b.Value) + if errA != nil || errB != nil { + if a.Value != b.Value { + return strings.Compare(a.Value, b.Value) + } + + return eventsMap[a.Type] - eventsMap[b.Type] + } + res, errC := pa.Compare(pb) + if errC != nil { + if a.Value != b.Value { + return strings.Compare(a.Value, b.Value) + } + + return eventsMap[a.Type] - eventsMap[b.Type] + } + if res != 0 { + return res + } + + return eventsMap[a.Type] - eventsMap[b.Type] + }) + } + + coarseMin := minCoarseVersion + coarseMax := maxCoarseVersion + + if exists { + for _, ev := range rangeEvents { + if ev.Type == "introduced" { + if cm, err := eHelper.Coarse(ev.Value); err == nil { + coarseMin = cm + } + last := rangeEvents[len(rangeEvents)-1] + if last.Type != "introduced" { + if cm, err := eHelper.Coarse(last.Value); err == nil { + coarseMax = cm + } + } + + break + } + } + } + + for _, e := range allPkgEcosystems { + res = append(res, AffectedVersions{ + VulnID: vuln.GetId(), + Ecosystem: e, + Name: pkgName, + Events: rangeEvents, + CoarseMin: coarseMin, + CoarseMax: coarseMax, + }) + } + } + + if pkgName != "" && len(affected.GetVersions()) > 0 { + hasAffected = true + coarseMin := minCoarseVersion + coarseMax := maxCoarseVersion + + if exists { + var allCoarse []string + for _, v := range affected.GetVersions() { + if cm, err := eHelper.Coarse(v); err == nil { + allCoarse = append(allCoarse, cm) + } + } + if len(allCoarse) > 0 { + slices.Sort(allCoarse) + coarseMin = allCoarse[0] + coarseMax = allCoarse[len(allCoarse)-1] + } + } + + for _, e := range allPkgEcosystems { + res = append(res, AffectedVersions{ + VulnID: vuln.GetId(), + Ecosystem: e, + Name: pkgName, + Versions: affected.GetVersions(), + CoarseMin: coarseMin, + CoarseMax: coarseMax, + }) + } + } + + if pkgName != "" && !hasAffected { + // We have a package that does not have any affected ranges or versions, + // which doesn't really make sense. + // Add an empty AffectedVersions entry so that this vuln is returned when + // querying the API with no version specified. + for _, e := range allPkgEcosystems { + res = append(res, AffectedVersions{ + VulnID: vuln.GetId(), + Ecosystem: e, + Name: pkgName, + CoarseMin: minCoarseVersion, + CoarseMax: maxCoarseVersion, + }) + } + } + + if repoURL != "" { + // If we have a repository, always add a GIT entry. + // Even if affected.versions is empty, we still want to return this vuln + // for the API queries with no versions specified. + res = append(res, AffectedVersions{ + VulnID: vuln.GetId(), + Ecosystem: "GIT", + Name: normalizeRepo(repoURL), + Versions: affected.GetVersions(), + }) + } + } + + return res +} diff --git a/go/internal/database/datastore/affected_versions_test.go b/go/internal/database/datastore/affected_versions_test.go new file mode 100644 index 00000000000..01eb6e8230d --- /dev/null +++ b/go/internal/database/datastore/affected_versions_test.go @@ -0,0 +1,193 @@ +package datastore + +import ( + "cmp" + "testing" + + gocmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/ossf/osv-schema/bindings/go/osvschema" +) + +func TestComputeAffectedVersions(t *testing.T) { + vuln := &osvschema.Vulnerability{ + Id: "TEST-123", + Affected: []*osvschema.Affected{ + { + Package: &osvschema.Package{ + Name: "testjs", + Ecosystem: "npm", + }, + Versions: []string{"0.1.0", "0.2.0", "0.3.0", "2.0.0", "2.1.0", "2.2.0"}, + Ranges: []*osvschema.Range{ + { + Type: osvschema.Range_ECOSYSTEM, + Events: []*osvschema.Event{ + {Introduced: "0"}, + {Fixed: "1.0.0"}, + }, + }, + { + Type: osvschema.Range_ECOSYSTEM, + Events: []*osvschema.Event{ + {Introduced: "2.0.0"}, + {LastAffected: "2.2.0"}, + }, + }, + }, + }, + { + Package: &osvschema.Package{ + Name: "test", + Ecosystem: "Ubuntu:24.04:LTS", + }, + Versions: []string{"1.0.0-1", "1.0.0-2"}, + Ranges: []*osvschema.Range{ + { + Type: osvschema.Range_ECOSYSTEM, + Events: []*osvschema.Event{ + {Introduced: "0"}, + {Fixed: "1.0.0-3"}, + }, + }, + }, + }, + }, + } + + got := computeAffectedVersions(vuln) + + want := []AffectedVersions{ + { + VulnID: "TEST-123", + Ecosystem: "npm", + Name: "testjs", + Events: []AffectedEvent{ + {Type: "introduced", Value: "0"}, + {Type: "fixed", Value: "1.0.0"}, + }, + CoarseMin: "00:00000000.00000000.00000000", + CoarseMax: "00:00000001.00000000.00000000", + }, + { + VulnID: "TEST-123", + Ecosystem: "npm", + Name: "testjs", + Events: []AffectedEvent{ + {Type: "introduced", Value: "2.0.0"}, + {Type: "last_affected", Value: "2.2.0"}, + }, + CoarseMin: "00:00000002.00000000.00000000", + CoarseMax: "00:00000002.00000002.00000000", + }, + { + VulnID: "TEST-123", + Ecosystem: "npm", + Name: "testjs", + Versions: []string{"0.1.0", "0.2.0", "0.3.0", "2.0.0", "2.1.0", "2.2.0"}, + CoarseMin: "00:00000000.00000001.00000000", + CoarseMax: "00:00000002.00000002.00000000", + }, + { + VulnID: "TEST-123", + Ecosystem: "Ubuntu:24.04:LTS", + Name: "test", + Events: []AffectedEvent{ + {Type: "introduced", Value: "0"}, + {Type: "fixed", Value: "1.0.0-3"}, + }, + CoarseMin: "00:00000000.00000000.00000000", + CoarseMax: "00:00000001.00000000.00000000", + }, + { + VulnID: "TEST-123", + Ecosystem: "Ubuntu:24.04", + Name: "test", + Events: []AffectedEvent{ + {Type: "introduced", Value: "0"}, + {Type: "fixed", Value: "1.0.0-3"}, + }, + CoarseMin: "00:00000000.00000000.00000000", + CoarseMax: "00:00000001.00000000.00000000", + }, + { + VulnID: "TEST-123", + Ecosystem: "Ubuntu", + Name: "test", + Events: []AffectedEvent{ + {Type: "introduced", Value: "0"}, + {Type: "fixed", Value: "1.0.0-3"}, + }, + CoarseMin: "00:00000000.00000000.00000000", + CoarseMax: "00:00000001.00000000.00000000", + }, + { + VulnID: "TEST-123", + Ecosystem: "Ubuntu:24.04:LTS", + Name: "test", + Versions: []string{"1.0.0-1", "1.0.0-2"}, + CoarseMin: "00:00000001.00000000.00000000", + CoarseMax: "00:00000001.00000000.00000000", + }, + { + VulnID: "TEST-123", + Ecosystem: "Ubuntu:24.04", + Name: "test", + Versions: []string{"1.0.0-1", "1.0.0-2"}, + CoarseMin: "00:00000001.00000000.00000000", + CoarseMax: "00:00000001.00000000.00000000", + }, + { + VulnID: "TEST-123", + Ecosystem: "Ubuntu", + Name: "test", + Versions: []string{"1.0.0-1", "1.0.0-2"}, + CoarseMin: "00:00000001.00000000.00000000", + CoarseMax: "00:00000001.00000000.00000000", + }, + } + + sortOpt := cmpopts.SortSlices(func(a, b AffectedVersions) bool { + return cmp.Or( + cmp.Compare(a.Ecosystem, b.Ecosystem), + cmp.Compare(len(a.Versions), len(b.Versions)), + cmp.Compare(len(a.Events), len(b.Events)), + cmp.Compare(a.CoarseMin, b.CoarseMin), + cmp.Compare(a.CoarseMax, b.CoarseMax), + ) < 0 + }) + + if diff := gocmp.Diff(want, got, cmpopts.EquateEmpty(), sortOpt); diff != "" { + t.Errorf("computeAffectedVersions mismatch (-want +got):\n%s", diff) + } +} + +func TestNormalizeRepo(t *testing.T) { + testCases := []struct { + repoURL string + expected string + }{ + {"http://git.musl-libc.org/git/musl", "git.musl-libc.org/git/musl"}, + {"https://git.musl-libc.org/git/musl", "git.musl-libc.org/git/musl"}, + {"git://git.musl-libc.org/git/musl", "git.musl-libc.org/git/musl"}, + {"http://github.com/user/repo", "github.com/user/repo"}, + {"https://github.com/user/repo", "github.com/user/repo"}, + {"git://github.com/user/repo", "github.com/user/repo"}, + {"https://github.com/user/repo/", "github.com/user/repo"}, + {"http://git.example.com/path/", "git.example.com/path"}, + {"https://github.com/user/repo.git", "github.com/user/repo"}, + {"http://git.example.com/repo.git", "git.example.com/repo"}, + {"", ""}, + {"http://", ""}, + {"https://hostname", "hostname"}, + } + + for _, tc := range testCases { + t.Run(tc.repoURL, func(t *testing.T) { + got := normalizeRepo(tc.repoURL) + if got != tc.expected { + t.Errorf("normalizeRepo(%q) = %q, want %q", tc.repoURL, got, tc.expected) + } + }) + } +} diff --git a/go/internal/database/datastore/internal/validate/validate.go b/go/internal/database/datastore/internal/validate/validate.go index d0142fbb6e2..37a2752df70 100644 --- a/go/internal/database/datastore/internal/validate/validate.go +++ b/go/internal/database/datastore/internal/validate/validate.go @@ -92,6 +92,22 @@ func readRecords(ctx context.Context, client *datastore.Client) { fmt.Printf("(Go) Failed getting SourceRepository: %v\n", err) os.Exit(1) } + + fmt.Println("(Go) Getting AffectedCommits") + key = datastore.NameKey("AffectedCommits", "CVE-123-456", nil) + var affectedCommits db.AffectedCommits + if err := client.Get(ctx, key, &affectedCommits); err != nil { + fmt.Printf("(Go) Failed getting AffectedCommits: %v\n", err) + os.Exit(1) + } + + fmt.Println("(Go) Getting AffectedVersions") + key = datastore.NameKey("AffectedVersions", "1", nil) + var affectedVersions db.AffectedVersions + if err := client.Get(ctx, key, &affectedVersions); err != nil { + fmt.Printf("(Go) Failed getting AffectedVersions: %v\n", err) + os.Exit(1) + } } func writeRecords(ctx context.Context, client *datastore.Client) { @@ -217,4 +233,33 @@ func writeRecords(ctx context.Context, client *datastore.Client) { fmt.Printf("(Go) Failed writing SourceRepository %v: %v\n", key, err) os.Exit(1) } + + fmt.Println("(Go) Writing AffectedCommits") + key = datastore.NameKey("AffectedCommits", "CVE-987-654", nil) + affectedCommits := db.AffectedCommits{ + VulnID: "CVE-987-654", + Commits: [][]byte{[]byte("hash1"), []byte("hash2")}, + Public: true, + Page: 1, + } + if _, err := client.Put(ctx, key, &affectedCommits); err != nil { + fmt.Printf("(Go) Failed writing AffectedCommits %v: %v\n", key, err) + os.Exit(1) + } + + fmt.Println("(Go) Writing AffectedVersions") + key = datastore.NameKey("AffectedVersions", "2", nil) + affectedVersions := db.AffectedVersions{ + VulnID: "CVE-987-654", + Ecosystem: "Go", + Name: "stdlib", + Versions: []string{"v1.0.0", "v1.1.0"}, + Events: []db.AffectedEvent{{Type: "introduced", Value: "v1.0.0"}}, + CoarseMin: "00:00000001.00000000.00000000", + CoarseMax: "00:00000001.00000001.00000000", + } + if _, err := client.Put(ctx, key, &affectedVersions); err != nil { + fmt.Printf("(Go) Failed writing AffectedVersions %v: %v\n", key, err) + os.Exit(1) + } } diff --git a/go/internal/database/datastore/internal/validate/validate.py b/go/internal/database/datastore/internal/validate/validate.py index 99a4b9a63d2..42270a2651a 100644 --- a/go/internal/database/datastore/internal/validate/validate.py +++ b/go/internal/database/datastore/internal/validate/validate.py @@ -21,7 +21,8 @@ import osv.tests from osv import Vulnerability, AliasGroup, AliasAllowListEntry, \ AliasDenyListEntry, ListedVulnerability, Severity, UpstreamGroup, \ - RelatedGroup, SourceRepository, SourceRepositoryType + RelatedGroup, SourceRepository, SourceRepositoryType, AffectedCommits, \ + AffectedVersions, AffectedEvent def main() -> int: @@ -91,6 +92,27 @@ def main() -> int: modified=datetime.datetime(2025, 6, 7, 8, 9, 10, tzinfo=datetime.UTC), ).put() + print('(Python) Putting AffectedCommits') + AffectedCommits( + id='CVE-123-456', + bug_id='CVE-123-456', + commits=[b'hash1', b'hash2'], + public=True, + page=1, + ).put() + + print('(Python) Putting AffectedVersions') + AffectedVersions( + id='1', + vuln_id='CVE-123-456', + ecosystem='Go', + name='stdlib', + versions=['v1.0.0', 'v1.1.0'], + events=[AffectedEvent(type='introduced', value='v1.0.0')], + coarse_min='00:00000001.00000000.00000000', + coarse_max='00:00000001.00000001.00000000', + ).put() + print('(Python) Putting SourceRepository') SourceRepository( id='oss-fuzz', @@ -149,8 +171,11 @@ def main() -> int: print('(Python) Getting RelatedGroup') if RelatedGroup.get_by_id('CVE-987-654') is None: return 1 - print('(Python) Getting SourceRepository') - if SourceRepository.get_by_id('go-source') is None: + print('(Python) Getting AffectedCommits') + if AffectedCommits.get_by_id('CVE-987-654') is None: + return 1 + print('(Python) Getting AffectedVersions') + if AffectedVersions.get_by_id('2') is None: return 1 return 0 diff --git a/go/osv/models/listedvulnerability.go b/go/internal/database/datastore/listed_vulnerability.go similarity index 97% rename from go/osv/models/listedvulnerability.go rename to go/internal/database/datastore/listed_vulnerability.go index 353757921a2..5398798db1e 100644 --- a/go/osv/models/listedvulnerability.go +++ b/go/internal/database/datastore/listed_vulnerability.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package models +package datastore import ( "regexp" @@ -22,6 +22,9 @@ import ( "github.com/ossf/osv-schema/bindings/go/osvschema" ) +// The maximum number of search indices and autocomplete tags to store (each). +// Datastore has a limit of 20,000 indexed properties + composite indexes. +// We use 2,000 to be conservative. const maxIndices = 2000 var nonAlphanumericRegex = regexp.MustCompile(`[^a-zA-Z0-9]+`) diff --git a/go/internal/database/datastore/listed_vulnerability_test.go b/go/internal/database/datastore/listed_vulnerability_test.go new file mode 100644 index 00000000000..578cf4dada9 --- /dev/null +++ b/go/internal/database/datastore/listed_vulnerability_test.go @@ -0,0 +1,108 @@ +package datastore + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/ossf/osv-schema/bindings/go/osvschema" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestNewListedVulnerabilityFromProto(t *testing.T) { + vuln := &osvschema.Vulnerability{ + Id: "TEST-123", + Published: timestamppb.New(time.Date(2025, time.January, 1, 0, 0, 0, 0, time.UTC)), + Summary: "This is a vuln", + Severity: []*osvschema.Severity{ + { + Type: osvschema.Severity_CVSS_V2, + Score: "AV:N/AC:L/Au:S/C:P/I:P/A:N", + }, + }, + Affected: []*osvschema.Affected{ + { + Package: &osvschema.Package{ + Name: "testjs", + Ecosystem: "npm", + }, + Versions: []string{"1.0.0"}, + Ranges: []*osvschema.Range{ + { + Type: osvschema.Range_GIT, + Repo: "https://github.com/test/test", + Events: []*osvschema.Event{ + {Introduced: "0"}, + {Fixed: "1.2.3"}, + }, + }, + }, + }, + { + Package: &osvschema.Package{ + Name: "test", + Ecosystem: "Ubuntu:24.04:LTS", + }, + Ranges: []*osvschema.Range{ + { + Type: osvschema.Range_GIT, + Repo: "https://github.com/test/test2", + Events: []*osvschema.Event{ + {Introduced: "0"}, + {Fixed: "1.2.3"}, + }, + }, + }, + Severity: []*osvschema.Severity{ + { + Type: osvschema.Severity_Ubuntu, + Score: "High", + }, + }, + }, + { + Package: &osvschema.Package{ + Name: "test", + Ecosystem: "Ubuntu:25.04", + }, + Severity: []*osvschema.Severity{ + { + Type: osvschema.Severity_Ubuntu, + Score: "Low", + }, + }, + }, + }, + } + + got := NewListedVulnerabilityFromProto(vuln) + + want := &ListedVulnerability{ + Published: time.Date(2025, time.January, 1, 0, 0, 0, 0, time.UTC), + Ecosystems: []string{"GIT", "Ubuntu", "npm"}, + Packages: []string{ + "Ubuntu:24.04:LTS/test", + "Ubuntu:25.04/test", + "github.com/test/test", + "github.com/test/test2", + "npm/testjs", + }, + Summary: "This is a vuln", + IsFixed: true, + Severities: []Severity{ + {Type: "CVSS_V2", Score: "AV:N/AC:L/Au:S/C:P/I:P/A:N"}, + {Type: "Ubuntu", Score: "High"}, + {Type: "Ubuntu", Score: "Low"}, + }, + } + + ignoreOpts := cmpopts.IgnoreFields(ListedVulnerability{}, "AutocompleteTags", "SearchIndices") + sortOpt := cmpopts.SortSlices(func(a, b Severity) bool { + return a.Score < b.Score + }) + + if diff := cmp.Diff(want, got, cmpopts.EquateEmpty(), ignoreOpts, sortOpt); diff != "" { + t.Errorf("NewListedVulnerabilityFromProto mismatch (-want +got):\n%s", diff) + } +} diff --git a/go/internal/database/datastore/models.go b/go/internal/database/datastore/models.go index 330e78e76fb..d7d105f89dc 100644 --- a/go/internal/database/datastore/models.go +++ b/go/internal/database/datastore/models.go @@ -16,6 +16,8 @@ package datastore import ( + "slices" + "strings" "time" "cloud.google.com/go/datastore" @@ -33,6 +35,53 @@ type Vulnerability struct { UpstreamRaw []string `datastore:"upstream_raw"` } +type AffectedEvent struct { + Type string `datastore:"type"` + Value string `datastore:"value"` +} + +type AffectedVersions struct { + VulnID string `datastore:"vuln_id"` + Ecosystem string `datastore:"ecosystem"` + Name string `datastore:"name"` + Versions []string `datastore:"versions,noindex"` + Events []AffectedEvent `datastore:"events"` + CoarseMin string `datastore:"coarse_min"` + CoarseMax string `datastore:"coarse_max"` +} + +func (av AffectedVersions) sortKey() string { + // Serializes all fields using the ASCII Unit Separator (\x1f) as a delimiter. + // This provides a stable unique string hash for diffing old vs new entities. + var b strings.Builder + b.WriteString(av.VulnID) + b.WriteString("\x1f") + b.WriteString(av.Ecosystem) + b.WriteString("\x1f") + b.WriteString(av.Name) + b.WriteString("\x1f") + sortedVersions := make([]string, len(av.Versions)) + copy(sortedVersions, av.Versions) + slices.Sort(sortedVersions) + b.WriteString(strings.Join(sortedVersions, ",")) + b.WriteString("\x1f") + for _, e := range av.Events { + b.WriteString(e.Type) + b.WriteString(":") + b.WriteString(e.Value) + b.WriteString(",") + } + + return b.String() +} + +type AffectedCommits struct { + VulnID string `datastore:"bug_id"` + Commits [][]byte `datastore:"commits"` + Public bool `datastore:"public"` + Page int `datastore:"page,noindex"` +} + type AliasGroup struct { VulnIDs []string `datastore:"bug_ids"` Modified time.Time `datastore:"last_modified"` diff --git a/go/internal/database/datastore/vulnerability.go b/go/internal/database/datastore/vulnerability.go index 3bae3dd6b28..a0a07e3475e 100644 --- a/go/internal/database/datastore/vulnerability.go +++ b/go/internal/database/datastore/vulnerability.go @@ -1,27 +1,32 @@ package datastore import ( + "bytes" "context" "errors" "fmt" "iter" + "slices" "strings" "time" "cloud.google.com/go/datastore" "github.com/google/osv.dev/go/internal/models" + "github.com/google/osv.dev/go/osv/clients" "github.com/ossf/osv-schema/bindings/go/osvschema" "google.golang.org/api/iterator" + "google.golang.org/protobuf/proto" ) type VulnerabilityStore struct { - client *datastore.Client + client *datastore.Client + gcsStore clients.CloudStorage } var _ models.VulnerabilityStore = (*VulnerabilityStore)(nil) -func NewVulnerabilityStore(client *datastore.Client) *VulnerabilityStore { - return &VulnerabilityStore{client: client} +func NewVulnerabilityStore(client *datastore.Client, gcsStore clients.CloudStorage) *VulnerabilityStore { + return &VulnerabilityStore{client: client, gcsStore: gcsStore} } func (s *VulnerabilityStore) ListBySource(ctx context.Context, source string, skipWithdrawn bool) iter.Seq2[*models.VulnSourceRef, error] { @@ -82,14 +87,254 @@ func (s *VulnerabilityStore) GetSourceModified(ctx context.Context, id string) ( return v.ModifiedRaw, nil } -func (s *VulnerabilityStore) Get(_ context.Context, _ string) (*osvschema.Vulnerability, error) { - panic("not implemented") +func (s *VulnerabilityStore) Get(ctx context.Context, id string) (*osvschema.Vulnerability, error) { + path := fmt.Sprintf("all/pb/%s.pb", id) + data, err := s.gcsStore.ReadObject(ctx, path) + if err != nil { + if errors.Is(err, clients.ErrNotFound) { + return nil, models.ErrNotFound + } + + return nil, err + } + + v := &osvschema.Vulnerability{} + if err := proto.Unmarshal(data, v); err != nil { + return nil, err + } + + return v, nil } -func (s *VulnerabilityStore) GetWithMetadata(_ context.Context, _ string) (*osvschema.Vulnerability, *models.VulnSourceRef, error) { - panic("not implemented") +func (s *VulnerabilityStore) GetWithMetadata(ctx context.Context, id string) (*osvschema.Vulnerability, *models.VulnSourceRef, error) { + key := datastore.NameKey("Vulnerability", id, nil) + var dv Vulnerability + if err := s.client.Get(ctx, key, &dv); err != nil { + if errors.Is(err, datastore.ErrNoSuchEntity) { + return nil, nil, models.ErrNotFound + } + + return nil, nil, err + } + v, err := s.Get(ctx, id) + if errors.Is(err, models.ErrNotFound) { + return nil, nil, errors.New("vulnerability found in Datastore but not in GCS") + } + if err != nil { + return nil, nil, err + } + + source, path, _ := strings.Cut(dv.SourceID, ":") + + ref := &models.VulnSourceRef{ + ID: id, + Source: source, + Path: path, + ModifiedRaw: dv.ModifiedRaw, + } + + return v, ref, nil } -func (s *VulnerabilityStore) Write(_ context.Context, _ models.WriteRequest) error { - panic("not implemented") +func (s *VulnerabilityStore) Write(ctx context.Context, req models.WriteRequest) error { + // Updates the base records (Vulnerability, ListedVulnerability, AffectedVersions) + // in a single atomic Datastore transaction. Does not write to GCS. + if err := s.writeBaseEntities(ctx, req); err != nil { + return err + } + + if !req.AffectedCommits.Skip { + if err := s.updateAffectedCommits(ctx, req.ID, req.AffectedCommits.Commits); err != nil { + return err + } + } + + return s.uploadToGCS(ctx, req.ID, req.Enriched) +} + +func (s *VulnerabilityStore) writeBaseEntities(ctx context.Context, req models.WriteRequest) error { + tx, err := s.client.NewTransaction(ctx) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + // Use a deferred rollback as a safety net for early returns or panics. + // Discard error because it returns an error if transaction was already committed. + defer func() { + _ = tx.Rollback() + }() + + dsVulnKey := datastore.NameKey("Vulnerability", req.ID, nil) + dsVuln := new(Vulnerability) + if err := tx.Get(dsVulnKey, dsVuln); err != nil && !errors.Is(err, datastore.ErrNoSuchEntity) { + return fmt.Errorf("failed to get old Vulnerability: %w", err) + } + + dsVuln.SourceID = req.Source + ":" + req.Path + dsVuln.Modified = req.Enriched.GetModified().AsTime() + dsVuln.IsWithdrawn = req.Enriched.GetWithdrawn() != nil + + if req.Raw != nil { + dsVuln.ModifiedRaw = req.Raw.GetModified().AsTime() + dsVuln.AliasRaw = req.Raw.GetAliases() + dsVuln.RelatedRaw = req.Raw.GetRelated() + dsVuln.UpstreamRaw = req.Raw.GetUpstream() + } + + listedVulnKey := datastore.NameKey("ListedVulnerability", req.ID, nil) + listedVuln := NewListedVulnerabilityFromProto(req.Enriched) + + if dsVuln.IsWithdrawn { + if _, err := tx.Put(dsVulnKey, dsVuln); err != nil { + return fmt.Errorf("failed to put Vulnerability: %w", err) + } + if err := tx.Delete(listedVulnKey); err != nil && !errors.Is(err, datastore.ErrNoSuchEntity) { + return fmt.Errorf("failed to delete ListedVulnerability: %w", err) + } + } else { + if _, err := tx.Put(dsVulnKey, dsVuln); err != nil { + return fmt.Errorf("failed to put Vulnerability: %w", err) + } + if _, err := tx.Put(listedVulnKey, listedVuln); err != nil { + return fmt.Errorf("failed to put ListedVulnerability: %w", err) + } + } + + if err := s.updateAffectedVersions(ctx, tx, dsVuln.IsWithdrawn, req.Enriched); err != nil { + return err + } + + if _, err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} + +func (s *VulnerabilityStore) updateAffectedVersions(ctx context.Context, tx *datastore.Transaction, isWithdrawn bool, enriched *osvschema.Vulnerability) error { + var oldAffected []AffectedVersions + q := datastore.NewQuery("AffectedVersions").FilterField("vuln_id", "=", enriched.GetId()).Transaction(tx) + keys, err := s.client.GetAll(ctx, q, &oldAffected) + if err != nil { + return fmt.Errorf("failed to query old AffectedVersions: %w", err) + } + + if isWithdrawn { + for _, k := range keys { + if err := tx.Delete(k); err != nil && !errors.Is(err, datastore.ErrNoSuchEntity) { + return fmt.Errorf("failed to delete old AffectedVersions: %w", err) + } + } + } else { + newAffected := computeAffectedVersions(enriched) + + oldKeysMap := make(map[string]*datastore.Key) + for i, av := range oldAffected { + oldKeysMap[av.sortKey()] = keys[i] + } + + newKeys := make(map[string]bool) + var added []AffectedVersions + for _, nav := range newAffected { + newKeys[nav.sortKey()] = true + if _, exists := oldKeysMap[nav.sortKey()]; !exists { + added = append(added, nav) + } + } + + var removed []*datastore.Key + for i, av := range oldAffected { + if !newKeys[av.sortKey()] { + removed = append(removed, keys[i]) + } + } + + for _, av := range added { + incompleteKey := datastore.IncompleteKey("AffectedVersions", nil) + if _, err := tx.Put(incompleteKey, &av); err != nil { + return fmt.Errorf("failed to put new AffectedVersions: %w", err) + } + } + + for _, k := range removed { + if err := tx.Delete(k); err != nil && !errors.Is(err, datastore.ErrNoSuchEntity) { + return fmt.Errorf("failed to delete removed AffectedVersions: %w", err) + } + } + } + + return nil +} + +func (s *VulnerabilityStore) updateAffectedCommits(ctx context.Context, id string, commits [][]byte) error { + // Write batched commit indexes. + // Sort the commits for some determinism. + sortedCommits := make([][]byte, len(commits)) + copy(sortedCommits, commits) + slices.SortFunc(sortedCommits, bytes.Compare) + + var toPut []*AffectedCommits + var keys []*datastore.Key + numPages := 0 + const batchSize = 10000 + for i := 0; i < len(sortedCommits); i += batchSize { + end := i + batchSize + if end > len(sortedCommits) { + end = len(sortedCommits) + } + batch := sortedCommits[i:end] + + acKey := datastore.NameKey("AffectedCommits", fmt.Sprintf("%s-%d", id, numPages), nil) + ac := &AffectedCommits{ + VulnID: id, + Commits: batch, + Public: true, + Page: numPages, + } + toPut = append(toPut, ac) + keys = append(keys, acKey) + numPages++ + } + + if len(toPut) > 0 { + if _, err := s.client.PutMulti(ctx, keys, toPut); err != nil { + return fmt.Errorf("failed to write AffectedCommits: %w", err) + } + } + + // Clear any previously written pages above our current page count. + q := datastore.NewQuery("AffectedCommits").FilterField("bug_id", "=", id) + var existing []AffectedCommits + keys, err := s.client.GetAll(ctx, q, &existing) + if err != nil { + return fmt.Errorf("failed to query AffectedCommits: %w", err) + } + + for idx, k := range keys { + if existing[idx].Page >= numPages { + if err := s.client.Delete(ctx, k); err != nil { + return fmt.Errorf("failed to delete AffectedCommits: %w", err) + } + } + } + + return nil +} + +func (s *VulnerabilityStore) uploadToGCS(ctx context.Context, id string, enriched *osvschema.Vulnerability) error { + path := fmt.Sprintf("all/pb/%s.pb", id) + newData, err := proto.Marshal(enriched) + if err != nil { + return fmt.Errorf("failed to marshal enriched vulnerability: %w", err) + } + + customTime := enriched.GetModified().AsTime() + opts := &clients.WriteOptions{ + CustomTime: &customTime, + } + + if err := s.gcsStore.WriteObject(ctx, path, newData, opts); err != nil { + return fmt.Errorf("failed to write vulnerability object to GCS: %w", err) + } + + return nil } diff --git a/go/internal/database/datastore/vulnerability_test.go b/go/internal/database/datastore/vulnerability_test.go index add5e2b510d..752f93d44e2 100644 --- a/go/internal/database/datastore/vulnerability_test.go +++ b/go/internal/database/datastore/vulnerability_test.go @@ -2,6 +2,7 @@ package datastore import ( "context" + "errors" "testing" "cloud.google.com/go/datastore" @@ -9,12 +10,15 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/osv.dev/go/internal/models" "github.com/google/osv.dev/go/testutils" + "github.com/ossf/osv-schema/bindings/go/osvschema" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestVulnerabilityStore_ListBySource(t *testing.T) { ctx := context.Background() dsClient := testutils.MustNewDatastoreClientForTesting(t) - store := NewVulnerabilityStore(dsClient) + store := NewVulnerabilityStore(dsClient, testutils.NewMockStorage()) vulns := []Vulnerability{ { @@ -100,3 +104,198 @@ func TestVulnerabilityStore_ListBySource(t *testing.T) { }) } } + +func TestVulnerabilityStore_Get(t *testing.T) { + ctx := context.Background() + dsClient := testutils.MustNewDatastoreClientForTesting(t) + gcsStore := testutils.NewMockStorage() + store := NewVulnerabilityStore(dsClient, gcsStore) + + v := &osvschema.Vulnerability{ + Id: "GHSA-xxxx-yyyy-zzzz", + } + data, err := proto.Marshal(v) + if err != nil { + t.Fatalf("failed to marshal vulnerability: %v", err) + } + + if err := gcsStore.WriteObject(ctx, "all/pb/GHSA-xxxx-yyyy-zzzz.pb", data, nil); err != nil { + t.Fatalf("failed to write mock GCS object: %v", err) + } + + got, err := store.Get(ctx, "GHSA-xxxx-yyyy-zzzz") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + + if got.GetId() != v.GetId() { + t.Errorf("Get() id mismatch, got %s, want %s", got.GetId(), v.GetId()) + } +} + +func TestVulnerabilityStore_Get_NotFound(t *testing.T) { + ctx := context.Background() + dsClient := testutils.MustNewDatastoreClientForTesting(t) + gcsStore := testutils.NewMockStorage() + store := NewVulnerabilityStore(dsClient, gcsStore) + + _, err := store.Get(ctx, "GHSA-not-exists") + if !errors.Is(err, models.ErrNotFound) { + t.Errorf("want ErrNotFound, got %v", err) + } +} + +func TestVulnerabilityStore_GetWithMetadata(t *testing.T) { + ctx := context.Background() + dsClient := testutils.MustNewDatastoreClientForTesting(t) + gcsStore := testutils.NewMockStorage() + store := NewVulnerabilityStore(dsClient, gcsStore) + + v := &osvschema.Vulnerability{ + Id: "GHSA-xxxx-yyyy-zzzz", + } + data, err := proto.Marshal(v) + if err != nil { + t.Fatalf("failed to marshal vulnerability: %v", err) + } + + if err := gcsStore.WriteObject(ctx, "all/pb/GHSA-xxxx-yyyy-zzzz.pb", data, nil); err != nil { + t.Fatalf("failed to write mock GCS object: %v", err) + } + + key := datastore.NameKey("Vulnerability", "GHSA-xxxx-yyyy-zzzz", nil) + dv := Vulnerability{ + SourceID: "source-x:path/1.json", + } + if _, err := dsClient.Put(ctx, key, &dv); err != nil { + t.Fatalf("failed to setup datastore test data: %v", err) + } + + gotV, gotRef, err := store.GetWithMetadata(ctx, "GHSA-xxxx-yyyy-zzzz") + if err != nil { + t.Fatalf("GetWithMetadata() error = %v", err) + } + + if gotV.GetId() != v.GetId() { + t.Errorf("GetWithMetadata() id mismatch, got %s, want %s", gotV.GetId(), v.GetId()) + } + + if gotRef.Source != "source-x" { + t.Errorf("gotRef.Source mismatch, got %s, want source-x", gotRef.Source) + } + + if gotRef.Path != "path/1.json" { + t.Errorf("gotRef.Path mismatch, got %s, want path/1.json", gotRef.Path) + } +} + +func TestVulnerabilityStore_GetWithMetadata_NotFound(t *testing.T) { + ctx := context.Background() + dsClient := testutils.MustNewDatastoreClientForTesting(t) + gcsStore := testutils.NewMockStorage() + store := NewVulnerabilityStore(dsClient, gcsStore) + + _, _, err := store.GetWithMetadata(ctx, "GHSA-not-exists") + if !errors.Is(err, models.ErrNotFound) { + t.Errorf("want ErrNotFound, got %v", err) + } +} + +func TestVulnerabilityStore_Write(t *testing.T) { + ctx := context.Background() + dsClient := testutils.MustNewDatastoreClientForTesting(t) + gcsStore := testutils.NewMockStorage() + store := NewVulnerabilityStore(dsClient, gcsStore) + + v := &osvschema.Vulnerability{ + Id: "TEST-WRITE-123", + Published: timestamppb.Now(), + Affected: []*osvschema.Affected{ + { + Package: &osvschema.Package{ + Name: "test-pkg", + Ecosystem: "npm", + }, + Ranges: []*osvschema.Range{ + { + Type: osvschema.Range_ECOSYSTEM, + Events: []*osvschema.Event{ + {Introduced: "1.0.0"}, + {Fixed: "2.0.0"}, + }, + }, + }, + }, + }, + } + + req := models.WriteRequest{ + ID: "TEST-WRITE-123", + Source: "test-source", + Path: "TEST-WRITE-123.json", + Raw: v, + Enriched: v, + AffectedCommits: models.AffectedCommitsResult{ + Skip: false, + Commits: [][]byte{[]byte("test-commit")}, + }, + } + + if err := store.Write(ctx, req); err != nil { + t.Fatalf("Write() failed: %v", err) + } + + // 1. Verify Vulnerability was written + var dv Vulnerability + vKey := datastore.NameKey("Vulnerability", req.ID, nil) + if err := dsClient.Get(ctx, vKey, &dv); err != nil { + t.Fatalf("failed to fetch Vulnerability after Write: %v", err) + } + if dv.SourceID != "test-source:TEST-WRITE-123.json" { + t.Errorf("Vulnerability.SourceID mismatch, got %s", dv.SourceID) + } + + // 2. Verify ListedVulnerability was written + var lv ListedVulnerability + lvKey := datastore.NameKey("ListedVulnerability", req.ID, nil) + if err := dsClient.Get(ctx, lvKey, &lv); err != nil { + t.Fatalf("failed to fetch ListedVulnerability after Write: %v", err) + } + if lv.Summary != v.GetSummary() { + t.Errorf("ListedVulnerability Summary mismatch") + } + + // 3. Verify AffectedCommits was written + var ac AffectedCommits + acKey := datastore.NameKey("AffectedCommits", "TEST-WRITE-123-0", nil) + if err := dsClient.Get(ctx, acKey, &ac); err != nil { + t.Fatalf("failed to fetch AffectedCommits after Write: %v", err) + } + if len(ac.Commits) != 1 || string(ac.Commits[0]) != "test-commit" { + t.Errorf("AffectedCommits mismatch, got %v", ac.Commits) + } + + // 4. Verify AffectedVersions was written + var avs []AffectedVersions + q := datastore.NewQuery("AffectedVersions").FilterField("vuln_id", "=", req.ID) + if _, err := dsClient.GetAll(ctx, q, &avs); err != nil { + t.Fatalf("failed to query AffectedVersions after Write: %v", err) + } + if len(avs) == 0 { + t.Errorf("Expected AffectedVersions to be written, got 0") + } + + // 5. Verify GCS object was written + path := "all/pb/TEST-WRITE-123.pb" + gcsData, err := gcsStore.ReadObject(ctx, path) + if err != nil { + t.Fatalf("failed to read GCS object after Write: %v", err) + } + gotV := new(osvschema.Vulnerability) + if err := proto.Unmarshal(gcsData, gotV); err != nil { + t.Fatalf("failed to unmarshal vulnerability from GCS: %v", err) + } + if gotV.GetId() != v.GetId() { + t.Errorf("GCS Vulnerability ID mismatch, got %s, want %s", gotV.GetId(), v.GetId()) + } +}