In this tutorial we will demonstrate how to use StructED for multi-class problems. For this tutorial we will use a multi-class problem, the standard benchmark MNIST, we supply a small subset from this dataset, you can download the full data set from
mnist.
This tutorial is also suitable for any other multi-class problem, at the end it we demonstrate how use it also with the
Iris flower data set.
Before we begin!
You can use this tutorial in two ways:
- Run the code as we implemented it:
In order to run the tutorials you can simply run the bash script in any of the tutorials directories in our GitHub repository. The tutorials can be found under the tutorials-code directory under multiclass folder.
- Write the code yourself:
We should setup our development environment first. Open a Java project and add the StructED jar file to your build path (it can be found under the bin directory at StructED repository). You can download all StructED source codes and jar files from GitHub.
Now let's begin!
MNIST
MNIST is a dataset of handwritten digits labeled from 0 to 9, and contains 60,000 training examples and 10,000 test examples. Each example has been size-normalized and centered in a fixed-size image of 28 $\times$ 28.
Notice that when downloading MNIST db from the web we get a compressed files, so we need to make a little bit of preprocessing before using StructED.
StructED should get as input the db from the following format,
Each example should be in a different line (meaning each example should end with \n). Each feature/value pair should be separated by a space character and a: between the feature number and its value. Features with value zero can be skipped. The label(target value) should be the first value in each example.
For example, the line: 3 1:0.55 8:0.07 2293:0.11
specifies an example of class 3 for which feature number 1 has the value 0.55, feature number 8 has the value 0.07, feature number 2293 has the value 0.11, all the other features have value 0.
We provide a sample from the MNIST data set with this tutorial.
The Code
Here, we present what classes do we need to add and what interfaces do we need to implement. We provide the source code for all of the classes and interfaces.
Task Loss
In order to add new loss/cost function one needs to implement the ITaskLoss interface. In our package we implemented a 0-1 loss function:
\begin{equation}
\label{eq:loss}
\ell (y,\hat{y}) = {\mathcal{1}\!\left[y \ne \hat{y}\right]}
\end{equation}
In other words, the loss will be one if y not equal to $\hat{y}$, otherwise it will be zero.
Inference
In the multi-class cases our inference will be to go over all the possible classes. In case we would like to use another inference functions all we need to do is just implement the IInference interface and use it with the desired learning algorithm.
Feature Functions
In multi-class problems there is no real need for feature functions, but we need to store as many weight vectors as the number of classes defined by the task settings. To solve that we just concatenate all the vectors one after the other into a single weight vector.
Thus, the feature functions is just putting the right vector in his place according to its class number.
In case we would like to use another feature functions all we need to do is just implement the IFeatureFunctions interface and use it with the desired learning algorithm.
Running The Code
Now we can create a StructEDModel object with the interfaces we have just implemented, all we have left to do is to choose the model. The example code here is for MNIST dataset. The code here also generates a validation error graph and saves it.
Here is a snippet of such code (the complete code can be found at the package repository under tutorials-code folder):
Logger.info("Loading MNIST dataset.");
// === PARAMETERS === //
// <the path to the mnist train data>
String trainPath = "data/MNIST/train.txt";
// <the path to the mnist test data>
String testPath = "data/MNIST/test.data.txt";
// <the path to the mnist validation data>
String valPath = "data/MNIST/val.data.txt";
int epochNum = 1;
int readerType = 0;
int isAvg = 1;
int numExamples2Display = 3;
int numOfClasses = 10;
int maxFeatures = 784;
Reader reader = getReader(readerType);
// ================== //
// load the data
InstancesContainer mnistTrainInstances = reader.readData(trainPath, Consts.SPACE,
Consts.COLON_SPLITTER);
InstancesContainer mnistDevelopInstances = reader.readData(valPath, Consts.SPACE,
Consts.COLON_SPLITTER);
InstancesContainer mnistTestInstances = reader.readData(testPath, Consts.SPACE,
Consts.COLON_SPLITTER);
if (mnistTrainInstances.getSize() == 0) return;
// ======= SVM ====== //
// init the first weight vector
Vector W = new Vector() {{put(0, 0.0);}};
// model parameters
ArrayList<Double> arguments = new ArrayList<Double>() {{add(0.1);add(0.1);}};
// build the model
StructEDModel mnist_model = new StructEDModel(W, new SVM(), new TaskLossMultiClass(),
new InferenceMultiClassOld(numOfClasses), null,
new FeatureFunctionsSparse(numOfClasses, maxFeatures), arguments);
// train
mnist_model.train(mnistTrainInstances, null, mnistDevelopInstances, epochNum, isAvg);
// predict
mnist_model.predict(mnistTestInstances, null, numExamples2Display);
// plot the error on the validation set
// the true flag indicates that we saves the image to img folder in the project directory
// if the img directory does not exists it will create it
mnist_model.plotValidationError(true);
Another example of different model on Iris dataset:
Logger.info("Loading IRIS dataset.");
// ============================ IRIS DATA ============================= //
// === PARAMETERS === //
// <the path to the iris train data>
String trainPath = "data/iris/iris.train.txt";
// <the path to the iris test data>
String testPath = "data/iris/iris.test.txt";
int epochNum = 10;
int isAvg = 1;
int numExamples2Display = 3;
int numOfClasses = 3;
int maxFeatures = 4;
int readerType = 0;
Reader reader = getReader(readerType);
// ================== //
// load the data
InstancesContainer irisTrainInstances = reader.readData(trainPath, Consts.COMMA_NOTE,
Consts.COLON_SPLITTER);
InstancesContainer irisTestInstances = reader.readData(testPath, Consts.COMMA_NOTE,
Consts.COLON_SPLITTER);
// ======= PA ====== //
// init the first weight vector to be zeros
Vector W = new Vector() {{put(0, 0.0);}};
// model parameters
ArrayList<Double> arguments = new ArrayList<Double>() {{add(1.0);}};
// build the model
StructEDModel iris_model = new StructEDModel(W, new PassiveAggressive(), new TaskLossMultiClass(),
new InferenceMultiClassOld(numOfClasses), null,
new FeatureFunctionsSparse(numOfClasses, maxFeatures), arguments);
// train
iris_model.train(irisTrainInstances, null, null, epochNum, isAvg);
// predict
ArrayList<PredictedLabels> iris_labels = iris_model.predict(irisTestInstances, null,
numExamples2Display);
// printing the predictions
for(int i=0 ; i<iris_labels.size() ; i++)
Logger.info("Desire Label: "+irisTestInstances.getInstance(i).getLabel()+", Predicted Label: "+
iris_labels.get(i).firstKey());