From 41a75bd2993992c23a0c8b48df8dad23517d6c82 Mon Sep 17 00:00:00 2001 From: Unlock Music Dev Date: Mon, 5 Dec 2022 07:14:32 +0800 Subject: [PATCH] fix bytes & string protocol --- interface.go | 4 +++- vault.go | 25 ++++++++++++++++++++++++- vault_test.go | 10 +++++----- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/interface.go b/interface.go index 84b3c84..2c32c5e 100644 --- a/interface.go +++ b/interface.go @@ -9,5 +9,7 @@ type Manager interface { type Vault interface { Keys() []string - Get(key string) ([]byte, bool) + GetRaw(key string) ([]byte, bool) + GetBytes(key string) ([]byte, error) + GetString(key string) (string, error) } diff --git a/vault.go b/vault.go index c7da235..131b5d8 100644 --- a/vault.go +++ b/vault.go @@ -8,6 +8,7 @@ import ( "github.com/golang/protobuf/proto" "golang.org/x/exp/maps" + "google.golang.org/protobuf/encoding/protowire" ) type vault map[string][]byte @@ -16,11 +17,33 @@ func (v vault) Keys() []string { return maps.Keys(v) } -func (v vault) Get(key string) ([]byte, bool) { +func (v vault) GetRaw(key string) ([]byte, bool) { val, ok := v[key] return val, ok } +func (v vault) GetBytes(key string) ([]byte, error) { + raw, ok := v[key] + if !ok { + return nil, fmt.Errorf("key not found: %s", key) + } + + val, n := protowire.ConsumeBytes(raw) + if n < 0 { + return nil, fmt.Errorf("invalid protobuf bytes") + } + + return val, nil +} + +func (v vault) GetString(key string) (string, error) { + val, err := v.GetBytes(key) + if err != nil { + return "", err + } + return string(val), nil +} + // metadata is optional. but if it exists, validate with it. func loadVault(src io.Reader, m *metadata) (Vault, error) { fileSizeBuf := make([]byte, 4) diff --git a/vault_test.go b/vault_test.go index 5732410..d3e79fc 100644 --- a/vault_test.go +++ b/vault_test.go @@ -17,10 +17,10 @@ func Test_loadVault(t *testing.T) { assert.Equal(t, 2, len(v.Keys())) - val, ok := v.Get("world") - assert.Equal(t, "hello", string(val)) - assert.True(t, ok) + val, err := v.GetString("world") + assert.Equal(t, "hello", val) + assert.NoError(t, err) - val, ok = v.Get("foo") - assert.False(t, ok) + _, err = v.GetBytes("foo") + assert.Error(t, err) }