public abstract class Classifier<FeaturesType,E extends Classifier<FeaturesType,E,M>,M extends ClassificationModel<FeaturesType,M>> extends Predictor<FeaturesType,E,M>
Single-label binary or multiclass classification. Classes are indexed {0, 1, ..., numClasses - 1}.
Constructor and Description |
---|
Classifier() |
Modifier and Type | Method and Description |
---|---|
protected RDD<LabeledPoint> |
extractLabeledPoints(Dataset<?> dataset,
int numClasses)
Extract
labelCol and featuresCol from the given dataset,
and put it in an RDD with strong types. |
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
protected int |
getNumClasses(Dataset<?> dataset,
int maxNumClasses)
Get the number of classes.
|
java.lang.String |
getPredictionCol() |
java.lang.String |
getRawPredictionCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
Param<java.lang.String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
E |
setRawPredictionCol(java.lang.String value) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
copy, extractLabeledPoints, fit, setFeaturesCol, setLabelCol, setPredictionCol, train, transformSchema
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copy, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
toString, uid
public E setRawPredictionCol(java.lang.String value)
protected RDD<LabeledPoint> extractLabeledPoints(Dataset<?> dataset, int numClasses)
labelCol
and featuresCol
from the given dataset,
and put it in an RDD with strong types.
dataset
- DataFrame with columns for labels (NumericType
)
and features (Vector
). Labels are cast to DoubleType
.numClasses
- Number of classes label can take. Labels must be integers in the range
[0, numClasses).SparkException
- if any label is not an integer >= 0protected int getNumClasses(Dataset<?> dataset, int maxNumClasses)
Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
such as in extractLabeledPoints()
.
dataset
- Dataset which contains a column labelCol
maxNumClasses
- Maximum number of classes allowed when inferred from data. If numClasses
is specified in the metadata, then maxNumClasses is ignored.java.lang.IllegalArgumentException
- if metadata does not specify numClasses, and the
actual numClasses exceeds maxNumClassespublic StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<java.lang.String> rawPredictionCol()
public java.lang.String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()