Skip to content

Commit 3330802

Browse files
committed
refac
1 parent f96e8f0 commit 3330802

File tree

1 file changed

+42
-25
lines changed

1 file changed

+42
-25
lines changed

backend/open_webui/models/groups.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ def get_all_groups(self, db: Optional[Session] = None) -> list[GroupModel]:
164164

165165
def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse]:
166166
with get_db_context(db) as db:
167-
query = db.query(Group)
167+
member_count = func.count(GroupMember.user_id).label("member_count")
168+
query = db.query(Group, member_count).outerjoin(
169+
GroupMember, GroupMember.group_id == Group.id
170+
)
168171

169172
if filter:
170173
if "query" in filter:
@@ -179,9 +182,6 @@ def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse
179182
json_share_lower = func.lower(json_share_str)
180183

181184
if share_value:
182-
# Groups open to anyone: data is null, config.share is null, or share is true
183-
# Use case-insensitive string comparison to handle variations like "True", "TRUE"
184-
# Handle potential JSON boolean to string casting issues by checking for both string 'true' and boolean equivalence if possible,
185185
anyone_can_share = or_(
186186
Group.data.is_(None),
187187
json_share_str.is_(None),
@@ -190,7 +190,6 @@ def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse
190190
)
191191

192192
if member_id:
193-
# Also include member-only groups where user is a member
194193
member_groups_select = select(GroupMember.group_id).where(
195194
GroupMember.user_id == member_id
196195
)
@@ -211,21 +210,28 @@ def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse
211210
else:
212211
# Only apply member_id filter when share filter is NOT present
213212
if "member_id" in filter:
214-
query = query.join(
215-
GroupMember, GroupMember.group_id == Group.id
216-
).filter(GroupMember.user_id == filter["member_id"])
213+
query = query.filter(
214+
Group.id.in_(
215+
select(GroupMember.group_id).where(
216+
GroupMember.user_id == filter["member_id"]
217+
)
218+
)
219+
)
220+
221+
results = (
222+
query.group_by(Group.id)
223+
.order_by(Group.updated_at.desc())
224+
.all()
225+
)
217226

218-
groups = query.order_by(Group.updated_at.desc()).all()
219-
group_ids = [group.id for group in groups]
220-
member_counts = self.get_group_member_counts_by_ids(group_ids, db=db)
221227
return [
222228
GroupResponse.model_validate(
223229
{
224230
**GroupModel.model_validate(group).model_dump(),
225-
"member_count": member_counts.get(group.id, 0),
231+
"member_count": count or 0,
226232
}
227233
)
228-
for group in groups
234+
for group, count in results
229235
]
230236

231237
def search_groups(
@@ -242,31 +248,42 @@ def search_groups(
242248
if "query" in filter:
243249
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
244250
if "member_id" in filter:
245-
query = query.join(
246-
GroupMember, GroupMember.group_id == Group.id
247-
).filter(GroupMember.user_id == filter["member_id"])
251+
query = query.filter(
252+
Group.id.in_(
253+
select(GroupMember.group_id).where(
254+
GroupMember.user_id == filter["member_id"]
255+
)
256+
)
257+
)
248258

249259
if "share" in filter:
250-
# 'share' is stored in data JSON, support both sqlite and postgres
251260
share_value = filter["share"]
252-
print("Filtering by share:", share_value)
253261
query = query.filter(
254262
Group.data.op("->>")("share") == str(share_value)
255263
)
256264

257265
total = query.count()
258-
query = query.order_by(Group.updated_at.desc())
259-
groups = query.offset(skip).limit(limit).all()
260-
group_ids = [group.id for group in groups]
261-
member_counts = self.get_group_member_counts_by_ids(group_ids, db=db)
266+
267+
member_count = func.count(GroupMember.user_id).label("member_count")
268+
results = (
269+
query.add_columns(member_count)
270+
.outerjoin(GroupMember, GroupMember.group_id == Group.id)
271+
.group_by(Group.id)
272+
.order_by(Group.updated_at.desc())
273+
.offset(skip)
274+
.limit(limit)
275+
.all()
276+
)
262277

263278
return {
264279
"items": [
265280
GroupResponse.model_validate(
266-
**GroupModel.model_validate(group).model_dump(),
267-
member_count=member_counts.get(group.id, 0),
281+
{
282+
**GroupModel.model_validate(group).model_dump(),
283+
"member_count": count or 0,
284+
}
268285
)
269-
for group in groups
286+
for group, count in results
270287
],
271288
"total": total,
272289
}

0 commit comments

Comments
 (0)