Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for a dedicated HMAC type in Transit. #16668

Merged
merged 12 commits into from
Sep 6, 2022
322 changes: 169 additions & 153 deletions builtin/logical/transit/path_hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,188 +14,204 @@ import (
func TestTransit_HMAC(t *testing.T) {
b, storage := createBackendWithSysView(t)

// First create a key
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
typ string
}{
{
name: "foo",
typ: "",
},
{
name: "dedicated",
typ: "hmac",
},
}

for _, c := range cases {
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/" + c.name,
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}

// Now, change the key value to something we control
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: "foo",
}, b.GetRandomReader())
if err != nil {
t.Fatal(err)
}
// We don't care as we're the only one using this
latestVersion := strconv.Itoa(p.LatestVersion)
keyEntry := p.Keys[latestVersion]
keyEntry.HMACKey = []byte("01234567890123456789012345678901")
p.Keys[latestVersion] = keyEntry
if err = p.Persist(context.Background(), storage); err != nil {
t.Fatal(err)
}
// Now, change the key value to something we control
p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{
Storage: storage,
Name: c.name,
}, b.GetRandomReader())
if err != nil {
t.Fatal(err)
}
// We don't care as we're the only one using this
latestVersion := strconv.Itoa(p.LatestVersion)
keyEntry := p.Keys[latestVersion]
keyEntry.HMACKey = []byte("01234567890123456789012345678901")
keyEntry.Key = []byte("01234567890123456789012345678901")
p.Keys[latestVersion] = keyEntry
if err = p.Persist(context.Background(), storage); err != nil {
t.Fatal(err)
}

req.Path = "hmac/foo"
req.Data = map[string]interface{}{
"input": "dGhlIHF1aWNrIGJyb3duIGZveA==",
}
req.Path = "hmac/" + c.name
req.Data = map[string]interface{}{
"input": "dGhlIHF1aWNrIGJyb3duIGZveA==",
}

doRequest := func(req *logical.Request, errExpected bool, expected string) {
path := req.Path
defer func() { req.Path = path }()
doRequest := func(req *logical.Request, errExpected bool, expected string) {
path := req.Path
defer func() { req.Path = path }()

resp, err := b.HandleRequest(context.Background(), req)
if err != nil && !errExpected {
panic(fmt.Sprintf("%v", err))
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if errExpected {
if !resp.IsError() {
resp, err := b.HandleRequest(context.Background(), req)
if err != nil && !errExpected {
panic(fmt.Sprintf("%v", err))
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if errExpected {
if !resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
}
return
}
if resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
}
return
}
if resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
}
value, ok := resp.Data["hmac"]
if !ok {
t.Fatalf("no hmac key found in returned data, got resp data %#v", resp.Data)
}
if value.(string) != expected {
panic(fmt.Sprintf("mismatched hashes; expected %s, got resp data %#v", expected, resp.Data))
}
value, ok := resp.Data["hmac"]
if !ok {
t.Fatalf("no hmac key found in returned data, got resp data %#v", resp.Data)
}
if value.(string) != expected {
panic(fmt.Sprintf("mismatched hashes; expected %s, got resp data %#v", expected, resp.Data))
}

// Now verify
req.Path = strings.ReplaceAll(req.Path, "hmac", "verify")
req.Data["hmac"] = value.(string)
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("%v: %v", err, resp)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["valid"].(bool) == false {
panic(fmt.Sprintf("error validating hmac;\nreq:\n%#v\nresp:\n%#v", *req, *resp))
// Now verify
req.Path = strings.ReplaceAll(req.Path, "hmac", "verify")
req.Data["hmac"] = value.(string)
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("%v: %v", err, resp)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["valid"].(bool) == false {
panic(fmt.Sprintf("error validating hmac;\nreq:\n%#v\nresp:\n%#v", *req, *resp))
}
}
}

// Comparisons are against values generated via openssl
// Comparisons are against values generated via openssl

// Test defaults -- sha2-256
doRequest(req, false, "vault:v1:UcBvm5VskkukzZHlPgm3p5P/Yr/PV6xpuOGZISya3A4=")
// Test defaults -- sha2-256
doRequest(req, false, "vault:v1:UcBvm5VskkukzZHlPgm3p5P/Yr/PV6xpuOGZISya3A4=")

// Test algorithm selection in the path
req.Path = "hmac/foo/sha2-224"
doRequest(req, false, "vault:v1:3p+ZWVquYDvu2dSTCa65Y3fgoMfIAc6fNaBbtg==")
// Test algorithm selection in the path
req.Path = "hmac/" + c.name + "/sha2-224"
doRequest(req, false, "vault:v1:3p+ZWVquYDvu2dSTCa65Y3fgoMfIAc6fNaBbtg==")

// Reset and test algorithm selection in the data
req.Path = "hmac/foo"
req.Data["algorithm"] = "sha2-224"
doRequest(req, false, "vault:v1:3p+ZWVquYDvu2dSTCa65Y3fgoMfIAc6fNaBbtg==")
// Reset and test algorithm selection in the data
req.Path = "hmac/" + c.name
req.Data["algorithm"] = "sha2-224"
doRequest(req, false, "vault:v1:3p+ZWVquYDvu2dSTCa65Y3fgoMfIAc6fNaBbtg==")

req.Data["algorithm"] = "sha2-384"
doRequest(req, false, "vault:v1:jDB9YXdPjpmr29b1JCIEJO93IydlKVfD9mA2EO9OmJtJQg3QAV5tcRRRb7IQGW9p")
req.Data["algorithm"] = "sha2-384"
doRequest(req, false, "vault:v1:jDB9YXdPjpmr29b1JCIEJO93IydlKVfD9mA2EO9OmJtJQg3QAV5tcRRRb7IQGW9p")

req.Data["algorithm"] = "sha2-512"
doRequest(req, false, "vault:v1:PSXLXvkvKF4CpU65e2bK1tGBZQpcpCEM32fq2iUoiTyQQCfBcGJJItQ+60tMwWXAPQrC290AzTrNJucGrr4GFA==")
req.Data["algorithm"] = "sha2-512"
doRequest(req, false, "vault:v1:PSXLXvkvKF4CpU65e2bK1tGBZQpcpCEM32fq2iUoiTyQQCfBcGJJItQ+60tMwWXAPQrC290AzTrNJucGrr4GFA==")

// Test returning as base64
req.Data["format"] = "base64"
doRequest(req, false, "vault:v1:PSXLXvkvKF4CpU65e2bK1tGBZQpcpCEM32fq2iUoiTyQQCfBcGJJItQ+60tMwWXAPQrC290AzTrNJucGrr4GFA==")
// Test returning as base64
req.Data["format"] = "base64"
doRequest(req, false, "vault:v1:PSXLXvkvKF4CpU65e2bK1tGBZQpcpCEM32fq2iUoiTyQQCfBcGJJItQ+60tMwWXAPQrC290AzTrNJucGrr4GFA==")

// Test SHA3
req.Path = "hmac/foo"
req.Data["algorithm"] = "sha3-224"
doRequest(req, false, "vault:v1:TGipmKH8LR/BkMolYpDYy0BJCIhTtGPDhV2VkQ==")
// Test SHA3
req.Path = "hmac/" + c.name
req.Data["algorithm"] = "sha3-224"
doRequest(req, false, "vault:v1:TGipmKH8LR/BkMolYpDYy0BJCIhTtGPDhV2VkQ==")

req.Data["algorithm"] = "sha3-256"
doRequest(req, false, "vault:v1:+px9V/7QYLfdK808zPESC2T/L33uFf4Blzsn9Jy838o=")
req.Data["algorithm"] = "sha3-256"
doRequest(req, false, "vault:v1:+px9V/7QYLfdK808zPESC2T/L33uFf4Blzsn9Jy838o=")

req.Data["algorithm"] = "sha3-384"
doRequest(req, false, "vault:v1:YGoRwN4UdTRYZeOER86jsQOB8piWenzLDzJ2wJQK/Jq59rAsY8lh7SCdqqCyFg70")
req.Data["algorithm"] = "sha3-384"
doRequest(req, false, "vault:v1:YGoRwN4UdTRYZeOER86jsQOB8piWenzLDzJ2wJQK/Jq59rAsY8lh7SCdqqCyFg70")

req.Data["algorithm"] = "sha3-512"
doRequest(req, false, "vault:v1:GrNA8sU88naMPEQ7UZGj9EJl7YJhl03AFHfxcEURFrtvnobdea9ZlZHePpxAx/oCaC7R2HkrAO+Tu3uXPIl3lg==")
req.Data["algorithm"] = "sha3-512"
doRequest(req, false, "vault:v1:GrNA8sU88naMPEQ7UZGj9EJl7YJhl03AFHfxcEURFrtvnobdea9ZlZHePpxAx/oCaC7R2HkrAO+Tu3uXPIl3lg==")

// Test returning SHA3 as base64
req.Data["format"] = "base64"
doRequest(req, false, "vault:v1:GrNA8sU88naMPEQ7UZGj9EJl7YJhl03AFHfxcEURFrtvnobdea9ZlZHePpxAx/oCaC7R2HkrAO+Tu3uXPIl3lg==")
// Test returning SHA3 as base64
req.Data["format"] = "base64"
doRequest(req, false, "vault:v1:GrNA8sU88naMPEQ7UZGj9EJl7YJhl03AFHfxcEURFrtvnobdea9ZlZHePpxAx/oCaC7R2HkrAO+Tu3uXPIl3lg==")

req.Data["algorithm"] = "foobar"
doRequest(req, true, "")
req.Data["algorithm"] = "foobar"
doRequest(req, true, "")

req.Data["algorithm"] = "sha2-256"
req.Data["input"] = "foobar"
doRequest(req, true, "")
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
req.Data["algorithm"] = "sha2-256"
req.Data["input"] = "foobar"
doRequest(req, true, "")
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="

// Rotate
err = p.Rotate(context.Background(), storage, b.GetRandomReader())
if err != nil {
t.Fatal(err)
}
keyEntry = p.Keys["2"]
// Set to another value we control
keyEntry.HMACKey = []byte("12345678901234567890123456789012")
p.Keys["2"] = keyEntry
if err = p.Persist(context.Background(), storage); err != nil {
t.Fatal(err)
}
// Rotate
err = p.Rotate(context.Background(), storage, b.GetRandomReader())
if err != nil {
t.Fatal(err)
}
keyEntry = p.Keys["2"]
// Set to another value we control
keyEntry.HMACKey = []byte("12345678901234567890123456789012")
p.Keys["2"] = keyEntry
if err = p.Persist(context.Background(), storage); err != nil {
t.Fatal(err)
}

doRequest(req, false, "vault:v2:Dt+mO/B93kuWUbGMMobwUNX5Wodr6dL3JH4DMfpQ0kw=")
doRequest(req, false, "vault:v2:Dt+mO/B93kuWUbGMMobwUNX5Wodr6dL3JH4DMfpQ0kw=")

// Verify a previous version
req.Path = "verify/foo"
// Verify a previous version
req.Path = "verify/" + c.name

req.Data["hmac"] = "vault:v1:UcBvm5VskkukzZHlPgm3p5P/Yr/PV6xpuOGZISya3A4="
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("%v: %v", err, resp)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["valid"].(bool) == false {
t.Fatalf("error validating hmac\nreq\n%#v\nresp\n%#v", *req, *resp)
}
req.Data["hmac"] = "vault:v1:UcBvm5VskkukzZHlPgm3p5P/Yr/PV6xpuOGZISya3A4="
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("%v: %v", err, resp)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["valid"].(bool) == false {
t.Fatalf("error validating hmac\nreq\n%#v\nresp\n%#v", *req, *resp)
}

// Try a bad value
req.Data["hmac"] = "vault:v1:UcBvm4VskkukzZHlPgm3p5P/Yr/PV6xpuOGZISya3A4="
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("%v: %v", err, resp)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["valid"].(bool) {
t.Fatalf("expected error validating hmac")
}
// Try a bad value
req.Data["hmac"] = "vault:v1:UcBvm4VskkukzZHlPgm3p5P/Yr/PV6xpuOGZISya3A4="
resp, err = b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("%v: %v", err, resp)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.Data["valid"].(bool) {
t.Fatalf("expected error validating hmac")
}

// Set min decryption version, attempt to verify
p.MinDecryptionVersion = 2
if err = p.Persist(context.Background(), storage); err != nil {
t.Fatal(err)
}
// Set min decryption version, attempt to verify
p.MinDecryptionVersion = 2
if err = p.Persist(context.Background(), storage); err != nil {
t.Fatal(err)
}

req.Data["hmac"] = "vault:v1:UcBvm5VskkukzZHlPgm3p5P/Yr/PV6xpuOGZISya3A4="
resp, err = b.HandleRequest(context.Background(), req)
if err == nil {
t.Fatalf("expected an error, got response %#v", resp)
}
if err != logical.ErrInvalidRequest {
t.Fatalf("expected invalid request error, got %v", err)
req.Data["hmac"] = "vault:v1:UcBvm5VskkukzZHlPgm3p5P/Yr/PV6xpuOGZISya3A4="
resp, err = b.HandleRequest(context.Background(), req)
if err == nil {
t.Fatalf("expected an error, got response %#v", resp)
}
if err != logical.ErrInvalidRequest {
t.Fatalf("expected invalid request error, got %v", err)
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions builtin/logical/transit/path_import.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d *
polReq.KeyType = keysutil.KeyType_RSA3072
case "rsa-4096":
polReq.KeyType = keysutil.KeyType_RSA4096
case "hmac":
polReq.KeyType = keysutil.KeyType_HMAC
default:
return logical.ErrorResponse(fmt.Sprintf("unknown key type: %v", keyType)), logical.ErrInvalidRequest
}
Expand Down
5 changes: 3 additions & 2 deletions builtin/logical/transit/path_import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var keyTypes = []string{
"rsa-2048",
"rsa-3072",
"rsa-4096",
"hmac",
}

var hashFns = []string{
Expand Down Expand Up @@ -543,7 +544,7 @@ func wrapTargetKeyForImport(t *testing.T, wrappingKey *rsa.PublicKey, targetKey
var ok bool
var err error
switch targetKeyType {
case "aes128-gcm96", "aes256-gcm96", "chacha20-poly1305":
case "aes128-gcm96", "aes256-gcm96", "chacha20-poly1305", "hmac":
preppedTargetKey, ok = targetKey.([]byte)
if !ok {
t.Fatal("failed to wrap target key for import: symmetric key not provided in byte format")
Expand Down Expand Up @@ -600,7 +601,7 @@ func generateKey(keyType string) (interface{}, error) {
switch keyType {
case "aes128-gcm96":
return uuid.GenerateRandomBytes(16)
case "aes256-gcm96":
case "aes256-gcm96", "hmac":
Copy link
Contributor

@cipherboy cipherboy Aug 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to actually generate 64 bytes for hmac?

My 2c., but since SHA-256 uses a 512-bit block (which I think is what HMAC is using under the covers, but it isn't clear from my quick glance) -- a 32-byte key would be padded with 32 bytes of zeros, I'd rather we use all the bits since we can.

I'm also generally interested in extensibility. Do you see us adding SHA-512? SHA-3? If so, where/how?

Definitely like this though!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add hash function as a sub param. I think I like that rather than a plethora of distinct types.

return uuid.GenerateRandomBytes(32)
case "chacha20-poly1305":
return uuid.GenerateRandomBytes(32)
Expand Down
Loading