by
Michael Schlauch
Introduction to image classification using camera trap images¶
Camera traps are a tool used by conservationists to study and monitor a wide range of ecologies while limiting human interference. However, they also generate a vast amount of data that quickly exceeds the capacity of humans to sift through. That's where machine learning can help! Advances in computer vision can help automate tasks like species detection and identification, so that humans can spend more time learning from and protecting these ecologies.
This post walks through an initial approach for the Conservision Practice Area challenge on DrivenData, a practice competition where you identify animal species in a real world dataset of wildlife images from Tai National Park in Côte d'Ivoire. This is a practice competition designed to be accessible to participants at all levels. That makes it a great place to dive into the world of data science competitions and computer vision.
We will go through the following steps in order to train a PyTorch model that can be used to identify the species of animal in a given image:
- Set up your environment (feel free to skip)
- Download the data
- Explore the data
- Split into train and evaluation sets
- Build the Model
- Training
- Evaluation
- Create submission
The only pre-requisite is a basic familiarity with Python and some of the basic concepts behind deep learning. We'll guide you step-by-step through the rest.
Let's get started!
1. Set up your environment¶
Feel free to skip this step if you already have an environment set up.
The folks on our team typically use conda to manage environments. Once you have conda installed you can create a new "conserviz" environment (name it whatever you like) with:
conda create -n conserviz python=3.8
Then we activate the new environment and install the required libraries with pip. The pip command below includes all the libraries we'll need for this notebook. Launch a jupyter notebook from this new environment.
conda activate conserviz
pip install pandas matplotlib Pillow tqdm scikit-learn torch torchvision
2. Download the data¶
Download the competition data from the Data Download page. You'll need to first register for the competition by clicking on "Compete" and agreeing to the rules.
The competition.zip
file contains everything you need to take part in this competition, including this notebook benchmark.ipynb
. Unzip the archive into a location of your choice. The file structure should look like this:
├── benchmark.ipynb
├── submission_format.csv
├── test_features
│ ├── ZJ000000.jpg
│ ├── ZJ000001.jpg
│ └── ...
├── test_features.csv
├── train_features
│ ├── ZJ016488.jpg
│ ├── ZJ016489.jpg
│ └── ...
├── train_features.csv
└── train_labels.csv
Next, let's import some of the usual suspects:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from tqdm import tqdm
Read in the train and test CSVs first and see what they look like.
train_features = pd.read_csv("train_features.csv", index_col="id")
test_features = pd.read_csv("test_features.csv", index_col="id")
train_labels = pd.read_csv("train_labels.csv", index_col="id")
The features
CSVs contain the image ID, filepath and site ID for each image.
train_features.head()
test_features.head()
The train_labels
CSV is an indicator matrix of the species identified in each of the training images. Some images are labeled as "blank" if no animal was detected.
train_labels.head()
Let's store a sorted list of the labels, so that we can sort the inputs and outputs to our model in a consistent way.
species_labels = sorted(train_labels.columns.unique())
species_labels
3. Explore the data¶
Now let's see what some of the actual images look like. The code below iterates through a list of species and selects a single random image from each species to display, along with its image ID and label. You can try changing the random_state
variable to display a new set of images.
import matplotlib.image as mpimg
random_state = 42
# we'll create a grid with 8 positions, one for each label (7 species, plus blanks)
fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(20, 20))
# iterate through each species
for species, ax in zip(species_labels, axes.flat):
# get an image ID for this species
img_id = (
train_labels[train_labels.loc[:,species] == 1]
.sample(1, random_state=random_state)
.index[0]
)
# reads the filepath and returns a numpy array
img = mpimg.imread(train_features.loc[img_id].filepath)
# plot etc
ax.imshow(img)
ax.set_title(f"{img_id} | {species}")