support encrypted database

This commit is contained in:
awalol 2024-02-05 00:47:24 +08:00
parent 41a75bd299
commit a0c5a497c7
7 changed files with 63 additions and 12 deletions

View File

@ -5,6 +5,7 @@ type Manager interface {
// If the vault does not exist, it will be created. // If the vault does not exist, it will be created.
// If id is empty, DefaultVaultID will be used. // If id is empty, DefaultVaultID will be used.
OpenVault(id string) (Vault, error) OpenVault(id string) (Vault, error)
OpenVaultCrypto(id string, cryptoKey string) (Vault, error)
} }
type Vault interface { type Vault interface {

View File

@ -42,7 +42,7 @@ func (m *manager) OpenVault(id string) (Vault, error) {
return v, nil return v, nil
} }
vault, err := m.openVault(id) vault, err := m.openVault(id, "")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open vault: %w", err) return nil, fmt.Errorf("failed to open vault: %w", err)
} }
@ -51,7 +51,25 @@ func (m *manager) OpenVault(id string) (Vault, error) {
return vault, nil return vault, nil
} }
func (m *manager) openVault(id string) (Vault, error) { func (m *manager) OpenVaultCrypto(id string, cryptoKey string) (Vault, error) {
if id == "" {
id = DefaultVaultID
}
if v, ok := m.vaults[id]; ok {
return v, nil
}
vault, err := m.openVault(id, cryptoKey)
if err != nil {
return nil, fmt.Errorf("failed to open vault: %w", err)
}
m.vaults[id] = vault
return vault, nil
}
func (m *manager) openVault(id string, cryptoKey string) (Vault, error) {
metaFile, err := os.Open(path.Join(m.dir, id+".crc")) metaFile, err := os.Open(path.Join(m.dir, id+".crc"))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open metadata file: %w", err) return nil, fmt.Errorf("failed to open metadata file: %w", err)
@ -69,7 +87,7 @@ func (m *manager) openVault(id string) (Vault, error) {
return nil, fmt.Errorf("failed to load metadata: %w", err) return nil, fmt.Errorf("failed to load metadata: %w", err)
} }
v, err := loadVault(vaultFile, meta) v, err := loadVault(vaultFile, meta, cryptoKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load vault: %w", err) return nil, fmt.Errorf("failed to load vault: %w", err)
} }

View File

@ -7,11 +7,28 @@ import (
) )
func TestNewManager(t *testing.T) { func TestNewManager(t *testing.T) {
mgr, err := NewManager("./testdata") t.Run("Default", func(t *testing.T) {
assert.NoError(t, err) mgr, err := NewManager("./testdata")
assert.NotNil(t, mgr) assert.NoError(t, err)
assert.NotNil(t, mgr)
vault, err := mgr.OpenVault("") vault, err := mgr.OpenVault("")
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, vault) assert.NotNil(t, vault)
})
t.Run("Crypto", func(t *testing.T) {
mgr, err := NewManager("./testdata")
assert.NoError(t, err)
assert.NotNil(t, mgr)
vault, err := mgr.OpenVaultCrypto("crypto", "123456")
val, err := vault.GetString("world")
assert.NotNil(t, vault)
assert.Equal(t, "hello", val)
assert.NoError(t, err)
_, err = vault.GetBytes("foo")
assert.Error(t, err)
})
} }

BIN
testdata/crypto vendored Normal file

Binary file not shown.

BIN
testdata/crypto.crc vendored Normal file

Binary file not shown.

View File

@ -1,6 +1,8 @@
package mmkv package mmkv
import ( import (
"crypto/aes"
"crypto/cipher"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"hash/crc32" "hash/crc32"
@ -44,8 +46,8 @@ func (v vault) GetString(key string) (string, error) {
return string(val), nil return string(val), nil
} }
// metadata is optional. but if it exists, validate with it. // metadata and cryptoKey are optional. but if it exists, validate with them.
func loadVault(src io.Reader, m *metadata) (Vault, error) { func loadVault(src io.Reader, m *metadata, cryptoKey string) (Vault, error) {
fileSizeBuf := make([]byte, 4) fileSizeBuf := make([]byte, 4)
_, err := io.ReadFull(src, fileSizeBuf) _, err := io.ReadFull(src, fileSizeBuf)
if err != nil { if err != nil {
@ -67,6 +69,19 @@ func loadVault(src io.Reader, m *metadata) (Vault, error) {
return nil, fmt.Errorf("metadata and vault payload crc32 mismatch") return nil, fmt.Errorf("metadata and vault payload crc32 mismatch")
} }
// 将数据库完整解密
if len(cryptoKey) > 0 {
m_key := make([]byte, aes.BlockSize) // 16 bytes key
copy(m_key, cryptoKey)
block, err := aes.NewCipher(m_key)
if err != nil {
return nil, fmt.Errorf("failed to create aes cipher")
}
stream := cipher.NewCFBDecrypter(block, m.aesVector)
stream.XORKeyStream(buf, buf)
}
v := make(vault) v := make(vault)
// mmkv is not really protobuf compatible, // mmkv is not really protobuf compatible,

View File

@ -12,7 +12,7 @@ func Test_loadVault(t *testing.T) {
file, err := os.Open("./testdata/mmkv.default") file, err := os.Open("./testdata/mmkv.default")
require.NoError(t, err) require.NoError(t, err)
v, err := loadVault(file, nil) v, err := loadVault(file, nil, "")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, len(v.Keys())) assert.Equal(t, 2, len(v.Keys()))