printer.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. package printer
  2. import (
  3. "strconv"
  4. "strings"
  5. "github.com/kyleconroy/sqlc/internal/python/ast"
  6. )
  7. type writer struct {
  8. options Options
  9. src []byte
  10. }
  11. type Options struct {
  12. }
  13. type PrintResult struct {
  14. Python []byte
  15. }
  16. func Print(node *ast.Node, options Options) PrintResult {
  17. w := writer{options: options}
  18. w.printNode(node, 0)
  19. return PrintResult{
  20. Python: w.src,
  21. }
  22. }
  23. func (w *writer) print(text string) {
  24. w.src = append(w.src, text...)
  25. }
  26. func (w *writer) printIndent(indent int32) {
  27. for i, n := 0, int(indent); i < n; i++ {
  28. w.src = append(w.src, " "...)
  29. }
  30. }
  31. func (w *writer) printNode(node *ast.Node, indent int32) {
  32. switch n := node.Node.(type) {
  33. case *ast.Node_Alias:
  34. w.print(n.Alias.Name)
  35. case *ast.Node_AnnAssign:
  36. w.printAnnAssign(n.AnnAssign, indent)
  37. case *ast.Node_Assign:
  38. w.printAssign(n.Assign, indent)
  39. case *ast.Node_AsyncFor:
  40. w.printAsyncFor(n.AsyncFor, indent)
  41. case *ast.Node_AsyncFunctionDef:
  42. w.printAsyncFunctionDef(n.AsyncFunctionDef, indent)
  43. case *ast.Node_Attribute:
  44. w.printAttribute(n.Attribute, indent)
  45. case *ast.Node_Await:
  46. w.printAwait(n.Await, indent)
  47. case *ast.Node_Call:
  48. w.printCall(n.Call, indent)
  49. case *ast.Node_ClassDef:
  50. w.printClassDef(n.ClassDef, indent)
  51. case *ast.Node_Comment:
  52. w.printComment(n.Comment, indent)
  53. case *ast.Node_Compare:
  54. w.printCompare(n.Compare, indent)
  55. case *ast.Node_Constant:
  56. w.printConstant(n.Constant, indent)
  57. case *ast.Node_Dict:
  58. w.printDict(n.Dict, indent)
  59. case *ast.Node_Expr:
  60. w.printNode(n.Expr.Value, indent)
  61. case *ast.Node_For:
  62. w.printFor(n.For, indent)
  63. case *ast.Node_FunctionDef:
  64. w.printFunctionDef(n.FunctionDef, indent)
  65. case *ast.Node_If:
  66. w.printIf(n.If, indent)
  67. case *ast.Node_Import:
  68. w.printImport(n.Import, indent)
  69. case *ast.Node_ImportFrom:
  70. w.printImportFrom(n.ImportFrom, indent)
  71. case *ast.Node_ImportGroup:
  72. w.printImportGroup(n.ImportGroup, indent)
  73. case *ast.Node_Is:
  74. w.print("is")
  75. case *ast.Node_Keyword:
  76. w.printKeyword(n.Keyword, indent)
  77. case *ast.Node_Module:
  78. w.printModule(n.Module, indent)
  79. case *ast.Node_Name:
  80. w.print(n.Name.Id)
  81. case *ast.Node_Pass:
  82. w.print("pass")
  83. case *ast.Node_Return:
  84. w.printReturn(n.Return, indent)
  85. case *ast.Node_Subscript:
  86. w.printSubscript(n.Subscript, indent)
  87. case *ast.Node_Yield:
  88. w.printYield(n.Yield, indent)
  89. default:
  90. panic(n)
  91. }
  92. }
  93. func (w *writer) printAnnAssign(aa *ast.AnnAssign, indent int32) {
  94. if aa.Comment != "" {
  95. w.print("# ")
  96. w.print(aa.Comment)
  97. w.print("\n")
  98. w.printIndent(indent)
  99. }
  100. w.printName(aa.Target, indent)
  101. w.print(": ")
  102. w.printNode(aa.Annotation, indent)
  103. }
  104. func (w *writer) printArg(a *ast.Arg, indent int32) {
  105. w.print(a.Arg)
  106. if a.Annotation != nil {
  107. w.print(": ")
  108. w.printNode(a.Annotation, indent)
  109. }
  110. }
  111. func (w *writer) printAssign(a *ast.Assign, indent int32) {
  112. for i, name := range a.Targets {
  113. w.printNode(name, indent)
  114. if i != len(a.Targets)-1 {
  115. w.print(", ")
  116. }
  117. }
  118. w.print(" = ")
  119. w.printNode(a.Value, indent)
  120. }
  121. func (w *writer) printAsyncFor(n *ast.AsyncFor, indent int32) {
  122. w.print("async ")
  123. w.printFor(&ast.For{
  124. Target: n.Target,
  125. Iter: n.Iter,
  126. Body: n.Body,
  127. }, indent)
  128. }
  129. func (w *writer) printAsyncFunctionDef(afd *ast.AsyncFunctionDef, indent int32) {
  130. w.print("async ")
  131. w.printFunctionDef(&ast.FunctionDef{
  132. Name: afd.Name,
  133. Args: afd.Args,
  134. Body: afd.Body,
  135. Returns: afd.Returns,
  136. }, indent)
  137. }
  138. func (w *writer) printAttribute(a *ast.Attribute, indent int32) {
  139. if _, ok := a.Value.Node.(*ast.Node_Await); ok {
  140. w.print("(")
  141. w.printNode(a.Value, indent)
  142. w.print(")")
  143. } else {
  144. w.printNode(a.Value, indent)
  145. }
  146. w.print(".")
  147. w.print(a.Attr)
  148. }
  149. func (w *writer) printAwait(n *ast.Await, indent int32) {
  150. w.print("await ")
  151. w.printNode(n.Value, indent)
  152. }
  153. func (w *writer) printCall(c *ast.Call, indent int32) {
  154. w.printNode(c.Func, indent)
  155. w.print("(")
  156. for i, a := range c.Args {
  157. w.printNode(a, indent)
  158. if i != len(c.Args)-1 {
  159. w.print(", ")
  160. }
  161. }
  162. for _, kw := range c.Keywords {
  163. w.print("\n")
  164. w.printIndent(indent + 1)
  165. w.printKeyword(kw, indent+1)
  166. w.print(",")
  167. }
  168. if len(c.Keywords) > 0 {
  169. w.print("\n")
  170. w.printIndent(indent)
  171. }
  172. w.print(")")
  173. }
  174. func (w *writer) printClassDef(cd *ast.ClassDef, indent int32) {
  175. for _, node := range cd.DecoratorList {
  176. w.print("@")
  177. w.printNode(node, indent)
  178. w.print("\n")
  179. }
  180. w.print("class ")
  181. w.print(cd.Name)
  182. if len(cd.Bases) > 0 {
  183. w.print("(")
  184. for i, node := range cd.Bases {
  185. w.printNode(node, indent)
  186. if i != len(cd.Bases)-1 {
  187. w.print(", ")
  188. }
  189. }
  190. w.print(")")
  191. }
  192. w.print(":\n")
  193. for i, node := range cd.Body {
  194. if i != 0 {
  195. if _, ok := node.Node.(*ast.Node_FunctionDef); ok {
  196. w.print("\n")
  197. }
  198. if _, ok := node.Node.(*ast.Node_AsyncFunctionDef); ok {
  199. w.print("\n")
  200. }
  201. }
  202. w.printIndent(indent + 1)
  203. // A docstring is a string literal that occurs as the first
  204. // statement in a module, function, class, or method
  205. // definition. Such a docstring becomes the __doc__ special
  206. // attribute of that object.
  207. if i == 0 {
  208. if e, ok := node.Node.(*ast.Node_Expr); ok {
  209. if c, ok := e.Expr.Value.Node.(*ast.Node_Constant); ok {
  210. w.print(`""`)
  211. w.printConstant(c.Constant, indent)
  212. w.print(`""`)
  213. w.print("\n")
  214. continue
  215. }
  216. }
  217. }
  218. w.printNode(node, indent+1)
  219. w.print("\n")
  220. }
  221. }
  222. func (w *writer) printConstant(c *ast.Constant, indent int32) {
  223. switch n := c.Value.(type) {
  224. case *ast.Constant_Int:
  225. w.print(strconv.Itoa(int(n.Int)))
  226. case *ast.Constant_None:
  227. w.print("None")
  228. case *ast.Constant_Str:
  229. str := `"`
  230. if strings.Contains(n.Str, "\n") {
  231. str = `"""`
  232. }
  233. w.print(str)
  234. w.print(n.Str)
  235. w.print(str)
  236. default:
  237. panic(n)
  238. }
  239. }
  240. func (w *writer) printComment(c *ast.Comment, indent int32) {
  241. w.print("# ")
  242. w.print(c.Text)
  243. w.print("\n")
  244. }
  245. func (w *writer) printCompare(c *ast.Compare, indent int32) {
  246. w.printNode(c.Left, indent)
  247. w.print(" ")
  248. for _, node := range c.Ops {
  249. w.printNode(node, indent)
  250. w.print(" ")
  251. }
  252. for _, node := range c.Comparators {
  253. w.printNode(node, indent)
  254. }
  255. }
  256. func (w *writer) printDict(d *ast.Dict, indent int32) {
  257. if len(d.Keys) != len(d.Values) {
  258. panic(`dict keys and values are not the same length`)
  259. }
  260. w.print("{")
  261. split := len(d.Keys) > 3
  262. keyIndent := indent
  263. if split {
  264. keyIndent += 1
  265. }
  266. for i, _ := range d.Keys {
  267. if split {
  268. w.print("\n")
  269. w.printIndent(keyIndent)
  270. }
  271. w.printNode(d.Keys[i], keyIndent)
  272. w.print(": ")
  273. w.printNode(d.Values[i], keyIndent)
  274. if i != len(d.Keys)-1 || split {
  275. if split {
  276. w.print(",")
  277. } else {
  278. w.print(", ")
  279. }
  280. }
  281. }
  282. if split {
  283. w.print("\n")
  284. w.printIndent(indent)
  285. }
  286. w.print("}")
  287. }
  288. func (w *writer) printFor(n *ast.For, indent int32) {
  289. w.print("for ")
  290. w.printNode(n.Target, indent)
  291. w.print(" in ")
  292. w.printNode(n.Iter, indent)
  293. w.print(":\n")
  294. for i, node := range n.Body {
  295. w.printIndent(indent + 1)
  296. w.printNode(node, indent+1)
  297. if i != len(n.Body)-1 {
  298. w.print("\n")
  299. }
  300. }
  301. }
  302. func (w *writer) printIf(i *ast.If, indent int32) {
  303. w.print("if ")
  304. w.printNode(i.Test, indent)
  305. w.print(":\n")
  306. for j, node := range i.Body {
  307. w.printIndent(indent + 1)
  308. w.printNode(node, indent+1)
  309. if j != len(i.Body)-1 {
  310. w.print("\n")
  311. }
  312. }
  313. }
  314. func (w *writer) printFunctionDef(fd *ast.FunctionDef, indent int32) {
  315. w.print("def ")
  316. w.print(fd.Name)
  317. w.print("(")
  318. if fd.Args != nil {
  319. for i, arg := range fd.Args.Args {
  320. w.printArg(arg, indent)
  321. if i != len(fd.Args.Args)-1 {
  322. w.print(", ")
  323. }
  324. }
  325. if len(fd.Args.KwOnlyArgs) > 0 {
  326. w.print(", *, ")
  327. for i, arg := range fd.Args.KwOnlyArgs {
  328. w.printArg(arg, indent)
  329. if i != len(fd.Args.KwOnlyArgs)-1 {
  330. w.print(", ")
  331. }
  332. }
  333. }
  334. }
  335. w.print(")")
  336. if fd.Returns != nil {
  337. w.print(" -> ")
  338. w.printNode(fd.Returns, indent)
  339. }
  340. w.print(":\n")
  341. for i, node := range fd.Body {
  342. w.printIndent(indent + 1)
  343. w.printNode(node, indent+1)
  344. if i != len(fd.Body)-1 {
  345. w.print("\n")
  346. }
  347. }
  348. }
  349. func (w *writer) printImport(imp *ast.Import, indent int32) {
  350. w.print("import ")
  351. for i, node := range imp.Names {
  352. w.printNode(node, indent)
  353. if i != len(imp.Names)-1 {
  354. w.print(", ")
  355. }
  356. }
  357. w.print("\n")
  358. }
  359. func (w *writer) printImportFrom(imp *ast.ImportFrom, indent int32) {
  360. w.print("from ")
  361. w.print(imp.Module)
  362. w.print(" import ")
  363. for i, node := range imp.Names {
  364. w.printNode(node, indent)
  365. if i != len(imp.Names)-1 {
  366. w.print(", ")
  367. }
  368. }
  369. w.print("\n")
  370. }
  371. func (w *writer) printImportGroup(n *ast.ImportGroup, indent int32) {
  372. if len(n.Imports) == 0 {
  373. return
  374. }
  375. for _, node := range n.Imports {
  376. w.printNode(node, indent)
  377. }
  378. w.print("\n")
  379. }
  380. func (w *writer) printIs(i *ast.Is, indent int32) {
  381. w.print("is")
  382. }
  383. func (w *writer) printKeyword(k *ast.Keyword, indent int32) {
  384. w.print(k.Arg)
  385. w.print("=")
  386. w.printNode(k.Value, indent)
  387. }
  388. func (w *writer) printModule(mod *ast.Module, indent int32) {
  389. for i, node := range mod.Body {
  390. prevIsImport := false
  391. if i > 0 {
  392. _, isImport := mod.Body[i-1].Node.(*ast.Node_ImportGroup)
  393. prevIsImport = isImport
  394. }
  395. _, isClassDef := node.Node.(*ast.Node_ClassDef)
  396. _, isAssign := node.Node.(*ast.Node_Assign)
  397. if isClassDef || isAssign {
  398. if prevIsImport {
  399. w.print("\n")
  400. } else {
  401. w.print("\n\n")
  402. }
  403. }
  404. w.printNode(node, indent)
  405. if isAssign {
  406. w.print("\n")
  407. }
  408. }
  409. }
  410. func (w *writer) printName(n *ast.Name, indent int32) {
  411. w.print(n.Id)
  412. }
  413. func (w *writer) printReturn(r *ast.Return, indent int32) {
  414. w.print("return ")
  415. w.printNode(r.Value, indent)
  416. }
  417. func (w *writer) printSubscript(ss *ast.Subscript, indent int32) {
  418. w.printName(ss.Value, indent)
  419. w.print("[")
  420. w.printNode(ss.Slice, indent)
  421. w.print("]")
  422. }
  423. func (w *writer) printYield(n *ast.Yield, indent int32) {
  424. w.print("yield ")
  425. w.printNode(n.Value, indent)
  426. }