当前位置:首页 > JavaScript

js实现线性回归

2026-02-03 08:47:23JavaScript

线性回归的数学原理

线性回归模型的基本形式为: $$ y = w \cdot x + b $$

其中:

  • $y$ 是预测值
  • $x$ 是特征值
  • $w$ 是权重(斜率)
  • $b$ 是偏置(截距)

损失函数(均方误差): $$ L(w, b) = \frac{1}{2n} \sum_{i=1}^n (y_i - (w \cdot x_i + b))^2 $$

梯度下降实现

使用梯度下降法更新参数: $$ w := w - \alpha \cdot \frac{\partial L}{\partial w} $$ $$ b := b - \alpha \cdot \frac{\partial L}{\partial b} $$

偏导数计算: $$ \frac{\partial L}{\partial w} = -\frac{1}{n} \sum_{i=1}^n x_i (y_i - (w \cdot xi + b)) $$ $$ \frac{\partial L}{\partial b} = -\frac{1}{n} \sum{i=1}^n (y_i - (w \cdot x_i + b)) $$

JavaScript实现代码

class LinearRegression {
  constructor(learningRate = 0.01, iterations = 1000) {
    this.learningRate = learningRate;
    this.iterations = iterations;
    this.weights = null;
    this.bias = 0;
  }

  fit(X, y) {
    const n = X.length;
    const d = X[0].length;
    this.weights = Array(d).fill(0);

    for (let iter = 0; iter < this.iterations; iter++) {
      let gradW = Array(d).fill(0);
      let gradB = 0;

      for (let i = 0; i < n; i++) {
        const error = this.predict(X[i]) - y[i];

        for (let j = 0; j < d; j++) {
          gradW[j] += (1/n) * X[i][j] * error;
        }
        gradB += (1/n) * error;
      }

      for (let j = 0; j < d; j++) {
        this.weights[j] -= this.learningRate * gradW[j];
      }
      this.bias -= this.learningRate * gradB;
    }
  }

  predict(x) {
    let result = this.bias;
    for (let i = 0; i < x.length; i++) {
      result += x[i] * this.weights[i];
    }
    return result;
  }
}

// 使用示例
const X = [[1], [2], [3], [4]];  // 特征
const y = [2, 4, 6, 8];          // 标签

const model = new LinearRegression();
model.fit(X, y);

console.log(model.predict([5]));  // 预测结果接近10

使用矩阵运算优化

对于大数据集,可以使用矩阵运算提高效率:

class MatrixLinearRegression {
  constructor(learningRate = 0.01, iterations = 1000) {
    this.learningRate = learningRate;
    this.iterations = iterations;
    this.weights = null;
  }

  fit(X, y) {
    // 添加偏置列
    X = X.map(x => [1, ...x]);
    const n = X.length;
    const d = X[0].length;

    // 初始化权重
    this.weights = Array(d).fill(0);

    for (let iter = 0; iter < this.iterations; iter++) {
      // 计算预测值
      const predictions = X.map(x => {
        return x.reduce((sum, xi, i) => sum + xi * this.weights[i], 0);
      });

      // 计算梯度
      const errors = predictions.map((p, i) => p - y[i]);
      const gradients = Array(d).fill(0);

      for (let j = 0; j < d; j++) {
        gradients[j] = (1/n) * errors.reduce(
          (sum, err, i) => sum + err * X[i][j], 0
        );
      }

      // 更新权重
      for (let j = 0; j < d; j++) {
        this.weights[j] -= this.learningRate * gradients[j];
      }
    }
  }

  predict(x) {
    x = [1, ...x];
    return x.reduce((sum, xi, i) => sum + xi * this.weights[i], 0);
  }
}

使用现成库实现

推荐使用ml.js库实现更完整的线性回归:

js实现线性回归

// 安装:npm install ml-regression
const { LinearRegression } = require('ml-regression');

const X = [1, 2, 3, 4];
const y = [2, 4, 6, 8];

const regression = new LinearRegression(X, y);
console.log(regression.predict(5));  // 10
console.log(regression.toString(3)); // 显示模型方程

标签: 线性js
分享给朋友:

相关文章

js实现动画

js实现动画

使用 CSS 动画与 JavaScript 控制 通过 JavaScript 动态添加或移除 CSS 类来触发动画。CSS 定义关键帧(@keyframes),JavaScript 通过 classL…

js实现预览

js实现预览

文件上传预览实现 使用JavaScript实现文件上传预览功能,可以通过FileReader对象读取文件内容并显示预览。以下代码支持图片、文本和PDF文件的预览: // HTML部分需要包含inpu…

js实现点击显示和隐藏

js实现点击显示和隐藏

实现点击显示和隐藏的JavaScript方法 使用classList.toggle切换类名 通过添加/移除CSS类控制元素的显示与隐藏,需提前在CSS中定义隐藏样式(如display: none)。…

js实现自举

js实现自举

实现 JavaScript 自举的方法 JavaScript 自举(Bootstrap)通常指通过 JavaScript 动态加载自身或依赖的脚本,实现运行时环境的初始化或扩展。以下是几种常见实现方式…

js实现视口

js实现视口

js实现视口检测的方法 使用JavaScript检测元素是否进入视口(viewport)可以通过Intersection Observer API或手动计算元素位置实现。以下是两种常见方法: Int…

js实现下拉菜单

js实现下拉菜单

使用HTML和CSS创建基础结构 HTML部分需要包含一个触发下拉的按钮和隐藏的下拉菜单内容: <div class="dropdown"> <button class="dr…