diff --git a/github.go b/github.go index 5a5b080..c3f01c3 100644 --- a/github.go +++ b/github.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "net/url" + "strings" "time" ) @@ -101,6 +102,16 @@ func (c *githubClient) withEnterpriseURL(baseURL string) (*githubClient, error) return nil, fmt.Errorf("failed to parse base URL: %w", err) } + if !strings.HasSuffix(base.Path, "/") { + base.Path += "/" + } + + if !strings.HasSuffix(base.Path, "/api/v3/") && + !strings.HasPrefix(base.Host, "api.") && + !strings.Contains(base.Host, ".api.") { + base.Path += "api/v3/" + } + c.baseURL = base return c, nil diff --git a/github_test.go b/github_test.go index 1c4f20b..5f02937 100644 --- a/github_test.go +++ b/github_test.go @@ -12,33 +12,84 @@ import ( func Test_githubClient_withEnterpriseURL(t *testing.T) { tests := []struct { - name string - baseURL string - wantErr bool + name string + baseURL string + wantErr bool + expectedBaseURL string }{ { - name: "valid URL", - baseURL: "https://github.example.com", - wantErr: false, + name: "valid URL with first subdomain", + baseURL: "https://api.github.example.com", + wantErr: false, + expectedBaseURL: "https://api.github.example.com/", + }, + { + name: "valid URL with first subdomain and trailing slash", + baseURL: "https://api.github.example.com/", + wantErr: false, + expectedBaseURL: "https://api.github.example.com/", + }, + { + name: "valid URL with second subdomain", + baseURL: "https://ghes.api.example.com", + wantErr: false, + expectedBaseURL: "https://ghes.api.example.com/", + }, + { + name: "valid URL with second subdomain and trailing slash", + baseURL: "https://ghes.api.example.com/", + wantErr: false, + expectedBaseURL: "https://ghes.api.example.com/", + }, + { + name: "valid URL with path", + baseURL: "https://github.example.com/api/v3", + wantErr: false, + expectedBaseURL: "https://github.example.com/api/v3/", }, { - name: "invalid URL with control characters", - baseURL: "ht\ntp://invalid", - wantErr: true, + name: "valid URL with path and trailing slash", + baseURL: "https://github.example.com/api/v3/", + wantErr: false, + expectedBaseURL: "https://github.example.com/api/v3/", }, { - name: "URL with spaces", - baseURL: "http://invalid url with spaces", - wantErr: true, + name: "valid URL without path", + baseURL: "https://github.example.com", + wantErr: false, + expectedBaseURL: "https://github.example.com/api/v3/", + }, + { + name: "valid URL without path but with trailing slash", + baseURL: "https://github.example.com/", + wantErr: false, + expectedBaseURL: "https://github.example.com/api/v3/", + }, + { + name: "invalid URL with control characters", + baseURL: "ht\ntp://invalid", + wantErr: true, + expectedBaseURL: "", + }, + { + name: "URL with spaces", + baseURL: "http://invalid url with spaces", + wantErr: true, + expectedBaseURL: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := newGitHubClient(&http.Client{}) - _, err := client.withEnterpriseURL(tt.baseURL) + githubClient, err := client.withEnterpriseURL(tt.baseURL) + if (err != nil) != tt.wantErr { - t.Errorf("withEnterpriseURL() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("withEnterpriseURL(%v) error = %v", tt.baseURL, err) + } + + if err == nil && githubClient.baseURL.String() != tt.expectedBaseURL { + t.Errorf("withEnterpriseURL(%v) expected = %v, received = %v", tt.baseURL, tt.expectedBaseURL, githubClient.baseURL) } }) }