最具影响力的数字化技术在线社区

168大数据

 找回密码
 立即注册

QQ登录

只需一步,快速开始

1 2 3 4 5
打印 上一主题 下一主题
开启左侧

Spark范例:K-means算法

[复制链接]
跳转到指定楼层
楼主
发表于 2014-10-24 17:08:38 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式

马上注册,结交更多数据大咖,获取更多知识干货,轻松玩转大数据

您需要 登录 才可以下载或查看,没有帐号?立即注册

x

k-means 算法接受输入量 k ;然后将n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。

算法首先会随机确定K个中心位置(位于空间中代表聚类中心的点),然后将各个数据项分配给最临近的中心点。待分配完成之后,聚类中心就会移到分配给该聚类的所有节点的平均位置处,然后整个分配过程重新开始。这个过程会一直重复下去,直到分配过程不再产生变化为止。下图是涉及5个数据项和2个聚类的过程。

Spark的例子里也提供了K-means算法的实现,我加了注释方便理解:

[AppleScript] 纯文本查看 复制代码
package spark.examples
 
import java.util.Random
import spark.SparkContext
import spark.SparkContext._
import spark.examples.Vector._
 
object SparkKMeans {
  /**
   * line -> vector
   */
  def parseVector(line: String): Vector = {
    return new Vector(line.split(' ').map(_.toDouble))
  }
 
  /**
   * 计算该节点的最近中心节点
   */
  def closestCenter(p: Vector, centers: Array[Vector]): Int = {
    var bestIndex = 0
    var bestDist = p.squaredDist(centers(0))//差平方之和
    for (i <- 1 until centers.length) {
      val dist = p.squaredDist(centers(i))
      if (dist < bestDist) {
        bestDist = dist
        bestIndex = i
      }
    }
    return bestIndex
  }
 
  def main(args: Array[String]) {
    if (args.length < 3) {
      System.err.println("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
      System.exit(1)
    }
    val sc = new SparkContext(args(0), "SparkKMeans")
    val lines = sc.textFile(args(1), args(5).toInt)
    val points = lines.map(parseVector(_)).cache() //文本中每行为一个节点,再将每个节点转换成Vector
    val dimensions = args(2).toInt//节点的维度
    val k = args(3).toInt //聚类个数
    val iterations = args(4).toInt//迭代次数
 
    // 随机初始化k个中心节点
    val rand = new Random(42)
    var centers = new Array[Vector](k)
    for (i <- 0 until k)
      centers(i) = Vector(dimensions, _ => 2 * rand.nextDouble - 1)
    println("Initial centers: " + centers.mkString(", "))
    val time1 = System.currentTimeMillis()
    for (i <- 1 to iterations) {
      println("On iteration " + i)
 
      // Map each point to the index of its closest center and a (point, 1) pair
      // that we will use to compute an average later
      val mappedPoints = points.map { p => (closestCenter(p, centers), (p, 1)) }
 
      val newCenters = mappedPoints.reduceByKey {
        case ((sum1, count1), (sum2, count2)) => (sum1 + sum2, count1 + count2) //(向量相加, 计数器相加)
      }.map { 
        case (id, (sum, count)) => (id, sum / count)//根据前面的聚类,重新计算中心节点的位置
      }.collect
 
      // 更新中心节点
      for ((id, value) <- newCenters) {
        centers(id) = value
      }
    }
    val time2 = System.currentTimeMillis()
    println("Final centers: " + centers.mkString(", ") + ", time: "+(time2- time1) )
  }
}

例子中使用了iterations来限制迭代次数,并不是一种好的方法。可以设置一个阀值,在更新中心节点前,判断新的节点和上一次计算的中心计算差平方之和是否已经到了阀值内,如果是,则不需要继续计算下去。

其中用到的Vector类 https://github.com/mesos/spark/b ... amples/Vector.scala



楼主热帖
分享到:  QQ好友和群QQ好友和群 QQ空间QQ空间 腾讯微博腾讯微博 腾讯朋友腾讯朋友
收藏收藏 转播转播 分享分享 分享淘帖 赞 踩

168大数据 - 论坛版权1.本主题所有言论和图片纯属网友个人见解,与本站立场无关
2.本站所有主题由网友自行投稿发布。若为首发或独家,该帖子作者与168大数据享有帖子相关版权。
3.其他单位或个人使用、转载或引用本文时必须同时征得该帖子作者和168大数据的同意,并添加本文出处。
4.本站所收集的部分公开资料来源于网络,转载目的在于传递价值及用于交流学习,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。
5.任何通过此网页连接而得到的资讯、产品及服务,本站概不负责,亦不负任何法律责任。
6.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源,若标注有误或遗漏而侵犯到任何版权问题,请尽快告知,本站将及时删除。
7.168大数据管理员和版主有权不事先通知发贴者而删除本文。

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

站长推荐上一条 /1 下一条

关于我们|小黑屋|Archiver|168大数据 ( 京ICP备14035423号|申请友情链接

GMT+8, 2024-5-5 08:10

Powered by BI168大数据社区

© 2012-2014 168大数据

快速回复 返回顶部 返回列表