Neural Regression, Representational Similarity, Model Zoology & Neural Taskonomy at Scale in Rodent Visual Cortex
This article is an abridged summary of a longer work appearing at NeurIPS 2021, as well as a conceptual introduction to the Deep Mouse Trap Github repo.
What goes on in the brain of a mouse? It’s a seemingly simple question that belies a devilishly complicated scientific endeavor to understand how the firing and wiring together of neurons in a nervous system produce intelligent behavior. The mouse is arguably the centerpiece of a modern neuroscientific praxis that has availed itself of everything from genetics to cybernetics, yet in certain respects we know very little about certain key aspects of its neural software. In this project, we’ll be looking at vision, and in particular the ways we’ve increasingly come to model it.
The relative paucity of models we have for characterizing the vision of mice is made all the more conspicuous by the relative excess of models we have for characterizing the vision of another paradigmatic lab animal: monkeys (and in particular the rhesus macaque). Over the last 5 years, our ability to characterize and predict the neural activity of macaque visual cortex has surged in large part thanks to a singular class of model: object-recognizing deep neural networks. So powerful are these models that we can even use them as a sort of neural remote control, synthesizing visual stimuli that drives neural activity beyond the range evoked by handpicked natural images. The success of these models in predicting mouse visual cortex, on the other hand, has been a bit more modest, with some even suggesting that randomly initialized neural networks (neural networks that have never actually learned anything) are as predictive as trained ones – a particularly worrisome suggestion if we’d like to make mechanistic claims about the neural activity we’re predicting as having something to do with visual intelligence.
Here, we re-examine the state of neural network modeling in mouse visual cortex, using a massive optical physiology dataset of almost 6,600 highly reliable visual cortical neurons (courtesy of the Allen Brain Observatory), a large battery of neural networks (both trained and randomly initialized), and multiple methods of comparing the activity of those networks to the brain (including both representational similarity and linear mapping). Our intent is this is not to necessarily to converge on the single best model of mouse brain per se, but to better understand the kinds of pressures that shape the representations inherent to those models with greater or lesser neural predictivity.
We first preprocess the neural data such that we have the average responses per neuron to each of the 119 natural scene images that were used by the Brain Observatory as a freeviewing probe. (The 6619 neurons in our final set of neurons are actually the subsampled neurons from a larger set of about 37,398 unique neurons that we’ve filtered for reliability.) Our neural sample includes neurons from 6 different cortical areas that span what has (neuroanatomically) been demarcated as the rodent ventral and dorsal visual pathways.
We then compare these responses systematically to the responses of artificial neurons across the layers of a variety of deep net models, selected deliberately to engender meaningful experimental foils we can use to answer thematic questions about representations in the mouse brain. These models include over 90 distinct architectures (e.g. ConvNets, transformers, MLP-Mixers) all trained on object classification with the ImageNet training set; the randomly-initialized (untrained) versions of these same 90 architectures; the 24 models of the Taskonomy project (which include as a backbone the same architecture of encoder); and 20 models (all ResNet50 architectures) trained on a variety of self-supervised tasks. A list of the models we use is available in the table below.
| model_display_name | description |
|---|---|
| AlexNet | AlexNet trained on image classification with the ImageNet dataset. |
| VGG11 | VGG11 trained on image classification with the ImageNet dataset. |
| VGG13 | VGG13 trained on image classification with the ImageNet dataset. |
| VGG16 | VGG16 trained on image classification with the ImageNet dataset. |
| VGG19 | VGG19 trained on image classification with the ImageNet dataset. |
| VGG11-BatchNorm | VGG11-BatchNorm trained on image classification with the ImageNet dataset. |
| VGG13-BatchNorm | VGG13-BatchNorm trained on image classification with the ImageNet dataset. |
| VGG16-BatchNorm | VGG16-BatchNorm trained on image classification with the ImageNet dataset. |
| VGG19-BatchNorm | VGG19-BatchNorm trained on image classification with the ImageNet dataset. |
| ResNet18 | ResNet18 trained on image classification with the ImageNet dataset. |
| ResNet34 | ResNet34 trained on image classification with the ImageNet dataset. |
| ResNet50 | ResNet50 trained on image classification with the ImageNet dataset. |
| ResNet101 | ResNet101 trained on image classification with the ImageNet dataset. |
| ResNet152 | ResNet152 trained on image classification with the ImageNet dataset. |
| SqueezeNet1.0 | SqueezeNet1.0 trained on image classification with the ImageNet dataset. |
| SqueezeNet1.1 | SqueezeNet1.1 trained on image classification with the ImageNet dataset. |
| DenseNet121 | DenseNet121 trained on image classification with the ImageNet dataset. |
| DenseNet161 | DenseNet161 trained on image classification with the ImageNet dataset. |
| DenseNet169 | DenseNet169 trained on image classification with the ImageNet dataset. |
| DenseNet201 | DenseNet201 trained on image classification with the ImageNet dataset. |
| GoogleNet | GoogleNet trained on image classification with the ImageNet dataset. |
| ShuffleNet-V2-x0.5 | ShuffleNet-V2-x0.5 trained on image classification with the ImageNet dataset. |
| ShuffleNet-V2-x1.0 | ShuffleNet-V2-x1.0 trained on image classification with the ImageNet dataset. |
| MobileNet-V2 | MobileNet-V2 trained on image classification with the ImageNet dataset. |
| ResNext50-32x4D | ResNext50-32x4D trained on image classification with the ImageNet dataset. |
| ResNext50-32x8D | ResNext50-32x8D trained on image classification with the ImageNet dataset. |
| Wide-ResNet50 | Wide-ResNet50 trained on image classification with the ImageNet dataset. |
| Wide-ResNet101 | Wide-ResNet101 trained on image classification with the ImageNet dataset. |
| MNASNet0.5 | MNASNet0.5 trained on image classification with the ImageNet dataset. |
| MNASNet1.0 | MNASNet1.0 trained on image classification with the ImageNet dataset. |
| Inception-V3 | Inception-V3 trained on image classification with the ImageNet dataset. |
| Autoencoder | Image compression and decompression |
| Object Classification | 1000-way object classification (via knowledge distillation from ImageNet). |
| Scene Classification | Scene Classification (via knowledge distillation from MIT Places). |
| Curvatures | Magnitude of 3D principal curvatures |
| Denoising | Uncorrupted version of corrupted image. |
| Euclidean Depth | Depth estimation |
| Z-Buffer Depth | Depth estimation. |
| Occlusion Edges | Edges which include parts of the scene. |
| Texture Edges | Edges computed from RGB only (texture edges). |
| Egomotion | Odometry (camera poses) given three input images. |
| Camera Pose (Fixated) | Relative camera pose with matching optical centers. |
| Inpainting | Filling in masked center of image. |
| Jigsaw | Putting scrambled image pieces back together. |
| 2D Keypoints | Keypoint estimation from RGB-only (texture features). |
| 3D Keypoints | 3D Keypoint estimation from underlying scene 3D. |
| Camera Pose (Nonfixated) | Relative camera pose with distinct optical centers. |
| Surface Normals | Pixel-wise surface normals. |
| Point Matching | Classifying if centers of two images match or not. |
| Reshading | Reshading with new lighting placed at camera location. |
| Room Layout | Orientation and aspect ratio of cubic room layout. |
| Semantic Segmentation | Pixel-wise semantic labeling (via knowledge distillation from MS COCO). |
| Unsupervised 2.5D Segmentation | Segmentation (graph cut approximation) on RGB-D-Normals-Curvature image. |
| Unsupervised 2D Segmentation | Segmentation (graph cut approximation) on RGB. |
| Vanishing Point | Three Manhattan-world vanishing points. |
| Random Weights | Taskonomy architecture randomly initialized. |
| CaIT-S24 | CaIT-S24 trained on image classification with the ImageNet dataset. |
| CoaT-Lite-Mini | CoaT-Lite-Mini trained on image classification with the ImageNet dataset. |
| ConViT-B | ConViT-B trained on image classification with the ImageNet dataset. |
| ConViT-S | ConViT-S trained on image classification with the ImageNet dataset. |
| CSP-DarkNet53 | CSP-DarkNet53 trained on image classification with the ImageNet dataset. |
| CSP-ResNet50 | CSP-ResNet50 trained on image classification with the ImageNet dataset. |
| DLA34 | DLA34 trained on image classification with the ImageNet dataset. |
| DLA169 | DLA169 trained on image classification with the ImageNet dataset. |
| ECA-NFNeT-L0 | ECA-NFNeT-L0 trained on image classification with the ImageNet dataset. |
| ECA-NFNeT-L1 | ECA-NFNeT-L1 trained on image classification with the ImageNet dataset. |
| ECA-Resnet50-D | ECA-Resnet50-D trained on image classification with the ImageNet dataset. |
| ECA-Resnet101-D | ECA-Resnet101-D trained on image classification with the ImageNet dataset. |
| EfficientNet-V2-S | EfficientNet-V2-S trained on image classification with the ImageNet dataset. |
| FBNetC100 | FBNetC100 trained on image classification with the ImageNet dataset. |
| GerNet-L | GerNet-L trained on image classification with the ImageNet dataset. |
| GerNet-S | GerNet-S trained on image classification with the ImageNet dataset. |
| GhostNet100 | GhostNet100 trained on image classification with the ImageNet dataset. |
| HardCoreNAS-A | HardCoreNAS-A trained on image classification with the ImageNet dataset. |
| HardCoreNAS-F | HardCoreNAS-F trained on image classification with the ImageNet dataset. |
| LeViT128 | LeViT128 trained on image classification with the ImageNet dataset. |
| LeViT256 | LeViT256 trained on image classification with the ImageNet dataset. |
| Inception-Resnet-V2 | Inception-Resnet-V2 trained on image classification with the ImageNet dataset. |
| Inception-V3 | Inception-V3 trained on image classification with the ImageNet dataset. |
| Inception-V4 | Inception-V4 trained on image classification with the ImageNet dataset. |
| Inception-V4 | Inception-V4 trained on image classification with the ImageNet dataset. |
| MLP-Mixer-B16 | MLP-Mixer-B16 trained on image classification with the ImageNet dataset. |
| MLP-Mixer-L16 | MLP-Mixer-L16 trained on image classification with the ImageNet dataset. |
| MixNet-L | MixNet-L trained on image classification with the ImageNet dataset. |
| MixNet-S | MixNet-S trained on image classification with the ImageNet dataset. |
| MNASNet100 | MNASNet100 trained on image classification with the ImageNet dataset. |
| MNASNet100 | MNASNet100 trained on image classification with the ImageNet dataset. |
| MobileNet-V3 | MobileNet-V3 trained on image classification with the ImageNet dataset. |
| NASNet-A-Large | NASNet-A-Large trained on image classification with the ImageNet dataset. |
| NF-ResNet50 | NF-ResNet50 trained on image classification with the ImageNet dataset. |
| NF-Net-L0 | NF-Net-L0 trained on image classification with the ImageNet dataset. |
| PNASNet-5-Large | PNASNet-5-Large trained on image classification with the ImageNet dataset. |
| RegNetX-64 | RegNetX-64 trained on image classification with the ImageNet dataset. |
| RegNetY-64 | RegNetY-64 trained on image classification with the ImageNet dataset. |
| RepVGG-B3 | RepVGG-B3 trained on image classification with the ImageNet dataset. |
| RepVGG-B3G4 | RepVGG-B3G4 trained on image classification with the ImageNet dataset. |
| Res2Net50-26W-4S | Res2Net50-26W-4S trained on image classification with the ImageNet dataset. |
| ResNest50D | ResNest50D trained on image classification with the ImageNet dataset. |
| ResNetRS50 | ResNetRS50 trained on image classification with the ImageNet dataset. |
| RexNet100 | RexNet100 trained on image classification with the ImageNet dataset. |
| SemNASNet100 | SemNASNet100 trained on image classification with the ImageNet dataset. |
| SEResNet152D | SEResNet152D trained on image classification with the ImageNet dataset. |
| SEResNext50-32x4D | SEResNext50-32x4D trained on image classification with the ImageNet dataset. |
| SKResNet18 | SKResNet18 trained on image classification with the ImageNet dataset. |
| SKResNext50-32x4D | SKResNext50-32x4D trained on image classification with the ImageNet dataset. |
| SPNasNet100 | SPNasNet100 trained on image classification with the ImageNet dataset. |
| Swin-B-P4-W7-224 | Swin-B-P4-W7-224 trained on image classification with the ImageNet dataset. |
| Swin-L-P4-W7-224 | Swin-L-P4-W7-224 trained on image classification with the ImageNet dataset. |
| Swin-S-P4-W7-224 | Swin-S-P4-W7-224 trained on image classification with the ImageNet dataset. |
| EfficientNet-B1 | EfficientNet-B1 trained on image classification with the ImageNet dataset. |
| EfficientNet-B3 | EfficientNet-B3 trained on image classification with the ImageNet dataset. |
| EfficientNet-B5 | EfficientNet-B5 trained on image classification with the ImageNet dataset. |
| EfficientNet-B7 | EfficientNet-B7 trained on image classification with the ImageNet dataset. |
| Visformer | Visformer trained on image classification with the ImageNet dataset. |
| ViT-L-P16-224 | ViT-L-P16-224 trained on image classification with the ImageNet dataset. |
| ViT-S-P16-224 | ViT-S-P16-224 trained on image classification with the ImageNet dataset. |
| ViT-B-P16-224 | ViT-B-P16-224 trained on image classification with the ImageNet dataset. |
| XCeption | XCeption trained on image classification with the ImageNet dataset. |
| XCeption65 | XCeption65 trained on image classification with the ImageNet dataset. |
| ResNet50-JigSaw-P100 | ResNet50-JigSaw-P100 trained via self supervision with the ImageNet dataset. |
| ResNet50-JigSaw-Goyal19 | ResNet50-JigSaw-Goyal19 trained via self supervision with the ImageNet dataset. |
| ResNet50-RotNet | ResNet50-RotNet trained via self supervision with the ImageNet dataset. |
| ResNet50-ClusterFit-16K-RotNet | ResNet50-ClusterFit-16K-RotNet trained via self supervision with the ImageNet dataset. |
| ResNet50-NPID-4KNegative | ResNet50-NPID-4KNegative trained via self supervision with the ImageNet dataset. |
| ResNet50-PIRL | ResNet50-PIRL trained via self supervision with the ImageNet dataset. |
| ResNet50-SimCLR | ResNet50-SimCLR trained via self supervision with the ImageNet dataset. |
| ResNet50-DeepClusterV2 | ResNet50-DeepClusterV2-2x224 trained via self supervision with the ImageNet dataset. |
| ResNet50-DeepClusterV2 | ResNet50-DeepClusterV2-2x224+6x96 trained via self supervision with the ImageNet dataset. |
| ResNet50-SwAV-BS4096 | ResNet50-SwAV-BS4096-2x224 trained via self supervision with the ImageNet dataset. |
| ResNet50-SwAV-BS4096 | ResNet50-SwAV-BS4096-2x224+6x96 trained via self supervision with the ImageNet dataset. |
| ResNet50-MoCoV2-BS256 | ResNet50-MoCoV2-BS256 trained via self supervision with the ImageNet dataset. |
| ResNet50-BarlowTwins-BS2048 | ResNet50-BarlowTwins-BS2048 trained via self supervision with the ImageNet dataset. |
| Dino-VIT-S16 | Dino-VIT-S16 trained via self supervision with the ImageNet dataset. |
| Dino-VIT-S8 | Dino-VIT-S8 trained via self supervision with the ImageNet dataset. |
| Dino-VIT-B16 | Dino-VIT-B16 trained via self supervision with the ImageNet dataset. |
| Dino-VIT-B8 | Dino-VIT-B8 trained via self supervision with the ImageNet dataset. |
| Dino-XCIT-S12-P16 | Dino-XCIT-S12-P16 trained via self supervision with the ImageNet dataset. |
| Dino-XCIT-S12-P8 | Dino-XCIT-S12-P8 trained via self supervision with the ImageNet dataset. |
| Dino-XCIT-M24-P16 | Dino-XCIT-M24-P16 trained via self supervision with the ImageNet dataset. |
| Dino-XCIT-M24-P8 | Dino-XCIT-M24-P8 trained via self supervision with the ImageNet dataset. |
| Dino-ResNet50 | Dino-ResNet50 trained via self supervision with the ImageNet dataset. |
Equipped with our neural data and models, we then employ two distinct methods for mapping the responses of our biological neurons to the responses of our artificial neurons. The first, often called classic representational similarity analysis, is designed to assess representational structure (sometimes referred to as representational geometry) at the level of neural populations – in our case, the neural populations of 6 different visual cortical areas. The key component of representational similarity the representational (dis)similarity matrix (RDM), a distance matrix computed by taking the pairwise distance (1 - Pearson correlation coefficient in our case) of each stimulus to every other stimulus across all neural responses in the target population. A given model’s neural predictivity score in this classic representational similarity analysis is simply an average second order-distance (in our case, another 1 - Pearson correlation coefficient) between the RDM of its maximally correspondent layer and each of the cortical RDMs in our sample. Note that this kind of classic representational similarity analysis is a nonparametric mapping, and requires no fits or transformations – just an emergent similarity in how stimuli are organized across the responses of the two neural populations (one artificial, one biological) being compared.
Rather than target an entire neural population simultaneously, we can also more closely scrutinize individual neural responses using a method broadly called neural encoding or neural regression. With this method, we take the artificial neural responses of our model as the set of predictors in a regression where we try to predict (always with some form of cross-validation) the responses of a biological neuron to images not included in the regression. What we’re effectively doing in this method is mixing and matching our set of artificial neurons (often with some sort of dimensiionality reduction along the way) to approximate the representational profile of a single biological neuron. The better suited those artificial neurons are to this mixing and matching (which we measure with a correlation between the neural responses predicted by the regression and the actual responses of a target neuron), the higher the score of the model that hosts them. (A schematic of our neural regression method can be seen in the figure below.)
Combining together our neural data, models and mapping methods, a sort of intuitive first result we obtain is a large set of model rankings, which in the plots below we’ve organized broadly into 1 of 3 categories.