Deep Learning

171_Image Classification using Torchvision(4)

elif 2024. 5. 20. 21:51

Continuing from the previous post(170_Image Classification using Torchvision(3)).

 

def show_confusion_matrix(confusion_matrix, class_names):
    cm = confusion_matrix.copy()

    cell_counts = cm.flatten()

    cm_row_norm = cm / cm.sum(axis=1)[:, np.newaxis]

    row_percentages = ["{0:.2f}".format(value) for value in cm_row_norm.flatten()]

    cell_labels = [f"{cnt}\n{per}" for cnt, per in zip(cell_counts, row_percentages)]
    cell_labels = np.asarray(cell_labels).reshape(cm.shape[0], cm.shape[1])

    df_cm = pd.DataFrame(cm_row_norm, index=class_names, columns=class_names)

    hmap = sns.heatmap(df_cm, annot=cell_labels, fmt="", cmap="Blues")
    hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
    hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
    plt.ylabel('True Sign')
    plt.xlabel('Predicted Sign');


cm = confusion_matrix(y_test, y_pred)
show_confusion_matrix(cm, class_names)

 

 

The above code defines a function to visualize the confusion matrix and uses it to evaluate the model's prediction performance. It flattens the confusion matrix into a 1D array to get the values of each cell, then normalizes each row by dividing by the sum of that row to convert the counts to proportions. It formats each value to two decimal places, converts it to a 1D array, combines the actual count and proportion for each cell to create labels, and reshapes this label array back to the original shape of the confusion matrix. The normalized confusion matrix is converted to a Pandas DataFrame with class names set for rows and columns, and a heatmap is generated using Seaborn.

 

Now, we'll see how to classify new data that is not included in the dataset. The example image to be used is as follows.

 

 

def predict_proba(model, image_path):
  img = Image.open(image_path)
  img = img.convert('RGB')
  img = transforms['test'](img).unsqueeze(0)

  pred = model(img.to(device))
  pred = F.softmax(pred, dim=1)
  return pred.detach().cpu().numpy().flatten()


pred = predict_proba(base_model, 'stop-sign.jpg')
pred

 

 

Plotted for easier understanding, it looks as follows.

 

def show_prediction_confidence(prediction, class_names):
    pred_df = pd.DataFrame({
    'class_names': class_names,
    'values': prediction
    })
    sns.barplot(x='values', y='class_names', data=pred_df, orient='h')
    plt.xlim([0, 1]);
show_prediction_confidence(pred, class_names)

 

 

The model correctly predicted the stop sign. When performing predictions on new data not included in the dataset, the model showed good performance. Of course, more examples should be examined, but for the blog post, this should suffice. Detailed information and code can be found in the references below.

 

ref : Venelin Valkov - Get SH_T Done with PyTorch_ Solve Real-world Machine Learning Problems with Deep Neural Networks in Python-Venelin Valkov (2020)