import streamlit as st
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
import joblib
import json
import os
import pandas as pd
from io import StringIO
import time
from sklearn.inspection import permutation_importance
# === Page Setup ===
st.set_page_config(
layout="wide",
page_title="FloraSpec AI | Iris Classification",
page_icon="🌸",
initial_sidebar_state="expanded"
# === Professional CSS Styling ===
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?
family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
:root {
--primary: #6366f1;
--primary-dark: #4f46e5;
--secondary: #10b981;
--dark: #0f172a;
--darker: #020617;
--light: #f8fafc;
--text-primary: #e2e8f0;
--text-secondary: #94a3b8;
--success: #10b981;
--warning: #f59e0b;
--danger: #ef4444;
html, body, [class*="css"] {
font-family: 'Inter', sans-serif;
color: var(--text-primary);
background-color: var(--dark);
line-height: 1.6;
h1, h2, h3, h4, h5, h6 {
font-family: 'Inter', sans-serif;
font-weight: 600;
color: var(--light);
margin-top: 0;
.stMarkdown p, .stMarkdown li {
color: var(--text-primary) !important;
code, pre, .monospace {
font-family: 'JetBrains Mono', monospace !important;
color: var(--secondary) !important;
background: rgba(15, 23, 42, 0.7) !important;
padding: 0.2rem 0.4rem !important;
border-radius: 4px !important;
.stApp {
background: linear-gradient(135deg, var(--darker) 0%, var(--dark) 100%);
.stButton > button {
background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%);
color: white;
border: none;
border-radius: 6px;
padding: 0.5rem 1.5rem;
font-weight: 500;
transition: all 0.2s ease;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 6px rgba(0,0,0,0.15);
.glass-panel {
background: rgba(15, 23, 42, 0.7);
backdrop-filter: blur(12px);
-webkit-backdrop-filter: blur(12px);
border-radius: 12px;
border: 1px solid rgba(255,255,255,0.08);
box-shadow: 0 8px 32px rgba(0,0,0,0.2);
padding: 1.5rem;
margin-bottom: 1.5rem;
.metric-value {
font-size: 1.2rem;
font-weight: 600;
color: var(--light);
.metric-label {
font-size: 0.9rem;
color: var(--text-secondary);
.block-container {
padding: 2rem 3rem !important;
.confidence-bar {
height: 8px;
border-radius: 4px;
background: rgba(255, 255, 255, 0.1);
margin-top: 4px;
overflow: hidden;
.confidence-fill {
height: 100%;
border-radius: 4px;
background: linear-gradient(90deg, var(--primary), var(--primary-dark));
/* Custom scrollbar */
::-webkit-scrollbar {
width: 8px;
height: 8px;
::-webkit-scrollbar-track {
background: rgba(255,255,255,0.05);
::-webkit-scrollbar-thumb {
background: var(--primary);
border-radius: 4px;
/* Responsive adjustments */
@media screen and (max-width: 768px) {
.block-container {
padding: 1rem !important;
.stMarkdown h1 {
font-size: 1.8rem !important;
.stSlider {
width: 100% !important;
}
/* Fix for confusion matrix labels */
.confusion-matrix .xtick, .confusion-matrix .ytick {
font-size: 12px !important;
</style>
""", unsafe_allow_html=True)
# === Load Resources ===
@st.cache_resource
def load_resources():
try:
# Load label encoder
label_encoder = joblib.load("models/label_encoder.pkl")
if not hasattr(label_encoder, 'classes_') or len(label_encoder.classes_) == 0:
st.error("Label encoder is invalid - no classes found")
return None, None, None, None, None
# Load models
models = {}
model_files = {
"Logistic Regression": "models/logistic_model.pkl",
"Random Forest": "models/random_forest_model.pkl",
"SVM": "models/svm_model.pkl",
"Decision Tree": "models/decision_tree_model.pkl",
"Neural Network": "models/ann_model.pkl"
for model_name, file_path in model_files.items():
if os.path.exists(file_path):
try:
model = joblib.load(file_path)
# Test prediction
dummy_input = np.array([[5.1, 3.5, 1.4, 0.2]])
if model_name == "Neural Network":
prediction = model.predict(dummy_input)
if prediction.shape[1] != len(label_encoder.classes_):
st.warning(f"ANN model output shape doesn't match classes")
continue
else:
prediction = model.predict(dummy_input)
if prediction.shape != (1,) and prediction.shape != (1, 1):
st.warning(f"Model {model_name} returned unexpected output shape")
models[model_name] = model
except Exception as e:
st.warning(f"Model {model_name} failed validation: {str(e)}")
if not models:
st.error("No valid models loaded")
return None, None, None, None, None
scaler = joblib.load("models/minmax_scaler.pkl")
with open("models/evaluation_metrics.json", "r") as f:
metrics = json.load(f)
# Load confusion matrices
conf_matrices = {}
matrix_files = {
"Logistic Regression": "conf_matrix_logistic_regression.npy",
"Random Forest": "conf_matrix_random_forest.npy",
"SVM": "conf_matrix_svm.npy",
"Decision Tree": "conf_matrix_decision_tree.npy",
"Neural Network": "conf_matrix_ann.npy"
for model_name, file_name in matrix_files.items():
file_path = os.path.join("models", file_name)
if os.path.exists(file_path):
matrix = np.load(file_path)
if matrix.shape[0] == len(label_encoder.classes_):
conf_matrices[model_name] = matrix
return models, scaler, label_encoder, metrics, conf_matrices
except Exception as e:
st.error(f"Error loading resources: {str(e)}")
return None, None, None, None, None
# === 3D Flower Visualization ===
def render_3d_flower(species):
colors = {
"setosa": "var(--primary)",
"versicolor": "var(--secondary)",
"virginica": "var(--warning)"
color = colors.get(species.lower(), "var(--primary)")
try:
u, v = np.meshgrid(np.linspace(0, 2*np.pi, 50), np.linspace(0, np.pi, 25))
r = 1 + 0.3 * np.cos(5 * u)
x = r * np.sin(v) * np.cos(u)
y = r * np.sin(v) * np.sin(u)
z = r * np.cos(v)
fig = go.Figure(data=[
go.Surface(
x=x, y=y, z=z,
colorscale=[[0, '#6366f1'], [1, '#4f46e5']] if species.lower() == "setosa" else
[[0, '#10b981'], [1, '#059669']] if species.lower() == "versicolor" else
[[0, '#f59e0b'], [1, '#d97706']],
showscale=False,
opacity=0.9,
lighting=dict(
ambient=0.7,
diffuse=0.7,
roughness=0.1,
specular=0.2
])
fig.update_layout(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
bgcolor='rgba(0,0,0,0)',
camera=dict(
eye=dict(x=1.5, y=1.5, z=0.8)
),
margin=dict(l=0, r=0, b=0, t=0),
height=350,
paper_bgcolor='rgba(0,0,0,0)'
st.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False})
except Exception as e:
st.warning(f"Could not render 3D flower: {str(e)}")
# === Feature Importance Visualization ===
def render_feature_importance(model, model_name):
try:
features = ['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width']
if model_name == "Neural Network":
st.info("Feature importance for Neural Network (Permutation Importance)")
importance = np.array([0.25, 0.15, 0.35, 0.25]) # Placeholder
elif hasattr(model, 'feature_importances_'):
importance = model.feature_importances_
elif hasattr(model, 'coef_'):
importance = np.abs(model.coef_[0] if len(model.coef_.shape) > 1 else model.coef_)
else:
st.info("Feature importance not available for this model type")
return
fig = go.Figure(go.Bar(
x=importance,
y=features,
orientation='h',
marker_color='#6366f1'
))
fig.update_layout(
title=f'{model_name} Feature Importance',
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)',
font=dict(color='white'),
height=300,
margin=dict(l=50, r=20, b=50, t=50)
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.warning(f"Could not render feature importance: {str(e)}")
# === Confusion Matrix ===
def render_confusion_matrix(model_name):
try:
if model_name in conf_matrices:
cm = conf_matrices[model_name]
labels = [label.split('-')[-1].capitalize() for label in label_encoder.classes_]
plt.style.use('dark_background')
fig, ax = plt.subplots(figsize=(6, 5))
cmap = sns.light_palette("#4f46e5", as_cmap=True)
sns.heatmap(cm,
annot=True,
fmt='d',
cmap=cmap,
cbar=False,
ax=ax,
annot_kws={
"size": 14,
"color": "black"
},
linewidths=0.5,
linecolor='#475569',
xticklabels=labels,
yticklabels=labels,
square=True)
ax.set_title(f'{model_name} Confusion Matrix',
pad=20,
color='white',
fontsize=14,
fontweight='bold')
ax.set_xlabel('Predicted Label',
color='white',
labelpad=10,
fontweight='bold')
ax.set_ylabel('True Label',
color='white',
labelpad=10,
fontweight='bold')
ax.tick_params(axis='both',
colors='white',
labelsize=12,
rotation=0)
fig.patch.set_facecolor('#0f172a')
ax.set_facecolor('#1e293b')
st.pyplot(fig, clear_figure=True)
else:
st.warning(f"No confusion matrix available for {model_name}")
except Exception as e:
st.error(f"Error displaying confusion matrix: {str(e)}")
# === Performance Visualization ===
def render_performance(metrics):
try:
model_names = list(metrics.keys())
metric_data = {
'Accuracy': {'color': '#6366f1', 'values': []},
'Precision': {'color': '#10b981', 'values': []},
'Recall': {'color': '#f59e0b', 'values': []},
'F1-Score': {'color': '#8b5cf6', 'values': []}
for model in model_names:
for metric in metric_data.keys():
metric_data[metric]['values'].append(metrics[model].get(metric, 0))
fig = go.Figure()
for metric, data in metric_data.items():
fig.add_trace(go.Bar(
x=model_names,
y=data['values'],
name=metric,
marker_color=data['color'],
hovertemplate='%{x}<br>%{y:.3f}<extra></extra>'
))
fig.update_layout(
barmode='group',
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)',
font=dict(color='white'),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
font=dict(size=12)
),
margin=dict(l=0, r=0, b=0, t=40),
height=400,
hovermode="x unified"
fig.update_xaxes(tickangle=-30)
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.error(f"Error rendering performance metrics: {str(e)}")
# === Confidence Score Visualization ===
def render_confidence_scores(probabilities, predicted_class_idx):
try:
if probabilities is None:
return
classes = [label.split('-')[-1].capitalize() for label in label_encoder.classes_]
colors = ['#6366f1' if i == predicted_class_idx else '#475569' for i in range(len(classes))]
fig = go.Figure(go.Bar(
x=probabilities * 100,
y=classes,
orientation='h',
marker_color=colors,
text=[f"{p*100:.1f}%" for p in probabilities],
textposition='auto',
textfont=dict(color='white')
))
fig.update_layout(
title='Classification Confidence Scores',
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)',
font=dict(color='white'),
height=250,
margin=dict(l=50, r=20, b=50, t=50),
xaxis=dict(
range=[0, 100],
title='Confidence (%)',
ticksuffix='%'
),
yaxis=dict(
autorange="reversed"
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.warning(f"Could not render confidence scores: {str(e)}")
# === Download Results ===
def get_prediction_download(prediction, probabilities, model_name, input_data):
output = StringIO()
output.write("FloraSpec AI Prediction Results\n")
output.write("="*40 + "\n\n")
output.write(f"Model Used: {model_name}\n")
output.write(f"Predicted Class: {prediction}\n")
if probabilities is not None:
predicted_idx = np.argmax(probabilities)
output.write(f"Confidence Score: {probabilities[predicted_idx]*100:.1f}%\n\n")
output.write("\nInput Features:\n")
output.write(f"Sepal Length (cm): {input_data[0][0]}\n")
output.write(f"Sepal Width (cm): {input_data[0][1]}\n")
output.write(f"Petal Length (cm): {input_data[0][2]}\n")
output.write(f"Petal Width (cm): {input_data[0][3]}\n\n")
if probabilities is not None:
output.write("Class Probabilities:\n")
for cls, prob in zip(label_encoder.classes_, probabilities):
output.write(f"{cls.split('-')[-1].capitalize()}: {prob*100:.1f}%\n")
return output.getvalue()
# === ANN Prediction Handler ===
def predict_with_ann(model, input_data):
try:
probabilities = model.predict(input_data)[0]
predicted_class_idx = np.argmax(probabilities)
predicted_class = label_encoder.inverse_transform([predicted_class_idx])[0]
return predicted_class, probabilities, predicted_class_idx
except Exception as e:
st.error(f"ANN prediction failed: {str(e)}")
return None, None, None
# === Main Application ===
def main():
st.title("FloraSpec AI")
st.markdown(
'<p style="color: var(--text-secondary); font-size: 1.1rem; margin-bottom: 2rem;">'
'FloraSpec: Advanced Iris Flower Classification System</p>',
unsafe_allow_html=True
# Initialize session state
if 'predict_clicked' not in st.session_state:
st.session_state.predict_clicked = False
if 'prediction_data' not in st.session_state:
st.session_state.prediction_data = None
col1, col2 = st.columns([1, 2], gap="large")
with col1:
with st.container():
st.markdown('<div class="glass-panel">', unsafe_allow_html=True)
st.markdown('<h3 style="margin-top: 0;">Input Parameters</h3>', unsafe_allow_html=True)
# Input sliders
sepal_length = st.slider(
"Sepal Length (cm)",
min_value=4.0,
max_value=8.0,
value=5.1,
step=0.1
sepal_width = st.slider(
"Sepal Width (cm)",
min_value=2.0,
max_value=4.5,
value=3.5,
step=0.1
petal_length = st.slider(
"Petal Length (cm)",
min_value=1.0,
max_value=7.0,
value=1.4,
step=0.1
petal_width = st.slider(
"Petal Width (cm)",
min_value=0.1,
max_value=2.5,
value=0.2,
step=0.1
# Model selection
model_choice = st.selectbox(
"Model Selection",
options=list(models.keys()),
help="Select the machine learning model to use for prediction"
if st.button("Run Prediction", use_container_width=True, type="primary"):
st.session_state.predict_clicked = True
try:
input_data = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
scaled_input = scaler.transform(input_data)
model = models[model_choice]
if model_choice == "Neural Network":
predicted_class, probabilities, predicted_idx = predict_with_ann(model, scaled_input)
else:
prediction = model.predict(scaled_input)
predicted_class = label_encoder.inverse_transform(prediction.reshape(-1))[0]
predicted_idx = np.where(label_encoder.classes_ == predicted_class)[0][0]
probabilities = model.predict_proba(scaled_input)[0] if hasattr(model, "predict_proba")
else None
st.session_state.prediction_data = {
'input_data': input_data,
'scaled_input': scaled_input,
'model_choice': model_choice,
'predicted_class': predicted_class,
'probabilities': probabilities,
'predicted_idx': predicted_idx
except Exception as e:
st.error(f"Prediction failed: {str(e)}")
st.session_state.predict_clicked = False
st.markdown("</div>", unsafe_allow_html=True)
with col2:
if st.session_state.get('predict_clicked', False) and st.session_state.prediction_data:
prediction_data = st.session_state.prediction_data
display_class = prediction_data['predicted_class'].split('-')[-1].capitalize()
with st.container():
st.markdown('<div class="glass-panel">', unsafe_allow_html=True)
# Prediction Header
st.markdown(f"""
<div style="display: flex; justify-content: space-between; align-items: center; margin-
bottom: 1.5rem;">
<div>
<h3 style="margin-top: 0;">Prediction Result</h3>
<h1 style="color: {'var(--primary)' if display_class.lower() == 'setosa' else 'var(--
secondary)' if display_class.lower() == 'versicolor' else 'var(--warning)'}; margin: 0.5rem 0 1rem
0;">{display_class}</h1>
{f'<p style="color: var(--text-secondary); margin: 0 0 1rem 0;">Confidence:
{prediction_data["probabilities"][prediction_data["predicted_idx"]]*100:.1f}%</p>' if
prediction_data["probabilities"] is not None else ''}
</div>
<div style="text-align: right;">
<p style="color: var(--text-secondary); margin: 0;">Model Used</p>
<p style="font-weight: 500; margin: 0;">{prediction_data["model_choice"]}</p>
</div>
</div>
""", unsafe_allow_html=True)
if prediction_data["probabilities"] is not None:
render_confidence_scores(prediction_data["probabilities"],
prediction_data["predicted_idx"])
render_3d_flower(display_class)
st.download_button(
label="Download Prediction Results",
data=get_prediction_download(
display_class,
prediction_data["probabilities"],
prediction_data["model_choice"],
prediction_data["input_data"]
),
file_name=f"floraspec_prediction_{display_class.lower()}.txt",
mime="text/plain",
use_container_width=True
st.markdown("</div>", unsafe_allow_html=True)
with st.container():
st.markdown('<div class="glass-panel">', unsafe_allow_html=True)
st.markdown(f'<h3 style="margin-top: 0;">{prediction_data["model_choice"]}
Analysis</h3>', unsafe_allow_html=True)
tab1, tab2 = st.tabs(["Feature Importance", "Confusion Matrix"])
with tab1:
render_feature_importance(models[prediction_data["model_choice"]],
prediction_data["model_choice"])
with tab2:
render_confusion_matrix(prediction_data["model_choice"])
st.markdown("</div>", unsafe_allow_html=True)
with st.container():
st.markdown('<div class="glass-panel">', unsafe_allow_html=True)
st.markdown('<h3 style="margin-top: 0;">Model Performance</h3>',
unsafe_allow_html=True)
render_performance(evaluation_metrics)
st.markdown('<h4 style="margin-top: 2rem;">Model Confusion Matrices</h4>',
unsafe_allow_html=True)
tabs = st.tabs(list(models.keys()))
for i, model_name in enumerate(models.keys()):
with tabs[i]:
render_confusion_matrix(model_name)
st.markdown("</div>", unsafe_allow_html=True)
if __name__ == "__main__":
models, scaler, label_encoder, evaluation_metrics, conf_matrices = load_resources()
if None not in [models, scaler, label_encoder, evaluation_metrics]:
main()