Spark快速获得CrossValidator的最佳模型参数

10 11月

本文版权归作者所有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

转载自夜明的孤行灯

本文链接地址: https://www.huangyunkun.com/2016/11/10/spark-get-crossvalidator-best-model-params/

Spark提供了便利的Pipeline模型,可以轻松的创建自己的学习模型。

但是大部分模型都是需要提供参数的,如果不提供就是默认参数,那么怎么选择参数就是一个比较常见的问题。Spark提供在org.apache.spark.ml.tuning包下提供了模型选择器,可以替换参数然后比较模型输出。

目前有CrossValidator和TrainValidationSplit两种,比如一个文本情感预测模型。

Pipeline只有三步,第一步切词,第二部Hashing TF,第三部NB分类

Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{tokenizer, hashingTF, naiveBayes});

ParamMap[] paramMaps = new ParamGridBuilder()
                .addGrid(hashingTF.numFeatures(), new int[]{10000, 100000, 500000, 1000000})
                .build();
CrossValidator cv = new CrossValidator()
                .setEstimator(pipeline)
                .setEvaluator(new BinaryClassificationEvaluator())
                .setEstimatorParamMaps(paramMaps);

其中Hashing TF的参数选择非常重要,我们这里就随便尝试几种,然后放在CrossValidator中去。

最后我们会获得一个CrossValidatorModel类,这里有两种选择。

第一种是自己手动获取其中的参数,因为bestModel的参数就是我们最后选择的参数

Pipeline bestPipeline = (Pipeline) model.bestModel().parent();
PipelineStage stage = bestPipeline.getStages()[1];
stage.extractParamMap().get(stage.getParam("numFeatures"));

这种方法可以获得值,但是需要根据你模型情况修改获取的位置。

如果你只是想知道最佳参数是多少,并不是需要在上下文中使用,那还有一个更简单的方法。

修改log4j的配置,添加

log4j.logger.org.apache.spark.ml.tuning.TrainValidationSplit=INFO
log4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO

效果如下:

spark-best-model-params

本文版权归作者所有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

转载自夜明的孤行灯

本文链接地址: https://www.huangyunkun.com/2016/11/10/spark-get-crossvalidator-best-model-params/

One Reply to “Spark快速获得CrossValidator的最佳模型参数”

发表评论

电子邮件地址不会被公开。