diff --git a/go.mod b/go.mod index fa2c3e0f8..5a1c68df2 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,6 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect - github.com/kr/pretty v0.3.1 // indirect github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect @@ -31,7 +30,6 @@ require ( github.com/lestrrat-go/option v1.0.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/stretchr/objx v0.5.2 // indirect golang.org/x/crypto v0.35.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index c0b10b6fd..f4841fb1b 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,5 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -13,8 +12,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk= @@ -27,13 +24,11 @@ github.com/gorilla/sessions v1.1.1 h1:YMDmfaK68mUixINzY/XjscuJ47uXFWSSHzFbBQM0Pr github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= github.com/jarcoal/httpmock v0.0.0-20180424175123-9c70cfe4a1da h1:FjHUJJ7oBW4G/9j1KzlHaXL09LyMVM9rupS39lncbXk= github.com/jarcoal/httpmock v0.0.0-20180424175123-9c70cfe4a1da/go.mod h1:ks+b9deReOc7jgqp+e7LuFiCBH6Rm5hL32cLcEAArb4= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A= github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y= github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= @@ -51,14 +46,10 @@ github.com/markbates/going v1.0.0 h1:DQw0ZP7NbNlFGcKbcE/IVSOAFzScxRtLpd0rLMzLhq0 github.com/markbates/going v1.0.0/go.mod h1:I6mnB4BPnEeqo85ynXIx1ZFLLbtiLHNXVgWeFO9OGOA= github.com/mrjones/oauth v0.0.0-20180629183705-f4e24b6d100c h1:3wkDRdxK92dF+c1ke2dtj7ZzemFWBHB9plnJOtlwdFA= github.com/mrjones/oauth v0.0.0-20180629183705-f4e24b6d100c/go.mod h1:skjdDftzkFALcuGzYSklqYd8gvat6F1gZJ4YPVbkZpM= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -85,8 +76,6 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= -golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/providers/linkedin/linkedin.go b/providers/linkedin/linkedin.go index 5719911d7..70b28718e 100644 --- a/providers/linkedin/linkedin.go +++ b/providers/linkedin/linkedin.go @@ -2,6 +2,7 @@ package linkedin import ( + "context" "encoding/json" "errors" "fmt" @@ -82,9 +83,10 @@ func (p *Provider) BeginAuth(state string) (goth.Session, error) { func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { s := session.(*Session) user := goth.User{ - AccessToken: s.AccessToken, - Provider: p.Name(), - ExpiresAt: s.ExpiresAt, + AccessToken: s.AccessToken, + Provider: p.Name(), + ExpiresAt: s.ExpiresAt, + RefreshToken: s.RefreshToken, } if user.AccessToken == "" { @@ -267,12 +269,18 @@ func newConfig(provider *Provider, scopes []string) *oauth2.Config { return c } -// RefreshToken refresh token is not provided by linkedin -func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { - return nil, errors.New("Refresh token is not provided by linkedin") +// RefreshTokenAvailable tells whether a refresh token is provided by the auth provider or not +func (p *Provider) RefreshTokenAvailable() bool { + return true } -// RefreshTokenAvailable refresh token is not provided by linkedin -func (p *Provider) RefreshTokenAvailable() bool { - return false +// RefreshToken gets a new access token using the refresh token +func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { + token := &oauth2.Token{RefreshToken: refreshToken} + ts := p.config.TokenSource(context.Background(), token) + newToken, err := ts.Token() + if err != nil { + return nil, err + } + return newToken, err } diff --git a/providers/linkedin/linkedin_test.go b/providers/linkedin/linkedin_test.go index a67644eca..a8e3d630b 100644 --- a/providers/linkedin/linkedin_test.go +++ b/providers/linkedin/linkedin_test.go @@ -47,11 +47,12 @@ func Test_SessionFromJSON(t *testing.T) { provider := linkedinProvider() - s, err := provider.UnmarshalSession(`{"AuthURL":"http://linkedin.com/auth_url","AccessToken":"1234567890"}`) + s, err := provider.UnmarshalSession(`{"AuthURL":"http://linkedin.com/auth_url","AccessToken":"1234567890","RefreshToken":"987654321"}`) a.NoError(err) session := s.(*linkedin.Session) a.Equal(session.AuthURL, "http://linkedin.com/auth_url") a.Equal(session.AccessToken, "1234567890") + a.Equal(session.RefreshToken, "987654321") } func linkedinProvider() *linkedin.Provider { diff --git a/providers/linkedin/session.go b/providers/linkedin/session.go index 51dee95d1..e53479872 100644 --- a/providers/linkedin/session.go +++ b/providers/linkedin/session.go @@ -10,9 +10,10 @@ import ( // Session stores data during the auth process with LinkedIn. type Session struct { - AuthURL string - AccessToken string - ExpiresAt time.Time + AuthURL string + AccessToken string + ExpiresAt time.Time + RefreshToken string } // GetAuthURL will return the URL set by calling the `BeginAuth` function on the LinkedIn provider. @@ -37,6 +38,7 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, s.AccessToken = token.AccessToken s.ExpiresAt = token.Expiry + s.RefreshToken = token.RefreshToken return token.AccessToken, err } diff --git a/providers/linkedin/session_test.go b/providers/linkedin/session_test.go index 4cd49c22a..962309cd2 100644 --- a/providers/linkedin/session_test.go +++ b/providers/linkedin/session_test.go @@ -1,13 +1,37 @@ package linkedin_test import ( + "errors" + "io" + "net/http" + "strings" "testing" + "time" "github.com/markbates/goth" "github.com/markbates/goth/providers/linkedin" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) +type MockParams struct { + params map[string]string +} + +func (m *MockParams) Get(key string) string { + return m.params[key] +} + +type MockedHTTPClient struct { + mock.Mock +} + +func (m *MockedHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) { + args := m.Mock.Called(req) + return args.Get(0).(*http.Response), args.Error(1) +} + func Test_Implements_Session(t *testing.T) { t.Parallel() a := assert.New(t) @@ -46,3 +70,73 @@ func Test_String(t *testing.T) { a.Equal(s.String(), s.Marshal()) } + +func Test_Authorize(t *testing.T) { + session := &linkedin.Session{} + params := &MockParams{ + params: map[string]string{ + "code": "authorization_code", + }, + } + + t.Run("happy path", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := linkedinProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"test_token","expires_in":3600, "refresh_token":"refresh_token"}`)), + }, nil) + token, err := session.Authorize(p, params) + require.NoError(t, err) + assert.Equal(t, "test_token", token) + assert.Equal(t, session.AccessToken, "test_token") + assert.WithinDuration(t, session.ExpiresAt, time.Now().Add(3600*time.Second), 1*time.Second) + assert.Equal(t, session.RefreshToken, "refresh_token") + }) + + t.Run("error on request", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := linkedinProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{}, errors.New("request error")) + _, err := session.Authorize(p, params) + require.Error(t, err) + }) + + t.Run("non-200 status code", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := linkedinProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusForbidden, + Body: io.NopCloser(strings.NewReader(``)), + }, nil) + _, err := session.Authorize(p, params) + require.Error(t, err) + }) + + t.Run("error on response decode", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := linkedinProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`not a json`)), + }, nil) + _, err := session.Authorize(p, params) + require.Error(t, err) + }) + + t.Run("error code in response", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := linkedinProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"code":1,"msg":"error message"}`)), + }, nil) + _, err := session.Authorize(p, params) + require.Error(t, err) + }) +}