123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- package mysql
- import (
- "fmt"
- "sort"
- "strings"
- "github.com/jinzhu/inflection"
- "vitess.io/vitess/go/vt/sqlparser"
- "github.com/kyleconroy/sqlc/internal/codegen"
- "github.com/kyleconroy/sqlc/internal/codegen/golang"
- "github.com/kyleconroy/sqlc/internal/config"
- "github.com/kyleconroy/sqlc/internal/core"
- )
- type PackageGenerator struct {
- *Schema
- config.CombinedSettings
- packageName string
- }
- type Result struct {
- PackageGenerator
- Queries []*Query
- }
- // Enums generates parser-agnostic GoEnum types
- func (r *Result) Enums(settings config.CombinedSettings) []golang.Enum {
- var enums []golang.Enum
- for _, table := range r.Schema.tables {
- for _, col := range table {
- if strings.ToLower(col.Type.Type) == "enum" {
- constants := []golang.Constant{}
- enumName := r.enumNameFromColDef(col)
- for _, c := range col.Type.EnumValues {
- stripped := stripInnerQuotes(c)
- constants = append(constants, golang.Constant{
- // TODO: maybe add the struct name call to capitalize the name here
- Name: stripped,
- Value: stripped,
- Type: enumName,
- })
- }
- goEnum := golang.Enum{
- Name: enumName,
- Comment: "",
- Constants: constants,
- }
- enums = append(enums, goEnum)
- }
- }
- }
- return enums
- }
- func stripInnerQuotes(identifier string) string {
- return strings.Replace(identifier, "'", "", 2)
- }
- func (pGen PackageGenerator) enumNameFromColDef(col *sqlparser.ColumnDefinition) string {
- return fmt.Sprintf("%sType",
- golang.StructName(col.Name.String(), pGen.CombinedSettings))
- }
- // Structs marshels each query into a go struct for generation
- func (r *Result) Structs(settings config.CombinedSettings) []golang.Struct {
- var structs []golang.Struct
- for tableName, cols := range r.Schema.tables {
- structName := golang.StructName(tableName, settings)
- if !(settings.Go.EmitExactTableNames || settings.Kotlin.EmitExactTableNames) {
- structName = inflection.Singular(structName)
- }
- s := golang.Struct{
- Name: structName,
- Table: core.FQN{tableName, "", ""}, // TODO: Complete hack. Only need for equality check to see if struct can be reused between queries
- }
- for _, col := range cols {
- s.Fields = append(s.Fields, golang.Field{
- Name: golang.StructName(col.Name.String(), settings),
- Type: r.goTypeCol(Column{col, tableName}),
- Tags: map[string]string{"json:": col.Name.String()},
- Comment: "",
- })
- }
- structs = append(structs, s)
- }
- sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name })
- return structs
- }
- // GoQueries generates parser-agnostic query information for code generation
- func (r *Result) GoQueries(settings config.CombinedSettings) []golang.Query {
- structs := r.Structs(settings)
- qs := make([]golang.Query, 0, len(r.Queries))
- for ix, query := range r.Queries {
- if query == nil {
- panic(fmt.Sprintf("query is nil on index: %v, len: %v", ix, len(r.Queries)))
- }
- if query.Name == "" {
- continue
- }
- if query.Cmd == "" {
- continue
- }
- gq := golang.Query{
- Cmd: query.Cmd,
- ConstantName: codegen.LowerTitle(query.Name),
- FieldName: codegen.LowerTitle(query.Name) + "Stmt",
- MethodName: query.Name,
- SourceName: query.Filename,
- SQL: query.SQL,
- // Comments: query.Comments,
- }
- if len(query.Params) == 1 {
- p := query.Params[0]
- gq.Arg = golang.QueryValue{
- Name: p.Name,
- Typ: p.Typ,
- }
- } else if len(query.Params) > 1 {
- structInfo := make([]structParams, len(query.Params))
- for i := range query.Params {
- structInfo[i] = structParams{
- originalName: query.Params[i].Name,
- goType: query.Params[i].Typ,
- }
- }
- gq.Arg = golang.QueryValue{
- Emit: true,
- Name: "arg",
- Struct: r.columnsToStruct(gq.MethodName+"Params", structInfo, settings),
- }
- }
- if len(query.Columns) == 1 {
- c := query.Columns[0]
- gq.Ret = golang.QueryValue{
- Name: columnName(c.ColumnDefinition, 0),
- Typ: r.goTypeCol(c),
- }
- } else if len(query.Columns) > 1 {
- var gs *golang.Struct
- var emit bool
- for _, s := range structs {
- if len(s.Fields) != len(query.Columns) {
- continue
- }
- same := true
- for i, f := range s.Fields {
- c := query.Columns[i]
- sameName := f.Name == golang.StructName(columnName(c.ColumnDefinition, i), settings)
- sameType := f.Type == r.goTypeCol(c)
- hackedFQN := core.FQN{c.Table, "", ""} // TODO: only check needed here is equality to see if struct can be reused, this type should be removed or properly used
- sameTable := s.Table.Catalog == hackedFQN.Catalog && s.Table.Schema == hackedFQN.Schema && s.Table.Rel == hackedFQN.Rel
- if !sameName || !sameType || !sameTable {
- same = false
- }
- }
- if same {
- gs = &s
- break
- }
- }
- if gs == nil {
- structInfo := make([]structParams, len(query.Columns))
- for i := range query.Columns {
- structInfo[i] = structParams{
- originalName: query.Columns[i].Name.String(),
- goType: r.goTypeCol(query.Columns[i]),
- }
- }
- gs = r.columnsToStruct(gq.MethodName+"Row", structInfo, settings)
- emit = true
- }
- gq.Ret = golang.QueryValue{
- Emit: emit,
- Name: "i",
- Struct: gs,
- }
- }
- qs = append(qs, gq)
- }
- sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName })
- return qs
- }
- type structParams struct {
- originalName string
- goType string
- }
- func (r *Result) columnsToStruct(name string, items []structParams, settings config.CombinedSettings) *golang.Struct {
- gs := golang.Struct{
- Name: name,
- }
- seen := map[string]int{}
- for _, item := range items {
- name := item.originalName
- typ := item.goType
- tagName := name
- fieldName := golang.StructName(name, settings)
- if v := seen[name]; v > 0 {
- tagName = fmt.Sprintf("%s_%d", tagName, v+1)
- fieldName = fmt.Sprintf("%s_%d", fieldName, v+1)
- }
- gs.Fields = append(gs.Fields, golang.Field{
- Name: fieldName,
- Type: typ,
- Tags: map[string]string{"json:": tagName},
- })
- seen[name]++
- }
- return &gs
- }
- func (pGen PackageGenerator) goTypeCol(col Column) string {
- mySQLType := strings.ToLower(col.ColumnDefinition.Type.Type)
- notNull := bool(col.Type.NotNull)
- colName := col.Name.String()
- for _, oride := range pGen.Overrides {
- shouldOverride := (oride.DBType != "" && oride.DBType == mySQLType && oride.Null != notNull) ||
- (oride.ColumnName != "" && oride.ColumnName == colName && oride.Table.Rel == col.Table)
- if shouldOverride {
- return oride.GoTypeName
- }
- }
- switch t := mySQLType; {
- case "varchar" == t, "text" == t, "char" == t,
- "tinytext" == t, "mediumtext" == t, "longtext" == t:
- if col.Type.NotNull {
- return "string"
- }
- return "sql.NullString"
- case "int" == t, "integer" == t, t == "smallint",
- "mediumint" == t, "bigint" == t, "year" == t:
- if col.Type.NotNull {
- return "int"
- }
- return "sql.NullInt64"
- case "blob" == t, "binary" == t, "varbinary" == t, "tinyblob" == t,
- "mediumblob" == t, "longblob" == t:
- return "[]byte"
- case "float" == t, strings.HasPrefix(strings.ToLower(t), "decimal"):
- if col.Type.NotNull {
- return "float64"
- }
- return "sql.NullFloat64"
- case "enum" == t:
- return pGen.enumNameFromColDef(col.ColumnDefinition)
- case "date" == t, "timestamp" == t, "datetime" == t, "time" == t:
- if col.Type.NotNull {
- return "time.Time"
- }
- return "sql.NullTime"
- case "boolean" == t, "bool" == t, "tinyint" == t:
- if col.Type.NotNull {
- return "bool"
- }
- return "sql.NullBool"
- default:
- fmt.Printf("unknown MySQL type: %s\n", t)
- return "interface{}"
- }
- }
- func columnName(c *sqlparser.ColumnDefinition, pos int) string {
- if !c.Name.IsEmpty() {
- return c.Name.String()
- }
- return fmt.Sprintf("column_%d", pos+1)
- }
- func argName(name string) string {
- out := ""
- for i, p := range strings.Split(name, "_") {
- if i == 0 {
- out += strings.ToLower(p)
- } else if p == "id" {
- out += "ID"
- } else {
- out += strings.Title(p)
- }
- }
- return out
- }
|