-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathDistributedSampler.cpp
More file actions
62 lines (51 loc) · 1.8 KB
/
DistributedSampler.cpp
File metadata and controls
62 lines (51 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/*
* TinyTorch
* @author : [email protected]
*
*/
#include "DistributedSampler.h"
#include <algorithm>
#include "DistributedProcessGroup.h"
namespace tinytorch::distributed {
DistributedSampler::DistributedSampler(size_t datasetSize, std::optional<size_t> numReplicas,
std::optional<size_t> rank, bool shuffle, unsigned long seed, bool dropLast)
: datasetSize_(datasetSize), shuffle_(shuffle), seed_(seed), dropLast_(dropLast) {
auto dpg = DistributedProcessGroup::getInstance();
if (numReplicas.has_value()) {
numReplicas_ = numReplicas.value();
} else {
numReplicas_ = dpg->getWorldSize();
}
if (rank.has_value()) {
rank_ = rank.value();
} else {
rank_ = dpg->getRank();
}
if (dropLast_) {
numSamples_ = datasetSize_ / numReplicas_;
} else {
numSamples_ = std::ceil(static_cast<float>(datasetSize_) * 1.0f / static_cast<float>(numReplicas_));
}
totalSize_ = numSamples_ * numReplicas_;
generateIndices();
}
void DistributedSampler::generateIndices() {
std::vector<size_t> allIndices(datasetSize_);
std::iota(allIndices.begin(), allIndices.end(), 0);
if (shuffle_) {
auto& gen = RandomGeneratorCPU::getGenerator();
gen.seed(static_cast<unsigned int>(seed_ + epoch_));
std::shuffle(allIndices.begin(), allIndices.end(), gen);
}
if (!dropLast_) {
if (allIndices.size() < totalSize_) {
allIndices.insert(allIndices.end(), allIndices.begin(),
allIndices.begin() + static_cast<int64_t>(totalSize_ - allIndices.size()));
}
} else {
allIndices.resize(totalSize_);
}
auto offset = static_cast<int64_t>(numSamples_ * rank_);
indices_.assign(allIndices.begin() + offset, allIndices.begin() + offset + static_cast<int64_t>(numSamples_));
}
} // namespace tinytorch::distributed