55 "fmt"
66 "io"
77 "os"
8+ "sync"
89
910 "github.com/dagu-org/dagu/internal/common/cmdutil"
1011 "github.com/dagu-org/dagu/internal/common/logger"
@@ -33,6 +34,7 @@ func getSSHClientFromContext(ctx context.Context) *Client {
3334}
3435
3536type sshExecutor struct {
37+ mu sync.Mutex
3638 step core.Step
3739 client * Client
3840 stdout io.Writer
@@ -76,8 +78,12 @@ func (e *sshExecutor) SetStderr(out io.Writer) {
7678}
7779
7880func (e * sshExecutor ) Kill (_ os.Signal ) error {
79- if e .session != nil {
80- return e .session .Close ()
81+ e .mu .Lock ()
82+ session := e .session
83+ e .mu .Unlock ()
84+
85+ if session != nil {
86+ return session .Close ()
8187 }
8288 return nil
8389}
@@ -98,28 +104,59 @@ func (e *sshExecutor) Run(ctx context.Context) error {
98104 default :
99105 }
100106
101- session , err := e .client .NewSession ()
102- if err != nil {
103- return fmt .Errorf ("command %d: failed to create session: %w" , i + 1 , err )
107+ if err := e .runCommand (ctx , i , cmdEntry ); err != nil {
108+ return err
104109 }
105- e .session = session
110+ }
111+
112+ return nil
113+ }
106114
107- session .Stdout = e .stdout
108- session .Stderr = e .stderr
115+ // runCommand executes a single command with context cancellation support.
116+ // Since session.Run() blocks without context awareness, we run it in a goroutine
117+ // and select on both completion and context cancellation for responsiveness.
118+ func (e * sshExecutor ) runCommand (ctx context.Context , index int , cmdEntry core.CommandEntry ) error {
119+ session , err := e .client .NewSession ()
120+ if err != nil {
121+ return fmt .Errorf ("command %d: failed to create session: %w" , index + 1 , err )
122+ }
109123
110- command := cmdutil .ShellQuote (cmdEntry .Command )
111- if len (cmdEntry .Args ) > 0 {
112- command += " " + cmdutil .ShellQuoteArgs (cmdEntry .Args )
113- }
124+ e .mu .Lock ()
125+ e .session = session
126+ e .mu .Unlock ()
114127
115- err = session .Run (command )
128+ session .Stdout = e .stdout
129+ session .Stderr = e .stderr
130+
131+ command := cmdutil .ShellQuote (cmdEntry .Command )
132+ if len (cmdEntry .Args ) > 0 {
133+ command += " " + cmdutil .ShellQuoteArgs (cmdEntry .Args )
134+ }
135+
136+ // Run command in goroutine to enable context cancellation
137+ done := make (chan error , 1 )
138+ go func () {
139+ done <- session .Run (command )
140+ }()
141+
142+ // Wait for either command completion or context cancellation
143+ select {
144+ case err = <- done :
145+ // Command completed (success or failure)
146+ case <- ctx .Done ():
147+ // Context cancelled - close session to terminate the command
116148 if closeErr := session .Close (); closeErr != nil {
117- logger .Warn (ctx , "SSH session close error" , tag .Error (closeErr ))
149+ logger .Warn (ctx , "SSH session close error during cancellation " , tag .Error (closeErr ))
118150 }
151+ return ctx .Err ()
152+ }
119153
120- if err != nil {
121- return fmt .Errorf ("command %d failed: %w" , i + 1 , err )
122- }
154+ if closeErr := session .Close (); closeErr != nil {
155+ logger .Warn (ctx , "SSH session close error" , tag .Error (closeErr ))
156+ }
157+
158+ if err != nil {
159+ return fmt .Errorf ("command %d failed: %w" , index + 1 , err )
123160 }
124161
125162 return nil
0 commit comments