1
0

parse.go 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package sqlite
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "github.com/antlr/antlr4/runtime/Go/antlr/v4"
  7. "github.com/sqlc-dev/sqlc/internal/engine/sqlite/parser"
  8. "github.com/sqlc-dev/sqlc/internal/source"
  9. "github.com/sqlc-dev/sqlc/internal/sql/ast"
  10. )
  11. type errorListener struct {
  12. *antlr.DefaultErrorListener
  13. err string
  14. }
  15. func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
  16. el.err = msg
  17. }
  18. // func (el *errorListener) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
  19. // }
  20. //
  21. // func (el *errorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
  22. // }
  23. //
  24. // func (el *errorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs antlr.ATNConfigSet) {
  25. // }
  26. func NewParser() *Parser {
  27. return &Parser{}
  28. }
  29. type Parser struct {
  30. }
  31. func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
  32. blob, err := io.ReadAll(r)
  33. if err != nil {
  34. return nil, err
  35. }
  36. input := antlr.NewInputStream(string(blob))
  37. lexer := parser.NewSQLiteLexer(input)
  38. stream := antlr.NewCommonTokenStream(lexer, 0)
  39. pp := parser.NewSQLiteParser(stream)
  40. el := &errorListener{}
  41. pp.AddErrorListener(el)
  42. // pp.BuildParseTrees = true
  43. tree := pp.Parse()
  44. if el.err != "" {
  45. return nil, errors.New(el.err)
  46. }
  47. pctx, ok := tree.(*parser.ParseContext)
  48. if !ok {
  49. return nil, fmt.Errorf("expected ParserContext; got %T\n", tree)
  50. }
  51. var stmts []ast.Statement
  52. for _, istmt := range pctx.AllSql_stmt_list() {
  53. list, ok := istmt.(*parser.Sql_stmt_listContext)
  54. if !ok {
  55. return nil, fmt.Errorf("expected Sql_stmt_listContext; got %T\n", istmt)
  56. }
  57. loc := 0
  58. for _, stmt := range list.AllSql_stmt() {
  59. converter := &cc{}
  60. out := converter.convert(stmt)
  61. if _, ok := out.(*ast.TODO); ok {
  62. continue
  63. }
  64. len := (stmt.GetStop().GetStop() + 1) - loc
  65. stmts = append(stmts, ast.Statement{
  66. Raw: &ast.RawStmt{
  67. Stmt: out,
  68. StmtLocation: loc,
  69. StmtLen: len,
  70. },
  71. })
  72. loc = stmt.GetStop().GetStop() + 2
  73. }
  74. }
  75. return stmts, nil
  76. }
  77. func (p *Parser) CommentSyntax() source.CommentSyntax {
  78. return source.CommentSyntax{
  79. Dash: true,
  80. Hash: false,
  81. SlashStar: true,
  82. }
  83. }