Implementing EtinyNet-1.0 in PyTorch

Nathan Bailey
9 min readMar 2, 2024

--

After recently detailing the architecture of an impressive TinyML CNN [1], I went about implementing this in PyTorch. This blog details the network implementation, highlighting where specific assumptions had to be made. In addition, the network was trained and tested on Tiny ImageNet 200 [5], a smaller version of ImageNet-1000 which the original authors of the paper used for training and evaluation.

Implementation

The implementation of this network was straightforward apart from a few assumptions that needed to be made due to a lack of information provided in the paper.

The overall architecture of the network can be seen below. As shown in the figure through each stage, except the pooling layer we halve the spatial size of the tensors. We finally use a global pooling layer to flatten the tensor before inputting it into a fully connected layer for classification.

Linear depthwise block, dense linear depthwise block, overall network architecture [1]

I first started off creating the building blocks. The linear depthwise block (LB) was trivial to create, I first took a depthwise convolutional layer and stripped out the ReLU activation function between the first depthwise component and the pointwise layer. I then added a depthwise component to the output of the pointwise layer before adding batch normalization layers after each stage.

A depthwise component can alter the spatial size of the input if so desired. In the original paper, they did not specify which component of the LB block is responsible for the downsampling of the input. Therefore, I assumed that the first depthwise component would be responsible for this and kept the padding as “same” for the final depthwise component. This ensures that no changes are made to the spatial size of the input to this layer.

The implementation of the LB block can be seen in the following code snippet.

class LB(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int, stride: int, bias: bool = True) -> None:
super().__init__()
self.depthwise_conv_layer_a = torch.nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, padding=padding, stride=stride, bias=bias)
self.batch_normalization_a = torch.nn.BatchNorm2d(num_features=in_channels)
self.pointwise_layer = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
self.batch_normalization_point = torch.nn.BatchNorm2d(num_features=out_channels)
self.depthwise_conv_layer_b = torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, groups=out_channels, padding="same", stride=1, bias=bias)
self.batch_normalization_b = torch.nn.BatchNorm2d(num_features=out_channels)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
depthwise_result = self.batch_normalization_a(self.depthwise_conv_layer_a(input_tensor))
pointwise_result = F.relu(self.batch_normalization_point(self.pointwise_layer(depthwise_result)))
return F.relu(self.batch_normalization_b(self.depthwise_conv_layer_b(pointwise_result)))

The dense linear depthwise block (DLB) followed the same structure as the linear block, just now with added shortcut connections. As the input to the block could and did have different dimensions compared to the output of the block, a linear projection was needed to add the input to the output of the block. Again the paper did not mention this, so the downsampling structure used in ResNet was adopted here [6]. If the block features a stride parameter greater than 1 or the input dimensions are different to the output dimensions we scale the input using a 1x1 convolutional block.

if stride != 1 or in_channels != out_channels:
downsample_block = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
torch.nn.BatchNorm2d(out_channels)
)

The shortcut connections are then implemented by summing the components together before applying the activation function.

class DLB(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int = 0, stride: int = 1, downsample: torch.nn.Module | None = None, bias: bool = True) -> None:
super().__init__()
self.depthwise_conv_layer_a = torch.nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, padding=padding, stride=stride, bias=bias)
self.batch_normalization_a = torch.nn.BatchNorm2d(num_features=in_channels)
self.pointwise_layer = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
self.batch_normalization_point = torch.nn.BatchNorm2d(num_features=out_channels)
self.depthwise_conv_layer_b = torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, groups=out_channels, padding="same", stride=1, bias=bias)
self.batch_normalization_b = torch.nn.BatchNorm2d(num_features=out_channels)
self.downsample = downsample

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
residual = input_tensor
depthwise_result = self.batch_normalization_a(self.depthwise_conv_layer_a(input_tensor))
pointwise_result = self.batch_normalization_point(self.pointwise_layer(depthwise_result))
if self.downsample:
residual = self.downsample(input_tensor)
pointwise_result = pointwise_result + residual
pointwise_result = F.relu(pointwise_result)
final_depthwise_result = self.batch_normalization_b(self.depthwise_conv_layer_b(pointwise_result))
final_depthwise_result = F.relu(final_depthwise_result + pointwise_result + residual)
return final_depthwise_result

The network was then created in a modular fashion. First, the starting convolutional layer was created. I chose to omit the max pooling layer and according to the table, it did not seem to affect the spatial dimensions of the input tensor. I therefore saw little purpose in including a pooling layer that would simply pad the image before downscaling it to the same dimensions.

self.starting_conv_layer = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU()
)

A function was built to create the desired building block and any needed downsampling layers. A list of dictionaries containing the building block parameters was inputted into the network’s initialisation function which would inform of the structure.

def _make_layer(self, in_channels: int, out_channels: int, kernel_size: int, padding: int, stride: int, layer_type: str):
if layer_type == "lb":
layer = LB(in_channels, out_channels, kernel_size, padding, stride)
else:
downsample_block = None
if stride != 1 or in_channels != out_channels:
downsample_block = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
torch.nn.BatchNorm2d(out_channels)
)
layer = DLB(in_channels, out_channels, kernel_size, padding, stride, downsample=downsample_block)
return layer

As seen from the table, each section contains multiple blocks. I assumed that the first block would halve the input dimensions and the others would keep the spatial dimensions the same. Therefore, the first block was given a stride of 2 and a padding of 1. There was also no mention of kernel size, which I assumed to be 3.

inter_block_list = []
for block_i in block_info:
block_section = []
for idx, layer_value in enumerate(block_i["layer_values"]):
if idx != 0:
stride=1
padding="same"
else:
padding=1
stride=2
block = self._make_layer(in_channels=layer_value["in_channels"], out_channels=layer_value['out_channels'], kernel_size=3, padding=padding, stride=stride, layer_type=block_i['block_type'])
block_section.append(block)
inter_block_list.append(torch.nn.Sequential(*block_section))

self.blocks = torch.nn.Sequential(*inter_block_list)
etinynet_block_info = [
{
"block_type": "lb",
"layer_values": [{"in_channels": 32, "out_channels": 32} for _ in range(4)]
},
{
"block_type": "lb",
"layer_values": [{"in_channels": 32, "out_channels": 128}] + [{"in_channels": 128, "out_channels": 128} for _ in range(3)]
},
{
"block_type": "dlb",
"layer_values": [{"in_channels": 128, "out_channels": 192}] + [{"in_channels": 192, "out_channels": 192} for _ in range(2)]
},
{
"block_type": "dlb",
"layer_values": [{"in_channels": 192, "out_channels": 256}, {"in_channels": 256, "out_channels": 256}, {"in_channels": 256, "out_channels": 512}]
}
]

Finally, a global pooling layer was applied before the final classification fully connected layer.

self.global_pool = torch.nn.AvgPool2d(kernel_size=7)
self.fully_connected = torch.nn.Linear(block_info[-1]["layer_values"][-1]["out_channels"], 200)
self.layers = torch.nn.Sequential(
self.starting_conv_layer,
self.blocks,
self.global_pool,
torch.nn.Flatten(),
self.fully_connected
)

Looking at the total number of parameters, the original EtinyNet had 980K parameters, and our network with a reduced fully connected layer (200 neurons instead of 1000 neurons) had 770K parameters. If we use 1000 neurons in the output layer we achieve a network with 1.2M parameters. As stated above, we had to make several assumptions about the structure of the network, which could have added more parameters than in the original network. Therefore, we can say our network is comparable to the size of the original EtinyNet.

Training

The model was trained as per the original paper. The original paper opted to crop the input images to 224x224 randomly. However, I opted to just resize them to 224x224. The images were then randomly flipped horizontally.

train_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize((224,224)),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std)
])

A standard SGD optimizer was used with a starting learning rate of 0.1, momentum of 0.9 and weight decay (L2 regularisation) of 1e-4. The original text used a cosine learning rate scheduler, however, I opted for a plateau learning rate scheduler instead. The learning rate was stepped down from 0.1 to 1e-7 by increments of 0.1 every time the validation loss did not improve for 3 epochs.

optimizer = torch.optim.SGD(filter(lambda param: param.requires_grad, network.parameters()), lr = 0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=0.1, mode='min', patience=3, min_lr=1e-7, threshold_mode='abs', threshold=1e-4)

Xu et al [1] also used a batch size of 1024 and 8 GPUs. I, however, did not have access to this and settled for training on a GTX 1060 6GB which could fit a batch size of 128.

Another issue was the size of the original training set used. The original text used ImageNet-1000 which is over 160GB. This was not feasible to use for my trials, instead, I opted for TinyImageNet-200 from Stanford [6]. This is a smaller dataset with only 200 classes compared to the 1000 classes of ImageNet-1000 and was easier to work with.

One omission from the paper was a lack of early stopping. Early stopping keeps track of a metric, which in our case would be the validation loss. When this metric does not improve over a set amount of epochs, the training is stopped to reduce the amount of overfitting. I opted to include this feature due to the benefits it brings in preventing overfitting.

PyTorch does not feature native early stopping, so the early-stopping-pytorch [7] repository was used. This is a great implementation of early stopping and I’ve used this throughout my deep-learning PyTorch projects.

early_stop = EarlyStopping(patience=patience, verbose=True, path=path_to_model+'.pt')
...
if early_stop.counter == 0:
torch.save(
{
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'validation_loss': np.mean(valid_loss),
'train_loss': np.mean(train_loss)
},
path_to_model+'_complete_collection.tar'
)
torch.save(model, path_to_model+'_full_model.pth')

early_stop_loss = np.mean(valid_loss)
early_stop(early_stop_loss, model)

Results

The results are shown below. As can be seen, our validation accuracy does not follow the training accuracy. We can see large amounts of overfitting occurring. Originally, I thought this was an issue in my training process. However, taking a look at the literature on the TinyImageNet dataset it is notorious for overfitting. Barnes et al [4] only managed 54% test accuracy. Moreover, the original EtinyNet [1] only managed 66% test accuracy on ImageNet-1000.

EtinyNet-1.0 Results

Nevertheless, I aimed to try and decrease the amount of overfitting in an attempt to increase the final test accuracy. To do this, I applied a random crop to the input images and then scaled the images back to 224x224. The chances of a random horizontal flip were increased from 0.5 to 0.7 and finally, a dropout layer was inserted before the final classification layer. The probability of dropout started at 0.1 and was increased by a factor of 0.1 to investigate the effect of an increase in regularisation on the test accuracy.

train_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize((224,224)),
torchvision.transforms.RandomResizedCrop((224,224), scale=(0.4, 1.0)),
torchvision.transforms.RandomHorizontalFlip(p=0.7),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std)
])
self.layers = torch.nn.Sequential(
self.starting_conv_layer,
self.blocks,
self.global_pool,
torch.nn.Flatten(),
torch.nn.Dropout(p=0.1),
self.fully_connected
)

We can see from the results below that increasing regularisation in the network decreases the amount of overfitting and increases the test accuracy. Our best result achieved 60% top-1 accuracy and 80% top-5 accuracy with a value of 0.4 for the dropout layer.

EtinyNet Results

The distance between the validation metric and the training metric usually measures the amount of overfitting. As we increase the dropout parameter and thus increase the amount of regularisation we can see a reduction in overfitting. This shows us that our regularisation techniques are indeed working as expected.

Dropout p=0.1
Dropout p=0.3
Dropout p=0.5

Comparing our network to state-of-the-art networks on TinyImageNet, we were able to outperform complex networks with larger amounts of parameters. For example, ResNet-18 could only achieve 42% top-1 test accuracy [2] with around 11 million parameters [3].

Conclusions

In this blog, I detailed the implementation and training of EtinyNet-1.0 in PyTorch. This outlined the design decisions and any assumptions made. We were able to achieve comparable results compared to much larger models.

The full code can be found here: GitHub Link

--

--

Nathan Bailey

MSc AI and ML Student @ ICL. Ex ML Engineer @ Arm, Ex FPGA Engineer @ Arm + Intel, University of Warwick CSE Graduate, Climber. https://www.nathanbaileyw.com