Skip to content

Commit 75596e9

Browse files
committed
refactor(discord): unify DM command auth handling
1 parent 12c1257 commit 75596e9

File tree

4 files changed

+120
-64
lines changed

4 files changed

+120
-64
lines changed

src/discord/monitor/dm-command-auth.ts

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,32 @@ export type DiscordDmCommandAccess = {
1717
allowMatch: ReturnType<typeof resolveDiscordAllowListMatch> | { allowed: false };
1818
};
1919

20+
function resolveSenderAllowMatch(params: {
21+
allowEntries: string[];
22+
sender: { id: string; name?: string; tag?: string };
23+
allowNameMatching: boolean;
24+
}) {
25+
const allowList = normalizeDiscordAllowList(params.allowEntries, DISCORD_ALLOW_LIST_PREFIXES);
26+
return allowList
27+
? resolveDiscordAllowListMatch({
28+
allowList,
29+
candidate: params.sender,
30+
allowNameMatching: params.allowNameMatching,
31+
})
32+
: ({ allowed: false } as const);
33+
}
34+
35+
function resolveDmPolicyCommandAuthorization(params: {
36+
dmPolicy: DiscordDmPolicy;
37+
decision: DmGroupAccessDecision;
38+
commandAuthorized: boolean;
39+
}) {
40+
if (params.dmPolicy === "open" && params.decision === "allow") {
41+
return true;
42+
}
43+
return params.commandAuthorized;
44+
}
45+
2046
export async function resolveDiscordDmCommandAccess(params: {
2147
accountId: string;
2248
dmPolicy: DiscordDmPolicy;
@@ -40,30 +66,19 @@ export async function resolveDiscordDmCommandAccess(params: {
4066
allowFrom: params.configuredAllowFrom,
4167
groupAllowFrom: [],
4268
storeAllowFrom,
43-
isSenderAllowed: (allowEntries) => {
44-
const allowList = normalizeDiscordAllowList(allowEntries, DISCORD_ALLOW_LIST_PREFIXES);
45-
const allowMatch = allowList
46-
? resolveDiscordAllowListMatch({
47-
allowList,
48-
candidate: params.sender,
49-
allowNameMatching: params.allowNameMatching,
50-
})
51-
: { allowed: false };
52-
return allowMatch.allowed;
53-
},
69+
isSenderAllowed: (allowEntries) =>
70+
resolveSenderAllowMatch({
71+
allowEntries,
72+
sender: params.sender,
73+
allowNameMatching: params.allowNameMatching,
74+
}).allowed,
5475
});
5576

56-
const commandAllowList = normalizeDiscordAllowList(
57-
access.effectiveAllowFrom,
58-
DISCORD_ALLOW_LIST_PREFIXES,
59-
);
60-
const allowMatch = commandAllowList
61-
? resolveDiscordAllowListMatch({
62-
allowList: commandAllowList,
63-
candidate: params.sender,
64-
allowNameMatching: params.allowNameMatching,
65-
})
66-
: { allowed: false };
77+
const allowMatch = resolveSenderAllowMatch({
78+
allowEntries: access.effectiveAllowFrom,
79+
sender: params.sender,
80+
allowNameMatching: params.allowNameMatching,
81+
});
6782

6883
const commandAuthorized = resolveCommandAuthorizedFromAuthorizers({
6984
useAccessGroups: params.useAccessGroups,
@@ -75,13 +90,15 @@ export async function resolveDiscordDmCommandAccess(params: {
7590
],
7691
modeWhenAccessGroupsOff: "configured",
7792
});
78-
const effectiveCommandAuthorized =
79-
access.decision === "allow" && params.dmPolicy === "open" ? true : commandAuthorized;
8093

8194
return {
8295
decision: access.decision,
8396
reason: access.reason,
84-
commandAuthorized: effectiveCommandAuthorized,
97+
commandAuthorized: resolveDmPolicyCommandAuthorization({
98+
dmPolicy: params.dmPolicy,
99+
decision: access.decision,
100+
commandAuthorized,
101+
}),
85102
allowMatch,
86103
};
87104
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
2+
import type { DiscordDmCommandAccess } from "./dm-command-auth.js";
3+
4+
export async function handleDiscordDmCommandDecision(params: {
5+
dmAccess: DiscordDmCommandAccess;
6+
accountId: string;
7+
sender: {
8+
id: string;
9+
tag?: string;
10+
name?: string;
11+
};
12+
onPairingCreated: (code: string) => Promise<void>;
13+
onUnauthorized: () => Promise<void>;
14+
upsertPairingRequest?: typeof upsertChannelPairingRequest;
15+
}): Promise<boolean> {
16+
if (params.dmAccess.decision === "allow") {
17+
return true;
18+
}
19+
20+
if (params.dmAccess.decision === "pairing") {
21+
const upsertPairingRequest = params.upsertPairingRequest ?? upsertChannelPairingRequest;
22+
const { code, created } = await upsertPairingRequest({
23+
channel: "discord",
24+
id: params.sender.id,
25+
accountId: params.accountId,
26+
meta: {
27+
tag: params.sender.tag,
28+
name: params.sender.name,
29+
},
30+
});
31+
if (created) {
32+
await params.onPairingCreated(code);
33+
}
34+
return false;
35+
}
36+
37+
await params.onUnauthorized();
38+
return false;
39+
}

src/discord/monitor/message-handler.preflight.ts

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import { enqueueSystemEvent } from "../../infra/system-events.js";
2525
import { logDebug } from "../../logger.js";
2626
import { getChildLogger } from "../../logging.js";
2727
import { buildPairingReply } from "../../pairing/pairing-messages.js";
28-
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
2928
import { resolveAgentRoute } from "../../routing/resolve-route.js";
3029
import { DEFAULT_ACCOUNT_ID, resolveAgentIdFromSessionKey } from "../../routing/session-key.js";
3130
import { fetchPluralKitMessageInfo } from "../pluralkit.js";
@@ -42,6 +41,7 @@ import {
4241
resolveGroupDmAllow,
4342
} from "./allow-list.js";
4443
import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js";
44+
import { handleDiscordDmCommandDecision } from "./dm-command-decision.js";
4545
import {
4646
formatDiscordUserTag,
4747
resolveDiscordSystemLocation,
@@ -175,6 +175,7 @@ export async function preflightDiscordMessage(
175175
const dmPolicy = params.discordConfig?.dmPolicy ?? params.discordConfig?.dm?.policy ?? "pairing";
176176
const useAccessGroups = params.cfg.commands?.useAccessGroups !== false;
177177
const resolvedAccountId = params.accountId ?? DEFAULT_ACCOUNT_ID;
178+
const allowNameMatching = isDangerousNameMatchingEnabled(params.discordConfig);
178179
let commandAuthorized = true;
179180
if (isDirectMessage) {
180181
if (dmPolicy === "disabled") {
@@ -190,25 +191,23 @@ export async function preflightDiscordMessage(
190191
name: sender.name,
191192
tag: sender.tag,
192193
},
193-
allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig),
194+
allowNameMatching,
194195
useAccessGroups,
195196
});
196197
commandAuthorized = dmAccess.commandAuthorized;
197198
if (dmAccess.decision !== "allow") {
198199
const allowMatchMeta = formatAllowlistMatchMeta(
199200
dmAccess.allowMatch.allowed ? dmAccess.allowMatch : undefined,
200201
);
201-
if (dmAccess.decision === "pairing") {
202-
const { code, created } = await upsertChannelPairingRequest({
203-
channel: "discord",
202+
await handleDiscordDmCommandDecision({
203+
dmAccess,
204+
accountId: resolvedAccountId,
205+
sender: {
204206
id: author.id,
205-
accountId: resolvedAccountId,
206-
meta: {
207-
tag: formatDiscordUserTag(author),
208-
name: author.username ?? undefined,
209-
},
210-
});
211-
if (created) {
207+
tag: formatDiscordUserTag(author),
208+
name: author.username ?? undefined,
209+
},
210+
onPairingCreated: async (code) => {
212211
logVerbose(
213212
`discord pairing request sender=${author.id} tag=${formatDiscordUserTag(author)} (${allowMatchMeta})`,
214213
);
@@ -229,12 +228,13 @@ export async function preflightDiscordMessage(
229228
} catch (err) {
230229
logVerbose(`discord pairing reply failed for ${author.id}: ${String(err)}`);
231230
}
232-
}
233-
} else {
234-
logVerbose(
235-
`Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`,
236-
);
237-
}
231+
},
232+
onUnauthorized: async () => {
233+
logVerbose(
234+
`Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`,
235+
);
236+
},
237+
});
238238
return null;
239239
}
240240
}
@@ -570,7 +570,7 @@ export async function preflightDiscordMessage(
570570
guildInfo,
571571
memberRoleIds,
572572
sender,
573-
allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig),
573+
allowNameMatching,
574574
});
575575

576576
if (!isDirectMessage) {
@@ -587,7 +587,7 @@ export async function preflightDiscordMessage(
587587
name: sender.name,
588588
tag: sender.tag,
589589
},
590-
{ allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig) },
590+
{ allowNameMatching },
591591
)
592592
: false;
593593
const commandGate = resolveControlCommandGate({

src/discord/monitor/native-command.ts

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ import { logVerbose } from "../../globals.js";
4646
import { createSubsystemLogger } from "../../logging/subsystem.js";
4747
import { getAgentScopedMediaLocalRoots } from "../../media/local-roots.js";
4848
import { buildPairingReply } from "../../pairing/pairing-messages.js";
49-
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
5049
import { resolveAgentRoute } from "../../routing/resolve-route.js";
5150
import { resolveAgentIdFromSessionKey } from "../../routing/session-key.js";
5251
import { buildUntrustedChannelMetadata } from "../../security/channel-metadata.js";
@@ -65,6 +64,7 @@ import {
6564
resolveDiscordOwnerAllowFrom,
6665
} from "./allow-list.js";
6766
import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js";
67+
import { handleDiscordDmCommandDecision } from "./dm-command-decision.js";
6868
import { resolveDiscordChannelInfo } from "./message-utils.js";
6969
import {
7070
readDiscordModelPickerRecentModels,
@@ -1269,6 +1269,7 @@ async function dispatchDiscordCommandInteraction(params: {
12691269
const memberRoleIds = Array.isArray(interaction.rawData.member?.roles)
12701270
? interaction.rawData.member.roles.map((roleId: string) => String(roleId))
12711271
: [];
1272+
const allowNameMatching = isDangerousNameMatchingEnabled(discordConfig);
12721273
const ownerAllowList = normalizeDiscordAllowList(
12731274
discordConfig?.allowFrom ?? discordConfig?.dm?.allowFrom ?? [],
12741275
["discord:", "user:", "pk:"],
@@ -1282,7 +1283,7 @@ async function dispatchDiscordCommandInteraction(params: {
12821283
name: sender.name,
12831284
tag: sender.tag,
12841285
},
1285-
{ allowNameMatching: isDangerousNameMatchingEnabled(discordConfig) },
1286+
{ allowNameMatching },
12861287
)
12871288
: false;
12881289
const guildInfo = resolveDiscordGuildEntry({
@@ -1366,22 +1367,20 @@ async function dispatchDiscordCommandInteraction(params: {
13661367
name: sender.name,
13671368
tag: sender.tag,
13681369
},
1369-
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
1370+
allowNameMatching,
13701371
useAccessGroups,
13711372
});
13721373
commandAuthorized = dmAccess.commandAuthorized;
13731374
if (dmAccess.decision !== "allow") {
1374-
if (dmAccess.decision === "pairing") {
1375-
const { code, created } = await upsertChannelPairingRequest({
1376-
channel: "discord",
1375+
await handleDiscordDmCommandDecision({
1376+
dmAccess,
1377+
accountId,
1378+
sender: {
13771379
id: user.id,
1378-
accountId,
1379-
meta: {
1380-
tag: sender.tag,
1381-
name: sender.name,
1382-
},
1383-
});
1384-
if (created) {
1380+
tag: sender.tag,
1381+
name: sender.name,
1382+
},
1383+
onPairingCreated: async (code) => {
13851384
await respond(
13861385
buildPairingReply({
13871386
channel: "discord",
@@ -1390,10 +1389,11 @@ async function dispatchDiscordCommandInteraction(params: {
13901389
}),
13911390
{ ephemeral: true },
13921391
);
1393-
}
1394-
} else {
1395-
await respond("You are not authorized to use this command.", { ephemeral: true });
1396-
}
1392+
},
1393+
onUnauthorized: async () => {
1394+
await respond("You are not authorized to use this command.", { ephemeral: true });
1395+
},
1396+
});
13971397
return;
13981398
}
13991399
}
@@ -1403,7 +1403,7 @@ async function dispatchDiscordCommandInteraction(params: {
14031403
guildInfo,
14041404
memberRoleIds,
14051405
sender,
1406-
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
1406+
allowNameMatching,
14071407
});
14081408
const authorizers = useAccessGroups
14091409
? [
@@ -1509,7 +1509,7 @@ async function dispatchDiscordCommandInteraction(params: {
15091509
channelConfig,
15101510
guildInfo,
15111511
sender: { id: sender.id, name: sender.name, tag: sender.tag },
1512-
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig),
1512+
allowNameMatching,
15131513
});
15141514
const ctxPayload = finalizeInboundContext({
15151515
Body: prompt,

0 commit comments

Comments
 (0)