-
-
Notifications
You must be signed in to change notification settings - Fork 228
Expand file tree
/
Copy pathtrain-clm.rs
More file actions
144 lines (120 loc) · 4.38 KB
/
train-clm.rs
File metadata and controls
144 lines (120 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use std::{
fs::File,
io::{Read, Seek, SeekFrom, Write},
path::Path
};
use kdam::BarExt;
use ort::{
ep,
memory::Allocator,
session::{Session, builder::SessionBuilder},
training::{Checkpoint, Trainer},
value::{Tensor, TensorRef}
};
use rand::RngCore;
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const BATCH_SIZE: usize = 16;
const SEQUENCE_LENGTH: usize = 256;
// Include common code for `ort` examples that allows using the various feature flags to enable different EPs and
// backends.
#[path = "../common/mod.rs"]
mod common;
fn main() -> ort::Result<()> {
// Initialize tracing to receive log messages from `ort`
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "warn".into()))
.with(tracing_subscriber::fmt::layer())
.init();
// Register EPs based on feature flags - this isn't crucial for usage and can be removed.
common::init()?;
kdam::term::init(true);
let _ = kdam::term::hide_cursor();
let trainer = Trainer::new(
SessionBuilder::new()?.with_execution_providers([ep::CUDA::default().build()])?,
Allocator::default(),
Checkpoint::load("tools/train-data/mini-clm/checkpoint")?,
"tools/train-data/mini-clm/training_model.onnx",
"tools/train-data/mini-clm/eval_model.onnx",
"tools/train-data/mini-clm/optimizer_model.onnx"
)?;
let tokenizer = Tokenizer::from_file(
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.join("gpt2")
.join("data")
.join("tokenizer.json")
)
.unwrap();
let mut optimizer = trainer.optimizer();
optimizer.set_lr(7e-5)?;
let mut dataset = File::open("train-clm-dataset.bin").unwrap();
let file_size = dataset.metadata().unwrap().len();
let num_tokens = (file_size / 2) as usize; // 16-bit tokens
let mut rng = rand::rng();
let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let mut pb = kdam::tqdm!(total = 5000);
for _ in 0..5000 {
for batch in 0..BATCH_SIZE {
let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64;
dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
}
let inputs = Tensor::from_array(([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect::<Vec<i64>>()))?;
let labels = Tensor::from_array(([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect::<Vec<i64>>()))?;
let outputs = trainer.step(ort::inputs![inputs], ort::inputs![labels])?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
pb.set_postfix(format!("loss={loss:.3}"));
pb.update(1).unwrap();
if loss.is_nan() {
return Ok(());
}
let mut optimizer = trainer.optimizer();
optimizer.step()?;
optimizer.reset_grad()?;
}
eprintln!();
let _ = kdam::term::show_cursor();
trainer.export("trained-clm.onnx", ["probs"])?;
let mut session = Session::builder()?.commit_from_file("trained-clm.onnx")?;
let mut stdout = std::io::stdout();
let tokens = tokenizer.encode("<|endoftext|>", false).unwrap();
let mut tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
for _ in 0..50 {
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let outputs = session.run(ort::inputs![input])?;
let (dim, probabilities) = outputs["probs"].try_extract_tensor()?;
let (seq_len, vocab_size) = (dim[2] as usize, dim[3] as usize);
let mut probabilities: Vec<(usize, f32)> = probabilities[(seq_len - 1) * vocab_size..].iter().copied().enumerate().collect();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
let token = probabilities[0].0 as i64;
tokens.push(token);
let token_str = tokenizer.decode(&[token as _], false).unwrap();
print!("{}", token_str);
stdout.flush().unwrap();
}
println!();
Ok(())
}