From e281d84a44dd3d75e73c9267f4b9c0e9d74b58d8 Mon Sep 17 00:00:00 2001 From: Justin Li Date: Fri, 20 Oct 2017 15:27:00 -0400 Subject: [PATCH] Implement NamedValueChecker for mysqlConn * Also add conversions for additional types in ConvertValue ref https://github.com/golang/go/commit/d7c0de98a96893e5608358f7578c85be7ba12b25 --- AUTHORS | 1 + connection_go18.go | 5 ++ connection_go18_test.go | 30 ++++++++++ statement.go | 8 +++ statement_test.go | 119 +++++++++++++++++++++++++++++++++++++--- 5 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 connection_go18_test.go diff --git a/AUTHORS b/AUTHORS index ca6a1daeb..c405b8912 100644 --- a/AUTHORS +++ b/AUTHORS @@ -40,6 +40,7 @@ Jian Zhen Joshua Prunier Julien Lefevre Julien Schmidt +Justin Li Justin Nuß Kamil Dziedzic Kevin Malachowski diff --git a/connection_go18.go b/connection_go18.go index 48a9cca64..1306b70b7 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -195,3 +195,8 @@ func (mc *mysqlConn) startWatcher() { } }() } + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} diff --git a/connection_go18_test.go b/connection_go18_test.go new file mode 100644 index 000000000..2719ab3b7 --- /dev/null +++ b/connection_go18_test.go @@ -0,0 +1,30 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "database/sql/driver" + "testing" +) + +func TestCheckNamedValue(t *testing.T) { + value := driver.NamedValue{Value: ^uint64(0)} + x := &mysqlConn{} + err := x.CheckNamedValue(&value) + + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if value.Value != "18446744073709551615" { + t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value) + } +} diff --git a/statement.go b/statement.go index 628174b64..4870a307c 100644 --- a/statement.go +++ b/statement.go @@ -157,6 +157,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return int64(u64), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/statement_test.go b/statement_test.go index 8de4a8b26..98a6c1933 100644 --- a/statement_test.go +++ b/statement_test.go @@ -8,14 +8,119 @@ package mysql -import "testing" +import ( + "bytes" + "testing" +) -type customString string +func TestConvertDerivedString(t *testing.T) { + type derived string -func TestConvertValueCustomTypes(t *testing.T) { - var cstr customString = "string" - c := converter{} - if _, err := c.ConvertValue(cstr); err != nil { - t.Errorf("custom string type should be valid") + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Derived string type not convertible", err) + } + + if output != "value" { + t.Fatalf("Derived string type not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedByteSlice(t *testing.T) { + type derived []uint8 + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Byte slice not convertible", err) + } + + if bytes.Compare(output.([]byte), []byte("value")) != 0 { + t.Fatalf("Byte slice not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedUnsupportedSlice(t *testing.T) { + type derived []int + + _, err := converter{}.ConvertValue(derived{1}) + if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" { + t.Fatal("Unexpected error", err) + } +} + +func TestConvertDerivedBool(t *testing.T) { + type derived bool + + output, err := converter{}.ConvertValue(derived(true)) + if err != nil { + t.Fatal("Derived bool type not convertible", err) + } + + if output != true { + t.Fatalf("Derived bool type not converted, got %#v %T", output, output) + } +} + +func TestConvertPointer(t *testing.T) { + str := "value" + + output, err := converter{}.ConvertValue(&str) + if err != nil { + t.Fatal("Pointer type not convertible", err) + } + + if output != "value" { + t.Fatalf("Pointer type not converted, got %#v %T", output, output) + } +} + +func TestConvertSignedIntegers(t *testing.T) { + values := []interface{}{ + int8(-42), + int16(-42), + int32(-42), + int64(-42), + int(-42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(-42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } +} + +func TestConvertUnsignedIntegers(t *testing.T) { + values := []interface{}{ + uint8(42), + uint16(42), + uint32(42), + uint64(42), + uint(42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } + + output, err := converter{}.ConvertValue(^uint64(0)) + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if output != "18446744073709551615" { + t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output) } }