Skip to content

Commit 27020ac

Browse files
ethanalee-workgopherbot
authored andcommitted
internal/server: add module upgrade pathway after vulncheck scanning
- Provide prompt messages indicating vulnerabilities found and actions that users can take to update dependencies and `go mod tidy`. Change-Id: I7d18fb48ac53ee3e4857fa3ff4be2968260e33db Reviewed-on: https://go-review.googlesource.com/c/tools/+/736122 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Hongxiang Jiang <[email protected]> Auto-Submit: Ethan Lee <[email protected]>
1 parent c4ec0f5 commit 27020ac

File tree

2 files changed

+164
-52
lines changed

2 files changed

+164
-52
lines changed

gopls/internal/server/vulncheck_prompt.go

Lines changed: 160 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package server
66

77
import (
8+
"bytes"
89
"context"
910
"crypto/sha256"
1011
"encoding/hex"
@@ -18,10 +19,13 @@ import (
1819
"strings"
1920

2021
"golang.org/x/mod/modfile"
22+
"golang.org/x/mod/semver"
23+
"golang.org/x/tools/gopls/internal/cache"
2124
"golang.org/x/tools/gopls/internal/filecache"
2225
"golang.org/x/tools/gopls/internal/progress"
2326
"golang.org/x/tools/gopls/internal/protocol"
2427
"golang.org/x/tools/gopls/internal/settings"
28+
"golang.org/x/tools/gopls/internal/vulncheck/govulncheck"
2529
"golang.org/x/tools/internal/event"
2630
"golang.org/x/tools/internal/xcontext"
2731
)
@@ -38,60 +42,60 @@ const (
3842
func computeGoModHash(file *modfile.File) (string, error) {
3943
h := sha256.New()
4044
for _, req := range file.Require {
41-
if _, err := h.Write([]byte(req.Mod.Path + req.Mod.Version)); err != nil {
45+
if _, err := h.Write([]byte(req.Mod.Path + "\x00" + req.Mod.Version)); err != nil {
4246
return "", err
4347
}
4448
}
4549
for _, exc := range file.Exclude {
46-
if _, err := h.Write([]byte(exc.Mod.Path + exc.Mod.Version)); err != nil {
50+
if _, err := h.Write([]byte(exc.Mod.Path + "\x00" + exc.Mod.Version)); err != nil {
4751
return "", err
4852
}
4953
}
5054
for _, rep := range file.Replace {
51-
if _, err := h.Write([]byte(rep.Old.Path + rep.Old.Version + rep.New.Path + rep.New.Version)); err != nil {
55+
if _, err := h.Write([]byte(rep.Old.Path + "\x00" + rep.Old.Version + "\x00" + rep.New.Path + "\x00" + rep.New.Version)); err != nil {
5256
return "", err
5357
}
5458
}
5559
return hex.EncodeToString(h.Sum(nil)), nil
5660
}
5761

62+
func getModFileHashes(uri protocol.DocumentURI) (contentHash string, pathHash [32]byte, err error) {
63+
content, err := os.ReadFile(uri.Path())
64+
if err != nil {
65+
return "", [32]byte{}, err
66+
}
67+
newModFile, err := modfile.Parse("go.mod", content, nil)
68+
if err != nil {
69+
return "", [32]byte{}, err
70+
}
71+
contentHash, err = computeGoModHash(newModFile)
72+
if err != nil {
73+
return "", [32]byte{}, err
74+
}
75+
pathHash = sha256.Sum256([]byte(uri.Path()))
76+
return contentHash, pathHash, nil
77+
}
78+
5879
func (s *server) checkGoModDeps(ctx context.Context, uri protocol.DocumentURI) {
5980
if s.Options().Vulncheck != settings.ModeVulncheckPrompt {
6081
return
6182
}
6283
ctx, done := event.Start(ctx, "server.CheckGoModDeps")
6384
defer done()
6485

65-
var (
66-
newHash, oldHash string
67-
pathHash [32]byte
68-
)
69-
{
70-
newContent, err := os.ReadFile(uri.Path())
71-
if err != nil {
72-
event.Error(ctx, "reading new go.mod content failed", err)
73-
return
74-
}
75-
newModFile, err := modfile.Parse("go.mod", newContent, nil)
76-
if err != nil {
77-
event.Error(ctx, "parsing new go.mod failed", err)
78-
return
79-
}
80-
hash, err := computeGoModHash(newModFile)
81-
if err != nil {
82-
event.Error(ctx, "computing new go.mod hash failed", err)
83-
return
84-
}
85-
newHash = hash
86+
newHash, pathHash, err := getModFileHashes(uri)
87+
if err != nil {
88+
event.Error(ctx, "getting go.mod hashes failed", err)
89+
return
90+
}
8691

87-
pathHash = sha256.Sum256([]byte(uri.Path()))
88-
oldHashBytes, err := filecache.Get(goModHashKind, pathHash)
89-
if err != nil && err != filecache.ErrNotFound {
90-
event.Error(ctx, "reading old go.mod hash from filecache failed", err)
91-
return
92-
}
93-
oldHash = string(oldHashBytes)
92+
oldHashBytes, err := filecache.Get(goModHashKind, pathHash)
93+
if err != nil && err != filecache.ErrNotFound {
94+
event.Error(ctx, "reading old go.mod hash from filecache failed", err)
95+
return
9496
}
97+
oldHash := string(oldHashBytes)
98+
9599
if oldHash != newHash {
96100
fileLink := fmt.Sprintf("[%s](%s)", uri.Path(), string(uri))
97101
govulncheckLink := "[govulncheck](https://pkg.go.dev/golang.org/x/vuln/cmd/govulncheck)"
@@ -153,22 +157,14 @@ func (s *server) handleVulncheck(ctx context.Context, uri protocol.DocumentURI)
153157
workDoneWriter := progress.NewWorkDoneWriter(ctx, work)
154158
result, err := s.runVulncheck(ctx, snapshot, uri, "./...", workDoneWriter)
155159
if err != nil {
156-
event.Error(ctx, "vulncheck failed", err)
160+
event.Error(ctx, "govulncheck failed", err)
161+
showMessage(ctx, s.client, protocol.Error, fmt.Sprintf("govulncheck failed: %v", err))
157162
return
158163
}
159164

160-
affecting := make(map[string]struct{})
161-
upgrades := make(map[string]string)
162-
for _, f := range result.Findings {
163-
if len(f.Trace) > 1 {
164-
affecting[f.OSV] = struct{}{}
165-
if f.FixedVersion != "" && f.Trace[0].Module != "stdlib" {
166-
upgrades[f.Trace[0].Module] = f.FixedVersion
167-
}
168-
}
169-
}
170-
171-
if len(affecting) == 0 {
165+
affecting, stdLibVulns, modulesToUpgrade := computeModulesToUpgrade(result.Findings)
166+
numStdLib := len(stdLibVulns)
167+
if len(affecting) == 0 && numStdLib == 0 {
172168
showMessage(ctx, s.client, protocol.Info, "No vulnerabilities found.")
173169
return
174170
}
@@ -177,7 +173,7 @@ func (s *server) handleVulncheck(ctx context.Context, uri protocol.DocumentURI)
177173
sort.Strings(affectingOSVs)
178174

179175
var b strings.Builder
180-
fmt.Fprintf(&b, "Found %d vulnerabilities affecting your dependencies:\n\n", len(affectingOSVs))
176+
fmt.Fprintf(&b, "Found %d actionable vulnerabilities and %d standard library vulnerabilities affecting your dependencies:\n\n", len(affectingOSVs), numStdLib)
181177

182178
for i, id := range affectingOSVs {
183179
if i >= maxVulnsToShow {
@@ -197,16 +193,132 @@ func (s *server) handleVulncheck(ctx context.Context, uri protocol.DocumentURI)
197193
fmt.Fprintf(&b, "\n\n...and %d more.", len(affectingOSVs)-maxVulnsToShow)
198194
}
199195

200-
action, err := showMessageRequest(ctx, s.client, protocol.Warning, b.String(), "Upgrade All", "Ignore")
196+
if numStdLib > 0 {
197+
if b.Len() > 0 {
198+
b.WriteString("\n\n")
199+
}
200+
b.WriteString("Upgrading your Go version may address vulnerabilities in the standard library.")
201+
}
202+
203+
actions := []string{"Ignore"}
204+
if len(modulesToUpgrade) > 0 {
205+
actions = append([]string{"Upgrade All"}, actions...)
206+
}
207+
208+
action, err := showMessageRequest(ctx, s.client, protocol.Warning, b.String(), actions...)
201209
if err != nil {
202210
event.Error(ctx, "vulncheck remediation failed", err)
203211
return
204212
}
205213

206214
if action == "Upgrade All" {
207-
// TODO: Add dependency upgrade functionality.
208-
showMessage(ctx, s.client, protocol.Info, "Upgrading all modules... (not yet implemented)")
215+
if err := s.upgradeModules(ctx, snapshot, uri, modulesToUpgrade); err != nil {
216+
event.Error(ctx, "upgrading modules failed", err)
217+
}
218+
}
219+
}
220+
221+
func computeModulesToUpgrade(findings []*govulncheck.Finding) (affecting map[string]bool, stdLibVulns map[string]bool, modulesToUpgrade map[string]string) {
222+
affecting = make(map[string]bool)
223+
stdLibVulns = make(map[string]bool)
224+
modulesToUpgrade = make(map[string]string)
225+
226+
for _, f := range findings {
227+
if len(f.Trace) == 0 {
228+
continue
229+
}
230+
mod := f.Trace[0].Module
231+
// An empty module path or "stdlib" indicates a standard library package.
232+
// These vulnerabilities cannot be remediated via module upgrades and
233+
// instead require updating the Go toolchain.
234+
if mod == "stdlib" || mod == "" {
235+
stdLibVulns[f.OSV] = true
236+
} else {
237+
affecting[f.OSV] = true
238+
if f.FixedVersion != "" {
239+
current, ok := modulesToUpgrade[mod]
240+
if !ok || current == "latest" || semver.Compare(f.FixedVersion, current) > 0 {
241+
modulesToUpgrade[mod] = f.FixedVersion
242+
}
243+
} else if _, ok := modulesToUpgrade[mod]; !ok {
244+
modulesToUpgrade[mod] = "latest"
245+
}
246+
}
247+
}
248+
return affecting, stdLibVulns, modulesToUpgrade
249+
}
250+
251+
func (s *server) upgradeModules(ctx context.Context, snapshot *cache.Snapshot, uri protocol.DocumentURI, modulesToUpgrade map[string]string) error {
252+
if err := s.runGoGet(ctx, snapshot, uri, modulesToUpgrade); err != nil {
253+
return err
254+
}
255+
if err := s.runGoModTidy(ctx, snapshot, uri); err != nil {
256+
return err
257+
}
258+
259+
var (
260+
upgradedStrs []string
261+
upgrades []string
262+
)
263+
for module, version := range modulesToUpgrade {
264+
upgrades = append(upgrades, module+"@"+version)
265+
upgradedStrs = append(upgradedStrs, fmt.Sprintf("%s to %s", module, version))
266+
}
267+
sort.Strings(upgradedStrs)
268+
269+
msg := fmt.Sprintf("Successfully upgraded vulnerable modules:\n %s", strings.Join(upgradedStrs, ",\n "))
270+
showMessage(ctx, s.client, protocol.Info, msg)
271+
if hash, pathHash, err := getModFileHashes(uri); err == nil {
272+
if err := filecache.Set(goModHashKind, pathHash, []byte(hash)); err != nil {
273+
event.Error(ctx, "failed to update go.mod hash after upgrade", err)
274+
}
275+
} else {
276+
event.Error(ctx, "failed to get go.mod hash after upgrade", err)
277+
}
278+
return nil
279+
}
280+
281+
func (s *server) runGoGet(ctx context.Context, snapshot *cache.Snapshot, uri protocol.DocumentURI, modulesToUpgrade map[string]string) error {
282+
work := s.progress.Start(ctx, "Upgrading Modules", "Running go get...", nil, nil)
283+
defer work.End(ctx, "Done.")
284+
285+
var upgrades []string
286+
for module, version := range modulesToUpgrade {
287+
upgrades = append(upgrades, module+"@"+version)
288+
}
289+
290+
if err := runGoCommand(ctx, snapshot, uri, "get", upgrades); err != nil {
291+
msg := fmt.Sprintf("Failed to upgrade modules: %v", err)
292+
showMessage(ctx, s.client, protocol.Error, msg)
293+
return err
294+
}
295+
return nil
296+
}
297+
298+
func (s *server) runGoModTidy(ctx context.Context, snapshot *cache.Snapshot, uri protocol.DocumentURI) error {
299+
work := s.progress.Start(ctx, "Upgrading Modules", "Running go mod tidy...", nil, nil)
300+
defer work.End(ctx, "Done.")
301+
302+
if err := runGoCommand(ctx, snapshot, uri, "mod", []string{"tidy"}); err != nil {
303+
event.Error(ctx, "go mod tidy failed", err)
304+
showMessage(ctx, s.client, protocol.Error, fmt.Sprintf("go mod tidy failed: %v", err))
305+
return err
306+
}
307+
return nil
308+
}
309+
310+
func runGoCommand(ctx context.Context, snapshot *cache.Snapshot, uri protocol.DocumentURI, verb string, args []string) error {
311+
dir := uri.DirPath()
312+
inv, cleanup, err := snapshot.GoCommandInvocation(cache.NetworkOK, dir, verb, args)
313+
if err != nil {
314+
return err
315+
}
316+
defer cleanup()
317+
var stdout, stderr bytes.Buffer
318+
if err := snapshot.View().GoCommandRunner().RunPiped(ctx, *inv, &stdout, &stderr); err != nil {
319+
return fmt.Errorf("go %s %s failed: %v\n-- stdout --\n%s\n-- stderr --\n%s", verb, strings.Join(args, " "), err, stdout.String(), stderr.String())
209320
}
321+
return nil
210322
}
211323

212324
type vulncheckConfig struct {

gopls/internal/server/vulncheck_prompt_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ func TestComputeGoModHash(t *testing.T) {
4545
`,
4646
want: func() string {
4747
h := sha256.New()
48-
h.Write([]byte("golang.org/x/toolsv0.1.0"))
49-
h.Write([]byte("golang.org/x/vulnv0.2.0"))
48+
h.Write([]byte("golang.org/x/tools\x00v0.1.0"))
49+
h.Write([]byte("golang.org/x/vuln\x00v0.2.0"))
5050
return hex.EncodeToString(h.Sum(nil))
5151
}(),
5252
},
@@ -60,7 +60,7 @@ func TestComputeGoModHash(t *testing.T) {
6060
`,
6161
want: func() string {
6262
h := sha256.New()
63-
h.Write([]byte("golang.org/x/toolsv0.1.0"))
63+
h.Write([]byte("golang.org/x/tools\x00v0.1.0"))
6464
return hex.EncodeToString(h.Sum(nil))
6565
}(),
6666
},
@@ -74,7 +74,7 @@ func TestComputeGoModHash(t *testing.T) {
7474
`,
7575
want: func() string {
7676
h := sha256.New()
77-
h.Write([]byte("golang.org/x/toolsv0.1.0golang.org/x/toolsv0.2.0"))
77+
h.Write([]byte("golang.org/x/tools\x00v0.1.0\x00golang.org/x/tools\x00v0.2.0"))
7878
return hex.EncodeToString(h.Sum(nil))
7979
}(),
8080
},

0 commit comments

Comments
 (0)