public class StructEDModel
extends java.lang.Object
implements java.io.Serializable
| Constructor and Description |
|---|
StructEDModel(Vector init_weights,
IUpdateRule updateRule,
ITaskLoss taskLoss,
IInference inference,
IKernel kernel,
IFeatureFunctions phi,
java.util.ArrayList<java.lang.Double> args) |
StructEDModel(Vector init_weights,
IUpdateRule updateRule,
ITaskLoss taskLoss,
IInference inference,
IKernel kernel,
IFeatureFunctions phi,
java.util.ArrayList<java.lang.Double> args,
boolean isShuffle) |
| Modifier and Type | Method and Description |
|---|---|
double |
getCumulative_loss() |
Vector |
getWeights() |
void |
loadModel(java.lang.String path)
load the model, notice: we load only the weight vector, the rest of the configuration should be predefined
|
void |
plotValidationError(boolean save)
Plotting the error graph for the validation set
|
java.util.ArrayList<PredictedLabels> |
predict(InstancesContainer instances,
java.util.ArrayList<java.lang.Double> task_loss_params,
int numPredictions2Return,
boolean verbose)
predict based on the model and return the scores of the best matches
|
void |
saveModel(java.lang.String path)
save the model, notice: we save only the weight vector, the rest of the configuration should be predefined
|
void |
setFeatureFunctions(IFeatureFunctions phi) |
void |
setInference(IInference inference) |
void |
setKernel(IKernel kernel) |
void |
setReShuffle(boolean isShuffle) |
void |
setTaskLoss(ITaskLoss taskLoss) |
void |
setUpdateRule(IUpdateRule updateRule) |
void |
train(InstancesContainer trainInstances,
java.util.ArrayList<java.lang.Double> task_loss_params,
InstancesContainer developInstances,
int epoch,
int isAvg,
boolean verbose)
Train the model on the train instances
|
public StructEDModel(Vector init_weights, IUpdateRule updateRule, ITaskLoss taskLoss, IInference inference, IKernel kernel, IFeatureFunctions phi, java.util.ArrayList<java.lang.Double> args)
public StructEDModel(Vector init_weights, IUpdateRule updateRule, ITaskLoss taskLoss, IInference inference, IKernel kernel, IFeatureFunctions phi, java.util.ArrayList<java.lang.Double> args, boolean isShuffle)
public void train(InstancesContainer trainInstances, java.util.ArrayList<java.lang.Double> task_loss_params, InstancesContainer developInstances, int epoch, int isAvg, boolean verbose) throws java.lang.Exception
epoch - - the number of epochs to run over the datatrainInstances - - the training instances(train set)task_loss_params - - the task loss parameters, if there are no parameters set this to nulldevelopInstances - - the develop instances(dev set), set this to null if there is no validation setisAvg - - an indicator whether or not to average the weightsjava.lang.Exceptionpublic java.util.ArrayList<PredictedLabels> predict(InstancesContainer instances, java.util.ArrayList<java.lang.Double> task_loss_params, int numPredictions2Return, boolean verbose)
instances - - the test settask_loss_params - - the parameters for the task loss function, set to null if there are no task loss parametersnumPredictions2Return - - the number of examples to return in the scorespublic void plotValidationError(boolean save)
save - flag which indicates whether to write the image or notpublic void saveModel(java.lang.String path)
throws java.io.IOException
path - - the path to save the modeljava.io.IOExceptionpublic void loadModel(java.lang.String path)
throws java.io.IOException,
java.lang.ClassNotFoundException
path - - the model's pathjava.io.IOExceptionjava.lang.ClassNotFoundExceptionpublic Vector getWeights()
public double getCumulative_loss()
public void setUpdateRule(IUpdateRule updateRule)
public void setTaskLoss(ITaskLoss taskLoss)
public void setKernel(IKernel kernel)
public void setInference(IInference inference)
public void setFeatureFunctions(IFeatureFunctions phi)
public void setReShuffle(boolean isShuffle)