mirror of https://github.com/lqs/sqlingo
automatically find tables from fields if "From" is not specified
This commit is contained in:
parent
2d43ad3c78
commit
fdd2b3d4ef
|
@ -6,6 +6,11 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// SqlingoRuntimeVersion is the the runtime version of sqlingo
|
||||
SqlingoRuntimeVersion = 2
|
||||
)
|
||||
|
||||
// Model is the interface of generated model struct
|
||||
type Model interface {
|
||||
GetTable() Table
|
||||
|
|
|
@ -104,6 +104,10 @@ type expression struct {
|
|||
isFalse bool
|
||||
}
|
||||
|
||||
func (e expression) GetTable() Table {
|
||||
return nil
|
||||
}
|
||||
|
||||
type scope struct {
|
||||
Database *database
|
||||
Tables []Table
|
||||
|
|
51
field.go
51
field.go
|
@ -5,24 +5,38 @@ import "strings"
|
|||
// Field is the interface of a generated field.
|
||||
type Field interface {
|
||||
Expression
|
||||
GetTable() Table
|
||||
}
|
||||
|
||||
// NumberField is the interface of a generated field of number type.
|
||||
type NumberField interface {
|
||||
Field
|
||||
NumberExpression
|
||||
}
|
||||
|
||||
// BooleanField is the interface of a generated field of boolean type.
|
||||
type BooleanField interface {
|
||||
Field
|
||||
BooleanExpression
|
||||
}
|
||||
|
||||
// StringField is the interface of a generated field of string type.
|
||||
type StringField interface {
|
||||
Field
|
||||
StringExpression
|
||||
}
|
||||
|
||||
func newFieldExpression(tableName string, fieldName string) expression {
|
||||
type actualField struct {
|
||||
expression
|
||||
table Table
|
||||
}
|
||||
|
||||
func (f actualField) GetTable() Table {
|
||||
return f.table
|
||||
}
|
||||
|
||||
func newField(table Table, fieldName string) actualField {
|
||||
tableName := table.GetName()
|
||||
tableNameSqlArray := quoteIdentifier(tableName)
|
||||
fieldNameSqlArray := quoteIdentifier(fieldName)
|
||||
|
||||
|
@ -31,33 +45,36 @@ func newFieldExpression(tableName string, fieldName string) expression {
|
|||
fullFieldNameSqlArray[dialect] = tableNameSqlArray[dialect] + "." + fieldNameSqlArray[dialect]
|
||||
}
|
||||
|
||||
return expression{
|
||||
builder: func(scope scope) (string, error) {
|
||||
dialect := dialectUnknown
|
||||
if scope.Database != nil {
|
||||
dialect = scope.Database.dialect
|
||||
}
|
||||
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
|
||||
return fullFieldNameSqlArray[dialect], nil
|
||||
}
|
||||
return fieldNameSqlArray[dialect], nil
|
||||
return actualField{
|
||||
expression: expression{
|
||||
builder: func(scope scope) (string, error) {
|
||||
dialect := dialectUnknown
|
||||
if scope.Database != nil {
|
||||
dialect = scope.Database.dialect
|
||||
}
|
||||
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
|
||||
return fullFieldNameSqlArray[dialect], nil
|
||||
}
|
||||
return fieldNameSqlArray[dialect], nil
|
||||
},
|
||||
},
|
||||
table: table,
|
||||
}
|
||||
}
|
||||
|
||||
// NewNumberField creates a reference to a number field. It should only be called from generated code.
|
||||
func NewNumberField(tableName string, fieldName string) NumberField {
|
||||
return newFieldExpression(tableName, fieldName)
|
||||
func NewNumberField(table Table, fieldName string) NumberField {
|
||||
return newField(table, fieldName)
|
||||
}
|
||||
|
||||
// NewBooleanField creates a reference to a boolean field. It should only be called from generated code.
|
||||
func NewBooleanField(tableName string, fieldName string) BooleanField {
|
||||
return newFieldExpression(tableName, fieldName)
|
||||
func NewBooleanField(table Table, fieldName string) BooleanField {
|
||||
return newField(table, fieldName)
|
||||
}
|
||||
|
||||
// NewStringField creates a reference to a string field. It should only be called from generated code.
|
||||
func NewStringField(tableName string, fieldName string) StringField {
|
||||
return newFieldExpression(tableName, fieldName)
|
||||
func NewStringField(table Table, fieldName string) StringField {
|
||||
return newField(table, fieldName)
|
||||
}
|
||||
|
||||
type fieldList []Field
|
||||
|
|
|
@ -33,9 +33,10 @@ func (d dummyTable) GetFullFieldsSQL() string {
|
|||
}
|
||||
|
||||
func TestField(t *testing.T) {
|
||||
assertValue(t, NewNumberField("t1", "f1").Equals(1), "`t1`.`f1` = 1")
|
||||
assertValue(t, NewBooleanField("t1", "f1").Equals(true), "`t1`.`f1` = 1")
|
||||
assertValue(t, NewStringField("t1", "f1").Equals("x"), "`t1`.`f1` = 'x'")
|
||||
t1 := NewTable("t1")
|
||||
assertValue(t, NewNumberField(t1, "f1").Equals(1), "`t1`.`f1` = 1")
|
||||
assertValue(t, NewBooleanField(t1, "f1").Equals(true), "`t1`.`f1` = 1")
|
||||
assertValue(t, NewStringField(t1, "f1").Equals("x"), "`t1`.`f1` = 'x'")
|
||||
|
||||
sql, _ := fieldList{}.GetSQL(scope{
|
||||
Tables: []Table{
|
||||
|
|
|
@ -24,10 +24,12 @@ type TestModel struct {
|
|||
F2 string
|
||||
}
|
||||
|
||||
var tTestTable = NewTable("test")
|
||||
|
||||
var Test = tTest{
|
||||
Table: NewTable("test"),
|
||||
F1: fTestF1{NewNumberField("test", "f1")},
|
||||
F2: fTestF2{NewStringField("test", "f2")},
|
||||
F1: fTestF1{NewNumberField(tTestTable, "f1")},
|
||||
F2: fTestF2{NewStringField(tTestTable, "f2")},
|
||||
}
|
||||
|
||||
func (m TestModel) GetTable() Table {
|
||||
|
|
15
select.go
15
select.go
|
@ -384,6 +384,21 @@ func (s selectBase) buildSelectBase(sb *strings.Builder) error {
|
|||
sb.WriteString("DISTINCT ")
|
||||
}
|
||||
|
||||
// find tables from fields if "From" is not specified
|
||||
if len(s.scope.Tables) == 0 && len(s.fields) > 0 {
|
||||
tableMap := make(map[string]Table)
|
||||
for _, field := range s.fields {
|
||||
table := field.GetTable()
|
||||
if table == nil {
|
||||
continue
|
||||
}
|
||||
tableMap[table.GetName()] = table
|
||||
}
|
||||
for _, table := range tableMap {
|
||||
s.scope.Tables = append(s.scope.Tables, table)
|
||||
}
|
||||
}
|
||||
|
||||
fieldsSql, err := s.fields.GetSQL(s.scope)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -5,25 +5,30 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
type table1 struct {
|
||||
type tTable1 struct {
|
||||
Table
|
||||
}
|
||||
|
||||
var Table1 = table1{
|
||||
var Table1 = tTable1{
|
||||
NewTable("table1"),
|
||||
}
|
||||
var field1 = NewNumberField("table1", "field1")
|
||||
var field2 = NewNumberField("table1", "field2")
|
||||
|
||||
func (t table1) GetFields() []Field {
|
||||
var table1 = NewTable("table1")
|
||||
var field1 = NewNumberField(table1, "field1")
|
||||
var field2 = NewNumberField(table1, "field2")
|
||||
|
||||
var table2 = NewTable("table2")
|
||||
var field3 = NewNumberField(table2, "field3")
|
||||
|
||||
func (t tTable1) GetFields() []Field {
|
||||
return []Field{field1, field2}
|
||||
}
|
||||
|
||||
func (t table1) GetFieldsSQL() string {
|
||||
func (t tTable1) GetFieldsSQL() string {
|
||||
return "<fields sql>"
|
||||
}
|
||||
|
||||
func (t table1) GetFullFieldsSQL() string {
|
||||
func (t tTable1) GetFullFieldsSQL() string {
|
||||
return "<full fields sql>"
|
||||
}
|
||||
|
||||
|
@ -31,9 +36,6 @@ func TestSelect(t *testing.T) {
|
|||
db := newMockDatabase()
|
||||
assertValue(t, db.Select(1), "(SELECT 1)")
|
||||
|
||||
table2 := NewTable("table1")
|
||||
field3 := NewNumberField("table2", "field1")
|
||||
|
||||
db.Select(field1).From(Table1).Where(field1.Equals(42)).Limit(10).GetSQL()
|
||||
|
||||
db.Select(field1, field2, field3, Count(1).As("count")).
|
||||
|
@ -83,6 +85,16 @@ func TestCount(t *testing.T) {
|
|||
assertLastSql(t, "SELECT EXISTS (SELECT `f1` FROM `test`)")
|
||||
}
|
||||
|
||||
func TestSelectAutoFrom(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
|
||||
_, _ = db.Select(field1, field2, 123).FetchFirst()
|
||||
assertLastSql(t, "SELECT `field1`, `field2`, 123 FROM `table1`")
|
||||
|
||||
_, _ = db.Select(field1, field2, 123, field3).FetchFirst()
|
||||
assertLastSql(t, "SELECT `table1`.`field1`, `table1`.`field2`, 123, `table2`.`field3` FROM `table1`, `table2`")
|
||||
}
|
||||
|
||||
func TestFetchAll(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
|
||||
|
@ -137,12 +149,10 @@ func TestUnion(t *testing.T) {
|
|||
_, _ = db.SelectFrom(table1).UnionSelectFrom(table2).Where(cond1).FetchAll()
|
||||
assertLastSql(t, "SELECT * FROM `table1` UNION SELECT * FROM `table2` WHERE <condition 1>")
|
||||
|
||||
|
||||
_, _ = db.SelectFrom(table1).Where(cond1).
|
||||
UnionSelectFrom(table2).Where(cond2).FetchAll()
|
||||
assertLastSql(t, "SELECT * FROM `table1` WHERE <condition 1> UNION SELECT * FROM `table2` WHERE <condition 2>")
|
||||
|
||||
|
||||
_, _ = db.SelectFrom(table1).Where(Raw("C1")).
|
||||
UnionSelectFrom(table2).Where(Raw("C2")).
|
||||
UnionSelect(3).From(table2).Where(Raw("C3")).
|
||||
|
@ -151,15 +161,14 @@ func TestUnion(t *testing.T) {
|
|||
UnionAllSelect(6).From(table2).Where(Raw("C6")).
|
||||
UnionAllSelectDistinct(7).From(table2).Where(Raw("C7")).
|
||||
FetchAll()
|
||||
assertLastSql(t, "SELECT * FROM `table1` WHERE C1 " +
|
||||
"UNION SELECT * FROM `table2` WHERE C2 " +
|
||||
"UNION SELECT 3 FROM `table2` WHERE C3 " +
|
||||
"UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 " +
|
||||
"UNION ALL SELECT * FROM `table2` WHERE C5 " +
|
||||
"UNION ALL SELECT 6 FROM `table2` WHERE C6 " +
|
||||
assertLastSql(t, "SELECT * FROM `table1` WHERE C1 "+
|
||||
"UNION SELECT * FROM `table2` WHERE C2 "+
|
||||
"UNION SELECT 3 FROM `table2` WHERE C3 "+
|
||||
"UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 "+
|
||||
"UNION ALL SELECT * FROM `table2` WHERE C5 "+
|
||||
"UNION ALL SELECT 6 FROM `table2` WHERE C6 "+
|
||||
"UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7")
|
||||
|
||||
|
||||
_, _ = db.SelectFrom(table1).Where(Raw("C1")).
|
||||
UnionSelectFrom(table2).Where(Raw("C2")).
|
||||
UnionSelect(3).From(table2).Where(Raw("C3")).
|
||||
|
@ -168,13 +177,13 @@ func TestUnion(t *testing.T) {
|
|||
UnionAllSelect(6).From(table2).Where(Raw("C6")).
|
||||
UnionAllSelectDistinct(7).From(table2).Where(Raw("C7")).
|
||||
Count()
|
||||
assertLastSql(t, "SELECT COUNT(1) FROM (" +
|
||||
"SELECT 1 FROM `table1` WHERE C1 " +
|
||||
"UNION SELECT * FROM `table2` WHERE C2 " +
|
||||
"UNION SELECT 3 FROM `table2` WHERE C3 " +
|
||||
"UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 " +
|
||||
"UNION ALL SELECT * FROM `table2` WHERE C5 " +
|
||||
"UNION ALL SELECT 6 FROM `table2` WHERE C6 " +
|
||||
"UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7" +
|
||||
assertLastSql(t, "SELECT COUNT(1) FROM ("+
|
||||
"SELECT 1 FROM `table1` WHERE C1 "+
|
||||
"UNION SELECT * FROM `table2` WHERE C2 "+
|
||||
"UNION SELECT 3 FROM `table2` WHERE C3 "+
|
||||
"UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 "+
|
||||
"UNION ALL SELECT * FROM `table2` WHERE C5 "+
|
||||
"UNION ALL SELECT 6 FROM `table2` WHERE C6 "+
|
||||
"UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7"+
|
||||
") AS t")
|
||||
}
|
||||
|
|
|
@ -11,6 +11,10 @@ import (
|
|||
"unicode"
|
||||
)
|
||||
|
||||
const (
|
||||
sqlingoGeneratorVersion = 2
|
||||
)
|
||||
|
||||
type schemaFetcher interface {
|
||||
GetDatabaseName() (dbName string, err error)
|
||||
GetTableNames() (tableNames []string, err error)
|
||||
|
@ -128,6 +132,12 @@ func generate(driverName string, dataSourceName string, tableNames []string) (st
|
|||
code += "package " + dbName + "_dsl\n"
|
||||
code += "import . \"github.com/lqs/sqlingo\"\n\n"
|
||||
|
||||
code += "type sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame uint32\n\n"
|
||||
|
||||
sqlingoGeneratorVersionString := strconv.Itoa(sqlingoGeneratorVersion)
|
||||
code += "const _ = sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame(SqlingoRuntimeVersion - " + sqlingoGeneratorVersionString + ")\n"
|
||||
code += "const _ = sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame(" + sqlingoGeneratorVersionString + " - SqlingoRuntimeVersion)\n\n"
|
||||
|
||||
code += "type table interface {\n"
|
||||
code += "\tTable\n"
|
||||
code += "}\n\n"
|
||||
|
@ -185,17 +195,18 @@ func generateTable(schemaFetcher schemaFetcher, tableName string) (string, error
|
|||
return "", err
|
||||
}
|
||||
|
||||
tableLines := ""
|
||||
modelLines := ""
|
||||
objectLines := "\ttable: NewTable(" + strconv.Quote(tableName) + "),\n\n"
|
||||
fieldCaseLines := ""
|
||||
classLines := ""
|
||||
|
||||
className := convertCase(tableName)
|
||||
tableStructName := "t" + className
|
||||
tableObjectName := "o" + className
|
||||
|
||||
modelClassName := className + "Model"
|
||||
|
||||
tableLines := ""
|
||||
modelLines := ""
|
||||
objectLines := "\ttable: " + tableObjectName + ",\n\n"
|
||||
fieldCaseLines := ""
|
||||
classLines := ""
|
||||
|
||||
fields := ""
|
||||
fieldsSQL := ""
|
||||
fullFieldsSQL := ""
|
||||
|
@ -226,7 +237,7 @@ func generateTable(schemaFetcher schemaFetcher, tableName string) (string, error
|
|||
|
||||
objectLines += commentLine
|
||||
objectLines += "\t" + goName + ": " + fieldStructName + "{"
|
||||
objectLines += "New" + fieldClass + "(" + strconv.Quote(tableName) + ", " + strconv.Quote(fieldDescriptor.Name) + ")},\n"
|
||||
objectLines += "New" + fieldClass + "(" + tableObjectName + ", " + strconv.Quote(fieldDescriptor.Name) + ")},\n"
|
||||
|
||||
fieldCaseLines += "\tcase " + strconv.Quote(fieldDescriptor.Name) + ": return t." + goName + "\n"
|
||||
|
||||
|
@ -253,6 +264,7 @@ func generateTable(schemaFetcher schemaFetcher, tableName string) (string, error
|
|||
|
||||
code += classLines
|
||||
|
||||
code += "var " + tableObjectName + " = NewTable(" + strconv.Quote(tableName) + ")\n"
|
||||
code += "var " + className + " = " + tableStructName + "{\n"
|
||||
code += objectLines
|
||||
code += "}\n\n"
|
||||
|
|
|
@ -10,7 +10,7 @@ func TestTable(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDerivedTable(t *testing.T) {
|
||||
dummyFields := []Field{NewNumberField("table", "field")}
|
||||
dummyFields := []Field{NewNumberField(NewTable("table"), "field")}
|
||||
dt := derivedTable{
|
||||
name: "t",
|
||||
selectStatus: selectStatus{
|
||||
|
|
Loading…
Reference in New Issue