A neural network for elastography (part 2)
Let's train a neural network to solve a (slightly more realistic) typical problem encountered in elastography.
In part 1, we talked about elastography, and we created a model that predicts the stiffness from a given displacement field. The displacement field had no noise, so we decided to add some to make it looks more realistic. Our goal is to beat that model.
Introduction and Data Preparation sections will be almost identical to part 1, so we won't comment much about it. For details, check part 1.
As mentioned in part 1, it is easy to run all the code we show. Just click on the "Open in Colab" button on the top of this page. It will open this page as a Colab notebook.
First, if you're using colab, you need to install PyTorch and fastai.
!pip install torch
!pip install fastai
!pip install fastai --upgrade -q
Then, we download and unzip the data we're going to use.
import os.path
if not os.path.exists('MNIST_input_files.zip'):
!wget https://open.bu.edu/bitstream/handle/2144/38693/MNIST_input_files.zip
if not os.path.exists('FEA_displacement_results_step5.zip'):
!wget https://open.bu.edu/bitstream/handle/2144/38693/FEA_displacement_results_step5.zip
!unzip -n "MNIST_input_files.zip" -d "."
!unzip -n "FEA_displacement_results_step5.zip" -d "uniaxial/"
We import some libraries.
import pandas as pd
import torch
from fastai.vision.all import *
We load the data from text files to PyTorch tensor.
df = pd.read_csv('MNIST_input_files/mnist_img_train.txt', sep=' ', header=None)
imgs_train = torch.tensor(df.values)
df = pd.read_csv('MNIST_input_files/mnist_img_test.txt', sep=' ', header=None)
imgs_valid = torch.tensor(df.values)
imgs_train.shape, imgs_valid.shape
df = pd.read_csv('uniaxial/FEA_displacement_results_step5/summary_dispx_train_step5.txt', sep=' ', header=None)
ux5s_train = torch.tensor(df.values)
df = pd.read_csv('uniaxial/FEA_displacement_results_step5/summary_dispy_train_step5.txt', sep=' ', header=None)
uy5s_train = torch.tensor(df.values)
df = pd.read_csv('uniaxial/FEA_displacement_results_step5/summary_dispx_test_step5.txt', sep=' ', header=None)
ux5s_valid = torch.tensor(df.values)
df = pd.read_csv('uniaxial/FEA_displacement_results_step5/summary_dispy_test_step5.txt', sep=' ', header=None)
uy5s_valid = torch.tensor(df.values)
ux5s_train.shape, uy5s_train.shape, ux5s_valid.shape, uy5s_valid.shape
We reshape the data.
ux5s_train = torch.reshape(ux5s_train, (-1,1,28,28))
uy5s_train = torch.reshape(uy5s_train, (-1,1,28,28))
ux5s_valid = torch.reshape(ux5s_valid, (-1,1,28,28))
uy5s_valid = torch.reshape(uy5s_valid, (-1,1,28,28))
imgs_train = torch.reshape(imgs_train, (-1,1,28,28))
imgs_valid = torch.reshape(imgs_valid, (-1,1,28,28))
ux5s_train.shape, uy5s_train.shape, ux5s_valid.shape, uy5s_valid.shape, imgs_train.shape, imgs_valid.shape
The displacement field is inverted, so we fix it.
show_image(imgs_train[2], cmap='Greys')
show_image(uy5s_train[2], cmap='Greys')
ux5s_train = torch.flip(ux5s_train, dims=[1,2])
uy5s_train = torch.flip(uy5s_train, dims=[1,2])
ux5s_valid = torch.flip(ux5s_valid, dims=[1,2])
uy5s_valid = torch.flip(uy5s_valid, dims=[1,2])
show_image(imgs_train[2], cmap='Greys')
show_image(uy5s_train[2], cmap='Greys')
We convert from pixel values to stiffness.
imgs_train = (imgs_train / 255.) * 99. + 1.
imgs_valid = (imgs_valid / 255.) * 99. + 1.
imgs_train = imgs_train / 100.
imgs_valid = imgs_valid / 100.
imgs_train.min(), imgs_valid.min(), imgs_train.max(), imgs_valid.max()
We put horizontal and vertical displacements into the same tensor.
us_train = torch.cat([ux5s_train, uy5s_train], dim=1)
us_valid = torch.cat([ux5s_valid, uy5s_valid], dim=1)
us_train.shape, us_valid.shape
We do some cleaning.
del ux5s_train; del uy5s_train; del ux5s_valid; del uy5s_valid
Okay, now we have some new stuff.
This fastai Transform is the same from part 1, except that it adds some noise. We take the noise from a Normal Distribution with mean equals 0, and standard deviation equals 0.1. We will see that noise makes the task more challenging.
class GetNormalizedData(Transform):
def __init__(self, us, imgs, mean, std, noise_std):
self.us, self.imgs = us, imgs
self.mean, self.std = mean, std
self.noise_std = noise_std
def encodes(self, i):
us_norm = torch.true_divide((self.us[i] - self.mean.view(2,1,1)), self.std.view(2,1,1))
us_with_noise = us_norm + torch.randn(2,28,28)*self.noise_std
return (us_with_noise.float(), self.imgs[i].float())
We calculate the mean and standard deviation from our dataset so we can normalize the data.
us_mean = torch.mean(us_train, dim=[0,2,3])
us_std = torch.std(us_train, dim=[0,2,3])
We create fastai TfmdLists and DataLoaders.
noise_std = 0.1
train_tl= TfmdLists(range(len(us_train[:])),
GetNormalizedData(us_train[:], imgs_train[:], us_mean, us_std, noise_std)
)
valid_tl= TfmdLists(range(len(us_valid[:])),
GetNormalizedData(us_valid[:], imgs_valid[:], us_mean, us_std, noise_std)
)
dls = DataLoaders.from_dsets(train_tl, valid_tl, bs=64)
if torch.cuda.is_available(): dls = dls.cuda()
Okay, let's take a look at our batch.
x,y = dls.one_batch()
x.shape, y.shape, x.mean(), x.std()
show_image(x[0][1], cmap='Greys')
show_image(y[0], cmap='Greys')
show_image(x[1][1], cmap='Greys')
show_image(y[1], cmap='Greys')
show_image(x[2][1], cmap='Greys')
show_image(y[2], cmap='Greys')
Okay, 0.1 for the standard deviation seems to be a good value for a challenging but yet possible task. We see from the images above that we can still identify the numbers in the displacement field, but the details are hard to see.
Remember, our goal is to beat the model created in part 1. That model got a validation loss of 0.0065, so that's our goal. The task now is definitely more difficult, so we will use a deeper neural network.
We will not train it all at once (we probably could, try it). We begin training the same model we used in part 1. Then, we remove the output layer and add a new hidden layer and a new output layer. We train the whole model. After that, we repeat the process of removing/adding layers and training them.
We create a class Base
, which will be the input/hidden layer, and a class Head
, which will be the output layer. We multiply the activations by 1.1 after the sigmoid, so the model returns values between 0 and 1.1.
class Base(nn.Module):
def __init__(self, n_in, n_out):
super(Base, self).__init__()
self.conv = nn.Conv2d(n_in, n_out, kernel_size=5, stride=1, padding=2, padding_mode='reflect', bias=True)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
class Head(nn.Module):
def __init__(self, n_in, n_out, y_range):
super(Head, self).__init__()
self.y_range = y_range
self.conv = nn.Conv2d(n_in, n_out, 5, 1, 2, padding_mode='reflect', bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self.sigmoid(self.conv(x))*self.y_range
We begin with the same architecture used in part 1.
cnn1 = nn.Sequential(Base(2, 4), Head(4, 1, 1.1))
l1 = Learner(dls, cnn1, loss_func=F.mse_loss, model_dir='')
We find a good candidate for the learning rate, and then we train the model.
l1.lr_find()
l1.fit_one_cycle(10,lr_max=1e-2)
We can save it.
l1.path = Path('')
l1.save('02_cnn1_ep10')
Let's see how our model is doing.
preds, targets, losses = l1.get_preds(with_loss=True)
_, indices = torch.sort(losses, descending=True)
for i in indices[:3]:
show_image(torch.cat([preds, torch.ones(preds.shape[0],1,28,1), targets], dim=3)[i], cmap='Greys_r')
I_preds = ((preds*100)-1)*(255/99)
I_targets = ((targets*100)-1)*(255/99)
I_error = torch.abs(I_preds - I_targets)
I_mean = torch.mean(I_error)
I_std = torch.std(I_error)
I_mean, I_std
Mean and standard deviation are 14.6 and 30.9, respectively. It does not seem terrible, but we are still far from our goal. Let's add more layers.
We take the first layer of our previous model and add two new layers.
cnn2 = nn.Sequential(l1.model[0], Base(4, 8), Head(8, 1, 1.1))
l2 = Learner(dls, cnn2, loss_func=F.mse_loss, model_dir='')
We find a good learning rate and train our model.
l2.lr_find()
l2.fit_one_cycle(5, lr_max=3e-3)
l2.fit_one_cycle(5, lr_max=3e-3)
l2.path = Path('')
l2.save('models/02_cnn2_ep10')
Let's see if the model is improving.
preds, targets, losses = l2.get_preds(with_loss=True)
_, indices = torch.sort(losses, descending=True)
for i in indices[:3]:
show_image(torch.cat([preds, torch.ones(preds.shape[0],1,28,1), targets], dim=3)[i], cmap='Greys_r')
I_preds = ((preds*100)-1)*(255/99)
I_targets = ((targets*100)-1)*(255/99)
I_error = torch.abs(I_preds - I_targets)
I_mean = torch.mean(I_error)
I_std = torch.std(I_error)
I_mean, I_std
Nice, the mean and the standard deviation decreased. Let's add a new layer.
Again, we keep everything but the last layer and add two new layers (Base
and Head
).
cnn3 = nn.Sequential(l2.model[0], l2.model[1], Base(8, 16), Head(16, 1, 1.1))
l3 = Learner(dls, cnn3, loss_func=F.mse_loss, model_dir='')
We find a good learning rate and train the model.
l3.lr_find()
l3.fit_one_cycle(5, lr_max=1e-2)
l3.fit_one_cycle(5, lr_max=1e-2)
l3.path = Path('')
l3.save('models/02_cnn3_ep10')
Nice, validation loss is under 0.01. Let's add more layers.
cnn4 = nn.Sequential(l3.model[0], l3.model[1], l3.model[2], Base(16, 32), Head(32, 1, 1.1))
l4 = Learner(dls, cnn4, loss_func=F.mse_loss, model_dir='')
l4.lr_find()
l4.fit_one_cycle(5, lr_max=3e-3)
l4.fit_one_cycle(5, lr_max=3e-3)
The loss increased in the first epochs. This could be because the learning rate is too high for the first layers. Next time we'll try something different.
l4.path = Path('')
l4.save('models/02_cnn4_ep10')
Let's add more layers.
cnn5 = nn.Sequential(*l4.model[0:4], Base(32, 64), Head(64, 1, 1.1))
l5 = Learner(dls, cnn5, loss_func=F.mse_loss, model_dir='')
l5.lr_find()
l5.fit_one_cycle(5, lr_max=3e-3)
This time we run lr_find()
again to train with a different learning rate.
l5.lr_find()
l5.fit_one_cycle(5, lr_max=1e-5)
l5.path = Path('')
l5.save('models/02_cnn5_std=0.1_ep10')
Validation loss was 0.0065 in part 1, so we're very close. Let's add more layers.
cnn6 = nn.Sequential(*l5.model[0:5], Base(64, 128), Head(128, 1, 1.1))
l6 = Learner(dls, cnn6, loss_func=F.mse_loss, model_dir='')
l6.lr_find()
l6.fit_one_cycle(5, lr_max=3e-3)
l6.lr_find()
l6.fit_one_cycle(5, lr_max=1e-5)
l6.path = Path('')
l6.save('models/02_cnn6_ep10')
We did it! Our validation loss is less than 0.0065.
Let's now check how good the predictions are.
inputs, preds, targets, losses = l6.get_preds(with_input=True, with_loss=True)
On the left, we show predictions that had the ten worst losses. On the right, we show the targets.
_, indices = torch.sort(losses, descending=True)
for i in indices[:10]:
show_image(torch.cat([preds, torch.ones(preds.shape[0],1,28,1), targets], dim=3)[i], cmap='Greys_r')
It seems nice. Even the worst predictions capture some details of the handwritten digits.
Let's calculate the same metrics we used in part 1.
I_preds = ((preds*100)-1)*(255/99)
I_targets = ((targets*100)-1)*(255/99)
I_error = torch.abs(I_preds - I_targets)
I_mean = torch.mean(I_error)
I_std = torch.std(I_error)
I_mean, I_std
In part 1, the mean was 8.5, and the standard deviation was 19. Our new model got a mean equals 6.7, and a standard deviation equals 18.9.
So we did it! We beat the model created in part 1, even though here, we added noise in the data, while in part 1, we did not. Of course, we could try to keep improving the model. For instance, we could do some hyper-parameter tweaking, include data from other experiments (such as shear or confined compression), add more layers, etc. However, this is left as an exercise for the reader 😆.
To recap, in this notebook, we've created a model that, given a noisy two-dimensional displacement field, predicts the stiffness throughout a body that has a hard inclusion. Below, we show one example of input (noisy displacement field), prediction (predicted stiffness), and target (correct stiffness). We see that our model has determined the shape of the hard inclusion with a good resolution, despite the noise in the displacement field.
i = 123
show_image(inputs[i][0], cmap='Greys_r', title='Input: Horizontal displacement')
show_image(inputs[i][1], cmap='Greys_r', title=' Input: Vertical displacement ')
show_image(preds[i], cmap='Greys_r', title=' Output: Predicted stiffness ')
show_image(targets[i], cmap='Greys_r', title=' Target: Correct stiffness ')