Skip to content

Commit 2f408c8

Browse files
authored
Adding Arabic Dates Classification using DL (#339)
* Arabic Dates Classification using DL #329 * Updating ArabicDatesClassifierModel.py * Update * Update for Frontend of ArabicDatesClassifierModel * reform for 'Lint and Format / lint-format' * Modify model * 2-space indentation formatting * new edit for ArabicDatesClassifierModel * Modified ArabicDatesClassifierModel (with kaggle model output) * Final edit for ArabicDatesClassifierModel
1 parent 4983a9a commit 2f408c8

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ myenv
2121
*.log
2222

2323
.env
24+
.vscode
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import io
2+
import pickle
3+
import sys
4+
import os
5+
import numpy as np
6+
from PIL import Image
7+
8+
import streamlit as st
9+
import torch
10+
from torch import nn
11+
from torchvision import models, transforms
12+
13+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")))
14+
from src.helpers.kaggle import downloadNotebookOutput
15+
16+
def arabicDatesClassifierModel():
17+
# APP HEADER
18+
st.title("Arabic Dates Classification 🍂")
19+
st.markdown(
20+
"This pretrained model classifies an image of **Arabic Dates** into one of the 9 varieties commonly found in the Arabian region."
21+
)
22+
23+
# LOAD MODEL
24+
@st.cache_resource
25+
def load_model_data():
26+
try:
27+
downloadNotebookOutput("supratikbhowal", "arabic-dates-classification", "notebook")
28+
29+
PICKLE_SAVE_PATH = "notebook/arabic_dates_classnames.pkl"
30+
with open(PICKLE_SAVE_PATH, "rb") as f:
31+
class_names = pickle.load(f)
32+
33+
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
34+
model.fc = nn.Sequential(
35+
nn.Linear(model.fc.in_features, 512),
36+
nn.ReLU(),
37+
nn.Dropout(0.4),
38+
nn.Linear(512, len(class_names))
39+
)
40+
41+
MODEL_SAVE_PATH = "notebook/arabic_dates_model.pth"
42+
checkpoint = torch.load(MODEL_SAVE_PATH, map_location=torch.device("cpu"))
43+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
44+
model.eval()
45+
46+
return model, class_names
47+
48+
except Exception as e:
49+
st.error(f"🚨 Failed to load model or data: {e}")
50+
st.stop()
51+
52+
model, CLASS_NAMES = load_model_data()
53+
st.success("✅ Model and class names loaded successfully!")
54+
55+
# IMAGE PREPROCESSING
56+
transform = transforms.Compose([
57+
transforms.Resize((224, 224)),
58+
transforms.ToTensor(),
59+
transforms.Normalize([0.485, 0.456, 0.406],
60+
[0.229, 0.224, 0.225])
61+
])
62+
63+
def preprocess_image(image: Image.Image):
64+
image = image.convert("RGB")
65+
return transform(image).unsqueeze(0)
66+
67+
def predict(model, image_tensor, class_names):
68+
with torch.no_grad():
69+
outputs = model(image_tensor)
70+
probs = torch.nn.functional.softmax(outputs[0], dim=0)
71+
probs = probs.numpy()
72+
sorted_indices = np.argsort(probs)[::-1]
73+
sorted_probs = probs[sorted_indices]
74+
sorted_labels = [class_names[i] for i in sorted_indices]
75+
return sorted_labels, sorted_probs
76+
77+
# FILE UPLOAD
78+
uploaded_file = st.file_uploader("📸 Upload an image of Arabic Dates", type=["jpg", "jpeg", "png"])
79+
80+
if uploaded_file is not None:
81+
image_data = uploaded_file.read()
82+
image = Image.open(io.BytesIO(image_data))
83+
st.image(image, caption="Uploaded Image", width=400)
84+
st.write("🔍 Analyzing...")
85+
86+
image_tensor = preprocess_image(image)
87+
labels, probs = predict(model, image_tensor, CLASS_NAMES)
88+
89+
top_class = labels[0]
90+
top_prob = probs[0] * 100
91+
92+
st.markdown(f"### ✅ Predicted Class: **{top_class}** ({top_prob:.2f}% confidence)")
93+
94+
prob_dict = {labels[i]: float(probs[i] * 100) for i in range(len(labels))}
95+
st.write("#### 📊 Probability Distribution:")
96+
st.bar_chart(prob_dict)
97+
98+
else:
99+
st.info("👆 Upload a clear image of dates to classify.")
100+
101+
# FOOTER
102+
st.caption("Trained with ResNet50 on the Arabian Dates Dataset — 9 classes")

0 commit comments

Comments
 (0)