1
0

analyze.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. package analyzer
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "github.com/jackc/pgx/v5"
  9. "github.com/jackc/pgx/v5/pgconn"
  10. "github.com/jackc/pgx/v5/pgxpool"
  11. core "github.com/sqlc-dev/sqlc/internal/analysis"
  12. "github.com/sqlc-dev/sqlc/internal/config"
  13. "github.com/sqlc-dev/sqlc/internal/opts"
  14. pb "github.com/sqlc-dev/sqlc/internal/quickdb/v1"
  15. "github.com/sqlc-dev/sqlc/internal/shfmt"
  16. "github.com/sqlc-dev/sqlc/internal/sql/ast"
  17. "github.com/sqlc-dev/sqlc/internal/sql/named"
  18. "github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
  19. )
  20. type Analyzer struct {
  21. db config.Database
  22. client pb.QuickClient
  23. pool *pgxpool.Pool
  24. dbg opts.Debug
  25. replacer *shfmt.Replacer
  26. formats sync.Map
  27. columns sync.Map
  28. tables sync.Map
  29. }
  30. func New(client pb.QuickClient, db config.Database) *Analyzer {
  31. return &Analyzer{
  32. db: db,
  33. dbg: opts.DebugFromEnv(),
  34. client: client,
  35. replacer: shfmt.NewReplacer(nil),
  36. }
  37. }
  38. const columnQuery = `
  39. SELECT
  40. pg_catalog.format_type(pg_attribute.atttypid, pg_attribute.atttypmod) AS data_type,
  41. pg_attribute.attnotnull as not_null,
  42. pg_attribute.attndims as array_dims
  43. FROM
  44. pg_catalog.pg_attribute
  45. WHERE
  46. attrelid = $1
  47. AND attnum = $2;
  48. `
  49. const tableQuery = `
  50. SELECT
  51. pg_class.relname as table_name,
  52. pg_namespace.nspname as schema_name
  53. FROM
  54. pg_catalog.pg_class
  55. JOIN
  56. pg_catalog.pg_namespace ON pg_namespace.oid = pg_class.relnamespace
  57. WHERE
  58. pg_class.oid = $1;
  59. `
  60. type pgTable struct {
  61. TableName string `db:"table_name"`
  62. SchemaName string `db:"schema_name"`
  63. }
  64. // Cache these types in memory
  65. func (a *Analyzer) tableInfo(ctx context.Context, oid uint32) (*pgTable, error) {
  66. ctbl, ok := a.tables.Load(oid)
  67. if ok {
  68. return ctbl.(*pgTable), nil
  69. }
  70. rows, err := a.pool.Query(ctx, tableQuery, oid)
  71. if err != nil {
  72. return nil, err
  73. }
  74. tbl, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[pgTable])
  75. if err != nil {
  76. return nil, err
  77. }
  78. a.tables.Store(oid, &tbl)
  79. return &tbl, nil
  80. }
  81. type pgColumn struct {
  82. DataType string `db:"data_type"`
  83. NotNull bool `db:"not_null"`
  84. ArrayDims int `db:"array_dims"`
  85. }
  86. type columnKey struct {
  87. OID uint32
  88. Attr uint16
  89. }
  90. // Cache these types in memory
  91. func (a *Analyzer) columnInfo(ctx context.Context, field pgconn.FieldDescription) (*pgColumn, error) {
  92. key := columnKey{field.TableOID, field.TableAttributeNumber}
  93. cinfo, ok := a.columns.Load(key)
  94. if ok {
  95. return cinfo.(*pgColumn), nil
  96. }
  97. rows, err := a.pool.Query(ctx, columnQuery, field.TableOID, int16(field.TableAttributeNumber))
  98. if err != nil {
  99. return nil, err
  100. }
  101. col, err := pgx.CollectOneRow(rows, pgx.RowToStructByName[pgColumn])
  102. if err != nil {
  103. return nil, err
  104. }
  105. a.columns.Store(key, &col)
  106. return &col, nil
  107. }
  108. type formatKey struct {
  109. OID uint32
  110. Modified int32
  111. }
  112. // TODO: Use PGX to do the lookup for basic OID types
  113. func (a *Analyzer) formatType(ctx context.Context, oid uint32, modifier int32) (string, error) {
  114. key := formatKey{oid, modifier}
  115. ftyp, ok := a.formats.Load(key)
  116. if ok {
  117. return ftyp.(string), nil
  118. }
  119. rows, err := a.pool.Query(ctx, `SELECT format_type($1, $2)`, oid, modifier)
  120. if err != nil {
  121. return "", err
  122. }
  123. dt, err := pgx.CollectOneRow(rows, pgx.RowTo[string])
  124. if err != nil {
  125. return "", err
  126. }
  127. a.formats.Store(key, dt)
  128. return dt, err
  129. }
  130. // TODO: This is bad
  131. func rewriteType(dt string) string {
  132. switch {
  133. case strings.HasPrefix(dt, "character("):
  134. return "pg_catalog.bpchar"
  135. case strings.HasPrefix(dt, "character varying"):
  136. return "pg_catalog.varchar"
  137. case strings.HasPrefix(dt, "bit varying"):
  138. return "pg_catalog.varbit"
  139. case strings.HasPrefix(dt, "bit("):
  140. return "pg_catalog.bit"
  141. }
  142. switch dt {
  143. case "bpchar":
  144. return "pg_catalog.bpchar"
  145. case "timestamp without time zone":
  146. return "pg_catalog.timestamp"
  147. case "timestamp with time zone":
  148. return "pg_catalog.timestamptz"
  149. case "time without time zone":
  150. return "pg_catalog.time"
  151. case "time with time zone":
  152. return "pg_catalog.timetz"
  153. }
  154. return dt
  155. }
  156. func parseType(dt string) (string, bool, int) {
  157. size := 0
  158. for {
  159. trimmed := strings.TrimSuffix(dt, "[]")
  160. if trimmed == dt {
  161. return rewriteType(dt), size > 0, size
  162. }
  163. size += 1
  164. dt = trimmed
  165. }
  166. }
  167. // Don't create a database per query
  168. func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrations []string, ps *named.ParamSet) (*core.Analysis, error) {
  169. extractSqlErr := func(e error) error {
  170. var pgErr *pgconn.PgError
  171. if errors.As(e, &pgErr) {
  172. return &sqlerr.Error{
  173. Code: pgErr.Code,
  174. Message: pgErr.Message,
  175. Location: max(n.Pos()+int(pgErr.Position)-1, 0),
  176. }
  177. }
  178. return e
  179. }
  180. if a.pool == nil {
  181. var uri string
  182. if a.db.Managed {
  183. if a.client == nil {
  184. return nil, fmt.Errorf("client is nil")
  185. }
  186. edb, err := a.client.CreateEphemeralDatabase(ctx, &pb.CreateEphemeralDatabaseRequest{
  187. Engine: "postgresql",
  188. Migrations: migrations,
  189. })
  190. if err != nil {
  191. return nil, err
  192. }
  193. uri = edb.Uri
  194. } else if a.dbg.OnlyManagedDatabases {
  195. return nil, fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed")
  196. } else {
  197. uri = a.replacer.Replace(a.db.URI)
  198. }
  199. conf, err := pgxpool.ParseConfig(uri)
  200. if err != nil {
  201. return nil, err
  202. }
  203. pool, err := pgxpool.NewWithConfig(ctx, conf)
  204. if err != nil {
  205. return nil, err
  206. }
  207. a.pool = pool
  208. }
  209. c, err := a.pool.Acquire(ctx)
  210. if err != nil {
  211. return nil, err
  212. }
  213. defer c.Release()
  214. // TODO: Pick a random name
  215. desc, err := c.Conn().Prepare(ctx, "foo", query)
  216. if err != nil {
  217. return nil, extractSqlErr(err)
  218. }
  219. if err := c.Conn().Deallocate(ctx, "foo"); err != nil {
  220. return nil, err
  221. }
  222. var result core.Analysis
  223. for _, field := range desc.Fields {
  224. if field.TableOID > 0 {
  225. col, err := a.columnInfo(ctx, field)
  226. if err != nil {
  227. return nil, err
  228. }
  229. // debug.Dump(i, field, col)
  230. tbl, err := a.tableInfo(ctx, field.TableOID)
  231. if err != nil {
  232. return nil, err
  233. }
  234. // TODO: Why are these dims different?
  235. dt, isArray, _ := parseType(col.DataType)
  236. notNull := col.NotNull
  237. name := field.Name
  238. result.Columns = append(result.Columns, &core.Column{
  239. Name: name,
  240. OriginalName: field.Name,
  241. DataType: dt,
  242. NotNull: notNull,
  243. IsArray: isArray,
  244. ArrayDims: int32(col.ArrayDims),
  245. Table: &core.Identifier{
  246. Schema: tbl.SchemaName,
  247. Name: tbl.TableName,
  248. },
  249. })
  250. } else {
  251. dataType, err := a.formatType(ctx, field.DataTypeOID, field.TypeModifier)
  252. if err != nil {
  253. return nil, err
  254. }
  255. // debug.Dump(i, field, dataType)
  256. notNull := false
  257. name := field.Name
  258. dt, isArray, dims := parseType(dataType)
  259. result.Columns = append(result.Columns, &core.Column{
  260. Name: name,
  261. OriginalName: field.Name,
  262. DataType: dt,
  263. NotNull: notNull,
  264. IsArray: isArray,
  265. ArrayDims: int32(dims),
  266. })
  267. }
  268. }
  269. for i, oid := range desc.ParamOIDs {
  270. dataType, err := a.formatType(ctx, oid, -1)
  271. if err != nil {
  272. return nil, err
  273. }
  274. notNull := false
  275. dt, isArray, dims := parseType(dataType)
  276. name := ""
  277. if ps != nil {
  278. name, _ = ps.NameFor(i + 1)
  279. }
  280. result.Params = append(result.Params, &core.Parameter{
  281. Number: int32(i + 1),
  282. Column: &core.Column{
  283. Name: name,
  284. DataType: dt,
  285. IsArray: isArray,
  286. ArrayDims: int32(dims),
  287. NotNull: notNull,
  288. },
  289. })
  290. }
  291. return &result, nil
  292. }
  293. func (a *Analyzer) Close(_ context.Context) error {
  294. if a.pool != nil {
  295. a.pool.Close()
  296. }
  297. return nil
  298. }