123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379 |
- package sqlite
- import (
- "github.com/antlr/antlr4/runtime/Go/antlr"
- "strings"
- "github.com/kyleconroy/sqlc/internal/engine/sqlite/parser"
- "github.com/kyleconroy/sqlc/internal/sql/ast"
- )
- type node interface {
- GetParser() antlr.Parser
- }
- func convertAlter_table_stmtContext(c *parser.Alter_table_stmtContext) ast.Node {
- if c.RENAME_() != nil {
- if newTable, ok := c.New_table_name().(*parser.New_table_nameContext); ok {
- name := newTable.Any_name().GetText()
- return &ast.RenameTableStmt{
- Table: parseTableName(c),
- NewName: &name,
- }
- }
- if newCol, ok := c.GetNew_column_name().(*parser.Column_nameContext); ok {
- name := newCol.Any_name().GetText()
- return &ast.RenameColumnStmt{
- Table: parseTableName(c),
- Col: &ast.ColumnRef{
- Name: c.GetOld_column_name().GetText(),
- },
- NewName: &name,
- }
- }
- }
- if c.ADD_() != nil {
- if def, ok := c.Column_def().(*parser.Column_defContext); ok {
- stmt := &ast.AlterTableStmt{
- Table: parseTableName(c),
- Cmds: &ast.List{},
- }
- name := def.Column_name().GetText()
- stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{
- Name: &name,
- Subtype: ast.AT_AddColumn,
- Def: &ast.ColumnDef{
- Colname: name,
- TypeName: &ast.TypeName{
- Name: def.Type_name().GetText(),
- },
- },
- })
- return stmt
- }
- }
- if c.DROP_() != nil {
- stmt := &ast.AlterTableStmt{
- Table: parseTableName(c),
- Cmds: &ast.List{},
- }
- name := c.Column_name(0).GetText()
- //fmt.Printf("column: %s", name)
- stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{
- Name: &name,
- Subtype: ast.AT_DropColumn,
- })
- return stmt
- }
- return &ast.TODO{}
- }
- func convertAttach_stmtContext(c *parser.Attach_stmtContext) ast.Node {
- name := c.Schema_name().GetText()
- return &ast.CreateSchemaStmt{
- Name: &name,
- }
- }
- func convertCreate_table_stmtContext(c *parser.Create_table_stmtContext) ast.Node {
- stmt := &ast.CreateTableStmt{
- Name: parseTableName(c),
- IfNotExists: c.EXISTS_() != nil,
- }
- for _, idef := range c.AllColumn_def() {
- if def, ok := idef.(*parser.Column_defContext); ok {
- stmt.Cols = append(stmt.Cols, &ast.ColumnDef{
- Colname: def.Column_name().GetText(),
- IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()),
- TypeName: &ast.TypeName{Name: def.Type_name().GetText()},
- })
- }
- }
- return stmt
- }
- func convertDrop_stmtContext(c *parser.Drop_stmtContext) ast.Node {
- // TODO confirm that this logic does what it looks like it should
- if tableName, ok := c.TABLE_().(antlr.TerminalNode); ok {
- name := ast.TableName{
- Name: tableName.GetText(),
- }
- if c.Schema_name() != nil {
- name.Schema = c.Schema_name().GetText()
- }
- return &ast.DropTableStmt{
- IfExists: c.EXISTS_() != nil,
- Tables: []*ast.TableName{&name},
- }
- } else {
- return &ast.TODO{}
- }
- }
- func identifier(id string) string {
- return strings.ToLower(id)
- }
- func NewIdentifer(t string) *ast.String {
- return &ast.String{Str: identifier(t)}
- }
- func convertExprContext(c *parser.ExprContext) ast.Node {
- if name, ok := c.Function_name().(*parser.Function_nameContext); ok {
- funcName := strings.ToLower(name.GetText())
- fn := &ast.FuncCall{
- Func: &ast.FuncName{
- Name: funcName,
- },
- Funcname: &ast.List{
- Items: []ast.Node{
- NewIdentifer(funcName),
- },
- },
- AggStar: c.STAR() != nil,
- Args: &ast.List{},
- AggOrder: &ast.List{},
- AggDistinct: c.DISTINCT_() != nil,
- }
- return fn
- }
- return &ast.Expr{}
- if c.Column_name().(*parser.Column_nameContext) != nil {
- return convertColumnNameExpr(c)
- }
- return &ast.TODO{}
- }
- func convertColumnNameExpr(c *parser.ExprContext) *ast.ColumnRef {
- var items []ast.Node
- if schema, ok := c.Schema_name().(*parser.Schema_nameContext); ok {
- schemaText := schema.GetText()
- if schemaText != "" {
- items = append(items, NewIdentifer(schemaText))
- }
- }
- if table, ok := c.Table_name().(*parser.Table_nameContext); ok {
- tableName := table.GetText()
- if tableName != "" {
- items = append(items, NewIdentifer(tableName))
- }
- }
- items = append(items, NewIdentifer(c.Column_name().GetText()))
- return &ast.ColumnRef{
- Fields: &ast.List{
- Items: items,
- },
- }
- }
- func convertSimpleSelect_stmtContext(c *parser.Simple_select_stmtContext) ast.Node {
- if core, ok := c.Select_core().(*parser.Select_coreContext); ok {
- cols := getCols(core)
- tables := getTables(core)
- return &ast.SelectStmt{
- FromClause: &ast.List{Items: tables},
- TargetList: &ast.List{Items: cols},
- }
- }
- return &ast.TODO{}
- }
- func convertMultiSelect_stmtContext(c multiselect) ast.Node {
- var tables []ast.Node
- var cols []ast.Node
- for _, icore := range c.AllSelect_core() {
- core, ok := icore.(*parser.Select_coreContext)
- if !ok {
- continue
- }
- cols = append(cols, getCols(core)...)
- tables = append(tables, getTables(core)...)
- }
- return &ast.SelectStmt{
- FromClause: &ast.List{Items: tables},
- TargetList: &ast.List{Items: cols},
- }
- }
- func getTables(core *parser.Select_coreContext) []ast.Node {
- var tables []ast.Node
- for _, ifrom := range core.AllTable_or_subquery() {
- from, ok := ifrom.(*parser.Table_or_subqueryContext)
- if !ok {
- continue
- }
- rel := from.Table_name().GetText()
- name := ast.RangeVar{
- Relname: &rel,
- Location: from.GetStart().GetStart(),
- }
- if from.Schema_name() != nil {
- text := from.Schema_name().GetText()
- name.Schemaname = &text
- }
- tables = append(tables, &name)
- }
- return tables
- }
- func getCols(core *parser.Select_coreContext) []ast.Node {
- var cols []ast.Node
- for _, icol := range core.AllResult_column() {
- col, ok := icol.(*parser.Result_columnContext)
- if !ok {
- continue
- }
- var val ast.Node
- iexpr := col.Expr()
- switch {
- case col.STAR() != nil:
- val = &ast.ColumnRef{
- Fields: &ast.List{
- Items: []ast.Node{
- &ast.A_Star{},
- },
- },
- Location: col.GetStart().GetStart(),
- }
- case iexpr != nil:
- val = convert(iexpr)
- }
- if val == nil {
- continue
- }
- cols = append(cols, &ast.ResTarget{
- Val: val,
- Location: col.GetStart().GetStart(),
- })
- }
- return cols
- }
- func convertSql_stmtContext(n *parser.Sql_stmtContext) ast.Node {
- if stmt := n.Alter_table_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Analyze_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Attach_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Begin_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Commit_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Create_index_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Create_table_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Create_trigger_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Create_view_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Create_virtual_table_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Delete_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Delete_stmt_limited(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Detach_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Drop_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Insert_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Pragma_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Reindex_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Release_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Rollback_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Savepoint_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Select_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Update_stmt(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Update_stmt_limited(); stmt != nil {
- return convert(stmt)
- }
- if stmt := n.Vacuum_stmt(); stmt != nil {
- return convert(stmt)
- }
- return nil
- }
- func convert(node node) ast.Node {
- switch n := node.(type) {
- case *parser.Alter_table_stmtContext:
- return convertAlter_table_stmtContext(n)
- case *parser.Attach_stmtContext:
- return convertAttach_stmtContext(n)
- case *parser.Create_table_stmtContext:
- return convertCreate_table_stmtContext(n)
- case *parser.Drop_stmtContext:
- return convertDrop_stmtContext(n)
- case *parser.ExprContext:
- return convertExprContext(n)
- case *parser.Factored_select_stmtContext:
- // TODO: need to handle this
- return &ast.TODO{}
- case *parser.Select_stmtContext:
- return convertMultiSelect_stmtContext(n)
- case *parser.Sql_stmtContext:
- return convertSql_stmtContext(n)
- case *parser.Simple_select_stmtContext:
- return convertSimpleSelect_stmtContext(n)
- case *parser.Compound_select_stmtContext:
- return convertMultiSelect_stmtContext(n)
- default:
- return &ast.TODO{}
- }
- }
|