diff --git a/ADOPTERS.md b/ADOPTERS.md index 46ce154bb..3d43bbbb7 100644 --- a/ADOPTERS.md +++ b/ADOPTERS.md @@ -7,6 +7,7 @@ - [Container Solutions](http://container-solutions.com/) - [DaangnPay](https://www.daangnpay.com/) - [Epidemic Sound](https://www.epidemicsound.com/) +- [Elastic](https://www.elastic.co/) - [Fivetran](https://www.fivetran.com) - [Form3](https://www.form3.tech/) - [GoTo](https://www.goto.com/) diff --git a/pkg/provider/vault/auth.go b/pkg/provider/vault/auth.go index ffced5085..c924b1a3d 100644 --- a/pkg/provider/vault/auth.go +++ b/pkg/provider/vault/auth.go @@ -16,6 +16,7 @@ package vault import ( "context" + "encoding/json" "errors" "fmt" @@ -160,6 +161,24 @@ func checkToken(ctx context.Context, token util.Token) (bool, error) { if tokenType == "batch" { return false, nil } + ttl, ok := resp.Data["ttl"] + if !ok { + return false, fmt.Errorf("no TTL found in response") + } + ttlInt, err := ttl.(json.Number).Int64() + if err != nil { + return false, fmt.Errorf("invalid token TTL: %v: %w", ttl, err) + } + expireTime, ok := resp.Data["expire_time"] + if !ok { + return false, fmt.Errorf("no expiration time found in response") + } + if ttlInt < 60 && expireTime != nil { + // Treat expirable tokens that are about to expire as already expired. + // This ensures that the token won't expire in between this check and + // performing the actual operation. + return false, nil + } return true, nil } diff --git a/pkg/provider/vault/auth_test.go b/pkg/provider/vault/auth_test.go index 5431332da..af5dcc23d 100644 --- a/pkg/provider/vault/auth_test.go +++ b/pkg/provider/vault/auth_test.go @@ -16,6 +16,7 @@ package vault import ( "context" + "encoding/json" "errors" "testing" @@ -208,3 +209,71 @@ func TestCheckTokenErrors(t *testing.T) { }) } } + +func TestCheckTokenTtl(t *testing.T) { + cases := map[string]struct { + message string + secret *vault.Secret + cache bool + }{ + "LongTTLExpirable": { + message: "should cache if expirable token expires far into the future", + secret: &vault.Secret{ + Data: map[string]interface{}{ + "expire_time": "2024-01-01T00:00:00.000000000Z", + "ttl": json.Number("3600"), + "type": "service", + }, + }, + cache: true, + }, + "ShortTTLExpirable": { + message: "should not cache if expirable token is about to expire", + secret: &vault.Secret{ + Data: map[string]interface{}{ + "expire_time": "2024-01-01T00:00:00.000000000Z", + "ttl": json.Number("5"), + "type": "service", + }, + }, + cache: false, + }, + "ZeroTTLExpirable": { + message: "should not cache if expirable token has TTL of 0", + secret: &vault.Secret{ + Data: map[string]interface{}{ + "expire_time": "2024-01-01T00:00:00.000000000Z", + "ttl": json.Number("0"), + "type": "service", + }, + }, + cache: false, + }, + "NonExpirable": { + message: "should cache if token is non-expirable", + secret: &vault.Secret{ + Data: map[string]interface{}{ + "expire_time": nil, + "ttl": json.Number("0"), + "type": "service", + }, + }, + cache: true, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + token := fake.Token{ + LookupSelfWithContextFn: func(ctx context.Context) (*vault.Secret, error) { + return tc.secret, nil + }, + } + + cached, err := checkToken(context.Background(), token) + if cached != tc.cache || err != nil { + t.Errorf("%v: err = %v", tc.message, err) + } + }) + } +}