\begin{equation} \label{eq:loss} \ell (y,\hat{y}) = \frac{\sum_{i=1}^m{\mathcal{1}\!\left[y_i \ne \hat{y_i}\right]}}{m} \end{equation}
// ===================== MULTI CLASS ===================== //
Logger.info("================= Multi-class version =================");
Logger.info("Loading data...");
// === LOADING DATA === //
InstancesContainer ocrAllInstancesMulti = reader.readDataMultiClass(dataPath, Consts.TAB,
Consts.COLON_SPLITTER);
InstancesContainer ocrTrainInstancesMulti = getFold(ocrAllInstancesMulti, 1, true);
InstancesContainer ocrTestInstancesMulti = getFold(ocrAllInstancesMulti, 1, false);
if (ocrTrainInstancesMulti.getSize() == 0) return;
// ==================== //
W = new Vector() {{put(0, 0.0);}}; // init the first weight vector
arguments = new ArrayList() {{add(500.0);}}; // model parameters
// ======= PA MULTI-CLASS MODEL ====== //
// create the model
StructEDModel ocr_model_multi_class = new StructEDModel(W, new PassiveAggressive(),
new TaskLossMultiClass(), new InferenceMultiClass(Char2Idx.char2id.size()-1), null,
new FeatureFunctionsSparse(Char2Idx.char2id.size()-1, maxFeatures), arguments);
ocr_model_multi_class.train(ocrTrainInstancesMulti, null, null, epochNum, isAvg, true); // train
ocr_model_multi_class.predict(ocrTestInstancesMulti, null, 1, false); // predict
// ======================================================= //
// ====================== STRUCTURED ===================== //
Logger.info("================= Structured version =================");
// === LOADING DATA === //
InstancesContainer ocrAllInstances = reader.readData(dataPath, Consts.TAB,
Consts.COLON_SPLITTER);
InstancesContainer ocrTrainInstancesStruct = getFold(ocrAllInstances, 1, true);
InstancesContainer ocrTestInstancesStruct = getFold(ocrAllInstances, 1, false);
if (ocrTrainInstancesStruct.getSize() == 0) return;
// ==================== //
// ======= PA STRUCTURED MODEL ====== //
W = new Vector() {{put(0, 0.0);}}; // init the first weight vector
arguments = new ArrayList() {{add(15.0);}}; // model parameters
StructEDModel ocr_model = new StructEDModel(W, new PassiveAggressive(),
new TaskLossOCR(), new InferenceOCR(), null,
new FeatureFunctionsOCR(maxFeatures), arguments); // create the model
ocr_model.train(ocrTrainInstancesStruct, null, null, epochNum, isAvg, true); // train
ocr_model.predict(ocrTestInstancesStruct, null, 1, false); // predict
Graph g = new Graph();
g.drawHeatMap(ocr_model, start_transition_characters, nam_of_characters);
// ======================================================= //