123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484 |
- package mysql
- import (
- "fmt"
- "io"
- "io/ioutil"
- "path/filepath"
- "github.com/davecgh/go-spew/spew"
- "vitess.io/vitess/go/vt/sqlparser"
- "github.com/kyleconroy/sqlc/internal/config"
- "github.com/kyleconroy/sqlc/internal/metadata"
- "github.com/kyleconroy/sqlc/internal/migrations"
- "github.com/kyleconroy/sqlc/internal/multierr"
- "github.com/kyleconroy/sqlc/internal/sql/sqlpath"
- )
- // Query holds the data for walking and validating mysql querys
- type Query struct {
- SQL string // the string representation of the parsed query
- Columns []Column
- Params []*Param // "?" params in the query string
- Name string // the Go function name
- Cmd string // TODO: Pick a better name. One of: one, many, exec, execrows
- DefaultTableName string // for columns that are not qualified
- Filename string
- }
- type Column struct {
- *sqlparser.ColumnDefinition
- Table string
- }
- func parsePath(sqlPath []string, generator PackageGenerator) (*Result, error) {
- files, err := sqlpath.Glob(sqlPath)
- if err != nil {
- return nil, err
- }
- parseErrors := multierr.New()
- parsedQueries := []*Query{}
- for _, filename := range files {
- blob, err := ioutil.ReadFile(filename)
- if err != nil {
- parseErrors.Add(filename, "", 0, err)
- }
- contents := migrations.RemoveRollbackStatements(string(blob))
- if err != nil {
- parseErrors.Add(filename, "", 0, err)
- continue
- }
- t := sqlparser.NewStringTokenizer(contents)
- var start int
- for {
- q, err := sqlparser.ParseNextStrictDDL(t)
- if err == io.EOF {
- break
- } else if err != nil {
- if posErr, ok := err.(sqlparser.PositionedErr); ok {
- message := fmt.Errorf(posErr.Err)
- if posErr.Near != nil {
- message = fmt.Errorf("%s at or near \"%s\"", posErr.Err, posErr.Near)
- }
- parseErrors.Add(filename, contents, posErr.Pos, message)
- } else {
- parseErrors.Add(filename, contents, start, err)
- }
- continue
- }
- query := contents[start : t.Position-1]
- result, err := generator.parseQueryString(q, query)
- if err != nil {
- parseErrors.Add(filename, contents, start, err)
- start = t.Position
- continue
- }
- start = t.Position
- if result == nil {
- continue
- }
- result.Filename = filepath.Base(filename)
- parsedQueries = append(parsedQueries, result)
- }
- }
- if len(parseErrors.Errs()) > 0 {
- return nil, parseErrors
- }
- return &Result{
- Queries: parsedQueries,
- PackageGenerator: generator,
- }, nil
- }
- func (pGen PackageGenerator) parseQueryString(tree sqlparser.Statement, query string) (*Query, error) {
- var parsedQuery *Query
- switch tree := tree.(type) {
- case *sqlparser.Select:
- selectQuery, err := pGen.parseSelect(tree, query)
- if err != nil {
- return nil, err
- }
- parsedQuery = selectQuery
- case *sqlparser.Insert:
- insert, err := pGen.parseInsert(tree, query)
- if err != nil {
- return nil, err
- }
- parsedQuery = insert
- case *sqlparser.Update:
- update, err := pGen.parseUpdate(tree, query)
- if err != nil {
- return nil, err
- }
- parsedQuery = update
- case *sqlparser.Delete:
- delete, err := pGen.parseDelete(tree, query)
- if err != nil {
- return nil, err
- }
- parsedQuery = delete
- case *sqlparser.DDL:
- pGen.Schema.Add(tree)
- return nil, nil
- default:
- // panic("Unsupported SQL statement type")
- return nil, nil
- }
- paramsReplacedQuery, err := replaceParamStrs(sqlparser.String(tree), parsedQuery.Params)
- if err != nil {
- return nil, fmt.Errorf("failed to replace param variables in query string: %w", err)
- }
- parsedQuery.SQL = paramsReplacedQuery
- return parsedQuery, nil
- }
- func (q *Query) parseNameAndCmd() error {
- if q == nil {
- return fmt.Errorf("cannot parse name and cmd from null query")
- }
- _, comments := sqlparser.SplitMarginComments(q.SQL)
- name, cmd, err := metadata.Parse(comments.Leading, metadata.CommentSyntaxStar)
- if err != nil {
- return err
- } else if name == "" || cmd == "" {
- return fmt.Errorf("failed to parse query leading comment")
- }
- q.Name = name
- q.Cmd = cmd
- return nil
- }
- func (pGen PackageGenerator) parseSelect(tree *sqlparser.Select, query string) (*Query, error) {
- tableAliasMap, defaultTableName, err := parseFrom(tree.From, false)
- if err != nil {
- return nil, fmt.Errorf("failed to parse table name alias's: %w", err)
- }
- // handle * expressions first by expanding all columns of the default table
- _, ok := tree.SelectExprs[0].(*sqlparser.StarExpr)
- if ok {
- colNames := []sqlparser.SelectExpr{}
- colDfns := pGen.Schema.tables[defaultTableName]
- for _, col := range colDfns {
- colNames = append(colNames, &sqlparser.AliasedExpr{
- Expr: &sqlparser.ColName{
- Name: col.Name,
- }},
- )
- }
- tree.SelectExprs = colNames
- }
- parsedQuery := Query{
- SQL: query,
- DefaultTableName: defaultTableName,
- }
- cols, err := pGen.parseSelectAliasExpr(tree.SelectExprs, tableAliasMap, defaultTableName)
- if err != nil {
- return nil, err
- }
- parsedQuery.Columns = cols
- whereParams, err := pGen.paramsInWhereExpr(tree.Where, tableAliasMap, defaultTableName)
- if err != nil {
- return nil, err
- }
- limitParams, err := pGen.paramsInLimitExpr(tree.Limit, tableAliasMap)
- if err != nil {
- return nil, err
- }
- parsedQuery.Params = append(whereParams, limitParams...)
- err = parsedQuery.parseNameAndCmd()
- if err != nil {
- return nil, err
- }
- return &parsedQuery, nil
- }
- // FromTable describes a table reference in the "FROM" clause of a query.
- type FromTable struct {
- TrueName string // the true table name as described in the schema
- IsLeftJoined bool // which could result in null columns
- }
- // FromTables describes a map between table alias expressions and the
- // proper table name
- type FromTables map[string]FromTable
- func parseFrom(from sqlparser.TableExprs, isLeftJoined bool) (FromTables, string, error) {
- tables := make(map[string]FromTable)
- var defaultTableName string
- for _, expr := range from {
- switch v := expr.(type) {
- case *sqlparser.AliasedTableExpr:
- name, ok := v.Expr.(sqlparser.TableName)
- if !ok {
- return nil, "", fmt.Errorf("failed to parse AliasedTableExpr name: %v", spew.Sdump(v))
- }
- t := FromTable{
- TrueName: name.Name.String(),
- IsLeftJoined: isLeftJoined,
- }
- if v.As.String() != "" {
- tables[v.As.String()] = t
- } else {
- tables[name.Name.String()] = t
- }
- defaultTableName = name.Name.String()
- case *sqlparser.JoinTableExpr:
- isLeftJoin := v.Join == "left join"
- left, leftMostTableName, err := parseFrom([]sqlparser.TableExpr{v.LeftExpr}, false)
- if err != nil {
- return nil, "", err
- }
- right, _, err := parseFrom([]sqlparser.TableExpr{v.RightExpr}, isLeftJoin)
- if err != nil {
- return nil, "", err
- }
- // merge the left and right maps
- for k, v := range left {
- right[k] = v
- }
- return right, leftMostTableName, nil
- default:
- return nil, "", fmt.Errorf("failed to parse table expr: %v", spew.Sdump(v))
- }
- }
- return tables, defaultTableName, nil
- }
- func (pGen PackageGenerator) parseUpdate(node *sqlparser.Update, query string) (*Query, error) {
- tableAliasMap, defaultTable, err := parseFrom(node.TableExprs, false)
- if err != nil {
- return nil, fmt.Errorf("failed to parse table name alias's: %w", err)
- }
- params := []*Param{}
- for _, updateExpr := range node.Exprs {
- newValue, isValue := updateExpr.Expr.(*sqlparser.SQLVal)
- if !isValue {
- continue
- } else if isParam := newValue.Type == sqlparser.ValArg; !isParam {
- continue
- }
- col, err := pGen.getColType(updateExpr.Name, tableAliasMap, defaultTable)
- if err != nil {
- return nil, fmt.Errorf("failed to determine type of a parameter's column: %w", err)
- }
- originalParamName := string(newValue.Val)
- param := Param{
- OriginalName: originalParamName,
- Name: paramName(col.Name, originalParamName),
- Typ: pGen.goTypeCol(*col),
- }
- params = append(params, ¶m)
- }
- whereParams, err := pGen.paramsInWhereExpr(node.Where, tableAliasMap, defaultTable)
- if err != nil {
- return nil, fmt.Errorf("failed to parse params from WHERE expression: %w", err)
- }
- parsedQuery := Query{
- SQL: query,
- Columns: nil,
- Params: append(params, whereParams...),
- DefaultTableName: defaultTable,
- }
- err = parsedQuery.parseNameAndCmd()
- if err != nil {
- return nil, err
- }
- return &parsedQuery, nil
- }
- func (pGen PackageGenerator) parseInsert(node *sqlparser.Insert, query string) (*Query, error) {
- params := []*Param{}
- cols := node.Columns
- tableName := node.Table.Name.String()
- switch rows := node.Rows.(type) {
- case *sqlparser.Select:
- selectQuery, err := pGen.parseSelect(rows, query)
- if err != nil {
- return nil, err
- }
- params = append(params, selectQuery.Params...)
- case sqlparser.Values:
- for _, row := range rows {
- for colIx, item := range row {
- switch v := item.(type) {
- case *sqlparser.SQLVal:
- if v.Type == sqlparser.ValArg {
- colName := cols[colIx].String()
- col, err := pGen.schemaLookup(tableName, colName)
- varName := string(v.Val)
- param := &Param{OriginalName: varName}
- if err == nil {
- param.Name = paramName(col.Name, varName)
- param.Typ = pGen.goTypeCol(*col)
- } else {
- param.Name = "Unknown"
- param.Typ = "interface{}"
- }
- params = append(params, param)
- }
- case *sqlparser.FuncExpr:
- name, raw, err := matchFuncExpr(v)
- if err != nil {
- return nil, err
- }
- if name == "" || raw == "" {
- continue
- }
- colName := cols[colIx].String()
- col, err := pGen.schemaLookup(tableName, colName)
- param := &Param{
- OriginalName: raw,
- }
- if err == nil {
- param.Name = name
- param.Typ = pGen.goTypeCol(*col)
- } else {
- param.Name = "Unknown"
- param.Typ = "interface{}"
- }
- params = append(params, param)
- default:
- return nil, fmt.Errorf("failed to parse insert query value")
- }
- }
- }
- default:
- return nil, fmt.Errorf("Unknown insert row type of %T", node.Rows)
- }
- parsedQuery := &Query{
- SQL: query,
- Params: params,
- Columns: nil,
- DefaultTableName: tableName,
- }
- err := parsedQuery.parseNameAndCmd()
- if err != nil {
- return nil, err
- }
- return parsedQuery, nil
- }
- func (pGen PackageGenerator) parseDelete(node *sqlparser.Delete, query string) (*Query, error) {
- tableAliasMap, defaultTableName, err := parseFrom(node.TableExprs, false)
- if err != nil {
- return nil, fmt.Errorf("failed to parse table name alias's: %w", err)
- }
- whereParams, err := pGen.paramsInWhereExpr(node.Where, tableAliasMap, defaultTableName)
- if err != nil {
- return nil, err
- }
- limitParams, err := pGen.paramsInLimitExpr(node.Limit, tableAliasMap)
- if err != nil {
- return nil, err
- }
- parsedQuery := &Query{
- SQL: query,
- Params: append(whereParams, limitParams...),
- Columns: nil,
- DefaultTableName: defaultTableName,
- }
- err = parsedQuery.parseNameAndCmd()
- if err != nil {
- return nil, err
- }
- return parsedQuery, nil
- }
- func (pGen PackageGenerator) parseSelectAliasExpr(exprs sqlparser.SelectExprs, tableAliasMap FromTables, defaultTable string) ([]Column, error) {
- cols := []Column{}
- for _, col := range exprs {
- switch expr := col.(type) {
- case *sqlparser.AliasedExpr:
- hasAlias := !expr.As.IsEmpty()
- switch v := expr.Expr.(type) {
- case *sqlparser.ColName:
- res, err := pGen.getColType(v, tableAliasMap, defaultTable)
- if err != nil {
- return nil, err
- }
- if hasAlias {
- res.Name = expr.As // applys the alias
- }
- cols = append(cols, *res)
- case *sqlparser.GroupConcatExpr:
- cols = append(cols, Column{
- ColumnDefinition: &sqlparser.ColumnDefinition{
- Name: sqlparser.NewColIdent(expr.As.String()),
- Type: sqlparser.ColumnType{
- Type: "varchar",
- NotNull: true,
- },
- },
- Table: "", // group concat expressions don't originate from a table schema
- },
- )
- case *sqlparser.FuncExpr:
- funcName := v.Name.Lowered()
- funcType := functionReturnType(funcName)
- var returnVal sqlparser.ColIdent
- if hasAlias {
- returnVal = expr.As
- } else {
- returnVal = sqlparser.NewColIdent(funcName)
- }
- colDfn := &sqlparser.ColumnDefinition{
- Name: returnVal,
- Type: sqlparser.ColumnType{
- Type: funcType,
- NotNull: true,
- },
- }
- cols = append(cols, Column{colDfn, ""}) // func returns types don't originate from a table schema
- }
- default:
- return nil, fmt.Errorf("Failed to handle select expr of type : %T", expr)
- }
- }
- return cols, nil
- }
- // GeneratePkg is the main entry to mysql generator package
- func GeneratePkg(pkgName string, schemaPath, querysPath []string, settings config.CombinedSettings) (*Result, error) {
- s := NewSchema()
- generator := PackageGenerator{
- Schema: s,
- CombinedSettings: settings,
- packageName: pkgName,
- }
- _, err := parsePath(schemaPath, generator)
- if err != nil {
- return nil, err
- }
- result, err := parsePath(querysPath, generator)
- if err != nil {
- return nil, err
- }
- return result, nil
- }
|