Skip to content

Commit c68a6ed

Browse files
authored
[1.16] Fix token renewal for pulsar oauth2 (#4080)
2 parents d3d319c + 5e5e701 commit c68a6ed

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

common/authentication/oauth2/clientcredentials.go

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ type ClientCredentials struct {
5959
httpClient *http.Client
6060
fetchTokenFn func(context.Context) (*oauth2.Token, error)
6161

62-
lock sync.RWMutex
62+
lock sync.Mutex
6363
}
6464

6565
func NewClientCredentials(ctx context.Context, opts ClientCredentialsOptions) (*ClientCredentials, error) {
@@ -126,37 +126,28 @@ func (c *ClientCredentialsOptions) toConfig() (*ccreds.Config, *http.Client, err
126126
}
127127

128128
func (c *ClientCredentials) Token() (string, error) {
129-
c.lock.RLock()
130-
defer c.lock.RUnlock()
131-
132-
if !c.currentToken.Valid() {
133-
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
134-
defer cancel()
135-
if err := c.renewToken(ctx); err != nil {
136-
return "", err
137-
}
138-
}
139-
140-
return c.currentToken.AccessToken, nil
141-
}
142-
143-
func (c *ClientCredentials) renewToken(ctx context.Context) error {
144129
c.lock.Lock()
145130
defer c.lock.Unlock()
146131

147-
// We need to check if the current token is valid because we might have lost
148-
// the mutex lock race from the caller and we don't want to double-fetch a
149-
// token unnecessarily!
150132
if c.currentToken.Valid() {
151-
return nil
133+
return c.currentToken.AccessToken, nil
134+
}
135+
136+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
137+
defer cancel()
138+
if err := c.renewToken(ctx); err != nil {
139+
return "", err
152140
}
141+
return c.currentToken.AccessToken, nil
142+
}
153143

144+
func (c *ClientCredentials) renewToken(ctx context.Context) error {
154145
token, err := c.fetchTokenFn(context.WithValue(ctx, oauth2.HTTPClient, c.httpClient))
155146
if err != nil {
156147
return err
157148
}
158149

159-
if !c.currentToken.Valid() {
150+
if !token.Valid() {
160151
return errors.New("oauth2 client_credentials token source returned an invalid token")
161152
}
162153

common/authentication/oauth2/clientcredentials_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ limitations under the License.
1414
package oauth2
1515

1616
import (
17+
"context"
1718
"net/url"
1819
"testing"
20+
"time"
1921

2022
"github.com/stretchr/testify/assert"
23+
"github.com/stretchr/testify/require"
24+
"golang.org/x/oauth2"
2125
ccreds "golang.org/x/oauth2/clientcredentials"
2226
)
2327

@@ -93,3 +97,19 @@ func Test_toConfig(t *testing.T) {
9397
})
9498
}
9599
}
100+
101+
func Test_TokenRenewal(t *testing.T) {
102+
expired := &oauth2.Token{AccessToken: "old-token", Expiry: time.Now().Add(-1 * time.Minute)}
103+
renewed := &oauth2.Token{AccessToken: "new-token", Expiry: time.Now().Add(1 * time.Hour)}
104+
105+
c := &ClientCredentials{
106+
currentToken: expired,
107+
fetchTokenFn: func(ctx context.Context) (*oauth2.Token, error) {
108+
return renewed, nil
109+
},
110+
}
111+
112+
tok, err := c.Token()
113+
require.NoError(t, err)
114+
assert.Equal(t, "new-token", tok)
115+
}

0 commit comments

Comments
 (0)