convert.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. package sqlite
  2. import (
  3. "github.com/antlr/antlr4/runtime/Go/antlr"
  4. "strings"
  5. "github.com/kyleconroy/sqlc/internal/engine/sqlite/parser"
  6. "github.com/kyleconroy/sqlc/internal/sql/ast"
  7. )
  8. type node interface {
  9. GetParser() antlr.Parser
  10. }
  11. func convertAlter_table_stmtContext(c *parser.Alter_table_stmtContext) ast.Node {
  12. if c.RENAME_() != nil {
  13. if newTable, ok := c.New_table_name().(*parser.New_table_nameContext); ok {
  14. name := newTable.Any_name().GetText()
  15. return &ast.RenameTableStmt{
  16. Table: parseTableName(c),
  17. NewName: &name,
  18. }
  19. }
  20. if newCol, ok := c.GetNew_column_name().(*parser.Column_nameContext); ok {
  21. name := newCol.Any_name().GetText()
  22. return &ast.RenameColumnStmt{
  23. Table: parseTableName(c),
  24. Col: &ast.ColumnRef{
  25. Name: c.GetOld_column_name().GetText(),
  26. },
  27. NewName: &name,
  28. }
  29. }
  30. }
  31. if c.ADD_() != nil {
  32. if def, ok := c.Column_def().(*parser.Column_defContext); ok {
  33. stmt := &ast.AlterTableStmt{
  34. Table: parseTableName(c),
  35. Cmds: &ast.List{},
  36. }
  37. name := def.Column_name().GetText()
  38. stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{
  39. Name: &name,
  40. Subtype: ast.AT_AddColumn,
  41. Def: &ast.ColumnDef{
  42. Colname: name,
  43. TypeName: &ast.TypeName{
  44. Name: def.Type_name().GetText(),
  45. },
  46. },
  47. })
  48. return stmt
  49. }
  50. }
  51. if c.DROP_() != nil {
  52. stmt := &ast.AlterTableStmt{
  53. Table: parseTableName(c),
  54. Cmds: &ast.List{},
  55. }
  56. name := c.Column_name(0).GetText()
  57. //fmt.Printf("column: %s", name)
  58. stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{
  59. Name: &name,
  60. Subtype: ast.AT_DropColumn,
  61. })
  62. return stmt
  63. }
  64. return &ast.TODO{}
  65. }
  66. func convertAttach_stmtContext(c *parser.Attach_stmtContext) ast.Node {
  67. name := c.Schema_name().GetText()
  68. return &ast.CreateSchemaStmt{
  69. Name: &name,
  70. }
  71. }
  72. func convertCreate_table_stmtContext(c *parser.Create_table_stmtContext) ast.Node {
  73. stmt := &ast.CreateTableStmt{
  74. Name: parseTableName(c),
  75. IfNotExists: c.EXISTS_() != nil,
  76. }
  77. for _, idef := range c.AllColumn_def() {
  78. if def, ok := idef.(*parser.Column_defContext); ok {
  79. stmt.Cols = append(stmt.Cols, &ast.ColumnDef{
  80. Colname: def.Column_name().GetText(),
  81. IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()),
  82. TypeName: &ast.TypeName{Name: def.Type_name().GetText()},
  83. })
  84. }
  85. }
  86. return stmt
  87. }
  88. func convertDrop_stmtContext(c *parser.Drop_stmtContext) ast.Node {
  89. // TODO confirm that this logic does what it looks like it should
  90. if tableName, ok := c.TABLE_().(antlr.TerminalNode); ok {
  91. name := ast.TableName{
  92. Name: tableName.GetText(),
  93. }
  94. if c.Schema_name() != nil {
  95. name.Schema = c.Schema_name().GetText()
  96. }
  97. return &ast.DropTableStmt{
  98. IfExists: c.EXISTS_() != nil,
  99. Tables: []*ast.TableName{&name},
  100. }
  101. } else {
  102. return &ast.TODO{}
  103. }
  104. }
  105. func identifier(id string) string {
  106. return strings.ToLower(id)
  107. }
  108. func NewIdentifer(t string) *ast.String {
  109. return &ast.String{Str: identifier(t)}
  110. }
  111. func convertExprContext(c *parser.ExprContext) ast.Node {
  112. if name, ok := c.Function_name().(*parser.Function_nameContext); ok {
  113. funcName := strings.ToLower(name.GetText())
  114. fn := &ast.FuncCall{
  115. Func: &ast.FuncName{
  116. Name: funcName,
  117. },
  118. Funcname: &ast.List{
  119. Items: []ast.Node{
  120. NewIdentifer(funcName),
  121. },
  122. },
  123. AggStar: c.STAR() != nil,
  124. Args: &ast.List{},
  125. AggOrder: &ast.List{},
  126. AggDistinct: c.DISTINCT_() != nil,
  127. }
  128. return fn
  129. }
  130. return &ast.Expr{}
  131. if c.Column_name().(*parser.Column_nameContext) != nil {
  132. return convertColumnNameExpr(c)
  133. }
  134. return &ast.TODO{}
  135. }
  136. func convertColumnNameExpr(c *parser.ExprContext) *ast.ColumnRef {
  137. var items []ast.Node
  138. if schema, ok := c.Schema_name().(*parser.Schema_nameContext); ok {
  139. schemaText := schema.GetText()
  140. if schemaText != "" {
  141. items = append(items, NewIdentifer(schemaText))
  142. }
  143. }
  144. if table, ok := c.Table_name().(*parser.Table_nameContext); ok {
  145. tableName := table.GetText()
  146. if tableName != "" {
  147. items = append(items, NewIdentifer(tableName))
  148. }
  149. }
  150. items = append(items, NewIdentifer(c.Column_name().GetText()))
  151. return &ast.ColumnRef{
  152. Fields: &ast.List{
  153. Items: items,
  154. },
  155. }
  156. }
  157. func convertSimpleSelect_stmtContext(c *parser.Simple_select_stmtContext) ast.Node {
  158. if core, ok := c.Select_core().(*parser.Select_coreContext); ok {
  159. cols := getCols(core)
  160. tables := getTables(core)
  161. return &ast.SelectStmt{
  162. FromClause: &ast.List{Items: tables},
  163. TargetList: &ast.List{Items: cols},
  164. }
  165. }
  166. return &ast.TODO{}
  167. }
  168. func convertMultiSelect_stmtContext(c multiselect) ast.Node {
  169. var tables []ast.Node
  170. var cols []ast.Node
  171. for _, icore := range c.AllSelect_core() {
  172. core, ok := icore.(*parser.Select_coreContext)
  173. if !ok {
  174. continue
  175. }
  176. cols = append(cols, getCols(core)...)
  177. tables = append(tables, getTables(core)...)
  178. }
  179. return &ast.SelectStmt{
  180. FromClause: &ast.List{Items: tables},
  181. TargetList: &ast.List{Items: cols},
  182. }
  183. }
  184. func getTables(core *parser.Select_coreContext) []ast.Node {
  185. var tables []ast.Node
  186. for _, ifrom := range core.AllTable_or_subquery() {
  187. from, ok := ifrom.(*parser.Table_or_subqueryContext)
  188. if !ok {
  189. continue
  190. }
  191. rel := from.Table_name().GetText()
  192. name := ast.RangeVar{
  193. Relname: &rel,
  194. Location: from.GetStart().GetStart(),
  195. }
  196. if from.Schema_name() != nil {
  197. text := from.Schema_name().GetText()
  198. name.Schemaname = &text
  199. }
  200. tables = append(tables, &name)
  201. }
  202. return tables
  203. }
  204. func getCols(core *parser.Select_coreContext) []ast.Node {
  205. var cols []ast.Node
  206. for _, icol := range core.AllResult_column() {
  207. col, ok := icol.(*parser.Result_columnContext)
  208. if !ok {
  209. continue
  210. }
  211. var val ast.Node
  212. iexpr := col.Expr()
  213. switch {
  214. case col.STAR() != nil:
  215. val = &ast.ColumnRef{
  216. Fields: &ast.List{
  217. Items: []ast.Node{
  218. &ast.A_Star{},
  219. },
  220. },
  221. Location: col.GetStart().GetStart(),
  222. }
  223. case iexpr != nil:
  224. val = convert(iexpr)
  225. }
  226. if val == nil {
  227. continue
  228. }
  229. cols = append(cols, &ast.ResTarget{
  230. Val: val,
  231. Location: col.GetStart().GetStart(),
  232. })
  233. }
  234. return cols
  235. }
  236. func convertSql_stmtContext(n *parser.Sql_stmtContext) ast.Node {
  237. if stmt := n.Alter_table_stmt(); stmt != nil {
  238. return convert(stmt)
  239. }
  240. if stmt := n.Analyze_stmt(); stmt != nil {
  241. return convert(stmt)
  242. }
  243. if stmt := n.Attach_stmt(); stmt != nil {
  244. return convert(stmt)
  245. }
  246. if stmt := n.Begin_stmt(); stmt != nil {
  247. return convert(stmt)
  248. }
  249. if stmt := n.Commit_stmt(); stmt != nil {
  250. return convert(stmt)
  251. }
  252. if stmt := n.Create_index_stmt(); stmt != nil {
  253. return convert(stmt)
  254. }
  255. if stmt := n.Create_table_stmt(); stmt != nil {
  256. return convert(stmt)
  257. }
  258. if stmt := n.Create_trigger_stmt(); stmt != nil {
  259. return convert(stmt)
  260. }
  261. if stmt := n.Create_view_stmt(); stmt != nil {
  262. return convert(stmt)
  263. }
  264. if stmt := n.Create_virtual_table_stmt(); stmt != nil {
  265. return convert(stmt)
  266. }
  267. if stmt := n.Delete_stmt(); stmt != nil {
  268. return convert(stmt)
  269. }
  270. if stmt := n.Delete_stmt_limited(); stmt != nil {
  271. return convert(stmt)
  272. }
  273. if stmt := n.Detach_stmt(); stmt != nil {
  274. return convert(stmt)
  275. }
  276. if stmt := n.Drop_stmt(); stmt != nil {
  277. return convert(stmt)
  278. }
  279. if stmt := n.Insert_stmt(); stmt != nil {
  280. return convert(stmt)
  281. }
  282. if stmt := n.Pragma_stmt(); stmt != nil {
  283. return convert(stmt)
  284. }
  285. if stmt := n.Reindex_stmt(); stmt != nil {
  286. return convert(stmt)
  287. }
  288. if stmt := n.Release_stmt(); stmt != nil {
  289. return convert(stmt)
  290. }
  291. if stmt := n.Rollback_stmt(); stmt != nil {
  292. return convert(stmt)
  293. }
  294. if stmt := n.Savepoint_stmt(); stmt != nil {
  295. return convert(stmt)
  296. }
  297. if stmt := n.Select_stmt(); stmt != nil {
  298. return convert(stmt)
  299. }
  300. if stmt := n.Update_stmt(); stmt != nil {
  301. return convert(stmt)
  302. }
  303. if stmt := n.Update_stmt_limited(); stmt != nil {
  304. return convert(stmt)
  305. }
  306. if stmt := n.Vacuum_stmt(); stmt != nil {
  307. return convert(stmt)
  308. }
  309. return nil
  310. }
  311. func convert(node node) ast.Node {
  312. switch n := node.(type) {
  313. case *parser.Alter_table_stmtContext:
  314. return convertAlter_table_stmtContext(n)
  315. case *parser.Attach_stmtContext:
  316. return convertAttach_stmtContext(n)
  317. case *parser.Create_table_stmtContext:
  318. return convertCreate_table_stmtContext(n)
  319. case *parser.Drop_stmtContext:
  320. return convertDrop_stmtContext(n)
  321. case *parser.ExprContext:
  322. return convertExprContext(n)
  323. case *parser.Factored_select_stmtContext:
  324. // TODO: need to handle this
  325. return &ast.TODO{}
  326. case *parser.Select_stmtContext:
  327. return convertMultiSelect_stmtContext(n)
  328. case *parser.Sql_stmtContext:
  329. return convertSql_stmtContext(n)
  330. case *parser.Simple_select_stmtContext:
  331. return convertSimpleSelect_stmtContext(n)
  332. case *parser.Compound_select_stmtContext:
  333. return convertMultiSelect_stmtContext(n)
  334. default:
  335. return &ast.TODO{}
  336. }
  337. }