!
pip install segmentation-models albumentations opencv-python
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"
# --- 1. Install dependencies ---
!pip install --quiet megatools nibabel segmentation-models
# --- 2. Download dataset from MEGA ---
import os
# Download megatools
if not os.path.exists("megatools"):
!apt-get install -y megatools
# 3. Download the MRI dataset ZIP file from MEGA
mri_url =
"https://mega.nz/file/O45BRBJD#CA8XAaACqlymX3MhcGkJzK8DNp8vZYoxxQW2pBcl4wM"
!megadl {mri_url} --path dataset.zip
# 4. Unzip the MRI dataset
!unzip -q dataset.zip -d dataset
# 5. download the clinical data from MEGA (if not using the CSV we already
prepared)
clinical_url = "https://mega.nz/file/v1RWUY7R#UIgjHQWMqC6BgtvrfeQxnr0sMR0Q79Ek37-
LFrqeYIU"
!megadl {clinical_url} --path clinical_data.csv
# --- INSTALL DEPENDENCIES ---
# Run this only in Colab or if not already installed
!pip install segmentation-models albumentations opencv-python nibabel
# --- SET ENV VARIABLES ---
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"
# --- IMPORT LIBRARIES ---
import numpy as np
import nibabel as nib
import tensorflow as tf
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D
import segmentation_models as sm
import cv2
# --- FUNCTION: DENSENET FEATURE EXTRACTION ---
def extract_mri_features_densenet(slice_img):
resized = cv2.resize(slice_img, (224, 224))
img_rgb = np.stack([resized] * 3, axis=-1) # Convert to 3 channels
img_rgb = np.expand_dims(img_rgb, axis=0) # Add batch dim
img_rgb = tf.keras.applications.densenet.preprocess_input(img_rgb)
base_model = DenseNet201(weights="imagenet", include_top=False,
input_shape=(224, 224, 3))
x = GlobalAveragePooling2D()(base_model.output)
model = Model(inputs=base_model.input, outputs=x)
features = model.predict(img_rgb)
return features
# --- FUNCTION: U-NET LESION SEGMENTATION ---
def segment_lesions_unet(slice_img):
resized = cv2.resize(slice_img, (256, 256))
resized = np.expand_dims(resized, axis=(0, -1)) / 255.0
unet_model = sm.Unet('efficientnetb0', input_shape=(256, 256, 1),
encoder_weights=None, classes=1, activation='sigmoid')
unet_model.compile(optimizer='adam', loss='binary_crossentropy')
prediction = unet_model.predict(resized)[0, :, :, 0]
binary_mask = (prediction > 0.5).astype(np.uint8)
lesion_volume = np.sum(binary_mask)
return lesion_volume, binary_mask
# --- DEMO USAGE WITH FAKE SLICE ---
sample_mri_slice = np.random.rand(256, 256) * 255
sample_mri_slice = sample_mri_slice.astype(np.uint8)
# Feature extraction
deep_features = extract_mri_features_densenet(sample_mri_slice)
lesion_volume, lesion_mask = segment_lesions_unet(sample_mri_slice)
print("DenseNet features shape:", deep_features.shape)
print("Lesion volume (pixels):", lesion_volume)
import os
import nibabel as nib
import numpy as np
import tensorflow as tf
import cv2
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D
import segmentation_models as sm
from tqdm import tqdm
# Setup env and model
os.environ["SM_FRAMEWORK"] = "tf.keras"
sm.set_framework('tf.keras')
sm.framework()
# --- Build DenseNet model ---
base_model = DenseNet201(weights="imagenet", include_top=False, input_shape=(224,
224, 3))
x = GlobalAveragePooling2D()(base_model.output)
densenet_model = Model(inputs=base_model.input, outputs=x)
# --- Build U-Net model (Untrained for now) ---
unet_model = sm.Unet('efficientnetb0', input_shape=(256, 256, 1),
encoder_weights=None, classes=1, activation='sigmoid')
unet_model.compile(optimizer='adam', loss='binary_crossentropy')
# --- Helper functions ---
def extract_densenet_features(slice_img):
resized = cv2.resize(slice_img, (224, 224))
img_rgb = np.stack([resized] * 3, axis=-1)
img_rgb = np.expand_dims(img_rgb, axis=0)
img_rgb = tf.keras.applications.densenet.preprocess_input(img_rgb)
return densenet_model.predict(img_rgb)
def segment_lesions(slice_img):
resized = cv2.resize(slice_img, (256, 256))
resized = np.expand_dims(resized, axis=(0, -1)) / 255.0
prediction = unet_model.predict(resized)[0, :, :, 0]
binary_mask = (prediction > 0.5).astype(np.uint8)
return np.sum(binary_mask)
# --- Process all patients ---
root_dir = "dataset" # Replace with your actual path
results = []
for patient_id in tqdm(os.listdir(root_dir)):
patient_path = os.path.join(root_dir, patient_id)
if not os.path.isdir(patient_path):
continue
# Find MRI file (e.g., FLAIR.nii.gz or T2.nii.gz)
mri_file = None
for f in os.listdir(patient_path):
if f.endswith(".nii") or f.endswith(".nii.gz"):
mri_file = os.path.join(patient_path, f)
break
if not mri_file:
print(f"No MRI found for {patient_id}")
continue
# Load MRI volume
img = nib.load(mri_file).get_fdata()
num_slices = img.shape[2]
patient_features = []
total_lesion_volume = 0
for i in range(num_slices):
slice_img = img[:, :, i]
if np.max(slice_img) == 0: # Skip empty slices
continue
slice_img = (slice_img / np.max(slice_img) * 255).astype(np.uint8)
# Extract DenseNet features
features = extract_densenet_features(slice_img)
patient_features.append(features[0])
# Segment lesion and count pixels
lesion_pixels = segment_lesions(slice_img)
total_lesion_volume += lesion_pixels
# Aggregate features (mean of slices)
if patient_features:
aggregated_features = np.mean(np.array(patient_features), axis=0)
else:
aggregated_features = np.zeros((1920,))
results.append({
"patient_id": patient_id,
"features": aggregated_features,
"lesion_volume": total_lesion_volume
})
# --- Convert to DataFrame or Save ---
import pandas as pd
# Flatten the 1920-D features into columns
df = pd.DataFrame([
{"patient_id": r["patient_id"], "lesion_volume": r["lesion_volume"],
**{f"f{i}": v for i, v in enumerate(r["features"])}}
for r in results
])
# Save to CSV
df.to_csv("extracted_features_per_patient.csv", index=False)
print("✅ Feature extraction complete. Saved to extracted_features_per_patient.csv")
import os
import tarfile
import requests
from tqdm import tqdm
def download_ixi(output_dir="ixi_data"):
os.makedirs(output_dir, exist_ok=True)
urls = [
"http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar",
"http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T2.tar"
]
for url in urls:
print(f"Downloading {url.split('/')[-1]}...")
response = requests.get(url, stream=True)
tar_path = os.path.join(output_dir, url.split('/')[-1])
with open(tar_path, 'wb') as f:
for chunk in tqdm(response.iter_content(chunk_size=1024)):
if chunk:
f.write(chunk)
print("Extracting...")
with tarfile.open(tar_path) as tar:
tar.extractall(path=output_dir)
os.remove(tar_path) # Clean up
download_ixi()
import pandas as pd
import nibabel as nib
import numpy as np
from tqdm import tqdm
def organize_ixi_data(data_dir="ixi_data"):
"""
Organizes downloaded IXI files and creates a metadata CSV with:
- Subject IDs
- File paths (T1, T2, synthetic FLAIR)
- Basic demographics (age/sex simulated)
- Healthy control label (0)
"""
# Create metadata dictionary
metadata = []
# Get all unique subject IDs
subject_ids = set()
for filename in os.listdir(data_dir):
if filename.endswith(".nii.gz"):
subject_id = filename.split('-')[0]
subject_ids.add(subject_id)
# Process each subject
for subject_id in tqdm(subject_ids, desc="Organizing IXI data"):
subject_files = {
'id': subject_id,
't1': None,
't2': None,
'flair': None,
'age': np.random.randint(20, 80), # Simulated age
'gender': np.random.choice(['M', 'F']), # Simulated gender
'label': 0 # Healthy control
}
# Find files for this subject
for filename in os.listdir(data_dir):
if filename.startswith(subject_id):
if 'T1' in filename:
subject_files['t1'] = os.path.join(data_dir, filename)
elif 'T2' in filename:
subject_files['t2'] = os.path.join(data_dir, filename)
# Create synthetic FLAIR path (even if file doesn't exist)
if subject_files['t1']:
subject_files['flair'] = subject_files['t1'].replace('T1', 'FLAIR')
metadata.append(subject_files)
# Convert to DataFrame and save
df = pd.DataFrame(metadata)
df.to_csv(os.path.join(data_dir, "ixi_metadata.csv"), index=False)
print(f"Created metadata for {len(df)} subjects")
return df
# Run the organization
ixi_metadata = organize_ixi_data()
def preprocess_ixi(data_dir="ixi_data"):
"""
Applies basic preprocessing to all IXI scans:
1. NIfTI file validation
2. Intensity normalization
3. Skull-stripping simulation
4. Resampling to uniform resolution
"""
from scipy.ndimage import zoom
import warnings
warnings.filterwarnings('ignore', category=UserWarning) # Suppress nibabel
warnings
# Load metadata
metadata_path = os.path.join(data_dir, "ixi_metadata.csv")
df = pd.read_csv(metadata_path)
# Create processed directory
processed_dir = os.path.join(data_dir, "processed")
os.makedirs(processed_dir, exist_ok=True)
# Target resolution (2mm isotropic)
target_shape = (182, 218, 182) # Common MNI-like dimensions
for _, row in tqdm(df.iterrows(), total=len(df), desc="Preprocessing"):
try:
for modality in ['t1', 't2']:
if pd.notna(row[modality]):
# Load image
img = nib.load(row[modality])
data = img.get_fdata()
# 1. Simple skull-stripping simulation
mask = data > np.percentile(data, 10)
data = data * mask
# 2. White matter peak normalization
wm_mask = (data > np.percentile(data, 85)) & (data <
np.percentile(data, 95))
if wm_mask.sum() > 0: # Avoid division by zero
data = data / np.median(data[wm_mask])
# 3. Resampling (if needed)
if data.shape != target_shape:
zoom_factors = [t/s for t,s in zip(target_shape,
data.shape)]
data = zoom(data, zoom_factors, order=1)
# Save processed
out_path = os.path.join(processed_dir,
f"{row['id']}_{modality}.nii.gz")
nib.save(nib.Nifti1Image(data, img.affine), out_path)
except Exception as e:
print(f"Error processing {row['id']}: {str(e)}")
print("Preprocessing complete. Processed files saved in:", processed_dir)
# Run preprocessing
preprocess_ixi()
def create_synthetic_flair(data_dir="ixi_data"):
"""
Generates synthetic FLAIR images from T1 scans using:
FLAIR ≈ 0.8*T1 + 0.2*T2 (simplified approximation)
"""
processed_dir = os.path.join(data_dir, "processed")
df = pd.read_csv(os.path.join(data_dir, "ixi_metadata.csv"))
for _, row in tqdm(df.iterrows(), total=len(df), desc="Creating synthetic
FLAIR"):
try:
# Load processed T1 and T2
t1_path = os.path.join(processed_dir, f"{row['id']}_t1.nii.gz")
t2_path = os.path.join(processed_dir, f"{row['id']}_t2.nii.gz")
if os.path.exists(t1_path) and os.path.exists(t2_path):
t1 = nib.load(t1_path).get_fdata()
t2 = nib.load(t2_path).get_fdata()
# Simple fusion (adjust weights as needed)
flair = 0.8*t1 + 0.2*t2
flair = np.clip(flair, 0, 1) # Ensure valid intensity range
# Save synthetic FLAIR
out_path = os.path.join(processed_dir, f"{row['id']}_flair.nii.gz")
nib.save(nib.Nifti1Image(flair, nib.load(t1_path).affine),
out_path)
except Exception as e:
print(f"Error creating FLAIR for {row['id']}: {str(e)}")
print("Synthetic FLAIR generation complete")
# Run if you need FLAIR scans
create_synthetic_flair()
def verify_processed_data(data_dir="ixi_data/processed"):
"""
Validates all processed scans and updates metadata
"""
import matplotlib.pyplot as plt
metadata_path = os.path.join(os.path.dirname(data_dir), "ixi_metadata.csv")
df = pd.read_csv(metadata_path)
# Create quality report directory
report_dir = os.path.join(os.path.dirname(data_dir), "quality_reports")
os.makedirs(report_dir, exist_ok=True)
good_scans = []
for _, row in tqdm(df.iterrows(), total=len(df), desc="Quality checking"):
try:
# Check all modalities exist
modalities = []
for mod in ['t1', 't2', 'flair']:
mod_path = os.path.join(data_dir, f"{row['id']}_{mod}.nii.gz")
if os.path.exists(mod_path):
img = nib.load(mod_path)
data = img.get_fdata()
# Basic quality checks
assert not np.isnan(data).any(), "NaN values found"
assert data.max() > 0, "Empty scan"
modalities.append(mod)
# Plot sample slices if all modalities exist
if len(modalities) >= 2: # At least T1+T2
fig, axes = plt.subplots(1, len(modalities), figsize=(15, 5))
for i, mod in enumerate(modalities):
img = nib.load(os.path.join(data_dir,
f"{row['id']}_{mod}.nii.gz"))
axes[i].imshow(img.get_fdata()[:, :, img.shape[2]//2],
cmap='gray')
axes[i].set_title(f"{row['id']} {mod.upper()}")
plt.savefig(os.path.join(report_dir, f"{row['id']}_qc.png"))
plt.close()
good_scans.append(row['id'])
except Exception as e:
print(f"Quality check failed for {row['id']}: {str(e)}")
# Update metadata with only good scans
clean_df = df[df['id'].isin(good_scans)]
clean_df.to_csv(metadata_path, index=False)
print(f"Quality check complete. {len(good_scans)}/{len(df)} scans passed")
return clean_df
# Run quality check
verified_metadata = verify_processed_data()