1+ import io
2+ import pickle
3+ import sys
4+ import os
5+ import numpy as np
6+ from PIL import Image
7+
8+ import streamlit as st
9+ import torch
10+ from torch import nn
11+ from torchvision import models , transforms
12+
13+ sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../../../../.." )))
14+ from src .helpers .kaggle import downloadNotebookOutput
15+
16+ def arabicDatesClassifierModel ():
17+ # APP HEADER
18+ st .title ("Arabic Dates Classification 🍂" )
19+ st .markdown (
20+ "This pretrained model classifies an image of **Arabic Dates** into one of the 9 varieties commonly found in the Arabian region."
21+ )
22+
23+ # LOAD MODEL
24+ @st .cache_resource
25+ def load_model_data ():
26+ try :
27+ downloadNotebookOutput ("supratikbhowal" , "arabic-dates-classification" , "notebook" )
28+
29+ PICKLE_SAVE_PATH = "notebook/arabic_dates_classnames.pkl"
30+ with open (PICKLE_SAVE_PATH , "rb" ) as f :
31+ class_names = pickle .load (f )
32+
33+ model = models .resnet50 (weights = models .ResNet50_Weights .DEFAULT )
34+ model .fc = nn .Sequential (
35+ nn .Linear (model .fc .in_features , 512 ),
36+ nn .ReLU (),
37+ nn .Dropout (0.4 ),
38+ nn .Linear (512 , len (class_names ))
39+ )
40+
41+ MODEL_SAVE_PATH = "notebook/arabic_dates_model.pth"
42+ checkpoint = torch .load (MODEL_SAVE_PATH , map_location = torch .device ("cpu" ))
43+ model .load_state_dict (checkpoint ["model_state_dict" ], strict = False )
44+ model .eval ()
45+
46+ return model , class_names
47+
48+ except Exception as e :
49+ st .error (f"🚨 Failed to load model or data: { e } " )
50+ st .stop ()
51+
52+ model , CLASS_NAMES = load_model_data ()
53+ st .success ("✅ Model and class names loaded successfully!" )
54+
55+ # IMAGE PREPROCESSING
56+ transform = transforms .Compose ([
57+ transforms .Resize ((224 , 224 )),
58+ transforms .ToTensor (),
59+ transforms .Normalize ([0.485 , 0.456 , 0.406 ],
60+ [0.229 , 0.224 , 0.225 ])
61+ ])
62+
63+ def preprocess_image (image : Image .Image ):
64+ image = image .convert ("RGB" )
65+ return transform (image ).unsqueeze (0 )
66+
67+ def predict (model , image_tensor , class_names ):
68+ with torch .no_grad ():
69+ outputs = model (image_tensor )
70+ probs = torch .nn .functional .softmax (outputs [0 ], dim = 0 )
71+ probs = probs .numpy ()
72+ sorted_indices = np .argsort (probs )[::- 1 ]
73+ sorted_probs = probs [sorted_indices ]
74+ sorted_labels = [class_names [i ] for i in sorted_indices ]
75+ return sorted_labels , sorted_probs
76+
77+ # FILE UPLOAD
78+ uploaded_file = st .file_uploader ("📸 Upload an image of Arabic Dates" , type = ["jpg" , "jpeg" , "png" ])
79+
80+ if uploaded_file is not None :
81+ image_data = uploaded_file .read ()
82+ image = Image .open (io .BytesIO (image_data ))
83+ st .image (image , caption = "Uploaded Image" , width = 400 )
84+ st .write ("🔍 Analyzing..." )
85+
86+ image_tensor = preprocess_image (image )
87+ labels , probs = predict (model , image_tensor , CLASS_NAMES )
88+
89+ top_class = labels [0 ]
90+ top_prob = probs [0 ] * 100
91+
92+ st .markdown (f"### ✅ Predicted Class: **{ top_class } ** ({ top_prob :.2f} % confidence)" )
93+
94+ prob_dict = {labels [i ]: float (probs [i ] * 100 ) for i in range (len (labels ))}
95+ st .write ("#### 📊 Probability Distribution:" )
96+ st .bar_chart (prob_dict )
97+
98+ else :
99+ st .info ("👆 Upload a clear image of dates to classify." )
100+
101+ # FOOTER
102+ st .caption ("Trained with ResNet50 on the Arabian Dates Dataset — 9 classes" )
0 commit comments