parse.go 16 KB


  1. //go:build !windows && cgo
  2. // +build !windows,cgo
  3. package postgresql
  4. import (
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strings"
  9. nodes "github.com/pganalyze/pg_query_go/v4"
  10. "github.com/pganalyze/pg_query_go/v4/parser"
  11. "github.com/sqlc-dev/sqlc/internal/source"
  12. "github.com/sqlc-dev/sqlc/internal/sql/ast"
  13. "github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
  14. )
  15. func stringSlice(list *nodes.List) []string {
  16. items := []string{}
  17. for _, item := range list.Items {
  18. if n, ok := item.Node.(*nodes.Node_String_); ok {
  19. items = append(items, n.String_.Sval)
  20. }
  21. }
  22. return items
  23. }
  24. func stringSliceFromNodes(s []*nodes.Node) []string {
  25. var items []string
  26. for _, item := range s {
  27. if n, ok := item.Node.(*nodes.Node_String_); ok {
  28. items = append(items, n.String_.Sval)
  29. }
  30. }
  31. return items
  32. }
  33. type relation struct {
  34. Catalog string
  35. Schema string
  36. Name string
  37. }
  38. func (r relation) TableName() *ast.TableName {
  39. return &ast.TableName{
  40. Catalog: r.Catalog,
  41. Schema: r.Schema,
  42. Name: r.Name,
  43. }
  44. }
  45. func (r relation) TypeName() *ast.TypeName {
  46. return &ast.TypeName{
  47. Catalog: r.Catalog,
  48. Schema: r.Schema,
  49. Name: r.Name,
  50. }
  51. }
  52. func (r relation) FuncName() *ast.FuncName {
  53. return &ast.FuncName{
  54. Catalog: r.Catalog,
  55. Schema: r.Schema,
  56. Name: r.Name,
  57. }
  58. }
  59. func parseRelationFromNodes(list []*nodes.Node) (*relation, error) {
  60. parts := stringSliceFromNodes(list)
  61. switch len(parts) {
  62. case 1:
  63. return &relation{
  64. Name: parts[0],
  65. }, nil
  66. case 2:
  67. return &relation{
  68. Schema: parts[0],
  69. Name: parts[1],
  70. }, nil
  71. case 3:
  72. return &relation{
  73. Catalog: parts[0],
  74. Schema: parts[1],
  75. Name: parts[2],
  76. }, nil
  77. default:
  78. return nil, fmt.Errorf("invalid name: %s", joinNodes(list, "."))
  79. }
  80. }
  81. func parseRelationFromRangeVar(rv *nodes.RangeVar) *relation {
  82. return &relation{
  83. Catalog: rv.Catalogname,
  84. Schema: rv.Schemaname,
  85. Name: rv.Relname,
  86. }
  87. }
  88. func parseRelation(in *nodes.Node) (*relation, error) {
  89. switch n := in.Node.(type) {
  90. case *nodes.Node_List:
  91. return parseRelationFromNodes(n.List.Items)
  92. case *nodes.Node_RangeVar:
  93. return parseRelationFromRangeVar(n.RangeVar), nil
  94. case *nodes.Node_TypeName:
  95. return parseRelationFromNodes(n.TypeName.Names)
  96. default:
  97. return nil, fmt.Errorf("unexpected node type: %T", n)
  98. }
  99. }
  100. func parseColName(node *nodes.Node) (*ast.ColumnRef, *ast.TableName, error) {
  101. switch n := node.Node.(type) {
  102. case *nodes.Node_List:
  103. parts := stringSlice(n.List)
  104. var tbl *ast.TableName
  105. var ref *ast.ColumnRef
  106. switch len(parts) {
  107. case 2:
  108. tbl = &ast.TableName{Name: parts[0]}
  109. ref = &ast.ColumnRef{Name: parts[1]}
  110. case 3:
  111. tbl = &ast.TableName{Schema: parts[0], Name: parts[1]}
  112. ref = &ast.ColumnRef{Name: parts[2]}
  113. case 4:
  114. tbl = &ast.TableName{Catalog: parts[0], Schema: parts[1], Name: parts[2]}
  115. ref = &ast.ColumnRef{Name: parts[3]}
  116. default:
  117. return nil, nil, fmt.Errorf("column specifier %q is not the proper format, expected '[catalog.][schema.]colname.tablename'", strings.Join(parts, "."))
  118. }
  119. return ref, tbl, nil
  120. default:
  121. return nil, nil, fmt.Errorf("parseColName: unexpected node type: %T", n)
  122. }
  123. }
  124. func joinNodes(list []*nodes.Node, sep string) string {
  125. return strings.Join(stringSliceFromNodes(list), sep)
  126. }
  127. func NewParser() *Parser {
  128. return &Parser{}
  129. }
  130. type Parser struct {
  131. }
  132. var errSkip = errors.New("skip stmt")
  133. func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
  134. contents, err := io.ReadAll(r)
  135. if err != nil {
  136. return nil, err
  137. }
  138. tree, err := nodes.Parse(string(contents))
  139. if err != nil {
  140. pErr := normalizeErr(err)
  141. return nil, pErr
  142. }
  143. var stmts []ast.Statement
  144. for _, raw := range tree.Stmts {
  145. n, err := translate(raw.Stmt)
  146. if err == errSkip {
  147. continue
  148. }
  149. if err != nil {
  150. return nil, err
  151. }
  152. if n == nil {
  153. return nil, fmt.Errorf("unexpected nil node")
  154. }
  155. stmts = append(stmts, ast.Statement{
  156. Raw: &ast.RawStmt{
  157. Stmt: n,
  158. StmtLocation: int(raw.StmtLocation),
  159. StmtLen: int(raw.StmtLen),
  160. },
  161. })
  162. }
  163. return stmts, nil
  164. }
  165. func normalizeErr(err error) error {
  166. //TODO: errors.As complains that *parser.Error does not implement error
  167. if pErr, ok := err.(*parser.Error); ok {
  168. sErr := &sqlerr.Error{
  169. Message: pErr.Message,
  170. //Err: pErr,
  171. Line: pErr.Lineno,
  172. Location: pErr.Cursorpos,
  173. }
  174. return sErr
  175. }
  176. return err
  177. }
  178. // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-COMMENTS
  179. func (p *Parser) CommentSyntax() source.CommentSyntax {
  180. return source.CommentSyntax{
  181. Dash: true,
  182. SlashStar: true,
  183. }
  184. }
  185. func translate(node *nodes.Node) (ast.Node, error) {
  186. switch inner := node.Node.(type) {
  187. case *nodes.Node_AlterEnumStmt:
  188. n := inner.AlterEnumStmt
  189. rel, err := parseRelationFromNodes(n.TypeName)
  190. if err != nil {
  191. return nil, err
  192. }
  193. if n.OldVal != "" {
  194. return &ast.AlterTypeRenameValueStmt{
  195. Type: rel.TypeName(),
  196. OldValue: makeString(n.OldVal),
  197. NewValue: makeString(n.NewVal),
  198. }, nil
  199. } else {
  200. return &ast.AlterTypeAddValueStmt{
  201. Type: rel.TypeName(),
  202. NewValue: makeString(n.NewVal),
  203. NewValHasNeighbor: len(n.NewValNeighbor) > 0,
  204. NewValNeighbor: makeString(n.NewValNeighbor),
  205. NewValIsAfter: n.NewValIsAfter,
  206. SkipIfNewValExists: n.SkipIfNewValExists,
  207. }, nil
  208. }
  209. case *nodes.Node_AlterObjectSchemaStmt:
  210. n := inner.AlterObjectSchemaStmt
  211. switch n.ObjectType {
  212. case nodes.ObjectType_OBJECT_TABLE, nodes.ObjectType_OBJECT_VIEW, nodes.ObjectType_OBJECT_MATVIEW:
  213. rel := parseRelationFromRangeVar(n.Relation)
  214. return &ast.AlterTableSetSchemaStmt{
  215. Table: rel.TableName(),
  216. NewSchema: makeString(n.Newschema),
  217. MissingOk: n.MissingOk,
  218. }, nil
  219. case nodes.ObjectType_OBJECT_TYPE:
  220. rel, err := parseRelation(n.Object)
  221. if err != nil {
  222. return nil, err
  223. }
  224. return &ast.AlterTypeSetSchemaStmt{
  225. Type: rel.TypeName(),
  226. NewSchema: makeString(n.Newschema),
  227. }, nil
  228. }
  229. return nil, errSkip
  230. case *nodes.Node_AlterTableStmt:
  231. n := inner.AlterTableStmt
  232. rel := parseRelationFromRangeVar(n.Relation)
  233. at := &ast.AlterTableStmt{
  234. Table: rel.TableName(),
  235. Cmds: &ast.List{},
  236. MissingOk: n.MissingOk,
  237. }
  238. for _, cmd := range n.Cmds {
  239. switch cmdOneOf := cmd.Node.(type) {
  240. case *nodes.Node_AlterTableCmd:
  241. altercmd := cmdOneOf.AlterTableCmd
  242. item := &ast.AlterTableCmd{Name: &altercmd.Name, MissingOk: altercmd.MissingOk}
  243. switch altercmd.Subtype {
  244. case nodes.AlterTableType_AT_AddColumn:
  245. d, ok := altercmd.Def.Node.(*nodes.Node_ColumnDef)
  246. if !ok {
  247. return nil, fmt.Errorf("expected alter table defintion to be a ColumnDef")
  248. }
  249. rel, err := parseRelationFromNodes(d.ColumnDef.TypeName.Names)
  250. if err != nil {
  251. return nil, err
  252. }
  253. item.Subtype = ast.AT_AddColumn
  254. item.Def = &ast.ColumnDef{
  255. Colname: d.ColumnDef.Colname,
  256. TypeName: rel.TypeName(),
  257. IsNotNull: isNotNull(d.ColumnDef),
  258. IsArray: isArray(d.ColumnDef.TypeName),
  259. ArrayDims: len(d.ColumnDef.TypeName.ArrayBounds),
  260. }
  261. case nodes.AlterTableType_AT_AlterColumnType:
  262. d, ok := altercmd.Def.Node.(*nodes.Node_ColumnDef)
  263. if !ok {
  264. return nil, fmt.Errorf("expected alter table defintion to be a ColumnDef")
  265. }
  266. col := ""
  267. if altercmd.Name != "" {
  268. col = altercmd.Name
  269. } else if d.ColumnDef.Colname != "" {
  270. col = d.ColumnDef.Colname
  271. } else {
  272. return nil, fmt.Errorf("unknown name for alter column type")
  273. }
  274. rel, err := parseRelationFromNodes(d.ColumnDef.TypeName.Names)
  275. if err != nil {
  276. return nil, err
  277. }
  278. item.Subtype = ast.AT_AlterColumnType
  279. item.Def = &ast.ColumnDef{
  280. Colname: col,
  281. TypeName: rel.TypeName(),
  282. IsNotNull: isNotNull(d.ColumnDef),
  283. IsArray: isArray(d.ColumnDef.TypeName),
  284. ArrayDims: len(d.ColumnDef.TypeName.ArrayBounds),
  285. }
  286. case nodes.AlterTableType_AT_DropColumn:
  287. item.Subtype = ast.AT_DropColumn
  288. case nodes.AlterTableType_AT_DropNotNull:
  289. item.Subtype = ast.AT_DropNotNull
  290. case nodes.AlterTableType_AT_SetNotNull:
  291. item.Subtype = ast.AT_SetNotNull
  292. default:
  293. continue
  294. }
  295. at.Cmds.Items = append(at.Cmds.Items, item)
  296. }
  297. }
  298. return at, nil
  299. case *nodes.Node_CommentStmt:
  300. n := inner.CommentStmt
  301. switch n.Objtype {
  302. case nodes.ObjectType_OBJECT_COLUMN:
  303. col, tbl, err := parseColName(n.Object)
  304. if err != nil {
  305. return nil, fmt.Errorf("COMMENT ON COLUMN: %w", err)
  306. }
  307. return &ast.CommentOnColumnStmt{
  308. Col: col,
  309. Table: tbl,
  310. Comment: makeString(n.Comment),
  311. }, nil
  312. case nodes.ObjectType_OBJECT_SCHEMA:
  313. o, ok := n.Object.Node.(*nodes.Node_String_)
  314. if !ok {
  315. return nil, fmt.Errorf("COMMENT ON SCHEMA: unexpected node type: %T", n.Object)
  316. }
  317. return &ast.CommentOnSchemaStmt{
  318. Schema: &ast.String{Str: o.String_.Sval},
  319. Comment: makeString(n.Comment),
  320. }, nil
  321. case nodes.ObjectType_OBJECT_TABLE:
  322. rel, err := parseRelation(n.Object)
  323. if err != nil {
  324. return nil, fmt.Errorf("COMMENT ON TABLE: %w", err)
  325. }
  326. return &ast.CommentOnTableStmt{
  327. Table: rel.TableName(),
  328. Comment: makeString(n.Comment),
  329. }, nil
  330. case nodes.ObjectType_OBJECT_TYPE:
  331. rel, err := parseRelation(n.Object)
  332. if err != nil {
  333. return nil, err
  334. }
  335. return &ast.CommentOnTypeStmt{
  336. Type: rel.TypeName(),
  337. Comment: makeString(n.Comment),
  338. }, nil
  339. case nodes.ObjectType_OBJECT_VIEW:
  340. rel, err := parseRelation(n.Object)
  341. if err != nil {
  342. return nil, fmt.Errorf("COMMENT ON VIEW: %w", err)
  343. }
  344. return &ast.CommentOnViewStmt{
  345. View: rel.TableName(),
  346. Comment: makeString(n.Comment),
  347. }, nil
  348. }
  349. return nil, errSkip
  350. case *nodes.Node_CompositeTypeStmt:
  351. n := inner.CompositeTypeStmt
  352. rel := parseRelationFromRangeVar(n.Typevar)
  353. return &ast.CompositeTypeStmt{
  354. TypeName: rel.TypeName(),
  355. }, nil
  356. case *nodes.Node_CreateStmt:
  357. n := inner.CreateStmt
  358. rel := parseRelationFromRangeVar(n.Relation)
  359. create := &ast.CreateTableStmt{
  360. Name: rel.TableName(),
  361. IfNotExists: n.IfNotExists,
  362. }
  363. for _, node := range n.InhRelations {
  364. switch item := node.Node.(type) {
  365. case *nodes.Node_RangeVar:
  366. if item.RangeVar.Inh {
  367. rel := parseRelationFromRangeVar(item.RangeVar)
  368. create.Inherits = append(create.Inherits, rel.TableName())
  369. }
  370. }
  371. }
  372. primaryKey := make(map[string]bool)
  373. for _, elt := range n.TableElts {
  374. switch item := elt.Node.(type) {
  375. case *nodes.Node_Constraint:
  376. if item.Constraint.Contype == nodes.ConstrType_CONSTR_PRIMARY {
  377. for _, key := range item.Constraint.Keys {
  378. // FIXME: Possible nil pointer dereference
  379. primaryKey[key.Node.(*nodes.Node_String_).String_.Sval] = true
  380. }
  381. }
  382. case *nodes.Node_TableLikeClause:
  383. rel := parseRelationFromRangeVar(item.TableLikeClause.Relation)
  384. create.ReferTable = rel.TableName()
  385. }
  386. }
  387. for _, elt := range n.TableElts {
  388. switch item := elt.Node.(type) {
  389. case *nodes.Node_ColumnDef:
  390. rel, err := parseRelationFromNodes(item.ColumnDef.TypeName.Names)
  391. if err != nil {
  392. return nil, err
  393. }
  394. primary := false
  395. for _, con := range item.ColumnDef.Constraints {
  396. if constraint, ok := con.Node.(*nodes.Node_Constraint); ok {
  397. primary = constraint.Constraint.Contype == nodes.ConstrType_CONSTR_PRIMARY
  398. }
  399. }
  400. create.Cols = append(create.Cols, &ast.ColumnDef{
  401. Colname: item.ColumnDef.Colname,
  402. TypeName: rel.TypeName(),
  403. IsNotNull: isNotNull(item.ColumnDef) || primaryKey[item.ColumnDef.Colname],
  404. IsArray: isArray(item.ColumnDef.TypeName),
  405. ArrayDims: len(item.ColumnDef.TypeName.ArrayBounds),
  406. PrimaryKey: primary,
  407. })
  408. }
  409. }
  410. return create, nil
  411. case *nodes.Node_CreateEnumStmt:
  412. n := inner.CreateEnumStmt
  413. rel, err := parseRelationFromNodes(n.TypeName)
  414. if err != nil {
  415. return nil, err
  416. }
  417. stmt := &ast.CreateEnumStmt{
  418. TypeName: rel.TypeName(),
  419. Vals: &ast.List{},
  420. }
  421. for _, val := range n.Vals {
  422. switch v := val.Node.(type) {
  423. case *nodes.Node_String_:
  424. stmt.Vals.Items = append(stmt.Vals.Items, &ast.String{
  425. Str: v.String_.Sval,
  426. })
  427. }
  428. }
  429. return stmt, nil
  430. case *nodes.Node_CreateFunctionStmt:
  431. n := inner.CreateFunctionStmt
  432. fn, err := parseRelationFromNodes(n.Funcname)
  433. if err != nil {
  434. return nil, err
  435. }
  436. var rt *ast.TypeName
  437. if n.ReturnType != nil {
  438. rel, err := parseRelationFromNodes(n.ReturnType.Names)
  439. if err != nil {
  440. return nil, err
  441. }
  442. rt = rel.TypeName()
  443. }
  444. stmt := &ast.CreateFunctionStmt{
  445. Func: fn.FuncName(),
  446. ReturnType: rt,
  447. Replace: n.Replace,
  448. Params: &ast.List{},
  449. }
  450. for _, item := range n.Parameters {
  451. arg := item.Node.(*nodes.Node_FunctionParameter).FunctionParameter
  452. rel, err := parseRelationFromNodes(arg.ArgType.Names)
  453. if err != nil {
  454. return nil, err
  455. }
  456. mode, err := convertFuncParamMode(arg.Mode)
  457. if err != nil {
  458. return nil, err
  459. }
  460. fp := &ast.FuncParam{
  461. Name: &arg.Name,
  462. Type: rel.TypeName(),
  463. Mode: mode,
  464. }
  465. if arg.Defexpr != nil {
  466. fp.DefExpr = &ast.TODO{}
  467. }
  468. stmt.Params.Items = append(stmt.Params.Items, fp)
  469. }
  470. return stmt, nil
  471. case *nodes.Node_CreateSchemaStmt:
  472. n := inner.CreateSchemaStmt
  473. return &ast.CreateSchemaStmt{
  474. Name: makeString(n.Schemaname),
  475. IfNotExists: n.IfNotExists,
  476. }, nil
  477. case *nodes.Node_DropStmt:
  478. n := inner.DropStmt
  479. switch n.RemoveType {
  480. case nodes.ObjectType_OBJECT_FUNCTION:
  481. drop := &ast.DropFunctionStmt{
  482. MissingOk: n.MissingOk,
  483. }
  484. for _, obj := range n.Objects {
  485. nowa, ok := obj.Node.(*nodes.Node_ObjectWithArgs)
  486. if !ok {
  487. return nil, fmt.Errorf("nodes.DropStmt: FUNCTION: unknown type in objects list: %T", obj)
  488. }
  489. owa := nowa.ObjectWithArgs
  490. fn, err := parseRelationFromNodes(owa.Objname)
  491. if err != nil {
  492. return nil, fmt.Errorf("nodes.DropStmt: FUNCTION: %w", err)
  493. }
  494. args := make([]*ast.TypeName, len(owa.Objargs))
  495. for i, objarg := range owa.Objargs {
  496. tn, ok := objarg.Node.(*nodes.Node_TypeName)
  497. if !ok {
  498. return nil, fmt.Errorf("nodes.DropStmt: FUNCTION: unknown type in objargs list: %T", objarg)
  499. }
  500. at, err := parseRelationFromNodes(tn.TypeName.Names)
  501. if err != nil {
  502. return nil, fmt.Errorf("nodes.DropStmt: FUNCTION: %w", err)
  503. }
  504. args[i] = at.TypeName()
  505. }
  506. drop.Funcs = append(drop.Funcs, &ast.FuncSpec{
  507. Name: fn.FuncName(),
  508. Args: args,
  509. HasArgs: !owa.ArgsUnspecified,
  510. })
  511. }
  512. return drop, nil
  513. case nodes.ObjectType_OBJECT_SCHEMA:
  514. drop := &ast.DropSchemaStmt{
  515. MissingOk: n.MissingOk,
  516. }
  517. for _, obj := range n.Objects {
  518. val, ok := obj.Node.(*nodes.Node_String_)
  519. if !ok {
  520. return nil, fmt.Errorf("nodes.DropStmt: SCHEMA: unknown type in objects list: %T", obj)
  521. }
  522. drop.Schemas = append(drop.Schemas, &ast.String{Str: val.String_.Sval})
  523. }
  524. return drop, nil
  525. case nodes.ObjectType_OBJECT_TABLE, nodes.ObjectType_OBJECT_VIEW, nodes.ObjectType_OBJECT_MATVIEW:
  526. drop := &ast.DropTableStmt{
  527. IfExists: n.MissingOk,
  528. }
  529. for _, obj := range n.Objects {
  530. name, err := parseRelation(obj)
  531. if err != nil {
  532. return nil, fmt.Errorf("nodes.DropStmt: TABLE: %w", err)
  533. }
  534. drop.Tables = append(drop.Tables, name.TableName())
  535. }
  536. return drop, nil
  537. case nodes.ObjectType_OBJECT_TYPE:
  538. drop := &ast.DropTypeStmt{
  539. IfExists: n.MissingOk,
  540. }
  541. for _, obj := range n.Objects {
  542. name, err := parseRelation(obj)
  543. if err != nil {
  544. return nil, fmt.Errorf("nodes.DropStmt: TYPE: %w", err)
  545. }
  546. drop.Types = append(drop.Types, name.TypeName())
  547. }
  548. return drop, nil
  549. }
  550. return nil, errSkip
  551. case *nodes.Node_RenameStmt:
  552. n := inner.RenameStmt
  553. switch n.RenameType {
  554. case nodes.ObjectType_OBJECT_COLUMN:
  555. rel := parseRelationFromRangeVar(n.Relation)
  556. return &ast.RenameColumnStmt{
  557. Table: rel.TableName(),
  558. Col: &ast.ColumnRef{Name: n.Subname},
  559. NewName: makeString(n.Newname),
  560. MissingOk: n.MissingOk,
  561. }, nil
  562. case nodes.ObjectType_OBJECT_TABLE:
  563. rel := parseRelationFromRangeVar(n.Relation)
  564. return &ast.RenameTableStmt{
  565. Table: rel.TableName(),
  566. NewName: makeString(n.Newname),
  567. MissingOk: n.MissingOk,
  568. }, nil
  569. case nodes.ObjectType_OBJECT_TYPE:
  570. rel, err := parseRelation(n.Object)
  571. if err != nil {
  572. return nil, fmt.Errorf("nodes.RenameStmt: TYPE: %w", err)
  573. }
  574. return &ast.RenameTypeStmt{
  575. Type: rel.TypeName(),
  576. NewName: makeString(n.Newname),
  577. }, nil
  578. }
  579. return nil, errSkip
  580. default:
  581. return convert(node)
  582. }
  583. }