graph.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package ast
  2. import (
  3. "fmt"
  4. "os"
  5. "sync"
  6. "github.com/dominikbraun/graph"
  7. "github.com/dominikbraun/graph/draw"
  8. "golang.org/x/sync/errgroup"
  9. )
  10. type TaskfileGraph struct {
  11. sync.Mutex
  12. graph.Graph[string, *TaskfileVertex]
  13. }
  14. // A TaskfileVertex is a vertex on the Taskfile DAG.
  15. type TaskfileVertex struct {
  16. URI string
  17. Taskfile *Taskfile
  18. }
  19. func taskfileHash(vertex *TaskfileVertex) string {
  20. return vertex.URI
  21. }
  22. func NewTaskfileGraph() *TaskfileGraph {
  23. return &TaskfileGraph{
  24. sync.Mutex{},
  25. graph.New(taskfileHash,
  26. graph.Directed(),
  27. graph.PreventCycles(),
  28. graph.Rooted(),
  29. ),
  30. }
  31. }
  32. func (tfg *TaskfileGraph) Visualize(filename string) error {
  33. f, err := os.Create(filename)
  34. if err != nil {
  35. return err
  36. }
  37. defer f.Close()
  38. return draw.DOT(tfg.Graph, f)
  39. }
  40. func (tfg *TaskfileGraph) Merge() (*Taskfile, error) {
  41. hashes, err := graph.TopologicalSort(tfg.Graph)
  42. if err != nil {
  43. return nil, err
  44. }
  45. predecessorMap, err := tfg.PredecessorMap()
  46. if err != nil {
  47. return nil, err
  48. }
  49. // Loop over each vertex in reverse topological order except for the root vertex.
  50. // This gives us a loop over every included Taskfile in an order which is safe to merge.
  51. for i := len(hashes) - 1; i > 0; i-- {
  52. hash := hashes[i]
  53. // Get the included vertex
  54. includedVertex, err := tfg.Vertex(hash)
  55. if err != nil {
  56. return nil, err
  57. }
  58. // Create an error group to wait for all the included Taskfiles to be merged with all its parents
  59. var g errgroup.Group
  60. // Loop over edge that leads to a vertex that includes the current vertex
  61. for _, edge := range predecessorMap[hash] {
  62. // Start a goroutine to process each included Taskfile
  63. g.Go(func() error {
  64. // Get the base vertex
  65. vertex, err := tfg.Vertex(edge.Source)
  66. if err != nil {
  67. return err
  68. }
  69. // Get the merge options
  70. includes, ok := edge.Properties.Data.([]*Include)
  71. if !ok {
  72. return fmt.Errorf("task: Failed to get merge options")
  73. }
  74. // Merge the included Taskfiles into the parent Taskfile
  75. for _, include := range includes {
  76. if err := vertex.Taskfile.Merge(
  77. includedVertex.Taskfile,
  78. include,
  79. ); err != nil {
  80. return err
  81. }
  82. }
  83. return nil
  84. })
  85. if err := g.Wait(); err != nil {
  86. return nil, err
  87. }
  88. }
  89. // Wait for all the go routines to finish
  90. if err := g.Wait(); err != nil {
  91. return nil, err
  92. }
  93. }
  94. // Get the root vertex
  95. rootVertex, err := tfg.Vertex(hashes[0])
  96. if err != nil {
  97. return nil, err
  98. }
  99. _ = rootVertex.Taskfile.Tasks.Range(func(name string, task *Task) error {
  100. if task == nil {
  101. task = &Task{}
  102. rootVertex.Taskfile.Tasks.Set(name, task)
  103. }
  104. task.Task = name
  105. return nil
  106. })
  107. return rootVertex.Taskfile, nil
  108. }