Skip to content

Commit 1a6c686

Browse files
committed
2 parents 4b7cf64 + 48e83ba commit 1a6c686

File tree

20 files changed

+413
-25
lines changed

20 files changed

+413
-25
lines changed

label_studio/core/all_urls.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,5 +1972,11 @@
19721972
"module": "jwt_auth.views.LSAPITokenRotateView",
19731973
"name": "jwt_auth:token_rotate",
19741974
"decorators": ""
1975+
},
1976+
{
1977+
"url": "/api/session-policy/",
1978+
"module": "session_policy.api.SessionTimeoutPolicyView",
1979+
"name": "session_policy:session-policy",
1980+
"decorators": ""
19751981
}
19761982
]

label_studio/core/middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def process_request(self, request) -> None:
216216

217217
active_org = request.user.active_organization
218218
if flag_set('fflag_feat_utc_46_session_timeout_policy', user=request.user) and active_org:
219-
org_max_session_age = timedelta(hours=active_org.session_timeout_policy.max_session_age).total_seconds()
219+
org_max_session_age = timedelta(minutes=active_org.session_timeout_policy.max_session_age).total_seconds()
220220
max_time_between_activity = timedelta(
221-
hours=active_org.session_timeout_policy.max_time_between_activity
221+
minutes=active_org.session_timeout_policy.max_time_between_activity
222222
).total_seconds()
223223

224224
if (current_time - last_login) > org_max_session_age:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Generated by Django 5.1.9 on 2025-06-04 19:36
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
("core", "0001_initial"),
10+
]
11+
12+
operations = [
13+
migrations.CreateModel(
14+
name="DeletedRow",
15+
fields=[
16+
(
17+
"id",
18+
models.AutoField(
19+
auto_created=True,
20+
primary_key=True,
21+
serialize=False,
22+
verbose_name="ID",
23+
),
24+
),
25+
("model", models.CharField(max_length=1024)),
26+
("row_id", models.IntegerField(null=True)),
27+
("data", models.JSONField(blank=True, null=True)),
28+
("reason", models.TextField(blank=True, null=True)),
29+
("created_at", models.DateTimeField(auto_now_add=True)),
30+
("updated_at", models.DateTimeField(auto_now=True)),
31+
("organization_id", models.IntegerField(blank=True, null=True)),
32+
("project_id", models.IntegerField(blank=True, null=True)),
33+
("user_id", models.IntegerField(blank=True, null=True)),
34+
],
35+
),
36+
]

label_studio/core/models.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import json
12
import logging
23

4+
from django.core import serializers
35
from django.db import models
46
from django.db.models import JSONField
57
from django.utils.translation import gettext_lazy as _
@@ -42,3 +44,39 @@ class AsyncMigrationStatus(models.Model):
4244

4345
def __str__(self):
4446
return f'(id={self.id}) ' + self.name + (' at project ' + str(self.project) if self.project else '')
47+
48+
49+
class DeletedRow(models.Model):
50+
"""
51+
Model to store deleted rows of other models.
52+
Useful for using as backup for deleted rows, in case we need to restore them.
53+
"""
54+
55+
model = models.CharField(max_length=1024) # tasks.task, projects.project, etc.
56+
row_id = models.IntegerField(null=True) # primary key of the deleted row. task.id, project.id, etc.
57+
data = JSONField(null=True, blank=True) # serialized json of the deleted row.
58+
reason = models.TextField(null=True, blank=True) # reason for deletion.
59+
created_at = models.DateTimeField(auto_now_add=True)
60+
updated_at = models.DateTimeField(auto_now=True)
61+
62+
# optional fields for searching purposes
63+
organization_id = models.IntegerField(null=True, blank=True)
64+
project_id = models.IntegerField(null=True, blank=True)
65+
user_id = models.IntegerField(null=True, blank=True)
66+
67+
@classmethod
68+
def serialize_and_create(cls, model, **kwargs) -> 'DeletedRow':
69+
data = json.loads(serializers.serialize('json', [model]))[0]
70+
model = data['model']
71+
row_id = int(data['pk'])
72+
return cls.objects.create(model=model, row_id=row_id, data=data, **kwargs)
73+
74+
@classmethod
75+
def bulk_serialize_and_create(cls, queryset, **kwargs) -> list['DeletedRow']:
76+
serialized_data = json.loads(serializers.serialize('json', queryset))
77+
bulk_objects = []
78+
for data in serialized_data:
79+
model = data['model']
80+
row_id = int(data['pk'])
81+
bulk_objects.append(cls(model=model, row_id=row_id, data=data, **kwargs))
82+
return cls.objects.bulk_create(bulk_objects)

label_studio/core/tests/__init__.py

Whitespace-only changes.
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import json
2+
3+
from core.models import DeletedRow
4+
from django.core import serializers
5+
from django.test import TestCase
6+
from organizations.models import Organization
7+
from organizations.tests.factories import OrganizationFactory
8+
from projects.models import Project
9+
from projects.tests.factories import ProjectFactory
10+
from tasks.models import Task
11+
from tasks.tests.factories import TaskFactory
12+
13+
14+
class TestDeletedRow(TestCase):
15+
def _assert_delete_and_restore_equal(self, drow, original):
16+
original_dict = original.__dict__.copy()
17+
original_id = original.id
18+
original_dict.pop('_state')
19+
original_created_at = original_dict.pop('created_at')
20+
original_updated_at = original_dict.pop('updated_at')
21+
original.delete()
22+
23+
for deserialized_object in serializers.deserialize('json', json.dumps([drow.data])):
24+
deserialized_object.save()
25+
new_object = original.__class__.objects.get(id=original_id)
26+
27+
new_dict = new_object.__dict__.copy()
28+
new_dict.pop('_state')
29+
new_created_at = new_dict.pop('created_at')
30+
new_updated_at = new_dict.pop('updated_at')
31+
32+
assert new_dict == original_dict
33+
# Datetime loses microsecond precision, so we can't compare them directly
34+
assert abs((new_created_at - original_created_at).total_seconds()) < 0.001
35+
assert abs((new_updated_at - original_updated_at).total_seconds()) < 0.001
36+
37+
def test_serialize_organization(self):
38+
organization = OrganizationFactory()
39+
drow = DeletedRow.serialize_and_create(organization, reason='reason', organization_id=organization.id)
40+
assert drow.row_id == organization.id
41+
assert drow.model == 'organizations.organization'
42+
assert drow.data['fields']['title'] == organization.title
43+
assert drow.data['fields']['token'] == organization.token
44+
assert drow.data['fields']['created_by'] == organization.created_by_id
45+
assert drow.reason == 'reason'
46+
assert drow.organization_id == organization.id
47+
assert drow.project_id is None
48+
assert drow.user_id is None
49+
self._assert_delete_and_restore_equal(drow, organization)
50+
51+
def test_serialize_project(self):
52+
project = ProjectFactory()
53+
drow = DeletedRow.serialize_and_create(
54+
project, reason='reason', organization_id=project.organization.id, project_id=project.id
55+
)
56+
assert drow.row_id == project.id
57+
assert drow.model == 'projects.project'
58+
assert drow.data['fields']['title'] == project.title
59+
assert drow.data['fields']['organization'] == project.organization.id
60+
assert drow.reason == 'reason'
61+
assert drow.organization_id == project.organization.id
62+
assert drow.project_id == project.id
63+
self._assert_delete_and_restore_equal(drow, project)
64+
65+
def test_serialize_task(self):
66+
organization = OrganizationFactory()
67+
project = ProjectFactory(organization=organization)
68+
task = TaskFactory(project=project)
69+
drow = DeletedRow.serialize_and_create(
70+
task,
71+
reason='reason',
72+
organization_id=organization.id,
73+
project_id=project.id,
74+
user_id=organization.created_by_id,
75+
)
76+
assert drow.row_id == task.id
77+
assert drow.model == 'tasks.task'
78+
assert drow.data['fields']['project'] == project.id
79+
assert drow.reason == 'reason'
80+
assert drow.organization_id == organization.id
81+
assert drow.project_id == project.id
82+
assert drow.user_id == organization.created_by_id
83+
self._assert_delete_and_restore_equal(drow, task)
84+
85+
def test_bulk_serialize_and_create(self):
86+
organization_1 = OrganizationFactory()
87+
organization_2 = OrganizationFactory()
88+
drows = DeletedRow.bulk_serialize_and_create(Organization.objects.all(), reason='reason')
89+
assert len(drows) == 2
90+
assert drows[0].model == 'organizations.organization'
91+
assert drows[0].row_id == organization_1.id
92+
assert drows[0].data['fields']['title'] == organization_1.title
93+
assert drows[0].data['fields']['token'] == organization_1.token
94+
assert drows[0].data['fields']['created_by'] == organization_1.created_by_id
95+
assert drows[0].reason == 'reason'
96+
assert drows[1].model == 'organizations.organization'
97+
assert drows[1].row_id == organization_2.id
98+
assert drows[1].data['fields']['title'] == organization_2.title
99+
assert drows[1].data['fields']['token'] == organization_2.token
100+
assert drows[1].data['fields']['created_by'] == organization_2.created_by_id
101+
assert drows[1].reason == 'reason'
102+
103+
project_1 = ProjectFactory(organization=organization_1)
104+
project_2 = ProjectFactory(organization=organization_1)
105+
drows = DeletedRow.bulk_serialize_and_create(
106+
Project.objects.all(), reason='reason', organization_id=organization_1.id
107+
)
108+
assert len(drows) == 2
109+
assert drows[0].model == 'projects.project'
110+
assert drows[0].row_id == project_1.id
111+
assert drows[0].data['fields']['title'] == project_1.title
112+
assert drows[0].data['fields']['organization'] == organization_1.id
113+
assert drows[0].reason == 'reason'
114+
assert drows[0].organization_id == organization_1.id
115+
assert drows[1].model == 'projects.project'
116+
assert drows[1].row_id == project_2.id
117+
assert drows[1].data['fields']['title'] == project_2.title
118+
assert drows[1].data['fields']['organization'] == organization_1.id
119+
assert drows[1].reason == 'reason'
120+
assert drows[1].organization_id == organization_1.id
121+
122+
task_1 = TaskFactory(project=project_1)
123+
task_2 = TaskFactory(project=project_1)
124+
drows = DeletedRow.bulk_serialize_and_create(
125+
Task.objects.all(), reason='reason', organization_id=organization_1.id, project_id=project_1.id
126+
)
127+
assert len(drows) == 2
128+
assert drows[0].model == 'tasks.task'
129+
assert drows[0].row_id == task_1.id
130+
assert drows[0].data['fields']['project'] == project_1.id
131+
assert drows[0].reason == 'reason'
132+
assert drows[0].organization_id == organization_1.id
133+
assert drows[0].project_id == project_1.id
134+
assert drows[1].model == 'tasks.task'
135+
assert drows[1].row_id == task_2.id
136+
assert drows[1].data['fields']['project'] == project_1.id
137+
assert drows[1].reason == 'reason'
138+
assert drows[1].organization_id == organization_1.id
139+
assert drows[1].project_id == project_1.id

label_studio/core/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
path('__lsa/', views.collect_metrics, name='collect_metrics'),
112112
re_path(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')),
113113
re_path(r'^', include('jwt_auth.urls')),
114+
re_path(r'^', include('session_policy.urls')),
114115
]
115116

116117
if settings.DEBUG:

label_studio/feature_flags.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,33 @@
136136
"version": 3,
137137
"deleted": false
138138
},
139+
"ff_all_fit_29_org_settings_page": {
140+
"key": "ff_all_fit_29_org_settings_page",
141+
"on": false,
142+
"prerequisites": [],
143+
"targets": [],
144+
"contextTargets": [],
145+
"rules": [],
146+
"fallthrough": {
147+
"variation": 0
148+
},
149+
"offVariation": 1,
150+
"variations": [
151+
true,
152+
false
153+
],
154+
"clientSideAvailability": {
155+
"usingMobileKey": false,
156+
"usingEnvironmentId": false
157+
},
158+
"clientSide": false,
159+
"salt": "109dbd50896a4840bf8fdbc6d8215a45",
160+
"trackEvents": false,
161+
"trackEventsFallthrough": false,
162+
"debugEventsUntilDate": null,
163+
"version": 2,
164+
"deleted": false
165+
},
139166
"ff_back_2004_async_review_24032022_short": {
140167
"key": "ff_back_2004_async_review_24032022_short",
141168
"on": false,

label_studio/jwt_auth/views.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class LSAPITokenRotateView(TokenViewBase):
192192
authentication_classes = [JWTAuthentication, TokenAuthenticationPhaseout, SessionAuthentication]
193193
permission_classes = [IsAuthenticated]
194194
_serializer_class = 'jwt_auth.serializers.LSAPITokenRotateSerializer'
195+
token_class = LSAPIToken
195196

196197
@swagger_auto_schema(
197198
tags=['JWT'],
@@ -219,5 +220,9 @@ def post(self, request, *args, **kwargs):
219220
return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_400_BAD_REQUEST)
220221

221222
# Create a new token for the user
222-
new_token = LSAPIToken.for_user(request.user)
223+
new_token = self.create_token(request.user)
223224
return Response({'refresh': new_token.get_full_jwt()}, status=status.HTTP_200_OK)
225+
226+
def create_token(self, user):
227+
"""Create a new token for the user. Can be overridden by child classes to use different token classes."""
228+
return self.token_class.for_user(user)

label_studio/session_policy/api.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from django.utils.decorators import method_decorator
2+
from drf_yasg.utils import swagger_auto_schema
3+
from rest_framework import generics
4+
from rest_framework.permissions import IsAuthenticated
5+
6+
from .models import SessionTimeoutPolicy
7+
from .serializers import SessionTimeoutPolicySerializer
8+
9+
10+
@method_decorator(
11+
name='get',
12+
decorator=swagger_auto_schema(
13+
tags=['Session Policy'],
14+
operation_summary='Retrieve Session Policy',
15+
operation_description='Retrieve session timeout policy for the currently active organization.',
16+
),
17+
)
18+
@method_decorator(
19+
name='patch',
20+
decorator=swagger_auto_schema(
21+
tags=['Session Policy'],
22+
operation_summary='Update Session Policy',
23+
operation_description='Update session timeout policy for the currently active organization.',
24+
),
25+
)
26+
class SessionTimeoutPolicyView(generics.RetrieveUpdateAPIView):
27+
"""
28+
API endpoint for retrieving and updating organization's session timeout policy
29+
"""
30+
31+
serializer_class = SessionTimeoutPolicySerializer
32+
permission_classes = [IsAuthenticated]
33+
http_method_names = ['get', 'patch'] # Explicitly specify allowed methods
34+
35+
def get_object(self):
36+
# Get the organization from the request
37+
org = self.request.user.active_organization
38+
# Get or create the session policy for the organization
39+
policy, _ = SessionTimeoutPolicy.objects.get_or_create(organization=org)
40+
return policy

0 commit comments

Comments
 (0)