diff --git a/database.go b/database.go new file mode 100644 index 0000000..fce8697 --- /dev/null +++ b/database.go @@ -0,0 +1,17 @@ +package abutil + +import ( + "database/sql" +) + +// Rollback does a rollback on the transaction and returns either the error +// from the rollback if there was one or the alternative. +// This is useful if you have multiple statments in a row but don't want to +// call rollback and check for errors every time. +func Rollback(tx *sql.Tx, alt error) error { + if err := tx.Rollback(); err != nil { + return err + } + + return alt +} diff --git a/database_test.go b/database_test.go new file mode 100644 index 0000000..82c687e --- /dev/null +++ b/database_test.go @@ -0,0 +1,58 @@ +package abutil + +import ( + "database/sql" + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func mockDBContext(t *testing.T, fn func(*sql.DB)) { + db, err := sqlmock.New() + if err != nil { + t.Error(err) + } + defer db.Close() + + fn(db) +} + +func TestRollback(t *testing.T) { + mockDBContext(t, func(db *sql.DB) { + sqlmock.ExpectBegin() + sqlmock.ExpectRollback() + + tx, err := db.Begin() + if err != nil { + t.Error(err) + } + + alt := errors.New("Some alternative error") + err = Rollback(tx, alt) + + if err != alt { + t.Errorf("Expected Rollback to return %v, but got %v", alt, err) + } + }) +} + +func TestRollbackFailing(t *testing.T) { + mockDBContext(t, func(db *sql.DB) { + rberr := errors.New("Some rollback error") + + sqlmock.ExpectBegin() + sqlmock.ExpectRollback(). + WillReturnError(rberr) + + tx, err := db.Begin() + if err != nil { + t.Error(err) + } + + err = Rollback(tx, errors.New("This should not be used")) + if err != rberr { + t.Errorf("Expected Rollback to return %v, but got %v", rberr, err) + } + }) +}