Skip to content

Commit

Permalink
New now performs the db migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Nov 15, 2024
1 parent 268916f commit 2046261
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 83 deletions.
21 changes: 0 additions & 21 deletions assetdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,6 @@ type AssetDB struct {
repository repository.Repository
}

// New creates a new assetDB instance.
// It initializes the asset database with the specified database type and DSN.
func New(dbtype, dsn string) *AssetDB {
if db, err := repository.New(dbtype, dsn); err == nil && db != nil {
return &AssetDB{
repository: db,
}
}
return nil
}

// Close will close the assetdb and return any errors.
func (as *AssetDB) Close() error {
return as.repository.Close()
}

// GetDBType returns the type of the underlying database.
func (as *AssetDB) GetDBType() string {
return as.repository.GetDBType()
}

// Create creates a new entity in the database.
// If the edge is provided, the entity is created and linked to the source entity using the specified edge.
// It returns the newly created entity and an error, if any.
Expand Down
87 changes: 87 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright © by Jeff Foley 2017-2024. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// SPDX-License-Identifier: Apache-2.0

package assetdb

import (
"embed"
"fmt"
"math/rand"

pgmigrations "github.com/owasp-amass/asset-db/migrations/postgres"
sqlitemigrations "github.com/owasp-amass/asset-db/migrations/sqlite3"
"github.com/owasp-amass/asset-db/repository"
"github.com/owasp-amass/asset-db/repository/sqlrepo"
migrate "github.com/rubenv/sql-migrate"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

// New creates a new assetDB instance.
// It initializes the asset database with the specified database type and DSN.
func New(dbtype, dsn string) *AssetDB {
if dbtype == sqlrepo.SQLiteMemory {
dsn = fmt.Sprintf("file:sqlite%d?mode=memory&cache=shared", rand.Int31n(1000))
}

if db, err := repository.New(dbtype, dsn); err == nil && db != nil {
if err := migrateDatabase(dbtype, dsn); err == nil {
return &AssetDB{
repository: db,
}
}
}
return nil
}

// Close will close the assetdb and return any errors.
func (as *AssetDB) Close() error {
return as.repository.Close()
}

// GetDBType returns the type of the underlying database.
func (as *AssetDB) GetDBType() string {
return as.repository.GetDBType()
}

func migrateDatabase(dbtype, dsn string) error {
var name string
var fs embed.FS
var database gorm.Dialector

switch dbtype {
case sqlrepo.SQLite:
fallthrough
case sqlrepo.SQLiteMemory:
name = "sqlite3"
fs = sqlitemigrations.Migrations()
database = sqlite.Open(dsn)
case sqlrepo.Postgres:
name = "postgres"
fs = pgmigrations.Migrations()
database = postgres.Open(dsn)
}

sql, err := gorm.Open(database, &gorm.Config{})
if err != nil {
return err
}

migrationsSource := migrate.EmbedFileSystemMigrationSource{
FileSystem: fs,
Root: "/",
}

sqlDb, err := sql.DB()
if err != nil {
return err
}

_, err = migrate.Exec(sqlDb, name, migrationsSource, migrate.Up)
if err != nil {
return err
}
return nil
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/stretchr/testify v1.9.0
gorm.io/datatypes v1.2.4
gorm.io/driver/postgres v1.5.9
gorm.io/driver/sqlite v1.5.4
gorm.io/gorm v1.25.12
)

Expand All @@ -28,6 +29,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.19 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
Expand All @@ -39,7 +41,6 @@ require (
golang.org/x/text v0.20.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/mysql v1.5.7 // indirect
gorm.io/driver/sqlite v1.5.4 // indirect
modernc.org/libc v1.61.0 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.8.0 // indirect
Expand Down
75 changes: 75 additions & 0 deletions repository/sqlrepo/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright © by Jeff Foley 2017-2024. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// SPDX-License-Identifier: Apache-2.0

package sqlrepo

import (
"errors"

"github.com/glebarez/sqlite"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)

const (
Postgres string = "postgres"
SQLite string = "sqlite"
SQLiteMemory string = "sqlite_memory"
)

// sqlRepository is a repository implementation using GORM as the underlying ORM.
type sqlRepository struct {
db *gorm.DB
dbtype string
}

// New creates a new instance of the asset database repository.
func New(dbtype, dsn string) (*sqlRepository, error) {
db, err := newDatabase(dbtype, dsn)
if err != nil {
return nil, err
}

return &sqlRepository{
db: db,
dbtype: dbtype,
}, nil
}

// newDatabase creates a new GORM database connection based on the provided database type and data source name (dsn).
func newDatabase(dbtype, dsn string) (*gorm.DB, error) {
switch dbtype {
case Postgres:
return postgresDatabase(dsn)
case SQLite:
fallthrough
case SQLiteMemory:
return sqliteDatabase(dsn)
}
return nil, errors.New("unknown DB type")
}

// postgresDatabase creates a new PostgreSQL database connection using the provided data source name (dsn).
func postgresDatabase(dsn string) (*gorm.DB, error) {
return gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
}

// sqliteDatabase creates a new SQLite database connection using the provided data source name (dsn).
func sqliteDatabase(dsn string) (*gorm.DB, error) {
return gorm.Open(sqlite.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
}

// Close implements the Repository interface.
func (sql *sqlRepository) Close() error {
if db, err := sql.db.DB(); err == nil {
return db.Close()
}
return errors.New("failed to obtain access to the database handle")
}

// GetDBType returns the type of the database.
func (sql *sqlRepository) GetDBType() string {
return string(sql.dbtype)
}
61 changes: 0 additions & 61 deletions repository/sqlrepo/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,72 +9,11 @@ import (
"strconv"
"time"

"github.com/glebarez/sqlite"
"github.com/owasp-amass/asset-db/types"
oam "github.com/owasp-amass/open-asset-model"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)

const (
Postgres string = "postgres"
SQLite string = "sqlite"
)

// sqlRepository is a repository implementation using GORM as the underlying ORM.
type sqlRepository struct {
db *gorm.DB
dbtype string
}

// New creates a new instance of the asset database repository.
func New(dbtype, dsn string) (*sqlRepository, error) {
db, err := newDatabase(dbtype, dsn)
if err != nil {
return nil, err
}

return &sqlRepository{
db: db,
dbtype: dbtype,
}, nil
}

// newDatabase creates a new GORM database connection based on the provided database type and data source name (dsn).
func newDatabase(dbtype, dsn string) (*gorm.DB, error) {
switch dbtype {
case Postgres:
return postgresDatabase(dsn)
case SQLite:
return sqliteDatabase(dsn)
}
return nil, errors.New("unknown DB type")
}

// postgresDatabase creates a new PostgreSQL database connection using the provided data source name (dsn).
func postgresDatabase(dsn string) (*gorm.DB, error) {
return gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
}

// sqliteDatabase creates a new SQLite database connection using the provided data source name (dsn).
func sqliteDatabase(dsn string) (*gorm.DB, error) {
return gorm.Open(sqlite.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
}

// Close implements the Repository interface.
func (sql *sqlRepository) Close() error {
if db, err := sql.db.DB(); err == nil {
return db.Close()
}
return errors.New("failed to obtain access to the database handle")
}

// GetDBType returns the type of the database.
func (sql *sqlRepository) GetDBType() string {
return string(sql.dbtype)
}

// CreateEntity creates a new entity in the database.
// It takes an oam.Asset as input and persists it in the database.
// The entity is serialized to JSON and stored in the Content field of the Entity struct.
Expand Down

0 comments on commit 2046261

Please sign in to comment.