diff --git a/assetdb.go b/assetdb.go index 689a450..f46f261 100644 --- a/assetdb.go +++ b/assetdb.go @@ -17,27 +17,6 @@ type AssetDB struct { repository repository.Repository } -// New creates a new assetDB instance. -// It initializes the asset database with the specified database type and DSN. -func New(dbtype, dsn string) *AssetDB { - if db, err := repository.New(dbtype, dsn); err == nil && db != nil { - return &AssetDB{ - repository: db, - } - } - return nil -} - -// Close will close the assetdb and return any errors. -func (as *AssetDB) Close() error { - return as.repository.Close() -} - -// GetDBType returns the type of the underlying database. -func (as *AssetDB) GetDBType() string { - return as.repository.GetDBType() -} - // Create creates a new entity in the database. // If the edge is provided, the entity is created and linked to the source entity using the specified edge. // It returns the newly created entity and an error, if any. diff --git a/db.go b/db.go new file mode 100644 index 0000000..4c4cadd --- /dev/null +++ b/db.go @@ -0,0 +1,87 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package assetdb + +import ( + "embed" + "fmt" + "math/rand" + + pgmigrations "github.com/owasp-amass/asset-db/migrations/postgres" + sqlitemigrations "github.com/owasp-amass/asset-db/migrations/sqlite3" + "github.com/owasp-amass/asset-db/repository" + "github.com/owasp-amass/asset-db/repository/sqlrepo" + migrate "github.com/rubenv/sql-migrate" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// New creates a new assetDB instance. +// It initializes the asset database with the specified database type and DSN. +func New(dbtype, dsn string) *AssetDB { + if dbtype == sqlrepo.SQLiteMemory { + dsn = fmt.Sprintf("file:sqlite%d?mode=memory&cache=shared", rand.Int31n(1000)) + } + + if db, err := repository.New(dbtype, dsn); err == nil && db != nil { + if err := migrateDatabase(dbtype, dsn); err == nil { + return &AssetDB{ + repository: db, + } + } + } + return nil +} + +// Close will close the assetdb and return any errors. +func (as *AssetDB) Close() error { + return as.repository.Close() +} + +// GetDBType returns the type of the underlying database. +func (as *AssetDB) GetDBType() string { + return as.repository.GetDBType() +} + +func migrateDatabase(dbtype, dsn string) error { + var name string + var fs embed.FS + var database gorm.Dialector + + switch dbtype { + case sqlrepo.SQLite: + fallthrough + case sqlrepo.SQLiteMemory: + name = "sqlite3" + fs = sqlitemigrations.Migrations() + database = sqlite.Open(dsn) + case sqlrepo.Postgres: + name = "postgres" + fs = pgmigrations.Migrations() + database = postgres.Open(dsn) + } + + sql, err := gorm.Open(database, &gorm.Config{}) + if err != nil { + return err + } + + migrationsSource := migrate.EmbedFileSystemMigrationSource{ + FileSystem: fs, + Root: "/", + } + + sqlDb, err := sql.DB() + if err != nil { + return err + } + + _, err = migrate.Exec(sqlDb, name, migrationsSource, migrate.Up) + if err != nil { + return err + } + return nil +} diff --git a/go.mod b/go.mod index 5d089a8..0b942dc 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/stretchr/testify v1.9.0 gorm.io/datatypes v1.2.4 gorm.io/driver/postgres v1.5.9 + gorm.io/driver/sqlite v1.5.4 gorm.io/gorm v1.25.12 ) @@ -28,6 +29,7 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.19 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect @@ -39,7 +41,6 @@ require ( golang.org/x/text v0.20.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gorm.io/driver/mysql v1.5.7 // indirect - gorm.io/driver/sqlite v1.5.4 // indirect modernc.org/libc v1.61.0 // indirect modernc.org/mathutil v1.6.0 // indirect modernc.org/memory v1.8.0 // indirect diff --git a/repository/sqlrepo/db.go b/repository/sqlrepo/db.go new file mode 100644 index 0000000..8aa9cf5 --- /dev/null +++ b/repository/sqlrepo/db.go @@ -0,0 +1,75 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package sqlrepo + +import ( + "errors" + + "github.com/glebarez/sqlite" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +const ( + Postgres string = "postgres" + SQLite string = "sqlite" + SQLiteMemory string = "sqlite_memory" +) + +// sqlRepository is a repository implementation using GORM as the underlying ORM. +type sqlRepository struct { + db *gorm.DB + dbtype string +} + +// New creates a new instance of the asset database repository. +func New(dbtype, dsn string) (*sqlRepository, error) { + db, err := newDatabase(dbtype, dsn) + if err != nil { + return nil, err + } + + return &sqlRepository{ + db: db, + dbtype: dbtype, + }, nil +} + +// newDatabase creates a new GORM database connection based on the provided database type and data source name (dsn). +func newDatabase(dbtype, dsn string) (*gorm.DB, error) { + switch dbtype { + case Postgres: + return postgresDatabase(dsn) + case SQLite: + fallthrough + case SQLiteMemory: + return sqliteDatabase(dsn) + } + return nil, errors.New("unknown DB type") +} + +// postgresDatabase creates a new PostgreSQL database connection using the provided data source name (dsn). +func postgresDatabase(dsn string) (*gorm.DB, error) { + return gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) +} + +// sqliteDatabase creates a new SQLite database connection using the provided data source name (dsn). +func sqliteDatabase(dsn string) (*gorm.DB, error) { + return gorm.Open(sqlite.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) +} + +// Close implements the Repository interface. +func (sql *sqlRepository) Close() error { + if db, err := sql.db.DB(); err == nil { + return db.Close() + } + return errors.New("failed to obtain access to the database handle") +} + +// GetDBType returns the type of the database. +func (sql *sqlRepository) GetDBType() string { + return string(sql.dbtype) +} diff --git a/repository/sqlrepo/entity.go b/repository/sqlrepo/entity.go index 8f58516..2f2b60f 100644 --- a/repository/sqlrepo/entity.go +++ b/repository/sqlrepo/entity.go @@ -9,72 +9,11 @@ import ( "strconv" "time" - "github.com/glebarez/sqlite" "github.com/owasp-amass/asset-db/types" oam "github.com/owasp-amass/open-asset-model" - "gorm.io/driver/postgres" "gorm.io/gorm" - "gorm.io/gorm/logger" ) -const ( - Postgres string = "postgres" - SQLite string = "sqlite" -) - -// sqlRepository is a repository implementation using GORM as the underlying ORM. -type sqlRepository struct { - db *gorm.DB - dbtype string -} - -// New creates a new instance of the asset database repository. -func New(dbtype, dsn string) (*sqlRepository, error) { - db, err := newDatabase(dbtype, dsn) - if err != nil { - return nil, err - } - - return &sqlRepository{ - db: db, - dbtype: dbtype, - }, nil -} - -// newDatabase creates a new GORM database connection based on the provided database type and data source name (dsn). -func newDatabase(dbtype, dsn string) (*gorm.DB, error) { - switch dbtype { - case Postgres: - return postgresDatabase(dsn) - case SQLite: - return sqliteDatabase(dsn) - } - return nil, errors.New("unknown DB type") -} - -// postgresDatabase creates a new PostgreSQL database connection using the provided data source name (dsn). -func postgresDatabase(dsn string) (*gorm.DB, error) { - return gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) -} - -// sqliteDatabase creates a new SQLite database connection using the provided data source name (dsn). -func sqliteDatabase(dsn string) (*gorm.DB, error) { - return gorm.Open(sqlite.Open(dsn), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) -} - -// Close implements the Repository interface. -func (sql *sqlRepository) Close() error { - if db, err := sql.db.DB(); err == nil { - return db.Close() - } - return errors.New("failed to obtain access to the database handle") -} - -// GetDBType returns the type of the database. -func (sql *sqlRepository) GetDBType() string { - return string(sql.dbtype) -} - // CreateEntity creates a new entity in the database. // It takes an oam.Asset as input and persists it in the database. // The entity is serialized to JSON and stored in the Content field of the Entity struct.