Skip to content

Commit b595bdd

Browse files
committed
fix bug
1 parent 4a16049 commit b595bdd

1 file changed

Lines changed: 54 additions & 17 deletions

File tree

  • internal/runtime/builtin/ssh

internal/runtime/builtin/ssh/ssh.go

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io"
77
"os"
8+
"sync"
89

910
"github.com/dagu-org/dagu/internal/common/cmdutil"
1011
"github.com/dagu-org/dagu/internal/common/logger"
@@ -33,6 +34,7 @@ func getSSHClientFromContext(ctx context.Context) *Client {
3334
}
3435

3536
type sshExecutor struct {
37+
mu sync.Mutex
3638
step core.Step
3739
client *Client
3840
stdout io.Writer
@@ -76,8 +78,12 @@ func (e *sshExecutor) SetStderr(out io.Writer) {
7678
}
7779

7880
func (e *sshExecutor) Kill(_ os.Signal) error {
79-
if e.session != nil {
80-
return e.session.Close()
81+
e.mu.Lock()
82+
session := e.session
83+
e.mu.Unlock()
84+
85+
if session != nil {
86+
return session.Close()
8187
}
8288
return nil
8389
}
@@ -98,28 +104,59 @@ func (e *sshExecutor) Run(ctx context.Context) error {
98104
default:
99105
}
100106

101-
session, err := e.client.NewSession()
102-
if err != nil {
103-
return fmt.Errorf("command %d: failed to create session: %w", i+1, err)
107+
if err := e.runCommand(ctx, i, cmdEntry); err != nil {
108+
return err
104109
}
105-
e.session = session
110+
}
111+
112+
return nil
113+
}
106114

107-
session.Stdout = e.stdout
108-
session.Stderr = e.stderr
115+
// runCommand executes a single command with context cancellation support.
116+
// Since session.Run() blocks without context awareness, we run it in a goroutine
117+
// and select on both completion and context cancellation for responsiveness.
118+
func (e *sshExecutor) runCommand(ctx context.Context, index int, cmdEntry core.CommandEntry) error {
119+
session, err := e.client.NewSession()
120+
if err != nil {
121+
return fmt.Errorf("command %d: failed to create session: %w", index+1, err)
122+
}
109123

110-
command := cmdutil.ShellQuote(cmdEntry.Command)
111-
if len(cmdEntry.Args) > 0 {
112-
command += " " + cmdutil.ShellQuoteArgs(cmdEntry.Args)
113-
}
124+
e.mu.Lock()
125+
e.session = session
126+
e.mu.Unlock()
114127

115-
err = session.Run(command)
128+
session.Stdout = e.stdout
129+
session.Stderr = e.stderr
130+
131+
command := cmdutil.ShellQuote(cmdEntry.Command)
132+
if len(cmdEntry.Args) > 0 {
133+
command += " " + cmdutil.ShellQuoteArgs(cmdEntry.Args)
134+
}
135+
136+
// Run command in goroutine to enable context cancellation
137+
done := make(chan error, 1)
138+
go func() {
139+
done <- session.Run(command)
140+
}()
141+
142+
// Wait for either command completion or context cancellation
143+
select {
144+
case err = <-done:
145+
// Command completed (success or failure)
146+
case <-ctx.Done():
147+
// Context cancelled - close session to terminate the command
116148
if closeErr := session.Close(); closeErr != nil {
117-
logger.Warn(ctx, "SSH session close error", tag.Error(closeErr))
149+
logger.Warn(ctx, "SSH session close error during cancellation", tag.Error(closeErr))
118150
}
151+
return ctx.Err()
152+
}
119153

120-
if err != nil {
121-
return fmt.Errorf("command %d failed: %w", i+1, err)
122-
}
154+
if closeErr := session.Close(); closeErr != nil {
155+
logger.Warn(ctx, "SSH session close error", tag.Error(closeErr))
156+
}
157+
158+
if err != nil {
159+
return fmt.Errorf("command %d failed: %w", index+1, err)
123160
}
124161

125162
return nil

0 commit comments

Comments
 (0)