1
0

convert.go 29 KB


  1. package sqlite
  2. import (
  3. "fmt"
  4. "log"
  5. "strconv"
  6. "strings"
  7. "github.com/antlr/antlr4/runtime/Go/antlr/v4"
  8. "github.com/sqlc-dev/sqlc/internal/debug"
  9. "github.com/sqlc-dev/sqlc/internal/engine/sqlite/parser"
  10. "github.com/sqlc-dev/sqlc/internal/sql/ast"
  11. )
  12. type cc struct {
  13. paramCount int
  14. }
  15. type node interface {
  16. GetParser() antlr.Parser
  17. }
  18. func todo(funcname string, n node) *ast.TODO {
  19. if debug.Active {
  20. log.Printf("sqlite.%s: Unknown node type %T\n", funcname, n)
  21. }
  22. return &ast.TODO{}
  23. }
  24. func identifier(id string) string {
  25. if len(id) >= 2 && id[0] == '"' && id[len(id)-1] == '"' {
  26. unquoted, _ := strconv.Unquote(id)
  27. return unquoted
  28. }
  29. return strings.ToLower(id)
  30. }
  31. func NewIdentifier(t string) *ast.String {
  32. return &ast.String{Str: identifier(t)}
  33. }
  34. func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) ast.Node {
  35. if n.RENAME_() != nil {
  36. if newTable, ok := n.New_table_name().(*parser.New_table_nameContext); ok {
  37. name := newTable.Any_name().GetText()
  38. return &ast.RenameTableStmt{
  39. Table: parseTableName(n),
  40. NewName: &name,
  41. }
  42. }
  43. if newCol, ok := n.GetNew_column_name().(*parser.Column_nameContext); ok {
  44. name := newCol.Any_name().GetText()
  45. return &ast.RenameColumnStmt{
  46. Table: parseTableName(n),
  47. Col: &ast.ColumnRef{
  48. Name: n.GetOld_column_name().GetText(),
  49. },
  50. NewName: &name,
  51. }
  52. }
  53. }
  54. if n.ADD_() != nil {
  55. if def, ok := n.Column_def().(*parser.Column_defContext); ok {
  56. stmt := &ast.AlterTableStmt{
  57. Table: parseTableName(n),
  58. Cmds: &ast.List{},
  59. }
  60. name := def.Column_name().GetText()
  61. stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{
  62. Name: &name,
  63. Subtype: ast.AT_AddColumn,
  64. Def: &ast.ColumnDef{
  65. Colname: name,
  66. TypeName: &ast.TypeName{
  67. Name: def.Type_name().GetText(),
  68. },
  69. IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()),
  70. },
  71. })
  72. return stmt
  73. }
  74. }
  75. if n.DROP_() != nil {
  76. stmt := &ast.AlterTableStmt{
  77. Table: parseTableName(n),
  78. Cmds: &ast.List{},
  79. }
  80. name := n.Column_name(0).GetText()
  81. stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{
  82. Name: &name,
  83. Subtype: ast.AT_DropColumn,
  84. })
  85. return stmt
  86. }
  87. return todo("convertAlter_table_stmtContext", n)
  88. }
  89. func (c *cc) convertAttach_stmtContext(n *parser.Attach_stmtContext) ast.Node {
  90. name := n.Schema_name().GetText()
  91. return &ast.CreateSchemaStmt{
  92. Name: &name,
  93. }
  94. }
  95. func (c *cc) convertCreate_table_stmtContext(n *parser.Create_table_stmtContext) ast.Node {
  96. stmt := &ast.CreateTableStmt{
  97. Name: parseTableName(n),
  98. IfNotExists: n.EXISTS_() != nil,
  99. }
  100. for _, idef := range n.AllColumn_def() {
  101. if def, ok := idef.(*parser.Column_defContext); ok {
  102. typeName := "any"
  103. if def.Type_name() != nil {
  104. typeName = def.Type_name().GetText()
  105. }
  106. stmt.Cols = append(stmt.Cols, &ast.ColumnDef{
  107. Colname: identifier(def.Column_name().GetText()),
  108. IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()),
  109. TypeName: &ast.TypeName{Name: typeName},
  110. })
  111. }
  112. }
  113. return stmt
  114. }
  115. func (c *cc) convertCreate_virtual_table_stmtContext(n *parser.Create_virtual_table_stmtContext) ast.Node {
  116. switch moduleName := n.Module_name().GetText(); moduleName {
  117. case "fts5":
  118. // https://www.sqlite.org/fts5.html
  119. return c.convertCreate_virtual_table_fts5(n)
  120. default:
  121. return todo(
  122. fmt.Sprintf("create_virtual_table. unsupported module name: %q", moduleName),
  123. n,
  124. )
  125. }
  126. }
  127. func (c *cc) convertCreate_virtual_table_fts5(n *parser.Create_virtual_table_stmtContext) ast.Node {
  128. stmt := &ast.CreateTableStmt{
  129. Name: parseTableName(n),
  130. IfNotExists: n.EXISTS_() != nil,
  131. }
  132. for _, arg := range n.AllModule_argument() {
  133. var columnName string
  134. // For example: CREATE VIRTUAL TABLE tbl_ft USING fts5(b, c UNINDEXED)
  135. // * the 'b' column is parsed like Expr_qualified_column_nameContext
  136. // * the 'c' column is parsed like Column_defContext
  137. if columnExpr, ok := arg.Expr().(*parser.Expr_qualified_column_nameContext); ok {
  138. columnName = columnExpr.Column_name().GetText()
  139. } else if columnDef, ok := arg.Column_def().(*parser.Column_defContext); ok {
  140. columnName = columnDef.Column_name().GetText()
  141. }
  142. if columnName != "" {
  143. stmt.Cols = append(stmt.Cols, &ast.ColumnDef{
  144. Colname: identifier(columnName),
  145. // you can not specify any column constraints in fts5, so we pass them manually
  146. IsNotNull: true,
  147. TypeName: &ast.TypeName{Name: "text"},
  148. })
  149. }
  150. }
  151. return stmt
  152. }
  153. func (c *cc) convertCreate_view_stmtContext(n *parser.Create_view_stmtContext) ast.Node {
  154. viewName := n.View_name().GetText()
  155. relation := &ast.RangeVar{
  156. Relname: &viewName,
  157. }
  158. if n.Schema_name() != nil {
  159. schemaName := n.Schema_name().GetText()
  160. relation.Schemaname = &schemaName
  161. }
  162. return &ast.ViewStmt{
  163. View: relation,
  164. Aliases: &ast.List{},
  165. Query: c.convert(n.Select_stmt()),
  166. Replace: false,
  167. Options: &ast.List{},
  168. WithCheckOption: ast.ViewCheckOption(0),
  169. }
  170. }
  171. type Delete_stmt interface {
  172. node
  173. Qualified_table_name() parser.IQualified_table_nameContext
  174. WHERE_() antlr.TerminalNode
  175. Expr() parser.IExprContext
  176. }
  177. func (c *cc) convertDelete_stmtContext(n Delete_stmt) ast.Node {
  178. if qualifiedName, ok := n.Qualified_table_name().(*parser.Qualified_table_nameContext); ok {
  179. tableName := qualifiedName.Table_name().GetText()
  180. relation := &ast.RangeVar{
  181. Relname: &tableName,
  182. }
  183. if qualifiedName.Schema_name() != nil {
  184. schemaName := qualifiedName.Schema_name().GetText()
  185. relation.Schemaname = &schemaName
  186. }
  187. if qualifiedName.Alias() != nil {
  188. alias := qualifiedName.Alias().GetText()
  189. relation.Alias = &ast.Alias{Aliasname: &alias}
  190. }
  191. relations := &ast.List{}
  192. relations.Items = append(relations.Items, relation)
  193. delete := &ast.DeleteStmt{
  194. Relations: relations,
  195. WithClause: nil,
  196. }
  197. if n.WHERE_() != nil && n.Expr() != nil {
  198. delete.WhereClause = c.convert(n.Expr())
  199. }
  200. if n, ok := n.(interface {
  201. Returning_clause() parser.IReturning_clauseContext
  202. }); ok {
  203. delete.ReturningList = c.convertReturning_caluseContext(n.Returning_clause())
  204. } else {
  205. delete.ReturningList = c.convertReturning_caluseContext(nil)
  206. }
  207. if n, ok := n.(interface {
  208. Limit_stmt() parser.ILimit_stmtContext
  209. }); ok {
  210. limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt())
  211. delete.LimitCount = limitCount
  212. }
  213. return delete
  214. }
  215. return todo("convertDelete_stmtContext", n)
  216. }
  217. func (c *cc) convertDrop_stmtContext(n *parser.Drop_stmtContext) ast.Node {
  218. if n.TABLE_() != nil || n.VIEW_() != nil {
  219. name := ast.TableName{
  220. Name: n.Any_name().GetText(),
  221. }
  222. if n.Schema_name() != nil {
  223. name.Schema = n.Schema_name().GetText()
  224. }
  225. return &ast.DropTableStmt{
  226. IfExists: n.EXISTS_() != nil,
  227. Tables: []*ast.TableName{&name},
  228. }
  229. }
  230. return todo("convertDrop_stmtContext", n)
  231. }
  232. func (c *cc) convertFuncContext(n *parser.Expr_functionContext) ast.Node {
  233. if name, ok := n.Qualified_function_name().(*parser.Qualified_function_nameContext); ok {
  234. funcName := strings.ToLower(name.Function_name().GetText())
  235. schema := ""
  236. if name.Schema_name() != nil {
  237. schema = name.Schema_name().GetText()
  238. }
  239. var argNodes []ast.Node
  240. for _, exp := range n.AllExpr() {
  241. argNodes = append(argNodes, c.convert(exp))
  242. }
  243. args := &ast.List{Items: argNodes}
  244. if funcName == "coalesce" {
  245. return &ast.CoalesceExpr{
  246. Args: args,
  247. Location: name.GetStart().GetStart(),
  248. }
  249. } else {
  250. return &ast.FuncCall{
  251. Func: &ast.FuncName{
  252. Schema: schema,
  253. Name: funcName,
  254. },
  255. Funcname: &ast.List{
  256. Items: []ast.Node{
  257. NewIdentifier(funcName),
  258. },
  259. },
  260. AggStar: n.STAR() != nil,
  261. Args: args,
  262. AggOrder: &ast.List{},
  263. AggDistinct: n.DISTINCT_() != nil,
  264. Location: name.GetStart().GetStart(),
  265. }
  266. }
  267. }
  268. return todo("convertFuncContext", n)
  269. }
  270. func (c *cc) convertExprContext(n *parser.ExprContext) ast.Node {
  271. return &ast.Expr{}
  272. }
  273. func (c *cc) convertColumnNameExpr(n *parser.Expr_qualified_column_nameContext) *ast.ColumnRef {
  274. var items []ast.Node
  275. if schema, ok := n.Schema_name().(*parser.Schema_nameContext); ok {
  276. schemaText := schema.GetText()
  277. if schemaText != "" {
  278. items = append(items, NewIdentifier(schemaText))
  279. }
  280. }
  281. if table, ok := n.Table_name().(*parser.Table_nameContext); ok {
  282. tableName := table.GetText()
  283. if tableName != "" {
  284. items = append(items, NewIdentifier(tableName))
  285. }
  286. }
  287. items = append(items, NewIdentifier(n.Column_name().GetText()))
  288. return &ast.ColumnRef{
  289. Fields: &ast.List{
  290. Items: items,
  291. },
  292. Location: n.GetStart().GetStart(),
  293. }
  294. }
  295. func (c *cc) convertComparison(n *parser.Expr_comparisonContext) ast.Node {
  296. lexpr := c.convert(n.Expr(0))
  297. if n.IN_() != nil {
  298. rexprs := []ast.Node{}
  299. for _, expr := range n.AllExpr()[1:] {
  300. e := c.convert(expr)
  301. switch t := e.(type) {
  302. case *ast.List:
  303. rexprs = append(rexprs, t.Items...)
  304. default:
  305. rexprs = append(rexprs, t)
  306. }
  307. }
  308. return &ast.In{
  309. Expr: lexpr,
  310. List: rexprs,
  311. Not: false,
  312. Sel: nil,
  313. Location: n.GetStart().GetStart(),
  314. }
  315. }
  316. return &ast.A_Expr{
  317. Name: &ast.List{
  318. Items: []ast.Node{
  319. &ast.String{Str: "="}, // TODO: add actual comparison
  320. },
  321. },
  322. Lexpr: lexpr,
  323. Rexpr: c.convert(n.Expr(1)),
  324. }
  325. }
  326. func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.Node {
  327. var ctes ast.List
  328. if ct := n.Common_table_stmt(); ct != nil {
  329. recursive := ct.RECURSIVE_() != nil
  330. for _, cte := range ct.AllCommon_table_expression() {
  331. tableName := identifier(cte.Table_name().GetText())
  332. var cteCols ast.List
  333. for _, col := range cte.AllColumn_name() {
  334. cteCols.Items = append(cteCols.Items, NewIdentifier(col.GetText()))
  335. }
  336. ctes.Items = append(ctes.Items, &ast.CommonTableExpr{
  337. Ctename: &tableName,
  338. Ctequery: c.convert(cte.Select_stmt()),
  339. Location: cte.GetStart().GetStart(),
  340. Cterecursive: recursive,
  341. Ctecolnames: &cteCols,
  342. })
  343. }
  344. }
  345. var selectStmt *ast.SelectStmt
  346. for s, icore := range n.AllSelect_core() {
  347. core, ok := icore.(*parser.Select_coreContext)
  348. if !ok {
  349. continue
  350. }
  351. cols := c.getCols(core)
  352. tables := c.getTables(core)
  353. var where ast.Node
  354. i := 0
  355. if core.WHERE_() != nil {
  356. where = c.convert(core.Expr(i))
  357. i++
  358. }
  359. var groups ast.List
  360. var having ast.Node
  361. if core.GROUP_() != nil {
  362. l := len(core.AllExpr()) - i
  363. if core.HAVING_() != nil {
  364. having = c.convert(core.Expr(l))
  365. l--
  366. }
  367. for i < l {
  368. groups.Items = append(groups.Items, c.convert(core.Expr(i)))
  369. i++
  370. }
  371. }
  372. var window ast.List
  373. if core.WINDOW_() != nil {
  374. for w, windowNameCtx := range core.AllWindow_name() {
  375. windowName := identifier(windowNameCtx.GetText())
  376. windowDef := core.Window_defn(w)
  377. _ = windowDef.Base_window_name()
  378. var partitionBy ast.List
  379. if windowDef.PARTITION_() != nil {
  380. for _, e := range windowDef.AllExpr() {
  381. partitionBy.Items = append(partitionBy.Items, c.convert(e))
  382. }
  383. }
  384. var orderBy ast.List
  385. if windowDef.ORDER_() != nil {
  386. for _, e := range windowDef.AllOrdering_term() {
  387. oterm := e.(*parser.Ordering_termContext)
  388. sortByDir := ast.SortByDirDefault
  389. if ad := oterm.Asc_desc(); ad != nil {
  390. if ad.ASC_() != nil {
  391. sortByDir = ast.SortByDirAsc
  392. } else {
  393. sortByDir = ast.SortByDirDesc
  394. }
  395. }
  396. sortByNulls := ast.SortByNullsDefault
  397. if oterm.NULLS_() != nil {
  398. if oterm.FIRST_() != nil {
  399. sortByNulls = ast.SortByNullsFirst
  400. } else {
  401. sortByNulls = ast.SortByNullsLast
  402. }
  403. }
  404. orderBy.Items = append(orderBy.Items, &ast.SortBy{
  405. Node: c.convert(oterm.Expr()),
  406. SortbyDir: sortByDir,
  407. SortbyNulls: sortByNulls,
  408. UseOp: &ast.List{},
  409. })
  410. }
  411. }
  412. window.Items = append(window.Items, &ast.WindowDef{
  413. Name: &windowName,
  414. PartitionClause: &partitionBy,
  415. OrderClause: &orderBy,
  416. FrameOptions: 0, // todo
  417. StartOffset: &ast.TODO{},
  418. EndOffset: &ast.TODO{},
  419. Location: windowNameCtx.GetStart().GetStart(),
  420. })
  421. }
  422. }
  423. sel := &ast.SelectStmt{
  424. FromClause: &ast.List{Items: tables},
  425. TargetList: &ast.List{Items: cols},
  426. WhereClause: where,
  427. GroupClause: &groups,
  428. HavingClause: having,
  429. WindowClause: &window,
  430. ValuesLists: &ast.List{},
  431. }
  432. if selectStmt == nil {
  433. selectStmt = sel
  434. } else {
  435. co := n.Compound_operator(s - 1)
  436. so := ast.None
  437. all := false
  438. switch {
  439. case co.UNION_() != nil:
  440. so = ast.Union
  441. all = co.ALL_() != nil
  442. case co.INTERSECT_() != nil:
  443. so = ast.Intersect
  444. case co.EXCEPT_() != nil:
  445. so = ast.Except
  446. }
  447. selectStmt = &ast.SelectStmt{
  448. TargetList: &ast.List{},
  449. FromClause: &ast.List{},
  450. Op: so,
  451. All: all,
  452. Larg: selectStmt,
  453. Rarg: sel,
  454. }
  455. }
  456. }
  457. limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt())
  458. selectStmt.LimitCount = limitCount
  459. selectStmt.LimitOffset = limitOffset
  460. selectStmt.WithClause = &ast.WithClause{Ctes: &ctes}
  461. return selectStmt
  462. }
  463. func (c *cc) convertExprListContext(n *parser.Expr_listContext) ast.Node {
  464. list := &ast.List{Items: []ast.Node{}}
  465. for _, e := range n.AllExpr() {
  466. list.Items = append(list.Items, c.convert(e))
  467. }
  468. return list
  469. }
  470. func (c *cc) getTables(core *parser.Select_coreContext) []ast.Node {
  471. if core.Join_clause() != nil {
  472. join := core.Join_clause().(*parser.Join_clauseContext)
  473. tables := c.convertTablesOrSubquery(join.AllTable_or_subquery())
  474. table := tables[0]
  475. for i, t := range tables[1:] {
  476. joinExpr := &ast.JoinExpr{
  477. Larg: table,
  478. Rarg: t,
  479. }
  480. jo := join.Join_operator(i)
  481. if jo.NATURAL_() != nil {
  482. joinExpr.IsNatural = true
  483. }
  484. switch {
  485. case jo.CROSS_() != nil || jo.INNER_() != nil:
  486. joinExpr.Jointype = ast.JoinTypeInner
  487. case jo.LEFT_() != nil:
  488. joinExpr.Jointype = ast.JoinTypeLeft
  489. case jo.RIGHT_() != nil:
  490. joinExpr.Jointype = ast.JoinTypeRight
  491. case jo.FULL_() != nil:
  492. joinExpr.Jointype = ast.JoinTypeFull
  493. }
  494. jc := join.Join_constraint(i)
  495. switch {
  496. case jc.ON_() != nil:
  497. joinExpr.Quals = c.convert(jc.Expr())
  498. case jc.USING_() != nil:
  499. var using ast.List
  500. for _, cn := range jc.AllColumn_name() {
  501. using.Items = append(using.Items, NewIdentifier(cn.GetText()))
  502. }
  503. joinExpr.UsingClause = &using
  504. }
  505. table = joinExpr
  506. }
  507. return []ast.Node{table}
  508. } else {
  509. return c.convertTablesOrSubquery(core.AllTable_or_subquery())
  510. }
  511. }
  512. func (c *cc) getCols(core *parser.Select_coreContext) []ast.Node {
  513. var cols []ast.Node
  514. for _, icol := range core.AllResult_column() {
  515. col, ok := icol.(*parser.Result_columnContext)
  516. if !ok {
  517. continue
  518. }
  519. target := &ast.ResTarget{
  520. Location: col.GetStart().GetStart(),
  521. }
  522. var val ast.Node
  523. iexpr := col.Expr()
  524. switch {
  525. case col.STAR() != nil:
  526. val = c.convertWildCardField(col)
  527. case iexpr != nil:
  528. val = c.convert(iexpr)
  529. }
  530. if val == nil {
  531. continue
  532. }
  533. if col.Column_alias() != nil {
  534. name := identifier(col.Column_alias().GetText())
  535. target.Name = &name
  536. }
  537. target.Val = val
  538. cols = append(cols, target)
  539. }
  540. return cols
  541. }
  542. func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef {
  543. items := []ast.Node{}
  544. if n.Table_name() != nil {
  545. items = append(items, NewIdentifier(n.Table_name().GetText()))
  546. }
  547. items = append(items, &ast.A_Star{})
  548. return &ast.ColumnRef{
  549. Fields: &ast.List{
  550. Items: items,
  551. },
  552. Location: n.GetStart().GetStart(),
  553. }
  554. }
  555. func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) ast.Node {
  556. if orderBy, ok := n.(*parser.Order_by_stmtContext); ok {
  557. list := &ast.List{Items: []ast.Node{}}
  558. for _, o := range orderBy.AllOrdering_term() {
  559. term, ok := o.(*parser.Ordering_termContext)
  560. if !ok {
  561. continue
  562. }
  563. list.Items = append(list.Items, &ast.CaseExpr{
  564. Xpr: c.convert(term.Expr()),
  565. Location: term.Expr().GetStart().GetStart(),
  566. })
  567. }
  568. return list
  569. }
  570. return todo("convertOrderby_stmtContext", n)
  571. }
  572. func (c *cc) convertLimit_stmtContext(n parser.ILimit_stmtContext) (ast.Node, ast.Node) {
  573. if n == nil {
  574. return nil, nil
  575. }
  576. var limitCount, limitOffset ast.Node
  577. if limit, ok := n.(*parser.Limit_stmtContext); ok {
  578. limitCount = c.convert(limit.Expr(0))
  579. if limit.OFFSET_() != nil {
  580. limitOffset = c.convert(limit.Expr(1))
  581. }
  582. }
  583. return limitCount, limitOffset
  584. }
  585. func (c *cc) convertSql_stmtContext(n *parser.Sql_stmtContext) ast.Node {
  586. if stmt := n.Alter_table_stmt(); stmt != nil {
  587. return c.convert(stmt)
  588. }
  589. if stmt := n.Analyze_stmt(); stmt != nil {
  590. return c.convert(stmt)
  591. }
  592. if stmt := n.Attach_stmt(); stmt != nil {
  593. return c.convert(stmt)
  594. }
  595. if stmt := n.Begin_stmt(); stmt != nil {
  596. return c.convert(stmt)
  597. }
  598. if stmt := n.Commit_stmt(); stmt != nil {
  599. return c.convert(stmt)
  600. }
  601. if stmt := n.Create_index_stmt(); stmt != nil {
  602. return c.convert(stmt)
  603. }
  604. if stmt := n.Create_table_stmt(); stmt != nil {
  605. return c.convert(stmt)
  606. }
  607. if stmt := n.Create_trigger_stmt(); stmt != nil {
  608. return c.convert(stmt)
  609. }
  610. if stmt := n.Create_view_stmt(); stmt != nil {
  611. return c.convert(stmt)
  612. }
  613. if stmt := n.Create_virtual_table_stmt(); stmt != nil {
  614. return c.convert(stmt)
  615. }
  616. if stmt := n.Delete_stmt(); stmt != nil {
  617. return c.convert(stmt)
  618. }
  619. if stmt := n.Delete_stmt_limited(); stmt != nil {
  620. return c.convert(stmt)
  621. }
  622. if stmt := n.Detach_stmt(); stmt != nil {
  623. return c.convert(stmt)
  624. }
  625. if stmt := n.Drop_stmt(); stmt != nil {
  626. return c.convert(stmt)
  627. }
  628. if stmt := n.Insert_stmt(); stmt != nil {
  629. return c.convert(stmt)
  630. }
  631. if stmt := n.Pragma_stmt(); stmt != nil {
  632. return c.convert(stmt)
  633. }
  634. if stmt := n.Reindex_stmt(); stmt != nil {
  635. return c.convert(stmt)
  636. }
  637. if stmt := n.Release_stmt(); stmt != nil {
  638. return c.convert(stmt)
  639. }
  640. if stmt := n.Rollback_stmt(); stmt != nil {
  641. return c.convert(stmt)
  642. }
  643. if stmt := n.Savepoint_stmt(); stmt != nil {
  644. return c.convert(stmt)
  645. }
  646. if stmt := n.Select_stmt(); stmt != nil {
  647. return c.convert(stmt)
  648. }
  649. if stmt := n.Update_stmt(); stmt != nil {
  650. return c.convert(stmt)
  651. }
  652. if stmt := n.Update_stmt_limited(); stmt != nil {
  653. return c.convert(stmt)
  654. }
  655. if stmt := n.Vacuum_stmt(); stmt != nil {
  656. return c.convert(stmt)
  657. }
  658. return nil
  659. }
  660. func (c *cc) convertLiteral(n *parser.Expr_literalContext) ast.Node {
  661. if literal, ok := n.Literal_value().(*parser.Literal_valueContext); ok {
  662. if literal.NUMERIC_LITERAL() != nil {
  663. i, _ := strconv.ParseInt(literal.GetText(), 10, 64)
  664. return &ast.A_Const{
  665. Val: &ast.Integer{Ival: i},
  666. Location: n.GetStart().GetStart(),
  667. }
  668. }
  669. if literal.STRING_LITERAL() != nil {
  670. // remove surrounding single quote
  671. text := literal.GetText()
  672. return &ast.A_Const{
  673. Val: &ast.String{Str: text[1 : len(text)-1]},
  674. Location: n.GetStart().GetStart(),
  675. }
  676. }
  677. if literal.TRUE_() != nil || literal.FALSE_() != nil {
  678. var i int64
  679. if literal.TRUE_() != nil {
  680. i = 1
  681. }
  682. return &ast.A_Const{
  683. Val: &ast.Integer{Ival: i},
  684. Location: n.GetStart().GetStart(),
  685. }
  686. }
  687. }
  688. return todo("convertLiteral", n)
  689. }
  690. func (c *cc) convertBinaryNode(n *parser.Expr_binaryContext) ast.Node {
  691. return &ast.A_Expr{
  692. Name: &ast.List{
  693. Items: []ast.Node{
  694. &ast.String{Str: n.GetChild(1).(antlr.TerminalNode).GetText()},
  695. },
  696. },
  697. Lexpr: c.convert(n.Expr(0)),
  698. Rexpr: c.convert(n.Expr(1)),
  699. }
  700. }
  701. func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node {
  702. return &ast.BoolExpr{
  703. // TODO: Set op
  704. Args: &ast.List{
  705. Items: []ast.Node{
  706. c.convert(n.Expr(0)),
  707. c.convert(n.Expr(1)),
  708. },
  709. },
  710. }
  711. }
  712. func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node {
  713. if n.NUMBERED_BIND_PARAMETER() != nil {
  714. // Parameter numbers start at one
  715. c.paramCount += 1
  716. text := n.GetText()
  717. number := c.paramCount
  718. if len(text) > 1 {
  719. number, _ = strconv.Atoi(text[1:])
  720. }
  721. return &ast.ParamRef{
  722. Number: number,
  723. Location: n.GetStart().GetStart(),
  724. Dollar: len(text) > 1,
  725. }
  726. }
  727. if n.NAMED_BIND_PARAMETER() != nil {
  728. return &ast.A_Expr{
  729. Name: &ast.List{Items: []ast.Node{&ast.String{Str: "@"}}},
  730. Rexpr: &ast.String{Str: n.GetText()[1:]},
  731. Location: n.GetStart().GetStart(),
  732. }
  733. }
  734. return todo("convertParam", n)
  735. }
  736. func (c *cc) convertInSelectNode(n *parser.Expr_in_selectContext) ast.Node {
  737. return c.convert(n.Select_stmt())
  738. }
  739. func (c *cc) convertReturning_caluseContext(n parser.IReturning_clauseContext) *ast.List {
  740. list := &ast.List{Items: []ast.Node{}}
  741. if n == nil {
  742. return list
  743. }
  744. r, ok := n.(*parser.Returning_clauseContext)
  745. if !ok {
  746. return list
  747. }
  748. for _, exp := range r.AllExpr() {
  749. list.Items = append(list.Items, &ast.ResTarget{
  750. Indirection: &ast.List{},
  751. Val: c.convert(exp),
  752. })
  753. }
  754. for _, star := range r.AllSTAR() {
  755. list.Items = append(list.Items, &ast.ResTarget{
  756. Indirection: &ast.List{},
  757. Val: &ast.ColumnRef{
  758. Fields: &ast.List{
  759. Items: []ast.Node{&ast.A_Star{}},
  760. },
  761. Location: star.GetSymbol().GetStart(),
  762. },
  763. Location: star.GetSymbol().GetStart(),
  764. })
  765. }
  766. return list
  767. }
  768. func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node {
  769. tableName := n.Table_name().GetText()
  770. rel := &ast.RangeVar{
  771. Relname: &tableName,
  772. }
  773. if n.Schema_name() != nil {
  774. schemaName := n.Schema_name().GetText()
  775. rel.Schemaname = &schemaName
  776. }
  777. if n.Table_alias() != nil {
  778. tableAlias := identifier(n.Table_alias().GetText())
  779. rel.Alias = &ast.Alias{
  780. Aliasname: &tableAlias,
  781. }
  782. }
  783. insert := &ast.InsertStmt{
  784. Relation: rel,
  785. Cols: c.convertColumnNames(n.AllColumn_name()),
  786. ReturningList: c.convertReturning_caluseContext(n.Returning_clause()),
  787. }
  788. if n.Select_stmt() != nil {
  789. if ss, ok := c.convert(n.Select_stmt()).(*ast.SelectStmt); ok {
  790. ss.ValuesLists = &ast.List{}
  791. insert.SelectStmt = ss
  792. }
  793. } else {
  794. var valuesLists ast.List
  795. var values *ast.List
  796. for _, cn := range n.GetChildren() {
  797. switch cn := cn.(type) {
  798. case antlr.TerminalNode:
  799. switch cn.GetSymbol().GetTokenType() {
  800. case parser.SQLiteParserVALUES_:
  801. values = &ast.List{}
  802. case parser.SQLiteParserOPEN_PAR:
  803. if values != nil {
  804. values = &ast.List{}
  805. }
  806. case parser.SQLiteParserCOMMA:
  807. case parser.SQLiteParserCLOSE_PAR:
  808. if values != nil {
  809. valuesLists.Items = append(valuesLists.Items, values)
  810. }
  811. }
  812. case parser.IExprContext:
  813. if values != nil {
  814. values.Items = append(values.Items, c.convert(cn))
  815. }
  816. }
  817. }
  818. insert.SelectStmt = &ast.SelectStmt{
  819. FromClause: &ast.List{},
  820. TargetList: &ast.List{},
  821. ValuesLists: &valuesLists,
  822. }
  823. }
  824. return insert
  825. }
  826. func (c *cc) convertColumnNames(cols []parser.IColumn_nameContext) *ast.List {
  827. list := &ast.List{Items: []ast.Node{}}
  828. for _, c := range cols {
  829. name := identifier(c.GetText())
  830. list.Items = append(list.Items, &ast.ResTarget{
  831. Name: &name,
  832. })
  833. }
  834. return list
  835. }
  836. func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast.Node {
  837. var tables []ast.Node
  838. for _, ifrom := range n {
  839. from, ok := ifrom.(*parser.Table_or_subqueryContext)
  840. if !ok {
  841. continue
  842. }
  843. if from.Table_name() != nil {
  844. rel := from.Table_name().GetText()
  845. rv := &ast.RangeVar{
  846. Relname: &rel,
  847. Location: from.GetStart().GetStart(),
  848. }
  849. if from.Schema_name() != nil {
  850. schema := from.Schema_name().GetText()
  851. rv.Schemaname = &schema
  852. }
  853. if from.Table_alias() != nil {
  854. alias := identifier(from.Table_alias().GetText())
  855. rv.Alias = &ast.Alias{Aliasname: &alias}
  856. }
  857. if from.Table_alias_fallback() != nil {
  858. alias := identifier(from.Table_alias_fallback().GetText())
  859. rv.Alias = &ast.Alias{Aliasname: &alias}
  860. }
  861. tables = append(tables, rv)
  862. } else if from.Table_function_name() != nil {
  863. rel := from.Table_function_name().GetText()
  864. rf := &ast.RangeFunction{
  865. Functions: &ast.List{
  866. Items: []ast.Node{
  867. &ast.FuncCall{
  868. Func: &ast.FuncName{
  869. Name: rel,
  870. },
  871. Funcname: &ast.List{
  872. Items: []ast.Node{
  873. NewIdentifier(rel),
  874. },
  875. },
  876. Args: &ast.List{
  877. Items: []ast.Node{&ast.TODO{}},
  878. },
  879. Location: from.GetStart().GetStart(),
  880. },
  881. },
  882. },
  883. }
  884. if from.Table_alias() != nil {
  885. alias := identifier(from.Table_alias().GetText())
  886. rf.Alias = &ast.Alias{Aliasname: &alias}
  887. }
  888. tables = append(tables, rf)
  889. } else if from.Select_stmt() != nil {
  890. rs := &ast.RangeSubselect{
  891. Subquery: c.convert(from.Select_stmt()),
  892. }
  893. if from.Table_alias() != nil {
  894. alias := identifier(from.Table_alias().GetText())
  895. rs.Alias = &ast.Alias{Aliasname: &alias}
  896. }
  897. tables = append(tables, rs)
  898. }
  899. }
  900. return tables
  901. }
  902. type Update_stmt interface {
  903. Qualified_table_name() parser.IQualified_table_nameContext
  904. GetStart() antlr.Token
  905. AllColumn_name() []parser.IColumn_nameContext
  906. WHERE_() antlr.TerminalNode
  907. Expr(i int) parser.IExprContext
  908. AllExpr() []parser.IExprContext
  909. }
  910. func (c *cc) convertUpdate_stmtContext(n Update_stmt) ast.Node {
  911. if n == nil {
  912. return nil
  913. }
  914. relations := &ast.List{}
  915. tableName := n.Qualified_table_name().GetText()
  916. rel := ast.RangeVar{
  917. Relname: &tableName,
  918. Location: n.GetStart().GetStart(),
  919. }
  920. relations.Items = append(relations.Items, &rel)
  921. list := &ast.List{}
  922. for i, col := range n.AllColumn_name() {
  923. colName := identifier(col.GetText())
  924. target := &ast.ResTarget{
  925. Name: &colName,
  926. Val: c.convert(n.Expr(i)),
  927. }
  928. list.Items = append(list.Items, target)
  929. }
  930. var where ast.Node = nil
  931. if n.WHERE_() != nil {
  932. where = c.convert(n.Expr(len(n.AllExpr()) - 1))
  933. }
  934. stmt := &ast.UpdateStmt{
  935. Relations: relations,
  936. TargetList: list,
  937. WhereClause: where,
  938. FromClause: &ast.List{},
  939. WithClause: nil, // TODO: support with clause
  940. }
  941. if n, ok := n.(interface {
  942. Returning_clause() parser.IReturning_clauseContext
  943. }); ok {
  944. stmt.ReturningList = c.convertReturning_caluseContext(n.Returning_clause())
  945. } else {
  946. stmt.ReturningList = c.convertReturning_caluseContext(nil)
  947. }
  948. if n, ok := n.(interface {
  949. Limit_stmt() parser.ILimit_stmtContext
  950. }); ok {
  951. limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt())
  952. stmt.LimitCount = limitCount
  953. }
  954. return stmt
  955. }
  956. func (c *cc) convertBetweenExpr(n *parser.Expr_betweenContext) ast.Node {
  957. return &ast.BetweenExpr{
  958. Expr: c.convert(n.Expr(0)),
  959. Left: c.convert(n.Expr(1)),
  960. Right: c.convert(n.Expr(2)),
  961. Location: n.GetStart().GetStart(),
  962. Not: n.NOT_() != nil,
  963. }
  964. }
  965. func (c *cc) convertCastExpr(n *parser.Expr_castContext) ast.Node {
  966. name := n.Type_name().GetText()
  967. return &ast.TypeCast{
  968. Arg: c.convert(n.Expr()),
  969. TypeName: &ast.TypeName{
  970. Name: name,
  971. Names: &ast.List{Items: []ast.Node{
  972. NewIdentifier(name),
  973. }},
  974. ArrayBounds: &ast.List{},
  975. },
  976. Location: n.GetStart().GetStart(),
  977. }
  978. }
  979. func (c *cc) convertCollateExpr(n *parser.Expr_collateContext) ast.Node {
  980. return &ast.CollateExpr{
  981. Xpr: c.convert(n.Expr()),
  982. Arg: NewIdentifier(n.Collation_name().GetText()),
  983. Location: n.GetStart().GetStart(),
  984. }
  985. }
  986. func (c *cc) convertCase(n *parser.Expr_caseContext) ast.Node {
  987. e := &ast.CaseExpr{
  988. Args: &ast.List{},
  989. }
  990. es := n.AllExpr()
  991. if n.ELSE_() != nil {
  992. e.Defresult = c.convert(es[len(es)-1])
  993. es = es[:len(es)-1]
  994. }
  995. if len(es)%2 == 1 {
  996. e.Arg = c.convert(es[0])
  997. es = es[1:]
  998. }
  999. for i := 0; i < len(es); i += 2 {
  1000. e.Args.Items = append(e.Args.Items, &ast.CaseWhen{
  1001. Expr: c.convert(es[i+0]),
  1002. Result: c.convert(es[i+1]),
  1003. })
  1004. }
  1005. return e
  1006. }
  1007. func (c *cc) convert(node node) ast.Node {
  1008. switch n := node.(type) {
  1009. case *parser.Alter_table_stmtContext:
  1010. return c.convertAlter_table_stmtContext(n)
  1011. case *parser.Attach_stmtContext:
  1012. return c.convertAttach_stmtContext(n)
  1013. case *parser.Create_table_stmtContext:
  1014. return c.convertCreate_table_stmtContext(n)
  1015. case *parser.Create_virtual_table_stmtContext:
  1016. return c.convertCreate_virtual_table_stmtContext(n)
  1017. case *parser.Create_view_stmtContext:
  1018. return c.convertCreate_view_stmtContext(n)
  1019. case *parser.Drop_stmtContext:
  1020. return c.convertDrop_stmtContext(n)
  1021. case *parser.Delete_stmtContext:
  1022. return c.convertDelete_stmtContext(n)
  1023. case *parser.Delete_stmt_limitedContext:
  1024. return c.convertDelete_stmtContext(n)
  1025. case *parser.ExprContext:
  1026. return c.convertExprContext(n)
  1027. case *parser.Expr_functionContext:
  1028. return c.convertFuncContext(n)
  1029. case *parser.Expr_qualified_column_nameContext:
  1030. return c.convertColumnNameExpr(n)
  1031. case *parser.Expr_comparisonContext:
  1032. return c.convertComparison(n)
  1033. case *parser.Expr_bindContext:
  1034. return c.convertParam(n)
  1035. case *parser.Expr_literalContext:
  1036. return c.convertLiteral(n)
  1037. case *parser.Expr_boolContext:
  1038. return c.convertBoolNode(n)
  1039. case *parser.Expr_listContext:
  1040. return c.convertExprListContext(n)
  1041. case *parser.Expr_binaryContext:
  1042. return c.convertBinaryNode(n)
  1043. case *parser.Expr_in_selectContext:
  1044. return c.convertInSelectNode(n)
  1045. case *parser.Expr_betweenContext:
  1046. return c.convertBetweenExpr(n)
  1047. case *parser.Expr_collateContext:
  1048. return c.convertCollateExpr(n)
  1049. case *parser.Factored_select_stmtContext:
  1050. // TODO: need to handle this
  1051. return todo("convert(case=parser.Factored_select_stmtContext)", n)
  1052. case *parser.Insert_stmtContext:
  1053. return c.convertInsert_stmtContext(n)
  1054. case *parser.Order_by_stmtContext:
  1055. return c.convertOrderby_stmtContext(n)
  1056. case *parser.Select_stmtContext:
  1057. return c.convertMultiSelect_stmtContext(n)
  1058. case *parser.Sql_stmtContext:
  1059. return c.convertSql_stmtContext(n)
  1060. case *parser.Update_stmtContext:
  1061. return c.convertUpdate_stmtContext(n)
  1062. case *parser.Update_stmt_limitedContext:
  1063. return c.convertUpdate_stmtContext(n)
  1064. case *parser.Expr_castContext:
  1065. return c.convertCastExpr(n)
  1066. case *parser.Expr_caseContext:
  1067. return c.convertCase(n)
  1068. default:
  1069. return todo("convert(case=default)", n)
  1070. }
  1071. }