Skip to content

Commit bea04dd

Browse files
feat: Support graceful job step cancellation (#2714)
* feat: Support graceful job step cancellation * for gh-act-runner * act-cli support as well * respecting always() and cancelled() of steps * change main * cancel startContainer / gh cli / bugreport early * add to watch as well --------- Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 517c3ac commit bea04dd

File tree

8 files changed

+241
-28
lines changed

8 files changed

+241
-28
lines changed

cmd/root.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ func newRunCommand(ctx context.Context, input *Input) func(*cobra.Command, []str
391391
}
392392

393393
if ok, _ := cmd.Flags().GetBool("bug-report"); ok {
394+
ctx, cancel := common.EarlyCancelContext(ctx)
395+
defer cancel()
394396
return bugReport(ctx, cmd.Version)
395397
}
396398
if ok, _ := cmd.Flags().GetBool("man-page"); ok {
@@ -430,6 +432,8 @@ func newRunCommand(ctx context.Context, input *Input) func(*cobra.Command, []str
430432
_ = readEnvsEx(input.Secretfile(), secrets, true)
431433

432434
if _, hasGitHubToken := secrets["GITHUB_TOKEN"]; !hasGitHubToken {
435+
ctx, cancel := common.EarlyCancelContext(ctx)
436+
defer cancel()
433437
secrets["GITHUB_TOKEN"], _ = gh.GetToken(ctx, "")
434438
}
435439

@@ -772,10 +776,13 @@ func watchAndRun(ctx context.Context, fn common.Executor) error {
772776
return err
773777
}
774778

779+
earlyCancelCtx, cancel := common.EarlyCancelContext(ctx)
780+
defer cancel()
781+
775782
for folderWatcher.IsRunning() {
776783
log.Debugf("Watching %s for changes", dir)
777784
select {
778-
case <-ctx.Done():
785+
case <-earlyCancelCtx.Done():
779786
return nil
780787
case changes := <-folderWatcher.ChangeDetails():
781788
log.Debugf("%s", changes.String())

main.go

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,18 @@
11
package main
22

33
import (
4-
"context"
54
_ "embed"
6-
"os"
7-
"os/signal"
8-
"syscall"
95

106
"github.com/nektos/act/cmd"
7+
"github.com/nektos/act/pkg/common"
118
)
129

1310
//go:embed VERSION
1411
var version string
1512

1613
func main() {
17-
ctx := context.Background()
18-
ctx, cancel := context.WithCancel(ctx)
19-
20-
// trap Ctrl+C and call cancel on the context
21-
c := make(chan os.Signal, 1)
22-
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
23-
defer func() {
24-
signal.Stop(c)
25-
cancel()
26-
}()
27-
go func() {
28-
select {
29-
case <-c:
30-
cancel()
31-
case <-ctx.Done():
32-
}
33-
}()
14+
ctx, cancel := common.CreateGracefulJobCancellationContext()
15+
defer cancel()
3416

3517
// run the command
3618
cmd.Execute(ctx, version)

main_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package main
2+
3+
import (
4+
"os"
5+
"testing"
6+
)
7+
8+
func TestMain(_ *testing.T) {
9+
os.Args = []string{"act", "--help"}
10+
main()
11+
}

pkg/common/context.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package common
2+
3+
import (
4+
"context"
5+
"os"
6+
"os/signal"
7+
"syscall"
8+
)
9+
10+
func createGracefulJobCancellationContext() (context.Context, func(), chan os.Signal) {
11+
ctx := context.Background()
12+
ctx, forceCancel := context.WithCancel(ctx)
13+
cancelCtx, cancel := context.WithCancel(ctx)
14+
ctx = WithJobCancelContext(ctx, cancelCtx)
15+
16+
// trap Ctrl+C and call cancel on the context
17+
c := make(chan os.Signal, 1)
18+
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
19+
go func() {
20+
select {
21+
case sig := <-c:
22+
if sig == os.Interrupt {
23+
cancel()
24+
select {
25+
case <-c:
26+
forceCancel()
27+
case <-ctx.Done():
28+
}
29+
} else {
30+
forceCancel()
31+
}
32+
case <-ctx.Done():
33+
}
34+
}()
35+
return ctx, func() {
36+
signal.Stop(c)
37+
forceCancel()
38+
cancel()
39+
}, c
40+
}
41+
42+
func CreateGracefulJobCancellationContext() (context.Context, func()) {
43+
ctx, cancel, _ := createGracefulJobCancellationContext()
44+
return ctx, cancel
45+
}

pkg/common/context_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package common
2+
3+
import (
4+
"context"
5+
"os"
6+
"syscall"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestGracefulJobCancellationViaSigint(t *testing.T) {
14+
ctx, cancel, channel := createGracefulJobCancellationContext()
15+
defer cancel()
16+
assert.NotNil(t, ctx)
17+
assert.NotNil(t, cancel)
18+
assert.NotNil(t, channel)
19+
cancelCtx := JobCancelContext(ctx)
20+
assert.NotNil(t, cancelCtx)
21+
assert.NoError(t, ctx.Err())
22+
assert.NoError(t, cancelCtx.Err())
23+
channel <- os.Interrupt
24+
select {
25+
case <-time.After(1 * time.Second):
26+
t.Fatal("context not canceled")
27+
case <-cancelCtx.Done():
28+
case <-ctx.Done():
29+
}
30+
if assert.Error(t, cancelCtx.Err(), "context canceled") {
31+
assert.Equal(t, context.Canceled, cancelCtx.Err())
32+
}
33+
assert.NoError(t, ctx.Err())
34+
channel <- os.Interrupt
35+
select {
36+
case <-time.After(1 * time.Second):
37+
t.Fatal("context not canceled")
38+
case <-ctx.Done():
39+
}
40+
if assert.Error(t, ctx.Err(), "context canceled") {
41+
assert.Equal(t, context.Canceled, ctx.Err())
42+
}
43+
}
44+
45+
func TestForceCancellationViaSigterm(t *testing.T) {
46+
ctx, cancel, channel := createGracefulJobCancellationContext()
47+
defer cancel()
48+
assert.NotNil(t, ctx)
49+
assert.NotNil(t, cancel)
50+
assert.NotNil(t, channel)
51+
cancelCtx := JobCancelContext(ctx)
52+
assert.NotNil(t, cancelCtx)
53+
assert.NoError(t, ctx.Err())
54+
assert.NoError(t, cancelCtx.Err())
55+
channel <- syscall.SIGTERM
56+
select {
57+
case <-time.After(1 * time.Second):
58+
t.Fatal("context not canceled")
59+
case <-cancelCtx.Done():
60+
}
61+
select {
62+
case <-time.After(1 * time.Second):
63+
t.Fatal("context not canceled")
64+
case <-ctx.Done():
65+
}
66+
if assert.Error(t, ctx.Err(), "context canceled") {
67+
assert.Equal(t, context.Canceled, ctx.Err())
68+
}
69+
if assert.Error(t, cancelCtx.Err(), "context canceled") {
70+
assert.Equal(t, context.Canceled, cancelCtx.Err())
71+
}
72+
}
73+
74+
func TestCreateGracefulJobCancellationContext(t *testing.T) {
75+
ctx, cancel := CreateGracefulJobCancellationContext()
76+
defer cancel()
77+
assert.NotNil(t, ctx)
78+
assert.NotNil(t, cancel)
79+
cancelCtx := JobCancelContext(ctx)
80+
assert.NotNil(t, cancelCtx)
81+
assert.NoError(t, cancelCtx.Err())
82+
}
83+
84+
func TestCreateGracefulJobCancellationContextCancelFunc(t *testing.T) {
85+
ctx, cancel := CreateGracefulJobCancellationContext()
86+
assert.NotNil(t, ctx)
87+
assert.NotNil(t, cancel)
88+
cancelCtx := JobCancelContext(ctx)
89+
assert.NotNil(t, cancelCtx)
90+
assert.NoError(t, cancelCtx.Err())
91+
cancel()
92+
if assert.Error(t, ctx.Err(), "context canceled") {
93+
assert.Equal(t, context.Canceled, ctx.Err())
94+
}
95+
if assert.Error(t, cancelCtx.Err(), "context canceled") {
96+
assert.Equal(t, context.Canceled, cancelCtx.Err())
97+
}
98+
}

pkg/common/job_error.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ type jobErrorContextKey string
88

99
const jobErrorContextKeyVal = jobErrorContextKey("job.error")
1010

11+
type jobCancelCtx string
12+
13+
const JobCancelCtxVal = jobCancelCtx("job.cancel")
14+
1115
// JobError returns the job error for current context if any
1216
func JobError(ctx context.Context) error {
1317
val := ctx.Value(jobErrorContextKeyVal)
@@ -28,3 +32,35 @@ func WithJobErrorContainer(ctx context.Context) context.Context {
2832
container := map[string]error{}
2933
return context.WithValue(ctx, jobErrorContextKeyVal, container)
3034
}
35+
36+
func WithJobCancelContext(ctx context.Context, cancelContext context.Context) context.Context {
37+
return context.WithValue(ctx, JobCancelCtxVal, cancelContext)
38+
}
39+
40+
func JobCancelContext(ctx context.Context) context.Context {
41+
val := ctx.Value(JobCancelCtxVal)
42+
if val != nil {
43+
if container, ok := val.(context.Context); ok {
44+
return container
45+
}
46+
}
47+
return nil
48+
}
49+
50+
// EarlyCancelContext returns a new context based on ctx that is canceled when the first of the provided contexts is canceled.
51+
func EarlyCancelContext(ctx context.Context) (context.Context, context.CancelFunc) {
52+
val := JobCancelContext(ctx)
53+
if val != nil {
54+
context, cancel := context.WithCancel(ctx)
55+
go func() {
56+
defer cancel()
57+
select {
58+
case <-context.Done():
59+
case <-ctx.Done():
60+
case <-val.Done():
61+
}
62+
}()
63+
return context, cancel
64+
}
65+
return ctx, func() {}
66+
}

pkg/runner/run_context.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ type RunContext struct {
5151
Masks []string
5252
cleanUpJobContainer common.Executor
5353
caller *caller // job calling this RunContext (reusable workflows)
54+
Cancelled bool
5455
nodeToolFullPath string
5556
}
5657

@@ -435,6 +436,8 @@ func (rc *RunContext) execJobContainer(cmd []string, env map[string]string, user
435436

436437
func (rc *RunContext) InitializeNodeTool() common.Executor {
437438
return func(ctx context.Context) error {
439+
ctx, cancel := common.EarlyCancelContext(ctx)
440+
defer cancel()
438441
rc.GetNodeToolFullPath(ctx)
439442
return nil
440443
}
@@ -651,6 +654,8 @@ func (rc *RunContext) interpolateOutputs() common.Executor {
651654

652655
func (rc *RunContext) startContainer() common.Executor {
653656
return func(ctx context.Context) error {
657+
ctx, cancel := common.EarlyCancelContext(ctx)
658+
defer cancel()
654659
if rc.IsHostEnv(ctx) {
655660
return rc.startHostEnvironment()(ctx)
656661
}
@@ -845,10 +850,14 @@ func trimToLen(s string, l int) string {
845850

846851
func (rc *RunContext) getJobContext() *model.JobContext {
847852
jobStatus := "success"
848-
for _, stepStatus := range rc.StepResults {
849-
if stepStatus.Conclusion == model.StepStatusFailure {
850-
jobStatus = "failure"
851-
break
853+
if rc.Cancelled {
854+
jobStatus = "cancelled"
855+
} else {
856+
for _, stepStatus := range rc.StepResults {
857+
if stepStatus.Conclusion == model.StepStatusFailure {
858+
jobStatus = "failure"
859+
break
860+
}
852861
}
853862
}
854863
return &model.JobContext{

pkg/runner/step.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ func runStepExecutor(step step, stage stepStage, executor common.Executor) commo
8585
return err
8686
}
8787

88+
cctx := common.JobCancelContext(ctx)
89+
rc.Cancelled = cctx != nil && cctx.Err() != nil
90+
8891
runStep, err := isStepEnabled(ctx, ifExpression, step, stage)
8992
if err != nil {
9093
stepResult.Conclusion = model.StepStatusFailure
@@ -140,10 +143,14 @@ func runStepExecutor(step step, stage stepStage, executor common.Executor) commo
140143
Mode: 0o666,
141144
})(ctx)
142145

143-
timeoutctx, cancelTimeOut := evaluateStepTimeout(ctx, rc.ExprEval, stepModel)
146+
stepCtx, cancelStepCtx := context.WithCancel(ctx)
147+
defer cancelStepCtx()
148+
var cancelTimeOut context.CancelFunc
149+
stepCtx, cancelTimeOut = evaluateStepTimeout(stepCtx, rc.ExprEval, stepModel)
144150
defer cancelTimeOut()
151+
monitorJobCancellation(ctx, stepCtx, cctx, rc, logger, ifExpression, step, stage, cancelStepCtx)
145152
startTime := time.Now()
146-
err = executor(timeoutctx)
153+
err = executor(stepCtx)
147154
executionTime := time.Since(startTime)
148155

149156
if err == nil {
@@ -192,6 +199,24 @@ func runStepExecutor(step step, stage stepStage, executor common.Executor) commo
192199
}
193200
}
194201

202+
func monitorJobCancellation(ctx context.Context, stepCtx context.Context, jobCancellationCtx context.Context, rc *RunContext, logger logrus.FieldLogger, ifExpression string, step step, stage stepStage, cancelStepCtx context.CancelFunc) {
203+
if !rc.Cancelled && jobCancellationCtx != nil {
204+
go func() {
205+
select {
206+
case <-jobCancellationCtx.Done():
207+
rc.Cancelled = true
208+
logger.Infof("Reevaluate condition %v due to cancellation", ifExpression)
209+
keepStepRunning, err := isStepEnabled(ctx, ifExpression, step, stage)
210+
logger.Infof("Result condition keepStepRunning=%v", keepStepRunning)
211+
if !keepStepRunning || err != nil {
212+
cancelStepCtx()
213+
}
214+
case <-stepCtx.Done():
215+
}
216+
}()
217+
}
218+
}
219+
195220
func evaluateStepTimeout(ctx context.Context, exprEval ExpressionEvaluator, stepModel *model.Step) (context.Context, context.CancelFunc) {
196221
timeout := exprEval.Interpolate(ctx, stepModel.TimeoutMinutes)
197222
if timeout != "" {

0 commit comments

Comments
 (0)