fix bytes & string protocol

This commit is contained in:
Unlock Music Dev 2022-12-05 07:14:32 +08:00
parent 629b839482
commit 41a75bd299
Signed by: um-dev
GPG Key ID: 95202E10D3413A1D
3 changed files with 32 additions and 7 deletions

View File

@ -9,5 +9,7 @@ type Manager interface {
type Vault interface { type Vault interface {
Keys() []string Keys() []string
Get(key string) ([]byte, bool) GetRaw(key string) ([]byte, bool)
GetBytes(key string) ([]byte, error)
GetString(key string) (string, error)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"google.golang.org/protobuf/encoding/protowire"
) )
type vault map[string][]byte type vault map[string][]byte
@ -16,11 +17,33 @@ func (v vault) Keys() []string {
return maps.Keys(v) return maps.Keys(v)
} }
func (v vault) Get(key string) ([]byte, bool) { func (v vault) GetRaw(key string) ([]byte, bool) {
val, ok := v[key] val, ok := v[key]
return val, ok return val, ok
} }
func (v vault) GetBytes(key string) ([]byte, error) {
raw, ok := v[key]
if !ok {
return nil, fmt.Errorf("key not found: %s", key)
}
val, n := protowire.ConsumeBytes(raw)
if n < 0 {
return nil, fmt.Errorf("invalid protobuf bytes")
}
return val, nil
}
func (v vault) GetString(key string) (string, error) {
val, err := v.GetBytes(key)
if err != nil {
return "", err
}
return string(val), nil
}
// metadata is optional. but if it exists, validate with it. // metadata is optional. but if it exists, validate with it.
func loadVault(src io.Reader, m *metadata) (Vault, error) { func loadVault(src io.Reader, m *metadata) (Vault, error) {
fileSizeBuf := make([]byte, 4) fileSizeBuf := make([]byte, 4)

View File

@ -17,10 +17,10 @@ func Test_loadVault(t *testing.T) {
assert.Equal(t, 2, len(v.Keys())) assert.Equal(t, 2, len(v.Keys()))
val, ok := v.Get("world") val, err := v.GetString("world")
assert.Equal(t, "hello", string(val)) assert.Equal(t, "hello", val)
assert.True(t, ok) assert.NoError(t, err)
val, ok = v.Get("foo") _, err = v.GetBytes("foo")
assert.False(t, ok) assert.Error(t, err)
} }