Multiclass


In this tutorial we will demonstrate how to use StructED for multi-class problems. For this tutorial we will use a multi-class problem, the standard benchmark MNIST, we supply a small subset from this dataset, you can download the full data set from mnist.
This tutorial is also suitable for any other multi-class problem, at the end it we demonstrate how use it also with the Iris flower data set.

Before we begin!

You can use this tutorial in two ways:
  1. Run the code as we implemented it:
  2. In order to run the tutorials you can simply run the bash script in any of the tutorials directories in our GitHub repository. The tutorials can be found under the tutorials-code directory under multiclass folder.

  3. Write the code yourself:
  4. We should setup our development environment first. Open a Java project and add the StructED jar file to your build path (it can be found under the bin directory at StructED repository). You can download all StructED source codes and jar files from GitHub.
Now let's begin!

MNIST


MNIST is a dataset of handwritten digits labeled from 0 to 9, and contains 60,000 training examples and 10,000 test examples. Each example has been size-normalized and centered in a fixed-size image of 28 $\times$ 28.

Notice that when downloading MNIST db from the web we get a compressed files, so we need to make a little bit of preprocessing before using StructED.
StructED should get as input the db from the following format,

Each example should be in a different line (meaning each example should end with \n). Each feature/value pair should be separated by a space character and a: between the feature number and its value. Features with value zero can be skipped. The label(target value) should be the first value in each example.

For example, the line: 3 1:0.55 8:0.07 2293:0.11

specifies an example of class 3 for which feature number 1 has the value 0.55, feature number 8 has the value 0.07, feature number 2293 has the value 0.11, all the other features have value 0.

We provide a sample from the MNIST data set with this tutorial.

The Code


Here, we present what classes do we need to add and what interfaces do we need to implement. We provide the source code for all of the classes and interfaces.

Task Loss

In order to add new loss/cost function one needs to implement the ITaskLoss interface. In our package we implemented a 0-1 loss function:

\begin{equation} \label{eq:loss} \ell (y,\hat{y}) = {\mathcal{1}\!\left[y \ne \hat{y}\right]} \end{equation}


In other words, the loss will be one if y not equal to $\hat{y}$, otherwise it will be zero.

Inference

In the multi-class cases our inference will be to go over all the possible classes. In case we would like to use another inference functions all we need to do is just implement the IInference interface and use it with the desired learning algorithm.

Feature Functions

In multi-class problems there is no real need for feature functions, but we need to store as many weight vectors as the number of classes defined by the task settings. To solve that we just concatenate all the vectors one after the other into a single weight vector. Thus, the feature functions is just putting the right vector in his place according to its class number. In case we would like to use another feature functions all we need to do is just implement the IFeatureFunctions interface and use it with the desired learning algorithm.

Running The Code


Now we can create a StructEDModel object with the interfaces we have just implemented, all we have left to do is to choose the model. The example code here is for MNIST dataset. The code here also generates a validation error graph and saves it.

Here is a snippet of such code (the complete code can be found at the package repository under tutorials-code folder):

                     
        Logger.info("Loading MNIST dataset.");

        // === PARAMETERS === //
        // <the path to the mnist train data>
        String trainPath = "data/MNIST/train.txt"; 
        // <the path to the mnist test data>
        String testPath = "data/MNIST/test.data.txt"; 
        // <the path to the mnist validation data>
        String valPath = "data/MNIST/val.data.txt"; 
        int epochNum = 1;
        int readerType = 0;
        int isAvg = 1;
        int numExamples2Display = 3;
        int numOfClasses = 10;
        int maxFeatures = 784;
        Reader reader = getReader(readerType);
        // ================== //

        // load the data
        InstancesContainer mnistTrainInstances = reader.readData(trainPath, Consts.SPACE, 
                Consts.COLON_SPLITTER);
        InstancesContainer mnistDevelopInstances = reader.readData(valPath, Consts.SPACE, 
                Consts.COLON_SPLITTER);
        InstancesContainer mnistTestInstances = reader.readData(testPath, Consts.SPACE, 
                Consts.COLON_SPLITTER);
        if (mnistTrainInstances.getSize() == 0) return;

        // ======= SVM ====== //
        // init the first weight vector
        Vector W = new Vector() {{put(0, 0.0);}}; 
        // model parameters
        ArrayList<Double> arguments = new ArrayList<Double>() {{add(0.1);add(0.1);}}; 

        // build the model
        StructEDModel mnist_model = new StructEDModel(W, new SVM(), new TaskLossMultiClass(),
                new InferenceMultiClassOld(numOfClasses), null, 
                        new FeatureFunctionsSparse(numOfClasses, maxFeatures), arguments);
        // train
        mnist_model.train(mnistTrainInstances, null, mnistDevelopInstances, epochNum, isAvg);
        // predict
        mnist_model.predict(mnistTestInstances, null, numExamples2Display);
        // plot the error on the validation set
        // the true flag indicates that we saves the image to img folder in the project directory
        // if the img directory does not exists it will create it
        mnist_model.plotValidationError(true);
              


Another example of different model on Iris dataset:

                     
        Logger.info("Loading IRIS dataset.");

        // ============================ IRIS DATA ============================= //
        // === PARAMETERS === //
        // <the path to the iris train data>
        String trainPath = "data/iris/iris.train.txt"; 
        // <the path to the iris test data>
        String testPath = "data/iris/iris.test.txt"; 
        int epochNum = 10;
        int isAvg = 1;
        int numExamples2Display = 3;
        int numOfClasses = 3;
        int maxFeatures = 4;
        int readerType = 0;
        Reader reader = getReader(readerType);
        // ================== //

        // load the data
        InstancesContainer irisTrainInstances = reader.readData(trainPath, Consts.COMMA_NOTE, 
                Consts.COLON_SPLITTER);
        InstancesContainer irisTestInstances = reader.readData(testPath, Consts.COMMA_NOTE, 
                Consts.COLON_SPLITTER);
        // ======= PA ====== //
        // init the first weight vector to be zeros
        Vector W = new Vector() {{put(0, 0.0);}}; 
        // model parameters
        ArrayList<Double> arguments = new ArrayList<Double>() {{add(1.0);}}; 

        // build the model
        StructEDModel iris_model = new StructEDModel(W, new PassiveAggressive(), new TaskLossMultiClass(),
                new InferenceMultiClassOld(numOfClasses), null, 
                      new FeatureFunctionsSparse(numOfClasses, maxFeatures), arguments);
        // train
        iris_model.train(irisTrainInstances, null, null, epochNum, isAvg);
        // predict
        ArrayList<PredictedLabels> iris_labels = iris_model.predict(irisTestInstances, null, 
                numExamples2Display);

        // printing the predictions
        for(int i=0 ; i<iris_labels.size() ; i++)
            Logger.info("Desire Label: "+irisTestInstances.getInstance(i).getLabel()+", Predicted Label: "+
                iris_labels.get(i).firstKey());