Save and Load StructED Models


In this example we will demonstrate how to save and load StructED models. For this example we use the Dummy Data again.

Before we begin!

You can use this tutorial in two ways:
  1. Run the code as we implemented it:
  2. In order to run the tutorials you can simply run the bash script in any of the tutorials directories in our GitHub repository. The tutorials can be found under the tutorials-code directory under save_load folder.

  3. Write the code yourself:
  4. We should setup our development environment first. Open a Java project and add the StructED jar file to your build path (it can be found under the bin directory at StructED repository). You can download all StructED source codes and jar files from GitHub.
Now let's begin!

The Code


Here is a snippet of such code (the complete code can be found at the package repository under tutorials-code folder):

        /*
            An example for the use of different learning algorithms on the same problem
            For this we use the dummy data
         */         
        Logger.info("Dummy data example.");
        // parameters
        int readerType = 0;
        int epochNum = 3;
        int isAvg = 1;
        int numExamples2Display = 3;
        String trainPath = "data/train.txt";
        String testPath = "data/test.txt";

        // load the data
        Reader reader = getReader(readerType);
        InstancesContainer dummyTrainInstances = reader.readData(trainPath, Consts.SPACE, 
              Consts.COLON_SPLITTER);
        InstancesContainer dummyTestInstances = reader.readData(testPath, Consts.SPACE, 
              Consts.COLON_SPLITTER);
        if ( dummyTrainInstances.getSize() == 0 ) return;

        /* PASSIVE AGGRESSIVE MODEL */
        // init the first weight vector
        Vector W = new Vector() {{put(0, 0.0);}}; 
        // model parameters
        ArrayList<Double> arguments = new ArrayList<Double>(){{add(3.0);}}; 
        // task loss parameters
        ArrayList<Double> task_loss_params = new ArrayList<Double>(){{add(1.0);}}; 

        // building the model
        StructEDModel dummy_model = new StructEDModel(W, new PassiveAggressive(), new TaskLossDummyData(),
                new InferenceDummyData(), null, new FeatureFunctionsDummy(), arguments);
        // train
        dummy_model.train(dummyTrainInstances, task_loss_params, null, epochNum, isAvg);
        // predict
        ArrayList<PredictedLabels> labels = dummy_model.predict(dummyTestInstances, 
                task_loss_params, numExamples2Display);

        // print the prediction
        for(int i=0 ; i<dummyTestInstances.getSize() ; i++)
            Logger.info("Y = "+dummyTestInstances.getInstance(i).getLabel()+", Y_HAT = "+
                labels.get(i).firstKey());
        Logger.info("");

        /*
            An example for saving and loading our trained model
            For this we use the dummy data
         */      
        Logger.info("Save and Load the Dummy model.");
        
        // save the PA model as dummy.model
        dummy_model.saveModel("dummy.model");

        /* prepare the model to be loaded */
        // model parameters 
        arguments = new ArrayList<Double>(){{add(3.0);}};
        // build the model
        StructEDModel loaded_model = new StructEDModel(null, new PassiveAggressive(), new TaskLossDummyData(),
                new InferenceDummyData(), null, new FeatureFunctionsDummy(), arguments);

        // load the saved model
        loaded_model.loadModel("dummy.model");
        // predict
        loaded_model.predict(dummyTestInstances, task_loss_params, numExamples2Display);

        // print the prediction
        for(int i=0 ; i<dummyTestInstances.getSize() ; i++)
            Logger.info("Y = "+dummyTestInstances.getInstance(i).getLabel()+", Y_HAT = "+
                labels.get(i).firstKey());
        Logger.info("");