diff --git a/buf.go b/buf.go index 9f417a1e..608150a3 100644 --- a/buf.go +++ b/buf.go @@ -48,6 +48,12 @@ func (b *readBuf) byte() byte { type writeBuf []byte +func newWriteBuf(c byte) *writeBuf { + b := make(writeBuf, 5) + b[0] = c + return &b +} + func (b *writeBuf) int32(n int) { x := make([]byte, 4) binary.BigEndian.PutUint32(x, uint32(n)) diff --git a/conn.go b/conn.go index 6e8a5751..13d09c3c 100644 --- a/conn.go +++ b/conn.go @@ -35,16 +35,10 @@ func init() { } type conn struct { - c net.Conn - buf *bufio.Reader - namei int - scratch [512]byte -} - -func (c *conn) writeBuf(b byte) *writeBuf { - c.scratch[0] = b - w := writeBuf(c.scratch[:5]) - return &w + c net.Conn + buf *bufio.Reader + namei int + singlerowmode bool } func Open(name string) (_ driver.Conn, err error) { @@ -87,6 +81,9 @@ func Open(name string) (_ driver.Conn, err error) { cn := &conn{c: c} cn.ssl(o) cn.buf = bufio.NewReader(cn.c) + if o.Get("singlerowmode") == "true" { + cn.singlerowmode = true + } cn.startup(o) return cn, nil } @@ -155,7 +152,7 @@ func (cn *conn) gname() string { func (cn *conn) simpleQuery(q string) (res driver.Result, err error) { defer errRecover(&err) - b := cn.writeBuf('Q') + b := newWriteBuf('Q') b.string(q) cn.send(b) @@ -183,18 +180,18 @@ func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) { st := &stmt{cn: cn, name: stmtName, query: q} - b := cn.writeBuf('P') + b := newWriteBuf('P') b.string(st.name) b.string(q) b.int16(0) cn.send(b) - b = cn.writeBuf('D') + b = newWriteBuf('D') b.byte('S') b.string(st.name) cn.send(b) - cn.send(cn.writeBuf('S')) + cn.send(newWriteBuf('S')) for { t, r := cn.recv1() @@ -240,7 +237,7 @@ func (cn *conn) Prepare(q string) (driver.Stmt, error) { func (cn *conn) Close() (err error) { defer errRecover(&err) - cn.send(cn.writeBuf('X')) + cn.send(newWriteBuf('X')) return cn.c.Close() } @@ -303,27 +300,20 @@ func (cn *conn) recv() (t byte, r *readBuf) { } func (cn *conn) recv1() (byte, *readBuf) { - x := cn.scratch[:5] + x := make([]byte, 5) _, err := io.ReadFull(cn.buf, x) if err != nil { panic(err) } - c := x[0] b := readBuf(x[1:]) - n := b.int32() - 4 - var y []byte - if n <= len(cn.scratch) { - y = cn.scratch[:n] - } else { - y = make([]byte, n) - } + y := make([]byte, b.int32()-4) _, err = io.ReadFull(cn.buf, y) if err != nil { panic(err) } - return c, (*readBuf)(&y) + return x[0], (*readBuf)(&y) } func (cn *conn) ssl(o Values) { @@ -339,11 +329,11 @@ func (cn *conn) ssl(o Values) { errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode) } - w := cn.writeBuf(0) + w := newWriteBuf(0) w.int32(80877103) cn.send(w) - b := cn.scratch[:1] + b := make([]byte, 1) _, err := io.ReadFull(cn.c, b) if err != nil { panic(err) @@ -357,7 +347,7 @@ func (cn *conn) ssl(o Values) { } func (cn *conn) startup(o Values) { - w := cn.writeBuf(0) + w := newWriteBuf(0) w.int32(196608) w.string("user") w.string(o.Get("user")) @@ -385,7 +375,7 @@ func (cn *conn) auth(r *readBuf, o Values) { case 0: // OK case 3: - w := cn.writeBuf('p') + w := newWriteBuf('p') w.string(o.Get("password")) cn.send(w) @@ -399,7 +389,7 @@ func (cn *conn) auth(r *readBuf, o Values) { } case 5: s := string(r.next(4)) - w := cn.writeBuf('p') + w := newWriteBuf('p') w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) cn.send(w) @@ -433,12 +423,12 @@ func (st *stmt) Close() (err error) { defer errRecover(&err) - w := st.cn.writeBuf('C') + w := newWriteBuf('C') w.byte('S') w.string(st.name) st.cn.send(w) - st.cn.send(st.cn.writeBuf('S')) + st.cn.send(newWriteBuf('S')) t, _ := st.cn.recv() if t != '3' { @@ -457,7 +447,15 @@ func (st *stmt) Close() (err error) { func (st *stmt) Query(v []driver.Value) (_ driver.Rows, err error) { defer errRecover(&err) st.exec(v) - return &rows{st: st}, nil + r := &rows{st: st} + if st.cn.singlerowmode { + return r, nil + } + + // fetch all rows + rc := &rowscomplete{} + err = rc.load(r) + return rc, err } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { @@ -489,7 +487,7 @@ func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { } func (st *stmt) exec(v []driver.Value) { - w := st.cn.writeBuf('B') + w := newWriteBuf('B') w.string("") w.string(st.name) w.int16(0) @@ -506,12 +504,12 @@ func (st *stmt) exec(v []driver.Value) { w.int16(0) st.cn.send(w) - w = st.cn.writeBuf('E') + w = newWriteBuf('E') w.string("") w.int32(0) st.cn.send(w) - st.cn.send(st.cn.writeBuf('S')) + st.cn.send(newWriteBuf('S')) var err error for { @@ -612,6 +610,62 @@ func (rs *rows) Next(dest []driver.Value) (err error) { panic("not reached") } +type rowscomplete struct { + rows [][]driver.Value + current int + done bool + cols []string +} + +func (rs *rowscomplete) load(r *rows) (err error) { + defer r.Close() + + rs.rows = make([][]driver.Value, 0) + rs.cols = r.st.cols + rs.current = -1 + + // fetch all records + for { + dest := make([]driver.Value, len(rs.cols)) + if err = r.Next(dest); err != nil { + break + } + rs.rows = append(rs.rows, dest) + } + + if err == io.EOF { + return nil + } + + return +} + +func (rs *rowscomplete) Close() error { + return nil +} + +func (rs *rowscomplete) Columns() []string { + return rs.cols +} + +func (rs *rowscomplete) Next(dest []driver.Value) (err error) { + if rs.done { + return io.EOF + } + + rs.current++ + if rs.current > len(rs.rows)-1 { + rs.done = true + return io.EOF + } + + for i, v := range rs.rows[rs.current] { + dest[i] = v + } + + return nil +} + func md5s(s string) string { h := md5.New() h.Write([]byte(s))