Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sparrow.env
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ SERVER_BASE_URL=https://server.sparrow-earth.com
TZ=Etc/UTC
ONLY_SAVE_ANIMALS=true
FTP_USER=camera
FTP_PASS=
FTP_PASS=
DRAW_BOXES=true
187 changes: 135 additions & 52 deletions sparrow/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
"""
This script uses Triton Inference Server to perform object detection using the MegaDetectorV6 model,
and then for each "animal" detection, it crops the bounding box and sends it to a classification
model (e.g., AI4GAmazonClassification) for species classification. The final annotated image is saved
along with logging details in a CSV file.
model (e.g., AI4GAmazonClassification) for species classification.

Results are:
- Logged to CSV
- For JPEG outputs, all bounding boxes (with labels & scores) are stored as JSON in EXIF
UserComment so server-side can turn overlays on/off later.

By default, the saved image pixels are on (boxes drawn). You can disable drawing with:
DRAW_BOXES=flase
"""

import os
Expand All @@ -13,6 +20,7 @@
import logging
import threading
from datetime import datetime

from PIL import Image, ImageFile, ImageDraw, ImageFont
import numpy as np
import tritonclient.http as httpclient
Expand All @@ -23,6 +31,7 @@
from filelock import FileLock
from utils.sparrow_id import get_hardware_id
from utils.detection_utils import non_max_suppression, scale_boxes
import piexif # EXIF metadata

# Setup Logging & Folders
LOGS_DIR = "/app/logs"
Expand All @@ -40,6 +49,7 @@
log = logging.getLogger("inference")

ONLY_SAVE_ANIMALS = os.getenv("ONLY_SAVE_ANIMALS", "false").strip().lower() == "true"
DRAW_BOXES = os.getenv("DRAW_BOXES", "true").strip().lower() == "true"

# Model Config Sync
CONFIG_DIR = "/app/config"
Expand Down Expand Up @@ -98,6 +108,7 @@

os.makedirs(CONFIG_DIR, exist_ok=True)


def load_model_config():
"""Load model_settings.json (create default if missing)."""
if not os.path.isfile(MODEL_CONFIG_FILE):
Expand All @@ -112,6 +123,7 @@ def load_model_config():
model_logger.error(f"Failed to load model_settings.json: {e}")
return DEFAULT_MODEL_CONFIG.copy()


def save_model_config(config):
"""Atomically save model_settings.json."""
tmp_path = f"{MODEL_CONFIG_FILE}.tmp"
Expand All @@ -123,6 +135,7 @@ def save_model_config(config):
except Exception as e:
model_logger.error(f"Failed to save model_settings.json: {e}")


def fetch_model_settings(unique_id, auth_key):
"""
Fetch updated model settings from the server.
Expand All @@ -145,61 +158,50 @@ def fetch_model_settings(unique_id, auth_key):
except Exception as e:
model_logger.warning(f"Could not fetch model settings: {e}")


def model_settings_fetch_loop(unique_id, auth_key):
"""
Background thread that pings the server every 1 minute
to update model_settings.json if there's a change.
"""
"""Background thread to periodically fetch model settings."""
model_logger.info("Started model settings background fetch thread.")
while True:
fetch_model_settings(unique_id, auth_key)
time.sleep(60)


def get_current_model_name():
"""
Reads 'selected_model' from model_settings.json.
Returns the default 'AI4GAmazonClassification' if missing.
"""
"""Get current classification model name."""
return load_model_config().get("selected_model", "AI4GAmazonClassification")


def get_current_labels():
"""
Reads 'lables' dict from model_settings.json.
Returns the default label set if missing.
"""
"""Get current label dictionary."""
return load_model_config().get("lables", DEFAULT_MODEL_CONFIG["lables"])


def is_classification_enabled():
"""
Reads 'classification_enabled' from model_settings.json.
Defaults to True.
"""
"""Whether classification is enabled."""
return load_model_config().get("classification_enabled", True)


def is_keep_blanks_enabled():
"""
Reads 'keep_blanks' from model_settings.json.
Defaults to False.
"""
"""Whether blank images should be kept."""
return load_model_config().get("keep_blanks", False)


def get_detection_threshold():
"""
Reads 'detection_threshold' from model_settings.json.
Defaults to DEFAULT_MODEL_CONFIG['detection_threshold'].
"""
"""Get detection confidence threshold."""
cfg = load_model_config()
return cfg.get("detection_threshold", DEFAULT_MODEL_CONFIG["detection_threshold"])


# Image & Preprocess Utils
ImageFile.LOAD_TRUNCATED_IMAGES = True


def load_font():
"""
Return a Pillow built-in bitmap font.
"""
"""Return a Pillow built-in bitmap font."""
return ImageFont.load_default()


def letterbox(im, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
"""Resize and pad image to meet stride-multiple constraints."""
if isinstance(im, Image.Image):
Expand Down Expand Up @@ -232,10 +234,12 @@ def letterbox(im, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=Tru
im = F.pad(im * 255.0, padding, value=114) / 255.0
return im


# MegaDetector classes
class_name_to_id = {0: "animal", 1: "person", 2: "vehicle"}
colors = ["red", "blue", "purple"]


def preprocess_classification(img):
"""
Preprocess a PIL image for classification:
Expand All @@ -247,6 +251,7 @@ def preprocess_classification(img):
img_np = np.expand_dims(img_np, axis=0).astype(np.float32)
return img_np


# Triton / IO Setup
TRITON_URL = (os.getenv("TRITON_SERVER_URL") or os.getenv("TRITON_URL", "http://triton:8000")).strip().rstrip("/")
if TRITON_URL.startswith(("http://", "https://")):
Expand All @@ -266,6 +271,7 @@ def preprocess_classification(img):
csv_file = '/app/static/data/detections.csv'
os.makedirs(os.path.dirname(csv_file), exist_ok=True)


def write_to_csv(image_name, detection, confidence, date):
"""Append detection results to CSV."""
file_exists = os.path.isfile(csv_file)
Expand All @@ -275,6 +281,41 @@ def write_to_csv(image_name, detection, confidence, date):
writer.writerow(['Image Name', 'Detection', 'Confidence Score', 'Date'])
writer.writerow([image_name, detection, confidence, date])


def save_jpeg_with_boxes(img, boxes_meta, out_path):
"""
Save a JPEG with bounding boxes stored as JSON in EXIF UserComment.

boxes_meta: list of dicts, each like:
{
"x1": float (normalized 0-1),
"y1": float,
"x2": float,
"y2": float,
"label": str,
"score": float,
"class_id": int,
"source": str,
"model": str or None
}
"""
exif_bytes_in = img.info.get("exif", b"")
if exif_bytes_in:
try:
exif_dict = piexif.load(exif_bytes_in)
except Exception:
exif_dict = {"0th": {}, "Exif": {}, "GPS": {}, "1st": {}, "thumbnail": None}
else:
exif_dict = {"0th": {}, "Exif": {}, "GPS": {}, "1st": {}, "thumbnail": None}

payload = json.dumps(boxes_meta).encode("utf-8")
# EXIF UserComment should start with an encoding prefix
exif_dict["Exif"][piexif.ExifIFD.UserComment] = b"ASCII\0\0\0" + payload

exif_bytes_out = piexif.dump(exif_dict)
img.save(out_path, format="JPEG", exif=exif_bytes_out)


# Background Settings Fetch
try:
with open(AUTH_KEY_PATH, "r") as f:
Expand All @@ -292,7 +333,7 @@ def write_to_csv(image_name, detection, confidence, date):

if AUTH_KEY and UNIQUE_ID:
model_thread = threading.Thread(
target=model_settings_fetch_loop,
target=model_settings_fetch_loop,
args=(UNIQUE_ID, AUTH_KEY),
daemon=True
)
Expand Down Expand Up @@ -365,38 +406,43 @@ def write_to_csv(image_name, detection, confidence, date):
md_confidence = pred[:, 4]
md_class_id = pred[:, 5].astype(int)

annotated_img = image.copy()
draw = ImageDraw.Draw(annotated_img)
font = load_font()

drew_any = False # track whether we drew any boxes (after filtering)
skipped_count = 0 # track how many non-animal detections we skip
drew_any = False # we had at least one kept detection
skipped_count = 0 # how many non-animal detections we skip

# Metadata for EXIF (one dict per detection)
boxes_meta = []
img_w, img_h = image.size

# Only create drawing context if we actually want boxes rendered
annotated_img = image.copy() if DRAW_BOXES else image
draw = ImageDraw.Draw(annotated_img) if DRAW_BOXES else None

for i in range(len(pred)):
cls_id = md_class_id[i]

# Skip non-animals (person=1, vehicle=2) if ONLY_SAVE_ANIMALS is enabled
if ONLY_SAVE_ANIMALS and cls_id in (1, 2):
# Log the skip with bbox + confidence
try:
x1, y1, x2, y2 = [float(v) for v in xyxy[i]]
x1_s, y1_s, x2_s, y2_s = [float(v) for v in xyxy[i]]
except Exception:
x1 = y1 = x2 = y2 = -1.0
x1_s = y1_s = x2_s = y2_s = -1.0
label_skipped = "person" if cls_id == 1 else "vehicle"
conf = float(md_confidence[i])
conf_s = float(md_confidence[i])
log.info(
f"Skipping {label_skipped} (conf={conf:.2f}) due to ONLY_SAVE_ANIMALS; "
f"image={image_name}, box=({x1:.1f},{y1:.1f},{x2:.1f},{y2:.1f})"
f"Skipping {label_skipped} (conf={conf_s:.2f}) due to ONLY_SAVE_ANIMALS; "
f"image={image_name}, box=({x1_s:.1f},{y1_s:.1f},{x2_s:.1f},{y2_s:.1f})"
)
skipped_count += 1
continue

md_label = class_name_to_id[cls_id]
det_conf = md_confidence[i]
x1, y1, x2, y2 = xyxy[i]

# Only run classification if it's an "animal" AND classification is enabled
if cls_id == 0 and is_classification_enabled():
x1, y1, x2, y2 = xyxy[i]
cropped = image.crop((x1, y1, x2, y2))
cropped_np = preprocess_classification(cropped)

Expand All @@ -420,21 +466,50 @@ def write_to_csv(image_name, detection, confidence, date):

write_to_csv(image_name, detected_class, clf_conf, date)
label = f"{detected_class} {clf_conf:.2f}"

stored_label = detected_class
stored_conf = clf_conf
stored_model = current_model_name
else:
# For person/vehicle, or if classification disabled, use MD label only
# (This path is not reached for non-animals when ONLY_SAVE_ANIMALS skipped above)
write_to_csv(image_name, md_label, det_conf, date)
label = f"{md_label} {det_conf:.2f}"

# Draw bounding box and label
draw.rectangle(xyxy[i], outline=colors[cls_id], width=2)
text_bbox = draw.textbbox((xyxy[i][0], xyxy[i][1] - 20), label, font=font)
draw.rectangle(
[text_bbox[0], text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2],
fill=colors[cls_id]
stored_label = md_label
stored_conf = float(det_conf)
stored_model = None

# Optionally draw bounding box and label
if DRAW_BOXES and draw is not None:
draw.rectangle(xyxy[i], outline=colors[cls_id], width=2)
text_bbox = draw.textbbox((xyxy[i][0], xyxy[i][1] - 20), label, font=font)
draw.rectangle(
[text_bbox[0], text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2],
fill=colors[cls_id]
)
draw.text((xyxy[i][0] + 2, xyxy[i][1] - 20), label, font=font, fill='white')

drew_any = True # we have at least one kept detection

# Store normalized coordinates + label in metadata list
norm_x1 = float(x1) / float(img_w)
norm_y1 = float(y1) / float(img_h)
norm_x2 = float(x2) / float(img_w)
norm_y2 = float(y2) / float(img_h)

boxes_meta.append(
{
"x1": norm_x1,
"y1": norm_y1,
"x2": norm_x2,
"y2": norm_y2,
"label": stored_label,
"score": float(stored_conf),
"class_id": int(cls_id),
"source": "megadetectorv6",
"model": stored_model,
}
)
draw.text((xyxy[i][0] + 2, xyxy[i][1] - 20), label, font=font, fill='white')
drew_any = True

# Per-image summary for skipped detections
if ONLY_SAVE_ANIMALS and skipped_count:
Expand All @@ -451,8 +526,16 @@ def write_to_csv(image_name, detection, confidence, date):
print(f"Removed source file {image_path} (all detections filtered)")
continue

annotated_img.save(os.path.join(output_dir, image_name))
print(f"Saved {os.path.join(output_dir, image_name)}")
# Save CLEAN image, embedding boxes in EXIF if JPEG
out_path = os.path.join(output_dir, image_name)
img_to_save = image # always save original pixels

if image_name.lower().endswith((".jpg", ".jpeg")):
save_jpeg_with_boxes(img_to_save, boxes_meta, out_path)
else:
img_to_save.save(out_path)

print(f"Saved {out_path}")

# Remove original after processing
os.remove(image_path)
Expand Down
1 change: 1 addition & 0 deletions sparrow/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ librosa==0.10.2.post1
numba==0.59.1
llvmlite==0.42.0
pyftpdlib==2.1.0
piexif==1.1.3