55package server
66
77import (
8+ "bytes"
89 "context"
910 "crypto/sha256"
1011 "encoding/hex"
@@ -18,10 +19,13 @@ import (
1819 "strings"
1920
2021 "golang.org/x/mod/modfile"
22+ "golang.org/x/mod/semver"
23+ "golang.org/x/tools/gopls/internal/cache"
2124 "golang.org/x/tools/gopls/internal/filecache"
2225 "golang.org/x/tools/gopls/internal/progress"
2326 "golang.org/x/tools/gopls/internal/protocol"
2427 "golang.org/x/tools/gopls/internal/settings"
28+ "golang.org/x/tools/gopls/internal/vulncheck/govulncheck"
2529 "golang.org/x/tools/internal/event"
2630 "golang.org/x/tools/internal/xcontext"
2731)
@@ -38,60 +42,60 @@ const (
3842func computeGoModHash (file * modfile.File ) (string , error ) {
3943 h := sha256 .New ()
4044 for _ , req := range file .Require {
41- if _ , err := h .Write ([]byte (req .Mod .Path + req .Mod .Version )); err != nil {
45+ if _ , err := h .Write ([]byte (req .Mod .Path + " \x00 " + req .Mod .Version )); err != nil {
4246 return "" , err
4347 }
4448 }
4549 for _ , exc := range file .Exclude {
46- if _ , err := h .Write ([]byte (exc .Mod .Path + exc .Mod .Version )); err != nil {
50+ if _ , err := h .Write ([]byte (exc .Mod .Path + " \x00 " + exc .Mod .Version )); err != nil {
4751 return "" , err
4852 }
4953 }
5054 for _ , rep := range file .Replace {
51- if _ , err := h .Write ([]byte (rep .Old .Path + rep .Old .Version + rep .New .Path + rep .New .Version )); err != nil {
55+ if _ , err := h .Write ([]byte (rep .Old .Path + " \x00 " + rep .Old .Version + " \x00 " + rep .New .Path + " \x00 " + rep .New .Version )); err != nil {
5256 return "" , err
5357 }
5458 }
5559 return hex .EncodeToString (h .Sum (nil )), nil
5660}
5761
62+ func getModFileHashes (uri protocol.DocumentURI ) (contentHash string , pathHash [32 ]byte , err error ) {
63+ content , err := os .ReadFile (uri .Path ())
64+ if err != nil {
65+ return "" , [32 ]byte {}, err
66+ }
67+ newModFile , err := modfile .Parse ("go.mod" , content , nil )
68+ if err != nil {
69+ return "" , [32 ]byte {}, err
70+ }
71+ contentHash , err = computeGoModHash (newModFile )
72+ if err != nil {
73+ return "" , [32 ]byte {}, err
74+ }
75+ pathHash = sha256 .Sum256 ([]byte (uri .Path ()))
76+ return contentHash , pathHash , nil
77+ }
78+
5879func (s * server ) checkGoModDeps (ctx context.Context , uri protocol.DocumentURI ) {
5980 if s .Options ().Vulncheck != settings .ModeVulncheckPrompt {
6081 return
6182 }
6283 ctx , done := event .Start (ctx , "server.CheckGoModDeps" )
6384 defer done ()
6485
65- var (
66- newHash , oldHash string
67- pathHash [32 ]byte
68- )
69- {
70- newContent , err := os .ReadFile (uri .Path ())
71- if err != nil {
72- event .Error (ctx , "reading new go.mod content failed" , err )
73- return
74- }
75- newModFile , err := modfile .Parse ("go.mod" , newContent , nil )
76- if err != nil {
77- event .Error (ctx , "parsing new go.mod failed" , err )
78- return
79- }
80- hash , err := computeGoModHash (newModFile )
81- if err != nil {
82- event .Error (ctx , "computing new go.mod hash failed" , err )
83- return
84- }
85- newHash = hash
86+ newHash , pathHash , err := getModFileHashes (uri )
87+ if err != nil {
88+ event .Error (ctx , "getting go.mod hashes failed" , err )
89+ return
90+ }
8691
87- pathHash = sha256 .Sum256 ([]byte (uri .Path ()))
88- oldHashBytes , err := filecache .Get (goModHashKind , pathHash )
89- if err != nil && err != filecache .ErrNotFound {
90- event .Error (ctx , "reading old go.mod hash from filecache failed" , err )
91- return
92- }
93- oldHash = string (oldHashBytes )
92+ oldHashBytes , err := filecache .Get (goModHashKind , pathHash )
93+ if err != nil && err != filecache .ErrNotFound {
94+ event .Error (ctx , "reading old go.mod hash from filecache failed" , err )
95+ return
9496 }
97+ oldHash := string (oldHashBytes )
98+
9599 if oldHash != newHash {
96100 fileLink := fmt .Sprintf ("[%s](%s)" , uri .Path (), string (uri ))
97101 govulncheckLink := "[govulncheck](https://pkg.go.dev/golang.org/x/vuln/cmd/govulncheck)"
@@ -153,22 +157,14 @@ func (s *server) handleVulncheck(ctx context.Context, uri protocol.DocumentURI)
153157 workDoneWriter := progress .NewWorkDoneWriter (ctx , work )
154158 result , err := s .runVulncheck (ctx , snapshot , uri , "./..." , workDoneWriter )
155159 if err != nil {
156- event .Error (ctx , "vulncheck failed" , err )
160+ event .Error (ctx , "govulncheck failed" , err )
161+ showMessage (ctx , s .client , protocol .Error , fmt .Sprintf ("govulncheck failed: %v" , err ))
157162 return
158163 }
159164
160- affecting := make (map [string ]struct {})
161- upgrades := make (map [string ]string )
162- for _ , f := range result .Findings {
163- if len (f .Trace ) > 1 {
164- affecting [f .OSV ] = struct {}{}
165- if f .FixedVersion != "" && f .Trace [0 ].Module != "stdlib" {
166- upgrades [f .Trace [0 ].Module ] = f .FixedVersion
167- }
168- }
169- }
170-
171- if len (affecting ) == 0 {
165+ affecting , stdLibVulns , modulesToUpgrade := computeModulesToUpgrade (result .Findings )
166+ numStdLib := len (stdLibVulns )
167+ if len (affecting ) == 0 && numStdLib == 0 {
172168 showMessage (ctx , s .client , protocol .Info , "No vulnerabilities found." )
173169 return
174170 }
@@ -177,7 +173,7 @@ func (s *server) handleVulncheck(ctx context.Context, uri protocol.DocumentURI)
177173 sort .Strings (affectingOSVs )
178174
179175 var b strings.Builder
180- fmt .Fprintf (& b , "Found %d vulnerabilities affecting your dependencies:\n \n " , len (affectingOSVs ))
176+ fmt .Fprintf (& b , "Found %d actionable vulnerabilities and %d standard library vulnerabilities affecting your dependencies:\n \n " , len (affectingOSVs ), numStdLib )
181177
182178 for i , id := range affectingOSVs {
183179 if i >= maxVulnsToShow {
@@ -197,16 +193,132 @@ func (s *server) handleVulncheck(ctx context.Context, uri protocol.DocumentURI)
197193 fmt .Fprintf (& b , "\n \n ...and %d more." , len (affectingOSVs )- maxVulnsToShow )
198194 }
199195
200- action , err := showMessageRequest (ctx , s .client , protocol .Warning , b .String (), "Upgrade All" , "Ignore" )
196+ if numStdLib > 0 {
197+ if b .Len () > 0 {
198+ b .WriteString ("\n \n " )
199+ }
200+ b .WriteString ("Upgrading your Go version may address vulnerabilities in the standard library." )
201+ }
202+
203+ actions := []string {"Ignore" }
204+ if len (modulesToUpgrade ) > 0 {
205+ actions = append ([]string {"Upgrade All" }, actions ... )
206+ }
207+
208+ action , err := showMessageRequest (ctx , s .client , protocol .Warning , b .String (), actions ... )
201209 if err != nil {
202210 event .Error (ctx , "vulncheck remediation failed" , err )
203211 return
204212 }
205213
206214 if action == "Upgrade All" {
207- // TODO: Add dependency upgrade functionality.
208- showMessage (ctx , s .client , protocol .Info , "Upgrading all modules... (not yet implemented)" )
215+ if err := s .upgradeModules (ctx , snapshot , uri , modulesToUpgrade ); err != nil {
216+ event .Error (ctx , "upgrading modules failed" , err )
217+ }
218+ }
219+ }
220+
221+ func computeModulesToUpgrade (findings []* govulncheck.Finding ) (affecting map [string ]bool , stdLibVulns map [string ]bool , modulesToUpgrade map [string ]string ) {
222+ affecting = make (map [string ]bool )
223+ stdLibVulns = make (map [string ]bool )
224+ modulesToUpgrade = make (map [string ]string )
225+
226+ for _ , f := range findings {
227+ if len (f .Trace ) == 0 {
228+ continue
229+ }
230+ mod := f .Trace [0 ].Module
231+ // An empty module path or "stdlib" indicates a standard library package.
232+ // These vulnerabilities cannot be remediated via module upgrades and
233+ // instead require updating the Go toolchain.
234+ if mod == "stdlib" || mod == "" {
235+ stdLibVulns [f .OSV ] = true
236+ } else {
237+ affecting [f .OSV ] = true
238+ if f .FixedVersion != "" {
239+ current , ok := modulesToUpgrade [mod ]
240+ if ! ok || current == "latest" || semver .Compare (f .FixedVersion , current ) > 0 {
241+ modulesToUpgrade [mod ] = f .FixedVersion
242+ }
243+ } else if _ , ok := modulesToUpgrade [mod ]; ! ok {
244+ modulesToUpgrade [mod ] = "latest"
245+ }
246+ }
247+ }
248+ return affecting , stdLibVulns , modulesToUpgrade
249+ }
250+
251+ func (s * server ) upgradeModules (ctx context.Context , snapshot * cache.Snapshot , uri protocol.DocumentURI , modulesToUpgrade map [string ]string ) error {
252+ if err := s .runGoGet (ctx , snapshot , uri , modulesToUpgrade ); err != nil {
253+ return err
254+ }
255+ if err := s .runGoModTidy (ctx , snapshot , uri ); err != nil {
256+ return err
257+ }
258+
259+ var (
260+ upgradedStrs []string
261+ upgrades []string
262+ )
263+ for module , version := range modulesToUpgrade {
264+ upgrades = append (upgrades , module + "@" + version )
265+ upgradedStrs = append (upgradedStrs , fmt .Sprintf ("%s to %s" , module , version ))
266+ }
267+ sort .Strings (upgradedStrs )
268+
269+ msg := fmt .Sprintf ("Successfully upgraded vulnerable modules:\n %s" , strings .Join (upgradedStrs , ",\n " ))
270+ showMessage (ctx , s .client , protocol .Info , msg )
271+ if hash , pathHash , err := getModFileHashes (uri ); err == nil {
272+ if err := filecache .Set (goModHashKind , pathHash , []byte (hash )); err != nil {
273+ event .Error (ctx , "failed to update go.mod hash after upgrade" , err )
274+ }
275+ } else {
276+ event .Error (ctx , "failed to get go.mod hash after upgrade" , err )
277+ }
278+ return nil
279+ }
280+
281+ func (s * server ) runGoGet (ctx context.Context , snapshot * cache.Snapshot , uri protocol.DocumentURI , modulesToUpgrade map [string ]string ) error {
282+ work := s .progress .Start (ctx , "Upgrading Modules" , "Running go get..." , nil , nil )
283+ defer work .End (ctx , "Done." )
284+
285+ var upgrades []string
286+ for module , version := range modulesToUpgrade {
287+ upgrades = append (upgrades , module + "@" + version )
288+ }
289+
290+ if err := runGoCommand (ctx , snapshot , uri , "get" , upgrades ); err != nil {
291+ msg := fmt .Sprintf ("Failed to upgrade modules: %v" , err )
292+ showMessage (ctx , s .client , protocol .Error , msg )
293+ return err
294+ }
295+ return nil
296+ }
297+
298+ func (s * server ) runGoModTidy (ctx context.Context , snapshot * cache.Snapshot , uri protocol.DocumentURI ) error {
299+ work := s .progress .Start (ctx , "Upgrading Modules" , "Running go mod tidy..." , nil , nil )
300+ defer work .End (ctx , "Done." )
301+
302+ if err := runGoCommand (ctx , snapshot , uri , "mod" , []string {"tidy" }); err != nil {
303+ event .Error (ctx , "go mod tidy failed" , err )
304+ showMessage (ctx , s .client , protocol .Error , fmt .Sprintf ("go mod tidy failed: %v" , err ))
305+ return err
306+ }
307+ return nil
308+ }
309+
310+ func runGoCommand (ctx context.Context , snapshot * cache.Snapshot , uri protocol.DocumentURI , verb string , args []string ) error {
311+ dir := uri .DirPath ()
312+ inv , cleanup , err := snapshot .GoCommandInvocation (cache .NetworkOK , dir , verb , args )
313+ if err != nil {
314+ return err
315+ }
316+ defer cleanup ()
317+ var stdout , stderr bytes.Buffer
318+ if err := snapshot .View ().GoCommandRunner ().RunPiped (ctx , * inv , & stdout , & stderr ); err != nil {
319+ return fmt .Errorf ("go %s %s failed: %v\n -- stdout --\n %s\n -- stderr --\n %s" , verb , strings .Join (args , " " ), err , stdout .String (), stderr .String ())
209320 }
321+ return nil
210322}
211323
212324type vulncheckConfig struct {
0 commit comments