1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- package sqlite
- import (
- "errors"
- "fmt"
- "io"
- "github.com/antlr/antlr4/runtime/Go/antlr/v4"
- "github.com/sqlc-dev/sqlc/internal/engine/sqlite/parser"
- "github.com/sqlc-dev/sqlc/internal/source"
- "github.com/sqlc-dev/sqlc/internal/sql/ast"
- )
- type errorListener struct {
- *antlr.DefaultErrorListener
- err string
- }
- func (el *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
- el.err = msg
- }
- // func (el *errorListener) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
- // }
- //
- // func (el *errorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
- // }
- //
- // func (el *errorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs antlr.ATNConfigSet) {
- // }
- func NewParser() *Parser {
- return &Parser{}
- }
- type Parser struct {
- }
- func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
- blob, err := io.ReadAll(r)
- if err != nil {
- return nil, err
- }
- input := antlr.NewInputStream(string(blob))
- lexer := parser.NewSQLiteLexer(input)
- stream := antlr.NewCommonTokenStream(lexer, 0)
- pp := parser.NewSQLiteParser(stream)
- el := &errorListener{}
- pp.AddErrorListener(el)
- // pp.BuildParseTrees = true
- tree := pp.Parse()
- if el.err != "" {
- return nil, errors.New(el.err)
- }
- pctx, ok := tree.(*parser.ParseContext)
- if !ok {
- return nil, fmt.Errorf("expected ParserContext; got %T\n", tree)
- }
- var stmts []ast.Statement
- for _, istmt := range pctx.AllSql_stmt_list() {
- list, ok := istmt.(*parser.Sql_stmt_listContext)
- if !ok {
- return nil, fmt.Errorf("expected Sql_stmt_listContext; got %T\n", istmt)
- }
- loc := 0
- for _, stmt := range list.AllSql_stmt() {
- converter := &cc{}
- out := converter.convert(stmt)
- if _, ok := out.(*ast.TODO); ok {
- continue
- }
- len := (stmt.GetStop().GetStop() + 1) - loc
- stmts = append(stmts, ast.Statement{
- Raw: &ast.RawStmt{
- Stmt: out,
- StmtLocation: loc,
- StmtLen: len,
- },
- })
- loc = stmt.GetStop().GetStop() + 2
- }
- }
- return stmts, nil
- }
- func (p *Parser) CommentSyntax() source.CommentSyntax {
- return source.CommentSyntax{
- Dash: true,
- Hash: false,
- SlashStar: true,
- }
- }
|