mirror of https://github.com/lqs/sqlingo
291 lines
6.0 KiB
Go
291 lines
6.0 KiB
Go
package sqlingo
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// Scanner is the interface that wraps the Scan method.
|
|
type Scanner interface {
|
|
Scan(dest ...interface{}) error
|
|
}
|
|
|
|
// Cursor is the interface of a row cursor.
|
|
type Cursor interface {
|
|
Next() bool
|
|
Scan(dest ...interface{}) error
|
|
GetMap() (map[string]value, error)
|
|
Close() error
|
|
}
|
|
|
|
type cursor struct {
|
|
rows *sql.Rows
|
|
}
|
|
|
|
func (c cursor) Next() bool {
|
|
return c.rows.Next()
|
|
}
|
|
|
|
var timeType = reflect.TypeOf(time.Time{})
|
|
|
|
var simpleTimeLayoutRegexp = regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.(\d+))?$`)
|
|
|
|
func guessTimeLayout(s string) string {
|
|
matches := simpleTimeLayoutRegexp.FindStringSubmatch(s)
|
|
if len(matches) > 0 {
|
|
var sb strings.Builder
|
|
sb.Grow(32)
|
|
sb.WriteString("2006-01-02 15:04:05")
|
|
if matches[1] != "" {
|
|
sb.WriteString(".")
|
|
for i := 0; i < len(matches[2]); i++ {
|
|
sb.WriteByte('0')
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|
|
return time.RFC3339Nano
|
|
}
|
|
|
|
func parseTime(s string) (time.Time, error) {
|
|
if strings.HasPrefix(s, "0000-00-00") {
|
|
// MySQL zero date
|
|
return time.Time{}, nil
|
|
}
|
|
layout := guessTimeLayout(s)
|
|
t, err := time.Parse(layout, s)
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf("unknown time format %s: %w", s, err)
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
func isScanner(val reflect.Value) bool {
|
|
_, ok := val.Addr().Interface().(sql.Scanner)
|
|
return ok
|
|
}
|
|
|
|
func preparePointers(val reflect.Value, scans *[]interface{}) error {
|
|
kind := val.Kind()
|
|
switch kind {
|
|
case reflect.Bool,
|
|
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
|
reflect.Uint, reflect.Uint8, reflect.Uint32, reflect.Uint64,
|
|
reflect.Float32, reflect.Float64,
|
|
reflect.String:
|
|
if addr := val.Addr(); addr.CanInterface() {
|
|
*scans = append(*scans, addr.Interface())
|
|
}
|
|
case reflect.Struct:
|
|
if canScan := val.Type() == timeType || isScanner(val); canScan {
|
|
*scans = append(*scans, val.Addr().Interface())
|
|
return nil
|
|
}
|
|
for j := 0; j < val.NumField(); j++ {
|
|
field := val.Field(j)
|
|
if field.Kind() == reflect.Interface {
|
|
continue
|
|
}
|
|
if err := preparePointers(field, scans); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
case reflect.Ptr:
|
|
toType := val.Type().Elem()
|
|
switch toType.Kind() {
|
|
case reflect.Bool,
|
|
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
|
reflect.Uint, reflect.Uint8, reflect.Uint32, reflect.Uint64,
|
|
reflect.Float32, reflect.Float64,
|
|
reflect.String:
|
|
*scans = append(*scans, val.Addr().Interface())
|
|
case reflect.Struct:
|
|
if toType == reflect.TypeOf(time.Time{}) {
|
|
*scans = append(*scans, val.Addr().Interface())
|
|
} else {
|
|
to := reflect.New(toType).Elem()
|
|
val.Set(to.Addr())
|
|
err := preparePointers(to, scans)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
default:
|
|
to := reflect.New(toType).Elem()
|
|
val.Set(to.Addr())
|
|
err := preparePointers(to, scans)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
case reflect.Slice:
|
|
if _, ok := (val.Interface()).([]byte); ok {
|
|
*scans = append(*scans, val.Addr().Interface())
|
|
} else {
|
|
return fmt.Errorf("unknown type []%s", val.Type().Elem().Kind().String())
|
|
}
|
|
default:
|
|
return fmt.Errorf("unknown type %s", kind.String())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseBool(s []byte) (bool, error) {
|
|
if len(s) == 1 {
|
|
if s[0] == 0 {
|
|
return false, nil
|
|
} else if s[0] == 1 {
|
|
return true, nil
|
|
}
|
|
}
|
|
return strconv.ParseBool(string(s))
|
|
}
|
|
|
|
func (c cursor) Scan(dest ...interface{}) error {
|
|
columns, err := c.rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
values := make([]interface{}, len(columns))
|
|
pointers := make([]interface{}, len(columns))
|
|
for i := range columns {
|
|
pointers[i] = &values[i]
|
|
}
|
|
if err := c.rows.Scan(pointers...); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(dest) == 0 {
|
|
// dry run
|
|
return nil
|
|
}
|
|
|
|
var scans []interface{}
|
|
for i, item := range dest {
|
|
if reflect.ValueOf(item).Kind() != reflect.Ptr {
|
|
return fmt.Errorf("argument %d is not pointer", i)
|
|
}
|
|
|
|
val := reflect.Indirect(reflect.ValueOf(item))
|
|
|
|
err := preparePointers(val, &scans)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
pbs := make(map[int]*bool)
|
|
ppbs := make(map[int]**bool)
|
|
pts := make(map[int]*time.Time)
|
|
ppts := make(map[int]**time.Time)
|
|
|
|
for i, scan := range scans {
|
|
switch scan.(type) {
|
|
case *bool:
|
|
var s []uint8
|
|
scans[i] = &s
|
|
pbs[i] = scan.(*bool)
|
|
case **bool:
|
|
var s *[]uint8
|
|
scans[i] = &s
|
|
ppbs[i] = scan.(**bool)
|
|
case *time.Time:
|
|
var s string
|
|
scans[i] = &s
|
|
pts[i] = scan.(*time.Time)
|
|
case **time.Time:
|
|
var s sql.NullString
|
|
scans[i] = &s
|
|
ppts[i] = scan.(**time.Time)
|
|
}
|
|
}
|
|
|
|
if err := c.rows.Scan(scans...); err != nil {
|
|
return err
|
|
}
|
|
|
|
for i, pb := range pbs {
|
|
if *(scans[i].(*[]byte)) == nil {
|
|
return fmt.Errorf("field %d is null", i)
|
|
}
|
|
b, err := parseBool(*(scans[i].(*[]byte)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*pb = b
|
|
}
|
|
for i, ppb := range ppbs {
|
|
if *(scans[i].(**[]uint8)) == nil {
|
|
*ppb = nil
|
|
} else {
|
|
b, err := parseBool(**(scans[i].(**[]byte)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*ppb = &b
|
|
}
|
|
}
|
|
for i := range pts {
|
|
s := scans[i].(*string)
|
|
if s == nil {
|
|
return fmt.Errorf("field %d is null", i)
|
|
}
|
|
t, err := parseTime(*s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*pts[i] = t
|
|
|
|
}
|
|
for i := range ppts {
|
|
nullString := scans[i].(*sql.NullString)
|
|
if nullString == nil {
|
|
return fmt.Errorf("field %d is null", i)
|
|
}
|
|
if !nullString.Valid {
|
|
*ppts[i] = nil
|
|
} else {
|
|
t, err := parseTime(nullString.String)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*ppts[i] = &t
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (c cursor) GetMap() (result map[string]value, err error) {
|
|
columns, err := c.rows.Columns()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
columnCount := len(columns)
|
|
values := make([]interface{}, columnCount)
|
|
for i := 0; i < columnCount; i++ {
|
|
var value *string
|
|
values[i] = &value
|
|
}
|
|
if err = c.rows.Scan(values...); err != nil {
|
|
return
|
|
}
|
|
|
|
result = make(map[string]value, columnCount)
|
|
for i, column := range columns {
|
|
result[column] = value{stringValue: *values[i].(**string)}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c cursor) Close() error {
|
|
return c.rows.Close()
|
|
}
|