fbox/utils.go
2021-07-07 00:49:50 +10:00

533 lines
12 KiB
Go

package main
import (
"context"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"sync"
"time"
"github.com/klauspost/reedsolomon"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/blake2b"
)
const (
requestTimeout = time.Second * 30
)
func receiveFile(r io.Reader) (*os.File, error) {
tf, err := ioutil.TempFile("", "fbox-receive-*")
if err != nil {
log.WithError(err).Error("error creating temporary file")
return nil, err
}
if _, err := io.Copy(tf, r); err != nil {
log.WithError(err).Error("error writing temporary file")
return tf, err
}
if _, err := tf.Seek(0, io.SeekStart); err != nil {
log.WithError(err).Error("error seeking temporary file")
return tf, err
}
return tf, nil
}
func createShards(r io.ReadSeeker, size int64) ([]io.ReadSeeker, error) {
enc, err := reedsolomon.NewStream(dataShards, parityShards)
if err != nil {
log.WithError(err).Error("error creating reedsolomon stream")
return nil, fmt.Errorf("error creating reedsolomon stream: %w", err)
}
shards := dataShards + parityShards
out := make([]*os.File, shards)
// Create the resulting files.
for i := range out {
tf, err := ioutil.TempFile("", "fbox-receive-*")
if err != nil {
log.WithError(err).Error("error creating temporary output file")
return nil, err
}
out[i] = tf
}
// Split into files.
data := make([]io.Writer, dataShards)
for i := range data {
data[i] = out[i]
}
// Do the split
if err := enc.Split(r, data, size); err != nil {
log.WithError(err).Error("error splitting input")
return nil, fmt.Errorf("error splitting input: %w", err)
}
// Close and re-open the files.
input := make([]io.Reader, dataShards)
for i := range data {
if err := out[i].Close(); err != nil {
log.WithError(err).Error("error closing output")
return nil, fmt.Errorf("error closing output: %w", err)
}
f, err := os.Open(out[i].Name())
if err != nil {
log.WithError(err).Error("error reopening output")
return nil, fmt.Errorf("error reopening output: %w", err)
}
defer f.Close()
input[i] = f
}
// Create parity output writers
parity := make([]io.Writer, parityShards)
for i := range parity {
parity[i] = out[dataShards+i]
defer out[dataShards+i].Close()
}
// Encode parity
if err := enc.Encode(input, parity); err != nil {
log.WithError(err).Error("error encoding party shards")
return nil, fmt.Errorf("error encoding parity shards: %w", err)
}
// Close and reopen outputs and return as slice of readers
rs := make([]io.ReadSeeker, shards)
for i := range out {
_ = out[i].Close()
f, err := os.Open(out[i].Name())
if err != nil {
log.WithError(err).Error("error reopening output")
return nil, fmt.Errorf("error reopening output: %w", err)
}
rs[i] = f
}
return rs, nil
}
func getShard(uri string) (io.Reader, error) {
tf, err := ioutil.TempFile("", "fbox-get-*")
if err != nil {
log.WithError(err).Error("error creating temporary shard file")
return nil, err
}
res, err := request(http.MethodGet, uri, nil, nil)
if err != nil {
log.WithError(err).Error("error making shard request")
return nil, fmt.Errorf("error making shard request: %w", err)
}
if res.StatusCode != 200 {
log.WithField("Status", res.Status).Error("non-200 shard response")
return nil, fmt.Errorf("non-200 shard response")
}
defer res.Body.Close()
if _, err := io.Copy(tf, res.Body); err != nil {
log.WithError(err).Error("error reading shard")
return nil, fmt.Errorf("error reading sahrd: %w", err)
}
if err := tf.Close(); err != nil {
log.WithError(err).Error("error closing tempoary shard file")
return nil, fmt.Errorf("error closing temporary shard file: %w", err)
}
return os.Open(tf.Name())
}
func getShards(shardMap ShardMap) ([]io.Reader, error) {
wg := sync.WaitGroup{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nErrors := 0
errChan := make(chan error)
readersChan := make([]chan io.Reader, len(shardMap))
for i, shard := range shardMap {
readersChan[i] = make(chan io.Reader, 1)
readerChan := readersChan[i]
node, ok := nodes[shard.NodeID]
if !ok {
log.Warnf("error locating shard with node id %s", nodeID)
readerChan <- nil
continue
}
wg.Add(1)
go func(node Node, shard Shard, readerChan chan io.Reader) {
defer func() {
wg.Done()
close(readerChan)
}()
reader, err := getShard(shard.URI(node))
if err != nil {
log.WithError(err).WithField("node", node).Error("error getting shard")
readerChan <- nil
return
}
readerChan <- reader
}(node, shard, readerChan)
}
go func(ctx context.Context) {
for {
select {
case err, ok := <-errChan:
if !ok {
return
}
nErrors++
log.WithError(err).Errorf("getShards() error")
case <-ctx.Done():
return
}
}
}(ctx)
wg.Wait()
close(errChan)
if nErrors > 0 {
log.Error("getShards() too many errors")
return nil, fmt.Errorf("error retrieving shards")
}
readers := make([]io.Reader, len(shardMap))
for i, readerChan := range readersChan {
readers[i] = <-readerChan
}
return readers, nil
}
func repairShards(enc reedsolomon.StreamEncoder, shards []io.Reader) ([]io.Reader, error) {
// Create out destination writers
out := make([]io.Writer, len(shards))
for i := range out {
if shards[i] == nil {
tf, err := ioutil.TempFile("", "fbox-repair-*")
if err != nil {
log.WithError(err).Error("error creating temporary shard file")
return nil, err
}
out[i] = tf
}
}
if err := enc.Reconstruct(shards, out); err != nil {
log.WithError(err).Error("error reconstructing shards")
return nil, fmt.Errorf("error reconstructing shards: %w", err)
}
// Close output.
for i := range out {
if out[i] != nil {
if err := out[i].(*os.File).Close(); err != nil {
log.WithError(err).Error("error closing shard file")
return nil, fmt.Errorf("error closing shard file: %w", err)
}
f, err := os.Open(out[i].(*os.File).Name())
if err != nil {
log.WithError(err).Error("error reopening repaired sshard file")
return nil, fmt.Errorf("error reopening reparied shard file: %w", err)
}
shards[i] = f
}
}
// Reset shards so we can re-read
for _, shard := range shards {
_, _ = shard.(*os.File).Seek(0, io.SeekStart)
}
ok, err := enc.Verify(shards)
if err != nil {
log.WithError(err).Error("error verifying repaired shards")
return nil, fmt.Errorf("error verifying repaired shards: %w", err)
}
if !ok {
log.Error("error verifying repaired shards")
return nil, fmt.Errorf("error verifying repaired shards")
}
return shards, nil
}
func deleteShard(uri string) error {
res, err := request(http.MethodDelete, uri, nil, nil)
if err != nil {
log.WithError(err).Error("error making shard request")
return fmt.Errorf("error making shard request: %w", err)
}
if res.StatusCode != 200 {
log.WithField("Status", res.Status).Error("non-200 shard response")
return fmt.Errorf("non-200 shard response")
}
defer res.Body.Close()
return nil
}
func deleteShards(metadata *Metadata) error {
// TODO: Do this concurrencly
for _, shard := range metadata.Shards {
node, ok := nodes[shard.NodeID]
if !ok {
// TODO: Deal with garbage collection of shards we weren't able to remove.
log.Warnf("error locating shard with node id %s", nodeID)
continue
}
if err := deleteShard(shard.URI(node)); err != nil {
log.WithError(err).WithField("node", node).Error("error deleting shard")
continue
}
}
return nil
}
func readShards(metadata *Metadata) (io.Reader, error) {
// Create matrix
enc, err := reedsolomon.NewStream(dataShards, parityShards)
if err != nil {
log.WithError(err).Error("error creating reedsolomon stream")
return nil, fmt.Errorf("error creating reedsolomon stream: %w", err)
}
// Open the inputs
shards, err := getShards(metadata.Shards)
if err != nil {
log.WithError(err).Error("error getting shards")
return nil, fmt.Errorf("error getting shards: %s", err)
}
// Verify the shards
ok, err := enc.Verify(shards)
if err != nil {
log.WithError(err).Warn("error verifying shards")
}
if !ok {
log.Warn("shard verification failed, reconstructing shards...")
// Reset shards so we can re-read
for _, shard := range shards {
if shard != nil {
_, _ = shard.(*os.File).Seek(0, io.SeekStart)
}
}
shards, err = repairShards(enc, shards)
if err != nil {
log.WithError(err).Error("error repairing shards")
return nil, fmt.Errorf("error repairing shards: %s", err)
}
}
// Reset shards so we can re-read
for _, shard := range shards {
_, _ = shard.(*os.File).Seek(0, io.SeekStart)
}
tf, err := ioutil.TempFile("", "fbox-*")
if err != nil {
log.WithError(err).Error("error creating temporary file")
return nil, err
}
// Join the shards and write them
if err := enc.Join(tf, shards, metadata.Size); err != nil {
log.WithError(err).Error("error joining shards")
return nil, fmt.Errorf("error joining shards: %w", err)
}
// Reset output file so we can re-read
_, _ = tf.Seek(0, io.SeekStart)
return tf, nil
}
func storeShard(uri string, r io.Reader) error {
res, err := request(http.MethodPut, uri, nil, r)
if err != nil {
log.WithError(err).Error("error making blob request")
return fmt.Errorf("error making blob request: %w", err)
}
if res.StatusCode != 200 {
log.WithField("Status", res.Status).Error("error making blob request")
return fmt.Errorf("error making blob request: %s", res.Status)
}
defer res.Body.Close()
return nil
}
func storeShards(rs []io.ReadSeeker) (ShardMap, error) {
hashes, err := hashShards(rs)
if err != nil {
log.WithError(err).Error("error calculating shard hashses")
return nil, fmt.Errorf("error calculating shard hashes: %w", err)
}
shardMap := make(ShardMap, len(rs))
for i := range rs {
shardMap[i] = Shard{NodeID: selectNode().ID, BlobID: hashes[i]}
}
wg := sync.WaitGroup{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nErrors := 0
errChan := make(chan error)
for i, shard := range shardMap {
node := nodes[shard.NodeID]
wg.Add(1)
go func(node Node, shard Shard, r io.Reader) {
defer wg.Done()
if err := storeShard(shard.URI(node), r); err != nil {
log.WithError(err).Error("error storing shard")
errChan <- fmt.Errorf("error storing shard: %w", err)
}
}(node, shard, rs[i])
}
go func(ctx context.Context) {
for {
select {
case err, ok := <-errChan:
if !ok {
return
}
nErrors++
log.WithError(err).Errorf("storeShards() error")
case <-ctx.Done():
return
}
}
}(ctx)
wg.Wait()
close(errChan)
if nErrors > 0 {
log.Error("storeShards() too many errors")
return nil, fmt.Errorf("error storing shards")
}
return shardMap, nil
}
func hashShards(rs []io.ReadSeeker) ([]string, error) {
hashes := make([]string, len(rs))
for i, r := range rs {
hash, err := hashReader(r)
if err != nil {
log.WithError(err).Error("error hashing shard")
return nil, fmt.Errorf("error hashing shard: %w", err)
}
if _, err := r.Seek(0, io.SeekStart); err != nil {
log.WithError(err).Error("error seeking shard")
return nil, fmt.Errorf("error seeking shard: %w", err)
}
hashes[i] = hash
}
return hashes, nil
}
func hashReader(r io.Reader) (string, error) {
hasher, err := blake2b.New256(nil)
if err != nil {
log.WithError(err).Error("error creating hasher interface")
return "", fmt.Errorf("error creating hasher interface: %s", err)
}
if _, err := io.Copy(hasher, r); err != nil {
log.WithError(err).Error("error hashing reader")
return "", fmt.Errorf("error hashing reader: %w", err)
}
sum := hasher.Sum(nil)
return hex.EncodeToString(sum), nil
}
func request(method, url string, headers http.Header, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
log.WithError(err).Errorf("%s: http.NewRequest fail: %s", url, err)
return nil, err
}
if headers == nil {
headers = make(http.Header)
}
// Set a default User-Agent (if none set)
if headers.Get("User-Agent") == "" {
headers.Set("User-Agent", fmt.Sprintf("fbox/%s", FullVersion()))
}
req.Header = headers
client := http.Client{
Timeout: requestTimeout,
}
res, err := client.Do(req)
if err != nil {
log.WithError(err).Errorf("%s: client.Do fail: %s", url, err)
return nil, err
}
return res, nil
}
func resourceExists(url string) bool {
res, err := request(http.MethodHead, url, nil, nil)
if err != nil {
log.WithError(err).Errorf("error checking if %s exists", url)
return false
}
defer res.Body.Close()
return res.StatusCode/100 == 2
}
func fileExists(name string) bool {
if _, err := os.Stat(name); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}