package compiler import ( "errors" "fmt" "io" "io/ioutil" "path/filepath" "regexp" "strings" "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/migrations" "github.com/kyleconroy/sqlc/internal/multierr" "github.com/kyleconroy/sqlc/internal/opts" "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/catalog" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" "github.com/kyleconroy/sqlc/internal/sql/sqlpath" ) type Parser interface { Parse(io.Reader) ([]ast.Statement, error) CommentSyntax() metadata.CommentSyntax } // copied over from gen.go func structName(name string) string { out := "" for _, p := range strings.Split(name, "_") { if p == "id" { out += "ID" } else { out += strings.Title(p) } } return out } var identPattern = regexp.MustCompile("[^a-zA-Z0-9_]+") func enumValueName(value string) string { name := "" id := strings.Replace(value, "-", "_", -1) id = strings.Replace(id, ":", "_", -1) id = strings.Replace(id, "/", "_", -1) id = identPattern.ReplaceAllString(id, "") for _, part := range strings.Split(id, "_") { name += strings.Title(part) } return name } // end copypasta func parseCatalog(p Parser, c *catalog.Catalog, schemas []string) error { files, err := sqlpath.Glob(schemas) if err != nil { return err } merr := multierr.New() for _, filename := range files { blob, err := ioutil.ReadFile(filename) if err != nil { merr.Add(filename, "", 0, err) continue } contents := migrations.RemoveRollbackStatements(string(blob)) stmts, err := p.Parse(strings.NewReader(contents)) if err != nil { merr.Add(filename, contents, 0, err) continue } for i := range stmts { if err := c.Update(stmts[i]); err != nil { merr.Add(filename, contents, stmts[i].Pos(), err) continue } } } if len(merr.Errs()) > 0 { return merr } return nil } func parseQueries(p Parser, c *catalog.Catalog, queries []string, o opts.Parser) (*Result, error) { var q []*Query merr := multierr.New() set := map[string]struct{}{} files, err := sqlpath.Glob(queries) if err != nil { return nil, err } for _, filename := range files { blob, err := ioutil.ReadFile(filename) if err != nil { merr.Add(filename, "", 0, err) continue } src := string(blob) stmts, err := p.Parse(strings.NewReader(src)) if err != nil { merr.Add(filename, src, 0, err) continue } for _, stmt := range stmts { query, err := parseQuery(p, c, stmt.Raw, src, o.UsePositionalParameters) if err == ErrUnsupportedStatementType { continue } if err != nil { var e *sqlerr.Error loc := stmt.Raw.Pos() if errors.As(err, &e) && e.Location != 0 { loc = e.Location } merr.Add(filename, src, loc, err) continue } if query.Name != "" { if _, exists := set[query.Name]; exists { merr.Add(filename, src, stmt.Raw.Pos(), fmt.Errorf("duplicate query name: %s", query.Name)) continue } set[query.Name] = struct{}{} } query.Filename = filepath.Base(filename) if query != nil { q = append(q, query) } } } if len(merr.Errs()) > 0 { return nil, merr } if len(q) == 0 { return nil, fmt.Errorf("no queries contained in paths %s", strings.Join(queries, ",")) } return &Result{ Catalog: c, Queries: q, }, nil }