parse.go 13 KB


  1. package mysql
  2. import (
  3. "fmt"
  4. "io"
  5. "io/ioutil"
  6. "path/filepath"
  7. "github.com/davecgh/go-spew/spew"
  8. "vitess.io/vitess/go/vt/sqlparser"
  9. "github.com/kyleconroy/sqlc/internal/config"
  10. "github.com/kyleconroy/sqlc/internal/metadata"
  11. "github.com/kyleconroy/sqlc/internal/migrations"
  12. "github.com/kyleconroy/sqlc/internal/multierr"
  13. "github.com/kyleconroy/sqlc/internal/sql/sqlpath"
  14. )
  15. // Query holds the data for walking and validating mysql querys
  16. type Query struct {
  17. SQL string // the string representation of the parsed query
  18. Columns []Column
  19. Params []*Param // "?" params in the query string
  20. Name string // the Go function name
  21. Cmd string // TODO: Pick a better name. One of: one, many, exec, execrows
  22. DefaultTableName string // for columns that are not qualified
  23. Filename string
  24. }
  25. type Column struct {
  26. *sqlparser.ColumnDefinition
  27. Table string
  28. }
  29. func parsePath(sqlPath []string, generator PackageGenerator) (*Result, error) {
  30. files, err := sqlpath.Glob(sqlPath)
  31. if err != nil {
  32. return nil, err
  33. }
  34. parseErrors := multierr.New()
  35. parsedQueries := []*Query{}
  36. for _, filename := range files {
  37. blob, err := ioutil.ReadFile(filename)
  38. if err != nil {
  39. parseErrors.Add(filename, "", 0, err)
  40. }
  41. contents := migrations.RemoveRollbackStatements(string(blob))
  42. if err != nil {
  43. parseErrors.Add(filename, "", 0, err)
  44. continue
  45. }
  46. t := sqlparser.NewStringTokenizer(contents)
  47. var start int
  48. for {
  49. q, err := sqlparser.ParseNextStrictDDL(t)
  50. if err == io.EOF {
  51. break
  52. } else if err != nil {
  53. if posErr, ok := err.(sqlparser.PositionedErr); ok {
  54. message := fmt.Errorf(posErr.Err)
  55. if posErr.Near != nil {
  56. message = fmt.Errorf("%s at or near \"%s\"", posErr.Err, posErr.Near)
  57. }
  58. parseErrors.Add(filename, contents, posErr.Pos, message)
  59. } else {
  60. parseErrors.Add(filename, contents, start, err)
  61. }
  62. continue
  63. }
  64. query := contents[start : t.Position-1]
  65. result, err := generator.parseQueryString(q, query)
  66. if err != nil {
  67. parseErrors.Add(filename, contents, start, err)
  68. start = t.Position
  69. continue
  70. }
  71. start = t.Position
  72. if result == nil {
  73. continue
  74. }
  75. result.Filename = filepath.Base(filename)
  76. parsedQueries = append(parsedQueries, result)
  77. }
  78. }
  79. if len(parseErrors.Errs()) > 0 {
  80. return nil, parseErrors
  81. }
  82. return &Result{
  83. Queries: parsedQueries,
  84. PackageGenerator: generator,
  85. }, nil
  86. }
  87. func (pGen PackageGenerator) parseQueryString(tree sqlparser.Statement, query string) (*Query, error) {
  88. var parsedQuery *Query
  89. switch tree := tree.(type) {
  90. case *sqlparser.Select:
  91. selectQuery, err := pGen.parseSelect(tree, query)
  92. if err != nil {
  93. return nil, err
  94. }
  95. parsedQuery = selectQuery
  96. case *sqlparser.Insert:
  97. insert, err := pGen.parseInsert(tree, query)
  98. if err != nil {
  99. return nil, err
  100. }
  101. parsedQuery = insert
  102. case *sqlparser.Update:
  103. update, err := pGen.parseUpdate(tree, query)
  104. if err != nil {
  105. return nil, err
  106. }
  107. parsedQuery = update
  108. case *sqlparser.Delete:
  109. delete, err := pGen.parseDelete(tree, query)
  110. if err != nil {
  111. return nil, err
  112. }
  113. parsedQuery = delete
  114. case *sqlparser.DDL:
  115. pGen.Schema.Add(tree)
  116. return nil, nil
  117. default:
  118. // panic("Unsupported SQL statement type")
  119. return nil, nil
  120. }
  121. paramsReplacedQuery, err := replaceParamStrs(sqlparser.String(tree), parsedQuery.Params)
  122. if err != nil {
  123. return nil, fmt.Errorf("failed to replace param variables in query string: %w", err)
  124. }
  125. parsedQuery.SQL = paramsReplacedQuery
  126. return parsedQuery, nil
  127. }
  128. func (q *Query) parseNameAndCmd() error {
  129. if q == nil {
  130. return fmt.Errorf("cannot parse name and cmd from null query")
  131. }
  132. _, comments := sqlparser.SplitMarginComments(q.SQL)
  133. name, cmd, err := metadata.Parse(comments.Leading, metadata.CommentSyntaxStar)
  134. if err != nil {
  135. return err
  136. } else if name == "" || cmd == "" {
  137. return fmt.Errorf("failed to parse query leading comment")
  138. }
  139. q.Name = name
  140. q.Cmd = cmd
  141. return nil
  142. }
  143. func (pGen PackageGenerator) parseSelect(tree *sqlparser.Select, query string) (*Query, error) {
  144. tableAliasMap, defaultTableName, err := parseFrom(tree.From, false)
  145. if err != nil {
  146. return nil, fmt.Errorf("failed to parse table name alias's: %w", err)
  147. }
  148. // handle * expressions first by expanding all columns of the default table
  149. _, ok := tree.SelectExprs[0].(*sqlparser.StarExpr)
  150. if ok {
  151. colNames := []sqlparser.SelectExpr{}
  152. colDfns := pGen.Schema.tables[defaultTableName]
  153. for _, col := range colDfns {
  154. colNames = append(colNames, &sqlparser.AliasedExpr{
  155. Expr: &sqlparser.ColName{
  156. Name: col.Name,
  157. }},
  158. )
  159. }
  160. tree.SelectExprs = colNames
  161. }
  162. parsedQuery := Query{
  163. SQL: query,
  164. DefaultTableName: defaultTableName,
  165. }
  166. cols, err := pGen.parseSelectAliasExpr(tree.SelectExprs, tableAliasMap, defaultTableName)
  167. if err != nil {
  168. return nil, err
  169. }
  170. parsedQuery.Columns = cols
  171. whereParams, err := pGen.paramsInWhereExpr(tree.Where, tableAliasMap, defaultTableName)
  172. if err != nil {
  173. return nil, err
  174. }
  175. limitParams, err := pGen.paramsInLimitExpr(tree.Limit, tableAliasMap)
  176. if err != nil {
  177. return nil, err
  178. }
  179. parsedQuery.Params = append(whereParams, limitParams...)
  180. err = parsedQuery.parseNameAndCmd()
  181. if err != nil {
  182. return nil, err
  183. }
  184. return &parsedQuery, nil
  185. }
  186. // FromTable describes a table reference in the "FROM" clause of a query.
  187. type FromTable struct {
  188. TrueName string // the true table name as described in the schema
  189. IsLeftJoined bool // which could result in null columns
  190. }
  191. // FromTables describes a map between table alias expressions and the
  192. // proper table name
  193. type FromTables map[string]FromTable
  194. func parseFrom(from sqlparser.TableExprs, isLeftJoined bool) (FromTables, string, error) {
  195. tables := make(map[string]FromTable)
  196. var defaultTableName string
  197. for _, expr := range from {
  198. switch v := expr.(type) {
  199. case *sqlparser.AliasedTableExpr:
  200. name, ok := v.Expr.(sqlparser.TableName)
  201. if !ok {
  202. return nil, "", fmt.Errorf("failed to parse AliasedTableExpr name: %v", spew.Sdump(v))
  203. }
  204. t := FromTable{
  205. TrueName: name.Name.String(),
  206. IsLeftJoined: isLeftJoined,
  207. }
  208. if v.As.String() != "" {
  209. tables[v.As.String()] = t
  210. } else {
  211. tables[name.Name.String()] = t
  212. }
  213. defaultTableName = name.Name.String()
  214. case *sqlparser.JoinTableExpr:
  215. isLeftJoin := v.Join == "left join"
  216. left, leftMostTableName, err := parseFrom([]sqlparser.TableExpr{v.LeftExpr}, false)
  217. if err != nil {
  218. return nil, "", err
  219. }
  220. right, _, err := parseFrom([]sqlparser.TableExpr{v.RightExpr}, isLeftJoin)
  221. if err != nil {
  222. return nil, "", err
  223. }
  224. // merge the left and right maps
  225. for k, v := range left {
  226. right[k] = v
  227. }
  228. return right, leftMostTableName, nil
  229. default:
  230. return nil, "", fmt.Errorf("failed to parse table expr: %v", spew.Sdump(v))
  231. }
  232. }
  233. return tables, defaultTableName, nil
  234. }
  235. func (pGen PackageGenerator) parseUpdate(node *sqlparser.Update, query string) (*Query, error) {
  236. tableAliasMap, defaultTable, err := parseFrom(node.TableExprs, false)
  237. if err != nil {
  238. return nil, fmt.Errorf("failed to parse table name alias's: %w", err)
  239. }
  240. params := []*Param{}
  241. for _, updateExpr := range node.Exprs {
  242. newValue, isValue := updateExpr.Expr.(*sqlparser.SQLVal)
  243. if !isValue {
  244. continue
  245. } else if isParam := newValue.Type == sqlparser.ValArg; !isParam {
  246. continue
  247. }
  248. col, err := pGen.getColType(updateExpr.Name, tableAliasMap, defaultTable)
  249. if err != nil {
  250. return nil, fmt.Errorf("failed to determine type of a parameter's column: %w", err)
  251. }
  252. originalParamName := string(newValue.Val)
  253. param := Param{
  254. OriginalName: originalParamName,
  255. Name: paramName(col.Name, originalParamName),
  256. Typ: pGen.goTypeCol(*col),
  257. }
  258. params = append(params, &param)
  259. }
  260. whereParams, err := pGen.paramsInWhereExpr(node.Where, tableAliasMap, defaultTable)
  261. if err != nil {
  262. return nil, fmt.Errorf("failed to parse params from WHERE expression: %w", err)
  263. }
  264. parsedQuery := Query{
  265. SQL: query,
  266. Columns: nil,
  267. Params: append(params, whereParams...),
  268. DefaultTableName: defaultTable,
  269. }
  270. err = parsedQuery.parseNameAndCmd()
  271. if err != nil {
  272. return nil, err
  273. }
  274. return &parsedQuery, nil
  275. }
  276. func (pGen PackageGenerator) parseInsert(node *sqlparser.Insert, query string) (*Query, error) {
  277. params := []*Param{}
  278. cols := node.Columns
  279. tableName := node.Table.Name.String()
  280. switch rows := node.Rows.(type) {
  281. case *sqlparser.Select:
  282. selectQuery, err := pGen.parseSelect(rows, query)
  283. if err != nil {
  284. return nil, err
  285. }
  286. params = append(params, selectQuery.Params...)
  287. case sqlparser.Values:
  288. for _, row := range rows {
  289. for colIx, item := range row {
  290. switch v := item.(type) {
  291. case *sqlparser.SQLVal:
  292. if v.Type == sqlparser.ValArg {
  293. colName := cols[colIx].String()
  294. col, err := pGen.schemaLookup(tableName, colName)
  295. varName := string(v.Val)
  296. param := &Param{OriginalName: varName}
  297. if err == nil {
  298. param.Name = paramName(col.Name, varName)
  299. param.Typ = pGen.goTypeCol(*col)
  300. } else {
  301. param.Name = "Unknown"
  302. param.Typ = "interface{}"
  303. }
  304. params = append(params, param)
  305. }
  306. case *sqlparser.FuncExpr:
  307. name, raw, err := matchFuncExpr(v)
  308. if err != nil {
  309. return nil, err
  310. }
  311. if name == "" || raw == "" {
  312. continue
  313. }
  314. colName := cols[colIx].String()
  315. col, err := pGen.schemaLookup(tableName, colName)
  316. param := &Param{
  317. OriginalName: raw,
  318. }
  319. if err == nil {
  320. param.Name = name
  321. param.Typ = pGen.goTypeCol(*col)
  322. } else {
  323. param.Name = "Unknown"
  324. param.Typ = "interface{}"
  325. }
  326. params = append(params, param)
  327. default:
  328. return nil, fmt.Errorf("failed to parse insert query value")
  329. }
  330. }
  331. }
  332. default:
  333. return nil, fmt.Errorf("Unknown insert row type of %T", node.Rows)
  334. }
  335. parsedQuery := &Query{
  336. SQL: query,
  337. Params: params,
  338. Columns: nil,
  339. DefaultTableName: tableName,
  340. }
  341. err := parsedQuery.parseNameAndCmd()
  342. if err != nil {
  343. return nil, err
  344. }
  345. return parsedQuery, nil
  346. }
  347. func (pGen PackageGenerator) parseDelete(node *sqlparser.Delete, query string) (*Query, error) {
  348. tableAliasMap, defaultTableName, err := parseFrom(node.TableExprs, false)
  349. if err != nil {
  350. return nil, fmt.Errorf("failed to parse table name alias's: %w", err)
  351. }
  352. whereParams, err := pGen.paramsInWhereExpr(node.Where, tableAliasMap, defaultTableName)
  353. if err != nil {
  354. return nil, err
  355. }
  356. limitParams, err := pGen.paramsInLimitExpr(node.Limit, tableAliasMap)
  357. if err != nil {
  358. return nil, err
  359. }
  360. parsedQuery := &Query{
  361. SQL: query,
  362. Params: append(whereParams, limitParams...),
  363. Columns: nil,
  364. DefaultTableName: defaultTableName,
  365. }
  366. err = parsedQuery.parseNameAndCmd()
  367. if err != nil {
  368. return nil, err
  369. }
  370. return parsedQuery, nil
  371. }
  372. func (pGen PackageGenerator) parseSelectAliasExpr(exprs sqlparser.SelectExprs, tableAliasMap FromTables, defaultTable string) ([]Column, error) {
  373. cols := []Column{}
  374. for _, col := range exprs {
  375. switch expr := col.(type) {
  376. case *sqlparser.AliasedExpr:
  377. hasAlias := !expr.As.IsEmpty()
  378. switch v := expr.Expr.(type) {
  379. case *sqlparser.ColName:
  380. res, err := pGen.getColType(v, tableAliasMap, defaultTable)
  381. if err != nil {
  382. return nil, err
  383. }
  384. if hasAlias {
  385. res.Name = expr.As // applys the alias
  386. }
  387. cols = append(cols, *res)
  388. case *sqlparser.GroupConcatExpr:
  389. cols = append(cols, Column{
  390. ColumnDefinition: &sqlparser.ColumnDefinition{
  391. Name: sqlparser.NewColIdent(expr.As.String()),
  392. Type: sqlparser.ColumnType{
  393. Type: "varchar",
  394. NotNull: true,
  395. },
  396. },
  397. Table: "", // group concat expressions don't originate from a table schema
  398. },
  399. )
  400. case *sqlparser.FuncExpr:
  401. funcName := v.Name.Lowered()
  402. funcType := functionReturnType(funcName)
  403. var returnVal sqlparser.ColIdent
  404. if hasAlias {
  405. returnVal = expr.As
  406. } else {
  407. returnVal = sqlparser.NewColIdent(funcName)
  408. }
  409. colDfn := &sqlparser.ColumnDefinition{
  410. Name: returnVal,
  411. Type: sqlparser.ColumnType{
  412. Type: funcType,
  413. NotNull: true,
  414. },
  415. }
  416. cols = append(cols, Column{colDfn, ""}) // func returns types don't originate from a table schema
  417. }
  418. default:
  419. return nil, fmt.Errorf("Failed to handle select expr of type : %T", expr)
  420. }
  421. }
  422. return cols, nil
  423. }
  424. // GeneratePkg is the main entry to mysql generator package
  425. func GeneratePkg(pkgName string, schemaPath, querysPath []string, settings config.CombinedSettings) (*Result, error) {
  426. s := NewSchema()
  427. generator := PackageGenerator{
  428. Schema: s,
  429. CombinedSettings: settings,
  430. packageName: pkgName,
  431. }
  432. _, err := parsePath(schemaPath, generator)
  433. if err != nil {
  434. return nil, err
  435. }
  436. result, err := parsePath(querysPath, generator)
  437. if err != nil {
  438. return nil, err
  439. }
  440. return result, nil
  441. }