\begin{equation} \label{eq:loss} \ell (y,\hat{y}) = {\mathcal{1}\!\left[y \ne \hat{y}\right]} \end{equation}
Logger.info("Loading MNIST dataset.");
// === PARAMETERS === //
// <the path to the mnist train data>
String trainPath = "data/MNIST/train.txt";
// <the path to the mnist test data>
String testPath = "data/MNIST/test.data.txt";
// <the path to the mnist validation data>
String valPath = "data/MNIST/val.data.txt";
int epochNum = 1;
int readerType = 0;
int isAvg = 1;
int numExamples2Display = 3;
int numOfClasses = 10;
int maxFeatures = 784;
Reader reader = getReader(readerType);
// ================== //
// load the data
InstancesContainer mnistTrainInstances = reader.readData(trainPath, Consts.SPACE,
Consts.COLON_SPLITTER);
InstancesContainer mnistDevelopInstances = reader.readData(valPath, Consts.SPACE,
Consts.COLON_SPLITTER);
InstancesContainer mnistTestInstances = reader.readData(testPath, Consts.SPACE,
Consts.COLON_SPLITTER);
if (mnistTrainInstances.getSize() == 0) return;
// ======= SVM ====== //
// init the first weight vector
Vector W = new Vector() {{put(0, 0.0);}};
// model parameters
ArrayList<Double> arguments = new ArrayList<Double>() {{add(0.1);add(0.1);}};
// build the model
StructEDModel mnist_model = new StructEDModel(W, new SVM(), new TaskLossMultiClass(),
new InferenceMultiClassOld(numOfClasses), null,
new FeatureFunctionsSparse(numOfClasses, maxFeatures), arguments);
// train
mnist_model.train(mnistTrainInstances, null, mnistDevelopInstances, epochNum, isAvg);
// predict
mnist_model.predict(mnistTestInstances, null, numExamples2Display);
// plot the error on the validation set
// the true flag indicates that we saves the image to img folder in the project directory
// if the img directory does not exists it will create it
mnist_model.plotValidationError(true);
Logger.info("Loading IRIS dataset.");
// ============================ IRIS DATA ============================= //
// === PARAMETERS === //
// <the path to the iris train data>
String trainPath = "data/iris/iris.train.txt";
// <the path to the iris test data>
String testPath = "data/iris/iris.test.txt";
int epochNum = 10;
int isAvg = 1;
int numExamples2Display = 3;
int numOfClasses = 3;
int maxFeatures = 4;
int readerType = 0;
Reader reader = getReader(readerType);
// ================== //
// load the data
InstancesContainer irisTrainInstances = reader.readData(trainPath, Consts.COMMA_NOTE,
Consts.COLON_SPLITTER);
InstancesContainer irisTestInstances = reader.readData(testPath, Consts.COMMA_NOTE,
Consts.COLON_SPLITTER);
// ======= PA ====== //
// init the first weight vector to be zeros
Vector W = new Vector() {{put(0, 0.0);}};
// model parameters
ArrayList<Double> arguments = new ArrayList<Double>() {{add(1.0);}};
// build the model
StructEDModel iris_model = new StructEDModel(W, new PassiveAggressive(), new TaskLossMultiClass(),
new InferenceMultiClassOld(numOfClasses), null,
new FeatureFunctionsSparse(numOfClasses, maxFeatures), arguments);
// train
iris_model.train(irisTrainInstances, null, null, epochNum, isAvg);
// predict
ArrayList<PredictedLabels> iris_labels = iris_model.predict(irisTestInstances, null,
numExamples2Display);
// printing the predictions
for(int i=0 ; i<iris_labels.size() ; i++)
Logger.info("Desire Label: "+irisTestInstances.getInstance(i).getLabel()+", Predicted Label: "+
iris_labels.get(i).firstKey());