2022 Short Course on the Application of Machine Learning for Automated Quantification of Behavior
Day 4 Workshop - Training Classifiers to Classify Animal Behavior
Purpose/Expectations
The contents of this course will grant insight into important characteristics while trainig behavioral classifiers. The usage of JABS code in this tutorial is strictly limited to reading in project annotations and feature vectors.
Within this tutorial, you will be provided with 2 example annotated projects that contain sparse labels that will be used for training and dense labels that you will use for final evaluation of your best model.
Expected knowledge before coming doing the course
Familiarity with animal pose data and frame-level features derived from that data
Basic python experience
Some numpy experience
Familiarity with generating plots in python (preference to using plotnine, which is ggplot-style syntax)
Expected takeaways of the course
Learn the core process to train a frame-wise predictor of behavior
Become familiar with key components to take into account when training these types of classifiers
Gain insight for some of the design decision considerations our group integrated into the JABS software
This is a static web-version of the ipython notebook found here.
Step 0
Import a lot of libraries that we will be using in this tutorial
import numpy as np
import pandas as pd
from itertools import chain
import sklearn
# JABS install directory
import sys
sys.path.append('JABS-behavior-classifier/')
from src.project import Project
from src.classifier import Classifier
# Plotting library
import plotnine as p9
%matplotlib notebook
Step 1
Import the project annotations that we are providing These annotations were created using our JABS software, so we can use that library to extract the annotations in python
project = Project('JABS-Training-Data/')
# We're picking the first behavior in the project. Each JABS project can contain multiple behavior annotations.
behavior = project.load_metadata()['behaviors'][0]
# Here we read in the all important information for all labels
training_data = project.get_labeled_features(behavior, window_size=5, use_social_features=True)
# Since this is more of an internal format, we can separate out the data to be more easily accessible
all_labels = training_data[0]['labels']
annotation_animal = training_data[0]['groups']
annotation_animal_names = [training_data[1][x]['video'] + ' animal ' + str(training_data[1][x]['identity']) for x in annotation_animal]
feature_names = training_data[0]['column_names']
all_features = Classifier().combine_data(training_data[0]['per_frame'],training_data[0]['window'])
Also extract some bout-level groups, something that JABS doesn’t normally do
raw_annotations = [project.load_video_labels(training_data[1][x]['video']).as_dict()['labels'][str(training_data[1][x]['identity'])][behavior] for x in np.unique(annotation_animal)]
num_animals_annotated = len(raw_annotations)
bout_data = [[np.repeat(animal_idx+bout_idx*num_animals_annotated, bout['end']-bout['start']+1) for bout_idx, bout in enumerate(animal_data)] for animal_idx, animal_data in enumerate(raw_annotations)]
bout_data = np.concatenate(list(chain.from_iterable(bout_data)))
At this point we have extracted feature data and annotations from a JABS project. Here’s a description of important variables:
all_features
contains the matrix of all the features. The first dimension is the frame index and the second dimension is the feature index.all_labels
contains a vector of labels. 0 = not behavior, 1 = behaviorannotation_animal
contains a vector with a unique number for each animalbout_data
contains a vector with a unique bout number for each bout annotated *feature_names
contains the name of each feature. Each value here is the name of the 2nd dimension of the feature vector.
Step 2
Read in the held-out test set
This set is distinct from the validation dataset we’ll be splitting in 2 key factors: 1. This set will be held out and should not be used in tuning model parameters. 2. It is also densely annotated (all frames contain a label), rather than the sparse annotation provided for training.
test_project = Project('JABS-Test-Data/')
test_data = test_project.get_labeled_features(behavior, window_size=5, use_social_features=True)
test_labels = test_data[0]['labels']
test_animals = test_data[0]['groups']
test_features = Classifier().combine_data(test_data[0]['per_frame'],test_data[0]['window'])
Step 3
Inspect characteristics of the dataset Since the annotations are located
in all_labels
, we should inspect some characteristics of it
Question 1
How many annotations do we have?
print(len(all_labels))
3407
Question 2
How many labels do we have of each class: 0 (not-behavior) and 1 (behavior)
num_not_behavior = sum(all_labels==0)
print(num_not_behavior)
num_behavior = sum(all_labels==1)
print(num_behavior)
2294
1113
Question 3-4
How many bouts were annotated?
How many animals were annotated?
num_bouts = len(np.unique(bout_data))
print(num_bouts)
num_animals = len(np.unique(annotation_animal))
print(num_animals)
235
21
Discussion 1
What are characteristics of the dataset that may be important about creating a good model?
Hypotheticals to think about:
If you have a rare behavior, is it okay to label a lot more “not behavior”? What might the model try and do to improve accuracy if that ratio becomes 1000:1?
If one animal tends to express the behavior more than another, is it okay to provide more labels on from that individual?
Discussion on the balancing of data
Rules of thumb:
Balanced annotations tend to create better classifiers. The higher the degree of imbalance, the more likely the classifier will just cheat and learn to predict one state over another. Generally, you should aim to not have more than a 2:1 ratio, but performance only particularly degrades when worse than a 5:1 or 10:1 ratio. Re-sampling can address this.
More bouts is generally better, because within-bout annotations have high temporal correlation and are therefore less informative for making decisions that generalize in favor of making decisions specific to that bout.
More animals is generally better, because some animals may express behavior in a more unique style. Having more animals enables better generalization. We’ve observed in the past that limiting to specific strains of animals in training hurts generalization to different strains. This keys in on adequately sampling from the population variation of the experiments you plan on running inferences on - include at least some of each genotype.
Extra visualizations of the dataset
plotting_df = pd.DataFrame({'annotations':all_labels, 'bout':bout_data, 'animal':annotation_animal, 'annotation_idx':np.arange(len(all_labels))})
# Histogram of the class labels
plot_labels = p9.ggplot() + \
p9.geom_histogram(mapping=p9.aes(x='factor(annotations)'), data=plotting_df, bins=2, fill='#9e9e9e', color='#000000', size=2) + \
p9.labs(title='Class Labels', x='Class', y='Count') + \
p9.scale_x_discrete(labels=['Not-Behavior',behavior]) + \
p9.theme_bw()
plot_labels.draw().show()
# Class labels per-animal
plot_animals = p9.ggplot() + \
p9.geom_histogram(mapping=p9.aes(x='factor(annotations)'), data=plotting_df, bins=2, fill='#9e9e9e', color='#000000', size=2) + \
p9.labs(title='Class Labels by Animal', x='Class', y='Count') + \
p9.scale_x_discrete(labels=['Not-Behavior',behavior]) + \
p9.facet_wrap('animal') + \
p9.theme_bw()
plot_animals.draw().show()
Example Output:
Step 4
Splitting the training data into train and validation
We will be using the training portion of the data to allow the algorithm to learn the best parameters to make predictions
The validation will be held out to evaluate performance of the classifier
Here we define a handful of functions that accepts the features and labels and returns the split data
# Defining functions to split the data
# Here we're going to manually shuffle and split
# Naive approach 1 - random sorted split (deterministic)
def split_data(features, labels, percent_train=0.75):
available_examples = np.arange(len(labels))
train_idxs = available_examples[:int(len(labels)*percent_train)]
test_idxs = available_examples[int(len(labels)*percent_train):]
# Separate out the training data
train_features = features[train_idxs]
train_labels = labels[train_idxs]
# Separate out the validation data
valid_features = features[test_idxs]
valid_labels = labels[test_idxs]
return train_features, train_labels, valid_features, valid_labels
# Naive approach 2 - random shuffle
def random_split_data(features, labels, percent_train=0.75):
available_examples = np.arange(len(labels))
np.random.shuffle(available_examples)
train_idxs = available_examples[:int(len(labels)*percent_train)]
test_idxs = available_examples[int(len(labels)*percent_train):]
# Separate out the training data
train_features = features[train_idxs]
train_labels = labels[train_idxs]
# Separate out the validation data
valid_features = features[test_idxs]
valid_labels = labels[test_idxs]
return train_features, train_labels, valid_features, valid_labels
# Using sklearn to split the data
# Stratified splitting:
# Attempts to preserve the train/valid class representations
def sklean_stratified_split(features, labels, percent_train=0.75):
train_idxs, test_idxs = list(sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, train_size=percent_train).split(features, labels))[0]
# Separate out the training data
train_features = features[train_idxs]
train_labels = labels[train_idxs]
# Separate out the validation data
valid_features = features[test_idxs]
valid_labels = labels[test_idxs]
return train_features, train_labels, valid_features, valid_labels
# Group splitting
# Leaves one group out within the split
# Note that this one group that is left out is not guaranteed to contain both labels, so sometime performance will contain things like division by 0.
def sklearn_logo_split(features, labels, groups, percent_train=0.75):
train_idxs, test_idxs = list(sklearn.model_selection.LeaveOneGroupOut().split(features, labels, groups))[0]
# Separate out the training data
train_features = features[train_idxs]
train_labels = labels[train_idxs]
# Separate out the validation data
valid_features = features[test_idxs]
valid_labels = labels[test_idxs]
return train_features, train_labels, valid_features, valid_labels
# Group splitting #2
# Leaves multiple groups out with a target percent annotations that fall into the training set
def sklearn_group_split(features, labels, groups, percent_train=0.75):
train_idxs, test_idxs = list(sklearn.model_selection.GroupShuffleSplit(n_splits=1, train_size=percent_train).split(features, labels, groups))[0]
# Separate out the training data
train_features = features[train_idxs]
train_labels = labels[train_idxs]
# Separate out the validation data
valid_features = features[test_idxs]
valid_labels = labels[test_idxs]
return train_features, train_labels, valid_features, valid_labels
Experiment 1
Inspect/Visualize the effects of different splits of the annotations
train_features, train_labels, valid_features, valid_labels = sklearn_group_split(all_features, all_labels, annotation_animal)
train_df = pd.DataFrame({'state':'train', 'label':train_labels})
valid_df = pd.DataFrame({'state':'valid', 'label':valid_labels})
plot_df = pd.concat([train_df, valid_df])
split_plot = p9.ggplot(plot_df) + \
p9.geom_bar(p9.aes(x='state', fill='factor(label)')) + \
p9.theme_bw()
split_plot.draw().show()
Example Output:
Label Balancing
Since we’re aware that we have roughly a 3x more annotations for not-behavior, we should probably consider balancing the labels
We can take multiple approaches to creating a more balanced dataset: 1. Downsampling: Discard data from the class which has the most annotation until the labels are balanced 2. Upsampling: Duplicate data from the class which has the least annotation until the labels are balanced
# Downsamples the featuers, labels, and groups such that the labels are balanced
def downsample_balanced(features, labels, groups=None):
_, label_counts = np.unique(labels, return_counts=True)
smallest_class_count = np.min(label_counts)
class_0_idxs = np.where(labels==0)[0]
class_1_idxs = np.where(labels==1)[0]
class_0_idxs = np.random.choice(class_0_idxs, smallest_class_count, replace=False)
class_1_idxs = np.random.choice(class_1_idxs, smallest_class_count, replace=False)
selected_samples = np.sort(np.concatenate([class_0_idxs, class_1_idxs]))
new_features = features[selected_samples,:]
new_labels = labels[selected_samples]
if groups is not None:
new_groups = groups[selected_samples]
return new_features, new_labels, new_groups
else:
return new_features, new_labels, None
Question 5
Write a function to upsample the data
# Downsamples the featuers, labels, and groups such that the labels are balanced
def upsample_balanced(features, labels, groups=None):
_, label_counts = np.unique(labels, return_counts=True)
largest_class_count = np.max(label_counts)
class_0_idxs = np.where(labels==0)[0]
class_1_idxs = np.where(labels==1)[0]
class_0_idxs = np.random.choice(class_0_idxs, largest_class_count, replace=True)
class_1_idxs = np.random.choice(class_1_idxs, largest_class_count, replace=True)
selected_samples = np.sort(np.concatenate([class_0_idxs, class_1_idxs]))
new_features = features[selected_samples,:]
new_labels = labels[selected_samples]
if groups is not None:
new_groups = groups[selected_samples]
return new_features, new_labels, new_groups
else:
return new_features, new_labels, None
Experiment 2
Inspect/Visualize the effects of downsampling the annotations
balanced_features, balanced_labels, balanced_groups = downsample_balanced(all_features, all_labels, annotation_animal)
train_features, train_labels, valid_features, valid_labels = sklearn_group_split(balanced_features, balanced_labels, balanced_groups)
train_df = pd.DataFrame({'state':'train', 'label':train_labels})
valid_df = pd.DataFrame({'state':'valid', 'label':valid_labels})
plot_df = pd.concat([train_df, valid_df])
split_plot = p9.ggplot(plot_df) + \
p9.geom_bar(p9.aes(x='state', fill='factor(label)')) + \
p9.theme_bw()
split_plot.draw().show()
Example Output:
Step 5
Train classifiers
For the purposes of this tutorial, we will be relying on using SKLearn to training classifiers.
Good recommended classifiers: 1. Adaboost 2. Random Forest
# Define how to train a decision tree classifier
def train_dt_classifier(train_features, train_labels):
# Create a classifier object
# Note here we are using a bunch of default values that SKLearn has picked for us.
# Check the documentation of the various parameters you can adjust for the decision trees:
# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn-tree-decisiontreeclassifier
classifier = sklearn.tree.DecisionTreeClassifier()
# Train the classifier on our training features + labels
classifier = classifier.fit(train_features, train_labels)
return classifier
Question 6
Write functions to train an adaboost and random forest classifier
# Adaboost Classifier
def train_ada_classifier(train_features, train_labels):
classifier = sklearn.ensemble.AdaBoostClassifier()
classifier = classifier.fit(train_features, train_labels)
return classifier
# Random forest Classifier
def train_rf_classifier(train_features, train_labels):
classifier = sklearn.ensemble.RandomForestClassifier()
classifier = classifier.fit(train_features, train_labels)
return classifier
Step 6
Frame-wise evaluation of the classifier
Since our classifier predicts for every frame, we can estimate its performance by comparing the held-out validation data with the predictions it is making.
For evaluating classifier performance, we must first calculate the 4 values of a confusion matrix: 1. True positives (TP): When both the classifier and the ground truth assign “behavior” 2. False negatives (FN): When the classifier predicts “not behavior” and the ground truth assigns “behavior” 3. False positives (FP): When the classifier predicts “behavior” and the ground truth assigns “not behavior” 4. True negatives (TN): When both the classifier and the ground truth assign “not behavior”
From these, we can calculate 4 additional useful performance metrics for describing the classifier. Wikipedia has a good article outlining these as well as more, but these are the 4 most common. 1. Accuracy: The total number of predictions that were correct. 2. Precision: Of the predictions for “behavior”, how many were correct? 3. Recall: Of the ground truths labeled “behavior”, how many did the classifier predict? 4. F-beta (F1, F-score): Harmonic mean of precision and recall
Question 7
Complete the definitions of the confusion matrix and the performance metrics below
# Routine for evaluating the classifier performance
def eval_classifier_performance(classifier, features, labels, print_results=True):
# Predict on held-out set
predictions = classifier.predict(features)
# Accuracy
# Students will write these
acc = np.mean(predictions == labels)
# Confusion matrix values for a binary classifier
tp_count = np.sum(np.logical_and(predictions==1, labels==1))
tn_count = np.sum(np.logical_and(predictions==0, labels==0))
fp_count = np.sum(np.logical_and(predictions==1, labels==0))
fn_count = np.sum(np.logical_and(predictions==0, labels==1))
# Calculate other metrics
# Of the predictions we made, how many were correct
precision = tp_count/(tp_count + fp_count)
# Of the things we were looking for, how many did we find
recall = tp_count/(tp_count + fn_count)
# Harmonic mean of Pr and Re
f1 = 2*(precision * recall)/(precision + recall)
if print_results:
print('Accuracy: ' + str(acc))
print('Precision: ' + str(precision))
print('Recall: ' + str(recall))
print('F1-score: ' + str(f1))
# We can also show that sklearn has tools for providing these metrics
# Note that we've calculated the "precision" and "recall" for the behavior (value == 1)
# print(sklearn.metrics.classification_report(labels, predictions))
# Manually obtaining each
# sklearn.metrics.precision_score(labels, predictions)
# sklearn.metrics.recall_score(labels, predictions)
# sklearn.metrics.f1_score(labels, predictions)
return acc, precision, recall, f1
Step 7
Run through the steps: 1. Split the training data into train/valid 2. Train a classifier on the train portion of the split 3. Evaluate the classifier
# Step 1: Split the data into train/valid
train_features, train_labels, valid_features, valid_labels = random_split_data(all_features, all_labels)
# Train a classifier
dt_classifier = train_dt_classifier(train_features, train_labels)
# Evaluate the classifier
eval_classifier_performance(dt_classifier, valid_features, valid_labels)
# Repeat these steps to test how different parameters influence performance
Accuracy: 0.9647887323943662
Precision: 0.9651567944250871
Recall: 0.9326599326599326
F1-score: 0.9486301369863013
Experiment 3
See performance of different classifiers and sampling strategies
# Example with extra step for balanced labels
balanced_features, balanced_labels, balanced_groups = downsample_balanced(all_features, all_labels, annotation_animal)
train_features, train_labels, valid_features, valid_labels = sklearn_group_split(balanced_features, balanced_labels, balanced_groups)
dt_classifier = train_dt_classifier(train_features, train_labels)
eval_classifier_performance(dt_classifier, valid_features, valid_labels)
Accuracy: 0.7521663778162911
Precision: 0.7605633802816901
Recall: 0.7422680412371134
F1-score: 0.7513043478260869
Experiment 4
Expand the search of parameters by also adjusting the parameters of the model, eg the number of trees in the random forest.
We provide an example structure for looping through 10 training splits for the decision tree classifier.
Suggested tests: 1. Change the classifier type 2. Change the train/valid split approach 3. Include balanced data 4. Test various parameters of the classifier
all_results = []
# random splits
for test_model in np.arange(10):
train_features, train_labels, valid_features, valid_labels = random_split_data(all_features, all_labels)
dt_classifier = train_dt_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(dt_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['decision tree'], 'split_method':['random'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
ada_classifier = train_ada_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(ada_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['ada boost'], 'split_method':['random'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
rf_classifier = train_rf_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(rf_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['random forest'], 'split_method':['random'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
# leave bouts out splits
for test_model in np.arange(10):
train_features, train_labels, valid_features, valid_labels = sklearn_group_split(all_features, all_labels, bout_data)
dt_classifier = train_dt_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(dt_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['decision tree'], 'split_method':['leave bouts out'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
ada_classifier = train_ada_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(ada_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['ada boost'], 'split_method':['leave bouts out'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
rf_classifier = train_rf_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(rf_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['random forest'], 'split_method':['leave bouts out'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
# leave animals out splits
for test_model in np.arange(10):
train_features, train_labels, valid_features, valid_labels = sklearn_group_split(all_features, all_labels, annotation_animal)
dt_classifier = train_dt_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(dt_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['decision tree'], 'split_method':['leave animal out'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
ada_classifier = train_ada_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(ada_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['ada boost'], 'split_method':['leave animal out'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
rf_classifier = train_rf_classifier(train_features, train_labels)
acc, pr, re, f1 = eval_classifier_performance(rf_classifier, valid_features, valid_labels, print_results=False)
all_results.append(pd.DataFrame({'classifier':['random forest'], 'split_method':['leave animal out'], 'accuracy':[acc], 'precision':[pr], 'recall':[re], 'f1':[f1]}))
# Flatten the list of dataframes into one
results_df = pd.concat(all_results)
Inspect the results of your scan(s)
accuracy_plot = p9.ggplot(data=results_df) + \
p9.geom_point(p9.aes(x='classifier', y='accuracy')) + \
p9.facet_wrap('split_method') + \
p9.labs(title='Classifier accuracy', x='Classifier type', y='Accuracy') + \
p9.theme_bw()
accuracy_plot.draw().show()
Example Output:
Step 8
Evaluate your best classifier on the held-out test dataset
You could also potentially just train a classifier on the entire training set (after you’ve tuned the numbers with the validation set)
#best_classifier = dt_classifier
best_classifier = train_rf_classifier(all_features, all_labels)
eval_classifier_performance(best_classifier, test_features, test_labels)
Accuracy: 0.9844444444444445
Precision: 0.8036332179930796
Recall: 0.6761280931586608
F1-score: 0.7343873517786562
You may observe that performance has dropped a bit, but if you selected the correct train/valid split approach and didn’t over-fit on your validation data, it should still perform well.
Discussion 2
Up until now, we’ve been focussing on frame-level agreement, which is simple enough for training classifiers.
Are there better ways to evaluate performance of a behavior classifier?
One approach could be evaluating dense annotations for bout-level agreement Other notes on the importance of bout-level agreement – eg filter/stitching/other post-processing steps
Step 9
Bout-level agreement on dense annotation
In order to measure bout-level agreement, we need to transform the data into starts and ends
We can either write the detection of bouts within predictions manually or re-use a classical compression algorithm called run length encoding
Note that typical RLE algorithms will encode all states (both “behavior” and “not-behavior”), while we primarily care about only 1 of the 2 states.
# Run length encoding, implemented using numpy
# Accepts a 1d vector
# Returns a tuple containing (starts, durations, values)
def rle(inarray):
ia = np.asarray(inarray)
n = len(ia)
if n == 0:
return (None, None, None)
else:
y = ia[1:] != ia[:-1]
i = np.append(np.where(y), n - 1)
z = np.diff(np.append(-1, i))
p = np.cumsum(np.append(0, z))[:-1]
return(p, z, ia[i])
Use the RLE algorithm to encode predictions
We also need to separate the data by animal, because bouts can only exist within animal
test_predictions = best_classifier.predict(test_features)
# Arrays to store the list of bouts by animal
all_pr_bouts = []
all_gt_bouts = []
# Loop over each animal and add their bouts to the lists
for animal_id in np.unique(test_animals):
animal_indices = test_animals==animal_id
test_bout_predictions = rle(test_predictions[animal_indices])
# Only store the bouts of behavior
behavior_bouts = test_bout_predictions[2]==1
if np.any(behavior_bouts):
test_bout_predictions = (test_bout_predictions[0][behavior_bouts], test_bout_predictions[1][behavior_bouts])
all_pr_bouts.append(test_bout_predictions)
else:
all_pr_bouts.append(([],[]))
test_bout_gt = rle(test_labels[animal_indices])
# Only store the bouts of behavior
behavior_bouts = test_bout_gt[2]==1
if np.any(behavior_bouts):
test_bout_gt = (test_bout_gt[0][behavior_bouts], test_bout_gt[1][behavior_bouts])
all_gt_bouts.append(test_bout_gt)
else:
all_gt_bouts.append(([],[]))
Display the predictions next to their GT
bout_data = []
for animal_idx, animal_id in enumerate(np.unique(test_animals)):
gt_bout_data = all_gt_bouts[animal_idx]
if len(gt_bout_data[0])>0:
gt_bout_df = pd.DataFrame({'state':'GT', 'animal':animal_id, 'animal_idx':animal_idx, 'start_time':gt_bout_data[0], 'end_time':gt_bout_data[0]+gt_bout_data[1]})
bout_data.append(gt_bout_df)
pr_bout_data = all_pr_bouts[animal_idx]
if len(pr_bout_data[0])>0:
pr_bout_df = pd.DataFrame({'state':'PR', 'animal':animal_id, 'animal_idx':animal_idx, 'start_time':pr_bout_data[0], 'end_time':pr_bout_data[0]+pr_bout_data[1]})
bout_data.append(pr_bout_df)
bout_data = pd.concat(bout_data)
bout_plot = p9.ggplot(bout_data) + \
p9.geom_rect(p9.aes(xmin='start_time', xmax='end_time', ymin='animal_idx-0.25', ymax='animal_idx+0.25', fill='factor(state)'), alpha=0.5) + \
p9.labs(title='Bout annotations and predictions', x='Time, frame', y='Animal index', fill='State') + \
p9.theme_bw()
bout_plot.draw().show()
Example Output:
Step 10
Strategies for evaluating performance of bout-based metrics
Since predictions are almost never going to be identical to the ground truth, we need to adjust our strategy for calling correct and incorrect classifications of the confusion matrix. To do this, we should attempt to detect how much the predictions overlap.
In classical image processing, this was done via calculating the intersection over union (IoU) between predictions and ground truth. We can adopt the same technique for our 1-D problem (time).
# Calculates the intersection of 2 bouts
# Each bout is a tuple of (start, duration)
def calculate_intersection(gt_bout, pr_bout):
# Detect the larger of the 2 start times
max_start_time = np.max([gt_bout[0], pr_bout[0]])
# Detect the smaller of the 2 end times
gt_bout_end = gt_bout[0]+gt_bout[1]
pr_bout_end = pr_bout[0]+pr_bout[1]
min_end_time = np.min([gt_bout_end, pr_bout_end])
# Detect if the 2 bouts intersected at all
if max_start_time < min_end_time:
return min_end_time-max_start_time
else:
return 0
# Calculates the union of 2 bouts
# Each bout is a tuple of (start, duration)
def calculate_union(gt_bout, pr_bout):
# If the 2 don't intersect, we can just sum up the durations
if calculate_intersection(gt_bout, pr_bout) == 0:
return gt_bout[1] + pr_bout[1]
# They do intersect
else:
min_start_time = np.min([gt_bout[0], pr_bout[0]])
gt_bout_end = gt_bout[0]+gt_bout[1]
pr_bout_end = pr_bout[0]+pr_bout[1]
max_end_time = np.max([gt_bout_end, pr_bout_end])
return max_end_time - min_start_time
all_intersections = []
all_unions = []
all_ious = []
# Loop over the animals again
num_test_animals = len(np.unique(test_animals))
for animal_idx in np.arange(num_test_animals):
# For each animal, we want a matrix of intersections, unions, and ious
num_gt_bouts = len(all_gt_bouts[animal_idx][0])
num_pr_bouts = len(all_pr_bouts[animal_idx][0])
animal_intersection_mat = np.zeros([num_gt_bouts, num_pr_bouts])
animal_union_mat = np.zeros([num_gt_bouts, num_pr_bouts])
# Do a second loop for each GT bout
for gt_idx in np.arange(num_gt_bouts):
gt_bout = [all_gt_bouts[animal_idx][0][gt_idx], all_gt_bouts[animal_idx][1][gt_idx]]
# Final loop for each proposed bout
for pr_idx in np.arange(num_pr_bouts):
pr_bout = [all_pr_bouts[animal_idx][0][pr_idx], all_pr_bouts[animal_idx][1][pr_idx]]
# Calculate the intersections, unions, and IoUs
animal_intersection_mat[gt_idx, pr_idx] = calculate_intersection(gt_bout, pr_bout)
animal_union_mat[gt_idx, pr_idx] = calculate_union(gt_bout, pr_bout)
# The IoU matrix will just be i / u
animal_iou_mat = animal_intersection_mat/animal_union_mat
# Add the animal data to the resulting lists
all_intersections.append(animal_intersection_mat)
all_unions.append(animal_union_mat)
all_ious.append(animal_iou_mat)
Now that we have IoUs, we can try and apply thresholds for calculating things like precision/recall Since it’s hard to define “True Negatives” with the IoU technique, we can skip that for now
# Define detection metrics, given a IoU threshold
def calc_temporal_iou_metrics(iou_data, threshold):
tp_counts = 0
fn_counts = 0
fp_counts = 0
for cur_iou_mat in iou_data:
tp_counts += np.sum(np.any(cur_iou_mat>threshold, axis=1))
fn_counts += np.sum(np.all(cur_iou_mat<threshold, axis=1))
fp_counts += np.sum(np.all(cur_iou_mat<threshold, axis=0))
precision = tp_counts/(tp_counts + fp_counts)
recall = tp_counts/(tp_counts + fn_counts)
f1 = 2*(precision * recall)/(precision + recall)
return precision, recall, f1
all_iou_results = []
ious_to_evaluate = all_ious
for cur_iou in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
performance = calc_temporal_iou_metrics(ious_to_evaluate, cur_iou)
all_iou_results.append(pd.DataFrame({'iou_threshold':[cur_iou], 'precision':[performance[0]], 'recall':[performance[1]], 'f1':[performance[2]]}))
iou_df = pd.concat(all_iou_results)
iou_df = pd.melt(iou_df, id_vars = ['iou_threshold'], value_vars = ['precision','recall','f1'])
iou_plot = p9.ggplot(iou_df) + \
p9.geom_line(p9.aes(x='iou_threshold', y='value', color='factor(variable)')) + \
p9.theme_bw()
iou_plot.draw().show()
Example Output:
Discussion 3
Pros and Cons to bout based analysis
False positives are overestimated, because there may be multiple predicted bouts within 1 real bout.
Are there potential methods to try and fix this issue?
Possible techniques include: 1. Filtering out short/spurious predictions 2. Stitching together multiple predictions close in time 3. Adjusting what we’re calling TP/FN/FP to be more lenient on the exact number of bouts
Step 11
Post-processing techniques
def merge_behavior_gaps(bout_starts, bout_durations, bout_states, max_gap_size, state_to_merge = False):
gaps_to_remove = np.logical_and(bout_states==state_to_merge, bout_durations<max_gap_size)
new_durations = np.copy(bout_durations)
new_starts = np.copy(bout_starts)
new_states = np.copy(bout_states)
if np.any(gaps_to_remove):
# Go through backwards removing gaps
for cur_gap in np.where(gaps_to_remove)[0][::-1]:
# Nothing earlier or later to join together, ignore
if cur_gap == 0 or cur_gap == len(new_durations)-1:
pass
else:
cur_duration = np.sum(new_durations[cur_gap-1:cur_gap+2])
new_durations[cur_gap-1] = cur_duration
new_durations = np.delete(new_durations, [cur_gap, cur_gap+1])
new_starts = np.delete(new_starts, [cur_gap, cur_gap+1])
new_states = np.delete(new_states, [cur_gap, cur_gap+1])
return new_starts, new_durations, new_states
Experiment 5
Test performance across multiple filtering parameters Run the next 2 cells and change the 2 variables at the top to be more/less stringent on bouts
# Filter out short bouts
min_bout_duration = 9
# Remove short breaks in bouts
min_gap_duration = 5
filtered_pr_bouts = []
for animal_id in np.unique(test_animals):
animal_indices = test_animals==animal_id
raw_bout_predictions = rle(test_predictions[animal_indices])
# Remove short breaks in bouts
filtered_bout_predictions = merge_behavior_gaps(raw_bout_predictions[0], raw_bout_predictions[1], raw_bout_predictions[2], min_bout_duration, state_to_merge=False)
# Filter out short bouts
filtered_bout_predictions = merge_behavior_gaps(filtered_bout_predictions[0], filtered_bout_predictions[1], filtered_bout_predictions[2], min_bout_duration, state_to_merge=True)
behavior_bouts = filtered_bout_predictions[2]==1
if np.any(behavior_bouts):
filtered_bout_predictions = (filtered_bout_predictions[0][behavior_bouts], filtered_bout_predictions[1][behavior_bouts])
filtered_pr_bouts.append(filtered_bout_predictions)
else:
filtered_pr_bouts.append(([],[]))
filtered_intersections = []
filtered_unions = []
filtered_ious = []
# Loop over the animals again
num_test_animals = len(np.unique(test_animals))
for animal_idx in np.arange(num_test_animals):
# For each animal, we want a matrix of intersections, unions, and ious
num_gt_bouts = len(all_gt_bouts[animal_idx][0])
num_pr_bouts = len(filtered_pr_bouts[animal_idx][0])
animal_intersection_mat = np.zeros([num_gt_bouts, num_pr_bouts])
animal_union_mat = np.zeros([num_gt_bouts, num_pr_bouts])
# Do a second loop for each GT bout
for gt_idx in np.arange(num_gt_bouts):
gt_bout = [all_gt_bouts[animal_idx][0][gt_idx], all_gt_bouts[animal_idx][1][gt_idx]]
# Final loop for each proposed bout
for pr_idx in np.arange(num_pr_bouts):
pr_bout = [filtered_pr_bouts[animal_idx][0][pr_idx], filtered_pr_bouts[animal_idx][1][pr_idx]]
# Calculate the intersections, unions, and IoUs
animal_intersection_mat[gt_idx, pr_idx] = calculate_intersection(gt_bout, pr_bout)
animal_union_mat[gt_idx, pr_idx] = calculate_union(gt_bout, pr_bout)
# The IoU matrix will just be i / u
animal_iou_mat = animal_intersection_mat/animal_union_mat
# Add the animal data to the resulting lists
filtered_intersections.append(animal_intersection_mat)
filtered_unions.append(animal_union_mat)
filtered_ious.append(animal_iou_mat)
Plot the performance with post-processing!
all_iou_results = []
for cur_iou in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
performance = calc_temporal_iou_metrics(filtered_ious, cur_iou)
all_iou_results.append(pd.DataFrame({'iou_threshold':[cur_iou], 'precision':[performance[0]], 'recall':[performance[1]], 'f1':[performance[2]]}))
iou_df_filtered = pd.concat(all_iou_results)
iou_df_filtered = pd.melt(iou_df_filtered, id_vars = ['iou_threshold'], value_vars = ['precision','recall','f1'])
iou_df_filtered['filtered'] = True
iou_df['filtered'] = False
iou_df_filtered = pd.concat([iou_df_filtered, iou_df])
iou_plot = p9.ggplot(iou_df_filtered) + \
p9.geom_line(p9.aes(x='iou_threshold', y='value', color='filtered')) + \
p9.facet_wrap('variable') + \
p9.theme_bw()
iou_plot.draw().show()
Example Output:
bout_data = []
for animal_idx, animal_id in enumerate(np.unique(test_animals)):
gt_bout_data = all_gt_bouts[animal_idx]
if len(gt_bout_data[0])>0:
gt_bout_df = pd.DataFrame({'state':'GT', 'animal':animal_id, 'animal_idx':animal_idx, 'start_time':gt_bout_data[0], 'end_time':gt_bout_data[0]+gt_bout_data[1]})
bout_data.append(gt_bout_df)
pr_bout_data = filtered_pr_bouts[animal_idx]
if len(pr_bout_data[0])>0:
pr_bout_df = pd.DataFrame({'state':'PR', 'animal':animal_id, 'animal_idx':animal_idx, 'start_time':pr_bout_data[0], 'end_time':pr_bout_data[0]+pr_bout_data[1]})
bout_data.append(pr_bout_df)
bout_data = pd.concat(bout_data)
bout_plot = p9.ggplot(bout_data) + \
p9.geom_rect(p9.aes(xmin='start_time', xmax='end_time', ymin='animal_idx-0.25', ymax='animal_idx+0.25', fill='factor(state)'), alpha=0.5) + \
p9.labs(title='Bout annotations and predictions', x='Time, frame', y='Animal index', fill='State') + \
p9.theme_bw()
bout_plot.draw().show()
Example Output: