首页  

Java 机器学习库     所属分类 quant 浏览量 6
除了 Weka,Java 生态中还有多款功能丰富、定位不同的机器学习库,
覆盖传统机器学习、深度学习、分布式计算、专用算法等场景


Java 机器学习库 
Spark MLlib 
Deeplearning4j(DL4J) 
Smile(Statistical Machine Intelligence and Learning Engine)
大数据场景:优先选 Spark MLlib(分布式能力强,API 易上手);
深度学习 + 传统 ML:选 DL4J(一站式解决,无需跨语言);
中小型数据集、高性能要求:选 Smile(轻量、速度快);
这些库均能覆盖 sklearn 的核心功能,且完全适配 Java 生态,可根据项目规模、场景复杂度选择。
其他辅助库 Apache Commons Math,ND4J(数值计算库)




1. Apache Spark MLlib
核心特点:
基于 Spark 的分布式机器学习库,支持大规模数据处理,
API 风格接近 sklearn,提供 Pipeline、特征转换、经典算法全覆盖;

适用场景:
大数据量(TB 级)的机器学习任务,分布式训练 / 预测;

核心算法:
分类(逻辑回归、随机森林、SVM)、回归、聚类(K-Means、DBSCAN)、
推荐(ALS)、特征工程(TF-IDF、标准化、独热编码);

极简示例(决策树分类):

import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class SparkMLlibDemo {
    public static void main(String[] args) {
        // 初始化SparkSession
        SparkSession spark = SparkSession.builder()
                .appName("SparkMLlibDemo")
                .master("local[*]") // 本地运行
                .getOrCreate();

        // 加载数据(CSV格式,对标sklearn load_csv)
        Dataset<Row> data = spark.read().format("csv")
                .option("header", "true")
                .option("inferSchema", "true")
                .load("data/iris.csv");

        // 特征组装(对标sklearn的ColumnTransformer)
        VectorAssembler assembler = new VectorAssembler()
                .setInputCols(new String[]{"sepal_length", "sepal_width", "petal_length", "petal_width"})
                .setOutputCol("features");
        Dataset<Row> featuresData = assembler.transform(data).withColumnRenamed("species", "label");

        // 训练决策树(对标sklearn DecisionTreeClassifier)
        DecisionTreeClassifier dt = new DecisionTreeClassifier()
                .setLabelCol("label")
                .setFeaturesCol("features");
        DecisionTreeClassificationModel model = dt.fit(featuresData);

        // 预测(对标sklearn predict)
        model.transform(featuresData).select("features", "label", "prediction").show(5);

        spark.stop();
    }
}


<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-mllib_2.12</artifactId>
    <version>3.5.0</version>
</dependency>


2. Deeplearning4j(DL4J)
核心特点:
Java/Scala 原生的深度学习库,
同时支持传统机器学习算法(如随机森林、逻辑回归),可与 ND4J(数值计算库)无缝集成;

适用场景:需要同时处理深度学习和传统 ML 的 Java 项目;
核心算法:DNN、CNN、RNN(深度学习),随机森林、梯度提升树(传统 ML);

ND4J   N-Dimensional Arrays for Java ,
可以看作是 Java 生态中的 NumPy


极简示例(逻辑回归):

import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class DL4JDemo {
    public static void main(String[] args) throws Exception {
        // 加载鸢尾花数据集(对标sklearn load_iris)
        IrisDataSetIterator iterator = new IrisDataSetIterator(150, 150);

        // 构建逻辑回归模型(单层输出层)
        MultiLayerNetwork model = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .weightInit(WeightInit.XAVIER)
                .learningRate(0.1)
                .list()
                .layer(new DenseLayer.Builder().nIn(4).nOut(10).activation(Activation.RELU).build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nIn(10).nOut(3).activation(Activation.SOFTMAX).build())
                .build());
        model.init();

        // 训练模型
        model.fit(iterator);

        // 评估准确率
        double accuracy = model.evaluate(iterator).accuracy();
        System.out.println("准确率:" + accuracy);
    }
}

Maven 依赖:
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-M2.1</version>
</dependency>
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-M2.1</version>
</dependency>




3. Smile(Statistical Machine Intelligence and Learning Engine)
核心特点:
轻量级、高性能的 Java ML 库,API 简洁,覆盖传统 ML 全场景,速度优于 Weka;

适用场景:中小型数据集、对性能要求高的 Java 端 ML 任务;
核心算法:分类(SVM、决策树)、回归、聚类、降维(PCA)、特征选择;

极简示例(SVM 分类):

import smile.classification.SVM;
import smile.feature.extraction.PCA;
import smile.io.Read;
import smile.math.kernel.GaussianKernel;
import smile.data.DataFrame;
import smile.data.formula.Formula;

public class SmileDemo {
    public static void main(String[] args) throws Exception {
        // 加载数据
        DataFrame iris = Read.csv("data/iris.csv");
        Formula formula = Formula.lhs("species");
        double[][] x = formula.x(iris).toArray();
        int[] y = formula.y(iris).toIntArray();

        // PCA降维(对标sklearn PCA)
        PCA pca = PCA.fit(x).setDimension(2);
        double[][] xPCA = pca.transform(x);

        // 训练SVM(对标sklearn SVC)
        SVM<double[]> svm = new SVM<>(new GaussianKernel(1.0), xPCA, y);
        svm.train();

        // 预测
        int pred = svm.predict(xPCA[0]);
        System.out.println("预测类别:" + pred);
    }
}

Maven 依赖:
<dependency>
    <groupId>com.github.haifengl</groupId>
    <artifactId>smile-core</artifactId>
    <version>2.6.0</version>
</dependency>
<dependency>
    <groupId>com.github.haifengl</groupId>
    <artifactId>smile-io</artifactId>
    <version>2.6.0</version>
</dependency>




4. ELKI(环境感知知识发现)
核心特点:
专注于聚类和异常检测,算法丰富且可配置性极强,支持自定义距离函数、聚类参数;

适用场景:
学术研究、高精度聚类 / 异常检测任务(如欺诈识别、离群点分析);

核心算法:
K-Means、DBSCAN、OPTICS、HDBSCAN(聚类),LOF、OCSVM(异常检测);

极简示例(DBSCAN 聚类):

import de.lmu.ifi.dbs.elki.data.DoubleVector;
import de.lmu.ifi.dbs.elki.data.model.ClusteringModel;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.StaticArrayDatabase;
import de.lmu.ifi.dbs.elki.database.connection.FileBasedDatabaseConnection;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.algorithm.clustering.DBSCAN;
import de.lmu.ifi.dbs.elki.datasource.parser.CSVReaderFormat;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.EuclideanDistanceFunction;

public class ELKIDemo {
    public static void main(String[] args) {
        // 加载CSV数据
        FileBasedDatabaseConnection dbc = new FileBasedDatabaseConnection(
                new CSVReaderFormat().withSeparator(','), "data/cluster_data.csv");
        Database db = new StaticArrayDatabase(dbc, null);
        db.initialize();

        // 运行DBSCAN聚类(ε=0.5,最小点数=5)
        DBSCAN<DoubleVector> dbscan = new DBSCAN<>(EuclideanDistanceFunction.STATIC, 0.5, 5);
        ClusteringModel<?> result = dbscan.run(db);

        // 输出聚类结果
        Relation<DoubleVector> rel = db.getRelation(DoubleVector.FIELD);
        System.out.println("聚类数量:" + result.getAllClusters().size());
    }
}

Maven 依赖:
<dependency>
    <groupId>de.lmu.ifi.dbs.elki</groupId>
    <artifactId>elki-core</artifactId>
    <version>0.7.5</version>
</dependency>
<dependency>
    <groupId>de.lmu.ifi.dbs.elki</groupId>
    <artifactId>elki-clustering</artifactId>
    <version>0.7.5</version>
</dependency>






轻量化工具库(辅助 ML 流程) 这类库不提供完整 ML 算法,但能支撑特征工程、模型评估等核心环节: 1. Apache Commons Math 核心特点:Java 基础数值计算库,提供统计、线性代数、优化、概率分布等功能; 适用场景:自定义 ML 算法、基础数据预处理(标准化、归一化); 核心功能:均值 / 方差计算、最小二乘拟合、矩阵运算、随机数生成。

上一篇    
springboot3 应用启动 报错 找不到 com.mybatisflex.core.service.IService

python量化项目

箱体理论及量化实战