邏輯回歸演算法原理及Spark MLlib調用
邏輯回歸
演算法原理:
邏輯回歸是一個流行的二分類問題預測方法。它是 Generalized Linear models 的一個特殊應用以預測結果概率。它是一個線性模型如下列方程所示,其中損失函數為邏輯損失:
對於二分類問題,演算法產出一個二值邏輯回歸模型。給定一個新數據,由x表示,則模型通過下列邏輯方程來預測:
其中。默認情況下,如果,結果為正,否則為負。和線性SVMs不同,邏輯回歸的原始輸出有概率解釋(x為正的概率)。
二分類邏輯回歸可以擴展為多分類邏輯回歸來訓練和預測多類別分類問題。如一個分類問題有K種可能結果,我們可以選取其中一種結果作為「中心點「,其他K-1個結果分別視為中心點結果的對立點。在spark.mllib中,取第一個類別為中心點類別。
*目前spark.ml邏輯回歸工具僅支持二分類問題,多分類回歸將在未來完善。
*當使用無攔截的連續非零列訓練LogisticRegressionModel時,Spark MLlib為連續非零列輸出零係數。這種處理不同於libsvm與R glmnet相似。
參數:
elasticNetParam:
類型:雙精度型。
含義:彈性網路混合參數,範圍[0,1]。
featuresCol:
類型:字元串型。
含義:特徵列名。
fitIntercept:
類型:布爾型。
含義:是否訓練攔截對象。
labelCol:
類型:字元串型。
含義:標籤列名。
maxIter:
類型:整數型。
含義:最多迭代次數(>=0)。
predictionCol:
類型:字元串型。
含義:預測結果列名。
probabilityCol:
類型:字元串型。
含義:用以預測類別條件概率的列名。
regParam:
類型:雙精度型。
含義:正則化參數(>=0)。
standardization:
類型:布爾型。
含義:訓練模型前是否需要對訓練特徵進行標準化處理。
threshold:
類型:雙精度型。
含義:二分類預測的閥值,範圍[0,1]。
thresholds:
類型:雙精度數組型。
含義:多分類預測的閥值,以調整預測結果在各個類別的概率。
tol:
類型:雙精度型。
含義:迭代演算法的收斂性。
weightCol:
類型:字元串型。
含義:列權重。
示例:
下面的例子展示如何訓練使用彈性網路正則化的邏輯回歸模型。elasticNetParam對應於,regParam對應於。
Scala:
import org.apache.spark.ml.classification.LogisticRegressionnn// Load training datanval training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")nnval lr = new LogisticRegression()n .setMaxIter(10)n .setRegParam(0.3)n .setElasticNetParam(0.8)nn// Fit the modelnval lrModel = lr.fit(training)nn// Print the coefficients and intercept for logistic regressionnprintln(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")n
Java:
import org.apache.spark.ml.classification.LogisticRegression;nimport org.apache.spark.ml.classification.LogisticRegressionModel;nimport org.apache.spark.sql.Dataset;nimport org.apache.spark.sql.Row;nimport org.apache.spark.sql.SparkSession;nn// Load training datanDataset<Row> training = spark.read().format("libsvm")n .load("data/mllib/sample_libsvm_data.txt");nnLogisticRegression lr = new LogisticRegression()n .setMaxIter(10)n .setRegParam(0.3)n .setElasticNetParam(0.8);nn// Fit the modelnLogisticRegressionModel lrModel = lr.fit(training);nn// Print the coefficients and intercept for logistic regressionnSystem.out.println("Coefficients: "n + lrModel.coefficients() + " Intercept: " + lrModel.intercept());n
from pyspark.ml.classification import LogisticRegressionnn# Load training datantraining = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")nnlr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)nn# Fit the modelnlrModel = lr.fit(training)nn# Print the coefficients and intercept for logistic regressionnprint("Coefficients: " + str(lrModel.coefficients))nprint("Intercept: " + str(lrModel.intercept))n
Python:
from pyspark.ml.classification import LogisticRegressionnn# Load training datantraining = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")nnlr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)nn# Fit the modelnlrModel = lr.fit(training)nn# Print the coefficients and intercept for logistic regressionnprint("Coefficients: " + str(lrModel.coefficients))nprint("Intercept: " + str(lrModel.intercept))n
spark.ml邏輯回歸工具同樣支持提取模總結。LogisticRegressionTrainingSummary提供LogisticRegressionModel的總結。目前僅支持二分類問題,所以總結必須明確投擲到BinaryLogisticRegressionTrainingSummary。支持多分類問題後可能有所改善。
繼續上面的例子:
Scala:
import org.apache.spark.ml.Pipelinenimport org.apache.spark.ml.classification.DecisionTreeClassificationModelnimport org.apache.spark.ml.classification.DecisionTreeClassifiernimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluatornimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}nn// Load the data stored in LIBSVM format as a DataFrame.nval data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")nn// Index labels, adding metadata to the label column.n// Fit on whole dataset to include all labels in index.nval labelIndexer = new StringIndexer()n .setInputCol("label")n .setOutputCol("indexedLabel")n .fit(data)n// Automatically identify categorical features, and index them.nval featureIndexer = new VectorIndexer()n .setInputCol("features")n .setOutputCol("indexedFeatures")n .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.n .fit(data)nn// Split the data into training and test sets (30% held out for testing).nval Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))nn// Train a DecisionTree model.nval dt = new DecisionTreeClassifier()n .setLabelCol("indexedLabel")n .setFeaturesCol("indexedFeatures")nn// Convert indexed labels back to original labels.nval labelConverter = new IndexToString()n .setInputCol("prediction")n .setOutputCol("predictedLabel")n .setLabels(labelIndexer.labels)nn// Chain indexers and tree in a Pipeline.nval pipeline = new Pipeline()n .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))nn// Train model. This also runs the indexers.nval model = pipeline.fit(trainingData)nn// Make predictions.nval predictions = model.transform(testData)nn// Select example rows to display.npredictions.select("predictedLabel", "label", "features").show(5)nn// Select (prediction, true label) and compute test error.nval evaluator = new MulticlassClassificationEvaluator()n .setLabelCol("indexedLabel")n .setPredictionCol("prediction")n .setMetricName("accuracy")nval accuracy = evaluator.evaluate(predictions)nprintln("Test Error = " + (1.0 - accuracy))nnval treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]nprintln("Learned classification tree model:n" + treeModel.toDebugString)n
Java:
import org.apache.spark.ml.Pipeline;nimport org.apache.spark.ml.PipelineModel;nimport org.apache.spark.ml.PipelineStage;nimport org.apache.spark.ml.classification.DecisionTreeClassifier;nimport org.apache.spark.ml.classification.DecisionTreeClassificationModel;nimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;nimport org.apache.spark.ml.feature.*;nimport org.apache.spark.sql.Dataset;nimport org.apache.spark.sql.Row;nimport org.apache.spark.sql.SparkSession;nn// Load the data stored in LIBSVM format as a DataFrame.nDataset<Row> data = sparkn .read()n .format("libsvm")n .load("data/mllib/sample_libsvm_data.txt");nn// Index labels, adding metadata to the label column.n// Fit on whole dataset to include all labels in index.nStringIndexerModel labelIndexer = new StringIndexer()n .setInputCol("label")n .setOutputCol("indexedLabel")n .fit(data);nn// Automatically identify categorical features, and index them.nVectorIndexerModel featureIndexer = new VectorIndexer()n .setInputCol("features")n .setOutputCol("indexedFeatures")n .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.n .fit(data);nn// Split the data into training and test sets (30% held out for testing).nDataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});nDataset<Row> trainingData = splits[0];nDataset<Row> testData = splits[1];nn// Train a DecisionTree model.nDecisionTreeClassifier dt = new DecisionTreeClassifier()n .setLabelCol("indexedLabel")n .setFeaturesCol("indexedFeatures");nn// Convert indexed labels back to original labels.nIndexToString labelConverter = new IndexToString()n .setInputCol("prediction")n .setOutputCol("predictedLabel")n .setLabels(labelIndexer.labels());nn// Chain indexers and tree in a Pipeline.nPipeline pipeline = new Pipeline()n .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});nn// Train model. This also runs the indexers.nPipelineModel model = pipeline.fit(trainingData);nn// Make predictions.nDataset<Row> predictions = model.transform(testData);nn// Select example rows to display.npredictions.select("predictedLabel", "label", "features").show(5);nn// Select (prediction, true label) and compute test error.nMulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()n .setLabelCol("indexedLabel")n .setPredictionCol("prediction")n .setMetricName("accuracy");ndouble accuracy = evaluator.evaluate(predictions);nSystem.out.println("Test Error = " + (1.0 - accuracy));nnDecisionTreeClassificationModel treeModel =n (DecisionTreeClassificationModel) (model.stages()[2]);nSystem.out.println("Learned classification tree model:n" + treeModel.toDebugString());n
推薦閱讀:
※一個演算法工程師的日常是怎樣的?
※AlphaGo Zero 的出現在意料之中!
※馬斯克:如果機器取代人類工作,那人類靠什麼生存?
※誤差分解以及為什麼DL這麼萬能
※帶你讀機器學習經典(三): Python機器學習(Chapter 1&2)