# How You Structure Information Matters
_This example has been adapted from the [SciKit-Learn manual](https://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html)._

For many types machine learning problems, the way you structure (model) the data is important. This example shows how the structure of image data can be modified so that classical machine learning models can be used to predict the contents of the image.

#### MNIST
MNIST is an image dataset often used to test computer vision algorithms. It is composed of handwritten digits with a training set of 60,000 examples and a test set of 10,000 examples. [You can download the entire dataset online.](http://yann.lecun.com/exdb/mnist/).

It is considered the "Hello World" of computer vision, and for that reason, it is included in the set of data distributed with SciKit-Learn. This notebook shows how to structure the data and submit it as a set of labels to a SVM to determine the digit.

### Import Dependencies and Configure Visualization

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# sklearn includes a complement of datasets which can be used
# to explore different types of machine learning examples
from sklearn import datasets

# SVM: Support vector machine
from sklearn import svm

# Ensemble/RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier

# sklearn.metrics helps to quantify the quality of predictions
# See https://scikit-learn.org/stable/modules/model_evaluation.html
from sklearn import metrics
from sklearn.model_selection import train_test_split

### Load Data from SciKit-Learn
_Dataset is structured as a multi-dimensional array._

The data is structured as a set of 8x8 pixel images layered on top of one another.

In [None]:
# MNIST is a set of hand-drawn numbers encoded in 8x8 pixel images.
# It is considered the "Hello World" of computer vision, and for that
# reason it is included in the set of data distributed with SciKit-Learn.
digits = datasets.load_digits()
print('Dataset data type: %s' % type(digits.images))
print("Dataset structure. Count: %d. Height: %d. Width: %d." % digits.images.shape)

#### Example Images

Combine images and data into a combined structure for visualization.

* `zip` is useful for combining elements from two separate iterables into a single object.

In [None]:
# Pull the first nine images from the images and their associated targets
example_images = tuple(zip(digits.images[:10], digits.target))

In [None]:
# Create a figure showing the images
for i, (idata, ilabel) in enumerate(example_images):
    sub = plt.subplot(3, 4, i+1)
    sub.imshow(idata, cmap='Greys')

## Create Classification Models to Predict Digits

#### Re-shape Data for ML Input
To apply a classifier to the image, the data model needs to be "flattened" into a single dimension that can be fed to the classifier.

* In the example below, we utilize the `reshape` method to transform the input array. 
  - The structure of the resulting data will be the number of samples (determined by taking the length of the input structure) and the squashing of the interior structure.
  - Using `-1` for the second transform tells NumPy to infer the length of the row dimension from the remaining dimensions in the structure. [See the SciPy documentation for details.](https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html).

In [None]:
# Flatten the structure of the array using reshape
# -1 tells NumPy to determine the row size from the
# length of the inner dimensions.
mnist_mldata = digits.images.reshape((len(digits.images), -1))
print('Two-dimensional structure dimensions', mnist_mldata.shape) 

#### Create Machine Learning Model: Support Vector Classifier (`SVC`)
1. Split data into training/testing sets
2. Initialize machine learning algorithm
3. Train model instance on data: `fit`
4. Assess model performance: `score`

_Support Vector Machines use a linear model for classification problems and a "kernel" to help it fit the data to the outcome classes. For classification, the SVM attempts to create a "line" (or hyperplane for higher-order datasets) which separate the data into specific classes matching the training labels. See [Support Vector Machines: An Overview](https://towardsdatascience.com/https-medium-com-pupalerushikesh-svm-f4b42800e989) for details._

In [None]:
# Step 1: Split data into training/testing sets
X_train, X_test, y_train, y_test = train_test_split(
    mnist_mldata, digits.target, test_size=0.3, shuffle=False)

_**Note**: When creating the testing/training splits that we did not shuffle the data. This was done to preserve ordering so that the original data can be retrieved for visualization._

In [None]:
# Step 2: Initialize machine learning algorithm
mnist_svm = svm.SVC()

In [None]:
# Step 3: Train model instance on the data
mnist_svm.fit(X_train, y_train)

In [None]:
# Step 4: Assess model performance on the testing data
mnist_svm.score(X_test, y_test)

#### Alternative Classification Algorithm: Random Forest
When working with a classification problem, it is often useful to compare multiple algoritms against one another to see relative performance.

In [None]:
# Initialize model instance, fit training data, assess
mnist_rf = RandomForestClassifier()
mnist_rf.fit(X_train, y_train)
mnist_rf.score(X_test, y_test)

### Apply Model to Data and Visualize Results
The example code below uses the SVC model to create a set of predictions and plot the images with the predicted and actual labels. _When we created the testing and training datasets, we didn't shuffle the data. This allows us to pull the images from the original array for visualization._

In [None]:
# Apply SVC model to create a set of "predictions"
predictions_svm = mnist_svm.predict(X_test)

In [None]:
# Create a structure including sample images, predictions, and labels
example_predictions_svm = tuple(
    zip(digits.images[len(X_train):], predictions_svm, y_test))

# Check resulting structure to ensure that data combined correctly
example_incorrect_predictions_svm = tuple(
    filter(lambda v: v[1] != v[2], example_predictions_svm))
print('Number of incorrectly matched images: {}/{} ({:.3f})'.format(
    len(example_incorrect_predictions_svm), len(example_predictions_svm), 
    len(example_incorrect_predictions_svm)/len(example_predictions_svm)))

#### Sample of Correctly Matched Images

In [None]:
plt.figure(figsize=(6, 6))

# Visualize correctly matched images
for i, (idata, p, ilabel) in enumerate(example_predictions_svm[:10]):
    sub = plt.subplot(3, 4, i+1)
    sub.imshow(idata, cmap='Greys')
    plt.title('P: %s. A: %s' % (p, ilabel))

#### Sample of Incorrectly Matched Images

In [None]:
plt.figure(figsize=(6, 6))

# Visualize correctly matched images
for i, (idata, p, ilabel) in enumerate(example_incorrect_predictions_svm[:12]):
    sub = plt.subplot(3, 4, i+1)
    sub.imshow(idata, cmap='Greys')
    plt.title('P: %s. A: %s' % (p, ilabel))