parse.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package dolphin
  2. import (
  3. "errors"
  4. "io"
  5. "io/ioutil"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "github.com/pingcap/parser"
  10. _ "github.com/pingcap/tidb/types/parser_driver"
  11. "github.com/kyleconroy/sqlc/internal/metadata"
  12. "github.com/kyleconroy/sqlc/internal/sql/ast"
  13. "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
  14. )
  15. func NewParser() *Parser {
  16. return &Parser{parser.New()}
  17. }
  18. type Parser struct {
  19. pingcap *parser.Parser
  20. }
  21. var lineColumn = regexp.MustCompile(`^line (\d+) column (\d+) (.*)`)
  22. func normalizeErr(err error) error {
  23. if err == nil {
  24. return err
  25. }
  26. parts := strings.Split(err.Error(), "\n")
  27. msg := strings.TrimSpace(parts[0] + "\"")
  28. out := lineColumn.FindStringSubmatch(msg)
  29. if len(out) == 4 {
  30. line, lineErr := strconv.Atoi(out[1])
  31. col, colErr := strconv.Atoi(out[2])
  32. if lineErr != nil || colErr != nil {
  33. return errors.New(msg)
  34. }
  35. return &sqlerr.Error{
  36. Message: "syntax error",
  37. Err: errors.New(out[3]),
  38. Line: line,
  39. Column: col,
  40. }
  41. }
  42. return errors.New(msg)
  43. }
  44. func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
  45. blob, err := ioutil.ReadAll(r)
  46. if err != nil {
  47. return nil, err
  48. }
  49. stmtNodes, _, err := p.pingcap.Parse(string(blob), "", "")
  50. if err != nil {
  51. return nil, normalizeErr(err)
  52. }
  53. var stmts []ast.Statement
  54. for i := range stmtNodes {
  55. out := convert(stmtNodes[i])
  56. if _, ok := out.(*ast.TODO); ok {
  57. continue
  58. }
  59. // TODO: Attach the text directly to the ast.Statement node
  60. text := stmtNodes[i].Text()
  61. loc := strings.Index(string(blob), text)
  62. stmts = append(stmts, ast.Statement{
  63. Raw: &ast.RawStmt{
  64. Stmt: out,
  65. StmtLocation: loc,
  66. StmtLen: len(text) - 1, // Subtract one to remove semicolon
  67. },
  68. })
  69. }
  70. return stmts, nil
  71. }
  72. func (p *Parser) CommentSyntax() metadata.CommentSyntax {
  73. return metadata.CommentSyntaxStar
  74. }