Merge pull request 'feat: 支持加密数据库' (#1) from awalol/go-mmkv:pr into master
Reviewed-on: #1
This commit is contained in:
commit
52ac92c3e0
@ -5,6 +5,7 @@ type Manager interface {
|
|||||||
// If the vault does not exist, it will be created.
|
// If the vault does not exist, it will be created.
|
||||||
// If id is empty, DefaultVaultID will be used.
|
// If id is empty, DefaultVaultID will be used.
|
||||||
OpenVault(id string) (Vault, error)
|
OpenVault(id string) (Vault, error)
|
||||||
|
OpenVaultCrypto(id string, cryptoKey string) (Vault, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vault interface {
|
type Vault interface {
|
||||||
|
24
manager.go
24
manager.go
@ -42,7 +42,7 @@ func (m *manager) OpenVault(id string) (Vault, error) {
|
|||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
vault, err := m.openVault(id)
|
vault, err := m.openVault(id, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open vault: %w", err)
|
return nil, fmt.Errorf("failed to open vault: %w", err)
|
||||||
}
|
}
|
||||||
@ -51,7 +51,25 @@ func (m *manager) OpenVault(id string) (Vault, error) {
|
|||||||
return vault, nil
|
return vault, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) openVault(id string) (Vault, error) {
|
func (m *manager) OpenVaultCrypto(id string, cryptoKey string) (Vault, error) {
|
||||||
|
if id == "" {
|
||||||
|
id = DefaultVaultID
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := m.vaults[id]; ok {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
vault, err := m.openVault(id, cryptoKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open vault: %w", err)
|
||||||
|
}
|
||||||
|
m.vaults[id] = vault
|
||||||
|
|
||||||
|
return vault, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *manager) openVault(id string, cryptoKey string) (Vault, error) {
|
||||||
metaFile, err := os.Open(path.Join(m.dir, id+".crc"))
|
metaFile, err := os.Open(path.Join(m.dir, id+".crc"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open metadata file: %w", err)
|
return nil, fmt.Errorf("failed to open metadata file: %w", err)
|
||||||
@ -69,7 +87,7 @@ func (m *manager) openVault(id string) (Vault, error) {
|
|||||||
return nil, fmt.Errorf("failed to load metadata: %w", err)
|
return nil, fmt.Errorf("failed to load metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
v, err := loadVault(vaultFile, meta)
|
v, err := loadVault(vaultFile, meta, cryptoKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load vault: %w", err)
|
return nil, fmt.Errorf("failed to load vault: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestNewManager(t *testing.T) {
|
func TestNewManager(t *testing.T) {
|
||||||
|
t.Run("Default", func(t *testing.T) {
|
||||||
mgr, err := NewManager("./testdata")
|
mgr, err := NewManager("./testdata")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, mgr)
|
assert.NotNil(t, mgr)
|
||||||
@ -14,4 +15,20 @@ func TestNewManager(t *testing.T) {
|
|||||||
vault, err := mgr.OpenVault("")
|
vault, err := mgr.OpenVault("")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, vault)
|
assert.NotNil(t, vault)
|
||||||
|
})
|
||||||
|
t.Run("Crypto", func(t *testing.T) {
|
||||||
|
mgr, err := NewManager("./testdata")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, mgr)
|
||||||
|
|
||||||
|
vault, err := mgr.OpenVaultCrypto("crypto", "123456")
|
||||||
|
val, err := vault.GetString("world")
|
||||||
|
assert.NotNil(t, vault)
|
||||||
|
|
||||||
|
assert.Equal(t, "hello", val)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = vault.GetBytes("foo")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
BIN
testdata/crypto
vendored
Normal file
BIN
testdata/crypto
vendored
Normal file
Binary file not shown.
BIN
testdata/crypto.crc
vendored
Normal file
BIN
testdata/crypto.crc
vendored
Normal file
Binary file not shown.
19
vault.go
19
vault.go
@ -1,6 +1,8 @@
|
|||||||
package mmkv
|
package mmkv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
@ -44,8 +46,8 @@ func (v vault) GetString(key string) (string, error) {
|
|||||||
return string(val), nil
|
return string(val), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// metadata is optional. but if it exists, validate with it.
|
// metadata and cryptoKey are optional. but if it exists, validate with them.
|
||||||
func loadVault(src io.Reader, m *metadata) (Vault, error) {
|
func loadVault(src io.Reader, m *metadata, cryptoKey string) (Vault, error) {
|
||||||
fileSizeBuf := make([]byte, 4)
|
fileSizeBuf := make([]byte, 4)
|
||||||
_, err := io.ReadFull(src, fileSizeBuf)
|
_, err := io.ReadFull(src, fileSizeBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -67,6 +69,19 @@ func loadVault(src io.Reader, m *metadata) (Vault, error) {
|
|||||||
return nil, fmt.Errorf("metadata and vault payload crc32 mismatch")
|
return nil, fmt.Errorf("metadata and vault payload crc32 mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 将数据库完整解密
|
||||||
|
if len(cryptoKey) > 0 {
|
||||||
|
m_key := make([]byte, aes.BlockSize) // 16 bytes key
|
||||||
|
copy(m_key, cryptoKey)
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(m_key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create aes cipher")
|
||||||
|
}
|
||||||
|
stream := cipher.NewCFBDecrypter(block, m.aesVector)
|
||||||
|
stream.XORKeyStream(buf, buf)
|
||||||
|
}
|
||||||
|
|
||||||
v := make(vault)
|
v := make(vault)
|
||||||
|
|
||||||
// mmkv is not really protobuf compatible,
|
// mmkv is not really protobuf compatible,
|
||||||
|
@ -12,7 +12,7 @@ func Test_loadVault(t *testing.T) {
|
|||||||
file, err := os.Open("./testdata/mmkv.default")
|
file, err := os.Open("./testdata/mmkv.default")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
v, err := loadVault(file, nil)
|
v, err := loadVault(file, nil, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, 2, len(v.Keys()))
|
assert.Equal(t, 2, len(v.Keys()))
|
||||||
|
Loading…
Reference in New Issue
Block a user