Java 机器学习库
所属分类 quant
浏览量 6
除了 Weka,Java 生态中还有多款功能丰富、定位不同的机器学习库,
覆盖传统机器学习、深度学习、分布式计算、专用算法等场景
Java 机器学习库
Spark MLlib
Deeplearning4j(DL4J)
Smile(Statistical Machine Intelligence and Learning Engine)
大数据场景:优先选 Spark MLlib(分布式能力强,API 易上手);
深度学习 + 传统 ML:选 DL4J(一站式解决,无需跨语言);
中小型数据集、高性能要求:选 Smile(轻量、速度快);
这些库均能覆盖 sklearn 的核心功能,且完全适配 Java 生态,可根据项目规模、场景复杂度选择。
其他辅助库 Apache Commons Math,ND4J(数值计算库)
1. Apache Spark MLlib
核心特点:
基于 Spark 的分布式机器学习库,支持大规模数据处理,
API 风格接近 sklearn,提供 Pipeline、特征转换、经典算法全覆盖;
适用场景:
大数据量(TB 级)的机器学习任务,分布式训练 / 预测;
核心算法:
分类(逻辑回归、随机森林、SVM)、回归、聚类(K-Means、DBSCAN)、
推荐(ALS)、特征工程(TF-IDF、标准化、独热编码);
极简示例(决策树分类):
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
public class SparkMLlibDemo {
public static void main(String[] args) {
// 初始化SparkSession
SparkSession spark = SparkSession.builder()
.appName("SparkMLlibDemo")
.master("local[*]") // 本地运行
.getOrCreate();
// 加载数据(CSV格式,对标sklearn load_csv)
Dataset data = spark.read().format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load("data/iris.csv");
// 特征组装(对标sklearn的ColumnTransformer)
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"sepal_length", "sepal_width", "petal_length", "petal_width"})
.setOutputCol("features");
Dataset featuresData = assembler.transform(data).withColumnRenamed("species", "label");
// 训练决策树(对标sklearn DecisionTreeClassifier)
DecisionTreeClassifier dt = new DecisionTreeClassifier()
.setLabelCol("label")
.setFeaturesCol("features");
DecisionTreeClassificationModel model = dt.fit(featuresData);
// 预测(对标sklearn predict)
model.transform(featuresData).select("features", "label", "prediction").show(5);
spark.stop();
}
}
org.apache.spark
spark-mllib_2.12
3.5.0
2. Deeplearning4j(DL4J)
核心特点:
Java/Scala 原生的深度学习库,
同时支持传统机器学习算法(如随机森林、逻辑回归),可与 ND4J(数值计算库)无缝集成;
适用场景:需要同时处理深度学习和传统 ML 的 Java 项目;
核心算法:DNN、CNN、RNN(深度学习),随机森林、梯度提升树(传统 ML);
ND4J N-Dimensional Arrays for Java ,
可以看作是 Java 生态中的 NumPy
极简示例(逻辑回归):
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class DL4JDemo {
public static void main(String[] args) throws Exception {
// 加载鸢尾花数据集(对标sklearn load_iris)
IrisDataSetIterator iterator = new IrisDataSetIterator(150, 150);
// 构建逻辑回归模型(单层输出层)
MultiLayerNetwork model = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(10).activation(Activation.RELU).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(10).nOut(3).activation(Activation.SOFTMAX).build())
.build());
model.init();
// 训练模型
model.fit(iterator);
// 评估准确率
double accuracy = model.evaluate(iterator).accuracy();
System.out.println("准确率:" + accuracy);
}
}
Maven 依赖:
org.deeplearning4j
deeplearning4j-core
1.0.0-M2.1
org.nd4j
nd4j-native-platform
1.0.0-M2.1
3. Smile(Statistical Machine Intelligence and Learning Engine)
核心特点:
轻量级、高性能的 Java ML 库,API 简洁,覆盖传统 ML 全场景,速度优于 Weka;
适用场景:中小型数据集、对性能要求高的 Java 端 ML 任务;
核心算法:分类(SVM、决策树)、回归、聚类、降维(PCA)、特征选择;
极简示例(SVM 分类):
import smile.classification.SVM;
import smile.feature.extraction.PCA;
import smile.io.Read;
import smile.math.kernel.GaussianKernel;
import smile.data.DataFrame;
import smile.data.formula.Formula;
public class SmileDemo {
public static void main(String[] args) throws Exception {
// 加载数据
DataFrame iris = Read.csv("data/iris.csv");
Formula formula = Formula.lhs("species");
double[][] x = formula.x(iris).toArray();
int[] y = formula.y(iris).toIntArray();
// PCA降维(对标sklearn PCA)
PCA pca = PCA.fit(x).setDimension(2);
double[][] xPCA = pca.transform(x);
// 训练SVM(对标sklearn SVC)
SVM svm = new SVM<>(new GaussianKernel(1.0), xPCA, y);
svm.train();
// 预测
int pred = svm.predict(xPCA[0]);
System.out.println("预测类别:" + pred);
}
}
Maven 依赖:
com.github.haifengl
smile-core
2.6.0
com.github.haifengl
smile-io
2.6.0
4. ELKI(环境感知知识发现)
核心特点:
专注于聚类和异常检测,算法丰富且可配置性极强,支持自定义距离函数、聚类参数;
适用场景:
学术研究、高精度聚类 / 异常检测任务(如欺诈识别、离群点分析);
核心算法:
K-Means、DBSCAN、OPTICS、HDBSCAN(聚类),LOF、OCSVM(异常检测);
极简示例(DBSCAN 聚类):
import de.lmu.ifi.dbs.elki.data.DoubleVector;
import de.lmu.ifi.dbs.elki.data.model.ClusteringModel;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.StaticArrayDatabase;
import de.lmu.ifi.dbs.elki.database.connection.FileBasedDatabaseConnection;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.algorithm.clustering.DBSCAN;
import de.lmu.ifi.dbs.elki.datasource.parser.CSVReaderFormat;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.EuclideanDistanceFunction;
public class ELKIDemo {
public static void main(String[] args) {
// 加载CSV数据
FileBasedDatabaseConnection dbc = new FileBasedDatabaseConnection(
new CSVReaderFormat().withSeparator(','), "data/cluster_data.csv");
Database db = new StaticArrayDatabase(dbc, null);
db.initialize();
// 运行DBSCAN聚类(ε=0.5,最小点数=5)
DBSCAN dbscan = new DBSCAN<>(EuclideanDistanceFunction.STATIC, 0.5, 5);
ClusteringModel> result = dbscan.run(db);
// 输出聚类结果
Relation rel = db.getRelation(DoubleVector.FIELD);
System.out.println("聚类数量:" + result.getAllClusters().size());
}
}
Maven 依赖:
de.lmu.ifi.dbs.elki
elki-core
0.7.5
de.lmu.ifi.dbs.elki
elki-clustering
0.7.5
轻量化工具库(辅助 ML 流程)
这类库不提供完整 ML 算法,但能支撑特征工程、模型评估等核心环节:
1. Apache Commons Math
核心特点:Java 基础数值计算库,提供统计、线性代数、优化、概率分布等功能;
适用场景:自定义 ML 算法、基础数据预处理(标准化、归一化);
核心功能:均值 / 方差计算、最小二乘拟合、矩阵运算、随机数生成。
上一篇
springboot3 应用启动 报错 找不到 com.mybatisflex.core.service.IService
python量化项目
箱体理论及量化实战