1
0

parse.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. //go:build !windows
  2. // +build !windows
  3. package postgresql
  4. import (
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strings"
  9. nodes "github.com/pganalyze/pg_query_go/v2"
  10. "github.com/kyleconroy/sqlc/internal/metadata"
  11. "github.com/kyleconroy/sqlc/internal/sql/ast"
  12. )
  13. func stringSlice(list *nodes.List) []string {
  14. items := []string{}
  15. for _, item := range list.Items {
  16. if n, ok := item.Node.(*nodes.Node_String_); ok {
  17. items = append(items, n.String_.Str)
  18. }
  19. }
  20. return items
  21. }
  22. func stringSliceFromNodes(s []*nodes.Node) []string {
  23. var items []string
  24. for _, item := range s {
  25. if n, ok := item.Node.(*nodes.Node_String_); ok {
  26. items = append(items, n.String_.Str)
  27. }
  28. }
  29. return items
  30. }
  31. type relation struct {
  32. Catalog string
  33. Schema string
  34. Name string
  35. }
  36. func (r relation) TableName() *ast.TableName {
  37. return &ast.TableName{
  38. Catalog: r.Catalog,
  39. Schema: r.Schema,
  40. Name: r.Name,
  41. }
  42. }
  43. func (r relation) TypeName() *ast.TypeName {
  44. return &ast.TypeName{
  45. Catalog: r.Catalog,
  46. Schema: r.Schema,
  47. Name: r.Name,
  48. }
  49. }
  50. func (r relation) FuncName() *ast.FuncName {
  51. return &ast.FuncName{
  52. Catalog: r.Catalog,
  53. Schema: r.Schema,
  54. Name: r.Name,
  55. }
  56. }
  57. func parseRelationFromNodes(list []*nodes.Node) (*relation, error) {
  58. parts := stringSliceFromNodes(list)
  59. switch len(parts) {
  60. case 1:
  61. return &relation{
  62. Name: parts[0],
  63. }, nil
  64. case 2:
  65. return &relation{
  66. Schema: parts[0],
  67. Name: parts[1],
  68. }, nil
  69. case 3:
  70. return &relation{
  71. Catalog: parts[0],
  72. Schema: parts[1],
  73. Name: parts[2],
  74. }, nil
  75. default:
  76. return nil, fmt.Errorf("invalid name: %s", joinNodes(list, "."))
  77. }
  78. }
  79. func parseRelationFromRangeVar(rv *nodes.RangeVar) *relation {
  80. return &relation{
  81. Catalog: rv.Catalogname,
  82. Schema: rv.Schemaname,
  83. Name: rv.Relname,
  84. }
  85. }
  86. func parseRelation(in *nodes.Node) (*relation, error) {
  87. switch n := in.Node.(type) {
  88. case *nodes.Node_List:
  89. return parseRelationFromNodes(n.List.Items)
  90. case *nodes.Node_RangeVar:
  91. return parseRelationFromRangeVar(n.RangeVar), nil
  92. case *nodes.Node_TypeName:
  93. return parseRelationFromNodes(n.TypeName.Names)
  94. default:
  95. return nil, fmt.Errorf("unexpected node type: %T", n)
  96. }
  97. }
  98. func parseColName(node *nodes.Node) (*ast.ColumnRef, *ast.TableName, error) {
  99. switch n := node.Node.(type) {
  100. case *nodes.Node_List:
  101. parts := stringSlice(n.List)
  102. var tbl *ast.TableName
  103. var ref *ast.ColumnRef
  104. switch len(parts) {
  105. case 2:
  106. tbl = &ast.TableName{Name: parts[0]}
  107. ref = &ast.ColumnRef{Name: parts[1]}
  108. case 3:
  109. tbl = &ast.TableName{Schema: parts[0], Name: parts[1]}
  110. ref = &ast.ColumnRef{Name: parts[2]}
  111. case 4:
  112. tbl = &ast.TableName{Catalog: parts[0], Schema: parts[1], Name: parts[2]}
  113. ref = &ast.ColumnRef{Name: parts[3]}
  114. default:
  115. return nil, nil, fmt.Errorf("column specifier %q is not the proper format, expected '[catalog.][schema.]colname.tablename'", strings.Join(parts, "."))
  116. }
  117. return ref, tbl, nil
  118. default:
  119. return nil, nil, fmt.Errorf("parseColName: unexpected node type: %T", n)
  120. }
  121. }
  122. func join(list *nodes.List, sep string) string {
  123. return strings.Join(stringSlice(list), sep)
  124. }
  125. func joinNodes(list []*nodes.Node, sep string) string {
  126. return strings.Join(stringSliceFromNodes(list), sep)
  127. }
  128. func NewParser() *Parser {
  129. return &Parser{}
  130. }
  131. type Parser struct {
  132. }
  133. var errSkip = errors.New("skip stmt")
  134. func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
  135. contents, err := io.ReadAll(r)
  136. if err != nil {
  137. return nil, err
  138. }
  139. tree, err := nodes.Parse(string(contents))
  140. if err != nil {
  141. return nil, err
  142. }
  143. var stmts []ast.Statement
  144. for _, raw := range tree.Stmts {
  145. n, err := translate(raw.Stmt)
  146. if err == errSkip {
  147. continue
  148. }
  149. if err != nil {
  150. return nil, err
  151. }
  152. if n == nil {
  153. return nil, fmt.Errorf("unexpected nil node")
  154. }
  155. stmts = append(stmts, ast.Statement{
  156. Raw: &ast.RawStmt{
  157. Stmt: n,
  158. StmtLocation: int(raw.StmtLocation),
  159. StmtLen: int(raw.StmtLen),
  160. },
  161. })
  162. }
  163. return stmts, nil
  164. }
  165. // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-COMMENTS
  166. func (p *Parser) CommentSyntax() metadata.CommentSyntax {
  167. return metadata.CommentSyntax{
  168. Dash: true,
  169. SlashStar: true,
  170. }
  171. }
  172. func translate(node *nodes.Node) (ast.Node, error) {
  173. switch inner := node.Node.(type) {
  174. case *nodes.Node_AlterEnumStmt:
  175. n := inner.AlterEnumStmt
  176. rel, err := parseRelationFromNodes(n.TypeName)
  177. if err != nil {
  178. return nil, err
  179. }
  180. if n.OldVal != "" {
  181. return &ast.AlterTypeRenameValueStmt{
  182. Type: rel.TypeName(),
  183. OldValue: makeString(n.OldVal),
  184. NewValue: makeString(n.NewVal),
  185. }, nil
  186. } else {
  187. return &ast.AlterTypeAddValueStmt{
  188. Type: rel.TypeName(),
  189. NewValue: makeString(n.NewVal),
  190. SkipIfNewValExists: n.SkipIfNewValExists,
  191. }, nil
  192. }
  193. case *nodes.Node_AlterObjectSchemaStmt:
  194. n := inner.AlterObjectSchemaStmt
  195. switch n.ObjectType {
  196. case nodes.ObjectType_OBJECT_TABLE:
  197. rel := parseRelationFromRangeVar(n.Relation)
  198. return &ast.AlterTableSetSchemaStmt{
  199. Table: rel.TableName(),
  200. NewSchema: makeString(n.Newschema),
  201. }, nil
  202. }
  203. return nil, errSkip
  204. case *nodes.Node_AlterTableStmt:
  205. n := inner.AlterTableStmt
  206. rel := parseRelationFromRangeVar(n.Relation)
  207. at := &ast.AlterTableStmt{
  208. Table: rel.TableName(),
  209. Cmds: &ast.List{},
  210. }
  211. for _, cmd := range n.Cmds {
  212. switch cmdOneOf := cmd.Node.(type) {
  213. case *nodes.Node_AlterTableCmd:
  214. altercmd := cmdOneOf.AlterTableCmd
  215. item := &ast.AlterTableCmd{Name: &altercmd.Name, MissingOk: altercmd.MissingOk}
  216. switch altercmd.Subtype {
  217. case nodes.AlterTableType_AT_AddColumn:
  218. d, ok := altercmd.Def.Node.(*nodes.Node_ColumnDef)
  219. if !ok {
  220. return nil, fmt.Errorf("expected alter table defintion to be a ColumnDef")
  221. }
  222. rel, err := parseRelationFromNodes(d.ColumnDef.TypeName.Names)
  223. if err != nil {
  224. return nil, err
  225. }
  226. item.Subtype = ast.AT_AddColumn
  227. item.Def = &ast.ColumnDef{
  228. Colname: d.ColumnDef.Colname,
  229. TypeName: rel.TypeName(),
  230. IsNotNull: isNotNull(d.ColumnDef),
  231. IsArray: isArray(d.ColumnDef.TypeName),
  232. }
  233. case nodes.AlterTableType_AT_AlterColumnType:
  234. d, ok := altercmd.Def.Node.(*nodes.Node_ColumnDef)
  235. if !ok {
  236. return nil, fmt.Errorf("expected alter table defintion to be a ColumnDef")
  237. }
  238. col := ""
  239. if altercmd.Name != "" {
  240. col = altercmd.Name
  241. } else if d.ColumnDef.Colname != "" {
  242. col = d.ColumnDef.Colname
  243. } else {
  244. return nil, fmt.Errorf("unknown name for alter column type")
  245. }
  246. rel, err := parseRelationFromNodes(d.ColumnDef.TypeName.Names)
  247. if err != nil {
  248. return nil, err
  249. }
  250. item.Subtype = ast.AT_AlterColumnType
  251. item.Def = &ast.ColumnDef{
  252. Colname: col,
  253. TypeName: rel.TypeName(),
  254. IsNotNull: isNotNull(d.ColumnDef),
  255. IsArray: isArray(d.ColumnDef.TypeName),
  256. }
  257. case nodes.AlterTableType_AT_DropColumn:
  258. item.Subtype = ast.AT_DropColumn
  259. case nodes.AlterTableType_AT_DropNotNull:
  260. item.Subtype = ast.AT_DropNotNull
  261. case nodes.AlterTableType_AT_SetNotNull:
  262. item.Subtype = ast.AT_SetNotNull
  263. default:
  264. continue
  265. }
  266. at.Cmds.Items = append(at.Cmds.Items, item)
  267. }
  268. }
  269. return at, nil
  270. case *nodes.Node_CommentStmt:
  271. n := inner.CommentStmt
  272. switch n.Objtype {
  273. case nodes.ObjectType_OBJECT_COLUMN:
  274. col, tbl, err := parseColName(n.Object)
  275. if err != nil {
  276. return nil, fmt.Errorf("COMMENT ON COLUMN: %w", err)
  277. }
  278. return &ast.CommentOnColumnStmt{
  279. Col: col,
  280. Table: tbl,
  281. Comment: makeString(n.Comment),
  282. }, nil
  283. case nodes.ObjectType_OBJECT_SCHEMA:
  284. o, ok := n.Object.Node.(*nodes.Node_String_)
  285. if !ok {
  286. return nil, fmt.Errorf("COMMENT ON SCHEMA: unexpected node type: %T", n.Object)
  287. }
  288. return &ast.CommentOnSchemaStmt{
  289. Schema: &ast.String{Str: o.String_.Str},
  290. Comment: makeString(n.Comment),
  291. }, nil
  292. case nodes.ObjectType_OBJECT_TABLE:
  293. rel, err := parseRelation(n.Object)
  294. if err != nil {
  295. return nil, fmt.Errorf("COMMENT ON TABLE: %w", err)
  296. }
  297. return &ast.CommentOnTableStmt{
  298. Table: rel.TableName(),
  299. Comment: makeString(n.Comment),
  300. }, nil
  301. case nodes.ObjectType_OBJECT_TYPE:
  302. rel, err := parseRelation(n.Object)
  303. if err != nil {
  304. return nil, err
  305. }
  306. return &ast.CommentOnTypeStmt{
  307. Type: rel.TypeName(),
  308. Comment: makeString(n.Comment),
  309. }, nil
  310. }
  311. return nil, errSkip
  312. case *nodes.Node_CompositeTypeStmt:
  313. n := inner.CompositeTypeStmt
  314. rel := parseRelationFromRangeVar(n.Typevar)
  315. return &ast.CompositeTypeStmt{
  316. TypeName: rel.TypeName(),
  317. }, nil
  318. case *nodes.Node_CreateStmt:
  319. n := inner.CreateStmt
  320. rel := parseRelationFromRangeVar(n.Relation)
  321. create := &ast.CreateTableStmt{
  322. Name: rel.TableName(),
  323. IfNotExists: n.IfNotExists,
  324. }
  325. for _, node := range n.InhRelations {
  326. switch item := node.Node.(type) {
  327. case *nodes.Node_RangeVar:
  328. if item.RangeVar.Inh {
  329. rel := parseRelationFromRangeVar(item.RangeVar)
  330. create.Inherits = append(create.Inherits, rel.TableName())
  331. }
  332. }
  333. }
  334. primaryKey := make(map[string]bool)
  335. for _, elt := range n.TableElts {
  336. switch item := elt.Node.(type) {
  337. case *nodes.Node_Constraint:
  338. if item.Constraint.Contype == nodes.ConstrType_CONSTR_PRIMARY {
  339. for _, key := range item.Constraint.Keys {
  340. // FIXME: Possible nil pointer dereference
  341. primaryKey[key.Node.(*nodes.Node_String_).String_.Str] = true
  342. }
  343. }
  344. case *nodes.Node_TableLikeClause:
  345. rel := parseRelationFromRangeVar(item.TableLikeClause.Relation)
  346. create.ReferTable = rel.TableName()
  347. }
  348. }
  349. for _, elt := range n.TableElts {
  350. switch item := elt.Node.(type) {
  351. case *nodes.Node_ColumnDef:
  352. rel, err := parseRelationFromNodes(item.ColumnDef.TypeName.Names)
  353. if err != nil {
  354. return nil, err
  355. }
  356. create.Cols = append(create.Cols, &ast.ColumnDef{
  357. Colname: item.ColumnDef.Colname,
  358. TypeName: rel.TypeName(),
  359. IsNotNull: isNotNull(item.ColumnDef) || primaryKey[item.ColumnDef.Colname],
  360. IsArray: isArray(item.ColumnDef.TypeName),
  361. })
  362. }
  363. }
  364. return create, nil
  365. case *nodes.Node_CreateEnumStmt:
  366. n := inner.CreateEnumStmt
  367. rel, err := parseRelationFromNodes(n.TypeName)
  368. if err != nil {
  369. return nil, err
  370. }
  371. stmt := &ast.CreateEnumStmt{
  372. TypeName: rel.TypeName(),
  373. Vals: &ast.List{},
  374. }
  375. for _, val := range n.Vals {
  376. switch v := val.Node.(type) {
  377. case *nodes.Node_String_:
  378. stmt.Vals.Items = append(stmt.Vals.Items, &ast.String{
  379. Str: v.String_.Str,
  380. })
  381. }
  382. }
  383. return stmt, nil
  384. case *nodes.Node_CreateFunctionStmt:
  385. n := inner.CreateFunctionStmt
  386. fn, err := parseRelationFromNodes(n.Funcname)
  387. if err != nil {
  388. return nil, err
  389. }
  390. var rt *ast.TypeName
  391. if n.ReturnType != nil {
  392. rel, err := parseRelationFromNodes(n.ReturnType.Names)
  393. if err != nil {
  394. return nil, err
  395. }
  396. rt = rel.TypeName()
  397. }
  398. stmt := &ast.CreateFunctionStmt{
  399. Func: fn.FuncName(),
  400. ReturnType: rt,
  401. Replace: n.Replace,
  402. Params: &ast.List{},
  403. }
  404. for _, item := range n.Parameters {
  405. arg := item.Node.(*nodes.Node_FunctionParameter).FunctionParameter
  406. rel, err := parseRelationFromNodes(arg.ArgType.Names)
  407. if err != nil {
  408. return nil, err
  409. }
  410. mode, err := convertFuncParamMode(arg.Mode)
  411. if err != nil {
  412. return nil, err
  413. }
  414. fp := &ast.FuncParam{
  415. Name: &arg.Name,
  416. Type: rel.TypeName(),
  417. Mode: mode,
  418. }
  419. if arg.Defexpr != nil {
  420. fp.DefExpr = &ast.TODO{}
  421. }
  422. stmt.Params.Items = append(stmt.Params.Items, fp)
  423. }
  424. return stmt, nil
  425. case *nodes.Node_CreateSchemaStmt:
  426. n := inner.CreateSchemaStmt
  427. return &ast.CreateSchemaStmt{
  428. Name: makeString(n.Schemaname),
  429. IfNotExists: n.IfNotExists,
  430. }, nil
  431. case *nodes.Node_DropStmt:
  432. n := inner.DropStmt
  433. switch n.RemoveType {
  434. case nodes.ObjectType_OBJECT_FUNCTION:
  435. drop := &ast.DropFunctionStmt{
  436. MissingOk: n.MissingOk,
  437. }
  438. for _, obj := range n.Objects {
  439. nowa, ok := obj.Node.(*nodes.Node_ObjectWithArgs)
  440. if !ok {
  441. return nil, fmt.Errorf("nodes.DropStmt: FUNCTION: unknown type in objects list: %T", obj)
  442. }
  443. owa := nowa.ObjectWithArgs
  444. fn, err := parseRelationFromNodes(owa.Objname)
  445. if err != nil {
  446. return nil, fmt.Errorf("nodes.DropStmt: FUNCTION: %w", err)
  447. }
  448. args := make([]*ast.TypeName, len(owa.Objargs))
  449. for i, objarg := range owa.Objargs {
  450. tn, ok := objarg.Node.(*nodes.Node_TypeName)
  451. if !ok {
  452. return nil, fmt.Errorf("nodes.DropStmt: FUNCTION: unknown type in objargs list: %T", objarg)
  453. }
  454. at, err := parseRelationFromNodes(tn.TypeName.Names)
  455. if err != nil {
  456. return nil, fmt.Errorf("nodes.DropStmt: FUNCTION: %w", err)
  457. }
  458. args[i] = at.TypeName()
  459. }
  460. drop.Funcs = append(drop.Funcs, &ast.FuncSpec{
  461. Name: fn.FuncName(),
  462. Args: args,
  463. HasArgs: !owa.ArgsUnspecified,
  464. })
  465. }
  466. return drop, nil
  467. case nodes.ObjectType_OBJECT_SCHEMA:
  468. drop := &ast.DropSchemaStmt{
  469. MissingOk: n.MissingOk,
  470. }
  471. for _, obj := range n.Objects {
  472. val, ok := obj.Node.(*nodes.Node_String_)
  473. if !ok {
  474. return nil, fmt.Errorf("nodes.DropStmt: SCHEMA: unknown type in objects list: %T", obj)
  475. }
  476. drop.Schemas = append(drop.Schemas, &ast.String{Str: val.String_.Str})
  477. }
  478. return drop, nil
  479. case nodes.ObjectType_OBJECT_TABLE, nodes.ObjectType_OBJECT_VIEW, nodes.ObjectType_OBJECT_MATVIEW:
  480. drop := &ast.DropTableStmt{
  481. IfExists: n.MissingOk,
  482. }
  483. for _, obj := range n.Objects {
  484. name, err := parseRelation(obj)
  485. if err != nil {
  486. return nil, fmt.Errorf("nodes.DropStmt: TABLE: %w", err)
  487. }
  488. drop.Tables = append(drop.Tables, name.TableName())
  489. }
  490. return drop, nil
  491. case nodes.ObjectType_OBJECT_TYPE:
  492. drop := &ast.DropTypeStmt{
  493. IfExists: n.MissingOk,
  494. }
  495. for _, obj := range n.Objects {
  496. name, err := parseRelation(obj)
  497. if err != nil {
  498. return nil, fmt.Errorf("nodes.DropStmt: TYPE: %w", err)
  499. }
  500. drop.Types = append(drop.Types, name.TypeName())
  501. }
  502. return drop, nil
  503. }
  504. return nil, errSkip
  505. case *nodes.Node_RenameStmt:
  506. n := inner.RenameStmt
  507. switch n.RenameType {
  508. case nodes.ObjectType_OBJECT_COLUMN:
  509. rel := parseRelationFromRangeVar(n.Relation)
  510. return &ast.RenameColumnStmt{
  511. Table: rel.TableName(),
  512. Col: &ast.ColumnRef{Name: n.Subname},
  513. NewName: makeString(n.Newname),
  514. }, nil
  515. case nodes.ObjectType_OBJECT_TABLE:
  516. rel := parseRelationFromRangeVar(n.Relation)
  517. return &ast.RenameTableStmt{
  518. Table: rel.TableName(),
  519. NewName: makeString(n.Newname),
  520. }, nil
  521. case nodes.ObjectType_OBJECT_TYPE:
  522. rel, err := parseRelation(n.Object)
  523. if err != nil {
  524. return nil, fmt.Errorf("nodes.RenameStmt: TYPE: %w", err)
  525. }
  526. return &ast.RenameTypeStmt{
  527. Type: rel.TypeName(),
  528. NewName: makeString(n.Newname),
  529. }, nil
  530. }
  531. return nil, errSkip
  532. default:
  533. return convert(node)
  534. }
  535. }