Skip to content

Commit

Permalink
AWS auth login with multi region STS support (#21960)
Browse files Browse the repository at this point in the history
  • Loading branch information
raymonstah authored Jul 28, 2023
1 parent 194e8cd commit 4f7a8fb
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 0 deletions.
17 changes: 17 additions & 0 deletions builtin/credential/aws/path_config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/go-secure-stdlib/strutil"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
Expand Down Expand Up @@ -61,6 +62,12 @@ func (b *backend) pathConfigClient() *framework.Path {
Description: "The region ID for the sts_endpoint, if set.",
},

"use_sts_region_from_client": {
Type: framework.TypeBool,
Default: false,
Description: "Uses the STS region from client requests for making AWS STS API calls.",
},

"iam_server_id_header_value": {
Type: framework.TypeString,
Default: "",
Expand Down Expand Up @@ -168,6 +175,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"iam_endpoint": clientConfig.IAMEndpoint,
"sts_endpoint": clientConfig.STSEndpoint,
"sts_region": clientConfig.STSRegion,
"use_sts_region_from_client": clientConfig.UseSTSRegionFromClient,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries,
"allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues,
Expand Down Expand Up @@ -281,6 +289,14 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
}
}

useSTSRegionFromClientRaw, ok := data.GetOk("use_sts_region_from_client")
if ok {
if configEntry.UseSTSRegionFromClient != useSTSRegionFromClientRaw.(bool) {
changedCreds = true
configEntry.UseSTSRegionFromClient = useSTSRegionFromClientRaw.(bool)
}
}

headerValStr, ok := data.GetOk("iam_server_id_header_value")
if ok {
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
Expand Down Expand Up @@ -363,6 +379,7 @@ type clientConfig struct {
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
UseSTSRegionFromClient bool `json:"use_sts_region_from_client"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"`
MaxRetries int `json:"max_retries"`
Expand Down
58 changes: 58 additions & 0 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@ import (

"github.com/aws/aws-sdk-go/aws"
awsClient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-retryablehttp"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
uuid "github.com/hashicorp/go-uuid"

"github.com/hashicorp/vault/builtin/credential/aws/pkcs7"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/cidrutil"
Expand Down Expand Up @@ -318,6 +321,24 @@ func (b *backend) pathLoginIamGetRoleNameCallerIdAndEntity(ctx context.Context,
}
}

// Extract and use a regional STS endpoint
// based on the region set in the Authorization header.
if config.UseSTSRegionFromClient {
clientSpecifiedRegion, err := awsRegionFromHeader(headers.Get("Authorization"))
if err != nil {
return "", nil, nil, logical.ErrorResponse("region missing from Authorization header"), nil
}

url, err := stsRegionalEndpoint(clientSpecifiedRegion)
if err != nil {
return "", nil, nil, logical.ErrorResponse(err.Error()), nil
}

b.Logger().Debug("use_sts_region_from_client set; using region specified from header", "region", clientSpecifiedRegion)
endpoint = url
}

b.Logger().Debug("submitting caller identity request", "endpoint", endpoint)
callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers)
if err != nil {
return "", nil, nil, logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil
Expand Down Expand Up @@ -1884,6 +1905,43 @@ func getMetadataValue(fromAuth *logical.Auth, forKey string) (string, error) {
return "", fmt.Errorf("%q not found in auth metadata", forKey)
}

func awsRegionFromHeader(authorizationHeader string) (string, error) {
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html
// The Authorization header takes the following form.
// Authorization: AWS4-HMAC-SHA256
// Credential=AKIAIOSFODNN7EXAMPLE/20230719/us-east-1/sts/aws4_request,
// SignedHeaders=content-length;content-type;host;x-amz-date,
// Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024
//
// The credential is in the form of "<your-access-key-id>/<date>/<aws-region>/<aws-service>/aws4_request"
fields := strings.Split(authorizationHeader, " ")
for _, field := range fields {
if strings.HasPrefix(field, "Credential=") {
fields := strings.Split(field, "/")
if len(fields) < 3 {
return "", fmt.Errorf("invalid header format")
}

region := fields[2]
return region, nil
}
}

return "", fmt.Errorf("invalid header format")
}

func stsRegionalEndpoint(region string) (string, error) {
stsService := sts.EndpointsID
resolver := endpoints.DefaultResolver()
resolvedEndpoint, err := resolver.EndpointFor(stsService, region,
endpoints.STSRegionalEndpointOption,
endpoints.StrictMatchingOption)
if err != nil {
return "", fmt.Errorf("unable to get regional STS endpoint for region: %v", region)
}
return resolvedEndpoint.URL, nil
}

const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID"

const pathLoginSyn = `
Expand Down
54 changes: 54 additions & 0 deletions builtin/credential/aws/path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"

"github.com/hashicorp/vault/sdk/logical"
)

Expand Down Expand Up @@ -625,6 +627,58 @@ func TestBackend_defaultAliasMetadata(t *testing.T) {
}
}

func TestRegionFromHeader(t *testing.T) {
tcs := map[string]struct {
header string
expectedRegion string
expectedSTSEndpoint string
}{
"us-east-1": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-east-1",
expectedSTSEndpoint: "https://sts.us-east-1.amazonaws.com",
},
"us-west-2": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-west-2/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-west-2",
expectedSTSEndpoint: "https://sts.us-west-2.amazonaws.com",
},
"ap-northeast-3": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/ap-northeast-3/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "ap-northeast-3",
expectedSTSEndpoint: "https://sts.ap-northeast-3.amazonaws.com",
},
"us-gov-east-1": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-gov-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-gov-east-1",
expectedSTSEndpoint: "https://sts.us-gov-east-1.amazonaws.com",
},
}
for name, tc := range tcs {
t.Run(name, func(t *testing.T) {
region, err := awsRegionFromHeader(tc.header)
assert.NoError(t, err)
assert.Equal(t, tc.expectedRegion, region)

stsEndpoint, err := stsRegionalEndpoint(region)
assert.NoError(t, err)
assert.Equal(t, tc.expectedSTSEndpoint, stsEndpoint)
})
}

t.Run("invalid-header", func(t *testing.T) {
region, err := awsRegionFromHeader("this-is-an-invalid-header/foobar")
assert.EqualError(t, err, "invalid header format")
assert.Empty(t, region)
})

t.Run("invalid-region", func(t *testing.T) {
endpoint, err := stsRegionalEndpoint("fake-region-1")
assert.EqualError(t, err, "unable to get regional STS endpoint for region: fake-region-1")
assert.Empty(t, endpoint)
})
}

func defaultLoginData() (map[string]interface{}, error) {
awsSession, err := session.NewSession()
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions changelog/21960.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
aws/auth: Adds a new config field `use_sts_region_from_client` which allows for using dynamic regional sts endpoints based on Authorization header when using IAM-based authentication.
```
5 changes: 5 additions & 0 deletions website/content/api-docs/auth/aws.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ capabilities, the credentials are fetched automatically.
- `sts_region` `(string: "")` - Region to override the default region for making
AWS STS API calls. Should only be set if `sts_endpoint` is set. If so, should
be set to the region in which the custom `sts_endpoint` resides.
- `use_sts_region_from_client` `(boolean: false)` - If set, overrides both `sts_endpoint`
and `sts_region` to instead use the region specified in the client request headers for
IAM-based authentication . This can be useful when you have client requests coming from
different regions and want flexibility in which regional STS API is used.
- `iam_server_id_header_value` `(string: "")` - The value to require in the
`X-Vault-AWS-IAM-Server-ID` header as part of GetCallerIdentity requests that
are used in the iam auth method. If not set, then no value is required or
Expand Down Expand Up @@ -123,6 +127,7 @@ $ curl \
"iam_endpoint": "",
"sts_endpoint": "",
"sts_region": "",
"use_sts_region_from_client": false,
"iam_server_id_header_value": ""
}
}
Expand Down

0 comments on commit 4f7a8fb

Please sign in to comment.