Fabric Defect Detection - Simplified
Beginner-Friendly Code
This document contains a simplified PyTorch pipeline for fabric defect detection using transfer learning
(ResNet-18). It is designed to be beginner-friendly and easy to present for academic evaluation.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# ------------------ Data Preparation ------------------
transform = transforms.Compose([
transforms.Resize((128,128)),
transforms.ToTensor(),
])
# Load dataset (folders: train/good, train/oil, train/thread, etc.)
train_data = datasets.ImageFolder("data/train", transform=transform)
test_data = datasets.ImageFolder("data/test", transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
# ------------------ Model Setup ------------------
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(train_data.classes))
# ------------------ Training Setup ------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# ------------------ Training Loop ------------------
for epoch in range(5): # 5 epochs for demo
model.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
# ------------------ Evaluation ------------------
correct, total = 0, 0
model.eval()
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100*correct/total:.2f}%")