Skip to content

Commit 37e6998

Browse files
committed
fix: reap expired skill uploads
1 parent be71e3c commit 37e6998

2 files changed

Lines changed: 192 additions & 0 deletions

File tree

tests/unit/test_skill_service.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import uuid
88
from contextlib import asynccontextmanager
9+
from datetime import UTC, datetime, timedelta
910

1011
import pytest
1112
from dotenv import dotenv_values
@@ -790,6 +791,66 @@ async def test_create_draft_upload_rejects_invalid_normalized_content_type(
790791
).scalars()
791792
assert upload_rows.all() == []
792793

794+
async def test_create_draft_upload_reaps_expired_incomplete_uploads(
795+
self,
796+
skill_service: SkillService,
797+
monkeypatch: pytest.MonkeyPatch,
798+
) -> None:
799+
"""Creating a new upload session should clean up older expired sessions."""
800+
801+
created = await skill_service.create_skill(SkillCreate(name="reap-uploads"))
802+
stale_upload = await skill_service.create_draft_upload(
803+
skill_id=created.id,
804+
params=SkillUploadSessionCreate(
805+
sha256=hashlib.sha256(b"stale upload").hexdigest(),
806+
size_bytes=len(b"stale upload"),
807+
content_type="text/plain; charset=utf-8",
808+
),
809+
)
810+
stale_upload_row = await skill_service.session.scalar(
811+
select(SkillUploadModel).where(
812+
SkillUploadModel.id == stale_upload.upload_id
813+
)
814+
)
815+
assert stale_upload_row is not None
816+
stale_upload_row.expires_at = datetime.now(UTC) - timedelta(minutes=1)
817+
skill_service.session.add(stale_upload_row)
818+
await skill_service.session.commit()
819+
820+
deleted: dict[str, str] = {}
821+
822+
async def fake_delete_file(*, key: str, bucket: str) -> None:
823+
deleted["key"] = key
824+
deleted["bucket"] = bucket
825+
826+
monkeypatch.setattr(
827+
"tracecat.agent.skill.service.blob.delete_file",
828+
fake_delete_file,
829+
)
830+
831+
fresh_upload = await skill_service.create_draft_upload(
832+
skill_id=created.id,
833+
params=SkillUploadSessionCreate(
834+
sha256=hashlib.sha256(b"fresh upload").hexdigest(),
835+
size_bytes=len(b"fresh upload"),
836+
content_type="text/plain; charset=utf-8",
837+
),
838+
)
839+
840+
assert fresh_upload.upload_id != stale_upload.upload_id
841+
assert deleted == {
842+
"key": stale_upload.key,
843+
"bucket": config.TRACECAT__BLOB_STORAGE_BUCKET_SKILLS,
844+
}
845+
assert (
846+
await skill_service.session.scalar(
847+
select(SkillUploadModel).where(
848+
SkillUploadModel.id == stale_upload.upload_id
849+
)
850+
)
851+
is None
852+
)
853+
793854
async def test_attach_uploaded_blob_rejects_size_mismatch(
794855
self,
795856
skill_service: SkillService,
@@ -855,6 +916,67 @@ async def fake_open_download_stream(*, key: str, bucket: str):
855916
)
856917
assert iterated is False
857918

919+
async def test_attach_uploaded_blob_deletes_expired_staged_key(
920+
self,
921+
skill_service: SkillService,
922+
monkeypatch: pytest.MonkeyPatch,
923+
) -> None:
924+
"""Expired staged uploads should be deleted before the validation error returns."""
925+
926+
created = await skill_service.create_skill(SkillCreate(name="expired-upload"))
927+
draft = await skill_service.get_draft(created.id)
928+
assert draft is not None
929+
930+
upload = await skill_service.create_draft_upload(
931+
skill_id=created.id,
932+
params=SkillUploadSessionCreate(
933+
sha256=hashlib.sha256(b"expired upload").hexdigest(),
934+
size_bytes=len(b"expired upload"),
935+
content_type="text/plain; charset=utf-8",
936+
),
937+
)
938+
upload_row = await skill_service.session.scalar(
939+
select(SkillUploadModel).where(SkillUploadModel.id == upload.upload_id)
940+
)
941+
assert upload_row is not None
942+
upload_row.expires_at = datetime.now(UTC) - timedelta(minutes=1)
943+
skill_service.session.add(upload_row)
944+
await skill_service.session.commit()
945+
946+
deleted: dict[str, str] = {}
947+
948+
async def fake_delete_file(*, key: str, bucket: str) -> None:
949+
deleted["key"] = key
950+
deleted["bucket"] = bucket
951+
952+
monkeypatch.setattr(
953+
"tracecat.agent.skill.service.blob.delete_file",
954+
fake_delete_file,
955+
)
956+
957+
with pytest.raises(TracecatValidationError) as exc_info:
958+
await skill_service.patch_draft(
959+
skill_id=created.id,
960+
params=SkillDraftPatch(
961+
base_revision=draft.draft_revision,
962+
operations=[
963+
SkillDraftAttachUploadedBlobOp(
964+
path="references/uploaded.txt",
965+
upload_id=upload.upload_id,
966+
)
967+
],
968+
),
969+
)
970+
971+
assert exc_info.value.detail == {
972+
"code": "upload_expired",
973+
"upload_id": str(upload.upload_id),
974+
}
975+
assert deleted == {
976+
"key": upload.key,
977+
"bucket": config.TRACECAT__BLOB_STORAGE_BUCKET_SKILLS,
978+
}
979+
858980
async def test_attach_uploaded_blob_stops_streaming_after_size_exceeded(
859981
self,
860982
skill_service: SkillService,

tracecat/agent/skill/service.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,20 @@ def _staged_upload_key_for(self, *, upload_id: uuid.UUID, sha256: str) -> str:
179179
normalized_sha256 = self._normalize_sha256(sha256)
180180
return f"skills/{self.workspace_id}/uploads/{upload_id}/{normalized_sha256}"
181181

182+
def _staged_upload_prefix(self) -> str:
183+
"""Return the storage-prefix used for staged upload objects."""
184+
185+
return f"skills/{self.workspace_id}/uploads/"
186+
187+
def _is_staged_upload_object(self, upload: SkillUploadModel) -> bool:
188+
"""Return whether an upload still points at a temporary staged object."""
189+
190+
return (
191+
upload.completed_at is None
192+
and upload.blob_id is None
193+
and upload.key.startswith(self._staged_upload_prefix())
194+
)
195+
182196
@staticmethod
183197
def _normalize_path(path: str) -> str:
184198
"""Normalize and validate a relative POSIX draft path.
@@ -478,6 +492,11 @@ async def _materialize_uploaded_blob(self, upload: SkillUploadModel) -> SkillBlo
478492
return blob_row
479493

480494
if upload.expires_at < datetime.now(UTC):
495+
if self._is_staged_upload_object(upload):
496+
await self._delete_staged_upload_object_best_effort(
497+
upload,
498+
reason="upload_expired",
499+
)
481500
raise TracecatValidationError(
482501
"Skill upload session has expired",
483502
detail={"code": "upload_expired", "upload_id": str(upload.id)},
@@ -551,6 +570,51 @@ async def _materialize_uploaded_blob(self, upload: SkillUploadModel) -> SkillBlo
551570
await self.session.flush()
552571
return blob_row
553572

573+
async def _delete_staged_upload_object_best_effort(
574+
self,
575+
upload: SkillUploadModel,
576+
*,
577+
reason: str,
578+
) -> None:
579+
"""Delete a temporary staged upload object without failing the caller."""
580+
581+
try:
582+
await blob.delete_file(key=upload.key, bucket=upload.bucket)
583+
except Exception as exc:
584+
self.logger.warning(
585+
"Failed to delete staged skill upload object",
586+
upload_id=str(upload.id),
587+
key=upload.key,
588+
bucket=upload.bucket,
589+
reason=reason,
590+
error=str(exc),
591+
)
592+
593+
async def _reap_expired_incomplete_uploads(self) -> list[SkillUploadModel]:
594+
"""Delete expired incomplete upload rows and return staged objects to clean up."""
595+
596+
expired_stmt = (
597+
select(SkillUploadModel)
598+
.where(
599+
SkillUploadModel.workspace_id == self.workspace_id,
600+
SkillUploadModel.completed_at.is_(None),
601+
SkillUploadModel.expires_at < datetime.now(UTC),
602+
)
603+
.with_for_update(skip_locked=True)
604+
)
605+
expired_uploads = (await self.session.execute(expired_stmt)).scalars().all()
606+
if not expired_uploads:
607+
return []
608+
609+
for upload in expired_uploads:
610+
await self.session.delete(upload)
611+
await self.session.flush()
612+
return [
613+
upload
614+
for upload in expired_uploads
615+
if self._is_staged_upload_object(upload)
616+
]
617+
554618
async def _list_draft_rows(
555619
self, skill_id: uuid.UUID
556620
) -> list[tuple[SkillDraftFile, SkillBlob]]:
@@ -1402,6 +1466,7 @@ async def create_draft_upload(
14021466
skill = await self.get_skill(skill_id)
14031467
if skill is None:
14041468
raise TracecatNotFoundError(f"Skill '{skill_id}' not found")
1469+
expired_uploads = await self._reap_expired_incomplete_uploads()
14051470

14061471
upload_id = uuid.uuid4()
14071472
expires_at = datetime.now(UTC) + timedelta(seconds=DEFAULT_UPLOAD_TTL_SECONDS)
@@ -1424,6 +1489,11 @@ async def create_draft_upload(
14241489
upload_row.id = upload_id
14251490
self.session.add(upload_row)
14261491
await self.session.commit()
1492+
for expired_upload in expired_uploads:
1493+
await self._delete_staged_upload_object_best_effort(
1494+
expired_upload,
1495+
reason="reap_expired_upload",
1496+
)
14271497
return SkillUploadSessionRead(
14281498
upload_id=upload_id,
14291499
upload_url=await blob.generate_presigned_upload_url(

0 commit comments

Comments
 (0)