-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathextract.py
More file actions
399 lines (342 loc) · 14.6 KB
/
extract.py
File metadata and controls
399 lines (342 loc) · 14.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
import dataclasses
import os
import typing
import warnings
import gguf
import numpy as np
from sklearn.decomposition import PCA
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
import tqdm
from .control import ControlModel, model_layer_list
from .saes import Sae
@dataclasses.dataclass
class DatasetEntry:
positive: str
negative: str
@dataclasses.dataclass
class ControlVector:
model_type: str
directions: dict[int, np.ndarray]
@classmethod
def train(
cls,
model: "PreTrainedModel | ControlModel",
tokenizer: PreTrainedTokenizerBase,
dataset: list[DatasetEntry],
**kwargs,
) -> "ControlVector":
"""
Train a ControlVector for a given model and tokenizer using the provided dataset.
Args:
model (PreTrainedModel | ControlModel): The model to train against.
tokenizer (PreTrainedTokenizerBase): The tokenizer to tokenize the dataset.
dataset (list[DatasetEntry]): The dataset used for training.
**kwargs: Additional keyword arguments.
max_batch_size (int, optional): The maximum batch size for training.
Defaults to 32. Try reducing this if you're running out of memory.
method (str, optional): The training method to use. Can be either
"pca_diff" or "pca_center". Defaults to "pca_diff".
compute_hiddens (Callable, optional): Override hidden state computation.
See signature of `read_representations`.
transform_hiddens (Callable, optional): Transform the hidden states after
they are computed. See signature of `read_representations`.
Returns:
ControlVector: The trained vector.
"""
with torch.inference_mode():
dirs = read_representations(
model,
tokenizer,
dataset,
**kwargs,
)
return cls(model_type=model.config.model_type, directions=dirs)
@classmethod
def train_with_sae(
cls,
model: "PreTrainedModel | ControlModel",
tokenizer: PreTrainedTokenizerBase,
sae: Sae,
dataset: list[DatasetEntry],
*,
decode: bool = True,
method: typing.Literal["pca_diff", "pca_center", "umap"] = "pca_center",
**kwargs,
) -> "ControlVector":
"""
Like ControlVector.train, but using an SAE. It's better! WIP.
Args:
model (PreTrainedModel | ControlModel): The model to train against.
tokenizer (PreTrainedTokenizerBase): The tokenizer to tokenize the dataset.
sae (saes.Sae): See the `saes` module for how to load this.
dataset (list[DatasetEntry]): The dataset used for training.
**kwargs: Additional keyword arguments.
decode (bool, optional): Whether to decode the vector to make it immediately usable.
If not, keeps it as monosemantic SAE features for introspection, but you will need to decode it manually
to use it. Defaults to True.
max_batch_size (int, optional): The maximum batch size for training.
Defaults to 32. Try reducing this if you're running out of memory.
method (str, optional): The training method to use. Can be either
"pca_diff" or "pca_center". Defaults to "pca_center"! This is different
than ControlVector.train, which defaults to "pca_diff".
Returns:
ControlVector: The trained vector.
"""
def transform_hiddens(hiddens: dict[int, np.ndarray]) -> dict[int, np.ndarray]:
sae_hiddens = {}
for k, v in tqdm.tqdm(hiddens.items(), desc="sae encoding"):
sae_hiddens[k] = sae.layers[k].encode(v)
return sae_hiddens
with torch.inference_mode():
dirs = read_representations(
model,
tokenizer,
dataset,
transform_hiddens=transform_hiddens,
method=method,
**kwargs,
)
final_dirs = {}
if decode:
for k, v in tqdm.tqdm(dirs.items(), desc="sae decoding"):
final_dirs[k] = sae.layers[k].decode(v)
else:
final_dirs = dirs
return cls(model_type=model.config.model_type, directions=final_dirs)
def export_gguf(self, path: os.PathLike[str] | str):
"""
Export a trained ControlVector to a llama.cpp .gguf file.
Note: This file can't be used with llama.cpp yet. WIP!
```python
vector = ControlVector.train(...)
vector.export_gguf("path/to/write/vector.gguf")
```
```
"""
arch = "controlvector"
writer = gguf.GGUFWriter(path, arch)
writer.add_string(f"{arch}.model_hint", self.model_type)
writer.add_uint32(f"{arch}.layer_count", len(self.directions))
for layer in self.directions.keys():
writer.add_tensor(f"direction.{layer}", self.directions[layer])
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file()
writer.close()
@classmethod
def import_gguf(cls, path: os.PathLike[str] | str) -> "ControlVector":
reader = gguf.GGUFReader(path)
archf = reader.get_field("general.architecture")
if not archf or not len(archf.parts):
warnings.warn(".gguf file missing architecture field")
else:
arch = str(bytes(archf.parts[-1]), encoding="utf-8", errors="replace")
if arch != "controlvector":
warnings.warn(
f".gguf file with architecture {arch!r} does not appear to be a control vector!"
)
modelf = reader.get_field("controlvector.model_hint")
if not modelf or not len(modelf.parts):
raise ValueError(".gguf file missing controlvector.model_hint field")
model_hint = str(bytes(modelf.parts[-1]), encoding="utf-8")
directions = {}
for tensor in reader.tensors:
if not tensor.name.startswith("direction."):
continue
try:
layer = int(tensor.name.split(".")[1])
except (IndexError, ValueError):
raise ValueError(
f".gguf file has invalid direction field name: {tensor.name}"
)
directions[layer] = tensor.data
return cls(model_type=model_hint, directions=directions)
def _helper_combine(
self, other: "ControlVector", other_coeff: float
) -> "ControlVector":
if self.model_type != other.model_type:
warnings.warn(
"Trying to add vectors with mismatched model_types together, this may produce unexpected results."
)
model_type = self.model_type
directions: dict[int, np.ndarray] = {}
for layer in self.directions:
directions[layer] = self.directions[layer]
for layer in other.directions:
other_layer = other_coeff * other.directions[layer]
if layer in directions:
directions[layer] = directions[layer] + other_layer
else:
directions[layer] = other_layer
return ControlVector(model_type=model_type, directions=directions)
def __eq__(self, other: "ControlVector") -> bool:
if self is other:
return True
if self.model_type != other.model_type:
return False
if self.directions.keys() != other.directions.keys():
return False
for k in self.directions.keys():
if (self.directions[k] != other.directions[k]).any():
return False
return True
def __add__(self, other: "ControlVector") -> "ControlVector":
if not isinstance(other, ControlVector):
raise TypeError(
f"Unsupported operand type(s) for +: 'ControlVector' and '{type(other).__name__}'"
)
return self._helper_combine(other, 1)
def __sub__(self, other: "ControlVector") -> "ControlVector":
if not isinstance(other, ControlVector):
raise TypeError(
f"Unsupported operand type(s) for -: 'ControlVector' and '{type(other).__name__}'"
)
return self._helper_combine(other, -1)
def __neg__(self) -> "ControlVector":
directions: dict[int, np.ndarray] = {}
for layer in self.directions:
directions[layer] = -self.directions[layer]
return ControlVector(model_type=self.model_type, directions=directions)
def __mul__(self, other: int | float | np.number) -> "ControlVector":
directions: dict[int, np.ndarray] = {}
for layer in self.directions:
directions[layer] = other * self.directions[layer]
return ControlVector(model_type=self.model_type, directions=directions)
def __rmul__(self, other: int | float | np.number) -> "ControlVector":
return self.__mul__(other)
def __truediv__(self, other: int | float | np.number) -> "ControlVector":
return self.__mul__(1 / other)
class ComputeHiddens(typing.Protocol):
def __call__(
self,
model: "PreTrainedModel | ControlModel",
tokenizer: PreTrainedTokenizerBase,
train_strs: list[str],
hidden_layers: list[int],
batch_size: int,
) -> dict[int, np.ndarray]: ...
def read_representations(
model: "PreTrainedModel | ControlModel",
tokenizer: PreTrainedTokenizerBase,
inputs: list[DatasetEntry],
hidden_layers: typing.Iterable[int] | None = None,
batch_size: int = 32,
method: typing.Literal["pca_diff", "pca_center", "umap"] = "pca_diff",
compute_hiddens: ComputeHiddens | None = None,
transform_hiddens: (
typing.Callable[[dict[int, np.ndarray]], dict[int, np.ndarray]] | None
) = None,
) -> dict[int, np.ndarray]:
"""
Extract the representations based on the contrast dataset.
"""
if not hidden_layers:
hidden_layers = range(-1, -model.config.num_hidden_layers, -1)
# normalize the layer indexes if they're negative
n_layers = len(model_layer_list(model))
hidden_layers = [i if i >= 0 else n_layers + i for i in hidden_layers]
# the order is [positive, negative, positive, negative, ...]
train_strs = [s for ex in inputs for s in (ex.positive, ex.negative)]
if compute_hiddens is None:
layer_hiddens = batched_get_hiddens(
model, tokenizer, train_strs, hidden_layers, batch_size
)
else:
layer_hiddens = compute_hiddens(
model=model,
tokenizer=tokenizer,
train_strs=train_strs,
hidden_layers=hidden_layers,
batch_size=batch_size,
)
if transform_hiddens is not None:
layer_hiddens = transform_hiddens(layer_hiddens)
# get directions for each layer using PCA
directions: dict[int, np.ndarray] = {}
for layer in tqdm.tqdm(hidden_layers):
h = layer_hiddens[layer]
assert h.shape[0] == len(inputs) * 2
if method == "pca_diff":
train = h[::2] - h[1::2]
elif method == "pca_center":
center = (h[::2] + h[1::2]) / 2
train = h
train[::2] -= center
train[1::2] -= center
elif method == "umap":
train = h
else:
raise ValueError("unknown method " + method)
if method != "umap":
# shape (1, n_features)
pca_model = PCA(n_components=1, whiten=False).fit(train)
# shape (n_features,)
directions[layer] = pca_model.components_.astype(np.float32).squeeze(axis=0)
else:
# still experimental so don't want to add this as a real dependency yet
import umap # type: ignore
umap_model = umap.UMAP(n_components=1)
embedding = umap_model.fit_transform(train).astype(np.float32)
directions[layer] = np.sum(train * embedding, axis=0) / np.sum(embedding)
# calculate sign
projected_hiddens = project_onto_direction(h, directions[layer])
# order is [positive, negative, positive, negative, ...]
positive_smaller_mean = np.mean(
[
projected_hiddens[i] < projected_hiddens[i + 1]
for i in range(0, len(inputs) * 2, 2)
]
)
positive_larger_mean = np.mean(
[
projected_hiddens[i] > projected_hiddens[i + 1]
for i in range(0, len(inputs) * 2, 2)
]
)
if positive_smaller_mean > positive_larger_mean: # type: ignore
directions[layer] *= -1
return directions
def batched_get_hiddens(
model,
tokenizer,
inputs: list[str],
hidden_layers: list[int],
batch_size: int,
) -> dict[int, np.ndarray]:
"""
Using the given model and tokenizer, pass the inputs through the model and get the hidden
states for each layer in `hidden_layers` for the last token.
Returns a dictionary from `hidden_layers` layer id to an numpy array of shape `(n_inputs, hidden_dim)`
"""
batched_inputs = [
inputs[p : p + batch_size] for p in range(0, len(inputs), batch_size)
]
hidden_states = {layer: [] for layer in hidden_layers}
with torch.no_grad():
for batch in tqdm.tqdm(batched_inputs):
# get the last token, handling right padding if present
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt")
encoded_batch = encoded_batch.to(model.device)
out = model(**encoded_batch, output_hidden_states=True)
attention_mask = encoded_batch["attention_mask"]
for i in range(len(batch)):
last_non_padding_index = (
attention_mask[i].nonzero(as_tuple=True)[0][-1].item()
)
for layer in hidden_layers:
hidden_idx = layer + 1 if layer >= 0 else layer
hidden_state = (
out.hidden_states[hidden_idx][i][last_non_padding_index]
.cpu()
.float()
.numpy()
)
hidden_states[layer].append(hidden_state)
del out
return {k: np.vstack(v) for k, v in hidden_states.items()}
def project_onto_direction(H, direction):
"""Project matrix H (n, d_1) onto direction vector (d_2,)"""
mag = np.linalg.norm(direction)
assert not np.isinf(mag)
return (H @ direction) / mag