generate.go 5.8 KB


  1. package cmd
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "github.com/kyleconroy/sqlc/internal/codegen/golang"
  12. "github.com/kyleconroy/sqlc/internal/codegen/kotlin"
  13. "github.com/kyleconroy/sqlc/internal/compiler"
  14. "github.com/kyleconroy/sqlc/internal/config"
  15. "github.com/kyleconroy/sqlc/internal/multierr"
  16. "github.com/kyleconroy/sqlc/internal/mysql"
  17. "github.com/kyleconroy/sqlc/internal/opts"
  18. )
  19. const errMessageNoVersion = `The configuration file must have a version number.
  20. Set the version to 1 at the top of sqlc.json:
  21. {
  22. "version": "1"
  23. ...
  24. }
  25. `
  26. const errMessageUnknownVersion = `The configuration file has an invalid version number.
  27. The only supported version is "1".
  28. `
  29. const errMessageNoPackages = `No packages are configured`
  30. func printFileErr(stderr io.Writer, dir string, fileErr *multierr.FileError) {
  31. filename := strings.TrimPrefix(fileErr.Filename, dir+"/")
  32. fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err)
  33. }
  34. type outPair struct {
  35. Gen config.SQLGen
  36. config.SQL
  37. }
  38. func Generate(e Env, dir string, stderr io.Writer) (map[string]string, error) {
  39. var yamlMissing, jsonMissing bool
  40. yamlPath := filepath.Join(dir, "sqlc.yaml")
  41. jsonPath := filepath.Join(dir, "sqlc.json")
  42. if _, err := os.Stat(yamlPath); os.IsNotExist(err) {
  43. yamlMissing = true
  44. }
  45. if _, err := os.Stat(jsonPath); os.IsNotExist(err) {
  46. jsonMissing = true
  47. }
  48. if yamlMissing && jsonMissing {
  49. fmt.Fprintln(stderr, "error parsing sqlc.json: file does not exist")
  50. return nil, errors.New("config file missing")
  51. }
  52. if !yamlMissing && !jsonMissing {
  53. fmt.Fprintln(stderr, "error parsing sqlc.json: both files present")
  54. return nil, errors.New("sqlc.json and sqlc.yaml present")
  55. }
  56. configPath := yamlPath
  57. if yamlMissing {
  58. configPath = jsonPath
  59. }
  60. blob, err := ioutil.ReadFile(configPath)
  61. if err != nil {
  62. fmt.Fprintln(stderr, "error parsing sqlc.json: file does not exist")
  63. return nil, err
  64. }
  65. conf, err := config.ParseConfig(bytes.NewReader(blob))
  66. if err != nil {
  67. switch err {
  68. case config.ErrMissingVersion:
  69. fmt.Fprintf(stderr, errMessageNoVersion)
  70. case config.ErrUnknownVersion:
  71. fmt.Fprintf(stderr, errMessageUnknownVersion)
  72. case config.ErrNoPackages:
  73. fmt.Fprintf(stderr, errMessageNoPackages)
  74. }
  75. fmt.Fprintf(stderr, "error parsing sqlc.json: %s\n", err)
  76. return nil, err
  77. }
  78. output := map[string]string{}
  79. errored := false
  80. var pairs []outPair
  81. for _, sql := range conf.SQL {
  82. if sql.Gen.Go != nil {
  83. pairs = append(pairs, outPair{
  84. SQL: sql,
  85. Gen: config.SQLGen{Go: sql.Gen.Go},
  86. })
  87. }
  88. if sql.Gen.Kotlin != nil {
  89. pairs = append(pairs, outPair{
  90. SQL: sql,
  91. Gen: config.SQLGen{Kotlin: sql.Gen.Kotlin},
  92. })
  93. }
  94. }
  95. for _, sql := range pairs {
  96. combo := config.Combine(conf, sql.SQL)
  97. // TODO: This feels like a hack that will bite us later
  98. joined := make([]string, 0, len(sql.Schema))
  99. for _, s := range sql.Schema {
  100. joined = append(joined, filepath.Join(dir, s))
  101. }
  102. sql.Schema = joined
  103. joined = make([]string, 0, len(sql.Queries))
  104. for _, q := range sql.Queries {
  105. joined = append(joined, filepath.Join(dir, q))
  106. }
  107. sql.Queries = joined
  108. var name string
  109. parseOpts := opts.Parser{}
  110. if sql.Gen.Go != nil {
  111. name = combo.Go.Package
  112. } else if sql.Gen.Kotlin != nil {
  113. parseOpts.UsePositionalParameters = true
  114. name = combo.Kotlin.Package
  115. }
  116. var files map[string]string
  117. var out string
  118. // TODO: Note about how this will be going away
  119. if sql.Engine == config.EngineMySQL {
  120. result, errored := parseMySQL(e, name, dir, sql.SQL, combo, parseOpts, stderr)
  121. if errored {
  122. break
  123. }
  124. out = combo.Go.Out
  125. files, err = golang.DeprecatedGenerate(result, combo)
  126. } else {
  127. result, errored := parse(e, name, dir, sql.SQL, combo, parseOpts, stderr)
  128. if errored {
  129. break
  130. }
  131. switch {
  132. case sql.Gen.Go != nil:
  133. out = combo.Go.Out
  134. files, err = golang.Generate(result, combo)
  135. case sql.Gen.Kotlin != nil:
  136. out = combo.Kotlin.Out
  137. files, err = kotlin.Generate(result, combo)
  138. default:
  139. panic("missing language backend")
  140. }
  141. }
  142. if err != nil {
  143. fmt.Fprintf(stderr, "# package %s\n", name)
  144. fmt.Fprintf(stderr, "error generating code: %s\n", err)
  145. errored = true
  146. continue
  147. }
  148. for n, source := range files {
  149. filename := filepath.Join(dir, out, n)
  150. output[filename] = source
  151. }
  152. }
  153. if errored {
  154. return nil, fmt.Errorf("errored")
  155. }
  156. return output, nil
  157. }
  158. // Experimental MySQL support
  159. func parseMySQL(e Env, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (golang.Generateable, bool) {
  160. q, err := mysql.GeneratePkg(name, sql.Schema, sql.Queries, combo)
  161. if err != nil {
  162. fmt.Fprintf(stderr, "# package %s\n", name)
  163. if parserErr, ok := err.(*multierr.Error); ok {
  164. for _, fileErr := range parserErr.Errs() {
  165. printFileErr(stderr, dir, fileErr)
  166. }
  167. } else {
  168. fmt.Fprintf(stderr, "error parsing schema: %s\n", err)
  169. }
  170. return nil, true
  171. }
  172. return q, false
  173. }
  174. func parse(e Env, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
  175. eng := compiler.NewEngine(sql, combo)
  176. if err := eng.ParseCatalog(sql.Schema); err != nil {
  177. fmt.Fprintf(stderr, "# package %s\n", name)
  178. if parserErr, ok := err.(*multierr.Error); ok {
  179. for _, fileErr := range parserErr.Errs() {
  180. printFileErr(stderr, dir, fileErr)
  181. }
  182. } else {
  183. fmt.Fprintf(stderr, "error parsing schema: %s\n", err)
  184. }
  185. return nil, true
  186. }
  187. if err := eng.ParseQueries(sql.Queries, parserOpts); err != nil {
  188. fmt.Fprintf(stderr, "# package %s\n", name)
  189. if parserErr, ok := err.(*multierr.Error); ok {
  190. for _, fileErr := range parserErr.Errs() {
  191. printFileErr(stderr, dir, fileErr)
  192. }
  193. } else {
  194. fmt.Fprintf(stderr, "error parsing queries: %s\n", err)
  195. }
  196. return nil, true
  197. }
  198. return eng.Result(), false
  199. }