Skip to content

Commit 94e5f97

Browse files
authored
hotfix: tag generation (#994)
1 parent 2909abd commit 94e5f97

File tree

3 files changed

+55
-50
lines changed

3 files changed

+55
-50
lines changed

apps/desktop/src/utils/tag-generation.ts

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,63 +6,68 @@ import { commands as templateCommands } from "@hypr/plugin-template";
66
import { generateText, localProviderName, modelProvider } from "@hypr/utils/ai";
77

88
export async function generateTagsForSession(sessionId: string): Promise<string[]> {
9-
const { type: connectionType } = await connectorCommands.getLlmConnection();
9+
try {
10+
const { type: connectionType } = await connectorCommands.getLlmConnection();
1011

11-
const config = await dbCommands.getConfig();
12-
const session = await dbCommands.getSession({ id: sessionId });
13-
if (!session) {
14-
throw new Error("Session not found");
15-
}
12+
const config = await dbCommands.getConfig();
13+
const session = await dbCommands.getSession({ id: sessionId });
14+
if (!session) {
15+
throw new Error("Session not found");
16+
}
1617

17-
const historicalTags = await dbCommands.listAllTags();
18-
const currentTags = await dbCommands.listSessionTags(sessionId);
18+
const historicalTags = await dbCommands.listAllTags();
19+
const currentTags = await dbCommands.listSessionTags(sessionId);
1920

20-
const extractHashtags = (text: string): string[] => {
21-
const hashtagRegex = /#(\w+)/g;
22-
return Array.from(text.matchAll(hashtagRegex), match => match[1]);
23-
};
21+
const extractHashtags = (text: string): string[] => {
22+
const hashtagRegex = /#(\w+)/g;
23+
return Array.from(text.matchAll(hashtagRegex), match => match[1]);
24+
};
2425

25-
const existingHashtags = extractHashtags(session.raw_memo_html);
26+
const existingHashtags = extractHashtags(session.raw_memo_html);
2627

27-
const systemPrompt = await templateCommands.render(
28-
"suggest_tags.system",
29-
{ config, type: connectionType },
30-
);
28+
const systemPrompt = await templateCommands.render(
29+
"suggest_tags.system",
30+
{ config, type: connectionType },
31+
);
3132

32-
const userPrompt = await templateCommands.render(
33-
"suggest_tags.user",
34-
{
35-
title: session.title,
36-
content: session.raw_memo_html,
37-
existing_hashtags: existingHashtags,
38-
formal_tags: currentTags.map(t => t.name),
39-
historical_tags: historicalTags.slice(0, 20).map(t => t.name),
40-
},
41-
);
33+
const userPrompt = await templateCommands.render(
34+
"suggest_tags.user",
35+
{
36+
title: session.title,
37+
content: session.raw_memo_html,
38+
existing_hashtags: existingHashtags,
39+
formal_tags: currentTags.map(t => t.name),
40+
historical_tags: historicalTags.slice(0, 20).map(t => t.name),
41+
},
42+
);
4243

43-
const provider = await modelProvider();
44-
const model = provider.languageModel("defaultModel");
44+
const provider = await modelProvider();
45+
const model = provider.languageModel("defaultModel");
4546

46-
const result = await generateText({
47-
model,
48-
messages: [
49-
{ role: "system", content: systemPrompt },
50-
{ role: "user", content: userPrompt },
51-
],
52-
providerOptions: {
53-
[localProviderName]: {
54-
metadata: {
55-
grammar: "tags",
47+
const result = await generateText({
48+
model,
49+
messages: [
50+
{ role: "system", content: systemPrompt },
51+
{ role: "user", content: userPrompt },
52+
],
53+
providerOptions: {
54+
[localProviderName]: {
55+
metadata: {
56+
grammar: "tags",
57+
},
5658
},
5759
},
58-
},
59-
});
60+
});
6061

61-
const schema = z.preprocess(
62-
(val) => (typeof val === "string" ? JSON.parse(val) : val),
63-
z.array(z.string().min(1)).min(1).max(5),
64-
);
62+
const schema = z.preprocess(
63+
(val) => (typeof val === "string" ? JSON.parse(val) : val),
64+
z.array(z.string().min(1)).min(1).max(5),
65+
);
6566

66-
const parsed = schema.safeParse(result.text);
67-
return parsed.success ? parsed.data : [];
67+
const parsed = schema.safeParse(result.text);
68+
return parsed.success ? parsed.data : [];
69+
} catch (error) {
70+
console.error("Tag generation failed:", error);
71+
return [];
72+
}
6873
}

crates/gbnf/assets/tags.gbnf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
root ::= "[" "'" word "'" ("," ws "'" word "'")* "]"
1+
root ::= "[" "\"" word "\"" ("," ws "\"" word "\"")* "]"
22
word ::= [a-zA-Z0-9_-]+
33
ws ::= " "*

crates/gbnf/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ mod tests {
4747
let gbnf = gbnf_validator::Validator::new().unwrap();
4848

4949
for (input, expected) in vec![
50-
("['meeting', 'summary']", true),
51-
("['meeting', 'summary', '']", false),
50+
("[\"meeting\", \"summary\"]", true),
51+
("[\"meeting\", \"summary\", \"\"]", false),
5252
] {
5353
let result = gbnf.validate(TAGS, input).unwrap();
5454
assert_eq!(result, expected, "failed: {}", input);

0 commit comments

Comments
 (0)