@@ -18,11 +18,15 @@ package exec
18
18
19
19
import (
20
20
"bytes"
21
+ "context"
22
+ "crypto/tls"
21
23
"fmt"
22
24
"io"
25
+ "net"
23
26
"net/http"
24
27
"os"
25
28
"os/exec"
29
+ "reflect"
26
30
"sync"
27
31
"time"
28
32
@@ -35,6 +39,8 @@ import (
35
39
"k8s.io/client-go/pkg/apis/clientauthentication"
36
40
"k8s.io/client-go/pkg/apis/clientauthentication/v1alpha1"
37
41
"k8s.io/client-go/tools/clientcmd/api"
42
+ "k8s.io/client-go/transport"
43
+ "k8s.io/client-go/util/connrotation"
38
44
)
39
45
40
46
const execInfoEnv = "KUBERNETES_EXEC_INFO"
@@ -147,14 +153,55 @@ type Authenticator struct {
147
153
// The mutex also guards calling the plugin. Since the plugin could be
148
154
// interactive we want to make sure it's only called once.
149
155
mu sync.Mutex
150
- cachedToken string
156
+ cachedCreds * credentials
151
157
exp time.Time
158
+
159
+ onRotate func ()
152
160
}
153
161
154
- // WrapTransport instruments an existing http.RoundTripper with credentials returned
155
- // by the plugin.
156
- func (a * Authenticator ) WrapTransport (rt http.RoundTripper ) http.RoundTripper {
157
- return & roundTripper {a , rt }
162
+ type credentials struct {
163
+ token string
164
+ cert * tls.Certificate
165
+ }
166
+
167
+ // UpdateTransportConfig updates the transport.Config to use credentials
168
+ // returned by the plugin.
169
+ func (a * Authenticator ) UpdateTransportConfig (c * transport.Config ) error {
170
+ wt := c .WrapTransport
171
+ c .WrapTransport = func (rt http.RoundTripper ) http.RoundTripper {
172
+ if wt != nil {
173
+ rt = wt (rt )
174
+ }
175
+ return & roundTripper {a , rt }
176
+ }
177
+
178
+ getCert := c .TLS .GetCert
179
+ c .TLS .GetCert = func () (* tls.Certificate , error ) {
180
+ // If previous GetCert is present and returns a valid non-nil
181
+ // certificate, use that. Otherwise use cert from exec plugin.
182
+ if getCert != nil {
183
+ cert , err := getCert ()
184
+ if err != nil {
185
+ return nil , err
186
+ }
187
+ if cert != nil {
188
+ return cert , nil
189
+ }
190
+ }
191
+ return a .cert ()
192
+ }
193
+
194
+ var dial func (ctx context.Context , network , addr string ) (net.Conn , error )
195
+ if c .Dial != nil {
196
+ dial = c .Dial
197
+ } else {
198
+ dial = (& net.Dialer {Timeout : 30 * time .Second , KeepAlive : 30 * time .Second }).DialContext
199
+ }
200
+ d := connrotation .NewDialer (dial )
201
+ a .onRotate = d .CloseAll
202
+ c .Dial = d .DialContext
203
+
204
+ return nil
158
205
}
159
206
160
207
type roundTripper struct {
@@ -169,11 +216,13 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
169
216
return r .base .RoundTrip (req )
170
217
}
171
218
172
- token , err := r .a .token ()
219
+ creds , err := r .a .getCreds ()
173
220
if err != nil {
174
- return nil , fmt .Errorf ("getting token: %v" , err )
221
+ return nil , fmt .Errorf ("getting credentials: %v" , err )
222
+ }
223
+ if creds .token != "" {
224
+ req .Header .Set ("Authorization" , "Bearer " + creds .token )
175
225
}
176
- req .Header .Set ("Authorization" , "Bearer " + token )
177
226
178
227
res , err := r .base .RoundTrip (req )
179
228
if err != nil {
@@ -184,47 +233,60 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
184
233
Header : res .Header ,
185
234
Code : int32 (res .StatusCode ),
186
235
}
187
- if err := r .a .refresh ( token , resp ); err != nil {
188
- glog .Errorf ("refreshing token : %v" , err )
236
+ if err := r .a .maybeRefreshCreds ( creds , resp ); err != nil {
237
+ glog .Errorf ("refreshing credentials : %v" , err )
189
238
}
190
239
}
191
240
return res , nil
192
241
}
193
242
194
- func (a * Authenticator ) tokenExpired () bool {
243
+ func (a * Authenticator ) credsExpired () bool {
195
244
if a .exp .IsZero () {
196
245
return false
197
246
}
198
247
return a .now ().After (a .exp )
199
248
}
200
249
201
- func (a * Authenticator ) token () (string , error ) {
250
+ func (a * Authenticator ) cert () (* tls.Certificate , error ) {
251
+ creds , err := a .getCreds ()
252
+ if err != nil {
253
+ return nil , err
254
+ }
255
+ return creds .cert , nil
256
+ }
257
+
258
+ func (a * Authenticator ) getCreds () (* credentials , error ) {
202
259
a .mu .Lock ()
203
260
defer a .mu .Unlock ()
204
- if a .cachedToken != "" && ! a .tokenExpired () {
205
- return a .cachedToken , nil
261
+ if a .cachedCreds != nil && ! a .credsExpired () {
262
+ return a .cachedCreds , nil
206
263
}
207
264
208
- return a .getToken (nil )
265
+ if err := a .refreshCredsLocked (nil ); err != nil {
266
+ return nil , err
267
+ }
268
+ return a .cachedCreds , nil
209
269
}
210
270
211
- // refresh executes the plugin to force a rotation of the token.
212
- func (a * Authenticator ) refresh (token string , r * clientauthentication.Response ) error {
271
+ // maybeRefreshCreds executes the plugin to force a rotation of the
272
+ // credentials, unless they were rotated already.
273
+ func (a * Authenticator ) maybeRefreshCreds (creds * credentials , r * clientauthentication.Response ) error {
213
274
a .mu .Lock ()
214
275
defer a .mu .Unlock ()
215
276
216
- if token != a .cachedToken {
217
- // Token already rotated.
277
+ // Since we're not making a new pointer to a.cachedCreds in getCreds, no
278
+ // need to do deep comparison.
279
+ if creds != a .cachedCreds {
280
+ // Credentials already rotated.
218
281
return nil
219
282
}
220
283
221
- _ , err := a .getToken (r )
222
- return err
284
+ return a .refreshCredsLocked (r )
223
285
}
224
286
225
- // getToken executes the plugin and reads the credentials from stdout. It must be
226
- // called while holding the Authenticator's mutex.
227
- func (a * Authenticator ) getToken (r * clientauthentication.Response ) ( string , error ) {
287
+ // refreshCredsLocked executes the plugin and reads the credentials from
288
+ // stdout. It must be called while holding the Authenticator's mutex.
289
+ func (a * Authenticator ) refreshCredsLocked (r * clientauthentication.Response ) error {
228
290
cred := & clientauthentication.ExecCredential {
229
291
Spec : clientauthentication.ExecCredentialSpec {
230
292
Response : r ,
@@ -234,7 +296,7 @@ func (a *Authenticator) getToken(r *clientauthentication.Response) (string, erro
234
296
235
297
data , err := runtime .Encode (codecs .LegacyCodec (a .group ), cred )
236
298
if err != nil {
237
- return "" , fmt .Errorf ("encode ExecCredentials: %v" , err )
299
+ return fmt .Errorf ("encode ExecCredentials: %v" , err )
238
300
}
239
301
240
302
env := append (a .environ (), a .env ... )
@@ -250,31 +312,51 @@ func (a *Authenticator) getToken(r *clientauthentication.Response) (string, erro
250
312
}
251
313
252
314
if err := cmd .Run (); err != nil {
253
- return "" , fmt .Errorf ("exec: %v" , err )
315
+ return fmt .Errorf ("exec: %v" , err )
254
316
}
255
317
256
318
_ , gvk , err := codecs .UniversalDecoder (a .group ).Decode (stdout .Bytes (), nil , cred )
257
319
if err != nil {
258
- return "" , fmt .Errorf ("decode stdout: %v" , err )
320
+ return fmt .Errorf ("decoding stdout: %v" , err )
259
321
}
260
322
if gvk .Group != a .group .Group || gvk .Version != a .group .Version {
261
- return "" , fmt .Errorf ("exec plugin is configured to use API version %s, plugin returned version %s" ,
323
+ return fmt .Errorf ("exec plugin is configured to use API version %s, plugin returned version %s" ,
262
324
a .group , schema.GroupVersion {Group : gvk .Group , Version : gvk .Version })
263
325
}
264
326
265
327
if cred .Status == nil {
266
- return "" , fmt .Errorf ("exec plugin didn't return a status field" )
328
+ return fmt .Errorf ("exec plugin didn't return a status field" )
267
329
}
268
- if cred .Status .Token == "" {
269
- return "" , fmt .Errorf ("exec plugin didn't return a token" )
330
+ if cred .Status .Token == "" && cred .Status .ClientCertificateData == "" && cred .Status .ClientKeyData == "" {
331
+ return fmt .Errorf ("exec plugin didn't return a token or cert/key pair" )
332
+ }
333
+ if (cred .Status .ClientCertificateData == "" ) != (cred .Status .ClientKeyData == "" ) {
334
+ return fmt .Errorf ("exec plugin returned only certificate or key, not both" )
270
335
}
271
336
272
337
if cred .Status .ExpirationTimestamp != nil {
273
338
a .exp = cred .Status .ExpirationTimestamp .Time
274
339
} else {
275
340
a .exp = time.Time {}
276
341
}
277
- a .cachedToken = cred .Status .Token
278
342
279
- return a .cachedToken , nil
343
+ newCreds := & credentials {
344
+ token : cred .Status .Token ,
345
+ }
346
+ if cred .Status .ClientKeyData != "" && cred .Status .ClientCertificateData != "" {
347
+ cert , err := tls .X509KeyPair ([]byte (cred .Status .ClientCertificateData ), []byte (cred .Status .ClientKeyData ))
348
+ if err != nil {
349
+ return fmt .Errorf ("failed parsing client key/certificate: %v" , err )
350
+ }
351
+ newCreds .cert = & cert
352
+ }
353
+
354
+ oldCreds := a .cachedCreds
355
+ a .cachedCreds = newCreds
356
+ // Only close all connections when TLS cert rotates. Token rotation doesn't
357
+ // need the extra noise.
358
+ if a .onRotate != nil && oldCreds != nil && ! reflect .DeepEqual (oldCreds .cert , a .cachedCreds .cert ) {
359
+ a .onRotate ()
360
+ }
361
+ return nil
280
362
}
0 commit comments