automatically find tables from fields if "From" is not specified

This commit is contained in:
lqs 2020-07-28 12:43:07 +08:00
parent 2d43ad3c78
commit fdd2b3d4ef
9 changed files with 122 additions and 57 deletions

View File

@ -6,6 +6,11 @@ import (
"strings"
)
const (
// SqlingoRuntimeVersion is the the runtime version of sqlingo
SqlingoRuntimeVersion = 2
)
// Model is the interface of generated model struct
type Model interface {
GetTable() Table

View File

@ -104,6 +104,10 @@ type expression struct {
isFalse bool
}
func (e expression) GetTable() Table {
return nil
}
type scope struct {
Database *database
Tables []Table

View File

@ -5,24 +5,38 @@ import "strings"
// Field is the interface of a generated field.
type Field interface {
Expression
GetTable() Table
}
// NumberField is the interface of a generated field of number type.
type NumberField interface {
Field
NumberExpression
}
// BooleanField is the interface of a generated field of boolean type.
type BooleanField interface {
Field
BooleanExpression
}
// StringField is the interface of a generated field of string type.
type StringField interface {
Field
StringExpression
}
func newFieldExpression(tableName string, fieldName string) expression {
type actualField struct {
expression
table Table
}
func (f actualField) GetTable() Table {
return f.table
}
func newField(table Table, fieldName string) actualField {
tableName := table.GetName()
tableNameSqlArray := quoteIdentifier(tableName)
fieldNameSqlArray := quoteIdentifier(fieldName)
@ -31,33 +45,36 @@ func newFieldExpression(tableName string, fieldName string) expression {
fullFieldNameSqlArray[dialect] = tableNameSqlArray[dialect] + "." + fieldNameSqlArray[dialect]
}
return expression{
builder: func(scope scope) (string, error) {
dialect := dialectUnknown
if scope.Database != nil {
dialect = scope.Database.dialect
}
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
return fullFieldNameSqlArray[dialect], nil
}
return fieldNameSqlArray[dialect], nil
return actualField{
expression: expression{
builder: func(scope scope) (string, error) {
dialect := dialectUnknown
if scope.Database != nil {
dialect = scope.Database.dialect
}
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
return fullFieldNameSqlArray[dialect], nil
}
return fieldNameSqlArray[dialect], nil
},
},
table: table,
}
}
// NewNumberField creates a reference to a number field. It should only be called from generated code.
func NewNumberField(tableName string, fieldName string) NumberField {
return newFieldExpression(tableName, fieldName)
func NewNumberField(table Table, fieldName string) NumberField {
return newField(table, fieldName)
}
// NewBooleanField creates a reference to a boolean field. It should only be called from generated code.
func NewBooleanField(tableName string, fieldName string) BooleanField {
return newFieldExpression(tableName, fieldName)
func NewBooleanField(table Table, fieldName string) BooleanField {
return newField(table, fieldName)
}
// NewStringField creates a reference to a string field. It should only be called from generated code.
func NewStringField(tableName string, fieldName string) StringField {
return newFieldExpression(tableName, fieldName)
func NewStringField(table Table, fieldName string) StringField {
return newField(table, fieldName)
}
type fieldList []Field

View File

@ -33,9 +33,10 @@ func (d dummyTable) GetFullFieldsSQL() string {
}
func TestField(t *testing.T) {
assertValue(t, NewNumberField("t1", "f1").Equals(1), "`t1`.`f1` = 1")
assertValue(t, NewBooleanField("t1", "f1").Equals(true), "`t1`.`f1` = 1")
assertValue(t, NewStringField("t1", "f1").Equals("x"), "`t1`.`f1` = 'x'")
t1 := NewTable("t1")
assertValue(t, NewNumberField(t1, "f1").Equals(1), "`t1`.`f1` = 1")
assertValue(t, NewBooleanField(t1, "f1").Equals(true), "`t1`.`f1` = 1")
assertValue(t, NewStringField(t1, "f1").Equals("x"), "`t1`.`f1` = 'x'")
sql, _ := fieldList{}.GetSQL(scope{
Tables: []Table{

View File

@ -24,10 +24,12 @@ type TestModel struct {
F2 string
}
var tTestTable = NewTable("test")
var Test = tTest{
Table: NewTable("test"),
F1: fTestF1{NewNumberField("test", "f1")},
F2: fTestF2{NewStringField("test", "f2")},
F1: fTestF1{NewNumberField(tTestTable, "f1")},
F2: fTestF2{NewStringField(tTestTable, "f2")},
}
func (m TestModel) GetTable() Table {

View File

@ -384,6 +384,21 @@ func (s selectBase) buildSelectBase(sb *strings.Builder) error {
sb.WriteString("DISTINCT ")
}
// find tables from fields if "From" is not specified
if len(s.scope.Tables) == 0 && len(s.fields) > 0 {
tableMap := make(map[string]Table)
for _, field := range s.fields {
table := field.GetTable()
if table == nil {
continue
}
tableMap[table.GetName()] = table
}
for _, table := range tableMap {
s.scope.Tables = append(s.scope.Tables, table)
}
}
fieldsSql, err := s.fields.GetSQL(s.scope)
if err != nil {
return err

View File

@ -5,25 +5,30 @@ import (
"testing"
)
type table1 struct {
type tTable1 struct {
Table
}
var Table1 = table1{
var Table1 = tTable1{
NewTable("table1"),
}
var field1 = NewNumberField("table1", "field1")
var field2 = NewNumberField("table1", "field2")
func (t table1) GetFields() []Field {
var table1 = NewTable("table1")
var field1 = NewNumberField(table1, "field1")
var field2 = NewNumberField(table1, "field2")
var table2 = NewTable("table2")
var field3 = NewNumberField(table2, "field3")
func (t tTable1) GetFields() []Field {
return []Field{field1, field2}
}
func (t table1) GetFieldsSQL() string {
func (t tTable1) GetFieldsSQL() string {
return "<fields sql>"
}
func (t table1) GetFullFieldsSQL() string {
func (t tTable1) GetFullFieldsSQL() string {
return "<full fields sql>"
}
@ -31,9 +36,6 @@ func TestSelect(t *testing.T) {
db := newMockDatabase()
assertValue(t, db.Select(1), "(SELECT 1)")
table2 := NewTable("table1")
field3 := NewNumberField("table2", "field1")
db.Select(field1).From(Table1).Where(field1.Equals(42)).Limit(10).GetSQL()
db.Select(field1, field2, field3, Count(1).As("count")).
@ -83,6 +85,16 @@ func TestCount(t *testing.T) {
assertLastSql(t, "SELECT EXISTS (SELECT `f1` FROM `test`)")
}
func TestSelectAutoFrom(t *testing.T) {
db := newMockDatabase()
_, _ = db.Select(field1, field2, 123).FetchFirst()
assertLastSql(t, "SELECT `field1`, `field2`, 123 FROM `table1`")
_, _ = db.Select(field1, field2, 123, field3).FetchFirst()
assertLastSql(t, "SELECT `table1`.`field1`, `table1`.`field2`, 123, `table2`.`field3` FROM `table1`, `table2`")
}
func TestFetchAll(t *testing.T) {
db := newMockDatabase()
@ -137,12 +149,10 @@ func TestUnion(t *testing.T) {
_, _ = db.SelectFrom(table1).UnionSelectFrom(table2).Where(cond1).FetchAll()
assertLastSql(t, "SELECT * FROM `table1` UNION SELECT * FROM `table2` WHERE <condition 1>")
_, _ = db.SelectFrom(table1).Where(cond1).
UnionSelectFrom(table2).Where(cond2).FetchAll()
assertLastSql(t, "SELECT * FROM `table1` WHERE <condition 1> UNION SELECT * FROM `table2` WHERE <condition 2>")
_, _ = db.SelectFrom(table1).Where(Raw("C1")).
UnionSelectFrom(table2).Where(Raw("C2")).
UnionSelect(3).From(table2).Where(Raw("C3")).
@ -151,15 +161,14 @@ func TestUnion(t *testing.T) {
UnionAllSelect(6).From(table2).Where(Raw("C6")).
UnionAllSelectDistinct(7).From(table2).Where(Raw("C7")).
FetchAll()
assertLastSql(t, "SELECT * FROM `table1` WHERE C1 " +
"UNION SELECT * FROM `table2` WHERE C2 " +
"UNION SELECT 3 FROM `table2` WHERE C3 " +
"UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 " +
"UNION ALL SELECT * FROM `table2` WHERE C5 " +
"UNION ALL SELECT 6 FROM `table2` WHERE C6 " +
assertLastSql(t, "SELECT * FROM `table1` WHERE C1 "+
"UNION SELECT * FROM `table2` WHERE C2 "+
"UNION SELECT 3 FROM `table2` WHERE C3 "+
"UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 "+
"UNION ALL SELECT * FROM `table2` WHERE C5 "+
"UNION ALL SELECT 6 FROM `table2` WHERE C6 "+
"UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7")
_, _ = db.SelectFrom(table1).Where(Raw("C1")).
UnionSelectFrom(table2).Where(Raw("C2")).
UnionSelect(3).From(table2).Where(Raw("C3")).
@ -168,13 +177,13 @@ func TestUnion(t *testing.T) {
UnionAllSelect(6).From(table2).Where(Raw("C6")).
UnionAllSelectDistinct(7).From(table2).Where(Raw("C7")).
Count()
assertLastSql(t, "SELECT COUNT(1) FROM (" +
"SELECT 1 FROM `table1` WHERE C1 " +
"UNION SELECT * FROM `table2` WHERE C2 " +
"UNION SELECT 3 FROM `table2` WHERE C3 " +
"UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 " +
"UNION ALL SELECT * FROM `table2` WHERE C5 " +
"UNION ALL SELECT 6 FROM `table2` WHERE C6 " +
"UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7" +
assertLastSql(t, "SELECT COUNT(1) FROM ("+
"SELECT 1 FROM `table1` WHERE C1 "+
"UNION SELECT * FROM `table2` WHERE C2 "+
"UNION SELECT 3 FROM `table2` WHERE C3 "+
"UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 "+
"UNION ALL SELECT * FROM `table2` WHERE C5 "+
"UNION ALL SELECT 6 FROM `table2` WHERE C6 "+
"UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7"+
") AS t")
}

View File

@ -11,6 +11,10 @@ import (
"unicode"
)
const (
sqlingoGeneratorVersion = 2
)
type schemaFetcher interface {
GetDatabaseName() (dbName string, err error)
GetTableNames() (tableNames []string, err error)
@ -128,6 +132,12 @@ func generate(driverName string, dataSourceName string, tableNames []string) (st
code += "package " + dbName + "_dsl\n"
code += "import . \"github.com/lqs/sqlingo\"\n\n"
code += "type sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame uint32\n\n"
sqlingoGeneratorVersionString := strconv.Itoa(sqlingoGeneratorVersion)
code += "const _ = sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame(SqlingoRuntimeVersion - " + sqlingoGeneratorVersionString + ")\n"
code += "const _ = sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame(" + sqlingoGeneratorVersionString + " - SqlingoRuntimeVersion)\n\n"
code += "type table interface {\n"
code += "\tTable\n"
code += "}\n\n"
@ -185,17 +195,18 @@ func generateTable(schemaFetcher schemaFetcher, tableName string) (string, error
return "", err
}
tableLines := ""
modelLines := ""
objectLines := "\ttable: NewTable(" + strconv.Quote(tableName) + "),\n\n"
fieldCaseLines := ""
classLines := ""
className := convertCase(tableName)
tableStructName := "t" + className
tableObjectName := "o" + className
modelClassName := className + "Model"
tableLines := ""
modelLines := ""
objectLines := "\ttable: " + tableObjectName + ",\n\n"
fieldCaseLines := ""
classLines := ""
fields := ""
fieldsSQL := ""
fullFieldsSQL := ""
@ -226,7 +237,7 @@ func generateTable(schemaFetcher schemaFetcher, tableName string) (string, error
objectLines += commentLine
objectLines += "\t" + goName + ": " + fieldStructName + "{"
objectLines += "New" + fieldClass + "(" + strconv.Quote(tableName) + ", " + strconv.Quote(fieldDescriptor.Name) + ")},\n"
objectLines += "New" + fieldClass + "(" + tableObjectName + ", " + strconv.Quote(fieldDescriptor.Name) + ")},\n"
fieldCaseLines += "\tcase " + strconv.Quote(fieldDescriptor.Name) + ": return t." + goName + "\n"
@ -253,6 +264,7 @@ func generateTable(schemaFetcher schemaFetcher, tableName string) (string, error
code += classLines
code += "var " + tableObjectName + " = NewTable(" + strconv.Quote(tableName) + ")\n"
code += "var " + className + " = " + tableStructName + "{\n"
code += objectLines
code += "}\n\n"

View File

@ -10,7 +10,7 @@ func TestTable(t *testing.T) {
}
func TestDerivedTable(t *testing.T) {
dummyFields := []Field{NewNumberField("table", "field")}
dummyFields := []Field{NewNumberField(NewTable("table"), "field")}
dt := derivedTable{
name: "t",
selectStatus: selectStatus{