generated from MinishLab/watertemplate
-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathutils.py
More file actions
149 lines (125 loc) · 4.99 KB
/
utils.py
File metadata and controls
149 lines (125 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import hashlib
from collections.abc import Sequence
from typing import Any, Protocol, TypeAlias, TypeVar
import numpy as np
from frozendict import frozendict
# Type definitions
Record = TypeVar("Record", str, dict[str, Any])
DuplicateList: TypeAlias = list[tuple[Record, float]]
class Encoder(Protocol):
"""An encoder protocol for SemHash. Supports text, images, or any encodable data."""
def encode(
self,
inputs: Sequence[Any] | Any,
**kwargs: Any,
) -> np.ndarray:
"""
Encode a list of inputs into embeddings.
:param inputs: A list of inputs to encode (strings, images, etc.).
:param **kwargs: Additional keyword arguments.
:return: The embeddings of the inputs.
"""
... # pragma: no cover
def make_hashable(value: Any) -> Any:
"""
Convert a value to a hashable representation for use as dict keys.
Strings and other hashable types are returned as-is.
Non-hashable types (like PIL images, numpy arrays) are hashed to a string.
:param value: The value to make hashable.
:return: A hashable representation of the value.
"""
# Fast path: most values are strings or already hashable
if isinstance(value, (str, int, float, bool, type(None))):
return value
# Handle objects with tobytes() (PIL Image, numpy array, etc.)
if hasattr(value, "tobytes"):
return hashlib.md5(value.tobytes()).hexdigest()
# Fallback: try to hash, otherwise stringify
try:
hash(value)
return value
except TypeError:
return str(value)
def coerce_value(value: Any) -> Any:
"""
Coerce a value for encoding: stringify primitives, keep complex types raw.
This ensures primitives (int, float, bool) work with text encoders,
while complex types (PIL images, tensors, etc.) are passed through for multimodal encoders.
:param value: The value to coerce.
:return: The coerced value.
"""
if isinstance(value, (str, bytes)):
return value
if isinstance(value, (int, float, bool)):
return str(value)
return value # Complex types (images, tensors, etc.)
def to_frozendict(record: dict[str, Any], columns: Sequence[str] | set[str]) -> frozendict[str, Any]:
"""
Convert a record to a frozendict with hashable values.
:param record: The record to convert.
:param columns: The columns to include.
:return: A frozendict with only the specified columns (values made hashable).
:raises ValueError: If a column is missing from the record.
"""
try:
return frozendict({k: make_hashable(record[k]) for k in columns})
except KeyError as e:
missing = e.args[0]
raise ValueError(f"Missing column '{missing}' in record {record}") from e
def compute_candidate_limit(
total: int,
selection_size: int,
fraction: float = 0.1,
min_candidates: int = 100,
max_candidates: int = 1000,
) -> int:
"""
Compute the 'auto' candidate limit based on the total number of records.
:param total: Total number of records.
:param selection_size: Number of representatives to select.
:param fraction: Fraction of total records to consider as candidates.
:param min_candidates: Minimum number of candidates.
:param max_candidates: Maximum number of candidates.
:return: Computed candidate limit.
"""
# 1) fraction of total
limit = int(total * fraction)
# 2) ensure enough to pick selection_size
limit = max(limit, selection_size)
# 3) enforce lower bound
limit = max(limit, min_candidates)
# 4) enforce upper bound (and never exceed the dataset)
limit = min(limit, max_candidates, total)
return limit
def featurize(
records: Sequence[dict[str, Any]],
columns: Sequence[str],
model: Encoder,
) -> np.ndarray:
"""
Featurize a list of records using the model.
:param records: A list of records.
:param columns: Columns to featurize.
:param model: An Encoder model.
:return: The embeddings of the records.
:raises ValueError: If a column is missing from one or more records.
:raises TypeError: If encoding fails due to incompatible data types.
"""
# Extract the embeddings for each column across all records
embeddings_per_col = []
for col in columns:
try:
col_texts = [r[col] for r in records]
except KeyError as e:
raise ValueError(f"Missing column '{col}' in one or more records") from e
try:
col_emb = model.encode(col_texts)
except TypeError as e:
sample_type = type(col_texts[0]).__name__ if col_texts else "unknown"
raise TypeError(
f"Failed to encode column '{col}' (data type: {sample_type}). "
f"If encoding non-text data, provide a compatible encoder via the `model` parameter. "
f"See the SemHash documentation for more info."
) from e
embeddings_per_col.append(np.asarray(col_emb))
return np.concatenate(embeddings_per_col, axis=1)