Skip to content

Commit

Permalink
added a method to close the database connection
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Aug 17, 2024
1 parent 4ac4ba9 commit d3663a9
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 0 deletions.
5 changes: 5 additions & 0 deletions assetdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ func New(dbType repository.DBType, dsn string) *AssetDB {
}
}

// Close will close the assetdb and return any errors.
func (as *AssetDB) Close() error {
return as.repository.Close()
}

// GetDBType returns the type of the underlying database.
func (as *AssetDB) GetDBType() string {
return as.repository.GetDBType()
Expand Down
4 changes: 4 additions & 0 deletions assetdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,10 @@ type mockAssetDB struct {
mock.Mock
}

func (m *mockAssetDB) Close() error {
return nil
}

func (m *mockAssetDB) GetDBType() string {
args := m.Called()
return args.String(0)
Expand Down
1 change: 1 addition & 0 deletions repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ type Repository interface {
RawQuery(sqlstr string, results interface{}) error
AssetQuery(constraints string) ([]*types.Asset, error)
RelationQuery(constraints string) ([]*types.Relation, error)
Close() error
}
8 changes: 8 additions & 0 deletions repository/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ func sqliteDatabase(dsn string) (*gorm.DB, error) {
return gorm.Open(sqlite.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
}

// Close implements the Repository interface.
func (sql *sqlRepository) Close() error {
if db, err := sql.db.DB(); err == nil {
return db.Close()
}
return errors.New("failed to obtain access to the database handle")
}

// GetDBType returns the type of the database.
func (sql *sqlRepository) GetDBType() string {
return string(sql.dbType)
Expand Down

0 comments on commit d3663a9

Please sign in to comment.