@@ -181,6 +181,7 @@ func (d *Downloader) DownloadDirectory(ctx context.Context, input *DownloadDirec
181181
182182 outs := make (chan DownloadOutput , len (objectsToQueue ))
183183 inputs := make ([]DownloadObjectInput , 0 , len (objectsToQueue ))
184+ illegalPathObjects := make ([]DownloadObjectInput , 0 , len (objectsToQueue ))
184185
185186 for _ , object := range objectsToQueue {
186187 objectWithoutPrefix := object
@@ -191,8 +192,26 @@ func (d *Downloader) DownloadDirectory(ctx context.Context, input *DownloadDirec
191192 objDirectory := filepath .Join (input .LocalDirectory , filepath .Dir (objectWithoutPrefix ))
192193 filePath := filepath .Join (input .LocalDirectory , objectWithoutPrefix )
193194
195+ // Prevent directory traversal attacks.
196+ isUnder , err := isSubPath (input .LocalDirectory , filePath )
197+ if err != nil {
198+ cleanFiles (inputs )
199+ return fmt .Errorf ("transfermanager: DownloadDirectory failed to verify path: %w" , err )
200+ }
201+ if ! isUnder {
202+ // skipped files will later be added in the results
203+ illegalPathObjects = append (illegalPathObjects , DownloadObjectInput {
204+ Bucket : input .Bucket ,
205+ Object : object ,
206+ Callback : input .OnObjectDownload ,
207+ ctx : ctx ,
208+ directoryObjectOutputs : outs ,
209+ directory : true ,
210+ })
211+ continue
212+ }
194213 // Make sure all directories in the object path exist.
195- err : = os .MkdirAll (objDirectory , fs .ModeDir | fs .ModePerm )
214+ err = os .MkdirAll (objDirectory , fs .ModeDir | fs .ModePerm )
196215 if err != nil {
197216 cleanFiles (inputs )
198217 return fmt .Errorf ("transfermanager: DownloadDirectory failed to make directory(%q): %w" , objDirectory , err )
@@ -218,9 +237,21 @@ func (d *Downloader) DownloadDirectory(ctx context.Context, input *DownloadDirec
218237
219238 if d .config .asynchronous {
220239 d .downloadsInProgress .Add (1 )
221- go d .gatherObjectOutputs (input , outs , len (inputs ))
240+ allObjectsCount := len (inputs ) + len (illegalPathObjects )
241+ go d .gatherObjectOutputs (input , outs , allObjectsCount )
222242 }
223243 d .addNewInputs (inputs )
244+
245+ for _ , file := range illegalPathObjects {
246+ // the waitgroup is further decremented in addResult method
247+ d .downloadsInProgress .Add (1 )
248+ d .addResult (& file , & DownloadOutput {
249+ Bucket : file .Bucket ,
250+ Object : file .Object ,
251+ Err : fmt .Errorf ("skipping download of object with unsafe path %q" , file .Object ),
252+ skipped : true ,
253+ })
254+ }
224255 return nil
225256}
226257
@@ -311,7 +342,7 @@ func (d *Downloader) addNewInputs(inputs []DownloadObjectInput) {
311342func (d * Downloader ) addResult (input * DownloadObjectInput , result * DownloadOutput ) {
312343 copiedResult := * result // make a copy so that callbacks do not affect the result
313344
314- if input .directory {
345+ if input .directory && ! result . skipped {
315346 f := input .Destination .(* os.File )
316347 if err := f .Close (); err != nil && result .Err == nil {
317348 result .Err = fmt .Errorf ("closing file(%q): %w" , f .Name (), err )
@@ -321,10 +352,9 @@ func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutpu
321352 if result .Err != nil {
322353 os .Remove (f .Name ())
323354 }
324-
325- if d .config .asynchronous {
326- input .directoryObjectOutputs <- copiedResult
327- }
355+ }
356+ if input .directory && d .config .asynchronous {
357+ input .directoryObjectOutputs <- copiedResult
328358 }
329359
330360 if d .config .asynchronous || input .directory {
@@ -810,6 +840,7 @@ type DownloadOutput struct {
810840 Err error // error occurring during download
811841 Attrs * storage.ReaderObjectAttrs // attributes of downloaded object, if successful
812842
843+ skipped bool
813844 shard int
814845 shardLength int64
815846 crc32c uint32
@@ -948,3 +979,28 @@ func checksumObject(got, want uint32) error {
948979 }
949980 return nil
950981}
982+
983+ func isSubPath (localDirectory , filePath string ) (bool , error ) {
984+ // Validate if paths can be converted to absolute paths.
985+ absLocalDirectory , err := filepath .Abs (localDirectory )
986+ if err != nil {
987+ return false , fmt .Errorf ("cannot convert local directory to absolute path: %w" , err )
988+ }
989+ absFilePath , err := filepath .Abs (filePath )
990+ if err != nil {
991+ return false , fmt .Errorf ("cannot convert file path to absolute path: %w" , err )
992+ }
993+
994+ // The relative path from the local directory to the file path.
995+ // ex: if localDirectory is /tmp/foo and filePath is /tmp/foo/bar, rel will be "bar".
996+ rel , err := filepath .Rel (absLocalDirectory , absFilePath )
997+ if err != nil {
998+ return false , err
999+ }
1000+
1001+ // rel should not start with ".." to escape target directory
1002+ prevDir := ".." + string (filepath .Separator )
1003+ isUnder := ! strings .HasPrefix (rel , prevDir ) && rel != ".."
1004+
1005+ return isUnder , nil
1006+ }
0 commit comments