|
6 | 6 | import os |
7 | 7 | import uuid |
8 | 8 | from contextlib import asynccontextmanager |
| 9 | +from datetime import UTC, datetime, timedelta |
9 | 10 |
|
10 | 11 | import pytest |
11 | 12 | from dotenv import dotenv_values |
@@ -790,6 +791,66 @@ async def test_create_draft_upload_rejects_invalid_normalized_content_type( |
790 | 791 | ).scalars() |
791 | 792 | assert upload_rows.all() == [] |
792 | 793 |
|
| 794 | + async def test_create_draft_upload_reaps_expired_incomplete_uploads( |
| 795 | + self, |
| 796 | + skill_service: SkillService, |
| 797 | + monkeypatch: pytest.MonkeyPatch, |
| 798 | + ) -> None: |
| 799 | + """Creating a new upload session should clean up older expired sessions.""" |
| 800 | + |
| 801 | + created = await skill_service.create_skill(SkillCreate(name="reap-uploads")) |
| 802 | + stale_upload = await skill_service.create_draft_upload( |
| 803 | + skill_id=created.id, |
| 804 | + params=SkillUploadSessionCreate( |
| 805 | + sha256=hashlib.sha256(b"stale upload").hexdigest(), |
| 806 | + size_bytes=len(b"stale upload"), |
| 807 | + content_type="text/plain; charset=utf-8", |
| 808 | + ), |
| 809 | + ) |
| 810 | + stale_upload_row = await skill_service.session.scalar( |
| 811 | + select(SkillUploadModel).where( |
| 812 | + SkillUploadModel.id == stale_upload.upload_id |
| 813 | + ) |
| 814 | + ) |
| 815 | + assert stale_upload_row is not None |
| 816 | + stale_upload_row.expires_at = datetime.now(UTC) - timedelta(minutes=1) |
| 817 | + skill_service.session.add(stale_upload_row) |
| 818 | + await skill_service.session.commit() |
| 819 | + |
| 820 | + deleted: dict[str, str] = {} |
| 821 | + |
| 822 | + async def fake_delete_file(*, key: str, bucket: str) -> None: |
| 823 | + deleted["key"] = key |
| 824 | + deleted["bucket"] = bucket |
| 825 | + |
| 826 | + monkeypatch.setattr( |
| 827 | + "tracecat.agent.skill.service.blob.delete_file", |
| 828 | + fake_delete_file, |
| 829 | + ) |
| 830 | + |
| 831 | + fresh_upload = await skill_service.create_draft_upload( |
| 832 | + skill_id=created.id, |
| 833 | + params=SkillUploadSessionCreate( |
| 834 | + sha256=hashlib.sha256(b"fresh upload").hexdigest(), |
| 835 | + size_bytes=len(b"fresh upload"), |
| 836 | + content_type="text/plain; charset=utf-8", |
| 837 | + ), |
| 838 | + ) |
| 839 | + |
| 840 | + assert fresh_upload.upload_id != stale_upload.upload_id |
| 841 | + assert deleted == { |
| 842 | + "key": stale_upload.key, |
| 843 | + "bucket": config.TRACECAT__BLOB_STORAGE_BUCKET_SKILLS, |
| 844 | + } |
| 845 | + assert ( |
| 846 | + await skill_service.session.scalar( |
| 847 | + select(SkillUploadModel).where( |
| 848 | + SkillUploadModel.id == stale_upload.upload_id |
| 849 | + ) |
| 850 | + ) |
| 851 | + is None |
| 852 | + ) |
| 853 | + |
793 | 854 | async def test_attach_uploaded_blob_rejects_size_mismatch( |
794 | 855 | self, |
795 | 856 | skill_service: SkillService, |
@@ -855,6 +916,67 @@ async def fake_open_download_stream(*, key: str, bucket: str): |
855 | 916 | ) |
856 | 917 | assert iterated is False |
857 | 918 |
|
| 919 | + async def test_attach_uploaded_blob_deletes_expired_staged_key( |
| 920 | + self, |
| 921 | + skill_service: SkillService, |
| 922 | + monkeypatch: pytest.MonkeyPatch, |
| 923 | + ) -> None: |
| 924 | + """Expired staged uploads should be deleted before the validation error returns.""" |
| 925 | + |
| 926 | + created = await skill_service.create_skill(SkillCreate(name="expired-upload")) |
| 927 | + draft = await skill_service.get_draft(created.id) |
| 928 | + assert draft is not None |
| 929 | + |
| 930 | + upload = await skill_service.create_draft_upload( |
| 931 | + skill_id=created.id, |
| 932 | + params=SkillUploadSessionCreate( |
| 933 | + sha256=hashlib.sha256(b"expired upload").hexdigest(), |
| 934 | + size_bytes=len(b"expired upload"), |
| 935 | + content_type="text/plain; charset=utf-8", |
| 936 | + ), |
| 937 | + ) |
| 938 | + upload_row = await skill_service.session.scalar( |
| 939 | + select(SkillUploadModel).where(SkillUploadModel.id == upload.upload_id) |
| 940 | + ) |
| 941 | + assert upload_row is not None |
| 942 | + upload_row.expires_at = datetime.now(UTC) - timedelta(minutes=1) |
| 943 | + skill_service.session.add(upload_row) |
| 944 | + await skill_service.session.commit() |
| 945 | + |
| 946 | + deleted: dict[str, str] = {} |
| 947 | + |
| 948 | + async def fake_delete_file(*, key: str, bucket: str) -> None: |
| 949 | + deleted["key"] = key |
| 950 | + deleted["bucket"] = bucket |
| 951 | + |
| 952 | + monkeypatch.setattr( |
| 953 | + "tracecat.agent.skill.service.blob.delete_file", |
| 954 | + fake_delete_file, |
| 955 | + ) |
| 956 | + |
| 957 | + with pytest.raises(TracecatValidationError) as exc_info: |
| 958 | + await skill_service.patch_draft( |
| 959 | + skill_id=created.id, |
| 960 | + params=SkillDraftPatch( |
| 961 | + base_revision=draft.draft_revision, |
| 962 | + operations=[ |
| 963 | + SkillDraftAttachUploadedBlobOp( |
| 964 | + path="references/uploaded.txt", |
| 965 | + upload_id=upload.upload_id, |
| 966 | + ) |
| 967 | + ], |
| 968 | + ), |
| 969 | + ) |
| 970 | + |
| 971 | + assert exc_info.value.detail == { |
| 972 | + "code": "upload_expired", |
| 973 | + "upload_id": str(upload.upload_id), |
| 974 | + } |
| 975 | + assert deleted == { |
| 976 | + "key": upload.key, |
| 977 | + "bucket": config.TRACECAT__BLOB_STORAGE_BUCKET_SKILLS, |
| 978 | + } |
| 979 | + |
858 | 980 | async def test_attach_uploaded_blob_stops_streaming_after_size_exceeded( |
859 | 981 | self, |
860 | 982 | skill_service: SkillService, |
|
0 commit comments