diff --git a/chain/db.go b/chain/db.go index 0dc03d5..39453bc 100644 --- a/chain/db.go +++ b/chain/db.go @@ -187,7 +187,7 @@ func (b *dbBucket) getRaw(key []byte) []byte { } func (b *dbBucket) get(key []byte, v types.DecoderFrom) bool { - val := b.getRaw(key) + val := b.getRaw(b.db.vkey(key)) if val == nil { return false } @@ -210,13 +210,17 @@ func (b *dbBucket) put(key []byte, v types.EncoderTo) { b.db.enc.Reset(&buf) v.EncodeTo(&b.db.enc) b.db.enc.Flush() - b.putRaw(key, buf.Bytes()) + b.putRaw(b.db.vkey(key), buf.Bytes()) } -func (b *dbBucket) delete(key []byte) { +func (b *dbBucket) deleteRaw(key []byte) { check(b.b.Delete(key)) } +func (b *dbBucket) delete(key []byte) { + b.deleteRaw(b.db.vkey(key)) +} + var ( bVersion = []byte("Version") bMainChain = []byte("MainChain") @@ -232,14 +236,21 @@ var ( // DBStore implements Store using a key-value database. type DBStore struct { - db DB - n *consensus.Network // for getState - enc types.Encoder + db DB + n *consensus.Network // for getState + version uint8 + keyBuf []byte + enc types.Encoder unflushed int lastFlush time.Time } +func (db *DBStore) vkey(key []byte) []byte { + db.keyBuf = append(db.keyBuf[:0], key...) + return append(db.keyBuf, db.version) +} + func (db *DBStore) bucket(name []byte) *dbBucket { return &dbBucket{db.db.Bucket(name), db} } @@ -258,14 +269,14 @@ func (db *DBStore) deleteBestIndex(height uint64) { } func (db *DBStore) getHeight() (height uint64) { - if val := db.bucket(bMainChain).getRaw(keyHeight); len(val) == 8 { + if val := db.bucket(bMainChain).getRaw(db.vkey(keyHeight)); len(val) == 8 { height = binary.BigEndian.Uint64(val) } return } func (db *DBStore) putHeight(height uint64) { - db.bucket(bMainChain).putRaw(keyHeight, db.encHeight(height)) + db.bucket(bMainChain).putRaw(db.vkey(keyHeight), db.encHeight(height)) } func (db *DBStore) getState(id types.BlockID) (cs consensus.State, ok bool) { @@ -378,13 +389,13 @@ func (db *DBStore) deleteFileContractElement(id types.FileContractID) { func (db *DBStore) putFileContractExpiration(id types.FileContractID, windowEnd uint64) { b := db.bucket(bFileContractElements) - key := db.encHeight(windowEnd) + key := db.vkey(db.encHeight(windowEnd)) b.putRaw(key, append(b.getRaw(key), id[:]...)) } func (db *DBStore) deleteFileContractExpiration(id types.FileContractID, windowEnd uint64) { b := db.bucket(bFileContractElements) - key := db.encHeight(windowEnd) + key := db.vkey(db.encHeight(windowEnd)) val := append([]byte(nil), b.getRaw(key)...) for i := 0; i < len(val); i += 32 { if *(*types.FileContractID)(val[i:]) == id { @@ -704,13 +715,12 @@ func NewDBStore(db DB, n *consensus.Network, genesisBlock types.Block) (_ *DBSto } dbs := &DBStore{ - db: db, - n: n, + db: db, + n: n, + version: 2, } - - // if the db is empty, initialize it; otherwise, check that the genesis - // block is correct - if dbGenesis, ok := dbs.BestIndex(0); !ok { + if version := dbs.bucket(bVersion).getRaw(dbs.vkey(bVersion)); len(version) != 1 { + // initialize empty database for _, bucket := range [][]byte{ bVersion, bMainChain, @@ -725,7 +735,7 @@ func NewDBStore(db DB, n *consensus.Network, genesisBlock types.Block) (_ *DBSto panic(err) } } - dbs.bucket(bVersion).putRaw(bVersion, []byte{1}) + dbs.bucket(bVersion).putRaw(dbs.vkey(bVersion), []byte{dbs.version}) // store genesis state and apply genesis block to it genesisState := n.GenesisState() @@ -738,16 +748,21 @@ func NewDBStore(db DB, n *consensus.Network, genesisBlock types.Block) (_ *DBSto if err := dbs.Flush(); err != nil { return nil, consensus.State{}, err } - } else if dbGenesis.ID != genesisBlock.ID() { - // try to detect network so we can provide a more helpful error message - _, mainnetGenesis := Mainnet() - _, zenGenesis := TestnetZen() - if genesisBlock.ID() == mainnetGenesis.ID() && dbGenesis.ID == zenGenesis.ID() { - return nil, consensus.State{}, errors.New("cannot use Zen testnet database on mainnet") - } else if genesisBlock.ID() == zenGenesis.ID() && dbGenesis.ID == mainnetGenesis.ID() { - return nil, consensus.State{}, errors.New("cannot use mainnet database on Zen testnet") - } else { - return nil, consensus.State{}, errors.New("database previously initialized with different genesis block") + } else if version[0] != dbs.version { + return nil, consensus.State{}, errors.New("incompatible version; please migrate the database") + } else { + // verify the genesis block + if dbGenesis, ok := dbs.BestIndex(0); !ok || dbGenesis.ID != genesisBlock.ID() { + // try to detect network so we can provide a more helpful error message + _, mainnetGenesis := Mainnet() + _, zenGenesis := TestnetZen() + if genesisBlock.ID() == mainnetGenesis.ID() && dbGenesis.ID == zenGenesis.ID() { + return nil, consensus.State{}, errors.New("cannot use Zen testnet database on mainnet") + } else if genesisBlock.ID() == zenGenesis.ID() && dbGenesis.ID == mainnetGenesis.ID() { + return nil, consensus.State{}, errors.New("cannot use mainnet database on Zen testnet") + } else { + return nil, consensus.State{}, errors.New("database previously initialized with different genesis block") + } } } diff --git a/chain/migrate.go b/chain/migrate.go index 5a46f1b..0350db6 100644 --- a/chain/migrate.go +++ b/chain/migrate.go @@ -1,7 +1,6 @@ package chain import ( - "errors" "fmt" "go.sia.tech/core/consensus" @@ -91,66 +90,96 @@ func MigrateDB(db DB, n *consensus.Network) error { return nil // nothing to migrate } dbs := &DBStore{ - db: db, - n: n, - } - var err error - rewrite := func(bucket []byte, key []byte, from types.DecoderFrom, to types.EncoderTo) { - if err != nil { - return - } - b := dbs.bucket(bucket) - val := b.getRaw(key) - if val == nil { - return - } - d := types.NewBufDecoder(val) - from.DecodeFrom(d) - if d.Err() != nil { - err = d.Err() - return - } - b.put(key, to) - if dbs.shouldFlush() { - dbs.Flush() - } + db: db, + n: n, + version: 2, } - version := dbs.bucket(bVersion).getRaw(bVersion) - if len(version) != 1 { - return errors.New("invalid version") + version := dbs.bucket(bVersion).getRaw(dbs.vkey(bVersion)) + if version == nil { + version = []byte{1} } switch version[0] { case 1: + var err error + addVersion := func(bucket []byte, key []byte) { + if err != nil { + return + } + b := dbs.bucket(bucket) + b.putRaw(dbs.vkey(key), b.getRaw(key)) + b.deleteRaw(key) + } + rewrite := func(bucket []byte, key []byte, from types.DecoderFrom, to types.EncoderTo) { + if err != nil { + return + } + b := dbs.bucket(bucket) + val := b.getRaw(key) + d := types.NewBufDecoder(val) + from.DecodeFrom(d) + if d.Err() != nil { + err = d.Err() + return + } + b.deleteRaw(key) + b.put(key, to) + if dbs.shouldFlush() { + dbs.Flush() + } + } + var sb supplementedBlock for _, key := range db.BucketKeys(bBlocks) { - rewrite(bBlocks, key, (*oldSupplementedBlock)(&sb), &sb) + if len(key) == 32 { + rewrite(bBlocks, key, (*oldSupplementedBlock)(&sb), &sb) + } } var cs consensus.State for _, key := range db.BucketKeys(bStates) { - rewrite(bStates, key, (*versionedState)(&cs), &cs) + if len(key) == 32 { + rewrite(bStates, key, (*versionedState)(&cs), &cs) + } } var sce types.SiacoinElement for _, key := range db.BucketKeys(bSiacoinElements) { - rewrite(bSiacoinElements, key, (*oldSiacoinElement)(&sce), &sce) + if len(key) == 32 { + rewrite(bSiacoinElements, key, (*oldSiacoinElement)(&sce), &sce) + } } var sfe types.SiafundElement for _, key := range db.BucketKeys(bSiafundElements) { - rewrite(bSiafundElements, key, (*oldSiafundElement)(&sfe), &sfe) + if len(key) == 32 { + rewrite(bSiafundElements, key, (*oldSiafundElement)(&sfe), &sfe) + } } var fce types.FileContractElement for _, key := range db.BucketKeys(bFileContractElements) { if len(key) == 32 { rewrite(bFileContractElements, key, (*oldFileContractElement)(&fce), &fce) + } else if len(key) == 8 { + addVersion(bFileContractElements, key) + } + } + for _, key := range db.BucketKeys(bMainChain) { + if len(key) == 8 || len(key) == 5 { + addVersion(bMainChain, key) } } + for _, key := range db.BucketKeys(bTree) { + if len(key) == 4 { + addVersion(bTree, key) + } + } + dbs.bucket(bVersion).deleteRaw(bVersion) + dbs.bucket(bVersion).putRaw(dbs.vkey(bVersion), []byte{2}) + if err != nil { return err } - dbs.bucket(bVersion).putRaw(bVersion, []byte{2}) dbs.Flush() fallthrough - case 2: + case dbs.version: // up-to-date return nil default: diff --git a/db.go b/db.go index 0578085..1dd3999 100644 --- a/db.go +++ b/db.go @@ -42,6 +42,18 @@ func (db *BoltChainDB) CreateBucket(name []byte) (chain.DBBucket, error) { return db.tx.CreateBucket(name) } +func (db *BoltChainDB) BucketKeys(name []byte) [][]byte { + if err := db.newTx(); err != nil { + panic(err) + } + var keys [][]byte + c := db.tx.Bucket(name).Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + keys = append(keys, append([]byte(nil), k...)) + } + return keys +} + // Flush implements chain.DB. func (db *BoltChainDB) Flush() error { if db.tx == nil {