gen.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package mysql
  2. import (
  3. "fmt"
  4. "sort"
  5. "strings"
  6. "github.com/jinzhu/inflection"
  7. "vitess.io/vitess/go/vt/sqlparser"
  8. "github.com/kyleconroy/sqlc/internal/codegen"
  9. "github.com/kyleconroy/sqlc/internal/codegen/golang"
  10. "github.com/kyleconroy/sqlc/internal/config"
  11. "github.com/kyleconroy/sqlc/internal/core"
  12. )
  13. type PackageGenerator struct {
  14. *Schema
  15. config.CombinedSettings
  16. packageName string
  17. }
  18. type Result struct {
  19. PackageGenerator
  20. Queries []*Query
  21. }
  22. // Enums generates parser-agnostic GoEnum types
  23. func (r *Result) Enums(settings config.CombinedSettings) []golang.Enum {
  24. var enums []golang.Enum
  25. for _, table := range r.Schema.tables {
  26. for _, col := range table {
  27. if strings.ToLower(col.Type.Type) == "enum" {
  28. constants := []golang.Constant{}
  29. enumName := r.enumNameFromColDef(col)
  30. for _, c := range col.Type.EnumValues {
  31. stripped := stripInnerQuotes(c)
  32. constants = append(constants, golang.Constant{
  33. // TODO: maybe add the struct name call to capitalize the name here
  34. Name: stripped,
  35. Value: stripped,
  36. Type: enumName,
  37. })
  38. }
  39. goEnum := golang.Enum{
  40. Name: enumName,
  41. Comment: "",
  42. Constants: constants,
  43. }
  44. enums = append(enums, goEnum)
  45. }
  46. }
  47. }
  48. return enums
  49. }
  50. func stripInnerQuotes(identifier string) string {
  51. return strings.Replace(identifier, "'", "", 2)
  52. }
  53. func (pGen PackageGenerator) enumNameFromColDef(col *sqlparser.ColumnDefinition) string {
  54. return fmt.Sprintf("%sType",
  55. golang.StructName(col.Name.String(), pGen.CombinedSettings))
  56. }
  57. // Structs marshels each query into a go struct for generation
  58. func (r *Result) Structs(settings config.CombinedSettings) []golang.Struct {
  59. var structs []golang.Struct
  60. for tableName, cols := range r.Schema.tables {
  61. structName := golang.StructName(tableName, settings)
  62. if !(settings.Go.EmitExactTableNames || settings.Kotlin.EmitExactTableNames) {
  63. structName = inflection.Singular(structName)
  64. }
  65. s := golang.Struct{
  66. Name: structName,
  67. Table: core.FQN{tableName, "", ""}, // TODO: Complete hack. Only need for equality check to see if struct can be reused between queries
  68. }
  69. for _, col := range cols {
  70. s.Fields = append(s.Fields, golang.Field{
  71. Name: golang.StructName(col.Name.String(), settings),
  72. Type: r.goTypeCol(Column{col, tableName}),
  73. Tags: map[string]string{"json:": col.Name.String()},
  74. Comment: "",
  75. })
  76. }
  77. structs = append(structs, s)
  78. }
  79. sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name })
  80. return structs
  81. }
  82. // GoQueries generates parser-agnostic query information for code generation
  83. func (r *Result) GoQueries(settings config.CombinedSettings) []golang.Query {
  84. structs := r.Structs(settings)
  85. qs := make([]golang.Query, 0, len(r.Queries))
  86. for ix, query := range r.Queries {
  87. if query == nil {
  88. panic(fmt.Sprintf("query is nil on index: %v, len: %v", ix, len(r.Queries)))
  89. }
  90. if query.Name == "" {
  91. continue
  92. }
  93. if query.Cmd == "" {
  94. continue
  95. }
  96. gq := golang.Query{
  97. Cmd: query.Cmd,
  98. ConstantName: codegen.LowerTitle(query.Name),
  99. FieldName: codegen.LowerTitle(query.Name) + "Stmt",
  100. MethodName: query.Name,
  101. SourceName: query.Filename,
  102. SQL: query.SQL,
  103. // Comments: query.Comments,
  104. }
  105. if len(query.Params) == 1 {
  106. p := query.Params[0]
  107. gq.Arg = golang.QueryValue{
  108. Name: p.Name,
  109. Typ: p.Typ,
  110. }
  111. } else if len(query.Params) > 1 {
  112. structInfo := make([]structParams, len(query.Params))
  113. for i := range query.Params {
  114. structInfo[i] = structParams{
  115. originalName: query.Params[i].Name,
  116. goType: query.Params[i].Typ,
  117. }
  118. }
  119. gq.Arg = golang.QueryValue{
  120. Emit: true,
  121. Name: "arg",
  122. Struct: r.columnsToStruct(gq.MethodName+"Params", structInfo, settings),
  123. }
  124. }
  125. if len(query.Columns) == 1 {
  126. c := query.Columns[0]
  127. gq.Ret = golang.QueryValue{
  128. Name: columnName(c.ColumnDefinition, 0),
  129. Typ: r.goTypeCol(c),
  130. }
  131. } else if len(query.Columns) > 1 {
  132. var gs *golang.Struct
  133. var emit bool
  134. for _, s := range structs {
  135. if len(s.Fields) != len(query.Columns) {
  136. continue
  137. }
  138. same := true
  139. for i, f := range s.Fields {
  140. c := query.Columns[i]
  141. sameName := f.Name == golang.StructName(columnName(c.ColumnDefinition, i), settings)
  142. sameType := f.Type == r.goTypeCol(c)
  143. hackedFQN := core.FQN{c.Table, "", ""} // TODO: only check needed here is equality to see if struct can be reused, this type should be removed or properly used
  144. sameTable := s.Table.Catalog == hackedFQN.Catalog && s.Table.Schema == hackedFQN.Schema && s.Table.Rel == hackedFQN.Rel
  145. if !sameName || !sameType || !sameTable {
  146. same = false
  147. }
  148. }
  149. if same {
  150. gs = &s
  151. break
  152. }
  153. }
  154. if gs == nil {
  155. structInfo := make([]structParams, len(query.Columns))
  156. for i := range query.Columns {
  157. structInfo[i] = structParams{
  158. originalName: query.Columns[i].Name.String(),
  159. goType: r.goTypeCol(query.Columns[i]),
  160. }
  161. }
  162. gs = r.columnsToStruct(gq.MethodName+"Row", structInfo, settings)
  163. emit = true
  164. }
  165. gq.Ret = golang.QueryValue{
  166. Emit: emit,
  167. Name: "i",
  168. Struct: gs,
  169. }
  170. }
  171. qs = append(qs, gq)
  172. }
  173. sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName })
  174. return qs
  175. }
  176. type structParams struct {
  177. originalName string
  178. goType string
  179. }
  180. func (r *Result) columnsToStruct(name string, items []structParams, settings config.CombinedSettings) *golang.Struct {
  181. gs := golang.Struct{
  182. Name: name,
  183. }
  184. seen := map[string]int{}
  185. for _, item := range items {
  186. name := item.originalName
  187. typ := item.goType
  188. tagName := name
  189. fieldName := golang.StructName(name, settings)
  190. if v := seen[name]; v > 0 {
  191. tagName = fmt.Sprintf("%s_%d", tagName, v+1)
  192. fieldName = fmt.Sprintf("%s_%d", fieldName, v+1)
  193. }
  194. gs.Fields = append(gs.Fields, golang.Field{
  195. Name: fieldName,
  196. Type: typ,
  197. Tags: map[string]string{"json:": tagName},
  198. })
  199. seen[name]++
  200. }
  201. return &gs
  202. }
  203. func (pGen PackageGenerator) goTypeCol(col Column) string {
  204. mySQLType := strings.ToLower(col.ColumnDefinition.Type.Type)
  205. notNull := bool(col.Type.NotNull)
  206. colName := col.Name.String()
  207. for _, oride := range pGen.Overrides {
  208. shouldOverride := (oride.DBType != "" && oride.DBType == mySQLType && oride.Null != notNull) ||
  209. (oride.ColumnName != "" && oride.ColumnName == colName && oride.Table.Rel == col.Table)
  210. if shouldOverride {
  211. return oride.GoTypeName
  212. }
  213. }
  214. switch t := mySQLType; {
  215. case "varchar" == t, "text" == t, "char" == t,
  216. "tinytext" == t, "mediumtext" == t, "longtext" == t:
  217. if col.Type.NotNull {
  218. return "string"
  219. }
  220. return "sql.NullString"
  221. case "int" == t, "integer" == t, t == "smallint",
  222. "mediumint" == t, "bigint" == t, "year" == t:
  223. if col.Type.NotNull {
  224. return "int"
  225. }
  226. return "sql.NullInt64"
  227. case "blob" == t, "binary" == t, "varbinary" == t, "tinyblob" == t,
  228. "mediumblob" == t, "longblob" == t:
  229. return "[]byte"
  230. case "float" == t, strings.HasPrefix(strings.ToLower(t), "decimal"):
  231. if col.Type.NotNull {
  232. return "float64"
  233. }
  234. return "sql.NullFloat64"
  235. case "enum" == t:
  236. return pGen.enumNameFromColDef(col.ColumnDefinition)
  237. case "date" == t, "timestamp" == t, "datetime" == t, "time" == t:
  238. if col.Type.NotNull {
  239. return "time.Time"
  240. }
  241. return "sql.NullTime"
  242. case "boolean" == t, "bool" == t, "tinyint" == t:
  243. if col.Type.NotNull {
  244. return "bool"
  245. }
  246. return "sql.NullBool"
  247. default:
  248. fmt.Printf("unknown MySQL type: %s\n", t)
  249. return "interface{}"
  250. }
  251. }
  252. func columnName(c *sqlparser.ColumnDefinition, pos int) string {
  253. if !c.Name.IsEmpty() {
  254. return c.Name.String()
  255. }
  256. return fmt.Sprintf("column_%d", pos+1)
  257. }
  258. func argName(name string) string {
  259. out := ""
  260. for i, p := range strings.Split(name, "_") {
  261. if i == 0 {
  262. out += strings.ToLower(p)
  263. } else if p == "id" {
  264. out += "ID"
  265. } else {
  266. out += strings.Title(p)
  267. }
  268. }
  269. return out
  270. }