diff --git a/builtin/credential/aws/path_config_client.go b/builtin/credential/aws/path_config_client.go index 979fac11a9d8..04f8f238d709 100644 --- a/builtin/credential/aws/path_config_client.go +++ b/builtin/credential/aws/path_config_client.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/hashicorp/go-secure-stdlib/strutil" + "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" ) @@ -61,6 +62,12 @@ func (b *backend) pathConfigClient() *framework.Path { Description: "The region ID for the sts_endpoint, if set.", }, + "use_sts_region_from_client": { + Type: framework.TypeBool, + Default: false, + Description: "Uses the STS region from client requests for making AWS STS API calls.", + }, + "iam_server_id_header_value": { Type: framework.TypeString, Default: "", @@ -168,6 +175,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request "iam_endpoint": clientConfig.IAMEndpoint, "sts_endpoint": clientConfig.STSEndpoint, "sts_region": clientConfig.STSRegion, + "use_sts_region_from_client": clientConfig.UseSTSRegionFromClient, "iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue, "max_retries": clientConfig.MaxRetries, "allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues, @@ -281,6 +289,14 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical } } + useSTSRegionFromClientRaw, ok := data.GetOk("use_sts_region_from_client") + if ok { + if configEntry.UseSTSRegionFromClient != useSTSRegionFromClientRaw.(bool) { + changedCreds = true + configEntry.UseSTSRegionFromClient = useSTSRegionFromClientRaw.(bool) + } + } + headerValStr, ok := data.GetOk("iam_server_id_header_value") if ok { if configEntry.IAMServerIdHeaderValue != headerValStr.(string) { @@ -363,6 +379,7 @@ type clientConfig struct { IAMEndpoint string `json:"iam_endpoint"` STSEndpoint string `json:"sts_endpoint"` STSRegion string `json:"sts_region"` + UseSTSRegionFromClient bool `json:"use_sts_region_from_client"` IAMServerIdHeaderValue string `json:"iam_server_id_header_value"` AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"` MaxRetries int `json:"max_retries"` diff --git a/builtin/credential/aws/path_login.go b/builtin/credential/aws/path_login.go index 1e23500dc538..cd7c736766fa 100644 --- a/builtin/credential/aws/path_login.go +++ b/builtin/credential/aws/path_login.go @@ -21,8 +21,10 @@ import ( "github.com/aws/aws-sdk-go/aws" awsClient "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" "github.com/hashicorp/errwrap" cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-retryablehttp" @@ -30,6 +32,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/strutil" uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/builtin/credential/aws/pkcs7" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/cidrutil" @@ -318,6 +321,24 @@ func (b *backend) pathLoginIamGetRoleNameCallerIdAndEntity(ctx context.Context, } } + // Extract and use a regional STS endpoint + // based on the region set in the Authorization header. + if config.UseSTSRegionFromClient { + clientSpecifiedRegion, err := awsRegionFromHeader(headers.Get("Authorization")) + if err != nil { + return "", nil, nil, logical.ErrorResponse("region missing from Authorization header"), nil + } + + url, err := stsRegionalEndpoint(clientSpecifiedRegion) + if err != nil { + return "", nil, nil, logical.ErrorResponse(err.Error()), nil + } + + b.Logger().Debug("use_sts_region_from_client set; using region specified from header", "region", clientSpecifiedRegion) + endpoint = url + } + + b.Logger().Debug("submitting caller identity request", "endpoint", endpoint) callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers) if err != nil { return "", nil, nil, logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil @@ -1884,6 +1905,43 @@ func getMetadataValue(fromAuth *logical.Auth, forKey string) (string, error) { return "", fmt.Errorf("%q not found in auth metadata", forKey) } +func awsRegionFromHeader(authorizationHeader string) (string, error) { + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html + // The Authorization header takes the following form. + // Authorization: AWS4-HMAC-SHA256 + // Credential=AKIAIOSFODNN7EXAMPLE/20230719/us-east-1/sts/aws4_request, + // SignedHeaders=content-length;content-type;host;x-amz-date, + // Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024 + // + // The credential is in the form of "////aws4_request" + fields := strings.Split(authorizationHeader, " ") + for _, field := range fields { + if strings.HasPrefix(field, "Credential=") { + fields := strings.Split(field, "/") + if len(fields) < 3 { + return "", fmt.Errorf("invalid header format") + } + + region := fields[2] + return region, nil + } + } + + return "", fmt.Errorf("invalid header format") +} + +func stsRegionalEndpoint(region string) (string, error) { + stsService := sts.EndpointsID + resolver := endpoints.DefaultResolver() + resolvedEndpoint, err := resolver.EndpointFor(stsService, region, + endpoints.STSRegionalEndpointOption, + endpoints.StrictMatchingOption) + if err != nil { + return "", fmt.Errorf("unable to get regional STS endpoint for region: %v", region) + } + return resolvedEndpoint.URL, nil +} + const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID" const pathLoginSyn = ` diff --git a/builtin/credential/aws/path_login_test.go b/builtin/credential/aws/path_login_test.go index 2c0262075ad3..3fbc090f8477 100644 --- a/builtin/credential/aws/path_login_test.go +++ b/builtin/credential/aws/path_login_test.go @@ -16,6 +16,8 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" + "github.com/stretchr/testify/assert" + "github.com/hashicorp/vault/sdk/logical" ) @@ -625,6 +627,58 @@ func TestBackend_defaultAliasMetadata(t *testing.T) { } } +func TestRegionFromHeader(t *testing.T) { + tcs := map[string]struct { + header string + expectedRegion string + expectedSTSEndpoint string + }{ + "us-east-1": { + header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + expectedRegion: "us-east-1", + expectedSTSEndpoint: "https://sts.us-east-1.amazonaws.com", + }, + "us-west-2": { + header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-west-2/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + expectedRegion: "us-west-2", + expectedSTSEndpoint: "https://sts.us-west-2.amazonaws.com", + }, + "ap-northeast-3": { + header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/ap-northeast-3/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + expectedRegion: "ap-northeast-3", + expectedSTSEndpoint: "https://sts.ap-northeast-3.amazonaws.com", + }, + "us-gov-east-1": { + header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-gov-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + expectedRegion: "us-gov-east-1", + expectedSTSEndpoint: "https://sts.us-gov-east-1.amazonaws.com", + }, + } + for name, tc := range tcs { + t.Run(name, func(t *testing.T) { + region, err := awsRegionFromHeader(tc.header) + assert.NoError(t, err) + assert.Equal(t, tc.expectedRegion, region) + + stsEndpoint, err := stsRegionalEndpoint(region) + assert.NoError(t, err) + assert.Equal(t, tc.expectedSTSEndpoint, stsEndpoint) + }) + } + + t.Run("invalid-header", func(t *testing.T) { + region, err := awsRegionFromHeader("this-is-an-invalid-header/foobar") + assert.EqualError(t, err, "invalid header format") + assert.Empty(t, region) + }) + + t.Run("invalid-region", func(t *testing.T) { + endpoint, err := stsRegionalEndpoint("fake-region-1") + assert.EqualError(t, err, "unable to get regional STS endpoint for region: fake-region-1") + assert.Empty(t, endpoint) + }) +} + func defaultLoginData() (map[string]interface{}, error) { awsSession, err := session.NewSession() if err != nil { diff --git a/changelog/21960.txt b/changelog/21960.txt new file mode 100644 index 000000000000..cab19fab96f3 --- /dev/null +++ b/changelog/21960.txt @@ -0,0 +1,3 @@ +```release-note:improvement +aws/auth: Adds a new config field `use_sts_region_from_client` which allows for using dynamic regional sts endpoints based on Authorization header when using IAM-based authentication. +``` diff --git a/website/content/api-docs/auth/aws.mdx b/website/content/api-docs/auth/aws.mdx index b7d45b506cb5..2beb84c18bc1 100644 --- a/website/content/api-docs/auth/aws.mdx +++ b/website/content/api-docs/auth/aws.mdx @@ -65,6 +65,10 @@ capabilities, the credentials are fetched automatically. - `sts_region` `(string: "")` - Region to override the default region for making AWS STS API calls. Should only be set if `sts_endpoint` is set. If so, should be set to the region in which the custom `sts_endpoint` resides. +- `use_sts_region_from_client` `(boolean: false)` - If set, overrides both `sts_endpoint` + and `sts_region` to instead use the region specified in the client request headers for + IAM-based authentication . This can be useful when you have client requests coming from + different regions and want flexibility in which regional STS API is used. - `iam_server_id_header_value` `(string: "")` - The value to require in the `X-Vault-AWS-IAM-Server-ID` header as part of GetCallerIdentity requests that are used in the iam auth method. If not set, then no value is required or @@ -123,6 +127,7 @@ $ curl \ "iam_endpoint": "", "sts_endpoint": "", "sts_region": "", + "use_sts_region_from_client": false, "iam_server_id_header_value": "" } }