Skip to content

Commit 03ffa12

Browse files
Use bind filer for mounts
The bind filter supports bind-like mounts and volume mounts. It also allows us to have read-only mounts. Signed-off-by: Gabriel Adrian Samfira <[email protected]>
1 parent e9b50df commit 03ffa12

4 files changed

Lines changed: 490 additions & 70 deletions

File tree

mount/bind_filter_windows.go

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
package mount
2+
3+
import (
4+
"bytes"
5+
"encoding/binary"
6+
"errors"
7+
"fmt"
8+
"os"
9+
"path/filepath"
10+
"strings"
11+
"syscall"
12+
"unicode/utf16"
13+
"unsafe"
14+
15+
"golang.org/x/sys/windows"
16+
)
17+
18+
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go
19+
//sys BfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter?
20+
//sys BfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) = bindfltapi.BfRemoveMapping?
21+
//sys BfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) = bindfltapi.BfGetMappings?
22+
23+
// BfSetupFilter flags. See:
24+
// https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240
25+
const (
26+
BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001
27+
// Generates a merged binding, mapping target entries to the virtualization root.
28+
BINDFLT_FLAG_MERGED_BIND_MAPPING uint32 = 0x00000002
29+
// Use the binding mapping attached to the mapped-in job object (silo) instead of the default global mapping.
30+
BINDFLT_FLAG_USE_CURRENT_SILO_MAPPING uint32 = 0x00000004
31+
BINDFLT_FLAG_REPARSE_ON_FILES uint32 = 0x00000008
32+
// Skips checks on file/dir creation inside a non-merged, read-only mapping.
33+
// Only usable when READ_ONLY_MAPPING is set.
34+
BINDFLT_FLAG_SKIP_SHARING_CHECK uint32 = 0x00000010
35+
BINDFLT_FLAG_CLOUD_FILES_ECPS uint32 = 0x00000020
36+
// Tells bindflt to fail mapping with STATUS_INVALID_PARAMETER if a mapping produces
37+
// multiple targets.
38+
BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040
39+
// Turns on caching by asserting that the backing store for name mappings is immutable.
40+
BINDFLT_FLAG_IMMUTABLE_BACKING uint32 = 0x00000080
41+
BINDFLT_FLAG_PREVENT_CASE_SENSITIVE_BINDING uint32 = 0x00000100
42+
// Tells bindflt to fail with STATUS_OBJECT_PATH_NOT_FOUND when a mapping is being added
43+
// but its parent paths (ancestors) have not already been added.
44+
BINDFLT_FLAG_EMPTY_VIRT_ROOT uint32 = 0x00000200
45+
BINDFLT_FLAG_NO_REPARSE_ON_ROOT uint32 = 0x10000000
46+
BINDFLT_FLAG_BATCHED_REMOVE_MAPPINGS uint32 = 0x20000000
47+
)
48+
49+
const (
50+
BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001
51+
BINDFLT_GET_MAPPINGS_FLAG_SILO uint32 = 0x00000002
52+
BINDFLT_GET_MAPPINGS_FLAG_USER uint32 = 0x00000004
53+
)
54+
55+
// ApplyFileBinding creates a global mount of the source in root, with an optional
56+
// read only flag.
57+
// The bind filter allows us to create mounts of directories and volumes. By default it allows
58+
// us to mount multiple sources inside a single root, acting as an overlay. Files from the
59+
// second source will superscede the first source that was mounted.
60+
// This function disables this behavior and sets the BINDFLT_FLAG_NO_MULTIPLE_TARGETS flag
61+
// on the mount.
62+
func ApplyFileBinding(root, source string, readOnly bool) error {
63+
// The parent directory needs to exist for the bind to work. MkdirAll stats and
64+
// returns nil if the directory exists internally so we should be fine to mkdirall
65+
// every time.
66+
if err := os.MkdirAll(filepath.Dir(root), 0); err != nil {
67+
return err
68+
}
69+
70+
if strings.Contains(source, "Volume{") && !strings.HasSuffix(source, "\\") {
71+
// Add trailing slash to volumes, otherwise we get an error when binding it to
72+
// a folder.
73+
source = source + "\\"
74+
}
75+
76+
rootPtr, err := windows.UTF16PtrFromString(root)
77+
if err != nil {
78+
return err
79+
}
80+
81+
targetPtr, err := windows.UTF16PtrFromString(source)
82+
if err != nil {
83+
return err
84+
}
85+
flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS
86+
if readOnly {
87+
flags |= BINDFLT_FLAG_READ_ONLY_MAPPING
88+
}
89+
90+
// Set the job handle to 0 to create a global mount.
91+
if err := BfSetupFilter(
92+
0,
93+
flags,
94+
rootPtr,
95+
targetPtr,
96+
nil,
97+
0,
98+
); err != nil {
99+
return fmt.Errorf("failed to bind target %q to root %q: %w", source, root, err)
100+
}
101+
return nil
102+
}
103+
104+
func RemoveFileBinding(root string) error {
105+
rootPtr, err := windows.UTF16PtrFromString(root)
106+
if err != nil {
107+
return fmt.Errorf("converting path to utf-16: %w", err)
108+
}
109+
110+
if err := BfRemoveMapping(0, rootPtr); err != nil {
111+
return fmt.Errorf("removing file binding: %w", err)
112+
}
113+
return nil
114+
}
115+
116+
// mappingEntry holds information about where in the response buffer we can
117+
// find information about the virtual root (the mount point) and the targets (sources)
118+
// that get mounted, as well as the flags used to bind the targets to the virtual root.
119+
type mappingEntry struct {
120+
VirtRootLength uint32
121+
VirtRootOffset uint32
122+
Flags uint32
123+
NumberOfTargets uint32
124+
TargetEntriesOffset uint32
125+
}
126+
127+
type mappingTargetEntry struct {
128+
TargetRootLength uint32
129+
TargetRootOffset uint32
130+
}
131+
132+
// getMappingsResponseHeader represents the first 12 bytes of the BfGetMappings() response.
133+
// It gives us the size of the buffer, the status of the call and the number of mappings.
134+
// A response
135+
type getMappingsResponseHeader struct {
136+
Size uint32
137+
Status uint32
138+
MappingCount uint32
139+
}
140+
141+
type BindMapping struct {
142+
MountPoint string
143+
Flags uint32
144+
Targets []string
145+
}
146+
147+
func decodeEntry(buffer []byte) (string, error) {
148+
name := make([]uint16, len(buffer)/2)
149+
err := binary.Read(bytes.NewReader(buffer), binary.LittleEndian, &name)
150+
if err != nil {
151+
return "", fmt.Errorf("decoding name: %w", err)
152+
}
153+
return string(utf16.Decode(name)), nil
154+
}
155+
156+
func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) {
157+
if len(buffer) < offset+count*6 {
158+
return nil, fmt.Errorf("invalid buffer")
159+
}
160+
161+
targets := make([]string, count)
162+
for i := 0; i < count; i++ {
163+
entryBuf := buffer[offset+i*8 : offset+i*8+8]
164+
tgt := *(*mappingTargetEntry)(unsafe.Pointer(&entryBuf[0]))
165+
if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) {
166+
return nil, fmt.Errorf("invalid buffer")
167+
}
168+
decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : uint32(tgt.TargetRootOffset)+uint32(tgt.TargetRootLength)])
169+
if err != nil {
170+
return nil, fmt.Errorf("decoding name: %w", err)
171+
}
172+
decoded, err = getFinalPath(decoded)
173+
if err != nil {
174+
return nil, fmt.Errorf("fetching final path: %w", err)
175+
}
176+
177+
targets[i] = decoded
178+
}
179+
return targets, nil
180+
}
181+
182+
func getFinalPath(pth string) (string, error) {
183+
// BfGetMappings returns VOLUME_NAME_NT paths like \Device\HarddiskVolume2\ProgramData.
184+
// These can be accessed by prepending \\.\GLOBALROOT to the path. We use this to get the
185+
// DOS paths for these files.
186+
if strings.HasPrefix(pth, `\Device`) {
187+
pth = `\\.\GLOBALROOT` + pth
188+
}
189+
190+
han, err := getFileHandle(pth)
191+
if err != nil {
192+
return "", fmt.Errorf("fetching file handle: %w", err)
193+
}
194+
195+
buf := make([]uint16, 100)
196+
var flags uint32 = 0x0
197+
for {
198+
n, err := windows.GetFinalPathNameByHandle(windows.Handle(han), &buf[0], uint32(len(buf)), flags)
199+
if err != nil {
200+
// if we mounted a volume that does not also have a drive letter assigned, attempting to
201+
// fetch the VOLUME_NAME_DOS will fail with os.ErrNotExist. Attempt to get the VOLUME_NAME_GUID.
202+
if errors.Is(err, os.ErrNotExist) && flags != 0x1 {
203+
flags = 0x1
204+
continue
205+
}
206+
return "", fmt.Errorf("getting final path name: %w", err)
207+
}
208+
if n < uint32(len(buf)) {
209+
break
210+
}
211+
buf = make([]uint16, n)
212+
}
213+
finalPath := syscall.UTF16ToString(buf)
214+
// We got VOLUME_NAME_DOS, we need to strip away some leading slashes.
215+
// Leave unchanged if we ended up requesting VOLUME_NAME_GUID
216+
if len(finalPath) > 4 && finalPath[:4] == `\\?\` && flags == 0x0 {
217+
finalPath = finalPath[4:]
218+
if len(finalPath) > 3 && finalPath[:3] == `UNC` {
219+
// return path like \\server\share\...
220+
finalPath = `\` + finalPath[3:]
221+
}
222+
}
223+
224+
return finalPath, nil
225+
}
226+
227+
func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, error) {
228+
if len(buffer) < int(entry.VirtRootOffset)+int(entry.VirtRootLength) {
229+
return BindMapping{}, fmt.Errorf("invalid buffer")
230+
}
231+
232+
src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+uint32(entry.VirtRootLength)])
233+
if err != nil {
234+
return BindMapping{}, fmt.Errorf("decoding entry: %w", err)
235+
}
236+
targets, err := getTargetsFromBuffer(buffer, int(entry.TargetEntriesOffset), int(entry.NumberOfTargets))
237+
if err != nil {
238+
return BindMapping{}, fmt.Errorf("fetching targets: %w", err)
239+
}
240+
241+
src, err = getFinalPath(src)
242+
if err != nil {
243+
return BindMapping{}, fmt.Errorf("fetching final path: %w", err)
244+
}
245+
246+
return BindMapping{
247+
Flags: entry.Flags,
248+
Targets: targets,
249+
MountPoint: src,
250+
}, nil
251+
}
252+
253+
func getFileHandle(pth string) (syscall.Handle, error) {
254+
info, err := os.Lstat(pth)
255+
if err != nil {
256+
return 0, fmt.Errorf("accessing file: %w", err)
257+
}
258+
p, err := syscall.UTF16PtrFromString(pth)
259+
if err != nil {
260+
return 0, err
261+
}
262+
attrs := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS)
263+
if info.Mode()&os.ModeSymlink != 0 {
264+
attrs |= syscall.FILE_FLAG_OPEN_REPARSE_POINT
265+
}
266+
h, err := syscall.CreateFile(p, 0, 0, nil, syscall.OPEN_EXISTING, attrs, 0)
267+
if err != nil {
268+
return 0, err
269+
}
270+
return h, nil
271+
}
272+
273+
// GetBindMappings returns a list of bind mappings that have their root on a
274+
// particular volume. The volumePath parameter can be any path that exists on
275+
// a volume. For example, if a number of mappings are created in C:\ProgramData\test,
276+
// to get a list of those mappings, the volumePath parameter would have to be set to
277+
// C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child
278+
// path that exists.
279+
func GetBindMappings(volumePath string) ([]BindMapping, error) {
280+
rootPtr, err := windows.UTF16PtrFromString(volumePath)
281+
if err != nil {
282+
return nil, err
283+
}
284+
285+
var flags uint32 = BINDFLT_GET_MAPPINGS_FLAG_VOLUME
286+
// allocate a large buffer for results
287+
var outBuffSize uint32 = 256 * 1024
288+
buf := make([]byte, outBuffSize)
289+
290+
if err := BfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, uintptr(unsafe.Pointer(&buf[0]))); err != nil {
291+
return nil, err
292+
}
293+
294+
if outBuffSize < 12 {
295+
return nil, fmt.Errorf("invalid buffer returned")
296+
}
297+
298+
result := buf[:outBuffSize]
299+
300+
// The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{}
301+
headerBuffer := result[:12]
302+
// The alternative to using unsafe and casting it to the above defined structures, is to manually
303+
// parse the fields. Not too terrible, but not sure it'd worth the trouble.
304+
header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0]))
305+
306+
if header.MappingCount == 0 {
307+
// no mappings
308+
return []BindMapping{}, nil
309+
}
310+
311+
mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)]
312+
// Get a pointer to the first mapping in the slice
313+
mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0]))
314+
// Get slice of mappings
315+
mappings := unsafe.Slice(mappingsPointer, header.MappingCount)
316+
317+
mappingEntries := make([]BindMapping, header.MappingCount)
318+
for i := 0; i < int(header.MappingCount); i++ {
319+
bindMapping, err := getBindMappingFromBuffer(result, mappings[i])
320+
if err != nil {
321+
return nil, fmt.Errorf("fetching bind mappings: %w", err)
322+
}
323+
mappingEntries[i] = bindMapping
324+
}
325+
326+
return mappingEntries, nil
327+
}

0 commit comments

Comments
 (0)