Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VAULT-19233 First part of caching static secrets work #23054

Merged
merged 7 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 52 additions & 29 deletions command/agentproxyshared/cache/lease_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,22 +189,39 @@ func (c *LeaseCache) PersistentStorage() *cacheboltdb.BoltStorage {
// checkCacheForDynamicSecretRequest checks the cache for a particular request based on its
// computed ID. It returns a non-nil *SendResponse if an entry is found.
func (c *LeaseCache) checkCacheForDynamicSecretRequest(id string) (*SendResponse, error) {
return c.checkCacheForRequest(id, "")
return c.checkCacheForRequest(id, nil)
}

// checkCacheForStaticSecretRequest checks the cache for a particular request based on its
// computed ID. It returns a non-nil *SendResponse if an entry is found.
// If a token is provided, it will validate that the token is allowed to retrieve this
// cache entry, and return nil if it isn't.
func (c *LeaseCache) checkCacheForStaticSecretRequest(id, token string) (*SendResponse, error) {
return c.checkCacheForRequest(id, token)
// If a request is provided, it will validate that the token is allowed to retrieve this
// cache entry, and return nil if it isn't. It will also evict the cache if this is a non-GET
// request.
func (c *LeaseCache) checkCacheForStaticSecretRequest(id string, req *SendRequest) (*SendResponse, error) {
return c.checkCacheForRequest(id, req)
}

// checkCacheForRequest checks the cache for a particular request based on its
// computed ID. It returns a non-nil *SendResponse if an entry is found.
// If a token is provided, it will validate that the token is allowed to retrieve this
// cache entry, and return nil if it isn't.
func (c *LeaseCache) checkCacheForRequest(id, token string) (*SendResponse, error) {
func (c *LeaseCache) checkCacheForRequest(id string, req *SendRequest) (*SendResponse, error) {
var token string
if req != nil {
token = req.Token
if req.Request.Method != http.MethodGet {
// This isn't a GET, so we should short-circuit and invalidate the cache
// as we know the cache is now stale.
c.logger.Debug("evicting index from cache, as non-GET received", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path)
err := c.db.Evict(cachememdb.IndexNameID, id)
if err != nil {
return nil, err
}

return nil, nil
}
}

index, err := c.db.Get(cachememdb.IndexNameID, id)
if err != nil {
return nil, err
Expand Down Expand Up @@ -266,6 +283,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
// Check the inflight cache to see if there are other inflight requests
// of the same kind, based on the computed ID. If so, we increment a counter

// Note: we lock both the dynamic secret cache ID and the static secret cache ID
// as at this stage, we don't know what kind of secret it is.
var inflight *inflightRequest

defer func() {
Expand Down Expand Up @@ -299,15 +318,15 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
case <-inflight.ch:
}
} else {
inflight = newInflightRequest()
inflight.remaining.Inc()
defer inflight.remaining.Dec()
if inflight == nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this if? When do we want to re-use an existing inflight?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's so that we have one inflight irrespective of if it's a dynamic or static request, essentially. It prevents double defer close(inflight.ch) and defer inflight.remaining.Dec() etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could we have two? How could the same request give rise to a cache hit for both keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't result in an inflight cache hit for both keys, but we take the else branch if we get a cache miss. We also don't know at this stage if it's a cache hit for either the static secret ID or the dynamic secret ID

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To put it another way, it's likely for most requests to result in a cache miss for both of the inflight cache IDs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This particular else could never see a non-nil inflight though, right? For the second case, the static if/else, might it be clearer to wrap the entire thing in an if inflight == nil instead? If we have already found the dynamic cache entry, there's no point in doing idLockStaticSecret.Lock and consulting the cache again, is there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to set both of the IDs in the cache (we don't know what kind of request it is at this stage) and unlock the inflight cache lock, so that behaviour can't be conditional. I might be misunderstanding, though?

We only have one request inflight (the stuff protected by the if inflight == nil) but we need to lock the IDs as it could be dynamic or static (or neither)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I'm the one who was missing something, I agree. Though I still feel the first if inflight == nil is a no-op, the one this comment is tied to.

inflight = newInflightRequest()
inflight.remaining.Inc()
defer inflight.remaining.Dec()
defer close(inflight.ch)
}

c.inflightCache.Set(dynamicSecretCacheId, inflight, gocache.NoExpiration)
idLockDynamicSecret.Unlock()

// Signal that the processing request is done
defer close(inflight.ch)
}

idLockStaticSecret := locksutil.LockForKey(c.idLocks, staticSecretCacheId)
Expand All @@ -332,15 +351,15 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
case <-inflight.ch:
}
} else {
inflight = newInflightRequest()
inflight.remaining.Inc()
defer inflight.remaining.Dec()
if inflight == nil {
inflight = newInflightRequest()
inflight.remaining.Inc()
defer inflight.remaining.Dec()
defer close(inflight.ch)
}

c.inflightCache.Set(staticSecretCacheId, inflight, gocache.NoExpiration)
idLockStaticSecret.Unlock()

// Signal that the processing request is done
defer close(inflight.ch)
}

// Check if the response for this request is already in the dynamic secret cache
Expand All @@ -354,12 +373,12 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
}

// Check if the response for this request is already in the static secret cache
cachedResp, err = c.checkCacheForStaticSecretRequest(staticSecretCacheId, req.Token)
cachedResp, err = c.checkCacheForStaticSecretRequest(staticSecretCacheId, req)
if err != nil {
return nil, err
}
if cachedResp != nil {
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
c.logger.Debug("returning cached response", "id", staticSecretCacheId, "path", req.Request.URL.Path)
return cachedResp, nil
}

Expand Down Expand Up @@ -387,7 +406,6 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,

// Build the index to cache based on the response received
index := &cachememdb.Index{
ID: dynamicSecretCacheId,
Namespace: namespace,
RequestPath: req.Request.URL.Path,
LastRenewed: time.Now().UTC(),
Expand Down Expand Up @@ -418,11 +436,16 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,

// TODO: if secret.MountType == "kvv1" || secret.MountType == "kvv2"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #23047

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to that condition, remember that we'll have to think more about method type: DELETE, PATCH, etc.

if c.cacheStaticSecrets && secret != nil {
index.Type = cacheboltdb.StaticSecretType
index.ID = staticSecretCacheId
err := c.cacheStaticSecret(ctx, req, resp, index)
if err != nil {
return nil, err
}
return resp, nil
} else {
// Since it's not a static secret, set the ID to be the dynamic id
index.ID = dynamicSecretCacheId
}

// Short-circuit if the secret is not renewable
Expand Down Expand Up @@ -528,15 +551,15 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
index.RequestToken = req.Token
index.RequestHeader = req.Request.Header

// Store the index in the cache
c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path)
err = c.Set(ctx, index)
if err != nil {
c.logger.Error("failed to cache the proxied response", "error", err)
return nil, err
}

if index.Type != cacheboltdb.StaticSecretType {
// Store the index in the cache
c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path)
err = c.Set(ctx, index)
if err != nil {
c.logger.Error("failed to cache the proxied response", "error", err)
return nil, err
}

// Start renewing the secret in the response
go c.startRenewing(renewCtx, index, req, secret)
}
Expand Down
161 changes: 161 additions & 0 deletions command/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,167 @@ log_level = "trace"
wg.Wait()
}

// TestProxy_Cache_StaticSecretInvalidation Tests that the cache successfully caches a static secret
// going through the Proxy, and that it gets invalidated by a POST.
func TestProxy_Cache_StaticSecretInvalidation(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()

serverClient := cluster.Cores[0].Client

// Unset the environment variable so that proxy picks up the right test
// cluster address
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
os.Unsetenv(api.EnvVaultAddress)

cacheConfig := `
cache {
cache_static_secrets = true
}
`
listenAddr := generateListenerAddress(t)
listenConfig := fmt.Sprintf(`
listener "tcp" {
address = "%s"
tls_disable = true
}
`, listenAddr)

config := fmt.Sprintf(`
vault {
address = "%s"
tls_skip_verify = true
}
%s
%s
log_level = "trace"
`, serverClient.Address(), cacheConfig, listenConfig)
configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)

// Start proxy
_, cmd := testProxyCommand(t, logger)
cmd.startedCh = make(chan struct{})

wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
cmd.Run([]string{"-config", configPath})
wg.Done()
}()

select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}

proxyClient, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
proxyClient.SetToken(serverClient.Token())
proxyClient.SetMaxRetries(0)
err = proxyClient.SetAddress("http://" + listenAddr)
if err != nil {
t.Fatal(err)
}

secretData := map[string]interface{}{
"foo": "bar",
}

secretData2 := map[string]interface{}{
"bar": "baz",
}

// Create kvv1 secret
err = serverClient.KVv1("secret").Put(context.Background(), "my-secret", secretData)
if err != nil {
t.Fatal(err)
}

// We use raw requests so we can check the headers for cache hit/miss.
req := proxyClient.NewRequest(http.MethodGet, "/v1/secret/my-secret")
resp1, err := proxyClient.RawRequest(req)
if err != nil {
t.Fatal(err)
}

cacheValue := resp1.Header.Get("X-Cache")
require.Equal(t, "MISS", cacheValue)

// Update the secret using the proxy client
err = proxyClient.KVv1("secret").Put(context.Background(), "my-secret", secretData2)
if err != nil {
t.Fatal(err)
}

resp2, err := proxyClient.RawRequest(req)
if err != nil {
t.Fatal(err)
}

cacheValue = resp2.Header.Get("X-Cache")
// This should miss too, as we just updated it
require.Equal(t, "MISS", cacheValue)

resp3, err := proxyClient.RawRequest(req)
if err != nil {
t.Fatal(err)
}

cacheValue = resp3.Header.Get("X-Cache")
// This should hit, as the third request should get the cached value
require.Equal(t, "HIT", cacheValue)

// Lastly, we check to make sure the actual data we received is
// as we expect. It's a little more awkward because of raw requests,
// but we make do.
resp1Map := map[string]interface{}{}
body, err := io.ReadAll(resp1.Body)
if err != nil {
t.Fatal(err)
}
err = json.Unmarshal(body, &resp1Map)
if err != nil {
t.Fatal(err)
}
resp1MapData := resp1Map["data"]
require.Equal(t, secretData, resp1MapData)

resp2Map := map[string]interface{}{}
body, err = io.ReadAll(resp2.Body)
if err != nil {
t.Fatal(err)
}
err = json.Unmarshal(body, &resp2Map)
if err != nil {
t.Fatal(err)
}
resp2MapData := resp2Map["data"]
require.Equal(t, secretData2, resp2MapData)

resp3Map := map[string]interface{}{}
body, err = io.ReadAll(resp3.Body)
if err != nil {
t.Fatal(err)
}
err = json.Unmarshal(body, &resp3Map)
if err != nil {
t.Fatal(err)
}
resp3MapData := resp2Map["data"]
require.Equal(t, secretData2, resp3MapData)

close(cmd.ShutdownCh)
wg.Wait()
}

// TestProxy_ApiProxy_Retry Tests the retry functionalities of Vault Proxy's API Proxy
func TestProxy_ApiProxy_Retry(t *testing.T) {
//----------------------------------------------------
Expand Down