diff --git a/named_context.go b/named_context.go new file mode 100644 index 00000000..9405007e --- /dev/null +++ b/named_context.go @@ -0,0 +1,132 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" +) + +// A union interface of contextPreparer and binder, required to be able to +// prepare named statements with context (as the bindtype must be determined). +type namedPreparerContext interface { + PreparerContext + binder +} + +func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { + bindType := BindType(p.DriverName()) + q, args, err := compileNamedQuery([]byte(query), bindType) + if err != nil { + return nil, err + } + stmt, err := PreparexContext(ctx, p, q) + if err != nil { + return nil, err + } + return &NamedStmt{ + QueryString: q, + Params: args, + Stmt: stmt, + }, nil +} + +// ExecContext executes a named statement using the struct passed. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return *new(sql.Result), err + } + return n.Stmt.ExecContext(ctx, args...) +} + +// QueryContext executes a named statement using the struct argument, returning rows. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return nil, err + } + return n.Stmt.QueryContext(ctx, args...) +} + +// QueryRowContext executes a named statement against the database. Because sqlx cannot +// create a *sql.Row with an error condition pre-set for binding errors, sqlx +// returns a *sqlx.Row instead. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row { + args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) + if err != nil { + return &Row{err: err} + } + return n.Stmt.QueryRowxContext(ctx, args...) +} + +// MustExecContext execs a NamedStmt, panicing on error +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result { + res, err := n.ExecContext(ctx, arg) + if err != nil { + panic(err) + } + return res +} + +// QueryxContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) { + r, err := n.QueryContext(ctx, arg) + if err != nil { + return nil, err + } + return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err +} + +// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is +// an alias for QueryRow. +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row { + return n.QueryRowContext(ctx, arg) +} + +// SelectContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error { + rows, err := n.QueryxContext(ctx, arg) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// GetContext using this NamedStmt +// Any named placeholder parameters are replaced with fields from arg. +func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error { + r := n.QueryRowxContext(ctx, arg) + return r.scanAny(dest, false) +} + +// NamedQueryContext binds a named query and then runs Query on the result using the +// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with +// map[string]interface{} types. +func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.QueryxContext(ctx, q, args...) +} + +// NamedExecContext uses BindStruct to get a query executable by the driver and +// then runs Exec on the result. Returns an error from the binding +// or the query excution itself. +func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { + q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) + if err != nil { + return nil, err + } + return e.ExecContext(ctx, q, args...) +} diff --git a/named_context_test.go b/named_context_test.go new file mode 100644 index 00000000..87e94ac2 --- /dev/null +++ b/named_context_test.go @@ -0,0 +1,136 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" + "testing" +) + +func TestNamedContextQueries(t *testing.T) { + RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) { + loadDefaultFixture(db, t) + test := Test{t} + var ns *NamedStmt + var err error + + ctx := context.Background() + + // Check that invalid preparations fail + ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name") + if err == nil { + t.Error("Expected an error with invalid prepared statement.") + } + + ns, err = db.PrepareNamedContext(ctx, "invalid sql") + if err == nil { + t.Error("Expected an error with invalid prepared statement.") + } + + // Check closing works as anticipated + ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name") + test.Error(err) + err = ns.Close() + test.Error(err) + + ns, err = db.PrepareNamedContext(ctx, ` + SELECT first_name, last_name, email + FROM person WHERE first_name=:first_name AND email=:email`) + test.Error(err) + + // test Queryx w/ uses Query + p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} + + rows, err := ns.QueryxContext(ctx, p) + test.Error(err) + for rows.Next() { + var p2 Person + rows.StructScan(&p2) + if p.FirstName != p2.FirstName { + t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) + } + if p.LastName != p2.LastName { + t.Errorf("got %s, expected %s", p.LastName, p2.LastName) + } + if p.Email != p2.Email { + t.Errorf("got %s, expected %s", p.Email, p2.Email) + } + } + + // test Select + people := make([]Person, 0, 5) + err = ns.SelectContext(ctx, &people, p) + test.Error(err) + + if len(people) != 1 { + t.Errorf("got %d results, expected %d", len(people), 1) + } + if p.FirstName != people[0].FirstName { + t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) + } + if p.LastName != people[0].LastName { + t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) + } + if p.Email != people[0].Email { + t.Errorf("got %s, expected %s", p.Email, people[0].Email) + } + + // test Exec + ns, err = db.PrepareNamedContext(ctx, ` + INSERT INTO person (first_name, last_name, email) + VALUES (:first_name, :last_name, :email)`) + test.Error(err) + + js := Person{ + FirstName: "Julien", + LastName: "Savea", + Email: "jsavea@ab.co.nz", + } + _, err = ns.ExecContext(ctx, js) + test.Error(err) + + // Make sure we can pull him out again + p2 := Person{} + db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) + if p2.Email != js.Email { + t.Errorf("expected %s, got %s", js.Email, p2.Email) + } + + // test Txn NamedStmts + tx := db.MustBeginTx(ctx, nil) + txns := tx.NamedStmtContext(ctx, ns) + + // We're going to add Steven in this txn + sl := Person{ + FirstName: "Steven", + LastName: "Luatua", + Email: "sluatua@ab.co.nz", + } + + _, err = txns.ExecContext(ctx, sl) + test.Error(err) + // then rollback... + tx.Rollback() + // looking for Steven after a rollback should fail + err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + if err != sql.ErrNoRows { + t.Errorf("expected no rows error, got %v", err) + } + + // now do the same, but commit + tx = db.MustBeginTx(ctx, nil) + txns = tx.NamedStmtContext(ctx, ns) + _, err = txns.ExecContext(ctx, sl) + test.Error(err) + tx.Commit() + + // looking for Steven after a Commit should succeed + err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) + test.Error(err) + if p2.Email != sl.Email { + t.Errorf("expected %s, got %s", sl.Email, p2.Email) + } + + }) +} diff --git a/sqlx_context.go b/sqlx_context.go new file mode 100644 index 00000000..aa0e95cd --- /dev/null +++ b/sqlx_context.go @@ -0,0 +1,299 @@ +// +build go1.8 + +package sqlx + +import ( + "context" + "database/sql" + "fmt" + "io/ioutil" + "path/filepath" + "reflect" +) + +// ConnectContext to a database and verify with a ping. +func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) { + db, err := Open(driverName, dataSourceName) + if err != nil { + return db, err + } + err = db.PingContext(ctx) + return db, err +} + +// QueryerContext is an interface used by GetContext and SelectContext +type QueryerContext interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) + QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row +} + +// PreparerContext is an interface used by PreparexContext. +type PreparerContext interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +// ExecerContext is an interface used by MustExecContext and LoadFileContext +type ExecerContext interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// ExtContext is a union interface which can bind, query, and exec, with Context +// used by NamedQueryContext and NamedExecContext. +type ExtContext interface { + binder + QueryerContext + ExecerContext +} + +// SelectContext executes a query using the provided Queryer, and StructScans +// each row into dest, which must be a slice. If the slice elements are +// scannable, then the result set must have only one column. Otherwise, +// StructScan is used. The *sql.Rows are closed automatically. +// Any placeholder parameters are replaced with supplied args. +func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { + rows, err := q.QueryxContext(ctx, query, args...) + if err != nil { + return err + } + // if something happens here, we want to make sure the rows are Closed + defer rows.Close() + return scanAll(rows, dest, false) +} + +// PreparexContext prepares a statement. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { + s, err := p.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err +} + +// GetContext does a QueryRow using the provided Queryer, and scans the +// resulting row to dest. If dest is scannable, the result must only have one +// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like +// row.Scan would. Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { + r := q.QueryRowxContext(ctx, query, args...) + return r.scanAny(dest, false) +} + +// LoadFileContext exec's every statement in a file (as a single call to Exec). +// LoadFileContext may return a nil *sql.Result if errors are encountered +// locating or reading the file at path. LoadFile reads the entire file into +// memory, so it is not suitable for loading large data dumps, but can be useful +// for initializing schemas or loading indexes. +// +// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 +// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting +// this by requiring something with DriverName() and then attempting to split the +// queries will be difficult to get right, and its current driver-specific behavior +// is deemed at least not complex in its incorrectness. +func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) { + realpath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + contents, err := ioutil.ReadFile(realpath) + if err != nil { + return nil, err + } + res, err := e.ExecContext(ctx, string(contents)) + return &res, err +} + +// MustExecContext execs the query using e and panics if there was an error. +// Any placeholder parameters are replaced with supplied args. +func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { + res, err := e.ExecContext(ctx, query, args...) + if err != nil { + panic(err) + } + return res +} + +// PrepareNamedContext returns an sqlx.NamedStmt +func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { + return prepareNamedContext(ctx, db, query) +} + +// NamedQueryContext using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { + return NamedQueryContext(ctx, db, query, arg) +} + +// NamedExecContext using this DB. +// Any named placeholder parameters are replaced with fields from arg. +func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return NamedExecContext(ctx, db, query, arg) +} + +// SelectContext using this DB. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return SelectContext(ctx, db, dest, query, args...) +} + +// GetContext using this DB. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return GetContext(ctx, db, dest, query, args...) +} + +// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. +// +// The provided context is used for the preparation of the statement, not for +// the execution of the statement. +func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { + return PreparexContext(ctx, db, query) +} + +// QueryxContext queries the database and returns an *sqlx.Rows. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := db.DB.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// QueryRowxContext queries the database and returns an *sqlx.Row. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.DB.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} +} + +// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead +// of an *sql.Tx. +// +// The provided context is used until the transaction is committed or rolled +// back. If the context is canceled, the sql package will roll back the +// transaction. Tx.Commit will return an error if the context provided to +// MustBeginContext is canceled. +func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { + tx, err := db.BeginTxx(ctx, opts) + if err != nil { + panic(err) + } + return tx +} + +// MustExecContext (panic) runs MustExec using this database. +// Any placeholder parameters are replaced with supplied args. +func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { + return MustExecContext(ctx, db, query, args...) +} + +// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an +// *sql.Tx. +// +// The provided context is used until the transaction is committed or rolled +// back. If the context is canceled, the sql package will roll back the +// transaction. Tx.Commit will return an error if the context provided to +// BeginxContext is canceled. +func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err +} + +// StmtxContext returns a version of the prepared statement which runs within a +// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt. +func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt { + var s *sql.Stmt + switch v := stmt.(type) { + case Stmt: + s = v.Stmt + case *Stmt: + s = v.Stmt + case sql.Stmt: + s = &v + case *sql.Stmt: + s = v + default: + panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) + } + return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper} +} + +// NamedStmtContext returns a version of the prepared statement which runs +// within a transaction. +func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt { + return &NamedStmt{ + QueryString: stmt.QueryString, + Params: stmt.Params, + Stmt: tx.StmtxContext(ctx, stmt.Stmt), + } +} + +// MustExecContext runs MustExecContext within a transaction. +// Any placeholder parameters are replaced with supplied args. +func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { + return MustExecContext(ctx, tx, query, args...) +} + +// SelectContext using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { + return SelectContext(ctx, &qStmt{s}, dest, "", args...) +} + +// GetContext using the prepared statement. +// Any placeholder parameters are replaced with supplied args. +// An error is returned if the result set is empty. +func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { + return GetContext(ctx, &qStmt{s}, dest, "", args...) +} + +// MustExecContext (panic) using this statement. Note that the query portion of +// the error output will be blank, as Stmt does not expose its query. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { + return MustExecContext(ctx, &qStmt{s}, "", args...) +} + +// QueryRowxContext using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { + qs := &qStmt{s} + return qs.QueryRowxContext(ctx, "", args...) +} + +// QueryxContext using this statement. +// Any placeholder parameters are replaced with supplied args. +func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { + qs := &qStmt{s} + return qs.QueryxContext(ctx, "", args...) +} + +func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return q.Stmt.QueryContext(ctx, args...) +} + +func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + r, err := q.Stmt.QueryContext(ctx, args...) + if err != nil { + return nil, err + } + return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err +} + +func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := q.Stmt.QueryContext(ctx, args...) + return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} +} + +func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return q.Stmt.ExecContext(ctx, args...) +} diff --git a/sqlx_context_test.go b/sqlx_context_test.go new file mode 100644 index 00000000..40ae5ca1 --- /dev/null +++ b/sqlx_context_test.go @@ -0,0 +1,1344 @@ +// +build go1.8 + +// The following environment variables, if set, will be used: +// +// * SQLX_SQLITE_DSN +// * SQLX_POSTGRES_DSN +// * SQLX_MYSQL_DSN +// +// Set any of these variables to 'skip' to skip them. Note that for MySQL, +// the string '?parseTime=True' will be appended to the DSN if it's not there +// already. +// +package sqlx + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "log" + "strings" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx/reflectx" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +func MultiExecContext(ctx context.Context, e ExecerContext, query string) { + stmts := strings.Split(query, ";\n") + if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { + stmts = stmts[:len(stmts)-1] + } + for _, s := range stmts { + _, err := e.ExecContext(ctx, s) + if err != nil { + fmt.Println(err, s) + } + } +} + +func RunWithSchemaContext(ctx context.Context, schema Schema, t *testing.T, test func(ctx context.Context, db *DB, t *testing.T)) { + runner := func(ctx context.Context, db *DB, t *testing.T, create, drop string) { + defer func() { + MultiExecContext(ctx, db, drop) + }() + + MultiExecContext(ctx, db, create) + test(ctx, db, t) + } + + if TestPostgres { + create, drop := schema.Postgres() + runner(ctx, pgdb, t, create, drop) + } + if TestSqlite { + create, drop := schema.Sqlite3() + runner(ctx, sldb, t, create, drop) + } + if TestMysql { + create, drop := schema.MySQL() + runner(ctx, mysqldb, t, create, drop) + } +} + +func loadDefaultFixtureContext(ctx context.Context, db *DB, t *testing.T) { + tx := db.MustBeginTx(ctx, nil) + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + if db.DriverName() == "mysql" { + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") + } else { + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") + } + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") + tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") + tx.Commit() +} + +// Test a new backwards compatible feature, that missing scan destinations +// will silently scan into sql.RawText rather than failing/panicing +func TestMissingNamesContextContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + type PersonPlus struct { + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string + //AddedAt time.Time `db:"added_at"` + } + + // test Select first + pps := []PersonPlus{} + // pps lacks added_at destination + err := db.SelectContext(ctx, &pps, "SELECT * FROM person") + if err == nil { + t.Error("Expected missing name from Select to fail, but it did not.") + } + + // test Get + pp := PersonPlus{} + err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") + if err == nil { + t.Error("Expected missing name Get to fail, but it did not.") + } + + // test naked StructScan + pps = []PersonPlus{} + rows, err := db.QueryContext(ctx, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Fatal(err) + } + rows.Next() + err = StructScan(rows, &pps) + if err == nil { + t.Error("Expected missing name in StructScan to fail, but it did not.") + } + rows.Close() + + // now try various things with unsafe set. + db = db.Unsafe() + pps = []PersonPlus{} + err = db.SelectContext(ctx, &pps, "SELECT * FROM person") + if err != nil { + t.Error(err) + } + + // test Get + pp = PersonPlus{} + err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Error(err) + } + + // test naked StructScan + pps = []PersonPlus{} + rowsx, err := db.QueryxContext(ctx, "SELECT * FROM person LIMIT 1") + if err != nil { + t.Fatal(err) + } + rowsx.Next() + err = StructScan(rowsx, &pps) + if err != nil { + t.Error(err) + } + rowsx.Close() + + // test Named stmt + if !isUnsafe(db) { + t.Error("Expected db to be unsafe, but it isn't") + } + nstmt, err := db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) + if err != nil { + t.Fatal(err) + } + // its internal stmt should be marked unsafe + if !nstmt.Stmt.unsafe { + t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") + } + pps = []PersonPlus{} + err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) + if err != nil { + t.Fatal(err) + } + if len(pps) != 1 { + t.Errorf("Expected 1 person back, got %d", len(pps)) + } + + // test it with a safe db + db.unsafe = false + if isUnsafe(db) { + t.Error("expected db to be safe but it isn't") + } + nstmt, err = db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) + if err != nil { + t.Fatal(err) + } + // it should be safe + if isUnsafe(nstmt) { + t.Error("NamedStmt did not inherit safety") + } + nstmt.Unsafe() + if !isUnsafe(nstmt) { + t.Error("expected newly unsafed NamedStmt to be unsafe") + } + pps = []PersonPlus{} + err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) + if err != nil { + t.Fatal(err) + } + if len(pps) != 1 { + t.Errorf("Expected 1 person back, got %d", len(pps)) + } + + }) +} + +func TestEmbeddedStructsContextContext(t *testing.T) { + type Loop1 struct{ Person } + type Loop2 struct{ Loop1 } + type Loop3 struct{ Loop2 } + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + peopleAndPlaces := []PersonPlace{} + err := db.SelectContext( + ctx, + &peopleAndPlaces, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Fatal(err) + } + for _, pp := range peopleAndPlaces { + if len(pp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(pp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + } + + // test embedded structs with StructScan + rows, err := db.QueryxContext( + ctx, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Error(err) + } + + perp := PersonPlace{} + rows.Next() + err = rows.StructScan(&perp) + if err != nil { + t.Error(err) + } + + if len(perp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(perp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + + rows.Close() + + // test the same for embedded pointer structs + peopleAndPlacesPtrs := []PersonPlacePtr{} + err = db.SelectContext( + ctx, + &peopleAndPlacesPtrs, + `SELECT person.*, place.* FROM + person natural join place`) + if err != nil { + t.Fatal(err) + } + for _, pp := range peopleAndPlacesPtrs { + if len(pp.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + if len(pp.Place.Country) == 0 { + t.Errorf("Expected non zero lengthed country.") + } + } + + // test "deep nesting" + l3s := []Loop3{} + err = db.SelectContext(ctx, &l3s, `select * from person`) + if err != nil { + t.Fatal(err) + } + for _, l3 := range l3s { + if len(l3.Loop2.Loop1.Person.FirstName) == 0 { + t.Errorf("Expected non zero lengthed first name.") + } + } + + // test "embed conflicts" + ec := []EmbedConflict{} + err = db.SelectContext(ctx, &ec, `select * from person`) + // I'm torn between erroring here or having some kind of working behavior + // in order to allow for more flexibility in destination structs + if err != nil { + t.Errorf("Was not expecting an error on embed conflicts.") + } + }) +} + +func TestJoinQueryContext(t *testing.T) { + type Employee struct { + Name string + ID int64 + // BossID is an id into the employee table + BossID sql.NullInt64 `db:"boss_id"` + } + type Boss Employee + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + + var employees []struct { + Employee + Boss `db:"boss"` + } + + err := db.SelectContext(ctx, + &employees, + `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees + JOIN employees AS boss ON employees.boss_id = boss.id`) + if err != nil { + t.Fatal(err) + } + + for _, em := range employees { + if len(em.Employee.Name) == 0 { + t.Errorf("Expected non zero lengthed name.") + } + if em.Employee.BossID.Int64 != em.Boss.ID { + t.Errorf("Expected boss ids to match") + } + } + }) +} + +func TestJoinQueryNamedPointerStructsContext(t *testing.T) { + type Employee struct { + Name string + ID int64 + // BossID is an id into the employee table + BossID sql.NullInt64 `db:"boss_id"` + } + type Boss Employee + + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + + var employees []struct { + Emp1 *Employee `db:"emp1"` + Emp2 *Employee `db:"emp2"` + *Boss `db:"boss"` + } + + err := db.SelectContext(ctx, + &employees, + `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", + emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", + boss.id "boss.id", boss.name "boss.name" FROM employees AS emp + JOIN employees AS boss ON emp.boss_id = boss.id + `) + if err != nil { + t.Fatal(err) + } + + for _, em := range employees { + if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { + t.Errorf("Expected non zero lengthed name.") + } + if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { + t.Errorf("Expected boss ids to match") + } + } + }) +} + +func TestSelectSliceMapTimeContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + rows, err := db.QueryxContext(ctx, "SELECT * FROM person") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + _, err := rows.SliceScan() + if err != nil { + t.Error(err) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM person") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + m := map[string]interface{}{} + err := rows.MapScan(m) + if err != nil { + t.Error(err) + } + } + + }) +} + +func TestNilReceiverContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + var p *Person + err := db.GetContext(ctx, p, "SELECT * FROM person LIMIT 1") + if err == nil { + t.Error("Expected error when getting into nil struct ptr.") + } + var pp *[]Person + err = db.SelectContext(ctx, pp, "SELECT * FROM person") + if err == nil { + t.Error("Expected an error when selecting into nil slice ptr.") + } + }) +} + +func TestNamedQueryContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE place ( + id integer PRIMARY KEY, + name text NULL + ); + CREATE TABLE person ( + first_name text NULL, + last_name text NULL, + email text NULL + ); + CREATE TABLE placeperson ( + first_name text NULL, + last_name text NULL, + email text NULL, + place_id integer NULL + ); + CREATE TABLE jsperson ( + "FIRST" text NULL, + last_name text NULL, + "EMAIL" text NULL + );`, + drop: ` + drop table person; + drop table jsperson; + drop table place; + drop table placeperson; + `, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type Person struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + } + + p := Person{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` + _, err := db.NamedExecContext(ctx, q1, p) + if err != nil { + log.Fatal(err) + } + + p2 := &Person{} + rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(p2) + if err != nil { + t.Error(err) + } + if p2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + p2.FirstName.String) + } + if p2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + p2.LastName.String) + } + } + + // these are tests for #73; they verify that named queries work if you've + // changed the db mapper. This code checks both NamedQuery "ad-hoc" style + // queries and NamedStmt queries, which use different code paths internally. + old := *db.Mapper + + type JSONPerson struct { + FirstName sql.NullString `json:"FIRST"` + LastName sql.NullString `json:"last_name"` + Email sql.NullString + } + + jp := JSONPerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "smith", Valid: true}, + Email: sql.NullString{String: "ben@smith.com", Valid: true}, + } + + db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) + + // prepare queries for case sensitivity to test our ToUpper function. + // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line + // strings are `` we use "" by default and swap out for MySQL + pdb := func(s string, db *DB) string { + if db.DriverName() == "mysql" { + return strings.Replace(s, `"`, "`", -1) + } + return s + } + + q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` + _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) + if err != nil { + t.Fatal(err, db.DriverName()) + } + + // Checks that a person pulled out of the db matches the one we put in + check := func(t *testing.T, rows *Rows) { + jp = JSONPerson{} + for rows.Next() { + err = rows.StructScan(&jp) + if err != nil { + t.Error(err) + } + if jp.FirstName.String != "ben" { + t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) + } + if jp.LastName.String != "smith" { + t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) + } + if jp.Email.String != "ben@smith.com" { + t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) + } + } + } + + ns, err := db.PrepareNamed(pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db)) + + if err != nil { + t.Fatal(err) + } + rows, err = ns.QueryxContext(ctx, jp) + if err != nil { + t.Fatal(err) + } + + check(t, rows) + + // Check exactly the same thing, but with db.NamedQuery, which does not go + // through the PrepareNamed/NamedStmt path. + rows, err = db.NamedQueryContext(ctx, pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db), jp) + if err != nil { + t.Fatal(err) + } + + check(t, rows) + + db.Mapper = &old + + // Test nested structs + type Place struct { + ID int `db:"id"` + Name sql.NullString `db:"name"` + } + type PlacePerson struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString + Place Place `db:"place"` + } + + pl := Place{ + Name: sql.NullString{String: "myplace", Valid: true}, + } + + pp := PlacePerson{ + FirstName: sql.NullString{String: "ben", Valid: true}, + LastName: sql.NullString{String: "doe", Valid: true}, + Email: sql.NullString{String: "ben@doe.com", Valid: true}, + } + + q2 := `INSERT INTO place (id, name) VALUES (1, :name)` + _, err = db.NamedExecContext(ctx, q2, pl) + if err != nil { + log.Fatal(err) + } + + id := 1 + pp.Place.ID = id + + q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExecContext(ctx, q3, pp) + if err != nil { + log.Fatal(err) + } + + pp2 := &PlacePerson{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + first_name, + last_name, + email, + place.id AS "place.id", + place.name AS "place.name" + FROM placeperson + INNER JOIN place ON place.id = placeperson.place_id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp2) + if err != nil { + t.Error(err) + } + if pp2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) + } + if pp2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + pp2.LastName.String) + } + if pp2.Place.Name.String != "myplace" { + t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) + } + if pp2.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) + } + } + }) +} + +func TestNilInsertsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE tt ( + id integer, + value text NULL DEFAULT NULL + );`, + drop: "drop table tt;", + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type TT struct { + ID int + Value *string + } + var v, v2 TT + r := db.Rebind + + db.MustExecContext(ctx, r(`INSERT INTO tt (id) VALUES (1)`)) + db.GetContext(ctx, &v, r(`SELECT * FROM tt`)) + if v.ID != 1 { + t.Errorf("Expecting id of 1, got %v", v.ID) + } + if v.Value != nil { + t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) + } + + v.ID = 2 + // NOTE: this incidentally uncovered a bug which was that named queries with + // pointer destinations would not work if the passed value here was not addressable, + // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for + // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly + // function. This next line is important as it provides the only coverage for this. + db.NamedExecContext(ctx, `INSERT INTO tt (id, value) VALUES (:id, :value)`, v) + + db.GetContext(ctx, &v2, r(`SELECT * FROM tt WHERE id=2`)) + if v.ID != v2.ID { + t.Errorf("%v != %v", v.ID, v2.ID) + } + if v2.Value != nil { + t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) + } + }) +} + +func TestScanErrorContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE kv ( + k text, + v integer + );`, + drop: `drop table kv;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type WrongTypes struct { + K int + V string + } + _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) + if err != nil { + t.Error(err) + } + + rows, err := db.QueryxContext(ctx, "SELECT * FROM kv") + if err != nil { + t.Error(err) + } + for rows.Next() { + var wt WrongTypes + err := rows.StructScan(&wt) + if err == nil { + t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) + } + } + }) +} + +// FIXME: this function is kinda big but it slows things down to be constantly +// loading and reloading the schema.. + +func TestUsageContext(t *testing.T) { + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + slicemembers := []SliceMember{} + err := db.SelectContext(ctx, &slicemembers, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + people := []Person{} + + err = db.SelectContext(ctx, &people, "SELECT * FROM person ORDER BY first_name ASC") + if err != nil { + t.Fatal(err) + } + + jason, john := people[0], people[1] + if jason.FirstName != "Jason" { + t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) + } + if jason.LastName != "Moiron" { + t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) + } + if jason.Email != "jmoiron@jmoiron.net" { + t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) + } + if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { + t.Errorf("John Doe's person record not what expected: Got %v\n", john) + } + + jason = Person{} + err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") + + if err != nil { + t.Fatal(err) + } + if jason.FirstName != "Jason" { + t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) + } + + err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") + if err == nil { + t.Errorf("Expecting an error, got nil\n") + } + if err != sql.ErrNoRows { + t.Errorf("Expected sql.ErrNoRows, got %v\n", err) + } + + // The following tests check statement reuse, which was actually a problem + // due to copying being done when creating Stmt's which was eventually removed + stmt1, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + jason = Person{} + + row := stmt1.QueryRowx("DoesNotExist") + row.Scan(&jason) + row = stmt1.QueryRowx("DoesNotExist") + row.Scan(&jason) + + err = stmt1.GetContext(ctx, &jason, "DoesNotExist User") + if err == nil { + t.Error("Expected an error") + } + err = stmt1.GetContext(ctx, &jason, "DoesNotExist User 2") + if err != nil { + t.Fatal(err) + } + + stmt2, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + jason = Person{} + tx, err := db.Beginx() + if err != nil { + t.Fatal(err) + } + tstmt2 := tx.Stmtx(stmt2) + row2 := tstmt2.QueryRowx("Jason") + err = row2.StructScan(&jason) + if err != nil { + t.Error(err) + } + tx.Commit() + + places := []*Place{} + err = db.SelectContext(ctx, &places, "SELECT telcode FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + usa, singsing, honkers := places[0], places[1], places[2] + + if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { + t.Errorf("Expected integer telcodes to work, got %#v", places) + } + + placesptr := []PlacePtr{} + err = db.SelectContext(ctx, &placesptr, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Error(err) + } + //fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) + + // if you have null fields and use SELECT *, you must use sql.Null* in your struct + // this test also verifies that you can use either a []Struct{} or a []*Struct{} + places2 := []Place{} + err = db.SelectContext(ctx, &places2, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + + usa, singsing, honkers = &places2[0], &places2[1], &places2[2] + + // this should return a type error that &p is not a pointer to a struct slice + p := Place{} + err = db.SelectContext(ctx, &p, "SELECT * FROM place ORDER BY telcode ASC") + if err == nil { + t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") + } + + // this should be an error + pl := []Place{} + err = db.SelectContext(ctx, pl, "SELECT * FROM place ORDER BY telcode ASC") + if err == nil { + t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") + } + + if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { + t.Errorf("Expected integer telcodes to work, got %#v", places) + } + + stmt, err := db.PreparexContext(ctx, db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) + if err != nil { + t.Error(err) + } + + places = []*Place{} + err = stmt.SelectContext(ctx, &places, 10) + if len(places) != 2 { + t.Error("Expected 2 places, got 0.") + } + if err != nil { + t.Fatal(err) + } + singsing, honkers = places[0], places[1] + if singsing.TelCode != 65 || honkers.TelCode != 852 { + t.Errorf("Expected the right telcodes, got %#v", places) + } + + rows, err := db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + place := Place{} + for rows.Next() { + err = rows.StructScan(&place) + if err != nil { + t.Fatal(err) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + m := map[string]interface{}{} + for rows.Next() { + err = rows.MapScan(m) + if err != nil { + t.Fatal(err) + } + _, ok := m["country"] + if !ok { + t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) + } + } + + rows, err = db.QueryxContext(ctx, "SELECT * FROM place") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + s, err := rows.SliceScan() + if err != nil { + t.Error(err) + } + if len(s) != 3 { + t.Errorf("Expected 3 columns in result, got %d\n", len(s)) + } + } + + // test advanced querying + // test that NamedExec works with a map as well as a struct + _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ + "first": "Bin", + "last": "Smuth", + "email": "bensmith@allblacks.nz", + }) + if err != nil { + t.Fatal(err) + } + + // ensure that if the named param happens right at the end it still works + // ensure that NamedQuery works with a map[string]interface{} + rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) + if err != nil { + t.Fatal(err) + } + + ben := &Person{} + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Bin" { + t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) + } + if ben.LastName != "Smuth" { + t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) + } + } + + ben.FirstName = "Ben" + ben.LastName = "Smith" + ben.Email = "binsmuth@allblacks.nz" + + // Insert via a named query using the struct + _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) + + if err != nil { + t.Fatal(err) + } + + rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", ben) + if err != nil { + t.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Ben" { + t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) + } + if ben.LastName != "Smith" { + t.Fatal("Expected first name of `Smith`, got " + ben.LastName) + } + } + // ensure that Get does not panic on emppty result set + person := &Person{} + err = db.GetContext(ctx, person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") + if err == nil { + t.Fatal("Should have got an error for Get on non-existant row.") + } + + // lets test prepared statements some more + + stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.QueryxContext(ctx, "Ben") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(ben) + if err != nil { + t.Fatal(err) + } + if ben.FirstName != "Ben" { + t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) + } + if ben.LastName != "Smith" { + t.Fatal("Expected first name of `Smith`, got " + ben.LastName) + } + } + + john = Person{} + stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) + if err != nil { + t.Error(err) + } + err = stmt.GetContext(ctx, &john, "John") + if err != nil { + t.Error(err) + } + + // test name mapping + // THIS USED TO WORK BUT WILL NO LONGER WORK. + db.MapperFunc(strings.ToUpper) + rsa := CPlace{} + err = db.GetContext(ctx, &rsa, "SELECT * FROM capplace;") + if err != nil { + t.Error(err, "in db:", db.DriverName()) + } + db.MapperFunc(strings.ToLower) + + // create a copy and change the mapper, then verify the copy behaves + // differently from the original. + dbCopy := NewDb(db.DB, db.DriverName()) + dbCopy.MapperFunc(strings.ToUpper) + err = dbCopy.GetContext(ctx, &rsa, "SELECT * FROM capplace;") + if err != nil { + fmt.Println(db.DriverName()) + t.Error(err) + } + + err = db.GetContext(ctx, &rsa, "SELECT * FROM cappplace;") + if err == nil { + t.Error("Expected no error, got ", err) + } + + // test base type slices + var sdest []string + rows, err = db.QueryxContext(ctx, "SELECT email FROM person ORDER BY email ASC;") + if err != nil { + t.Error(err) + } + err = scanAll(rows, &sdest, false) + if err != nil { + t.Error(err) + } + + // test Get with base types + var count int + err = db.GetContext(ctx, &count, "SELECT count(*) FROM person;") + if err != nil { + t.Error(err) + } + if count != len(sdest) { + t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) + } + + // test Get and Select with time.Time, #84 + var addedAt time.Time + err = db.GetContext(ctx, &addedAt, "SELECT added_at FROM person LIMIT 1;") + if err != nil { + t.Error(err) + } + + var addedAts []time.Time + err = db.SelectContext(ctx, &addedAts, "SELECT added_at FROM person;") + if err != nil { + t.Error(err) + } + + // test it on a double pointer + var pcount *int + err = db.GetContext(ctx, &pcount, "SELECT count(*) FROM person;") + if err != nil { + t.Error(err) + } + if *pcount != count { + t.Errorf("expected %d = %d", *pcount, count) + } + + // test Select... + sdest = []string{} + err = db.SelectContext(ctx, &sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") + if err != nil { + t.Error(err) + } + expected := []string{"Ben", "Bin", "Jason", "John"} + for i, got := range sdest { + if got != expected[i] { + t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) + } + } + + var nsdest []sql.NullString + err = db.SelectContext(ctx, &nsdest, "SELECT city FROM place ORDER BY city ASC") + if err != nil { + t.Error(err) + } + for _, val := range nsdest { + if val.Valid && val.String != "New York" { + t.Errorf("expected single valid result to be `New York`, but got %s", val.String) + } + } + }) +} + +// tests that sqlx will not panic when the wrong driver is passed because +// of an automatic nil dereference in sqlx.Open(), which was fixed. +func TestDoNotPanicOnConnectContext(t *testing.T) { + _, err := ConnectContext(context.Background(), "bogus", "hehe") + if err == nil { + t.Errorf("Should return error when using bogus driverName") + } +} + +func TestEmbeddedMapsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE message ( + string text, + properties text + );`, + drop: `drop table message;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + messages := []Message{ + {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, + {"Thanks, Joy", PropertyMap{"pull": "request"}}, + } + q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` + for _, m := range messages { + _, err := db.NamedExecContext(ctx, q1, m) + if err != nil { + t.Fatal(err) + } + } + var count int + err := db.GetContext(ctx, &count, "SELECT count(*) FROM message") + if err != nil { + t.Fatal(err) + } + if count != len(messages) { + t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) + } + + var m Message + err = db.GetContext(ctx, &m, "SELECT * FROM message LIMIT 1;") + if err != nil { + t.Fatal(err) + } + if m.Properties == nil { + t.Fatal("Expected m.Properties to not be nil, but it was.") + } + }) +} + +func TestIssue197Context(t *testing.T) { + // this test actually tests for a bug in database/sql: + // https://github.com/golang/go/issues/13905 + // this potentially makes _any_ named type that is an alias for []byte + // unsafe to use in a lot of different ways (basically, unsafe to hold + // onto after loading from the database). + t.Skip() + + type mybyte []byte + type Var struct{ Raw json.RawMessage } + type Var2 struct{ Raw []byte } + type Var3 struct{ Raw mybyte } + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + var err error + var v, q Var + if err = db.GetContext(ctx, &v, `SELECT '{"a": "b"}' AS raw`); err != nil { + t.Fatal(err) + } + if err = db.GetContext(ctx, &q, `SELECT 'null' AS raw`); err != nil { + t.Fatal(err) + } + + var v2, q2 Var2 + if err = db.GetContext(ctx, &v2, `SELECT '{"a": "b"}' AS raw`); err != nil { + t.Fatal(err) + } + if err = db.GetContext(ctx, &q2, `SELECT 'null' AS raw`); err != nil { + t.Fatal(err) + } + + var v3, q3 Var3 + if err = db.QueryRowContext(ctx, `SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { + t.Fatal(err) + } + if err = db.QueryRowContext(ctx, `SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { + t.Fatal(err) + } + t.Fail() + }) +} + +func TestInContext(t *testing.T) { + // some quite normal situations + type tr struct { + q string + args []interface{} + c int + } + tests := []tr{ + {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", + []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, + 7}, + {"SELECT * FROM foo WHERE x in (?)", + []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, + 8}, + } + for _, test := range tests { + q, a, err := In(test.q, test.args...) + if err != nil { + t.Error(err) + } + if len(a) != test.c { + t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) + } + if strings.Count(q, "?") != test.c { + t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) + } + } + + // too many bindVars, but no slices, so short circuits parsing + // i'm not sure if this is the right behavior; this query/arg combo + // might not work, but we shouldn't parse if we don't need to + { + orig := "SELECT * FROM foo WHERE x = ? AND y = ?" + q, a, err := In(orig, "foo", "bar", "baz") + if err != nil { + t.Error(err) + } + if len(a) != 3 { + t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) + } + if q != orig { + t.Error("Expected unchanged query.") + } + } + + tests = []tr{ + // too many bindvars; slice present so should return error during parse + {"SELECT * FROM foo WHERE x = ? and y = ?", + []interface{}{"foo", []int{1, 2, 3}, "bar"}, + 0}, + // empty slice, should return error before parse + {"SELECT * FROM foo WHERE x = ?", + []interface{}{[]int{}}, + 0}, + // too *few* bindvars, should return an error + {"SELECT * FROM foo WHERE x = ? AND y in (?)", + []interface{}{[]int{1, 2, 3}}, + 0}, + } + for _, test := range tests { + _, _, err := In(test.q, test.args...) + if err == nil { + t.Error("Expected an error, but got nil.") + } + } + RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { + loadDefaultFixtureContext(ctx, db, t) + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") + //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") + telcodes := []int{852, 65} + q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" + query, args, err := In(q, telcodes) + if err != nil { + t.Error(err) + } + query = db.Rebind(query) + places := []Place{} + err = db.SelectContext(ctx, &places, query, args...) + if err != nil { + t.Error(err) + } + if len(places) != 2 { + t.Fatalf("Expecting 2 results, got %d", len(places)) + } + if places[0].TelCode != 65 { + t.Errorf("Expecting singapore first, but got %#v", places[0]) + } + if places[1].TelCode != 852 { + t.Errorf("Expecting hong kong second, but got %#v", places[1]) + } + }) +} + +func TestEmbeddedLiteralsContext(t *testing.T) { + var schema = Schema{ + create: ` + CREATE TABLE x ( + k text + );`, + drop: `drop table x;`, + } + + RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { + type t1 struct { + K *string + } + type t2 struct { + Inline struct { + F string + } + K *string + } + + db.MustExecContext(ctx, db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") + + target := t1{} + err := db.GetContext(ctx, &target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + if err != nil { + t.Error(err) + } + if *target.K != "one" { + t.Error("Expected target.K to be `one`, got ", target.K) + } + + target2 := t2{} + err = db.GetContext(ctx, &target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") + if err != nil { + t.Error(err) + } + if *target2.K != "one" { + t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) + } + }) +} diff --git a/sqlx_test.go b/sqlx_test.go index 436d34b3..f87769db 100644 --- a/sqlx_test.go +++ b/sqlx_test.go @@ -791,7 +791,7 @@ func TestNamedQuery(t *testing.T) { email, place.id AS "place.id", place.name AS "place.name" - FROM placeperson + FROM placeperson INNER JOIN place ON place.id = placeperson.place_id WHERE place.id=:place.id`, pp) @@ -967,6 +967,9 @@ func TestUsage(t *testing.T) { t.Error("Expected an error") } err = stmt1.Get(&jason, "DoesNotExist User 2") + if err != nil { + t.Fatal(err) + } stmt2, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { @@ -987,6 +990,10 @@ func TestUsage(t *testing.T) { places := []*Place{} err = db.Select(&places, "SELECT telcode FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + usa, singsing, honkers := places[0], places[1], places[2] if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { @@ -1004,6 +1011,10 @@ func TestUsage(t *testing.T) { // this test also verifies that you can use either a []Struct{} or a []*Struct{} places2 := []Place{} err = db.Select(&places2, "SELECT * FROM place ORDER BY telcode ASC") + if err != nil { + t.Fatal(err) + } + usa, singsing, honkers = &places2[0], &places2[1], &places2[2] // this should return a type error that &p is not a pointer to a struct slice