$$ \newcommand{\R}{\mathbb{R}} \newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} \newcommand{\av}{\mathbf{\alpha}} \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} \newcommand{\id}{\mathbf{I}} \newcommand{\ind}{\mathbf{1}} \newcommand{\0}{\mathbf{0}} \newcommand{\unit}{\mathbf{e}} \newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \newcommand\rfrac[2]{^{#1}\!/_{#2}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} $$

K近邻算法(KNN)

描述

实现一个精密K最近邻居连接(exact k-nearest neighbors join)算法。假设有一个训练集 $A$ 和一个测试集 $B$,该算法返回

该暴力方法的目标是计算每个训练点和测试点之间的距离。为了使计算每个训练点之间距离的暴力计算过程更加简化和平滑,本方法使用一个四叉树。该四叉树在训练点的数量上有很好的扩展性,但是在空间维度上的扩展性表现不佳。本算法会自动选择是否采用该四叉树,用户也可以通过设置一个参数来覆盖算法的决定,强制指定是否使用该四叉树。

操作

KNN 是一个 Predictor。 正如所示, 它支持 fit 和 predict 操作。

拟合

KNN 通过一个给定的 Vector 集来训练:

  • fit[T <: Vector]: DataSet[T] => Unit

预测

KNN 为所有的FlinkML的 Vector 的子类预测对应的类别标签:

  • predict[T <: Vector]: DataSet[T] => DataSet[(T, Array[Vector])], 这里 (T, Array[Vector]) 元组对应 (test point, K-nearest training points)

参数

KNN的实现可以由以下参数控制:

               
参数描述
K

定义要搜索的最近邻居数量。也就是说,对于每一个测试点,该算法会从训练集中找到K个最近邻居.(默认值: 5)

DistanceMetric

设置用来计算两点之间距离的度量标准。如果没有指定度量标准,则[[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] 被使用.(默认值: EuclideanDistanceMetric)

Blocks

设置输入数据将会被切分的块数。该数目至少应该被设置成与并行度相等。如果没有指定块数,则使用作为输入的 [[DataSet]] 的平行度作为块数.(默认值: None)

UseQuadTree

一个布尔参数,该参数用来指定是否使用能够对训练集进行分区,并且有可能简化平滑KNN搜索的四叉树。如果该值没有指定,则代码会自动决定是否使用一个四叉树。四叉树的使用在训练点和测试点的数量上有很好的扩展性,但在维度上的扩展性表现不佳.(默认值: None)

SizeHint          

指定训练集或测试集是否小到能优化KNN搜索所需的向量乘操作。如果训练集小,该值应该是 `CrossHint.FIRST_IS_SMALL`,如果测试集小,则设置成 `CrossHint.SECOND_IS_SMALL`.(默认值: None)

示例

import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
import org.apache.flink.api.scala._
import org.apache.flink.ml.nn.KNN
import org.apache.flink.ml.math.Vector
import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric

val env = ExecutionEnvironment.getExecutionEnvironment

// 准备数据
val trainingSet: DataSet[Vector] = ...
val testingSet: DataSet[Vector] = ...

val knn = KNN()
  .setK(3)
  .setBlocks(10)
  .setDistanceMetric(SquaredEuclideanDistanceMetric())
  .setUseQuadTree(false)
  .setSizeHint(CrossHint.SECOND_IS_SMALL)

// 运行 knn join
knn.fit(trainingSet)
val result = knn.predict(testingSet).collect()

关于使用和不使用四叉树计算KNN的更多细节,参照该介绍: http://danielblazevski.github.io/