123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- package compiler
- import (
- "errors"
- "fmt"
- "io"
- "os"
- "path/filepath"
- "regexp"
- "strings"
- "github.com/kyleconroy/sqlc/internal/metadata"
- "github.com/kyleconroy/sqlc/internal/migrations"
- "github.com/kyleconroy/sqlc/internal/multierr"
- "github.com/kyleconroy/sqlc/internal/opts"
- "github.com/kyleconroy/sqlc/internal/sql/ast"
- "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
- "github.com/kyleconroy/sqlc/internal/sql/sqlpath"
- )
- // TODO: Rename this interface Engine
- type Parser interface {
- Parse(io.Reader) ([]ast.Statement, error)
- CommentSyntax() metadata.CommentSyntax
- IsReservedKeyword(string) bool
- }
- // copied over from gen.go
- func structName(name string) string {
- out := ""
- for _, p := range strings.Split(name, "_") {
- if p == "id" {
- out += "ID"
- } else {
- out += strings.Title(p)
- }
- }
- return out
- }
- var identPattern = regexp.MustCompile("[^a-zA-Z0-9_]+")
- func enumValueName(value string) string {
- name := ""
- id := strings.Replace(value, "-", "_", -1)
- id = strings.Replace(id, ":", "_", -1)
- id = strings.Replace(id, "/", "_", -1)
- id = identPattern.ReplaceAllString(id, "")
- for _, part := range strings.Split(id, "_") {
- name += strings.Title(part)
- }
- return name
- }
- // end copypasta
- func (c *Compiler) parseCatalog(schemas []string) error {
- files, err := sqlpath.Glob(schemas)
- if err != nil {
- return err
- }
- merr := multierr.New()
- for _, filename := range files {
- blob, err := os.ReadFile(filename)
- if err != nil {
- merr.Add(filename, "", 0, err)
- continue
- }
- contents := migrations.RemoveRollbackStatements(string(blob))
- stmts, err := c.parser.Parse(strings.NewReader(contents))
- if err != nil {
- merr.Add(filename, contents, 0, err)
- continue
- }
- for i := range stmts {
- if err := c.catalog.Update(stmts[i], c); err != nil {
- merr.Add(filename, contents, stmts[i].Pos(), err)
- continue
- }
- }
- }
- if len(merr.Errs()) > 0 {
- return merr
- }
- return nil
- }
- func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
- var q []*Query
- merr := multierr.New()
- set := map[string]struct{}{}
- files, err := sqlpath.Glob(c.conf.Queries)
- if err != nil {
- return nil, err
- }
- for _, filename := range files {
- blob, err := os.ReadFile(filename)
- if err != nil {
- merr.Add(filename, "", 0, err)
- continue
- }
- src := string(blob)
- stmts, err := c.parser.Parse(strings.NewReader(src))
- if err != nil {
- merr.Add(filename, src, 0, err)
- continue
- }
- for _, stmt := range stmts {
- query, err := c.parseQuery(stmt.Raw, src, o)
- if err == ErrUnsupportedStatementType {
- continue
- }
- if err != nil {
- var e *sqlerr.Error
- loc := stmt.Raw.Pos()
- if errors.As(err, &e) && e.Location != 0 {
- loc = e.Location
- }
- merr.Add(filename, src, loc, err)
- continue
- }
- if query.Name != "" {
- if _, exists := set[query.Name]; exists {
- merr.Add(filename, src, stmt.Raw.Pos(), fmt.Errorf("duplicate query name: %s", query.Name))
- continue
- }
- set[query.Name] = struct{}{}
- }
- query.Filename = filepath.Base(filename)
- if query != nil {
- q = append(q, query)
- }
- }
- }
- if len(merr.Errs()) > 0 {
- return nil, merr
- }
- if len(q) == 0 {
- return nil, fmt.Errorf("no queries contained in paths %s", strings.Join(c.conf.Queries, ","))
- }
- return &Result{
- Catalog: c.catalog,
- Queries: q,
- }, nil
- }
|