Giter Site home page Giter Site logo

blog's People

Contributors

teeyog avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

blog's Issues

基于SparkSQL实现的一套即席查询服务

IQL (项目地址:https://github.com/teeyog/IQL)

README-EN

基于SparkSQL实现了一套即席查询服务,具有如下特性:

  • 优雅的交互方式,支持多种datasource/sink,多数据源混算
  • spark常驻服务,基于zookeeper的引擎自动发现
  • 负载均衡,多个引擎随机执行
  • 多session模式实现并行查询
  • 采用spark的FAIR调度,避免资源被大任务独占
  • 基于spark的动态资源分配,在无任务的情况下不会占用executor资源
  • 支持Cluster和Client模式启动
  • 基于Structured Streaming实现SQL动态添加流
  • 类似SparkShell交互式数据分析功能
  • 高效的script管理,配合import/include语法完成各script的关联
  • 对数据源操作的权限验证

支持的数据源:hdfs、hive、hbase、kafka、mysql、es、mongo

支持的文件格式:parquet、csv、orc、json、text、xml

在Structured Streaming支持的Sink之外还增加了对Hbase、MySQL、es的支持

Quickstart

HBase

加载数据

load hbase.t_mbl_user_version_info 
where `spark.table.schema`="userid:String,osversion:String,toolversion:String"
	   and `hbase.table.schema`=":rowkey,info:osversion,info:toolversion" 
	   and `hbase.zookeeper.quorum`="localhost:2181"
as tb;
参数 说明 默认值
hbase.zookeeper.quorum zookeeper地址 localhost:2181
spark.table.schema Spark临时表对应的schema(eg: "ID:String,appname:String,age:Int")
hbase.table.schema HBase表对应schema(eg: ":rowkey,info:appname,info:age")
spark.rowkey.view.name rowkey对应的dataframe创建的temp view名 ,设置了该值后只获取rowkey对应的数据

可获取指定rowkey集合对应的数据,spark.rowkey.view.name 即是rowkey集合对应的tempview,默认获取第一列为rowkey列

保存数据

save tb1 as hbase.tableName 
where `hbase.zookeeper.quorum`="localhost:2181"
      and `hbase.table.rowkey.filed`="name"
参数 说明 默认值
hbase.zookeeper.quorum zookeeper地址 localhost:2181
hbase.table.rowkey.field spark临时表中作为hbase的rowkey的字段名 第一个字段
bulkload.enable 是否启动bulkload false
hbase.table.name Hbase表名
hbase.table.family 列族名 info
hbase.table.region.splits 预分区方式1:直接指定预分区分区段,以数组字符串方式指定,如 ['1','2','3']
hbase.table.rowkey.prefix 预分区方式2:当rowkey是数字,预分区只需指定前缀的formate形式,如 00 即可生成00-99等100个分区
hbase.table.startKey 预分区开始key
hbase.table.endKey 预分区结束key
hbase.table.numReg 分区个数
hbase.check_table 写入hbase表时,是否需要检查表是否存在 false
hbase.cf.ttl ttl

MySQL

  • 加载数据
load jdbc.ai_log_count 
where driver="com.mysql.jdbc.Driver" 
      and url="jdbc:mysql://localhost/db?characterEncoding=utf8" 
      and user="root" 
      and password="***" 
as tb; 
  • 保存数据
save append tb as jdbc.aatest_delete;

文件操作 (其中formate可为:json、orc、csv、parquet、text)

  • 加载数据
load format.`path` as tb;
  • 保存数据
save tb as formate.`path` partitionBy uid coalesce 2;

Kafka

  • 离线
load kafka.`topicName`
where maxRatePerPartition="200"
   and `group.id`="consumerGroupId"
as tb;
select * from tb;
参数 说明 默认值
autoCommitOffset 是否提交offset false
  • 实时
load kafka.`mc-monitor` 
where startingoffsets="latest"
	and failOnDataLoss="false"
	and `spark.job.mode`="stream" 
as tb1;

register watermark.tb1
where eventTimeCol="timestamp"
and delayThreshold="10 seconds"

select window.end as time_end,
count(1) as count
from tb1 a  
group by  window(a.timestamp,"10 seconds","10 seconds")
as tb2;

save tb2 as json.`/tmp/abc6` 
where outputMode="Append"
	and streamName="Stream"
	and duration="10"
	and sendDingDingOnTerminated="true"
	and `mail.receiver`="[email protected]"
	and checkpointLocation="/tmp/cp/cp16";
参数 说明 默认值
spark.job.mode 任务模式(batch:离线任务,stream:实时任务) batch
mail.receiver 任务失败邮件通知(多个邮箱逗号分隔)
sendDingDingOnTerminated 钉钉Robot通知 false

实时任务失败会自动重启,可以通过streamJobMaxAttempts配置(默认3次)。

动态注册UDF函数

register udf.`myupper`
where func="
	def apply(name:String)={
		name.toUpperCase
	}
";

load jsonStr.'
{"name":"ufo"}
{"name":"uu"}
{"name":"HIN"}
' as tb1;

select myupper(name) as newName from tb1;

include(import等效)语法,通过路径引入脚本片段

import语法

参考

StreamingPro之MLSQL

spark sql在喜马拉雅的使用之xql

Checkpoint 源码解析

前言

在spark应用程序中,常常会遇到运算量很大经过很复杂的 Transformation才能得到的RDD即Lineage链较长、宽依赖的RDD,此时我们可以考虑将这个RDD持久化。

cache也是可以持久化到磁盘,只不过是直接将partition的输出数据写到磁盘,而checkpoint是在逻辑job完成后,若有需要checkpoint的RDD,再单独启动一个job去完成checkpoint,这样该RDD就被计算了两次,所以建议在有checkpoint的时候先将该RDD cache到内存,到时候直接写到磁盘就行了。

checkpoint的实现

需要使用checkpoint都需要通过sparkcontext的setCheckpointDir方法设置一个目录以存checkpoint的各种信息数据,下面我们来看看该方法:

def setCheckpointDir(directory: String) {
    if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) {
      logWarning("Spark is not running in local mode, therefore the checkpoint directory " +
        s"must not be on the local filesystem. Directory '$directory' " +
        "appears to be on the local filesystem.")
    }
    checkpointDir = Option(directory).map { dir =>
      val path = new Path(dir, UUID.randomUUID().toString)
      val fs = path.getFileSystem(hadoopConfiguration)
      fs.mkdirs(path)
      fs.getFileStatus(path).getPath.toString
    }
  }

在非local模式下,directory必须是HDFS的目录;在该目录下创建一个以UUID生成的一个唯一的目录名的目录。
通过rdd.checkpoint()即可checkpoint此RDD

def checkpoint(): Unit = RDDCheckpointData.synchronized { 
    if (context.checkpointDir.isEmpty) {
      throw new SparkException("Checkpoint directory has not been set in the SparkContext")
    } else if (checkpointData.isEmpty) {
      checkpointData = Some(new ReliableRDDCheckpointData(this))
    }
  }

先判断是否设置了checkpointDir,再判断checkpointData.isEmpty是否成立,checkpointData的定义是这样的:

private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None

RDDCheckpointData和RDD一一对应,保存着和checkpoint相关的信息。这里通过new ReliableRDDCheckpointData(this)实例化了checkpointData ,ReliableRDDCheckpointData是其子类,这里相当于是checkpoint的一个标记,并没有真正执行checkpoint。

什么时候checkpoint

在有action动作时,会触发sparkcontext对runJob的调用:

def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      resultHandler: (Int, U) => Unit): Unit = {
    if (stopped.get()) {
      throw new IllegalStateException("SparkContext has been shutdown")
    }
    val callSite = getCallSite
    val cleanedFunc = clean(func)
    logInfo("Starting job: " + callSite.shortForm)
    if (conf.getBoolean("spark.logLineage", false)) {
      logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
    }
    dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
    progressBar.foreach(_.finishAll())
    rdd.doCheckpoint()
  }

我们可以看到在执行完job后会执行 rdd.doCheckpoint(),这里就是对前面标记了的RDD的checkpoint,我们继续看这个方法:

private[spark] def doCheckpoint(): Unit = {
    RDDOperationScope.withScope(sc, "checkpoint", allowNesting = false, ignoreParent = true) {
      if (!doCheckpointCalled) {
        doCheckpointCalled = true
        if (checkpointData.isDefined) {
          if (checkpointAllMarkedAncestors) {
              dependencies.foreach(_.rdd.doCheckpoint())
          }
          checkpointData.get.checkpoint()
        } else {
      dependencies.foreach(_.rdd.doCheckpoint())
        }
      }
    }
  }

先判断是否已经被处理过checkpoint,没有才执行,并将doCheckpointCalled 设为true,因为前面已经初始化过了checkpointData,所以checkpointData.isDefined也满足,若想要把checkpointData定义过的RDD的parents也进行checkpoint的话,那么我们需要先对parents checkpoint。因为,如果RDD把自己checkpoint了,那么它就将lineage中它的parents给切除了。继续跟进checkpointData.get.checkpoint()

final def checkpoint(): Unit = {
    // Guard against multiple threads checkpointing the same RDD by
    // atomically flipping the state of this RDDCheckpointData
    RDDCheckpointData.synchronized {
      if (cpState == Initialized) {
        cpState = CheckpointingInProgress
      } else {
        return
      }
    }

    val newRDD = doCheckpoint()

    // Update our state and truncate the RDD lineage
    RDDCheckpointData.synchronized {
      cpRDD = Some(newRDD)
      cpState = Checkpointed
      rdd.markCheckpointed()
    }
  }

先将checkpoint的状态改为CheckpointingInProgress,再执行doCheckpoint,返回一个newRDD,看doCheckpoint做了什么:

protected override def doCheckpoint(): CheckpointRDD[T] = {
    val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)
    if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
      rdd.context.cleaner.foreach { cleaner =>
        cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
      }
    }
    logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")
    newRDD
  }

ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir),将一个RDD写入到多个checkpoint文件,并返回一个ReliableCheckpointRDD来代表这个RDD

def writeRDDToCheckpointDirectory[T: ClassTag](
      originalRDD: RDD[T],
      checkpointDir: String,
      blockSize: Int = -1): ReliableCheckpointRDD[T] = {
    val sc = originalRDD.sparkContext
    // Create the output path for the checkpoint
    val checkpointDirPath = new Path(checkpointDir)
    val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
    if (!fs.mkdirs(checkpointDirPath)) {
      throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath")
    }
    // Save to file, and reload it as an RDD
    val broadcastedConf = sc.broadcast(
      new SerializableConfiguration(sc.hadoopConfiguration))
    // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
    sc.runJob(originalRDD,
      writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)
    if (originalRDD.partitioner.nonEmpty) {
      writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
    }
    val newRDD = new ReliableCheckpointRDD[T](
      sc, checkpointDirPath.toString, originalRDD.partitioner)
    if (newRDD.partitions.length != originalRDD.partitions.length) {
      throw new SparkException(
        s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
          s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})")
    }
    newRDD
  }

获取一些配置信息广播输出等操作,然后启动一个Job去写Checkpint文件,主要由ReliableCheckpointRDD.writeCheckpointFile来实现写操作,写完checkpoint后new一个ReliableCheckpointRDD实例返回,看看具体的writePartitionToCheckpointFile实现:

def writePartitionToCheckpointFile[T: ClassTag](
      path: String,
      broadcastedConf: Broadcast[SerializableConfiguration],
      blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
    val env = SparkEnv.get
    val outputDir = new Path(path)
    val fs = outputDir.getFileSystem(broadcastedConf.value.value)

    val finalOutputName = ReliableCheckpointRDD.checkpointFileName(ctx.partitionId())
    val finalOutputPath = new Path(outputDir, finalOutputName)
    val tempOutputPath =
      new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}")

    if (fs.exists(tempOutputPath)) {
      throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists")
    }
    val bufferSize = env.conf.getInt("spark.buffer.size", 65536)

    val fileOutputStream = if (blockSize < 0) {
      fs.create(tempOutputPath, false, bufferSize)
    } else {
      // This is mainly for testing purpose
      fs.create(tempOutputPath, false, bufferSize,
        fs.getDefaultReplication(fs.getWorkingDirectory), blockSize)
    }
    val serializer = env.serializer.newInstance()
    val serializeStream = serializer.serializeStream(fileOutputStream)
    Utils.tryWithSafeFinally {
      serializeStream.writeAll(iterator)
    } {
      serializeStream.close()
    }

    if (!fs.rename(tempOutputPath, finalOutputPath)) {
      if (!fs.exists(finalOutputPath)) {
        logInfo(s"Deleting tempOutputPath $tempOutputPath")
        fs.delete(tempOutputPath, false)
        throw new IOException("Checkpoint failed: failed to save output of task: " +
          s"${ctx.attemptNumber()} and final output path does not exist: $finalOutputPath")
      } else {
        // Some other copy of this task must've finished before us and renamed it
        logInfo(s"Final output path $finalOutputPath already exists; not overwriting it")
        if (!fs.delete(tempOutputPath, false)) {
          logWarning(s"Error deleting ${tempOutputPath}")
        }
      }
    }
  }

这里的代码就是普通的对HDFS写文件的操作,将一个RDD partition的数据写到checkpoint目录下。

doCheckpoint()操作已经完成,返回了一个new RDD:ReliableCheckpointRDD引用给cpRDD,接着标记checkpoint的状态为Checkpointed,rdd.markCheckpointed()干了什么呢?

private[spark] def markCheckpointed(): Unit = {
    clearDependencies()
    partitions_ = null
    deps = null    // Forget the constructor argument for dependencies too
  }

最后再清除RDD的所有依赖。

写checkpoint总结

  • Initialized
  • marked for checkpointing
  • checkpointing in progress
  • checkpointed

什么时候读checkpoint

在需要读取一个partition的数据时,会通过rdd.iterator() 去计算该 rdd 的 partition 的,我们来看RDD的iterator()实现:

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      getOrCompute(split, context)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

在cache中没有读到数据时再判断该RDD是否被checkpoint过,isCheckpointedAndMaterialized就是在checkpoint成功时的一个状态标记:cpState = Checkpointed。

private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointedAndMaterialized) {
      firstParent[T].iterator(split, context)
    } else {
      compute(split, context)
    }
  }

当该RDD被成功checkpoint了,直接使用parent rdd 的 iterator() 也就是 CheckpointRDD.iterator(),否则直接调用该RDD的compute方法。

final def dependencies: Seq[Dependency[_]] = {
    checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
      if (dependencies_ == null) {
        dependencies_ = getDependencies
      }
      dependencies_
    }
  }

获取RDD的依赖时,会先尝试从checkpointRDD中获取依赖,若成功则返回被OneToOneDependency包装过的ReliableCheckpointRDD对象,否则获取真正的依赖。

Standalone模式下Master、WorKer启动流程

本文基于spark2.1进行解析

前言

Spark作为分布式的计算框架可支持多种运行模式:

  • 本地运行模式 (单机)
  • 本地伪集群运行模式(单机模拟集群)
  • Standalone Client模式(集群)
  • Standalone Cluster模式(集群)
  • YARN Client模式(集群)
  • YARN Cluster模式(集群)

而Standalone 作为spark自带cluster manager,需要启动Master和Worker守护进程,本文将从源码角度解析两者的启动流程。Master和Worker之间的通信使用的是基于netty的RPC,Spark的Rpc推荐看深入解析Spark中的RPC

Master 启动

启动Master是通过脚本start-master.sh启动的,里面实际调用的类是:

org.apache.spark.deploy.master.Master

看看其main方法:

def main(argStrings: Array[String]) {
    Utils.initDaemon(log)
    val conf = new SparkConf
    val args = new MasterArguments(argStrings, conf)
    // 创建RpcEnv,启动Rpc服务
    val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
    //阻塞等待
    rpcEnv.awaitTermination()
  }

main方法先获取配置参数创建SparkConf,通过startRpcEnvAndEndpoint启动一个RPCEnv并创建一个Endpoint,调用awaitTermination来阻塞服务端监听请求并且处理。下面细看startRpcEnvAndEndpoint方法:

  def startRpcEnvAndEndpoint(
      host: String,
      port: Int,
      webUiPort: Int,
      conf: SparkConf): (RpcEnv, Int, Option[Int]) = {
    val securityMgr = new SecurityManager(conf)
    // 创建RpcEnv
    val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
    //通过rpcEnv 创建一个Endpoint
    val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME,
      new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
    val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest)
    (rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
  }

首先创建了RpcEnv,RpcEnv是整个Spark RPC的核心所在,RPCEndpoint定义了处理消息的逻辑,被创建后就被RpcEnv所管理,整个生命周期顺序为onStart,receive,onStop,其中receive可以被同时调用,ThreadSafeRpcEndpoint中的receive是线程安全的,同一时刻只能被一个线程访问。

该方法中向rpcEnv 注册的Endpoint是Master(继承了ThreadSafeRpcEndpoint),Master的构造器中创建了保存各种信息的变量。

 ...
  //一个HashSet用于保存WorkerInfo
  val workers = new HashSet[WorkerInfo]
 //一个HashSet用于保存客户端(SparkSubmit)提交的任务
  val apps = new HashSet[ApplicationInfo]
 //等待调度的App
  val waitingApps = new ArrayBuffer[ApplicationInfo]
 //保存DriverInfo
  val drivers = new HashSet[DriverInfo]
 ...

由于Master是一个Endpoint并被RpcEnv管理,需要先执行生命周期的onStart方法:

override def onStart(): Unit = {
   ...
    checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable {
      override def run(): Unit = Utils.tryLogNonFatalError {
        self.send(CheckForWorkerTimeOut)
      }
    }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)
   ...
  }

向线程池中加入了一个线程,每隔WORKER_TIMEOUT_MS(默认60秒)时间去检测是否有Worker超时,其实就是向自己发送了一个CheckForWorkerTimeOut事件,稍后再细讲。

Worker启动

多个节点上的Worker是通过脚本start-slaves.sh启动,底层调用的类是:

org.apache.spark.deploy.worker.Worker

看看其main方法:

def main(argStrings: Array[String]) {
    Utils.initDaemon(log)
    val conf = new SparkConf
    val args = new WorkerArguments(argStrings, conf)
    val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores,
      args.memory, args.masters, args.workDir, conf = conf)
    rpcEnv.awaitTermination()
  }

和Master类似,也是先获取配置参数创建SparkConf,接着调用startRpcEnvAndEndpoint启动一个RPCEnv并创建一个Endpoint,调用awaitTermination来阻塞服务端监听请求并且处理。

 def startRpcEnvAndEndpoint(
      host: String,
      port: Int,
      webUiPort: Int,
      cores: Int,
      memory: Int,
      masterUrls: Array[String],
      workDir: String,
      workerNumber: Option[Int] = None,
      conf: SparkConf = new SparkConf): RpcEnv = {

    // The LocalSparkCluster runs multiple local sparkWorkerX RPC Environments
    val systemName = SYSTEM_NAME + workerNumber.map(_.toString).getOrElse("")
    val securityMgr = new SecurityManager(conf)
    val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr)
    val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
    rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory,
      masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr))
    rpcEnv
  }

这里是通过new了一个Worker实例来作为Endpoint并注册到RpcEnv中,Worker的构造器中初始化了心跳超时时间为Master端的1/4及其他变量

Worker向Master注册

Worker需要根据生命周期执行onStart()方法:

override def onStart() {
   ...
    registerWithMaster()
   ...
  }

在onStart()方法中调用了registerWithMaster来向Master来注册自己:

private def registerWithMaster() {
    // onDisconnected may be triggered multiple times, so don't attempt registration
    // if there are outstanding registration attempts scheduled.
    registrationRetryTimer match {
      case None =>
        // 是否已注册
        registered = false
        // 尝试向所有Master注册自己
        registerMasterFutures = tryRegisterAllMasters()
        // 尝试连接次数
        connectionAttemptCount = 0
        // 网络或者Master故障的时候就需要重新注册自己
        // 注册重试次数超过阈值则直接退出
        registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate(
          new Runnable {
            override def run(): Unit = Utils.tryLogNonFatalError {
              Option(self).foreach(_.send(ReregisterWithMaster))
            }
          },
          INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
          INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
          TimeUnit.SECONDS))
      case Some(_) =>
        logInfo("Not spawning another attempt to register with the master, since there is an" +
          " attempt scheduled already.")
    }
  }

registrationRetryTimer第一次调用肯定为None,通过tryRegisterAllMasters向Master注册自己,后面还启动了一个线程在有限次数内去尝试重新注册(网络或者Master出现故障是需要重新注册)。这里先看tryRegisterAllMasters方法是如何向Master注册的:

private def tryRegisterAllMasters(): Array[JFuture[_]] = {
    masterRpcAddresses.map { masterAddress =>
      registerMasterThreadPool.submit(new Runnable {
        override def run(): Unit = {
          try {
            logInfo("Connecting to master " + masterAddress + "...")
            val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME)
            registerWithMaster(masterEndpoint)
          } catch {
            case ie: InterruptedException => // Cancelled
            case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e)
          }
        }
      })
    }
  }

这里调用了rpcEnv.setupEndpointRef,RpcEndpointRef 是 RpcEnv 中的 RpcEndpoint 的引用,是一个序列化的实体以便于通过网络传送或保存以供之后使用。一个 RpcEndpointRef 有一个地址和名字。可以调用 RpcEndpointRef 的 send 方法发送异步的单向的消息给对应的 RpcEndpoint 。

这里整段代码意思即是:遍历所有masterRpcAddresses,调用registerWithMaster方法,并传入master端的RpcEndpoint引用RpcEndpointRef ,继续看看registerWithMaster方法:

private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = {
    masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker(
      workerId, host, port, self, cores, memory, workerWebUiUrl))
      .onComplete {
        // This is a very fast action so we can use "ThreadUtils.sameThread"
        case Success(msg) =>
          Utils.tryLogNonFatalError {
            handleRegisterResponse(msg)
          }
        case Failure(e) =>
          logError(s"Cannot register with master: ${masterEndpoint.address}", e)
          System.exit(1)
      }(ThreadUtils.sameThread)
  }

通过RpcEndpointRef 和Master建立通信向Master发送RegisterWorker消息,并带入workerid,host,Port,cores,内存等参数信息,并有成功或者失败的回调函数稍后讲解。

Master 接收Worker注册

在Master中通过receiveAndReply方法处理各种需要回应的事件(单向消息通过receive),对于Worker注册消息RegisterWorker处理逻辑:

case RegisterWorker(
        id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) =>
      logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
        workerHost, workerPort, cores, Utils.megabytesToString(memory)))
      // 当前Master处于STANDBY
      if (state == RecoveryState.STANDBY) {
        context.reply(MasterInStandby)
      // Worker已经注册过了
      } else if (idToWorker.contains(id)) {
        context.reply(RegisterWorkerFailed("Duplicate worker ID"))
      } else {
        // 根据Worker注册信息为Worker创建WorkerInfo
        val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
          workerRef, workerWebUiUrl)
        if (registerWorker(worker)) {
          // 持久化记录Worker信息
          persistenceEngine.addWorker(worker)
          // 向Worker回复注册成功消息
          context.reply(RegisteredWorker(self, masterWebUiUrl))
          // 有了新的Worker,资源新增,为等待的app进行调度
          schedule()
        } else {
          val workerAddress = worker.endpoint.address
          logWarning("Worker registration failed. Attempted to re-register worker at same " +
            "address: " + workerAddress)
          // 向Worker回复注册失败消息
          context.reply(RegisterWorkerFailed("Attempted to re-register worker at same address: "
            + workerAddress))
        }
      }
  1. 若当前Master处于STANDBY状态,直接返回MasterInStandby消息
  2. 若Worker已经注册过了,直接返回RegisterWorkerFailed消息
  3. 根据Worker注册信息为Worker创建WorkerInfo,调用registerWorker方法进行注册:
  • 若注册成功则持久化这个Worker信息,并向Worker回复注册成功消息,另外,多了一个Worker意味着资源的增加会通过schedule()去调度等待调度的apps。
  • 若注册失败,则直接向Worker回复注册失败消息。

那是怎么判断是否注册成功呢?跟进registerWorker方法:

private def registerWorker(worker: WorkerInfo): Boolean = {
    // There may be one or more refs to dead workers on this same node (w/ different ID's),
    // remove them.
    workers.filter { w =>
      (w.host == worker.host && w.port == worker.port) && (w.state == WorkerState.DEAD)
    }.foreach { w =>
      workers -= w
    }
    // 获取新worker的workerAddress 
    val workerAddress = worker.endpoint.address
    if (addressToWorker.contains(workerAddress)) {
      // 根据workerAddress 获取以前注册的老Worker
      val oldWorker = addressToWorker(workerAddress)
      // 若为UNKNOWN则说明是Master 处于recovery,Worker处于恢复中
      if (oldWorker.state == WorkerState.UNKNOWN) {
        // 移除老Worker,接受新注册的Worker
        removeWorker(oldWorker)
      } else {
        logInfo("Attempted to re-register worker at same address: " + workerAddress)
        return false
      }
    }
    // 跟新变量
    workers += worker
    idToWorker(worker.id) = worker
    addressToWorker(workerAddress) = worker
    true
  }

遍历所有管理的Worker,若有与新注册的Worker相同的host,port且处于Dead(超时)状态的Worker则直接从workers中移除。若管理的addressToWorker已经存在新注册的Worker一样的workerAddress,则获取老Worker,若状态是UNKNOWN说明Master 处于recovery,Worker正处于恢复中,则将老Worker移除,将新Worker直接加入并成功返回,若老Worker是其他状态则说明已经重复注册了,返回失败。

Worker接收Master注册反馈消息

private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = {
    masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker(
      workerId, host, port, self, cores, memory, workerWebUiUrl))
      .onComplete {
        // This is a very fast action so we can use "ThreadUtils.sameThread"
        case Success(msg) =>
          Utils.tryLogNonFatalError {
            handleRegisterResponse(msg)
          }
        case Failure(e) =>
          logError(s"Cannot register with master: ${masterEndpoint.address}", e)
          System.exit(1)
      }(ThreadUtils.sameThread)
  }

在Worker向Master注册的时候就是调用的这个registerWithMaster方法,后随有回调方法处理结果,通过handleRegisterResponse来处理各种类型的反馈消息:

private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized {
    msg match {
      // 成功注册
      case RegisteredWorker(masterRef, masterWebUiUrl) =>
        logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
        // 标记成功注册
        registered = true
        // 跟新映射,删除其他的registeration retry
        changeMaster(masterRef, masterWebUiUrl)
        // 向Master发送心跳
        forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
          override def run(): Unit = Utils.tryLogNonFatalError {
            self.send(SendHeartbeat)
          }
        }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
       ...
      // 注册失败,直接退出进程
      case RegisterWorkerFailed(message) =>
        if (!registered) {
          logError("Worker registration failed: " + message)
          System.exit(1)
        }
      // Master不是处于Active的Master,忽略
      case MasterInStandby =>
        // Ignore. Master not yet ready.
    }
  }
  1. 当注册Worker失败收到RegisterWorkerFailed消息,则退出。
  2. 当注册的Master处于Standby状态,直接忽略。
  3. 注册Worker成功返回RegisteredWorker消息时,先标记注册成功,然后通过changeMaster更改一些变量(如activeMasterUrl,master,connected等),并删除当前其他正在重试的注册。然后新建了一个task到线程池执行,该线程每隔HEARTBEAT_MILLIS时间向自己发送一个SendHeartbeat消息,在消息处理方法receive里面可看到消息处理方法,即向Master发送心跳:
 case SendHeartbeat =>
      if (connected) { sendToMaster(Heartbeat(workerId, self)) }

Master 接收心跳

case Heartbeat(workerId, worker) =>
      idToWorker.get(workerId) match {
        case Some(workerInfo) =>
          workerInfo.lastHeartbeat = System.currentTimeMillis()
        case None =>
          if (workers.map(_.id).contains(workerId)) {
            logWarning(s"Got heartbeat from unregistered worker $workerId." +
              " Asking it to re-register.")
            worker.send(ReconnectWorker(masterUrl))
          } else {
            logWarning(s"Got heartbeat from unregistered worker $workerId." +
              " This worker was never registered, so ignoring the heartbeat.")
          }
      }

master端获取对应的workerInfo,若有则跟新上次获取心跳时间lastHeartbeat,若没有则向Worker发送需要重新建立连接的消息。

Master 检测Worker心跳超时

另外,由上文可知在Master的生命周期onStart里专门启动了一个线程检查worker是否超时,看看Master是如何处理的:

case CheckForWorkerTimeOut =>
      timeOutDeadWorkers()

private def timeOutDeadWorkers() {
    // Copy the workers into an array so we don't modify the hashset while iterating through it
    val currentTime = System.currentTimeMillis()
    val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray
    for (worker <- toRemove) {
      if (worker.state != WorkerState.DEAD) {
        logWarning("Removing %s because we got no heartbeat in %d seconds".format(
          worker.id, WORKER_TIMEOUT_MS / 1000))
        removeWorker(worker)
      } else {
        if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) {
          workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it
        }
      }
    }
  }

遍历所有管理的Worker,若上次心跳时间离现在已经超过超时时间则判断为超时,将从worker列表里移除。

SparkStreaming ReceiverTracker 数据产生与存储

看 spark streaming 源码解析之前最好先了解spark core的内容。

前言

Spark Streaming 是基于Spark Core将流式计算分解成一系列的小批处理任务来执行。

在Spark Streaming里,总体负责任务的动态调度是JobScheduler,而JobScheduler有两个很重要的成员:JobGeneratorReceiverTrackerJobGenerator 负责将每个 batch 生成具体的 RDD DAG ,而ReceiverTracker负责数据的来源。

Spark Streaming里的DStream可以看成是Spark Core里的RDD的模板,DStreamGraph是RDD DAG的模板。

跟着例子看流程

DStream 也和 RDD 一样有着转换(transformation)和 输出(output)操作,通过 transformation 操作会产生新的DStream,典型的 transformation 操作有map(), filter(), reduce(), join()等。RDD的输出操作会触发action,而DStream的输出操作也会新建一个ForeachDStream,用一个函数func来记录所需要做的操作。

下面看一个例子:

val conf = new SparkConf().setMaster("local[2]")
                          .setAppName("NetworkWordCount")
val ssc = new StreamingContext(conf, Seconds(1))
val lines = ssc.socketTextStream("localhost", 9999)
val words = lines.flatMap(_.split(" "))      
val pairs = words.map(word => (word, 1))    
val wordCounts = pairs.reduceByKey(_ + _)   
wordCounts.print()
ssc.start()
ssc.awaitTermination()

在创建 StreamingContext 的时候实创建了 graph: DStreamGraph:

private[streaming] val graph: DStreamGraph = {
    if (isCheckpointPresent) {
      _cp.graph.setContext(this)
      _cp.graph.restoreCheckpointData()
      _cp.graph
    } else {
      require(_batchDur != null, "Batch duration for StreamingContext cannot be null")
      val newGraph = new DStreamGraph()
      newGraph.setBatchDuration(_batchDur)
      newGraph
    }
  }

checkpoint 可用,会优先从 checkpoint 恢复 graph,否则新建一个。graph用来动态的创建RDD DAG,DStreamGraph有两个重要的成员:inputStreamsoutputStreams

private val inputStreams = new ArrayBuffer[InputDStream[_]]()
private val outputStreams = new ArrayBuffer[DStream[_]]()

Spark Streaming记录DStream DAG 的方式就是通过DStreamGraph 实例记录所有的outputStreams ,因为outputStream 会通过依赖
dependencies 来和parent DStream形成依赖链,通过outputStreams 向前追溯遍历就可以得到所有上游的DStream,另外,DStreamGraph 还会记录所有的inputStreams ,避免每次为查找 input stream 而对 output steam 进行 BFS 的消耗。

继续回到例子,这里通过ssc.socketTextStream 创建了一个ReceiverInputDStream,在其父类 InputDStream 中会将该ReceiverInputDStream添加到inputStream里。

接着调用了flatMap方法:

def flatMap[U: ClassTag](flatMapFunc: T => TraversableOnce[U]): DStream[U] = ssc.withScope {
    new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc))
  }

--------------------------------------------------------------------

private[streaming]
class FlatMappedDStream[T: ClassTag, U: ClassTag](
    parent: DStream[T],
    flatMapFunc: T => TraversableOnce[U]
  ) extends DStream[U](parent.ssc) {

  override def dependencies: List[DStream[_]] = List(parent)

  override def slideDuration: Duration = parent.slideDuration

  override def compute(validTime: Time): Option[RDD[U]] = {
    parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc))
  }
}

创建了一个 FlatMappedDStream ,而该类的compute方法是在父 DStream(ReceiverInputDStream) 在对应batch时间的RDD上调用了flatMap方法,也就是构造了 rdd.flatMap(func) 这样的代码,后面的操作类似,随后形成的是 rdd.flatMap(func1).map(func2).reduceByKey(func3).take(),这不就是我们spark core里的东西吗。另外其dependencies是直接指向了其构造参数parent,也就是刚才的ReceiverInputDStream,每个新建的DStream的dependencies都是指向了其父DStream,这样就构成了一个依赖链,也就是形成了DStream DAG。

这里我们再看看最后的 print() 操作:

----
def print(num: Int): Unit = ssc.withScope {
    def foreachFunc: (RDD[T], Time) => Unit = {
      (rdd: RDD[T], time: Time) => {
        val firstNum = rdd.take(num + 1)
        // scalastyle:off println
        println("-------------------------------------------")
        println(s"Time: $time")
        println("-------------------------------------------")
        firstNum.take(num).foreach(println)
        if (firstNum.length > num) println("...")
        println()
        // scalastyle:on println
      }
    }
    foreachRDD(context.sparkContext.clean(foreachFunc), displayInnerRDDOps = false)
  }
----
private def foreachRDD(
      foreachFunc: (RDD[T], Time) => Unit,
      displayInnerRDDOps: Boolean): Unit = {
    new ForEachDStream(this,
      context.sparkContext.clean(foreachFunc, false), displayInnerRDDOps).register()
  }
----
#ForEachDStream
override def generateJob(time: Time): Option[Job] = {
    parent.getOrCompute(time) match {
      case Some(rdd) =>
        val jobFunc = () => createRDDWithLocalProperties(time, displayInnerRDDOps) {
          foreachFunc(rdd, time)
        }
        Some(new Job(time, jobFunc))
      case None => None
    }
  }

在print() 方法里构建了一个foreachFunc方法:对一个rdd进行了take操作并打印(spark core中的action操作)。随后创建了ForEachDStream实例并调用了register()方法:

 private[streaming] def register(): DStream[T] = {
    ssc.graph.addOutputStream(this)
    this
  }

将 OutputStream 添加到DStreamGraphoutputStreams 里。可以看到刚才构建的 foreachFunc 方法最终用在了ForEachDStream实例的generateJob方法里,并创建了一个Streaming 中的Job,在job中的run方法中会调用这个方法,也就是会触发action操作。

注意这里Spark Streaming的Job和Spark Core里的Job是不一样的,Streaming的Job执行的是前面构造的方法,方法里面是Core里的Job,方法可以定义多个core里的Job,也可以一个core里的job都没有。

Spark metrics实现KafkaSink

背景

监控是Spark非常重要的一部分。Spark的运行情况是由ListenerBus以及MetricsSystem 来完成的。通过Spark的Metrics系统,我们可以把Spark Metrics的收集到的信息发送到各种各样的Sink,比如HTTP、JMX以及CSV文件。
目前支持的Sink包括:

  • ConsoleSink
  • CSVSink
  • JmxSink
  • MetricsServlet
  • GraphiteSink
  • GangliaSink

有时我们需要实时获取metrics数据通过spark分析展示等需求,这个时候若有个KafkaSink将metrics指标数据实时往kafka发送那就太方便了,故有了这篇博文。

实践

所有的Sink都需要继承Sink这个特质:

private[spark] trait Sink {
  def start(): Unit
  def stop(): Unit
  def report(): Unit
}

当该Sink注册到metrics系统中时,会调用start方法进行一些初始化操作,再通过report方式进行真正的输出操作,stop方法可以进行一些连接关闭等操作。直接上代码:

package org.apache.spark.metrics.sink

import java.util.concurrent.TimeUnit
import java.util.{Locale, Properties}

import com.codahale.metrics.MetricRegistry
import org.apache.kafka.clients.producer.KafkaProducer
import org.apache.spark.SecurityManager
import org.apache.spark.internal.Logging

private[spark] class KafkaSink(val property: Properties, val registry: MetricRegistry,
                               securityMgr: SecurityManager) extends Sink with Logging{

    val KAFKA_KEY_PERIOD = "period"
    val KAFKA_DEFAULT_PERIOD = 10

    val KAFKA_KEY_UNIT = "unit"
    val KAFKA_DEFAULT_UNIT = "SECONDS"

    val KAFKA_TOPIC = "topic"
    val KAFKA_DEFAULT_TOPIC = "kafka-sink-topic"

    val KAFAK_BROKERS = "kafka-brokers"
    val KAFAK_DEFAULT_BROKERS = "XXX:9092"

    val TOPIC = Option(property.getProperty(KAFKA_TOPIC)).getOrElse(KAFKA_DEFAULT_TOPIC)
    val BROKERS = Option(property.getProperty(KAFAK_BROKERS)).getOrElse(throw new IllegalStateException("kafka-brokers is null!"))

    private val kafkaProducerConfig = new Properties()
    kafkaProducerConfig.put("bootstrap.servers",BROKERS)
    kafkaProducerConfig.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer")
    kafkaProducerConfig.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer")

    private val producer = new KafkaProducer[String, String](kafkaProducerConfig)

    private val reporter: KafkaReporter = KafkaReporter.forRegistry(registry)
        .topic(TOPIC)
        .build(producer)


    val pollPeriod = Option(property.getProperty(KAFKA_KEY_PERIOD)) match {
        case Some(s) => s.toInt
        case None => KAFKA_DEFAULT_PERIOD
    }

    val pollUnit: TimeUnit = Option(property.getProperty(KAFKA_KEY_UNIT)) match {
        case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT))
        case None => TimeUnit.valueOf(KAFKA_DEFAULT_UNIT)
    }

    override def start(): Unit = {
        log.info("I4 Metrics System KafkaSink Start ......")
        reporter.start(pollPeriod, pollUnit)
    }

    override def stop(): Unit = {
        log.info("I4 Metrics System KafkaSink Stop ......")
        reporter.stop()
        producer.close()
    }

    override def report(): Unit = {
        log.info("I4 Metrics System KafkaSink Report ......")
        reporter.report()
    }
}

KafkaReporter类:

package org.apache.spark.metrics.sink;

import com.alibaba.fastjson.JSONObject;
import com.codahale.metrics.*;
import com.twitter.bijection.Injection;
import com.twitter.bijection.avro.GenericAvroCodecs;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.SortedMap;
import java.util.concurrent.TimeUnit;

public class KafkaReporter  extends ScheduledReporter  {

    private static final Logger LOGGER = LoggerFactory.getLogger(KafkaReporter.class);

    public static KafkaReporter.Builder forRegistry(MetricRegistry registry) {
        return new KafkaReporter.Builder(registry);
    }

    private KafkaProducer producer;
    private Clock clock;
    private String topic;

    private KafkaReporter(MetricRegistry registry,
                        TimeUnit rateUnit,
                        TimeUnit durationUnit,
                        MetricFilter filter,
                        Clock clock,
                        String topic,
                        KafkaProducer producer) {
        super(registry, "kafka-reporter", filter, rateUnit, durationUnit);
        this.producer = producer;
        this.topic = topic;
        this.clock = clock;
    }

    @Override
    public void report(SortedMap<String, Gauge> gauges, SortedMap<String, Counter> counters, SortedMap<String, Histogram> histograms, SortedMap<String, Meter> meters, SortedMap<String, Timer> timers) {
        final long timestamp = TimeUnit.MILLISECONDS.toSeconds(clock.getTime());

        // Gauge
        for (Map.Entry<String, Gauge> entry : gauges.entrySet()) {
            reportGauge(timestamp,entry.getKey(), entry.getValue());
        }
        // Histogram
//        for (Map.Entry<String, Histogram> entry : histograms.entrySet()) {
//            reportHistogram(timestamp, entry.getKey(), entry.getValue());
//        }
    }


    private void reportGauge(long timestamp, String name, Gauge gauge) {
        report(timestamp, name, gauge.getValue());
    }

    private void reportHistogram(long timestamp, String name, Histogram histogram) {
        final Snapshot snapshot = histogram.getSnapshot();
        report(timestamp, name, snapshot.getMax());
    }

    private void report(long timestamp, String name,  Object values) {
        JSONObject jsonObject = new JSONObject();
        jsonObject.put("name",name);
        jsonObject.put("timestamp",timestamp);
        jsonObject.put("value",values);
        producer.send(new ProducerRecord(topic,name, jsonObject.toJSONString()));
    }


    public static class Builder {

        private final MetricRegistry registry;
        private TimeUnit rateUnit;
        private TimeUnit durationUnit;
        private MetricFilter filter;
        private Clock clock;
        private String topic;

        private Builder(MetricRegistry registry) {
            this.registry = registry;
            this.rateUnit = TimeUnit.SECONDS;
            this.durationUnit = TimeUnit.MILLISECONDS;
            this.filter = MetricFilter.ALL;
            this.clock = Clock.defaultClock();
        }

        /**
         * Convert rates to the given time unit.
         *
         * @param rateUnit a unit of time
         * @return {@code this}
         */
        public KafkaReporter.Builder convertRatesTo(TimeUnit rateUnit) {
            this.rateUnit = rateUnit;
            return this;
        }

        /**
         * Convert durations to the given time unit.
         *
         * @param durationUnit a unit of time
         * @return {@code this}
         */
        public KafkaReporter.Builder convertDurationsTo(TimeUnit durationUnit) {
            this.durationUnit = durationUnit;
            return this;
        }

        /**
         * Use the given {@link Clock} instance for the time.
         *
         * @param clock a {@link Clock} instance
         * @return {@code this}
         */
        public Builder withClock(Clock clock) {
            this.clock = clock;
            return this;
        }

        /**
         * Only report metrics which match the given filter.
         *
         * @param filter a {@link MetricFilter}
         * @return {@code this}
         */
        public KafkaReporter.Builder filter(MetricFilter filter) {
            this.filter = filter;
            return this;
        }

        /**
         * Only report metrics which match the given filter.
         *
         * @param topic a
         * @return {@code this}
         */
        public KafkaReporter.Builder topic(String topic) {
            this.topic = topic;
            return this;
        }

        /**
         * Builds a {@link KafkaReporter} with the given properties, writing {@code .csv} files to the
         * given directory.
         *
         * @return a {@link KafkaReporter}
         */
        public KafkaReporter build(KafkaProducer producer) {
            return new KafkaReporter(registry,
                    rateUnit,
                    durationUnit,
                    filter,
                    clock,
                    topic,
                    producer);
        }
    }
}

其中的report方法就是获取各种类型指标,并进行对应的输出操作的时机。

如何使用

可在配置文件或者程序中设定需要注册的sink,并带上对应的参数即可:

spark.metrics.conf.*.sink.kafka.class=org.apache.spark.metrics.sink.KafkaSink
spark.metrics.conf.*.sink.kafka.kafka-brokers=XXX:9092

Shuffle Read解析 (Sort Based Shuffle)

Shuffle Write 请看 Shuffle Write解析

本文将讲解shuffle Reduce部分,shuffle的下游Stage的第一个rdd是ShuffleRDD,通过其compute方法来获取上游Stage Shuffle Write溢写到磁盘文件数据的一个迭代器:

 override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

从SparkEnv中获取shuffleManager(这里是SortShuffleManager),通过manager获取Reader并调用其read方法来得到一个迭代器。

override def getReader[K, C](
      handle: ShuffleHandle,
      startPartition: Int,
      endPartition: Int,
      context: TaskContext): ShuffleReader[K, C] = {
    new BlockStoreShuffleReader(
      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
  }

getReader方法实例化了一个BlockStoreShuffleReader,参数有需要获取分区对应的partitionId,看看起read方法:

 override def read(): Iterator[Product2[K, C]] = {
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      // 获取存储数据位置的元数据
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // 每次远程请求传输的最大大小
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

    // 用压缩加密来包装流
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      serializerManager.wrapStream(blockId, inputStream)
    }
  
    val serializerInstance = dep.serializer.newInstance()

    // 对每个流生成K/V迭代器
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
       serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // 每条记录读取后更新任务度量
    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
    // 生成完整的迭代器
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map { record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // 在map端已经聚合一次了
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // 只在reduce端聚合
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // 若需要全局排序
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

首先实例化了ShuffleBlockFetcherIterator对象,其中一个参数:

mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)

该方法获取reduce端数据的来源的元数据,返回的是 Seq[(BlockManagerId, Seq[(BlockId, Long)])],即数据是来自于哪个节点的哪些block的,并且block的数据大小是多少,看看getMapSizesByExecutorId是怎么实现的:

def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
    // 获取元数据信息
    val statuses = getStatuses(shuffleId)
    // 转换格式并得到指定partition的元数据信息
    statuses.synchronized {
      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
    }
  }
  • 传入shuffleId获取对应shuffle的所有元数据信息
  • 转换格式并获取指定partition的元数据

跟进getStatuses:

private def getStatuses(shuffleId: Int): Array[MapStatus] = {
    // 直接从mapStatuses中获取
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      val startTime = System.currentTimeMillis
      var fetchedStatuses: Array[MapStatus] = null
      ......
      if (fetchedStatuses == null) {
        // We won the race to fetch the statuses; do so
        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
        // This try-finally prevents hangs due to timeouts:
        try {
          // 从远程获取元数据
          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
          // 反序列化
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")
          // 加入mapStatus
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
            fetching.notifyAll()
          }
        }
      } 
     .....
      }
    } else {
      return statuses
    }
  }

若能从mapStatuses获取到则直接返回,若不能则向mapOutputTrackerMaster通信发送GetMapOutputStatuses消息来获取元数据。

我们知道一个Executor对应一个CoarseGrainedExecutorBackend,构建CoarseGrainedExecutorBackend的时候会创建一个SparkEnv,创建SparkEnv的时候会创建一个mapOutputTracker,即mapOutputTracker和Executor一一对应,也就是每一个Executor都有一个mapOutputTracker来维护元数据信息。

这里的mapStatuses就是mapOutputTracker保存元数据信息的,mapOutputTracker和Executor一一对应,在该Executor上完成的Shuffle Write的元数据信息都会保存在其mapStatus里面,另外通过远程获取的其他Executor上完成的Shuffle Write的元数据信息也会在当前的mapStatuses中保存。

Executor对应的是mapOutputTrackerWorker,而Driver对应的是mapOutputTrackerMaster,两者都是在实例化SparkEnv的时候创建的,每个在Executor上完成的Shuffle Task的结果都会注册到driver端的mapOutputTrackerMaster中,即driver端的mapOutputTrackerMaster的mapStatuses保存这所有元数据信息,所以当一个Executor上的任务需要获取一个shuffle的输出时,会先在自己的mapStatuses中查找,找不到再和mapOutputTrackerMaster通信获取元数据。

mapOutputTrackerMaster收到消息后的处理逻辑:

case GetMapOutputStatuses(shuffleId: Int) =>
      val hostPort = context.senderAddress.hostPort
      logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
      val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))

调用了tracker的post方法:

 def post(message: GetMapOutputMessage): Unit = {
    mapOutputRequests.offer(message)
  }

将该Message加入了mapOutputRequests中,mapOutputRequests是一个链式阻塞队列,在mapOutputTrackerMaster初始化的时候专门启动了一个线程池来执行这些请求:

private val threadpool: ThreadPoolExecutor = {
    val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher")
    for (i <- 0 until numThreads) {
      pool.execute(new MessageLoop)
    }
    pool
  }

看看线程处理类MessageLoop的run方法是怎么定义的:

private class MessageLoop extends Runnable {
    override def run(): Unit = {
      try {
        while (true) {
          try {
            // 取出一个GetMapOutputMessage
            val data = mapOutputRequests.take()
             if (data == PoisonPill) {
              // Put PoisonPill back so that other MessageLoops can see it.
              mapOutputRequests.offer(PoisonPill)
              return
            }
            val context = data.context
            val shuffleId = data.shuffleId
            val hostPort = context.senderAddress.hostPort
            logDebug("Handling request to send map output locations for shuffle " + shuffleId +
              " to " + hostPort)
            // 通过shuffleId获取对应序列化后的元数据信息
            val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
            // 返回数据
            context.reply(mapOutputStatuses)
          } catch {
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
        case ie: InterruptedException => // exit
      }
    }
  }

通过shuffleId获取对应序列化后的元数据信息并返回,具体看看getSerializedMapOutputStatuses的实现:

def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
    var statuses: Array[MapStatus] = null
    var retBytes: Array[Byte] = null
    var epochGotten: Long = -1

    // 从cache中检索出MapStatus,若没有则从mapStatuses中获取
    def checkCachedStatuses(): Boolean = {
      epochLock.synchronized {
        if (epoch > cacheEpoch) {
          cachedSerializedStatuses.clear()
          clearCachedBroadcast()
          cacheEpoch = epoch
        }
        cachedSerializedStatuses.get(shuffleId) match {
          case Some(bytes) =>
            retBytes = bytes
            true
          case None =>
            logDebug("cached status not found for : " + shuffleId)
            statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus])
            epochGotten = epoch
            false
        }
      }
    }

    if (checkCachedStatuses()) return retBytes
    var shuffleIdLock = shuffleIdLocks.get(shuffleId)
    if (null == shuffleIdLock) {
      val newLock = new Object()
      // in general, this condition should be false - but good to be paranoid
      val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
      shuffleIdLock = if (null != prevLock) prevLock else newLock
    }
    // synchronize so we only serialize/broadcast it once since multiple threads call
    // in parallel
    shuffleIdLock.synchronized {
      if (checkCachedStatuses()) return retBytes

      // 序列化statues
      val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager,
        isLocal, minSizeForBroadcast)
      logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
      // Add them into the table only if the epoch hasn't changed while we were working
      epochLock.synchronized {
        if (epoch == epochGotten) {
          cachedSerializedStatuses(shuffleId) = bytes
          if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
        } else {
          logInfo("Epoch changed, not caching!")
          removeBroadcast(bcast)
        }
      }
      bytes
    }
  }

大体思路是先从缓存中获取元数据(MapStatuses),获取到直接返回,若没有则从mapStatuses获取,获取到后将其序列化后返回,随后返回给mapOutputTrackerWorker(刚才与之通信的节点),mapOutputTracker收到回复后又将元数据序列化并加入当前Executor的mapStatuses中。

再回到getMapSizesByExecutorId方法中,getStatuses得到shuffleID对应的所有的元数据信息后,通过convertMapStatuses方法将获得的元数据信息转化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息,用来读取指定的分区的数据:

private def convertMapStatuses(
      shuffleId: Int,
      startPartition: Int,
      endPartition: Int,
      statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    assert (statuses != null)
    // 存储指定partition的元数据
    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
    for ((status, mapId) <- statuses.zipWithIndex) {
      if (status == null) {
        val errorMessage = s"Missing an output location for shuffle $shuffleId"
        logError(errorMessage)
        throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
      } else {
        for (part <- startPartition until endPartition) {
          splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
            ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
        }
      }
    }

    splitsByAddress.toSeq
  }

这里的参数statuses:Array[MapStatus]是前面获取的上游stage所有的shuffle Write 文件的元数据,并且是按map端的partitionId排序的,通过zipWithIndex将元素和这个元素在数组中的ID(索引号)组合成键/值对,这里的索引号即是map端的partitionId,再根据shuffleId、mapPartitionId、reducePartitionId来构建ShuffleBlockId(在map端的ShuffleBlockId构建中的reducePartitionId始终是0,因为一个ShuffleMapTask就一个Block,而这里加入的真正的reducePartitionId在后面通过index文件获取对应reduce端partition偏移量的时候需要用到),并估算得到对应数据的大小,因为后面获取远程数据的时候需要限制大小,最后返回位置信息。

至此mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)方法完成,返回了指定分区对应的元数据MapStatus信息。

在初始化对象ShuffleBlockFetcherIterator的时候调用了其初始化方法initialize():

private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())

    // 区分local blocks和remote blocks并返回远程请求FetchRequest
    val remoteRequests = splitLocalRemoteBlocks()
    // 将远程请求随机的加入到fetchRequests队列中
    fetchRequests ++= Utils.randomize(remoteRequests)
    assert ((0 == reqsInFlight) == (0 == bytesInFlight),
      "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
      ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

    // 从fetchRequests取出远程请求,并使用sendRequest方法发送请求
    fetchUpToMaxBytes()

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // 获取本地blocks
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }
  • 区分local blocks和remote blocks,并返回远程请求FetchRequest加入到fetchRequests队列中
  • 从fetchRequests取出远程请求,并使用sendRequest方法发送请求,获取远程数据
  • 获取本地blocks

先看是怎么区分local blocks和remote blocks的:

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // 将一次能获取的数据最大大小/5,目的是增加并行度,最大为5个并行度
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)

    // 存储远程请求的数组
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      // 若block所在executor就是当前executor,则判断为本地,否则为远程
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // 过滤掉大小为0的blocks
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          // Skip empty blocks
          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          // 当请求大小超过了限制,则创建一个FetchRequest并加入到remoteRequests中
          if (curRequestSize >= targetRequestSize) {
            // Add this FetchRequest
            remoteRequests += new FetchRequest(address, curBlocks)
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            logDebug(s"Creating fetch request of $curRequestSize at $address")
            curRequestSize = 0
          }
        }
        // 将剩余的blocks创建一个FetchRequest并加入到remoteRequests中
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
    remoteRequests
  }
  • 为了增加在远程节点获取数据的并行度,将一个请求的大小限制除以5作为最终的大小限制,即每次最多启动5个线程去最多5个节点上读取数据
  • 判断是否是本地blocks的条件是block所在的executor和当前executor是否是同一个
  • 遍历远程数据节点(Executor节点)的blocks,在一个节点上的请求数据超过大小限制则构建一个FetchRequest并加入到remoteRequests中,最后返回远程请求remoteRequests,这里的FetchRequest是对一个请求数据的包装,包括地址和blockId及大小

区分完local remote blocks后加入到了队列fetchRequests中,并调用fetchUpToMaxBytes()来获取远程数据:

private def fetchUpToMaxBytes(): Unit = {
    // Send fetch requests up to maxBytesInFlight
    while (fetchRequests.nonEmpty &&
      (bytesInFlight == 0 ||
        (reqsInFlight + 1 <= maxReqsInFlight &&
          bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
      sendRequest(fetchRequests.dequeue())
    }
  }

从fetchRequests中取出FetchRequest,并调用了sendRequest方法:

 private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size
    reqsInFlight += 1

    // 转成map  Map[blockId,size]
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
    val blockIds = req.blocks.map(_._1.toString)

    val address = req.address
    // 通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
      new BlockFetchingListener {
        // 将结果保存到results中
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          ShuffleBlockFetcherIterator.this.synchronized {
            if (!isZombie) {
              // Increment the ref count because we need to pass this to a different thread.
              // This needs to be released after use.
              buf.retain()
              remainingBlocks -= blockId
              results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                remainingBlocks.isEmpty))
              logDebug("remainingBlocks: " + remainingBlocks)
            }
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }

        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), address, e))
        }
      }
    )
  }

通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据,默认是通过NettyBlockTransferService的fetchBlocks方法实现的,不管是成功还是失败都将构建SuccessFetchResult & FailureFetchResult 结果放入results中。

获取完远程的数据接着通过fetchLocalBlocks()方法来获取本地的blocks信息:

private[this] def fetchLocalBlocks() {
    val iter = localBlocks.iterator
    while (iter.hasNext) {
      val blockId = iter.next()
      try {
        val buf = blockManager.getBlockData(blockId)
        shuffleMetrics.incLocalBlocksFetched(1)
        shuffleMetrics.incLocalBytesRead(buf.size)
        buf.retain()
        results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
      } catch {
        case e: Exception =>
          // If we see an exception, stop immediately.
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
          return
      }
    }
  }

迭代需要获取的block,直接从blockManager中获取数据,并通过结果数据构建SuccessFetchResult或者FailureFetchResult放入results中,看看在blockManager.getBlockData(blockId)的实现:

override def getBlockData(blockId: BlockId): ManagedBuffer = {
    if (blockId.isShuffle) {
      shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
    } else {
      getLocalBytes(blockId) match {
        case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer)
        case None =>
          // If this block manager receives a request for a block that it doesn't have then it's
          // likely that the master has outdated block statuses for this block. Therefore, we send
          // an RPC so that this block is marked as being unavailable from this block manager.
          reportBlockStatus(blockId, BlockStatus.empty)
          throw new BlockNotFoundException(blockId.toString)
      }
    }
  }

再看看getBlockData方法:

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
    // 根据ShuffleID和MapID获取索引文件
    val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
    val in = new DataInputStream(new FileInputStream(indexFile))
    try {
      // 跳到对应Block的数据区
      ByteStreams.skipFully(in, blockId.reduceId * 8)
      // partition对应的开始offset
      val offset = in.readLong()
      // partition对应的结束offset
      val nextOffset = in.readLong()
      new FileSegmentManagedBuffer(
        transportConf,
        getDataFile(blockId.shuffleId, blockId.mapId),
        offset,
        nextOffset - offset)
    } finally {
      in.close()
    }
  }

根据shuffleId和mapId获取index文件,并创建一个读文件的文件流,根据block的reduceId(上面获取对应partition元数据的时候提到过)跳过对应的Block的数据区,先后获取开始和结束的offset,然后在数据文件中读取数据。

得到所有数据结果result后,再回到read()方法中:

 override def read(): Iterator[Product2[K, C]] = {
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      // 与mapOutputTrackerMaster通信获取存储数据位置的元数据
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // 每次传输的最大大小
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

    // 用压缩加密来包装流
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      serializerManager.wrapStream(blockId, inputStream)
    }
  
    val serializerInstance = dep.serializer.newInstance()

    // 对每个流生成K/V迭代器
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
       serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // 每条记录读取后更新任务度量
    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
    // 生成完整的迭代器
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map { record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // 在map端已经聚合一次了
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // 只在reduce端聚合
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // 若需要全局排序
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

这里的ShuffleBlockFetcherIterator继承了Iterator,results可以被迭代,在其next()方法中将FetchResult以(blockId,inputStream)的形式返回:

case SuccessFetchResult(blockId, address, _, buf, _) =>
        try {
          (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
        } catch {
          case NonFatal(t) =>
            throwFetchFailedException(blockId, address, t)
        }

在read()方法的后半部分会进行聚合和排序,和Shuffle Write部分很类似,这里大致描述一下。

在需要聚合的前提下,有map端聚合的时候执行combineCombinersByKey,没有则执行combineValuesByKey,但最终都调用了ExternalAppendOnlyMap的insertAll(iter)方法:

def combineCombinersByKey(
      iter: Iterator[_ <: Product2[K, C]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
    combiners.insertAll(iter)
    updateMetrics(context, combiners)
    combiners.iterator
  }
def combineValuesByKey(
      iter: Iterator[_ <: Product2[K, V]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
    combiners.insertAll(iter)
    updateMetrics(context, combiners)
    combiners.iterator
  }
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
    if (currentMap == null) {
      throw new IllegalStateException(
        "Cannot insert new elements into a map after calling iterator")
    }
    // An update function for the map that we reuse across entries to avoid allocating
    // a new closure each time
    var curEntry: Product2[K, V] = null
    val update: (Boolean, C) => C = (hadVal, oldVal) => {
      if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
    }

    while (entries.hasNext) {
      curEntry = entries.next()
      val estimatedSize = currentMap.estimateSize()
      if (estimatedSize > _peakMemoryUsedBytes) {
        _peakMemoryUsedBytes = estimatedSize
      }
      if (maybeSpill(currentMap, estimatedSize)) {
        currentMap = new SizeTrackingAppendOnlyMap[K, C]
      }
      currentMap.changeValue(curEntry._1, update)
      addElementsRead()
    }
  }

在里面的迭代最终都会调用上面提到的ShuffleBlockFetcherIterator的next方法来获取数据。

每次update&insert也会估算currentMap的大小,并判断是否需要溢写到磁盘文件,若需要则将map中的数据根据定义的keyComparator对key进行排序后返回一个迭代器,然后写到一个临时的磁盘文件,然后新建一个map来放新的数据。

执行完combiners[ExternalAppendOnlyMap]的insertAll后,调用其iterator来返回一个代表一个完整partition数据(内存及spillFile)的迭代器:

override def iterator: Iterator[(K, C)] = {
    if (currentMap == null) {
      throw new IllegalStateException(
        "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
    }
    if (spilledMaps.isEmpty) {
      CompletionIterator[(K, C), Iterator[(K, C)]](
        destructiveIterator(currentMap.iterator), freeCurrentMap())
    } else {
      new ExternalIterator()
    }
  }

跟进ExternalIterator类的实例化:

// A queue that maintains a buffer for each stream we are currently merging
    // This queue maintains the invariant that it only contains non-empty buffers
    private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]

    // Input streams are derived both from the in-memory map and spilled maps on disk
    // The in-memory map is sorted in place, while the spilled maps are already in sorted order
    private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
      currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap())
    private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)

    inputStreams.foreach { it =>
      val kcPairs = new ArrayBuffer[(K, C)]
      readNextHashCode(it, kcPairs)
      if (kcPairs.length > 0) {
        mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
      }
    }

将currentMap中的数据经过排序后和spillFile数据的iterator组合在一起得到inputStreams ,迭代这个inputStreams ,将所有数据都保存在mergeHeadp中,在ExternalIterator方法的next()方法中将被访问到。

最后若需要对数据进行全局的排序,则通过只有排序参数的ExternalSorter的insertAll方法来进行排序,和Shuffle Write一样的这里就不细讲了。

最终返回一个指定partition所有数据的一个迭代器。

[Spark SQL] 源码解析之Parser

前言

由上篇博客我们知道了SparkSql整个解析流程如下:

  • sqlText 经过 SqlParser 解析成 Unresolved LogicalPlan;
  • analyzer 模块结合catalog进行绑定,生成 resolved LogicalPlan;
  • optimizer 模块对 resolved LogicalPlan 进行优化,生成 optimized LogicalPlan;
  • SparkPlan 将 LogicalPlan 转换成PhysicalPlan;
  • prepareForExecution()将 PhysicalPlan 转换成可执行物理计划;
  • 使用 execute()执行可执行物理计划;

详解Parser模块

Parser就是将SQL字符串切分成一个个Token,再根据一定语义规则解析为一棵语法树。我们写的sql语句只是一个字符串而已,首先需要将其通过词法解析和语法解析生成语法树,Spark1.x版本使用的是scala原生的parser语法解析器,从2.x后改用的是第三方语法解析工具ANTLR4, 在性能上有了较大的提升。

antlr4的使用需要定义一个语法文件,sparksql的语法文件的路径在sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
antlr可以使用插件自动生成词法解析和语法解析代码,在SparkSQL中词法解析器SqlBaseLexer和语法解析器SqlBaseParser,遍历节点有两种模式Listener和Visitor。

Listener模式是被动式遍历,antlr生成类ParseTreeListener,这个类里面包含了所有进入语法树中每个节点和退出每个节点时要进行的操作。我们只需要实现我们需要的节点事件逻辑代码即可,再实例化一个遍历类ParseTreeWalker,antlr会自上而下的遍历所有节点,以完成我们的逻辑处理;

Visitor则是主动遍历模式,需要我们显示的控制我们的遍历顺序。该模式可以实现在不改变各元素的类的前提下定义作用于这些元素的新操作。SparkSql用的就是此方式来遍历节点的。

通过词法解析和语法解析将SQL语句解析成了ANTLR 4的语法树结构ParseTree。然后在parsePlan中,使用AstBuilder将ANTLR 4语法树结构转换成catalyst表达式逻辑计划logical plan。具体看源码:

// 代码1
val spark = SparkSession
    .builder
    .appName("SparkSQL Test") 
    .master("local[4]") 
    .getOrCreate()
spark.sql("select * from table").show(false) 

---
// 代码2
def sql(sqlText: String): DataFrame = {
    Dataset.ofRows(self, sessionState.sqlParser.parsePlan(sqlText))
}

---
// 代码3
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
    astBuilder.visitSingleStatement(parser.singleStatement()) match {
      case plan: LogicalPlan => plan
      case _ =>
        val position = Origin(None, None)
        throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
    }
  }

---
// 代码4
protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
    logInfo(s"Parsing command: $command")

    val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
    lexer.removeErrorListeners()
    lexer.addErrorListener(ParseErrorListener)

    val tokenStream = new CommonTokenStream(lexer)
    val parser = new SqlBaseParser(tokenStream)
    parser.addParseListener(PostProcessor)
    parser.removeErrorListeners()
    parser.addErrorListener(ParseErrorListener)

    try {
      try {
        // first, try parsing with potentially faster SLL mode
        parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
        toResult(parser)
      ...

代码2中的sqlParser为 SparkSqlParser,其成员变量val astBuilder = new SparkSqlAstBuilder(conf) 是将antlr语法结构转换为catalyst表达式的关键类。

可以看到代码3中parsePlan方法先执行parse方法(代码4),在代码4中先后实例化了分词解析和语法解析类,最后将antlr的语法解析器parser:SqlBaseParser 传给了代码3中的柯里化函数,使用astBuilder转化为catalyst表达式,可以看到首先调用的是visitSingleStatement,singleStatement为语法文件中定义的最顶级节点,接下来就是利用antlr的visitor模式显示的遍历整个语法树,将所有的节点都替换成了LogicalPlan 或者TableIdentifier。
通过Parser解析后的AST语法树如图所示:

[Spark SQL] 主要执行流程

预备知识

先介绍在Spark SQL中两个非常重要的数据结构:Tree和Rule。

SparkSql的第一件事就是把SQLText解析成语法树,这棵树包含了很多节点对象,节点可以有特定的数据类型,同时可以有0个或者多个子节点,节点在SparkSQL中的表现形式为TreeNode对象。举个实际的例子:

  • Literal(value: Int): 一个常量
  • Attribute(name: String): 变量name
  • Add(left: TreeNode, right: TreeNode): 两个表达式的和

x + (1 + 2) 在代码中的表现形式为:Add(Attribute(x), Add(Literal(1), Literal(2)))

而Rule则是应用在Tree上的规则,通过模式匹配,匹配成功的就进行相应的规则变换,若不成功则继续匹配子节点,如在Optimizer模块中有个常量累加的优化规则,通过该规则,可以将两个常量节点直接转化为值相加后的一个常量节点,如下图:

可以看见先匹配第一个Add节点没有匹配成功,再匹配其子节点Add成功了。

总流程图

下图便是SparkSql整个解析成RDD的流程图,红色部分便是SparkSql优化器系统Catalyst,和大多数大数据SQL处理引擎设计基本相同(Impala、Presto、Hive(Calcite)等)。下面简述一下每个组成部分都做了什么,后续博客中会进行详解。

Parser

  1. sqlText先通过SparkSqlParser生成语法树。
  2. Spark1版本使用的是scala原生的parser语法解析器,从2.x后改用的是第三方语法解析工具ANTLR4,只需要定制好语法,可以通过插件自动生成对应的解析代码。
  3. 然后通过AstBuilder配合antlr的visitor模式自主控制遍历Tree,将antlr里面的节点都替换成catalyst(优化器系统)里面的类型,所有的类型都继承了TreeNode特质,TreeNode又有子节点children: Seq[BaseType],便有了树的结构。
  4. 此过程解析完后形成的AST(抽象语法树)为 unresolved LogicalPlan。

Analyzer

  1. 上个步骤还只是把sql字符串通过antlr4拆分并由SparkSqlParser解析成各种LogicalPlan(TreeNode的子类),每个LogicalPlan究竟是什么意思还不知道。
  2. 接下来就需要通过Analyzer去把不确定的属性和关系,通过catalog和一些适配器方法确定下来,比如要从Catalog中解析出表名user,是临时表、临时view,hive table还是hive view,schema又是怎么样的等都需要确定下来。
  3. 将各种Rule应用到Tree之上的真正执行者都是RuleExecutor,包括后面的Optimizer 也继承了RuleExecutor, 解析的套路是递归的遍历,将新解析出来的LogicalPlan来替换原来的LogicalPlan。
  4. 此过程解析完后形成的AST为 resolved LogicalPlan。若没有action操作,后续的优化,物理计划等都不会执行。

Optimizer

  1. 这个步骤就是根据大佬们多年的SQL优化经验来对SQL进行优化,比如谓词下推、列值裁剪、常量累加等。
  2. Optimizer 也继承了RuleExecutor,并定义了一批规则,和Analyzer 一样对输入的plan进行递归处理,此过程解析完后形成的AST为 optimized LogicalPlan。

SparkPlanner

通过优化后的LogicalPlan还只是逻辑上的,接下来需要通过SparkPlanner 将optimized LogicalPlan应用到一系列特定的Strategies上,即转化为可以直接操作真实数据的操作及数据和RDD的绑定等,此过程解析完后形成的AST为 PhysicalPlan。

prepareForExecution

此模块将 physical plan 转化为 executable physical plan,主要是插入 shuffle 操作和 internal row 的格式转换。

execute

最后调用SparkPlan的execute()执行计算。每个SparkPlan里面都有execute的实现,一般都会递归调用children的execute()方法,最后便会触发整个Tree的计算。


最后上个流程图
sql整个执行流程

后续会对每个模块进行详细解析。

BlockManager 解析

概述

BlockManager是spark自己的存储系统,RDD-Cache、 Shuffle-output、broadcast 等的实现都是基于BlockManager来实现的,BlockManager也是分布式结构,在driver和所有executor上都会有blockmanager节点,每个节点上存储的block信息都会汇报给driver端的blockManagerMaster作统一管理,BlockManager对外提供get和set数据接口,可将数据存储在memory, disk, off-heap。

blockManager的创建与注册

blockManagerMaster和blockManager都是在构造SparkEnv的时候创建的,Driver端是创建SparkContext的时候创建SparkEnv,Executor端的SparkEnv是在其守护进程CoarseGrainedExecutorBackend创建的时候创建的,下面看blockManager是怎么在sparkEnv中创建的:

// get&put 远程block的时候就是通过blockTransferService 完成的
val blockTransferService =
      new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress,
        blockManagerPort, numUsableCores)

 val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
      BlockManagerMaster.DRIVER_ENDPOINT_NAME,
      new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)),
      conf, isDriver)

    // NB: blockManager is not valid until initialize() is called later.
    val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
      serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
      blockTransferService, securityManager, numUsableCores)

构造blockManagerMaster的时候在Driver端是创建了一个BlockManagerMasterEndpoint并注册到了rpcEnv中,而在executor端是获取到了 Driver端BlockManagerMasterEndpoint的引用 BlockManagerMasterRef,以便后面的通信。随后都创建了自己blockManager,创建blockManager的时候都创建了BlockManagerSlaveEndpoint。

blockManager创建后还不能直接使用,接着都会调用blockManager的initialize方法,通过与master通信向master进行注册,master收到消息后会将blockManager的信息存到blockManagerInfo的map中,key为blockManagerId(保存着executorId、host、post等信息),value为BlockManagerInfo(保存着具体的block状态信息及 BlockManagerSlaveEndpoint 的引用 ),注册完后就可以真正干活了。

master与slave间的消息传递

slave -> master

    // slave向master注册,会保存在master的blockManagerInfo中
    case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) =>
      context.reply(register(blockManagerId, maxMemSize, slaveEndpoint))
    
    // 一个Block的更新消息,BlockId作为一个Block的唯一标识,会保存Block所在的节点和位置关系,以及block 存储级别,大小 占用内存和磁盘大小
    case _updateBlockInfo @
        UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
      context.reply(updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size))
      listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo)))
  
    // 用于获取指定 blockId 的 block 所在的 BlockManagerId 列表
    case GetLocations(blockId) =>
      context.reply(getLocations(blockId))
    
    // 获取多个Block所在 的位置,位置中会反映Block位于哪个 executor, host 和端口
    case GetLocationsMultipleBlockIds(blockIds) =>
      context.reply(getLocationsMultipleBlockIds(blockIds))

    // 一个block有可能在多个节点上存在,返回一个节点列表
    case GetPeers(blockManagerId) =>
      context.reply(getPeers(blockManagerId))
    
    // 根据BlockId,获取所在executorEndpointRef 也就是 BlockManagerSlaveEndpoint的引用
    case GetExecutorEndpointRef(executorId) =>
      context.reply(getExecutorEndpointRef(executorId))

    // 获取所有节点上的BlockManager的最大内存和剩余内存
    case GetMemoryStatus =>
      context.reply(memoryStatus)
    
    // 获取所有节点上的BlockManager的最大磁盘空间和剩余磁盘空间
    case GetStorageStatus =>
      context.reply(storageStatus)

    // 获取一个Block的状态信息,位置,占用内存和磁盘大小
    case GetBlockStatus(blockId, askSlaves) =>
      context.reply(blockStatus(blockId, askSlaves))

    // 获取一个Block的存储级别和所占内存和磁盘大小
    case GetMatchingBlockIds(filter, askSlaves) =>
      context.reply(getMatchingBlockIds(filter, askSlaves))
 
    // 删除Rdd对应的Block数据
    case RemoveRdd(rddId) =>
      context.reply(removeRdd(rddId))
 
    // 删除 shuffleId对应的BlockId的Block
    case RemoveShuffle(shuffleId) =>
      context.reply(removeShuffle(shuffleId))

    // 删除Broadcast对应的Block数据
    case RemoveBroadcast(broadcastId, removeFromDriver) =>
      context.reply(removeBroadcast(broadcastId, removeFromDriver))
    
    // 删除一个Block数据,会找到数据所在的slave,然后向slave发送一个删除消息
    case RemoveBlock(blockId) =>
      removeBlockFromWorkers(blockId)
      context.reply(true)
    
    // 从BlockManagerInfo中删除一个BlockManager, 并且删除这个 BlockManager上的所有的Blocks
    case RemoveExecutor(execId) =>
      removeExecutor(execId)
      context.reply(true)

    // 用于停止 driver 或 executor 端的 BlockManager
    case StopBlockManagerMaster =>
      context.reply(true)
      stop()

    // slave 发送心跳给 master , 证明自己还活着
    case BlockManagerHeartbeat(blockManagerId) =>
      context.reply(heartbeatReceived(blockManagerId))
    
    // 用于检查 executor 是否有缓存 blocks(广播变量的 blocks 不作考虑,因为广播变量的 block 不会汇报给 Master)
    case HasCachedBlocks(executorId) =>
      blockManagerIdByExecutor.get(executorId) match {
        case Some(bm) =>
          if (blockManagerInfo.contains(bm)) {
            val bmInfo = blockManagerInfo(bm)
            context.reply(bmInfo.cachedBlocks.nonEmpty)
          } else {
            context.reply(false)
          }
        case None => context.reply(false)
      }

master -> slave

    // slave删除自己BlockManager上的一个Block
    case RemoveBlock(blockId) =>
      doAsync[Boolean]("removing block " + blockId, context) {
        blockManager.removeBlock(blockId)
        true
      }
     
    // 删除Rdd对应的Block数据
    case RemoveRdd(rddId) =>
      doAsync[Int]("removing RDD " + rddId, context) {
        blockManager.removeRdd(rddId)
      }

    // 删除 shuffleId对应的BlockId的Block
    case RemoveShuffle(shuffleId) =>
      doAsync[Boolean]("removing shuffle " + shuffleId, context) {
        if (mapOutputTracker != null) {
          mapOutputTracker.unregisterShuffle(shuffleId)
        }
        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
      }

    // 删除 BroadcastId对应的BlockId的Block
    case RemoveBroadcast(broadcastId, _) =>
      doAsync[Int]("removing broadcast " + broadcastId, context) {
        blockManager.removeBroadcast(broadcastId, tellMaster = true)
      }

    // 获取一个Block的存储级别和所占内存和磁盘大小
    case GetBlockStatus(blockId, _) =>
      context.reply(blockManager.getStatus(blockId))

    case GetMatchingBlockIds(filter, _) =>
      context.reply(blockManager.getMatchingBlockIds(filter))

    case TriggerThreadDump =>
      context.reply(Utils.getThreadDump())

存储

在blockManager被创建的时候创建了MemoryStore和DiskStore两个对象用以存取block。

  private[spark] val memoryStore =
    new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
  private[spark] val diskStore = new DiskStore(conf, diskBlockManager)
  memoryManager.setMemoryStore(memoryStore)

DiskStore

diskSore就是基于磁盘来存储数据的,diskStore有一个成员DiskBlockManager,其主要作用就是逻辑block和磁盘block的映射,block的blockId对应磁盘文件中的一个文件。

def getFile(filename: String): File = {
    // Figure out which local directory it hashes to, and which subdirectory in that
    val hash = Utils.nonNegativeHash(filename)
    val dirId = hash % localDirs.length
    val subDirId = (hash / localDirs.length) % subDirsPerLocalDir

    // Create the subdirectory if it doesn't already exist
    val subDir = subDirs(dirId).synchronized {
      val old = subDirs(dirId)(subDirId)
      if (old != null) {
        old
      } else {
        val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
        if (!newDir.exists() && !newDir.mkdir()) {
          throw new IOException(s"Failed to create local dir in $newDir.")
        }
        subDirs(dirId)(subDirId) = newDir
        newDir
      }
    }

    new File(subDir, filename)
  }

通过blockId的hash值和localDirs的个数求余来决定在哪个localDir下创建文件,这里的localDirs是可配置的多个目录,可通过SPARK_LOCAL_DIRS进行设置,多个目录以逗号分割,配置多个目录的目的是可分散磁盘的读写压力。另外spark在每个localDir中创建了64(可通过spark.diskStore.subDirectories配置)个子目录来分散文件,子文件的选择也是通过blockId的hash值来计算的。

在diskStore中的putButes方法就是真正写数据到磁盘的方法:

def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = {
    put(blockId) { fileOutputStream =>
      val channel = fileOutputStream.getChannel
      Utils.tryWithSafeFinally {
        bytes.writeFully(channel)
      } {
        channel.close()
      }
    }
  }
def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = {
    if (contains(blockId)) {
      throw new IllegalStateException(s"Block $blockId is already present in the disk store")
    }
    logDebug(s"Attempting to put block $blockId")
    val startTime = System.currentTimeMillis
    val file = diskManager.getFile(blockId)
    val fileOutputStream = new FileOutputStream(file)
    var threwException: Boolean = true
    try {
      writeFunc(fileOutputStream)
      threwException = false
    } finally {
      try {
        Closeables.close(fileOutputStream, threwException)
      } finally {
         if (threwException) {
          remove(blockId)
        }
      }
    }
    val finishTime = System.currentTimeMillis
    logDebug("Block %s stored as %s file on disk in %d ms".format(
      file.getName,
      Utils.bytesToString(file.length()),
      finishTime - startTime))
  }

接收一个blockId和要写的字节数据,通过blockId获取到要写的具体文件并得到对应的文件输出流,将该bytes直接write这个流里,完成写文件。

diskStore还有一个重要的方法getBytes方法,即读磁盘文件的方法,通过blockId获取到对应的磁盘文件,以字节 buffer 的形式返回。

此外还有查询blockId对应文件的大小、是否存在blockId对应的文件、删除blockId对应的文件等方法。

MemoryStore

memorySore是基于JVM的堆内存来存储数据,MemoryStore内部维护了一个hash map来管理所有的block,以block id为key将block存放到hash map中。

private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true)

基于内存存储不像存磁盘那么简单,毕竟内存非常有限,memorySore有个专门管理内存的成员memoryManager,spark的内存管理详情参考内存管理 MemoryManager 解析

放内存就意味着要有足够的内存来放,不然会导致OOM。

  • 若blockId 对应的数据以bytes数据的方式存放,则可根据其size大小来申请内存存放,若不能申请足够的内存则说明放入内存失败。对应的方法是:

    putBytes[T: ClassTag](blockId: BlockId, size: Long, memoryMode: MemoryMode, _bytes: () => ChunkedByteBuffer): Boolean
    
  • 若blockId 对应的数据通过迭代器的方式写入内存,则无法提前知道其数据大小,这里的做法是逐步展开迭代器来检查是否还有空余内存。如果迭代器顺利展开了,那么用来展开迭代器的内存直接转换为存储内存,而不用再去分配内存来存储该 block 数据。如果未能完全开展迭代器,则返回一个包含 block 数据的迭代器,其对应的数据是由已经展开的数据和未展开的迭代器组成,在未能成功展开的情况下说明存储内存失败,将会根据存储级别判断是否需要存到磁盘,对应的方法是:

    putIteratorAsValues[T](blockId: BlockId, values: Iterator[T], classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long]
    
    putIteratorAsBytes[T](blockId: BlockId, values: Iterator[T], classTag: ClassTag[T], memoryMode: MemoryMode): Either[PartiallySerializedBlock[T], Long]
    

通过memoryStore读数据也有两种方式,一个是以字节buffer的形式返回指定的block数据,另一个是以迭代器的形式返回指定的block数据。

blockManager对外服务

blockManager典型的几个应用场景如下:

  • spark shuffle过程的数据就是通过blockManager来存储的。
  • spark broadcast 将task调度到多个executor的时候,broadCast 底层使用的数据存储就是blockManager。
  • 对一个rdd进行cache的时候,cache的数据就是通过blockManager来存放的。
  • spark streaming 一个 ReceiverInputDStream 接受到的数据也是先放在 BlockManager 中, 然后封装为一个 BlockRdd 进行下一步运算的。

参考

Task执行流程

前言

在文章TaskScheduler 任务提交与调度源码解析 中介绍了Task在executor上的逻辑分配,调用TaskSchedulerImpl的resourceOffers()方法,得到了TaskDescription序列的序列Seq[Seq[TaskDescription]],即对某个task需要在某个executor上执行的描述,仅仅是逻辑上的,还并未真正到executor上执行,本文将从源码角度解析Task是怎么被分配到executor上执行的。

Driver端发送LaunchTask事件

通过resourceOffers逻辑分配完task后,CoarseGrainedSchedulerBackend以Seq[Seq[TaskDescription]]参数调用了launchTasks方法:

private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
      for (task <- tasks.flatten) {
       //序列化TaskDescription
        val serializedTask = ser.serialize(task)
        if (serializedTask.limit >= maxRpcMessageSize) {
          scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
            try {
              var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
                "spark.rpc.message.maxSize (%d bytes). Consider increasing " +
                "spark.rpc.message.maxSize or using broadcast variables for large values."
              msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)
              taskSetMgr.abort(msg)
            } catch {
              case e: Exception => logError("Exception in error callback", e)
            }
          }
        }
        else {
          //根据executorId获取executor描述信息executorData
          val executorData = executorDataMap(task.executorId)
          //减少相应的freeCores
          executorData.freeCores -= scheduler.CPUS_PER_TASK

          logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
            s"${executorData.executorHost}.")
          //利用executorData中的executorEndpoint,发送LaunchTask事件,LaunchTask事件中包含序列化后的task 
          executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
        }
      }
    }

先将TaskDescription序列化后判断其大小是否超过akka规定的上限,若没有则通过executorData的executorEndpoint来发送LaunchTask事件,executorEndpoint是Diver端和executor端通信的引用,发送LaunchTask事件给executor,将Task传递给executor执行。

Executor端接收LaunchTask事件

driver端向executor发送任务需要通过后台辅助进程CoarseGrainedSchedulerBackend,那么自然而然executor接收任务也有对应的后台辅助进程CoarseGrainedExecutorBackend,该进程与executor一一对应,提供了executor和driver通讯的功能。下面看看CoarseGrainedExecutorBackend接收到事件后是如何处理的:

case LaunchTask(data) =>
      if (executor == null) {
        exitExecutor(1, "Received LaunchTask command but executor was null")
      } else {
        // 将TaskDescription反序列化
        val taskDesc = ser.deserialize[TaskDescription](data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        //调用executor的launchTask来加载该task
        executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
          taskDesc.name, taskDesc.serializedTask)
      }

将task反序列化后得到TaskDescription ,调用executor的launchTask来加载该task,继续跟进:

def launchTask(
      context: ExecutorBackend,
      taskId: Long,
      attemptNumber: Int,
      taskName: String,
      serializedTask: ByteBuffer): Unit = {
    // 创建一个TaskRunner
    val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
      serializedTask)
    runningTasks.put(taskId, tr)
    //将tr放到线程池中执行
    threadPool.execute(tr)
  }

创建了一个TaskRunner(继承于 Runnable)并加入到线程池中执行,重点就是TaskRunner中的run方法了,代码太长保留只要逻辑代码:

override def run(): Unit = {
       ...
      try {
        //反序列化task,得到taskFiles、jar包taskFiles和Task二进制数据taskBytes  
        val (taskFiles, taskJars, taskProps, taskBytes) =
          Task.deserializeWithDependencies(serializedTask)

        Executor.taskDeserializationProps.set(taskProps)
       //下载task依赖的文件和jar包
        updateDependencies(taskFiles, taskJars)
       //反序列化出task
        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
        ...
        val value = try {
          //调用task的run方法,真正执行task
          val res = task.run(
            taskAttemptId = taskId,
            attemptNumber = attemptNumber,
            metricsSystem = env.metricsSystem)
          threwException = false
          //返回结果
          res
        } finally {
          val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
          //通过任务内存管理器清理所有的分配的内存  
          val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
          if (freedMemory > 0 && !threwException) {
            val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
              throw new SparkException(errMsg)
            } else {
              logWarning(errMsg)
            }
          }
        ...
       
        val resultSer = env.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        //序列化task结果value
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()
        ...
        // 将序列化后的结果包装成DirectTaskResult对象
        val directResult = new DirectTaskResult(valueBytes, accumUpdates)
        //再将directResult 序列化,
        val serializedDirectResult = ser.serialize(directResult)
        val resultSize = serializedDirectResult.limit

        // directSend = sending directly back to the driver
        val serializedResult: ByteBuffer = {
          //若task结果大于所有maxResultSize(可配置,默认1G),则直接丢弃,driver在返回的对象中拿不到对应的结果
          if (maxResultSize > 0 && resultSize > maxResultSize) { 
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
          //若task结果大小超过akka最大能传输的大小(运行结果无法通过消息传递 ),则将结果写入BlockManager  
          } else if (resultSize > maxDirectResultSize) {
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId,
              new ChunkedByteBuffer(serializedDirectResult.duplicate()),
              StorageLevel.MEMORY_AND_DISK_SER)
            logInfo(
              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
          //结果比较小能以消息传递,直接返回
          } else {
            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
            serializedDirectResult
          }
        }
        // 向Driver端发状态更新
        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

      } catch { 
          ...
          //向Driver端发状态更新
          execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
          ...
      } finally {
        // 不管成功与否,都需要将task从runningTasks中移除
        runningTasks.remove(taskId)
      }
    }
  • 通过Task的deserializeWithDependencies反序列化得到taskFiles、jar包taskFiles和Task二进制数据taskBytes
  • 下载task依赖的文件和jar包
  • 反序列化出task
  • 调用task的run方法,真正执行task,并返回结果
  • 清除分配内存
  • 序列化task的结果,包装成directResult,再次序列化,根据其结果大小将结果以不同的方式返回给driver
    • 若task结果大于所有maxResultSize(可配置,默认1G),则直接丢弃,driver在返回的对象中拿不到对应的结果
    • 若task结果大小超过akka最大能传输的大小(运行结果无法通过消息传递 ),则将结果写入BlockManager
    • 结果比较小能以消息传递,直接返回

最后通过CoarseGrainedExecutorBackend的statusUpdate方法来返回结果给driver,该方法会使用driverRpcEndpointRef 发送一条包含 serializedResult 的 StatusUpdate 消息给 driver。

我们再来看看task的run方法都干了什么?

final def run(
      taskAttemptId: Long,
      attemptNumber: Int,
      metricsSystem: MetricsSystem): T = {
    SparkEnv.get.blockManager.registerTask(taskAttemptId)
    //创建一个task运行的上下文实例
    context = new TaskContextImpl(
      stageId,
      partitionId,
      taskAttemptId,
      attemptNumber,
      taskMemoryManager,
      localProperties,
      metricsSystem,
      metrics)
    TaskContext.setTaskContext(context)
    taskThread = Thread.currentThread()
    if (_killed) {
      kill(interruptThread = false)
    }
    try {
      runTask(context)
    } catch { 
     ...
    } finally { 
     ... //标记完成,释放内存
    }
  }

再继续看runTask方法,task有两种实现,分别是ResultTask(ResultStage的task,个数为最后FinalRdd的partition个数)、ShuffleMapTask(ShuffleMapStage的task,个数为最后FinalRdd的partition个数),两者对应的runTask也有不同的实现,先看ResultTask:

override def runTask(context: TaskContext): U = { 
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    //反序列化
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    //对rdd的指定分区的迭代器执行func函数,并返回结果
    func(context, rdd.iterator(partition, context))
  }
  • 使用广播变量反序列化得到rdd和func,数据来源于taskBinary
  • 对rdd的指定分区的迭代器执行func函数,并返回结果

这里的func函数根据具体操作而不同,遍历分区的每条记录是通过迭代器iterator来获取的。

再来看ShuffleMapTask的实现,shuffleMapTask的输出直接通过Shuffle write写磁盘,为下游的stage的Shuffle Read准备数据,:

override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    // 使用广播变量反序列化出rdd和ShuffleDependency
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    var writer: ShuffleWriter[Any, Any] = null
    try {
      // 获取shuffleManager
      val manager = SparkEnv.get.shuffleManager
      // 通过shuffleManager的getWriter()方法,获得shuffle的writer  
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      // 通过rdd指定分区的迭代器iterator方法来遍历每一条数据,再之上再调用writer的write方法以写数据
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }
  • 通过广播变量反序列化出rdd和ShuffleDependency,数据来源于taskBinary
  • 获取ShuffleManager的writer对象的write方法来将一个rdd的某个分区写入到磁盘
  • 通过rdd的iterator方法能遍历对应分区的所有数据

Driver端接收到结果后的处理在后续文章中再解析……

状态管理 updateStateByKey&mapWithState

前言

SparkStreaming 7*24 小时不间断的运行,有时需要管理一些状态,比如wordCount,每个batch的数据不是独立的而是需要累加的,这时就需要sparkStreaming来维护一些状态,目前有两种方案updateStateByKey&mapWithState,mapWithState是spark1.6新加入的保存状态的方案,官方声称有10倍性能提升。

updateStateByKey

先上一个示例:

def updateFunction(currValues:Seq[Int],preValue:Option[Int]): Option[Int] = {
       val currValueSum = currValues.sum
        //上面的Int类型都可以用对象类型替换
        Some(currValueSum + preValue.getOrElse(0)) //当前值的和加上历史值
    }
    kafkaStream.map(r => (r._2,1)).updateStateByKey(updateFunction _)

这里的updateFunction方法就是需要我们自己去实现的状态跟新的逻辑,currValues就是当前批次的所有值,preValue是历史维护的状态,updateStateByKey返回的是包含历史所有状态信息的DStream,下面我们来看底层是怎么实现状态的管理的,通过跟踪源码看到最核心的实现方法:

  private [this] def computeUsingPreviousRDD(
      batchTime: Time,
      parentRDD: RDD[(K, V)],
      prevStateRDD: RDD[(K, S)]) = {
    // Define the function for the mapPartition operation on cogrouped RDD;
    // first map the cogrouped tuple to tuples of required type,
    // and then apply the update function
    val updateFuncLocal = updateFunc
    val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
      val i = iterator.map { t =>
        val itr = t._2._2.iterator
        val headOption = if (itr.hasNext) Some(itr.next()) else None
        (t._1, t._2._1.toSeq, headOption)
      }
      updateFuncLocal(batchTime, i)
    }
    val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
    val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
    Some(stateRDD)
  }

可以看到是将parentRDDpreStateRDD进行co-group,然后将finalFunc方法作用于每个Partition,看到finalFunc方法的实现里面(t._1, t._2._1.toSeq, headOption)这样的形式,(key,currValues,preValue)这不就是和我们需要自己实现的updateFun类似的结构吗,是的没错,我们的方法已经被包装了一次:

def updateStateByKey[S: ClassTag](
      updateFunc: (Seq[V], Option[S]) => Option[S],
      partitioner: Partitioner
    ): DStream[(K, S)] = ssc.withScope {
    val cleanedUpdateF = sparkContext.clean(updateFunc)
    val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
      iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s)))
    }
    updateStateByKey(newUpdateFunc, partitioner, true)
  }

可以知道每次调用updateStateByKey都会将旧的状态RDD和当前batch的RDD进行co-group来得到一个新的状态RDD,即使真正需要跟新的数据只有1条也需要将两个RDD进行cogroup,所有的数据都会被计算一遍,而且随着状态的不断增加,运行速度会越来越慢。

为了解决这一问题,mapWithState应运而生。

mapWithState

先来个示例:

   val initialRDD = ssc.sparkContext.parallelize(List[(String, Int)]())
    //自定义mappingFunction,累加单词出现的次数并更新状态
    val mappingFunc = (word: String, count: Option[Int], state: State[Int]) => {
      val sum = count.getOrElse(0) + state.getOption.getOrElse(0)
      val output = (word, sum)
      state.update(sum)
      output
    }
    //调用mapWithState进行管理流数据的状态
    kafkaStream.map(r => (r._2,1)).mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD)).print()

这里的initialRDD 就是初始化状态,updateStateByKey也有对应的API。这里的mappingFun也是需要我们自己实现的状态跟新逻辑,调用state.update()就是对状态的跟新,output就是通过mapWithState后返回的DStream中的数据形式。注意这里不是直接传入的mappingFunc函数,而是一个StateSpec 的对象,其实也是对函数的一个包装而已。接下来我们跟踪源码看看是怎么实现状态的管理的,会创建一个MapWithStateDStreamImpl实例:

def mapWithState[StateType: ClassTag, MappedType: ClassTag](
      spec: StateSpec[K, V, StateType, MappedType]
    ): MapWithStateDStream[K, V, StateType, MappedType] = {
    new MapWithStateDStreamImpl[K, V, StateType, MappedType](
      self,
      spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
    )
  }

当然是要看看其compute方法是怎么实现的:

 private val internalStream =
    new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
 
  override def compute(validTime: Time): Option[RDD[MappedType]] = {
    internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
  }

compute方法又把处理逻辑给了internalStream:InternalMapWithStateDStream,继续看InternalMapWithStateDStream的compute方法主要处理逻辑:

override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
    // Get the previous state or create a new empty state RDD
    val prevStateRDD = getOrCompute(validTime - slideDuration) match {
      case Some(rdd) =>
        if (rdd.partitioner != Some(partitioner)) {
          // If the RDD is not partitioned the right way, let us repartition it using the
          // partition index as the key. This is to ensure that state RDD is always partitioned
          // before creating another state RDD using it
          MapWithStateRDD.createFromRDD[K, V, S, E](
            rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
        } else {
          rdd
        }
      case None =>
        MapWithStateRDD.createFromPairRDD[K, V, S, E](
          spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
          partitioner,
          validTime
        )
    }

    // Compute the new state RDD with previous state RDD and partitioned data RDD
    // Even if there is no data RDD, use an empty one to create a new state RDD
    val dataRDD = parent.getOrCompute(validTime).getOrElse {
      context.sparkContext.emptyRDD[(K, V)]
    }
    val partitionedDataRDD = dataRDD.partitionBy(partitioner)
    val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
      (validTime - interval).milliseconds
    }
    Some(new MapWithStateRDD(
      prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
  }

先后获取prevStateRDDparentRDD,并且保证使用的是同样的partitioner,接着以两个rdd为参数、自定义的mappingFunction函数、以及key的超时时间等为参数又创建了MapWithStateRDD,该RDD继承了RDD[MapWithStateRDDRecord[K, S, E]]MapWithStateRDD中的数据都是MapWithStateRDDRecord对象,每个分区对应一个对象来保存状态(这就是为什么两个RDD需要用同一个Partitioner),看看MapWithStateRDD的compute方法:

 override def compute(
      partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {

    val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
    val prevStateRDDIterator = prevStateRDD.iterator(
      stateRDDPartition.previousSessionRDDPartition, context)
    val dataIterator = partitionedDataRDD.iterator(
      stateRDDPartition.partitionedDataRDDPartition, context)

    val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
    val newRecord = MapWithStateRDDRecord.updateRecordWithData(
      prevRecord,
      dataIterator,
      mappingFunction,
      batchTime,
      timeoutThresholdTime,
      removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
    )
    Iterator(newRecord)
  }

拿到prevStateRDDparentRDD对应分区的迭代器,接着获取了prevStateRDD的一条数据,这个分区也只有一条MapWithStateRDDRecord类型的数据,维护了对应分区所有数据状态,接着调用了最核心的方法来跟新状态,最后返回了只包含一条数据的迭代器,我们来看看是怎么这个核心的计算逻辑:

 def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
    prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
    dataIterator: Iterator[(K, V)],
    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
    batchTime: Time,
    timeoutThresholdTime: Option[Long],
    removeTimedoutData: Boolean
  ): MapWithStateRDDRecord[K, S, E] = {
    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
    val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

    val mappedData = new ArrayBuffer[E]
    val wrappedState = new StateImpl[S]()

    // Call the mapping function on each record in the data iterator, and accordingly
    // update the states touched, and collect the data returned by the mapping function
    dataIterator.foreach { case (key, value) =>
      wrappedState.wrap(newStateMap.get(key))
      val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
      if (wrappedState.isRemoved) {
        newStateMap.remove(key)
      } else if (wrappedState.isUpdated
          || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
      }
      mappedData ++= returned
    }

    // Get the timed out state records, call the mapping function on each and collect the
    // data returned
    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
        wrappedState.wrapTimingOutState(state)
        val returned = mappingFunction(batchTime, key, None, wrappedState)
        mappedData ++= returned
        newStateMap.remove(key)
      }
    }

    MapWithStateRDDRecord(newStateMap, mappedData)
  }

先copy了原来的状态,接着定义了两个变量,mappedData是最终要返回的结果,wrappedState可以看成是对state的包装,添加了一些额外的方法。

接着遍历当前批次的数据,从状态中取出key对应的原来的state,并根据自定义的函数来对state进行跟新,这里涉及到state的remove&update&timeout来对newStateMap进行跟新操作,并将有跟新的状态加入到了mappedData中。

若有设置超时时间,则还会对超时了的key进行移除,也会加入到mappedData中,最终通过新的状态对象newStateMap和需返回的mappedData数组构建了MapWithStateRDDRecord对象来返回。

而在前面提到的MapWithStateDStreamImpl实例的compute方法中:

  override def compute(validTime: Time): Option[RDD[MappedType]] = {
    internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
  }

调用的就是这个mappedData 数据。

我们发现返回的都是有update的数据,若要获取所有的状态在mapWithState 之后调用stateSnapshots即可。若要清除某个key的状态,可在自定义的方法中调用state.remove()

总结

  • updateStateByKey底层是将preSateRDD和parentRDD进行co-group,然后对所有数据都将经过自定义的mapFun函数进行一次计算,即使当前batch只有一条数据也会进行这么复杂的计算,大大的降低了性能,并且计算时间会随着维护的状态的增加而增加。
  • mapWithstate底层是创建了一个MapWithStateRDD,存的数据是MapWithStateRDDRecord对象,一个Partition对应一个MapWithStateRDDRecord对象,该对象记录了对应Partition所有的状态,每次只会对当前batch有的数据进行跟新,而不会像updateStateByKey一样对所有数据计算。

DAGScheduler 提交stage源码解析

DAGScheduler在划分完Stage后([spark] DAGScheduler划分stage源码解析 ),将会通过submitStage(finalStage)来提交stage:

 private def submitStage(stage: Stage) {
    val jobId = activeJobForStage(stage)
    if (jobId.isDefined) {
      logDebug("submitStage(" + stage + ")")
      if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
        //获取未计算完的parentStage,判断是否计算完的条件是
        //_numAvailableOutputs == numPartitions,既有效输出个数是否等于分区数。
        //根据stageid从小到大排序,是因为越前面的stageid越小。
        val missing = getMissingParentStages(stage).sortBy(_.id)
        logDebug("missing: " + missing)
        if (missing.isEmpty) { 
          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
          submitMissingTasks(stage, jobId.get) //若当前stage没有任何依赖或者所有依赖都已经准备好,则提交task。
        } else {
          //若有未提交的父Stage,则递归提交父Stage
          //标记当前stage为waitingStages ,先等待父stage执行完。
          for (parent <- missing) {
            submitStage(parent) 
          }
          waitingStages += stage
        }
      }
    } else {
      abortStage(stage, "No active job for stage " + stage.id, None)
    }
  }

看看getMissingParentStages的实现:

private def getMissingParentStages(stage: Stage): List[Stage] = {
    val missing = new HashSet[Stage] //未计算完的stage
    val visited = new HashSet[RDD[_]] //被访问过的stage
    // We are manually maintaining a stack here to prevent StackOverflowError
    // caused by recursively visiting
    val waitingForVisit = new Stack[RDD[_]] //等待被访问的stage
    def visit(rdd: RDD[_]) {
      if (!visited(rdd)) {
        visited += rdd
        //先判断是否有未cache的分区,若全部都被cache了就不用计算parent Stage了。
        //遍历rdd的所有依赖,当是宽依赖时获取其对应依赖的宽依赖并判断该stage是否可用。
        //判断条件是该stage输出个数是否等于该stage的finalRDD分区数。
        //不等于时说明还有未计算的分区,则将该stage加入missing;
        //若为窄依赖则继续往上遍历。
        val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
        if (rddHasUncachedPartitions) {
          for (dep <- rdd.dependencies) {
            dep match {
              case shufDep: ShuffleDependency[_, _, _] =>
                val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
                if (!mapStage.isAvailable) {
                  missing += mapStage
                }
              case narrowDep: NarrowDependency[_] =>
                waitingForVisit.push(narrowDep.rdd)
            }
          }
        }
      }
    }
    waitingForVisit.push(stage.rdd)
    while (waitingForVisit.nonEmpty) {
      visit(waitingForVisit.pop())
    }
    missing.toList
  }

若当前stage没有任何依赖或者所有依赖都已经准备好,则提交通过submitMissingTasks来提交task,看看具体实现:

private def submitMissingTasks(stage: Stage, jobId: Int) {
    stage.pendingPartitions.clear()
    // 获取需要计算的分区
    val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
    val properties = jobIdToActiveJob(jobId).properties
    runningStages += stage  // 标记stage为running状态
    ......
    //获取task最佳计算位置
    val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        case s: ShuffleMapStage =>
          partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        case s: ResultStage =>
          val job = s.activeJob.get
          partitionsToCompute.map { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
          }.toMap
      }
    } catch {
        ...
    }

    stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
    listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))

    var taskBinary: Broadcast[Array[Byte]] = null
    try {
      val taskBinaryBytes: Array[Byte] = stage match {
        case stage: ShuffleMapStage =>
          JavaUtils.bufferToArray(
            closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
        case stage: ResultStage =>
          JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
      }
      taskBinary = sc.broadcast(taskBinaryBytes)
    } catch {
       ...
    }
    val tasks: Seq[Task[_]] = try {
      stage match {
        case stage: ShuffleMapStage =>
          partitionsToCompute.map { id =>
            val locs = taskIdToLocations(id)
            val part = stage.rdd.partitions(id)
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, stage.latestInfo.taskMetrics, properties)
          }
        case stage: ResultStage =>
          val job = stage.activeJob.get
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = stage.rdd.partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics)
          }
      }
    } catch {
      ...
    }
    if (tasks.size > 0) {
      stage.pendingPartitions ++= tasks.map(_.partitionId)
      taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
      stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
    } else {  
        ...
    }
  }

下面将对每个步骤详细讲解:
stage.findMissingPartitions获取需要计算的分区,不同的stage有不同的实现:

//ShuffleMapStage
//根据partition是否有对应的outputLocs来判断哪些分区需要被计算,计算过的partition会被outputLocs记录
override def findMissingPartitions(): Seq[Int] = {
    val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
    assert(missing.size == numPartitions - _numAvailableOutputs,
      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
    missing
  }

//ResultStage
//计算过的分区会被job记录为finish
 override def findMissingPartitions(): Seq[Int] = {
    val job = activeJob.get
    (0 until job.numPartitions).filter(id => !job.finished(id))
  }

taskIdToLocations获取task最佳计算位置,主要是通过getPreferredLocs方法实现:

private def getPreferredLocsInternal(
      rdd: RDD[_],
      partition: Int,
      visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = {
    // If the partition has already been visited, no need to re-visit.
    // This avoids exponential path exploration.  SPARK-695
    if (!visited.add((rdd, partition))) {
      // Nil has already been returned for previously visited partitions.
      return Nil
    }
    // If the partition is cached, return the cache locations
    val cached = getCacheLocs(rdd)(partition)
    if (cached.nonEmpty) {
      return cached
    }
    // If the RDD has some placement preferences (as is the case for input RDDs), get those
    val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
    if (rddPrefs.nonEmpty) {
      return rddPrefs.map(TaskLocation(_))
    }
    // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency
    // that has any placement preferences. Ideally we would choose based on transfer sizes,
    // but this will do for now.
    rdd.dependencies.foreach {
      case n: NarrowDependency[_] =>
        for (inPart <- n.getParents(partition)) {
          val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
          if (locs != Nil) {
            return locs
          }
        }
      case _ =>
    }
    Nil
  }
  • getCacheLocs方法中,cacheLocs 维护着RDD的partitions 的 location信息,该信息是TaskLocation的实例。如果从cacheLocs中获取到partition的location信息直接返回,若获取不到:如果RDD的存储级别为空返回nil,并填入cacheLocs,否则会通过blocakManagerMaster来获取持有该partition信息的 blockManager 并实例化ExecutorCacheTaskLocation放入cacheLocs中。
  • rdd.preferredLocations,该方法先尝试从checpoint中获取partition信息,若未获取到再通过rdd的getPreferredLocations(split)方法获取,不同rdd有不同实现,如HadoopRDD即通过Hadoop InputSplit 来获取当前partition的位置。
  • 前两者都没有获取到时,则通过递归寻找parentRDD的partition的最佳位置信息。注意:只适用于窄依赖。

获取到task最佳位置后,根据不同stage会广播不同序列化后的二进制信息到每个excutor,如果是shuffleMapStage,广播该Stage的FinalRDD和stage的shffleDep;如果是ResultStage,广播Stage的FinalRDD和stage.func。即将task的实际执行逻辑已经序列化到taskBinary中并broadcast到每个executor上。

 var taskBinary: Broadcast[Array[Byte]] = null
    try {
      // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
      // For ResultTask, serialize and broadcast (rdd, func).
      val taskBinaryBytes: Array[Byte] = stage match {
        case stage: ShuffleMapStage =>
          JavaUtils.bufferToArray(
            closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
        case stage: ResultStage =>
          JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
      }

      taskBinary = sc.broadcast(taskBinaryBytes)
    } catch {
    }

根据不同的stage生成不同的类型task,每个partition对应一个task且每个task都包含目标partition的location信息,最终所有tasks将被包装成taskSet进行提交。

stage match {
        case stage: ShuffleMapStage =>
          partitionsToCompute.map { id =>
            val locs = taskIdToLocations(id)
            val part = stage.rdd.partitions(id)
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, stage.latestInfo.taskMetrics, properties)
          }
        case stage: ResultStage =>
          val job = stage.activeJob.get
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = stage.rdd.partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics)
          }
      }

taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) 
    }

至此,DAGScheduler已经完成对stage的划分并以taskSet的形式提交给taskSchecduler,接着由TaskScheduler来提交管理tasks,后序将会推出。

spark推测式执行

概述

推测任务是指对于一个Stage里面拖后腿的Task,会在其他节点的Executor上再次启动这个task,如果其中一个Task实例运行成功则将这个最先完成的Task的计算结果作为最终结果,同时会干掉其他Executor上运行的实例。spark推测式执行默认是关闭的,可通过spark.speculation属性来开启。

检测是否有需要推测式执行的Task

在SparkContext创建了schedulerBackend和taskScheduler后,立即调用了taskScheduler 的start方法:

override def start() {
    backend.start()
    if (!isLocal && conf.getBoolean("spark.speculation", false)) {
      logInfo("Starting speculative execution thread")
      speculationScheduler.scheduleAtFixedRate(new Runnable {
        override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
          checkSpeculatableTasks()
        }
      }, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS)
    }
  }

可以看到,TaskScheduler在启动SchedulerBackend后,在非local模式前提下检查推测式执行功能是否开启(默认关闭,可通过spark.speculation开启),若开启则会启动一个线程每隔SPECULATION_INTERVAL_MS(默认100ms,可通过spark.speculation.interval属性设置)通过checkSpeculatableTasks方法检测是否有需要推测式执行的tasks:

// Check for speculatable tasks in all our active jobs.
  def checkSpeculatableTasks() {
    var shouldRevive = false
    synchronized {
      shouldRevive = rootPool.checkSpeculatableTasks()
    }
    if (shouldRevive) {
      backend.reviveOffers()
    }
  }

然后又通过rootPool的方法判断是否有需要推测式执行的tasks,若有则会调用SchedulerBackend的reviveOffers去尝试拿资源运行推测任务。继续看看检测逻辑是什么样的:

override def checkSpeculatableTasks(): Boolean = {
    var shouldRevive = false
    for (schedulable <- schedulableQueue.asScala) {
      shouldRevive |= schedulable.checkSpeculatableTasks()
    }
    shouldRevive
  }

在rootPool里又调用了schedulable的方法,schedulable是ConcurrentLinkedQueue[Schedulable]类型,队列里面放的都是TaskSetMagager,再看TaskSetMagager的checkSpeculatableTasks方法,终于找到检测根源了:

 override def checkSpeculatableTasks(): Boolean = {
    //如果task只有一个或者所有task都不需要再执行了就没有必要再检测
    if (isZombie || numTasks == 1) {  
      return false
    }
    var foundTasks = false
    // 所有task数 * SPECULATION_QUANTILE(默认0.75,可通过spark.speculation.quantile设置) 
    val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
    logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
    //成功的task数是否超过总数的75%,并且成功的task是否大于0
    if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
      val time = clock.getTimeMillis()
      // 过滤出成功执行的task的执行时间并排序
      val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
      Arrays.sort(durations)
     // 取这多个时间的中位数
      val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1))
      // 中位数 * SPECULATION_MULTIPLIER (默认1.5,可通过spark.speculation.multiplier设置)
      val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
      logDebug("Task length threshold for speculation: " + threshold)
      // 遍历该TaskSet中的task,取未成功执行、正在执行、执行时间已经大于threshold 、
      // 推测式执行task列表中未包括的task放进需要推测式执行的列表中speculatableTasks
      for ((tid, info) <- taskInfos) {
        val index = info.index
        if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
          !speculatableTasks.contains(index)) {
          logInfo(
            "Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms"
              .format(index, taskSet.id, info.host, threshold))
          speculatableTasks += index
          foundTasks = true
        }
      }
    }
    foundTasks
  }

检查逻辑代码中注释很明白,当成功的Task数超过总Task数的75%(可通过参数spark.speculation.quantile设置)时,再统计所有成功的Tasks的运行时间,得到一个中位数,用这个中位数乘以1.5(可通过参数spark.speculation.multiplier控制)得到运行时间门限,如果在运行的Tasks的运行时间超过这个门限,则对它启用推测。简单来说就是对那些拖慢整体进度的Tasks启用推测,以加速整个Stage的运行。
算法大致流程如图:

推测式任务什么时候被调度

在TaskSetMagager在延迟调度策略下为一个executor分配一个task时会调用dequeueTask方法:

private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value)
    : Option[(Int, TaskLocality.Value, Boolean)] =
  {
    for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) {
      return Some((index, TaskLocality.PROCESS_LOCAL, false))
    }

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) {
      for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) {
        return Some((index, TaskLocality.NODE_LOCAL, false))
      }
    }
   ......
    // find a speculative task if all others tasks have been scheduled
    dequeueSpeculativeTask(execId, host, maxLocality).map {
      case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)}
  }

该方法的最后一段就是在其他任务都被调度后为推测式任务进行调度,看看起实现:

protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
    : Option[(Int, TaskLocality.Value)] =
  {
    //从推测式执行任务列表中移除已经成功完成的task,因为从检测到调度之间还有一段时间,
    //某些task已经成功执行
    speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
     // 判断task是否可以在该executor对应的Host上执行,判断条件是:
     // task没有在该host上运行;
     // 该executor没有在task的黑名单里面(task在这个executor上失败过,并还在'黑暗'时间内)
    def canRunOnHost(index: Int): Boolean =
      !hasAttemptOnHost(index, host) && !executorIsBlacklisted(execId, index)
 
    if (!speculatableTasks.isEmpty) {
      // 获取能在该executor上启动的taskIndex
      for (index <- speculatableTasks if canRunOnHost(index)) {
        // 获取task的优先位置
        val prefs = tasks(index).preferredLocations 
        val executors = prefs.flatMap(_ match {
          case e: ExecutorCacheTaskLocation => Some(e.executorId)
          case _ => None
        });
        // 优先位置若为ExecutorCacheTaskLocation并且数据所在executor包含当前executor,
        // 则返回其task在taskSet的index和Locality Levels
        if (executors.contains(execId)) {
          speculatableTasks -= index
          return Some((index, TaskLocality.PROCESS_LOCAL))
        }
      }

      // 这里的判断是延迟调度的作用,即使是推测式任务也尽量以最好的本地性级别来启动
      if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
        for (index <- speculatableTasks if canRunOnHost(index)) {
          val locations = tasks(index).preferredLocations.map(_.host)
          if (locations.contains(host)) {
            speculatableTasks -= index
            return Some((index, TaskLocality.NODE_LOCAL))
          }
        }
      }

       ........
    }
    None
  }

代码太长只列了前面一部分,不过都是类似的逻辑,代码中注释也很清晰。先过滤掉已经成功执行的task,另外,推测执行task不在和正在执行的task同一Host执行,不在黑名单executor里执行,然后在延迟调度策略下根据task的优先位置来决定是否在该executor上以某种本地性级别被调度执行。

调度模式(FIFO&FAIR)

前言

spark应用程序的调度体现在两个地方,第一个是Yarn对spark应用间的调度,第二个是spark应用内(同一个SparkContext)的多个TaskSetManager的调度,这里暂时只对应用内部调度进行分析。

spark的调度模式分为两种:FIFO(先进先出)和FAIR(公平调度)。默认是FIFO,即谁先提交谁先执行,而FAIR支持在调度池中再进行分组,可以有不同的权重,根据权重、资源等来决定谁先执行。spark的调度模式可以通过spark.scheduler.mode进行设置。

调度池初始化

在DAGScheluer对job划分好stage并以TaskSet的形式提交给TaskScheduler后,TaskScheduler的实现类会为每个TaskSet创建一个TaskSetMagager对象,并将该对象添加到调度池中:

schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

schedulableBuilder是SparkContext 中newTaskSchedulerImpl(sc)在创建TaskSchedulerImpl的时候通过scheduler.initialize(backend)的initialize方法对schedulableBuilder进行了实例化。

def initialize(backend: SchedulerBackend) {
    this.backend = backend
    // temporarily set rootPool name to empty
    rootPool = new Pool("", schedulingMode, 0, 0)
    schedulableBuilder = {
      schedulingMode match {
        case SchedulingMode.FIFO =>
          new FIFOSchedulableBuilder(rootPool)
        case SchedulingMode.FAIR =>
          new FairSchedulableBuilder(rootPool, conf)
        case _ =>
          throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode")
      }
    }
    schedulableBuilder.buildPools()
  }

可以看到程序会根据配置来创建不同的调度池,schedulableBuilder有两种实现,分别是FIFOSchedulableBuilder和FairSchedulableBuilder,接着后面调用了schedulableBuilder.buildPools(),我们来看两者都是怎么实现的。

override def buildPools() {
    // nothing
  }

FIFOSchedulableBuilder啥也没干。

override def buildPools() {
    var is: Option[InputStream] = None
    try {
      is = Option {
        schedulerAllocFile.map { f =>
          new FileInputStream(f)
        }.getOrElse {
          Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
        }
      }
      //根据配置文件创建buildFairSchedulerPool
      is.foreach { i => buildFairSchedulerPool(i) }
    } finally {
      is.foreach(_.close())
    }

    // finally create "default" pool
    buildDefaultPool()
  }

可以看到FairSchedulableBuilder的buildPools方法中会先去读取FAIR模式的配置文件默认位于SPARK_HOME/conf/fairscheduler.xml,也可以通过参数spark.scheduler.allocation.file设置用户自定义配置文件。 模板如下:

<allocations>
  <pool name="production">
    <schedulingMode>FAIR</schedulingMode>
    <weight>1</weight>
    <minShare>2</minShare>
  </pool>
  <pool name="test">
    <schedulingMode>FIFO</schedulingMode>
    <weight>2</weight>
    <minShare>3</minShare>
  </pool>
</allocations>

其中:

  • name:调度池名字,可在程序中根据spark.scheduler.pool来指定使用某个调度池,未指定则使用名字为default的调度池。
  • schedulingMode:调度模式
  • weigt:权重(weight为2的分配到的资源为weight为1的两倍),如果设置为1000,该调度池一有任务就会马上运行,默认为1
  • minShare:调度池所需最小资源数(cores),默认为0

FAIR可以配置多个调度池,即rootPool里面还是一组Pool,Pool中包含了TaskSetMagager。
FairSchedulableBuilder会根据配置文件创建buildFairSchedulerPool。

private def buildFairSchedulerPool(is: InputStream) {
    val xml = XML.load(is)
    for (poolNode <- (xml \\ POOLS_PROPERTY)) {

      val poolName = (poolNode \ POOL_NAME_PROPERTY).text
      var schedulingMode = DEFAULT_SCHEDULING_MODE
      var minShare = DEFAULT_MINIMUM_SHARE
      var weight = DEFAULT_WEIGHT

      val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text
      if (xmlSchedulingMode != "") {
        try {
          schedulingMode = SchedulingMode.withName(xmlSchedulingMode)
        } catch {
          case e: NoSuchElementException =>
            logWarning(s"Unsupported schedulingMode: $xmlSchedulingMode, " +
              s"using the default schedulingMode: $schedulingMode")
        }
      }

      val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text
      if (xmlMinShare != "") {
        minShare = xmlMinShare.toInt
      }

      val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text
      if (xmlWeight != "") {
        weight = xmlWeight.toInt
      }

      val pool = new Pool(poolName, schedulingMode, minShare, weight)
      rootPool.addSchedulable(pool)
      logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
        poolName, schedulingMode, minShare, weight))
    }
  }

根据每个字段值(未设置则为默认值)来实例化一个Pool对象,并添加到rootPool中。

一个spark应用程序包含一个TaskScheduler,一个TaskScheduler包含一个唯一的RootPool,FIFO只有一层Pool,包含TaskSetMagager,而FARI包含两层Pool,RootPool包含子Pool,子Pool包含TaskSetMagager,RootPool都是在实例化SchedulableBuilder的时候创建的。

private def buildDefaultPool() {
    if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) {
      val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE,
        DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
      rootPool.addSchedulable(pool)
      logInfo("Created default pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
        DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
    }
  }

若根据配置文件创建的调度池中没有一个名字为default的调度池,则会创建一个所有参数都是默认值的名字为default的调度池。

调度池添加TaskSetMagager

两种调度模式的最终实现都是一样,不过FAIR会在添加之前会获取需要使用的调度池,默认为名字为default的调度池。

override def addSchedulable(schedulable: Schedulable) {
    require(schedulable != null)
    schedulableQueue.add(schedulable)
    schedulableNameToSchedulable.put(schedulable.name, schedulable)
    schedulable.parent = this
  }

添加一个TaskSetMagager的时候会添加到队列的尾部,获取是从头部获取。对于FIFO而言,parentPool都是RootPool,而FAIR,TaskSetMagager的parentPool都是RootPool的子Pool。

调度池对TaskSetMagager排序算法

TaskScheduler通过SchedulerBackend拿到的executor资源后,会对所有TaskSetMagager进行调度。通过rootPool.getSortedTaskSetQueue来获取排序后的TaskSetMagager。

override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = {
    var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
    val sortedSchedulableQueue =
      schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator)
    for (schedulable <- sortedSchedulableQueue) {
      sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue
    }
    sortedTaskSetQueue
  }

可见排序核心的算法在taskSetSchedulingAlgorithm.comparator里,而两种模式的taskSetSchedulingAlgorithm对应的实现也不一样:

var taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
    schedulingMode match {
      case SchedulingMode.FAIR =>
        new FairSchedulingAlgorithm()
      case SchedulingMode.FIFO =>
        new FIFOSchedulingAlgorithm()
      case _ =>
        val msg = "Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead."
        throw new IllegalArgumentException(msg)
    }
  }

FIFO模式的算法类是FIFOSchedulingAlgorithm,FAIR模式的算法实现类是FairSchedulingAlgorithm。下面看两种模式下的比较函数的实现,FIFO:

override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
    val priority1 = s1.priority
    val priority2 = s2.priority
    var res = math.signum(priority1 - priority2)
    if (res == 0) {
      val stageId1 = s1.stageId
      val stageId2 = s2.stageId
      res = math.signum(stageId1 - stageId2)
    }
    res < 0
  }
  1. 先比较priority,在FIFO中该优先级实际上是Job ID,越早提交的job的jobId越小,priority越小,优先级越高。
  2. 若priority相同,则说明是同一个job里的TaskSetMagager,则比较StageId,StageId越小优先级越高。

下面看FAIR的排序算法:

override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
    val minShare1 = s1.minShare
    val minShare2 = s2.minShare
    val runningTasks1 = s1.runningTasks
    val runningTasks2 = s2.runningTasks
    val s1Needy = runningTasks1 < minShare1
    val s2Needy = runningTasks2 < minShare2
    val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0)
    val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0)
    val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble
    val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble

    var compare = 0
    if (s1Needy && !s2Needy) {
      return true
    } else if (!s1Needy && s2Needy) {
      return false
    } else if (s1Needy && s2Needy) {
      compare = minShareRatio1.compareTo(minShareRatio2)
    } else {
      compare = taskToWeightRatio1.compareTo(taskToWeightRatio2)
    }
    if (compare < 0) {
      true
    } else if (compare > 0) {
      false
    } else {
      s1.name < s2.name
    }
  }
  1. 调度池运行的task数小于minShare的优先级比不小于的优先级要高。
  2. 若两者运行的task个数都比minShare小,则比较minShare使用率,使用率约低优先级越高。
  3. 若两者的minShare使用率相同,则比较权重使用率,使用率约低优先级越高。
  4. 若权重也相同,则比较名字。

在FAIR模式中,需要先对子Pool进行排序,再对子Pool里面的TaskSetMagager进行排序,因为Pool和TaskSetMagager都继承了Schedulable特质,都是用的是FairSchedulingAlgorithm.FairSchedulingAlgorithm算法。

内存管理 MemoryManager 解析

概述

spark的内存管理有两套方案,新旧方案分别对应的类是UnifiedMemoryManager和StaticMemoryManager。

旧方案是静态的,storageMemory(存储内存)和executionMemory(执行内存)拥有的内存是独享的不可相互借用,故在其中一方内存充足,另一方内存不足但又不能借用的情况下会造成资源的浪费。新方案是统一管理的,初始状态是内存各占一半,但其中一方内存不足时可以向对方借用,对内存资源进行合理有效的利用,提高了整体资源的利用率。

总的来说内存分为三大块,包括storageMemory、executionMemory、系统预留,其中storageMemory用来缓存rdd,unroll partition,存放direct task result、广播变量,在 Spark Streaming receiver 模式中存放每个 batch 的 blocks。executionMemory用于shuffle、join、sort、aggregation 中的缓存。除了这两者以外的内存都是预留给系统的。

旧方案 StaticMemoryManager

在SparkEnv中会创建memoryManager:

val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false)
    val memoryManager: MemoryManager =
      if (useLegacyMemoryManager) {
        new StaticMemoryManager(conf, numUsableCores)
      } else {
        UnifiedMemoryManager(conf, numUsableCores)
      }

默认使用的是统一管理方案UnifiedMemoryManager,这里我们简要的看看旧方案StaticMemoryManager。

storageMemory能分到的内存是:

systemMaxMemory * memoryFraction * safetyFraction

其中:

  • systemMaxMemory :Runtime.getRuntime.maxMemory,即JVM能获得的最大内存空间。
  • memoryFraction:由参数spark.storage.memoryFraction控制,默认0.6。
  • safetyFraction:由参数spark.storage.safetyFraction控制,默认是0.9,因为cache block都是估算的,所以需要一个安全系数来保证安全。

executionMemory能分到的内存是:

systemMaxMemory * memoryFraction * safetyFraction

其中:

  • systemMaxMemory :Runtime.getRuntime.maxMemory,即JVM能获得的最大内存空间。
  • memoryFraction:由参数spark.shuffle.memoryFraction控制,默认0.2。
  • safetyFraction:由参数spark.shuffle.safetyFraction控制,默认是0.8。

memoryFraction系数之外和安全系数之外的内存就是给系统预留的了。

executionMemory能分到的内存直接影响了shuffle中spill的频率,增加executionMemory可减少spill的次数,但storageMemory能cache的容量也相应减少。

execution 和 storage 被分配到内存后大小就一直不变了,每次申请内存都只能申请自己独有的不能相互借用,会造成资源的浪费。另外,只有 execution 内存支持 off heap,storage 内存不支持 off heap。

新方案 UnifiedMemoryManager

由于新方案中storageMemory和executionMemory是统一管理的,我们看看两者一共能拿到多少内存。

private def getMaxMemory(conf: SparkConf): Long = {
    val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory)
    val reservedMemory = conf.getLong("spark.testing.reservedMemory",
      if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES)
    val minSystemMemory = (reservedMemory * 1.5).ceil.toLong
    if (systemMemory < minSystemMemory) {
      throw new IllegalArgumentException(s"System memory $systemMemory must " +
        s"be at least $minSystemMemory. Please increase heap size using the --driver-memory " +
        s"option or spark.driver.memory in Spark configuration.")
    }
    // SPARK-12759 Check executor memory to fail fast if memory is insufficient
    if (conf.contains("spark.executor.memory")) {
      val executorMemory = conf.getSizeAsBytes("spark.executor.memory")
      if (executorMemory < minSystemMemory) {
        throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " +
          s"$minSystemMemory. Please increase executor memory using the " +
          s"--executor-memory option or spark.executor.memory in Spark configuration.")
      }
    }
    val usableMemory = systemMemory - reservedMemory
    val memoryFraction = conf.getDouble("spark.memory.fraction", 0.6)
    (usableMemory * memoryFraction).toLong
  }

首先给系统内存reservedMemory预留了300M,若jvm能拿到的最大内存和配置的executor内存分别不足以reservedMemory的1.5倍即450M都会抛出异常,最后storage和execution能拿到的内存为:

 (heap space - 300) * spark.memory.fraction (默认为0.6)

storage和execution各占所获内存的50%。

申请storage内存

为某个blockId申请numBytes大小的内存:

override def acquireStorageMemory(
      blockId: BlockId,
      numBytes: Long,
      memoryMode: MemoryMode): Boolean = synchronized {
    assertInvariants()
    assert(numBytes >= 0)
    val (executionPool, storagePool, maxMemory) = memoryMode match {
      case MemoryMode.ON_HEAP => (
        onHeapExecutionMemoryPool,
        onHeapStorageMemoryPool,
        maxOnHeapStorageMemory)
      case MemoryMode.OFF_HEAP => (
        offHeapExecutionMemoryPool,
        offHeapStorageMemoryPool,
        maxOffHeapMemory)
    }
    // 申请的内存大于storage和execution内存之和
    if (numBytes > maxMemory) {
      // Fail fast if the block simply won't fit
      logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " +
        s"memory limit ($maxMemory bytes)")
      return false
    }
    // 大于storage空闲内存
    if (numBytes > storagePool.memoryFree) {
      // There is not enough free memory in the storage pool, so try to borrow free memory from
      // the execution pool.
      val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, numBytes)
      executionPool.decrementPoolSize(memoryBorrowedFromExecution)
      storagePool.incrementPoolSize(memoryBorrowedFromExecution)
    }
    storagePool.acquireMemory(blockId, numBytes)
  }
  • 若申请的numBytes比两者总共的内存还大,直接返回false,说明申请失败。
  • 若numBytes比storage空闲的内存大,则需要向executionPool借用
    • 借用的大小为此时execution的空闲内存和numBytes的较小值(个人观点应该是和<numBytes-storage空闲内存>的较小值)
    • 减小execution的poolSize
    • 增加storage的poolSize

即使向executionPool借用了内存,但不一定就够numBytes,因为不可能把execution正在使用的内存都接过来,接着调用了storagePool的acquireMemory方法在不够numBytes的情况下去释放storage**cache的rdd,以增加storagePool.memoryFree的值:

def acquireMemory(blockId: BlockId, numBytes: Long): Boolean = lock.synchronized {
    val numBytesToFree = math.max(0, numBytes - memoryFree)
    acquireMemory(blockId, numBytes, numBytesToFree)
  }

计算出向execution借了内存后还差多少内存才能满足numBytes,即需要释放的内存numBytesToFree 。接着调用了acquireMemory方法:

def acquireMemory(
      blockId: BlockId,
      numBytesToAcquire: Long,
      numBytesToFree: Long): Boolean = lock.synchronized {
    assert(numBytesToAcquire >= 0)
    assert(numBytesToFree >= 0)
    assert(memoryUsed <= poolSize)
    if (numBytesToFree > 0) {
      memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, memoryMode)
    }
    // NOTE: If the memory store evicts blocks, then those evictions will synchronously call
    // back into this StorageMemoryPool in order to free memory. Therefore, these variables
    // should have been updated.
    val enoughMemory = numBytesToAcquire <= memoryFree
    if (enoughMemory) {
      _memoryUsed += numBytesToAcquire
    }
    enoughMemory
  }

当numBytesToFree 大于0的情况下,就真的要去释放缓存在memory中的block,释放完后再看空闲内存是否能满足numBytes,若满足则将numBytes加到已使用的变量里。

看看当需要从storay中释放block的时候是怎么释放的:

private[spark] def evictBlocksToFreeSpace(
      blockId: Option[BlockId],
      space: Long,
      memoryMode: MemoryMode): Long = {
    assert(space > 0)
    memoryManager.synchronized {
      var freedMemory = 0L
      val rddToAdd = blockId.flatMap(getRddId)
      val selectedBlocks = new ArrayBuffer[BlockId]
      def blockIsEvictable(blockId: BlockId, entry: MemoryEntry[_]): Boolean = {
        entry.memoryMode == memoryMode && (rddToAdd.isEmpty || rddToAdd != getRddId(blockId))
      }
      // This is synchronized to ensure that the set of entries is not changed
      // (because of getValue or getBytes) while traversing the iterator, as that
      // can lead to exceptions.
      entries.synchronized {
        val iterator = entries.entrySet().iterator()
        while (freedMemory < space && iterator.hasNext) {
          val pair = iterator.next()
          val blockId = pair.getKey
          val entry = pair.getValue
          if (blockIsEvictable(blockId, entry)) {
            // We don't want to evict blocks which are currently being read, so we need to obtain
            // an exclusive write lock on blocks which are candidates for eviction. We perform a
            // non-blocking "tryLock" here in order to ignore blocks which are locked for reading:
            if (blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) {
              selectedBlocks += blockId
              freedMemory += pair.getValue.size
            }
          }
        }
      }

      def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = {
        val data = entry match {
          case DeserializedMemoryEntry(values, _, _) => Left(values)
          case SerializedMemoryEntry(buffer, _, _) => Right(buffer)
        }
        val newEffectiveStorageLevel =
          blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag)
        if (newEffectiveStorageLevel.isValid) {
          // The block is still present in at least one store, so release the lock
          // but don't delete the block info
          blockInfoManager.unlock(blockId)
        } else {
          // The block isn't present in any store, so delete the block info so that the
          // block can be stored again
          blockInfoManager.removeBlock(blockId)
        }
      }

      if (freedMemory >= space) {
        logInfo(s"${selectedBlocks.size} blocks selected for dropping " +
          s"(${Utils.bytesToString(freedMemory)} bytes)")
        for (blockId <- selectedBlocks) {
          val entry = entries.synchronized { entries.get(blockId) }
          // This should never be null as only one task should be dropping
          // blocks and removing entries. However the check is still here for
          // future safety.
          if (entry != null) {
            dropBlock(blockId, entry)
          }
        }
        logInfo(s"After dropping ${selectedBlocks.size} blocks, " +
          s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}")
        freedMemory
      } else {
        blockId.foreach { id =>
          logInfo(s"Will not store $id")
        }
        selectedBlocks.foreach { id =>
          blockInfoManager.unlock(id)
        }
        0L
      }
    }
  }

spark中内存中的block都是通过memoryStore来存储的,用

private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true)

来维护了blockId和MemoryEntry(对应value的包装)的关联,另外方法中还定义了两个方法,blockIsEvictable方法是判断遍历到的blockId和当前blockId是否属于同一个rdd,因为不能提出同一个rdd的另外一个block。dropBlock方法就是真正执行从内存中移除block的,若StorageLevel包括了使用disk,则会写到磁盘文件。

整段代码的逻辑简单概述就是:遍历当前memoryStore中存的每个block(不是和当前请求的block属于同于同一rdd),直到block对应的内存之和大于所需释放的内存才停止遍历,也有可能遍历完了都还不能满足所需的内存。若能释放的内存满足所需的内存,则真正执行移除,否则不移除,因为不可能一个block在内存中一部分,在磁盘一部分,最后返回真正剔除block释放的内存。

总结一下向StorageMemory申请内存的过程(在MemoryMode.ON_HEAP模式下):

  • 若numBytes大于storage和execution内存之和,抛异常。
  • 若numBytes大于storage空闲内存,向execution借用min(executionFree,numBytes)大的内存,并更新各自的poolSize。
  • 若申请完后还不够,则释放storage中的block来补足。
    • memoryStore缓存的block大小满足需要补足的大小,则真正执行剔除(遍历block直到内存满足需求对应的block),否则不剔除。
  • 最终若空闲内存满足numBytes则返回true,否则返回false。

申请execution内存

在execution内存不足向storage借用时,还是不满足所需内存的情况下能借多少借多少。看看在需要向execution申请内存时是怎么处理的(MemoryMode.ON_HEAP模式下):

override private[memory] def acquireExecutionMemory(
      numBytes: Long,
      taskAttemptId: Long,
      memoryMode: MemoryMode): Long = synchronized {
    assertInvariants()
    assert(numBytes >= 0)
    val (executionPool, storagePool, storageRegionSize, maxMemory) = memoryMode match {
      case MemoryMode.ON_HEAP => (
        onHeapExecutionMemoryPool,
        onHeapStorageMemoryPool,
        onHeapStorageRegionSize,
        maxHeapMemory)
      case MemoryMode.OFF_HEAP => (
        offHeapExecutionMemoryPool,
        offHeapStorageMemoryPool,
        offHeapStorageMemory,
        maxOffHeapMemory)
    }

    /**
     * Grow the execution pool by evicting cached blocks, thereby shrinking the storage pool.
     *
     * When acquiring memory for a task, the execution pool may need to make multiple
     * attempts. Each attempt must be able to evict storage in case another task jumps in
     * and caches a large block between the attempts. This is called once per attempt.
     */
    def maybeGrowExecutionPool(extraMemoryNeeded: Long): Unit = {
      if (extraMemoryNeeded > 0) {
        // There is not enough free memory in the execution pool, so try to reclaim memory from
        // storage. We can reclaim any free memory from the storage pool. If the storage pool
        // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim
        // the memory that storage has borrowed from execution.
        val memoryReclaimableFromStorage = math.max(
          storagePool.memoryFree,
          storagePool.poolSize - storageRegionSize)
        if (memoryReclaimableFromStorage > 0) {
          // Only reclaim as much space as is necessary and available:
          val spaceToReclaim = storagePool.freeSpaceToShrinkPool(
            math.min(extraMemoryNeeded, memoryReclaimableFromStorage))
          storagePool.decrementPoolSize(spaceToReclaim)
          executionPool.incrementPoolSize(spaceToReclaim)
        }
      }
    }

    /**
     * The size the execution pool would have after evicting storage memory.
     *
     * The execution memory pool divides this quantity among the active tasks evenly to cap
     * the execution memory allocation for each task. It is important to keep this greater
     * than the execution pool size, which doesn't take into account potential memory that
     * could be freed by evicting storage. Otherwise we may hit SPARK-12155.
     *
     * Additionally, this quantity should be kept below `maxMemory` to arbitrate fairness
     * in execution memory allocation across tasks, Otherwise, a task may occupy more than
     * its fair share of execution memory, mistakenly thinking that other tasks can acquire
     * the portion of storage memory that cannot be evicted.
     */
    def computeMaxExecutionPoolSize(): Long = {
      maxMemory - math.min(storagePool.memoryUsed, storageRegionSize)
    }

    executionPool.acquireMemory(
      numBytes, taskAttemptId, maybeGrowExecutionPool, computeMaxExecutionPoolSize)
  }

这里先讲解这里面的两个方法:

maybeGrowExecutionPool就是需要向storage借内存的方法,能借用的最大内存memoryReclaimableFromStorage 为storage的空闲内存和storage向execution借用的内存(即已经使用也要释放来归还)的较大值,若memoryReclaimableFromStorage为0,则说明storage之前没有向execution借用内存,并且此时storage没有空闲的内存可借。
最终申请借用的是所需内存和memoryReclaimableFromStorage的较小值(缺多少借多少),跟进storagePool.freeSpaceToShrinkPool方法看看其实现:

def freeSpaceToShrinkPool(spaceToFree: Long): Long = lock.synchronized {
    val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree)
    val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory
    if (remainingSpaceToFree > 0) {
      // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks:
      val spaceFreedByEviction =
        memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, memoryMode)
      // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do
      // not need to decrement _memoryUsed here. However, we do need to decrement the pool size.
      spaceFreedByReleasingUnusedMemory + spaceFreedByEviction
    } else {
      spaceFreedByReleasingUnusedMemory
    }
  }

若storage空闲内存不足以所申请的内存,则需要通过释放storage中缓存的block来补充。

方法computeMaxExecutionPoolSize即计算的是execution拥有的最大可用内存。

接着通过这两个函数作为参数调用了方法executionPool.acquireMemory:

private[memory] def acquireMemory(
      numBytes: Long,
      taskAttemptId: Long,
      maybeGrowPool: Long => Unit = (additionalSpaceNeeded: Long) => Unit,
      computeMaxPoolSize: () => Long = () => poolSize): Long = lock.synchronized {
    assert(numBytes > 0, s"invalid number of bytes requested: $numBytes")

    // TODO: clean up this clunky method signature

    // Add this task to the taskMemory map just so we can keep an accurate count of the number
    // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory`
    if (!memoryForTask.contains(taskAttemptId)) {
      memoryForTask(taskAttemptId) = 0L
      // This will later cause waiting tasks to wake up and check numTasks again
      lock.notifyAll()
    }

    // Keep looping until we're either sure that we don't want to grant this request (because this
    // task would have more than 1 / numActiveTasks of the memory) or we have enough free
    // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
    // TODO: simplify this to limit each task to its own slot
    while (true) {
      val numActiveTasks = memoryForTask.keys.size
      val curMem = memoryForTask(taskAttemptId)

      // In every iteration of this loop, we should first try to reclaim any borrowed execution
      // space from storage. This is necessary because of the potential race condition where new
      // storage blocks may steal the free execution memory that this task was waiting for.
      maybeGrowPool(numBytes - memoryFree)

      // Maximum size the pool would have after potentially growing the pool.
      // This is used to compute the upper bound of how much memory each task can occupy. This
      // must take into account potential free memory as well as the amount this pool currently
      // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management,
      // we did not take into account space that could have been freed by evicting cached blocks.
      val maxPoolSize = computeMaxPoolSize()
      val maxMemoryPerTask = maxPoolSize / numActiveTasks
      val minMemoryPerTask = poolSize / (2 * numActiveTasks)

      // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks
      val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem))
      // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
      val toGrant = math.min(maxToGrant, memoryFree)

      // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
      // if we can't give it this much now, wait for other tasks to free up memory
      // (this happens if older tasks allocated lots of memory before N grew)
      if (toGrant < numBytes && curMem + toGrant < minMemoryPerTask) {
        logInfo(s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free")
        lock.wait()
      } else {
        memoryForTask(taskAttemptId) += toGrant
        return toGrant
      }
    }
    0L  // Never reached
  }

里面定义了一个Task能使用的execution内存:

val maxPoolSize = computeMaxPoolSize()
      val maxMemoryPerTask = maxPoolSize / numActiveTasks
      val minMemoryPerTask = poolSize / (2 * numActiveTasks)

其中maxPoolSize 为从 storage 借用了内存后,executionMemoryPool 的最大可用内存,保证一个Task可用的内存在 1/2*numActiveTasks ~ 1/numActiveTasks 范围内,整体保证各个Task资源占用平衡。

向execution申请内存代码流程:

  1. 先获取Task目前已经分配到的内存。
  2. 当numBytes大于execution空闲内存,则会通过maybeGrowPool方法向storage借内存。
  3. 能获取的最大内存maxToGrant为numBytes和(maxMemoryPerTask - curMem)的较小值。
  4. 本次循环能获取真正的内存toGrant为maxToGrant和(execution向memory借用后可用的内存)的较小值。
  5. 若最终能申请的内存小于numBytes且申请的内存加上原来有的内存还不足以一个Task最小的使用内存minMemoryPerTask,则会阻塞,直到有足够的内存或者有新的Task进来减小了minMemoryPerTask的值。
    否则直接返回本次分配到的内存。

对于向storage和execution申请内存以及相互借用内存的方式至此讲解完成。用到storage和execution内存的地方很多(看概述),其中缓存rdd会向storage申请内存,运行Task会向execution申请内存,接下来分别看看是在什么时候申请的。

缓存 RDD

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      getOrCompute(split, context)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

每个rdd分区的数据都是通过对应的迭代器得到,其中若存储级别不为NONE,则会先尝试从储存介质中(内存、磁盘文件等)获取,第一次获取当然都没有,只有先计算完缓存起来以供后续的计算直接获取。缓存序列化和非序列化的数据的缓存方式不一样,非序列化的缓存的代码是:

memoryStore.putIteratorAsValues(blockId, iterator(), classTag)
 private[storage] def putIteratorAsValues[T](
      blockId: BlockId,
      values: Iterator[T],
      classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {

    require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")

    // Number of elements unrolled so far
    var elementsUnrolled = 0
    // Whether there is still enough memory for us to continue unrolling this block
    var keepUnrolling = true
    // Initial per-task memory to request for unrolling blocks (bytes).
    val initialMemoryThreshold = unrollMemoryThreshold
    // How often to check whether we need to request more memory
    val memoryCheckPeriod = 16
    // Memory currently reserved by this task for this particular unrolling operation
    var memoryThreshold = initialMemoryThreshold
    // Memory to request as a multiple of current vector size
    val memoryGrowthFactor = 1.5
    // Keep track of unroll memory used by this particular block / putIterator() operation
    var unrollMemoryUsedByThisBlock = 0L
    // Underlying vector for unrolling the block
    var vector = new SizeTrackingVector[T]()(classTag)

    // Request enough memory to begin unrolling
    keepUnrolling =
      reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP)

    if (!keepUnrolling) {
      logWarning(s"Failed to reserve initial memory threshold of " +
        s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
    } else {
      unrollMemoryUsedByThisBlock += initialMemoryThreshold
    }

    // Unroll this block safely, checking whether we have exceeded our threshold periodically
    while (values.hasNext && keepUnrolling) {
      vector += values.next()
      if (elementsUnrolled % memoryCheckPeriod == 0) {
        // If our vector's size has exceeded the threshold, request more memory
        val currentSize = vector.estimateSize()
        if (currentSize >= memoryThreshold) {
          val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
          keepUnrolling =
            reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP)
          if (keepUnrolling) {
            unrollMemoryUsedByThisBlock += amountToRequest
          }
          // New threshold is currentSize * memoryGrowthFactor
          memoryThreshold += amountToRequest
        }
      }
      elementsUnrolled += 1
    }

    if (keepUnrolling) {
      // We successfully unrolled the entirety of this block
      val arrayValues = vector.toArray
      vector = null
      val entry =
        new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
      val size = entry.size
      def transferUnrollToStorage(amount: Long): Unit = {
        // Synchronize so that transfer is atomic
        memoryManager.synchronized {
          releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount)
          val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP)
          assert(success, "transferring unroll memory to storage memory failed")
        }
      }
      // Acquire storage memory if necessary to store this block in memory.
      val enoughStorageMemory = {
        if (unrollMemoryUsedByThisBlock <= size) {
          val acquiredExtra =
            memoryManager.acquireStorageMemory(
              blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP)
          if (acquiredExtra) {
            transferUnrollToStorage(unrollMemoryUsedByThisBlock)
          }
          acquiredExtra
        } else { // unrollMemoryUsedByThisBlock > size
          // If this task attempt already owns more unroll memory than is necessary to store the
          // block, then release the extra memory that will not be used.
          val excessUnrollMemory = unrollMemoryUsedByThisBlock - size
          releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory)
          transferUnrollToStorage(size)
          true
        }
      }
      if (enoughStorageMemory) {
        entries.synchronized {
          entries.put(blockId, entry)
        }
        logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(
          blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
        Right(size)
      } else {
        assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock,
          "released too much unroll memory")
        Left(new PartiallyUnrolledIterator(
          this,
          MemoryMode.ON_HEAP,
          unrollMemoryUsedByThisBlock,
          unrolled = arrayValues.toIterator,
          rest = Iterator.empty))
      }
    } else {
      // We ran out of space while unrolling the values for this block
      logUnrollFailureMessage(blockId, vector.estimateSize())
      Left(new PartiallyUnrolledIterator(
        this,
        MemoryMode.ON_HEAP,
        unrollMemoryUsedByThisBlock,
        unrolled = vector.iterator,
        rest = values))
    }
  }

代码太长了,我自己看到都头大了,没事,咱一点一点的慢慢来~

参数中的blockId是一个block的唯一标示,格式是"rdd_" + rddId + "_" + splitIndex,value就是该partition对应数据的迭代器,

  1. 通过reserveUnrollMemoryForThisTask方法向Storage申请initialMemoryThreshold(初始值可通过spark.storage.unrollMemoryThreshold配置,默认1M)的内存来unroll 迭代器:
    def reserveUnrollMemoryForThisTask(
      blockId: BlockId,
      memory: Long,
      memoryMode: MemoryMode): Boolean = {
    memoryManager.synchronized {
      val success = memoryManager.acquireUnrollMemory(blockId, memory, memoryMode)
      if (success) {
        val taskAttemptId = currentTaskAttemptId()
        val unrollMemoryMap = memoryMode match {
          case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap
          case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap
        }
        unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
      }
      success
    }
    }
    
    跟进acquireUnrollMemory可看见底层调用的就是前面所讲的向storage申请内存的方法acquireStorageMemory,若申请成功则将对应的onHeapUnrollMemoryMap加上申请到的内存,即unroll使用的内存。
  2. 若申请成功则跟新unrollMemoryUsedByThisBlock的值,即在该block上unroll使用的内存。
  3. 接着进行遍历,停止遍历的条件有两个,一是迭代器全部遍历完,二是没有申请到内存。
    • 每迭代一条数据都会加到SizeTrackingVector类型的vector中(底层由数组实现),每迭代16次都会估算vector的大小是否超过了memoryThreshold(申请的内存)。
    • 若超过了memoryThreshold,则会计算再次申请内存的大小,1.5倍当前vector大小-已经申请到的内存大小。
    • 再次向Storage申请内存,若申请成功,则跟新unrollMemoryUsedByThisBlock,继续遍历进入下次循环,否则停止遍历。
  4. 循环结束后,若keepUnrolling 为 true,则说明values 一定被全部展开了;若为false,则没有全部被展开,说明没有申请到足够的内存来展开这个values,意味着该partition缓存到内存失败。
  5. 在values全部成功展开的前提下,会将vector构造成一个DeserializedMemoryEntry对象,其中包括数据的大小,接着会将展开后的数据大小和申请的内存大小作比较:
    • 若申请的内存比数据小,则再次向storage申请对应的大小,申请成功则将unroll使用的内存转化到storage中去,转化对应的逻辑是:释放掉该Task占用的所有unroll内存,又向storage申请对应的内存,其实unroll内存就是storage内存,即操作的都是storage的内存,减去某值又加上某值,结果没有变,但流程还得这么走,因为为了将 MemoryStore 和 MemoryManager 的解耦。
    • 若申请的内存比数据大,则释放掉对应的unroll内存,接着将unroll使用的内存转化到storage中去。
    • 最后将blockId和对应的entry加入到memorySore所管理的entries中去。

缓存序列化rdd支持 ON_HEAP 和 OFF_HEAP,和缓存非序列化rdd的方式类似,只是以流的形式写到bytebuffer中,其中
MemoryMode 如果是 ON_HEAP,这里的 ByteBuffer 是 HeapByteBuffer(堆上内存);而如果是 OFF_HEAP,这里的 ByteBuffer 则是 DirectByteBuffer(指向的是堆外内存)。最后根据数据构建成SerializedMemoryEntry来保存在memoryStore的entries中。

shuffle中execution内存的使用

在shuffle write的时候,并不会直接将数据写到磁盘(详情请看Shuffle Write解析),而是先写到一个集合中,此集合占用的内存就是execution内存,初始给的大小是5M,可通过spark.shuffle.spill.initialMemoryThreshold进行设置,每写一次数据就判断是否需要溢写到磁盘,溢写之前还尝试会向execution申请来避免溢写,代码如下:

protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = acquireMemory(amountToRequest)
      myMemoryThreshold += granted
      // If we were granted too little memory to grow further (either tryToAcquire returned 0,
      // or we already had more memory than myMemoryThreshold), spill the current collection
      shouldSpill = currentMemory >= myMemoryThreshold
    }
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
    // Actually spill
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      spill(collection)
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      releaseMemory()
    }
    shouldSpill
  }

当insert&update的次数是32的倍数且当前集合的大小已经大于等于了已经申请到的内存,此时会尝试向execution申请更多的内存来避免spill,申请的大小为2倍当前集合大小减去已经申请到的内存大小,跟进acquireMemory方法:

 public long acquireMemory(long size) {
    long granted = taskMemoryManager.acquireExecutionMemory(size, this);
    used += granted;
    return granted;
  }

这不就是我们前面讲的向execution申请内存的方法吗,这里就不再叙述。

参考

http://www.jianshu.com/p/999ef21dffe8

Spark整合HBase(自定义HBase DataSource)

背景

Spark支持多种数据源,但是Spark对HBase 的读写都没有相对优雅的api,但spark和HBase整合的场景又比较多,故通过spark的DataSource API自己实现了一套比较方便操作HBase的API。

写 HBase

写HBase会根据Dataframe的schema写入对应数据类型的数据到Hbase,先上使用示例:

import spark.implicits._
import org.apache.hack.spark._
val df = spark.createDataset(Seq(("ufo",  "play"), ("yy",  ""))).toDF("name", "like")
// 方式一
val options = Map(
            "hbase.table.rowkey.field" -> "name",
            "hbase.table.numReg" -> "12",
            "hbase.table.rowkey.prefix" -> "00",
            "bulkload.enable" -> "false"
        )
df.saveToHbase("hbase_table", Some("XXX:2181"), options)
// 方式二
df1.write.format("org.apache.spark.sql.execution.datasources.hbase")
            .options(Map(
                "hbase.table.rowkey.field" -> "name",
                "hbase.table.name" -> "hbase_table",
                "hbase.zookeeper.quorum" -> "XXX:2181",
                "hbase.table.rowkey.prefix" -> "00",
                "hbase.table.numReg" -> "12",
                "bulkload.enable" -> "false"
            )).save()

上面两种方式实现的效果是一样的,下面解释一下每个参数的含义:

  • hbase.zookeeper.quorum:zookeeper地址
  • hbase.table.rowkey.field:spark临时表的哪个字段作为hbase的rowkey,默认第一个字段
  • bulkload.enable:是否启动bulkload,默认不启动,当要插入的hbase表只有一列rowkey时,必需启动
  • hbase.table.name:Hbase表名
  • hbase.table.family:列族名,默认info
  • hbase.table.startKey:预分区开始key,当hbase表不存在时,会自动创建Hbase表,不带一下三个参数则只有一个分区
  • hbase.table.endKey:预分区开始key
  • hbase.table.numReg:分区个数
  • hbase.table.rowkey.prefix: 当rowkey是数字开头,预分区需要指明前缀的formate形式,如 00
  • hbase.check_table: 写入hbase表时,是否需要检查表是否存在,默认 false

读 HBase

示例代码如下:

// 方式一
import org.apache.hack.spark._
 val options = Map(
    "spark.table.schema" -> "appid:String,appstoreid:int,firm:String",
    "hbase.table.schema" -> ":rowkey,info:appStoreId,info:firm"
)
spark.hbaseTableAsDataFrame("hbase_table", Some("XXX:2181")).show(false)
// 方式二
spark.read.format("org.apache.spark.sql.execution.datasources.hbase").
            options(Map(
            "spark.table.schema" -> "appid:String,appstoreid:int,firm:String",
            "hbase.table.schema" -> ":rowkey,info:appStoreId,info:firm",
            "hbase.zookeeper.quorum" -> "XXX:2181",
            "hbase.table.name" -> "hbase_table"
        )).load.show(false)  

spark和hbase表的schema映射关系指定不是必须的,默认会生成rowkey和content两个字段,content是由所有字段组成的json字符串,可通过field.type.fieldname对单个字段设置数据类型,默认都是StringType。这样映射出来还得通过spark程序转一下才是你想要的样子,而且所有字段都会去扫描,相对来说不是特别高效。

故我们可自定义schema映射来获取数据:

  • hbase.zookeeper.quorum:zookeeper地址
  • spark.table.schema:Spark临时表对应的schema eg: "ID:String,appname:String,age:Int"
  • hbase.table.schema:Hbase表对应schema eg: ":rowkey,info:appname,info:age"
  • hbase.table.name:Hbase表名
  • spark.rowkey.view.name:rowkey对应的dataframe创建的tempview名(设置了该值后,只获取rowkey对应的数据)

注意这两个schema是一一对应的,Hbase只会扫描hbase.table.schema对应的列。

核心代码

写 HBase

class DataFrameFunctions(data: DataFrame) extends Logging with Serializable {

    def saveToHbase(tableName: String, zkUrl: Option[String] = None,
                    options: Map[String, String] = new HashMap[String, String]): Unit = {

        val wrappedConf = {
            implicit val formats = DefaultFormats
            val hc = HBaseConfiguration.create()
            hc.set("hbase.zookeeper.quorum", zkUrl.getOrElse("127.0.0.1:2181"))
            new SerializableConfiguration(hc)
        }
        val hbaseConf = wrappedConf.value

        val rowkey = options.getOrElse("rowkey.field", data.schema.head.name)
        val family = options.getOrElse("family", "info")
        val numReg = options.getOrElse("numReg", -1).toString.toInt
        val startKey = options.getOrElse("startKey", null)
        val endKey = options.getOrElse("endKey", null)

        val rdd = data.rdd
        val f = family

        val tName = TableName.valueOf(tableName)
        val connection = ConnectionFactory.createConnection(hbaseConf)
        val admin = connection.getAdmin
        if (!admin.isTableAvailable(tName)) {
            HBaseUtils.createTable(connection, tName, family, startKey, endKey, numReg)
        }
        connection.close()
        if (hbaseConf.get("mapreduce.output.fileoutputformat.outputdir") == null) {
            hbaseConf.set("mapreduce.output.fileoutputformat.outputdir", "/tmp")
        }
        val jobConf = new JobConf(hbaseConf, this.getClass)
        jobConf.set(TableOutputFormat.OUTPUT_TABLE, tableName)

        val job = Job.getInstance(jobConf)
        job.setOutputKeyClass(classOf[ImmutableBytesWritable])
        job.setOutputValueClass(classOf[Result])
        job.setOutputFormatClass(classOf[TableOutputFormat[ImmutableBytesWritable]])

        val fields = data.schema.toArray
        val rowkeyIndex = fields.zipWithIndex.filter(f => f._1.name == rowkey).head._2
        val otherFields = fields.zipWithIndex.filter(f => f._1.name != rowkey)

        lazy val setters = otherFields.map(r => HBaseUtils.makeHbaseSetter(r))
        lazy val setters_bulkload = otherFields.map(r => HBaseUtils.makeHbaseSetter_bulkload(r))

        options.getOrElse("bulkload.enable", "true") match {

            case "true" =>
                val tmpPath = s"/tmp/bulkload/${tableName}" + System.currentTimeMillis()
                def convertToPut_bulkload(row: Row) = {
                    val rk = Bytes.toBytes(row.getString(rowkeyIndex))
                    setters_bulkload.map(_.apply(rk, row, f))
                }
                rdd.flatMap(convertToPut_bulkload)
                    .saveAsNewAPIHadoopFile(tmpPath, classOf[ImmutableBytesWritable], classOf[KeyValue], classOf[HFileOutputFormat2], job.getConfiguration)

                val bulkLoader: LoadIncrementalHFiles = new LoadIncrementalHFiles(hbaseConf)
                bulkLoader.doBulkLoad(new Path(tmpPath), new HTable(hbaseConf, tableName))

            case "false" =>
                def convertToPut(row: Row) = {
                    val put = new Put(Bytes.toBytes(row.getString(rowkeyIndex)))
                    setters.foreach(_.apply(put, row, f))
                    (new ImmutableBytesWritable, put)
                }
                rdd.map(convertToPut).saveAsNewAPIHadoopDataset(job.getConfiguration)
        }
    }
}

读Hbase

class SparkSqlContextFunctions(@transient val spark: SparkSession) extends Serializable {

    private val SPARK_TABLE_SCHEMA: String = "spark.table.schema"
    private val HBASE_TABLE_SCHEMA: String = "hbase.table.schema"

    def hbaseTableAsDataFrame(table: String, zkUrl: Option[String] = None,
                              options:Map[String, String] = new HashMap[String, String]
                             ): DataFrame = {

        val wrappedConf = {
            val hc = HBaseConfiguration.create()
            hc.set("hbase.zookeeper.quorum", zkUrl.getOrElse("127.0.0.1:2181"))
            hc.set(TableInputFormat.INPUT_TABLE, table)
            if (options.contains(HBASE_TABLE_SCHEMA)) {
                var str = ArrayBuffer[String]()
                options(HBASE_TABLE_SCHEMA)
                    .split(",", -1).map(field =>
                    if (!field.startsWith(":")) {
                        str += field
                    }
                )
                if (str.length > 1) hc.set(TableInputFormat.SCAN_COLUMNS, str.mkString(" "))
            }
            Array(SPARK_TABLE_SCHEMA,HBASE_TABLE_SCHEMA,TableInputFormat.SCAN_ROW_START,TableInputFormat.SCAN_ROW_STOP).foldLeft((hc,options)) {
                case ((_hc,_options),pram) => if(_options.contains(pram)) _hc.set(pram,_options(pram))
                    (_hc,_options)
            }
            new SerializableConfiguration(hc)
        }
        def hbaseConf = wrappedConf.value

        def schema: StructType = {
            import org.apache.spark.sql.types._
            Option(hbaseConf.get(SPARK_TABLE_SCHEMA)) match {
                case Some(schema) => HBaseUtils.registerSparkTableSchema(schema)
                case None =>
                    StructType(
                        Array(
                            StructField("rowkey", StringType, nullable = false),
                            StructField("content", StringType)
                        )
                    )
            }
        }

        Option(hbaseConf.get(SPARK_TABLE_SCHEMA)) match {
            case Some(s) =>
                require(hbaseConf.get(HBASE_TABLE_SCHEMA).nonEmpty, "Because the parameter spark.table.schema has been set, hbase.table.schema also needs to be set.")
                val sparkTableSchemas = schema.fields.map(f => SparkTableSchema(f.name, f.dataType))
                val hBaseTableSchemas = HBaseUtils.registerHbaseTableSchema(hbaseConf.get(HBASE_TABLE_SCHEMA))
                require(sparkTableSchemas.length == hBaseTableSchemas.length, "The length of the parameter spark.table.schema must be the same as the parameter hbase.table.schema.")
                val schemas = sparkTableSchemas.zip(hBaseTableSchemas)
                val setters = schemas.map(schema => HBaseUtils.makeHbaseGetter(schema))

                val hBaseRDD = spark.sparkContext.newAPIHadoopRDD(hbaseConf, classOf[TableInputFormat], classOf[ImmutableBytesWritable], classOf[Result])
                    .map { case (_, result) => Row.fromSeq(setters.map(r => r.apply(result)).toSeq) }
                spark.createDataFrame(hBaseRDD, schema)

            case None =>
                val hBaseRDD = spark.sparkContext.newAPIHadoopRDD(hbaseConf, classOf[TableInputFormat], classOf[ImmutableBytesWritable], classOf[Result])
                    .map { line =>
                        val rowKey = Bytes.toString(line._2.getRow)

                        implicit val formats = Serialization.formats(NoTypeHints)

                        val content = line._2.getMap.navigableKeySet().flatMap { f =>
                            line._2.getFamilyMap(f).map { c =>
                                val columnName = Bytes.toString(f) + ":" + Bytes.toString(c._1)
                                    options.get("field.type." + columnName) match {
                                    case Some(i) =>
                                        val value = i match {
                                            case "LongType" => Bytes.toLong(c._2)
                                            case "FloatType" => Bytes.toFloat(c._2)
                                            case "DoubleType" => Bytes.toDouble(c._2)
                                            case "IntegerType" => Bytes.toInt(c._2)
                                            case "BooleanType" => Bytes.toBoolean(c._2)
                                            case "BinaryType" => c._2
                                            case "TimestampType" => new Timestamp(Bytes.toLong(c._2))
                                            case "DateType" => new java.sql.Date(Bytes.toLong(c._2))
                                            case _ => Bytes.toString(c._2)
                                        }
                                        (columnName, value)
                                    case None => (columnName, Bytes.toString(c._2))
                                }
                            }
                        }.toMap
                        val contentStr = Serialization.write(content)
                        Row.fromSeq(Seq(rowKey,contentStr))
                    }
                spark.createDataFrame(hBaseRDD, schema)
        }
    }
}

扩展的DataSource都需要是名为DefaultSource 的类

class DefaultSource extends CreatableRelationProvider with RelationProvider with DataSourceRegister {

    override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation =
        HBaseRelation(parameters, None)(sqlContext)

    override def shortName(): String = "hbase"

    override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
        val relation = InsertHBaseRelation(data, parameters)(sqlContext)
        relation.insert(data, false)
        relation
    }
}

private[sql] case class InsertHBaseRelation(
                                               dataFrame: DataFrame,
                                               parameters: Map[String, String]
                                           )(@transient val sqlContext: SQLContext)
    extends BaseRelation with InsertableRelation with Logging {

    override def insert(data: DataFrame, overwrite: Boolean): Unit = {

        def getZkURL: String = parameters.getOrElse("zk", parameters.getOrElse("hbase.zookeeper.quorum", sys.error("You must specify parameter zkurl...")))
        def getOutputTableName: String = parameters.getOrElse("outputTableName", sys.error("You must specify parameter outputTableName..."))

        import org.apache.hack.spark._
        data.saveToHbase(getOutputTableName, Some(getZkURL), parameters)
    }
    override def schema: StructType = dataFrame.schema
}

private[sql] case class HBaseRelation(
                                         parameters: Map[String, String],
                                         userSpecifiedschema: Option[StructType]
                                     )(@transient val sqlContext: SQLContext)
    extends BaseRelation with TableScan with Logging {

    def getZkURL: String = parameters.getOrElse("zk", parameters.getOrElse("hbase.zookeeper.quorum", sys.error("You must specify parameter zkurl...")))
    def getInputTableName: String = parameters.getOrElse("inputTableName", sys.error("You must specify parameter imputTableName..."))

    def buildScan(): RDD[Row] = {
        import org.apache.hack.spark._
        sqlContext.sparkSession.hbaseTableAsDataFrame(getInputTableName, Some(getZkURL), parameters).rdd
    }

    override def schema: StructType = {
        import org.apache.hack.spark._
        sqlContext.sparkSession.hbaseTableAsDataFrame(getInputTableName, Some(getZkURL), parameters).schema
    }
}

参考

从spark-submit开始解析整个任务调度流程

本文在spark2.1以Standalone Cluster模式下解析

概述

spark应用程序可以以Client模式和Cluster启动,区别在于Client模式下的Driver是在执行spark-submit命令节点上启动的,而Cluster模式下是Master随机选择的一台Worker通过DriverWrapper来启动Driver的。

大概流程为:

  • 通过spark-submit提交会调用SparkSubmit类,SparkSubmit类里通过反射调用Client,Client与Master通信来SubmitDriver,收到成功回复后退出JVM(SparkSubmit进程退出)。
  • Master收到SubmitDriver后会随机选择一台能满足driver资源需求的Worker,然后与对应Worker通信发送启动driver的消息。Worker收到消息后根据driver的信息等来拼接成linux命令来启动DriverWrapper,在该类里面再启动driver,最后将Driver执行状态返回给Master。
  • driver启动后接下来就是注册APP,在SparkContext启动过程中会通过创建AppClient并与Master通信要求注册application。
  • Master收到消息后会去调度执行这个application,通过调度算法获取该application需要在哪些Worker上启动executor,接着与对应的Worker通信发送启动Executor的消息。
  • Worker 收到消息后通过拼接linux命令,启动了CoarseGrainedExecutorBackend进程,接着向Driver通信进行Executor的注册,成功注册后会在CoarseGrainedExecutorBackend中创建Executor对象。
  • 接着就是job的执行了,可以参看前面的文章……

Submit Driver

通过shell命令spark-submit提交一个自己编写的application,最终实际是通过java -cp调用的类是:

org.apache.spark.deploy.SparkSubmit

在该类的main方法中,在Cluster模式下不使用Rest,会通过反射调用Client类:

org.apache.spark.deploy.Client

在Client类的main方法中会获得与Master通信的EndpointRef,并且创建一个名为Client的ClientEndpoint,在生命周期的onStart中会创建一个Driver的描述信息对象DriverDescription,其中包括了最终需要启动Driver的mainClass:

org.apache.spark.deploy.worker.DriverWrapper

接着向Master发送一个RequestSubmitDriver消息,Master收到后将DriverInfo持久化到存储系统,然后通过schedule()去调度,接着会向Client返回一个SubmitDriverResponse消息,Client收到成功提交成功消息后会再次向Master发送RequestDriverStatus消息询问driver的状态,若能收到Master端存在该driver的回复消息DriverStatusResponse则退出JVM(SparkSubmit进程退出)。
流程如图:

Master LaunchDriver

前面提到Master收到提交Driver的消息后会调用schedule()方法:

private def schedule(): Unit = { 
    val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
    val numWorkersAlive = shuffledAliveWorkers.size
    var curPos = 0
    for (driver <- waitingDrivers.toList) {   
      var launched = false
      var numWorkersVisited = 0
      while (numWorkersVisited < numWorkersAlive && !launched) {
        val worker = shuffledAliveWorkers(curPos)
        numWorkersVisited += 1
        if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
          launchDriver(worker, driver)
          waitingDrivers -= driver
          launched = true
        }
        curPos = (curPos + 1) % numWorkersAlive
      }
    }
    startExecutorsOnWorkers()
  }

该方法会先打乱Worker防止Driver集中在一台Worker上,当Worker的资源满足driver所需要的资源,则会调用launchDriver方法。

在launchDriver方法里会向对应的Worker发送一个LaunchDriver消息,该Worker接收到消息后通过driver的各种描述信息创建一个DriverRunner,然后调用其start方法。

start方法中将driver的参数组织成Linux命令,通过java -cp来运行上面提到的DriverWrapper类来启动Driver,而不是直接启动,这是为了Driver程序和启动Driver的Worker程序共命运(源码注释中称为share fate),即如果此Worker挂了,对应的Driver也会停止。

最后将Driver的执行状态返回给Master。
流程如图:

Register APP

Driver起来后当然会涉及到APP向Master的注册,在创建SparkContext的时候,会创建SchedulerBackend和TaskScheduler:

val (sched, ts) = SparkContext.createTaskScheduler(this, master, deployMode)

接着调用了TaskScheduler(TaskSchedulerImpl)的start方法,start方法里面又调用了SchedulerBackend(standalone模式下是StandaloneSchedulerBackend)的start方法:

override def start() {
    super.start()
    ...
    val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
      appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit)
    client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
    client.start()
    ...
  }

super.start()中创建了driverEndpoint。先根据application的参数创建了ApplicationDescription,又创建了StandaloneAppClient并调用其start方法,在start方法中创建了名为AppClient的Endpoint,在其生命周期的onStart方法中向Master发送了RegisterApplication消息进行注册app。

Master收到RegisterApplication消息后,创建描述application的ApplicationInfo,并持久化到存储系统,随后向AppClient返回RegisteredApplication的消息,然后通过schedule()去调度application。
流程如图:

Launch Executor

在上文Master LaunchDriver时解析了该方法的前部分,前部分说明了是如何将Driver调度到Worker上启动的。

private def schedule(): Unit = { 
    val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
    val numWorkersAlive = shuffledAliveWorkers.size
    var curPos = 0
    for (driver <- waitingDrivers.toList) {   
      var launched = false
      var numWorkersVisited = 0
      while (numWorkersVisited < numWorkersAlive && !launched) {
        val worker = shuffledAliveWorkers(curPos)
        numWorkersVisited += 1
        if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
          launchDriver(worker, driver)
          waitingDrivers -= driver
          launched = true
        }
        curPos = (curPos + 1) % numWorkersAlive
      }
    }
    startExecutorsOnWorkers()
  }

现在来说说后部分 startExecutorsOnWorkers()是怎么在Worker上启动Executor的:

private def startExecutorsOnWorkers(): Unit = {
    // 遍历所有等待调度的application,顺序为FIFO
    for (app <- waitingApps if app.coresLeft > 0) {
      val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor
      // 过滤出资源能满足APP对于每一个Executor需求的Worker
      val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE)
        .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB &&
          worker.coresFree >= coresPerExecutor.getOrElse(1))
        .sortBy(_.coresFree).reverse
      // 对Executor的调度(为每个Worker分配的core数)
      val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps)

      // 根据前面调度好的,在对应Worker上启动Executor
      for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) {
        allocateWorkerResourceToExecutors(
          app, assignedCores(pos), coresPerExecutor, usableWorkers(pos))
      }
    }
  }

先过滤出能满足application对于一个Executor资源要求的Worker,然后对Executor进行调度,策略有两种:

  • 使用spreadOutApps算法分配资源,即Executor分布在尽可能多的Worker节点上
  • Executor聚集在某些Worker节点上

启用spreadOutApps算法通过参数spark.deploy.spreadOut配置,默认为true,scheduleExecutorsOnWorkers方法返回的就是每个Worker能分配到的core数。

然后通过allocateWorkerResourceToExecutors去计算该Worker上需要启动的Executor:

private def allocateWorkerResourceToExecutors(
      app: ApplicationInfo,
      assignedCores: Int,
      coresPerExecutor: Option[Int],
      worker: WorkerInfo): Unit = {
    // 计算在该Worker上启动的Executor数,总cores / 一个Executor所需
    // 若没有指定一个Executor所需core数,则将分到的core数都给一个Executor
    val numExecutors = coresPerExecutor.map { assignedCores / _ }.getOrElse(1)
    val coresToAssign = coresPerExecutor.getOrElse(assignedCores)
    for (i <- 1 to numExecutors) {
      val exec = app.addExecutor(worker, coresToAssign) 
      launchExecutor(worker, exec)
      app.state = ApplicationState.RUNNING
    }
  }

通过计算得到该Worker需要启动的Executor数,然后调用launchExecutor方法通过与对应的Worker通信来发送LaunchExecutor消息。
流程如图:

对应的Worker收到消息后将收到的信息封装成ExecutorRunner对象,并调用其start方法:

case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) =>
          ...
          val manager = new ExecutorRunner(
            appId,
            execId,
            appDesc.copy(command = Worker.maybeUpdateSSLSettings(appDesc.command, conf)),
            cores_,
            memory_,
            self,
            workerId,
            host,
            webUi.boundPort,
            publicAddress,
            sparkHome,
            executorDir,
            workerUri,
            conf,
            appLocalDirs, ExecutorState.RUNNING)
          executors(appId + "/" + execId) = manager
          manager.start()
          ...

在manager的start方法中调用了fetchAndRunExecutor方法:

private def fetchAndRunExecutor() {
    try { 
      val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf),
        memory, sparkHome.getAbsolutePath, substituteVariables)
      ...
      process = builder.start() 
     ...
    }

这里和启动Driver启动的方式类似,通过收到的信息拼接成Linux命令,通过Java -cp 来启动CoarseGrainedExecutorBackend进程。
流程如图:

在CoarseGrainedExecutorBackend的main方法里创建了名为Executor的Endpoint,在其生命周期的onStart()方法里向Driver发送了RegisterExecutor消息。

Driver收到消息后根据Executor信息创建了ExecutorData对象,并加入到executorDataMap集合中,然后返回RegisteredExecutor消息给CoarseGrainedExecutorBackend。

CoarseGrainedExecutorBackend收到RegisteredExecutor后:

case RegisteredExecutor =>
      logInfo("Successfully registered with driver")
      try {
        executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
      } catch {
        case NonFatal(e) =>
          exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
      }

便创建了一个Executor对象,此对象将执行Driver分配的Task。
流程如图:

接着就是通过DAGScheduler、TaskScheduler等对Stage的划分,Task的调度等执行,最终将Task结果返回到Driver,具体可看前面的文章:

参考

Shuffle Write解析 (Sort Based Shuffle)

本文基于 Spark 2.1 进行解析

前言

从 Spark 2.0 开始移除了Hash Based Shuffle,想要了解可参考Shuffle 过程,本文将讲解 Sort Based Shuffle。

ShuffleMapTask的结果(ShuffleMapStage中FinalRDD的数据)都将写入磁盘,以供后续Stage拉取,即整个Shuffle包括前Stage的Shuffle Write和后Stage的Shuffle Read,由于内容较多,本文先解析Shuffle Write。

概述:

  • 写records到内存缓冲区(一个数组维护的map),每次insert&update都需要检查是否达到溢写条件。
  • 若需要溢写,将集合中的数据根据partitionId和key(若需要)排序后顺序溢写到一个临时的磁盘文件,并释放内存新建一个map放数据,每次溢写都是写一个新的临时文件。
  • 一个task最终对应一个文件,将还在内存中的数据和已经spill的文件根据reduce端的partitionId进行合并,合并后需要再次聚合排序(有需要情况下),再根据partition的顺序写入最终文件,并返回每个partition在文件中的偏移量,最后以MapStatus对象返回给driver并注册到MapOutputTrackerMaster中,后续reduce好通过它来访问。

入口

执行一个ShuffleMapTask最终的执行逻辑是调用了ShuffleMapTask类
的runTask()方法:

override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val deserializeStartTime = System.currentTimeMillis()
    val ser = SparkEnv.get.closureSerializer.newInstance()
    // 从广播变量中反序列化出finalRDD和dependency
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime

    var writer: ShuffleWriter[Any, Any] = null
    try {
      // 获取shuffleManager
      val manager = SparkEnv.get.shuffleManager
      // 通过shuffleManager的getWriter()方法,获得shuffle的writer
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
       // 通过rdd指定分区的迭代器iterator方法来遍历每一条数据,再之上再调用writer的write方法以写数据
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

其中的finalRDD和dependency是在Driver端DAGScheluer中提交Stage的时候加入广播变量的。

接着通过SparkEnv获取shuffleManager,默认使用的是sort(对应的是org.apache.spark.shuffle.sort.SortShuffleManager),可通过spark.shuffle.manager设置。

然后manager.getWriter返回的是SortShuffleWriter,我们直接看writer.write发生了什么:

override def write(records: Iterator[Product2[K, V]]): Unit = {
    sorter = if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
    // 写内存缓冲区,超过阈值则溢写到磁盘文件
    sorter.insertAll(records)
    // 获取该task的最终输出文件
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
    val tmp = Utils.tempFileWith(output)
    try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      // merge后写到data文件
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      // 写index文件
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }
  • 通过判断是否有map端的combine来创建不同的ExternalSorter,若有则将对应的aggregator和keyOrdering作为参数传入。
  • 调用sorter.insertAll(records),将records写入内存缓冲区,超过阈值则溢写到磁盘文件。
  • Merge内存记录和所有被spill到磁盘的文件,并写到最终的数据文件.data中。
  • 将每个partition的偏移量写到index文件中。

先细看sorter.inster是怎么写到内存,并spill到磁盘文件的:

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined
    // 若需要Combine
    if (shouldCombine) {
      // 获取对新value合并到聚合结果中的函数
      val mergeValue = aggregator.get.mergeValue
      // 获取创建初始聚合值的函数
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      // 通过mergeValue 对已有的聚合结果的新value进行合并,通过createCombiner 对没有聚合结果的新value初始化聚合结果
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      // 遍历records
      while (records.hasNext) {
        addElementsRead()
        kv = records.next()
        // 使用update函数进行value的聚合
        map.changeValue((getPartition(kv._1), kv._1), update)
        // 是否需要spill到磁盘文件
        maybeSpillCollection(usingMap = true)
      }
    // 不需要Combine
    } else {
      // Stick values into our buffer
      while (records.hasNext) {
        addElementsRead()
        val kv = records.next()
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
      }
    }
  }
  • 需要聚合的情况,遍历records拿到record的KV,通过map的changeValue方法并根据update函数来对相同K的V进行聚合,这里的map是PartitionedAppendOnlyMap类型,只能添加数据不能删除数据,底层实现是一个数组,数组中存KV键值对的方式是[K1,V1,K2,V2...],每一次操作后都会判断是否要spill到磁盘。

  • 不需要聚合的情况,直接将record放入buffer,然后判断是否要溢写到磁盘。

先看map.changeValue方法到底是怎么通过map实现对数据combine的:

override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    // 通过聚合算法得到newValue
    val newValue = super.changeValue(key, updateFunc)
    // 跟新对map的大小采样
    super.afterUpdate()
    newValue
  }

super.changeValue的实现:

def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    ...
    // 根据k 得到pos
    var pos = rehash(k.hashCode) & mask
    var i = 1
    while (true) {
      // 从data中获取该位置的原来的key
      val curKey = data(2 * pos)  
      // 若原来的key和当前的key相等,则将两个值进行聚合
      if (k.eq(curKey) || k.equals(curKey)) {
        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        return newValue
       // 若当前key对应的位置没有key,则将当前key作为该位置的key
       // 并通过update方法初始化该位置的聚合结果
      } else if (curKey.eq(null)) {
        val newValue = updateFunc(false, null.asInstanceOf[V])
        data(2 * pos) = k
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        // 扩容
        incrementSize()
        return newValue
      // 若对应位置有key但不和当前key相等,即hash冲突了,则继续向后遍历
      } else {
        val delta = i
        pos = (pos + delta) & mask
        i += 1
      }
    }
    null.asInstanceOf[V] // Never reached but needed to keep compiler happy
  }

根据K的hashCode再哈希与上掩码 得到 pos,2 * pos 为 k 应该所在的位置,2 * pos + 1 为 k 对应的 v 所在的位置,获取k应该所在位置的原来的key:

  • 若原来的key和当前的 k 相等,则通过update函数将两个v进行聚合并更新该位置的value
  • 若原来的key存在但不和当前的k 相等,则说明hash冲突了,更新pos继续遍历
  • 若原来的key不存在,则将当前k作为该位置的key,并通过update函数初始化该k对应的聚合结果,接着会通过incrementSize()方法进行扩容:
     private def incrementSize() {
        curSize += 1
        if (curSize > growThreshold) {
          growTable()
        }
      }
    
    跟新curSize,若当前大小超过了阈值growThreshold(growThreshold是当前容量capacity的0.7倍),则通过growTable()来扩容:
protected def growTable() {
    // 容量翻倍
    val newCapacity = capacity * 2
    require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements")
    //生成新的数组来存数据
    val newData = new Array[AnyRef](2 * newCapacity)
    val newMask = newCapacity - 1
    var oldPos = 0
    while (oldPos < capacity) {
      // 将旧数组中的数据重新计算位置放到新的数组中
      if (!data(2 * oldPos).eq(null)) {
        val key = data(2 * oldPos)
        val value = data(2 * oldPos + 1)
        var newPos = rehash(key.hashCode) & newMask
        var i = 1
        var keepGoing = true
        while (keepGoing) {
          val curKey = newData(2 * newPos)
          if (curKey.eq(null)) {
            newData(2 * newPos) = key
            newData(2 * newPos + 1) = value
            keepGoing = false
          } else {
            val delta = i
            newPos = (newPos + delta) & newMask
            i += 1
          }
        }
      }
      oldPos += 1
    }
    // 替换及跟新变量
    data = newData
    capacity = newCapacity
    mask = newMask
    growThreshold = (LOAD_FACTOR * newCapacity).toInt
  }

这里重新创建了一个两倍capacity 的数组来存放数据,将原来数组中的数据通过重新计算位置放到新数组里,将data替换为新的数组,并跟新一些变量。

此时聚合已经完成,回到changeValue方面里面,接下来会执行super.afterUpdate()方法来对map的大小进行采样:

protected def afterUpdate(): Unit = {
    numUpdates += 1
    if (nextSampleNum == numUpdates) {
      takeSample()
    }
  }

若每遍历跟新一条record,都来对map进行采样估计大小,假设采样一次需要1ms,100w次采样就会花上16.7分钟,性能大大降低。所以这里只有当update次数达到nextSampleNum 的时候才通过takeSample()采样一次:

private def takeSample(): Unit = {
    samples.enqueue(Sample(SizeEstimator.estimate(this), numUpdates))
    // Only use the last two samples to extrapolate
    if (samples.size > 2) {
      samples.dequeue()
    }
    // 估计每次跟新的变化量
    val bytesDelta = samples.toList.reverse match {
      case latest :: previous :: tail =>
        (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates)
      // If fewer than 2 samples, assume no change
      case _ => 0
    }
    // 跟新变化量
    bytesPerUpdate = math.max(0, bytesDelta)
    // 获取下次采样的次数
    nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong
  }

这里估计每次跟新的变化量的逻辑是:(当前map大小-上次采样的时候的大小) / (当前update的次数 - 上次采样的时候的update次数)。

接着计算下次需要采样的update次数,该次数是指数级增长的,基数是1.1,第一次采样后,要1.1次进行第二次采样,第1.1*1.1次后进行第三次采样,以此类推,开始增长慢,后面增长跨度会非常大。

这里采样完成后回到insetAll方法,接着通过maybeSpillCollection方法判断是否需要spill:

 private def maybeSpillCollection(usingMap: Boolean): Unit = {
    var estimatedSize = 0L
    if (usingMap) {
      estimatedSize = map.estimateSize()
      if (maybeSpill(map, estimatedSize)) {
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else {
      estimatedSize = buffer.estimateSize()
      if (maybeSpill(buffer, estimatedSize)) {
        buffer = new PartitionedPairBuffer[K, C]
      }
    }

    if (estimatedSize > _peakMemoryUsedBytes) {
      _peakMemoryUsedBytes = estimatedSize
    }
  }

通过集合的estimateSize方法估计map的大小,若需要spill则将集合中的数据spill到磁盘文件,并且为集合创建一个新的对象放数据。先看看估计大小的方法estimateSize:

 def estimateSize(): Long = {
    assert(samples.nonEmpty)
    val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates)
    (samples.last.size + extrapolatedDelta).toLong
  }

以上次采样完更新的bytePerUpdate 作为最近平均每次跟新的大小,估计当前占用内存:(当前update次数-上次采样时的update次数) * 每次跟新大小 + 上次采样记录的大小。

获取到当前集合的大小后调用maybeSpill判断是否需要spill:

protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = acquireMemory(amountToRequest)
      // 跟新申请到的内存
      myMemoryThreshold += granted 
      // 集合大小还是比申请到的内存大?spill : no spill
      shouldSpill = currentMemory >= myMemoryThreshold
    }
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
    // Actually spill
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      spill(collection)
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      releaseMemory()
    }
    shouldSpill
  }

这里有两种情况都可导致spill:

  • 当前集合包含的records数超过了 numElementsForceSpillThreshold(默认为Long.MaxValue,可通过spark.shuffle.spill.numElementsForceSpillThreshold设置)
  • 当前集合包含的records数为32的整数倍,并且当前集合的大小超过了申请的内存myMemoryThreshold(第一次申请默认为5 * 1024 * 1024,可通过spark.shuffle.spill.initialMemoryThreshold设置),此时并不会立即spill,会尝试申请更多的内存避免spill,这里尝试申请的内存为2倍集合大小减去当前已经申请的内存大小(实际申请到的内存为granted),若加上原来的内存还是比当前集合的大小要小则需要spill。

若需要spill,则跟新spill次数,调用spill(collection)方法进行溢写磁盘,并释放内存。
跟进spill方法看看其具体实现:

override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
    // 传入comparator将集合中的数据先根据partition排序再通过key排序后返回一个迭代器
    val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
    // 写到磁盘文件,并返回一个对该文件的描述对象SpilledFile
    val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
    // 添加到spill文件数组
    spills.append(spillFile)
  }

继续跟进看看spillMemoryIteratorToDisk的实现:

private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
      : SpilledFile = {
    // 生成临时文件和blockId
    val (blockId, file) = diskBlockManager.createTempShuffleBlock()

    // 这些值在每次flush后会被重置
    var objectsWritten: Long = 0
    var spillMetrics: ShuffleWriteMetrics = null
    var writer: DiskBlockObjectWriter = null
    def openWriter(): Unit = {
      assert (writer == null && spillMetrics == null)
      spillMetrics = new ShuffleWriteMetrics
      writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
    }
    openWriter()

    // 按写入磁盘的顺序记录分支的大小
    val batchSizes = new ArrayBuffer[Long]

    // 记录每个分区有多少元素
    val elementsPerPartition = new Array[Long](numPartitions)

    // Flush  writer 内容到磁盘,并更新相关变量
    def flush(): Unit = {
      val w = writer
      writer = null
      w.commitAndClose()
      _diskBytesSpilled += spillMetrics.bytesWritten
      batchSizes.append(spillMetrics.bytesWritten)
      spillMetrics = null
      objectsWritten = 0
    }

    var success = false
    try {
      // 遍历迭代器
      while (inMemoryIterator.hasNext) {
        val partitionId = inMemoryIterator.nextPartition()
        require(partitionId >= 0 && partitionId < numPartitions,
          s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
        inMemoryIterator.writeNext(writer)
        elementsPerPartition(partitionId) += 1
        objectsWritten += 1
        // 元素个数达到批量序列化大小则flush到磁盘
        if (objectsWritten == serializerBatchSize) {
          flush()
          openWriter()
        }
      }
      // 将剩余的数据flush
      if (objectsWritten > 0) {
        flush()
      } else if (writer != null) {
        val w = writer
        writer = null
        w.revertPartialWritesAndClose()
      }
      success = true
    } finally {
        ...
    }
    // 返回SpilledFile
    SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
  }

通过diskBlockManager创建临时文件和blockID,临时文件名格式为是 "temp_shuffle_" + id,遍历内存数据迭代器,并调用Writer(DiskBlockObjectWriter)的write方法,当写的次数达到序列化大小则flush到磁盘文件,并重新打开writer,及跟新batchSizes等信息。

最后返回一个SpilledFile对象,该对象包含了溢写的临时文件File,blockId,每次flush的到磁盘的大小,每个partition对应的数据条数。

spill完成,并且insertAll方法也执行完成,回到开始的SortShuffleWriter的write方法:

override def write(records: Iterator[Product2[K, V]]): Unit = {
    ...
    // 写内存缓冲区,超过阈值则溢写到磁盘文件
    sorter.insertAll(records)
    // 获取该task的最终输出文件
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
    val tmp = Utils.tempFileWith(output)
    try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      // merge后写到data文件
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      // 写index文件shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }

获取最后的输出文件名及blockId,文件格式:

 "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"

接着通过sorter.writePartitionedFile方法来写文件,其中包括内存及所有spill文件的merge操作,看看起具体实现:

def writePartitionedFile(
      blockId: BlockId,
      outputFile: File): Array[Long] = {

    val writeMetrics = context.taskMetrics().shuffleWriteMetrics

    // 跟踪每个分区在文件中的range
    val lengths = new Array[Long](numPartitions)
    // 数据只存在内存中
    if (spills.isEmpty) { 
      val collection = if (aggregator.isDefined) map else buffer
      // 将内存中的数据先通过partitionId再通过k排序后返回一个迭代器
      val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
      // 遍历数据写入磁盘
      while (it.hasNext) {
        val writer = blockManager.getDiskWriter(
          blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
        val partitionId = it.nextPartition()
        //等待一个partition的数据写完后刷新到磁盘文件
        while (it.hasNext && it.nextPartition() == partitionId) {
          it.writeNext(writer)
        }
        writer.commitAndClose()
        val segment = writer.fileSegment()
        // 记录每个partition数据长度
        lengths(partitionId) = segment.length
      }
    } else {
      // 有数据spill到磁盘,先merge
      for ((id, elements) <- this.partitionedIterator) {
        if (elements.hasNext) {
          val writer = blockManager.getDiskWriter(
            blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
          for (elem <- elements) {
            writer.write(elem._1, elem._2)
          }
          writer.commitAndClose()
          val segment = writer.fileSegment()
          lengths(id) = segment.length
        }
      }
    }

    context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
    context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
    context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

    lengths
  }
  • 数据只存在内存中而没有spill文件,根据传入的比较函数comparator来对集合里的数据先根据partition排序再对里面的key排序并返回一个迭代器,遍历该迭代器得到所有recored,每一个partition对应一个writer,一个partition的数据写完后再flush到磁盘文件,并记录该partition的数据长度。
  • 数据有spill文件,通过方法partitionedIterator对内存和spill文件的数据进行merge-sort后返回一个(partitionId,对应分区的数据的迭代器)的迭代器,也是一个partition对应一个Writer,写完一个partition再flush到磁盘,并记录该partition数据的长度。

接下来看看通过this.partitionedIterator方法是怎么将内存及spill文件的数据进行merge-sort的:

def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
    val usingMap = aggregator.isDefined
    val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
    if (spills.isEmpty) {
      if (!ordering.isDefined) {
        // 只根据partitionId排序,不需要对key排序
        groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
      } else {
        // 需要对partitionID和key进行排序
        groupByPartition(destructiveIterator(
          collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
      }
    } else {
      // Merge spilled and in-memory data
      merge(spills, destructiveIterator(
        collection.partitionedDestructiveSortedIterator(comparator)))
    }
  }

这里在有spill文件的情况下会执行下面的merge方法,传入的是spill文件数组和内存中的数据进过partitionId和key排序后的数据迭代器,看看merge:

private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
      : Iterator[(Int, Iterator[Product2[K, C]])] = {
    // 每个文件对应一个Reader
    val readers = spills.map(new SpillReader(_)) 
    val inMemBuffered = inMemory.buffered
    (0 until numPartitions).iterator.map { p =>
      // 获取内存中当前partition对应的Iterator
      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
      // 将spill文件对应的partition的数据与内存中对应partition数据合并
      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
      if (aggregator.isDefined) {
        // 对key进行聚合并排序
        (p, mergeWithAggregation(
          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
      } else if (ordering.isDefined) {
        // 排序
        (p, mergeSort(iterators, ordering.get))
      } else {
        (p, iterators.iterator.flatten)
      }
    }
  }

merge方法将属于同一个reduce端的partition的内存数据和spill文件数据合并起来,再进行聚合排序(有需要的话),最后返回(reduce对应的partitionId,该分区数据迭代器)

将数据merge-sort后写入最终的文件后,需要将每个partition的偏移量持久化到文件以供后续每个reduce根据偏移量获取自己的数据,写偏移量的逻辑很简单,就是根据前面得到的partition长度的数组将偏移量写到index文件中,对应的文件名为:

 "shuffle_" + shuffleId + "_" + mapId + "_" + NOOP_REDUCE_ID + ".index"

最后创建一个MapStatus实例返回,包含了reduce端每个partition对应的偏移量。

该对象将返回到Driver端的DAGScheluer处理,被添加到对应stage的OutputLoc里,当该stage的所有task完成的时候会将这些结果注册到MapOutputTrackerMaster,以便下一个stage的task就可以通过它来获取shuffle的结果的元数据信息。

至此Shuffle Write完成!

Task成功执行的结果处理

前言

在文章Task执行流程 中介绍了task是怎么被分配到executor上执行的,本文讲解task成功执行时将结果返回给driver的处理流程。

Driver端接收task完成事件

在executor上成功执行完task并拿到serializedResult 之后,通过CoarseGrainedExecutorBackend的statusUpdate方法来返回结果给driver,该方法会使用driverRpcEndpointRef 发送一条包含 serializedResult 的 StatusUpdate 消息给 driver。

execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
    val msg = StatusUpdate(executorId, taskId, state, data)
    driver match {
      case Some(driverRef) => driverRef.send(msg)
      case None => logWarning(s"Drop $msg because has not yet connected to driver")
    }
  }

而在driver端CoarseGrainedSchedulerBackend 在接收到StatusUpdate事件的处理代码如下:

case StatusUpdate(executorId, taskId, state, data) =>
        scheduler.statusUpdate(taskId, state, data.value)
        if (TaskState.isFinished(state)) {
          executorDataMap.get(executorId) match {
            case Some(executorInfo) =>
              executorInfo.freeCores += scheduler.CPUS_PER_TASK
              makeOffers(executorId)
            case None =>
              // Ignoring the update since we don't know about the executor.
              logWarning(s"Ignored task status update ($taskId state $state) " +
                s"from unknown executor with ID $executorId")
          }
        }
  • 调用TaskSchedulerImpl的statusUpdate方法来告知task的执行状态以触发相应的操作
  • task结束,空闲出相应的资源,将task对应的executor的cores进行跟新
  • 结束的task对应的executor上有了空闲资源,为其分配task

这里我们重点看看在TaskSchedulerImpl里面根据task的状态做了什么样的操作:

def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
    var failedExecutor: Option[String] = None
    var reason: Option[ExecutorLossReason] = None
    synchronized {
      try {
        // task丢失,则标记对应的executor也丢失,并涉及到一些映射跟新
        if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
          // We lost this entire executor, so remember that it's gone
          val execId = taskIdToExecutorId(tid)

          if (executorIdToTaskCount.contains(execId)) {
            reason = Some(
              SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
            removeExecutor(execId, reason.get)
            failedExecutor = Some(execId)
          }
        }
        //获取task所在的taskSetManager
        taskIdToTaskSetManager.get(tid) match {
          case Some(taskSet) =>
            if (TaskState.isFinished(state)) {
              taskIdToTaskSetManager.remove(tid)
              taskIdToExecutorId.remove(tid).foreach { execId =>
                if (executorIdToTaskCount.contains(execId)) {
                  executorIdToTaskCount(execId) -= 1
                }
              }
            }
            // task成功的处理
            if (state == TaskState.FINISHED) {
              // 将当前task从taskSet中正在执行的task列表中移除
              taskSet.removeRunningTask(tid)
              //成功执行时,在线程池中处理任务的结果
              taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
            //处理失败的情况
            } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
              taskSet.removeRunningTask(tid)
              taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
            }
          case None =>
            logError(
              ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
                "likely the result of receiving duplicate task finished status updates)")
                .format(state, tid))
        }
      } catch {
        case e: Exception => logError("Exception in statusUpdate", e)
      }
    }
    // Update the DAGScheduler without holding a lock on this, since that can deadlock
    if (failedExecutor.isDefined) {
      assert(reason.isDefined)
      dagScheduler.executorLost(failedExecutor.get, reason.get)
      backend.reviveOffers()
    }
  }

task状态为Lost,则标记对应的executor也丢失,并涉及到一些映射跟新和意味着该executor上对应的task的重新分配;还有其他一些状态暂时不做解析。主要看task状态为FINISHED时,通过taskResultGetter的enqueueSuccessfulTask方法将task的的结果处理丢到了线程池中执行:

def enqueueSuccessfulTask(
      taskSetManager: TaskSetManager,
      tid: Long,
      serializedData: ByteBuffer): Unit = {
    getTaskResultExecutor.execute(new Runnable {
      override def run(): Unit = Utils.logUncaughtExceptions {
        try {
          // 从serializedData反序列化出result和结果大小
          val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
            // 可直接获取的结果
            case directResult: DirectTaskResult[_] =>
              // taskSet的总结果大小超过限制
              if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
                return
              } 
              directResult.value()
              // 直接返回结果及大小
              (directResult, serializedData.limit())
            // 可间接的获取执行结果,需借助BlockManager来获取
            case IndirectTaskResult(blockId, size) =>
              // 若大小超多了taskSetManager能抓取的最大限制,则删除远程节点上对应的blockManager 
              if (!taskSetManager.canFetchMoreResults(size)) {
                // dropped by executor if size is larger than maxResultSize
                sparkEnv.blockManager.master.removeBlock(blockId)
                return
              }
              logDebug("Fetching indirect task result for TID %s".format(tid))
              // 标记Task为需要远程抓取的Task并通知DAGScheduler              
              scheduler.handleTaskGettingResult(taskSetManager, tid)
              // 从远程的BlockManager上获取Task计算结果 
              val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
              // 抓取结果失败,结果丢失
              if (!serializedTaskResult.isDefined) {
               // 在Task执行结束获得结果后到driver远程去抓取结果之间,如果运行task的机器挂掉,
               // 或者该机器的BlockManager已经刷新掉了Task执行结果,都会导致远程抓取结果失败。
                scheduler.handleFailedTask(
                  taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
                return
              }
              // 抓取结果成功,反序列化结果
              val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
                serializedTaskResult.get.toByteBuffer)
                // 删除远程BlockManager对应的结果
               sparkEnv.blockManager.master.removeBlock(blockId)
              // 返回结果
              (deserializedResult, size)
          }
          ...
        // 通知scheduler处理成功Task
        scheduler.handleSuccessfulTask(taskSetManager, tid, result)
        } catch { 
          ...
        }
      }
    })
  }
  • 将serializedData反序列化
  • 若是可以直接获取的结果(DirectTaskResult),在当前taskSet已完成task的结果总大小还未超过限制(spark.driver.maxResultSize,默认1G)时可以直接返回其反序列化后的结果。
  • 若是可间接获取的结果(IndirectTaskResult),在大小满足条件的前提下,标记Task为需要远程抓取的Task并通知DAGScheduler,从远程的BlockManager上获取Task计算结果,若获取失败则通知scheduler进行失败处理,失败原因有两种:
    • 在Task执行结束获得结果后到driver远程去抓取结果之间,如果运行task的机器挂掉
    • 该机器的BlockManager已经刷新掉了Task执行结果
  • 获取结果远程获取结果成功后删除远程BlockManager对应的结果,则直接返回其序列化后的结果
  • 最后将该task对应的TaskSetMagager和tid和结果作为参数通知scheduler处理成功的task

继续跟进scheduler是如何处理成功的task:

def handleSuccessfulTask(
      taskSetManager: TaskSetManager,
      tid: Long,
      taskResult: DirectTaskResult[_]): Unit = synchronized {
    taskSetManager.handleSuccessfulTask(tid, taskResult)
  }

里面调用了该taskSetManager对成功task的处理方法:

def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = {
    val info = taskInfos(tid)
    val index = info.index
    info.markSuccessful()
    // 从线程池中移除该task
    removeRunningTask(tid)
    // 通知dagScheduler
    sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info)
    // 标记该task成功处理
    if (!successful(index)) {
      tasksSuccessful += 1
      logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format(
        info.id, taskSet.id, info.taskId, info.duration, info.host, tasksSuccessful, numTasks))
      // Mark successful and stop if all the tasks have succeeded.
      successful(index) = true
      if (tasksSuccessful == numTasks) {
        isZombie = true
      }
    } else {
      logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
        " because task " + index + " has already completed successfully")
    }
    // 从失败过的task->executor中移除
    failedExecutors.remove(index)
    // 若该taskSet所有task都成功执行
    maybeFinishTaskSet()
  }

逻辑很简单,标记task成功运行、跟新failedExecutors、若taskSet所有task都成功执行的一些处理,我们具体看看是怎么通知dagScheduler的,这里调用了dagScheduler的taskEnded方法:

def taskEnded(
      task: Task[_],
      reason: TaskEndReason,
      result: Any,
      accumUpdates: Seq[AccumulatorV2[_, _]],
      taskInfo: TaskInfo): Unit = {
    eventProcessLoop.post(
      CompletionEvent(task, reason, result, accumUpdates, taskInfo))
  }

这里像DAGScheduler Post了一个CompletionEvent事件,在DAGScheduler#doOnReceive有对应的处理:

// DAGScheduler#doOnReceive
 case completion: CompletionEvent =>
      dagScheduler.handleTaskCompletion(completion)

继续看看 dagScheduler#handleTaskCompletion的实现,代码太长,列出主要逻辑部分:

 private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
    ...
    val stage = stageIdToStage(task.stageId)
    event.reason match {
      case Success =>
        // 从该stage中等待处理的partition列表中移除Task对应的partition 
        stage.pendingPartitions -= task.partitionId
        task match {
          case rt: ResultTask[_, _] =>
            // Cast to ResultStage here because it's part of the ResultTask
            // TODO Refactor this out to a function that accepts a ResultStage
            val resultStage = stage.asInstanceOf[ResultStage]
            resultStage.activeJob match {
              case Some(job) =>
                if (!job.finished(rt.outputId)) {
                  updateAccumulators(event)
                  job.finished(rt.outputId) = true
                  job.numFinished += 1
                  // If the whole job has finished, remove it
                  if (job.numFinished == job.numPartitions) {
                    markStageAsFinished(resultStage)
                    cleanupStateForJobAndIndependentStages(job)
                    listenerBus.post(
                      SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
                  }

                  // taskSucceeded runs some user code that might throw an exception. Make sure
                  // we are resilient against that.
                  try {
                    job.listener.taskSucceeded(rt.outputId, event.result)
                  } catch {
                    case e: Exception =>
                      // TODO: Perhaps we want to mark the resultStage as failed?
                      job.listener.jobFailed(new SparkDriverExecutionException(e))
                  }
                }
              case None =>
                logInfo("Ignoring result from " + rt + " because its job has finished")
            }
          // 若是ShuffleMapTask
          case smt: ShuffleMapTask =>
            val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
            updateAccumulators(event)
            val status = event.result.asInstanceOf[MapStatus]
            val execId = status.location.executorId
            logDebug("ShuffleMapTask finished on " + execId)
            // 忽略在集群中游走的ShuffleMapTask(来自一个失效的节点的Task结果)。
            if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
              logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
            } else {
              // 将结果保存到对应的Stage
              shuffleStage.addOutputLoc(smt.partitionId, status)
            }
            // 若当前stage的所有task已经全部执行完毕
            if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) {
              markStageAsFinished(shuffleStage)
              logInfo("looking for newly runnable stages")
              logInfo("running: " + runningStages)
              logInfo("waiting: " + waitingStages)
              logInfo("failed: " + failedStages)

              // 将stage的结果注册到MapOutputTrackerMaster
              mapOutputTracker.registerMapOutputs(
                shuffleStage.shuffleDep.shuffleId,
                shuffleStage.outputLocInMapOutputTrackerFormat(),
                changeEpoch = true)
              // 清除本地缓存
              clearCacheLocs()
              // 若stage一些task执行失败没有结果,重新提交stage来调度执行未执行的task
              if (!shuffleStage.isAvailable) {
                // Some tasks had failed; let's resubmit this shuffleStage
                // TODO: Lower-level scheduler should also deal with this
                logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
                  ") because some of its tasks had failed: " +
                  shuffleStage.findMissingPartitions().mkString(", "))
                submitStage(shuffleStage)
              } else {
                // 标记所有等待这个Stage结束的Map-Stage Job为结束状态 
                if (shuffleStage.mapStageJobs.nonEmpty) {
                  val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep)
                  for (job <- shuffleStage.mapStageJobs) {
                    markMapStageJobAsFinished(job, stats)
                  }
                }
              }

              // Note: newly runnable stages will be submitted below when we submit waiting stages
            }
        }
        ...
    }
    submitWaitingStages()
  }

当task为ShuffleMapTask时,该task不是在无效节点的运行的条件下将结果保存到stage中,若当前stage的所有task都运行完毕(不一定成功),则将所有结果注册到MapOutputTrackerMaster(以便下一个stage的task就可以通过它来获取shuffle的结果的元数据信息);然后清空本地缓存;当该stage有task没有成功执行也就没有结果,需要重新提交该stage运行未完成的task;若所有task都成功完成,说明该stage已经完成,则会去标记所有等待这个Stage结束的Map-Stage Job为结束状态。

当task为ResultTask时,增加job完成的task数,若所有task全部完成即job已经完成,则标记该stage完成并从runningStages中移除,在cleanupStateForJobAndIndependentStages方法中,遍历当前job的所有stage,在对应stage没有依赖的job时则直接将此stage移除。然后将当前job从activeJob中移除。

最后调用job.listener.taskSucceeded(rt.outputId, event.result),实际调用的是JobWaiter(JobListener的具体实现)的taskSucceeded方法:

override def taskSucceeded(index: Int, result: Any): Unit = {
    // resultHandler call must be synchronized in case resultHandler itself is not thread safe.
    synchronized {
      resultHandler(index, result.asInstanceOf[T])
    }
    if (finishedTasks.incrementAndGet() == totalTasks) {
      jobPromise.success(())
    }
  }

这里的resultHandler就是在action操作触发runJob的时候规定的一种结果处理器:

def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int]): Array[U] = {
    val results = new Array[U](partitions.size)
    runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res)
    results
  }

这里的(index, res) => results(index) = res 就是resultHandler,也就是将这里的results数组填满再返回,根据不同的action进行不同操作。
若完成的task数和totalTasks数相等,则该job成功执行,打印日志完成。

Spark 实现MySQL update操作

背景

目前 spark 对 MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore几种表级别的模式,有时我们需要对表进行行级别的操作,比如update。即我们需要构造这样的语句出来:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;

需求:我们的目的是既不影响以前写的代码,又不引入新的API,只需新加一个配置如:savemode=update这样的形式来实现。

实践

要满足以上需求,肯定是要改源码的,首先创建自己的saveMode,只是新加了一个Update而已:

public enum I4SaveMode {
    Append,
    Overwrite,
    ErrorIfExists,
    Ignore,
    Update
}

JDBC数据源的相关实现主要在JdbcRelationProvider里,我们需要关注的是createRelation方法,我们可以在此方法里,把SaveMode改成我们自己的mode,并把mode带到saveTable方法里,所以改造后的方法如下(改了的地方都有注释):

   override def createRelation(
                                   sqlContext: SQLContext,
                                   mode: SaveMode,
                                   parameters: Map[String, String],
                                   df: DataFrame): BaseRelation = {
        val options = new JDBCOptions(parameters)
        val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
        // 替换成自己的saveMode
        var saveMode = mode match {
                case SaveMode.Overwrite => I4SaveMode.Overwrite
                case SaveMode.Append => I4SaveMode.Append
                case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists
                case SaveMode.Ignore => I4SaveMode.Ignore
            }
        //重点在这里,检查是否有saveMode=update的参数,并设为对应的模式
        val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2))
        if(parameterLower.keySet.contains("savemode")){
            saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode
        }
        val conn = JdbcUtils.createConnectionFactory(options)()
        try {
            val tableExists = JdbcUtils.tableExists(conn, options)
            if (tableExists) {
                saveMode match {
                    case I4SaveMode.Overwrite =>
                        if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
                            // In this case, we should truncate table and then load.
                            truncateTable(conn, options.table)
                            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                            saveTable(df, tableSchema, isCaseSensitive, options, saveMode)
                        } else {
                        ......
    }

接下来就是saveTable方法:

def saveTable(
      df: DataFrame,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      options: JDBCOptions,
      mode: I4SaveMode): Unit = { 
    ......
    val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
    .....
    repartitionedDF.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
    )
  }

这里通过getInsertStatement方法构造sql语句,接着遍历每个分区进行对应的save操作,我们先看是构造语句是怎么改的(改了的地方都有注释):

def getInsertStatement(
      table: String,
      rddSchema: StructType,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      dialect: JdbcDialect,
      mode: I4SaveMode): String = {
    val columns = if (tableSchema.isEmpty) {
      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    } else {
      val columnNameEquality = if (isCaseSensitive) {
        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
      } else {
        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
      } 
      val tableColumnNames = tableSchema.get.fieldNames
      rddSchema.fields.map { col =>
        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
          throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
        }
        dialect.quoteIdentifier(normalizedName)
      }.mkString(",")
    } 
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    // s"INSERT INTO $table ($columns) VALUES ($placeholders)"
   //若为update模式需要单独构造
    mode match {
            case I4SaveMode.Update ⇒
                val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name ⇒ s"$name=?").mkString(",")
                s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
            case _ ⇒ s"INSERT INTO $table ($columns) VALUES ($placeholders)"
        }
  }

只需判断是否是update模式来构造对应的 sql语句,接着主要是看 savePartition 方法,看看具体是怎么保存的:

 def savePartition(
      getConnection: () => Connection,
      table: String,
      iterator: Iterator[Row],
      rddSchema: StructType,
      insertStmt: String,
      batchSize: Int,
      dialect: JdbcDialect,
      isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = conn.prepareStatement(insertStmt)
      val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
      val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
      val numFields = rddSchema.fields.length

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          // If there is no cause already, set 'next exception' as cause. If cause is null,
          // it *may* be because no cause was set yet
          if (e.getCause == null) {
            try {
              e.initCause(cause)
            } catch {
              // Or it may be null because the cause *was* explicitly initialized, to *null*,
              // in which case this fails. There is no other way to detect it.
              // addSuppressed in this case as well.
              case _: IllegalStateException => e.addSuppressed(cause)
            }
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

大体**就是在迭代该分区数据进行插入之前就先根据数据的schema设置好了插入模板setters,迭代的时候只需将此模板应用到每一行数据上就行了,避免了每一行都需要去判断数据类型。
在非update的情况下:insert into tb (id,name,age) values (?,?,?)
在update情况下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
即占位符多了一倍,在update模式下进行写入的时候需要向PreparedStatement多喂一遍数据。原本的makeSetter方法如下:

private def makeSetter(
      conn: Connection,
      dialect: JdbcDialect,
      dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))
    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))
    ...
  }

我们只需要再加一个相对位置参数offset来控制,即改造成:

private def makeSetter(
       conn: Connection,
       dialect: JdbcDialect,
       dataType: DataType): JDBCValueSetter = dataType match {
     case IntegerType ⇒
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
             stmt.setInt(pos + 1, row.getInt(pos - offset))
     case LongType ⇒
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
             stmt.setLong(pos + 1, row.getLong(pos - offset))
    ...

在非update模式下offset就为0,在update模式下在没有超过numFields时offset为0,超过numFileds时offset为numFields。改造后的savePartition方法为:

def savePartition(
	             getConnection: () => Connection,
	             table: String,
	             iterator: Iterator[Row],
	             rddSchema: StructType,
	             insertStmt: String,
	             batchSize: Int,
	             dialect: JdbcDialect,
	             isolationLevel: Int,
	             mode: I4SaveMode): Iterator[Byte] = {
	...
	//判断是否为update
	val isUpdateMode = mode == I4SaveMode.Update
	val stmt = conn.prepareStatement(insertStmt)
	val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
	val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
	val length = rddSchema.fields.length
	// update模式下占位符是2倍
	val numFields = if (isUpdateMode) length * 2 else length
	val midField = numFields / 2
	try {
	    var rowCount = 0
	    while (iterator.hasNext) {
	        val row = iterator.next()
	        var i = 0
	        while (i < numFields) {
	            if (isUpdateMode) {
	                // update模式下未超过字段长度,offset为0
	                i < midField match {
	                    case true ?
	                        if (row.isNullAt(i)) {
	                            stmt.setNull(i + 1, nullTypes(i))
	                        } else {
	                            setters(i).apply(stmt, row, i, 0)
	                        }
	                    // update模式下超过字段长度,offset为midField,即字段长度
	                    case false ?
	                        if (row.isNullAt(i - midField)) {
	                            stmt.setNull(i + 1, nullTypes(i - midField))
	                        } else {
	                            setters(i - midField).apply(stmt, row, i, midField)
	                        }
	                }
	            
	            } else {
	                if (row.isNullAt(i)) {
	                    stmt.setNull(i + 1, nullTypes(i))
	                } else {
	                    setters(i).apply(stmt, row, i, 0)
	                }
	            }
	            i = i + 1
	        }
	      ...

改造好源码后,需要重新编译打包,替换掉线上对应的jar即可。其实这里有个捷径,自己创建相同的包名,改好源码后打成jar包,把该jar里面的class文件替换掉线上jar里面对应的那些class文件就可以了。

如何使用

若需要使用到update模式:

df.write.option("saveMode","update").jdbc(...)

参考

https://blog.csdn.net/cjuexuan/article/details/52333970

SparkStreaming 动态生成 Job 并提交执行

前言

Spark Streaming Job的生成是通过JobGenerator每隔 batchDuration 长时间动态生成的,每个batch 对应提交一个JobSet,因为针对一个batch可能有多个输出操作。

概述流程:

  • 定时器定时向 eventLoop 发送生成job的请求
  • 通过receiverTracker 为当前batch分配block
  • 为当前batch生成对应的 Jobs
  • 将Jobs封装成JobSet 提交执行

入口

在 JobGenerator 初始化的时候就创建了一个定时器:

private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
    longTime => eventLoop.post(GenerateJobs(new Time(longTime))), "JobGenerator")

每隔 batchDuration 就会向 eventLoop 发送 GenerateJobs(new Time(longTime))消息,eventLoop的事件处理方法中会调用generateJobs(time)方法:

      case GenerateJobs(time) => generateJobs(time)
private def generateJobs(time: Time) {
    // Checkpoint all RDDs marked for checkpointing to ensure their lineages are
    // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
    ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
    Try {
      jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
      graph.generateJobs(time) // generate jobs using allocated block
    } match {
      case Success(jobs) =>
        val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time)
        jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos))
      case Failure(e) =>
        jobScheduler.reportError("Error generating jobs for time " + time, e)
        PythonDStream.stopStreamingContextIfPythonProcessIsDead(e)
    }
    eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false))
  }

为当前batchTime分配Block

首先调用receiverTracker.allocateBlocksToBatch(time)方法为当前batchTime分配对应的Block,最终会调用receiverTracker的Block管理者receivedBlockTrackerallocateBlocksToBatch方法:

def allocateBlocksToBatch(batchTime: Time): Unit = synchronized {
    if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) {
      val streamIdToBlocks = streamIds.map { streamId =>
          (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true))
      }.toMap
      val allocatedBlocks = AllocatedBlocks(streamIdToBlocks)
      if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) {
        timeToAllocatedBlocks.put(batchTime, allocatedBlocks)
        lastAllocatedBatchTime = batchTime
      } else {
        logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery")
      }
    } else {
      logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery")
    }
  }
private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = {
    streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new ReceivedBlockQueue)
  }

可以看到是从streamIdToUnallocatedBlockQueues中获取到所有streamId对应的未分配的blocks,该队列的信息是supervisor 存储好Block后向receiverTracker上报的Block信息,详情可见 ReceiverTracker 数据产生与存储

获取到所有streamId对应的未分配的blockInfos后,将其放入了timeToAllocatedBlocks:Map[Time, AllocatedBlocks]中,后面生成RDD的时候会用到。

为当前batchTime生成Jobs

调用DStreamGraphgenerateJobs方法为当前batchTime生成job:

 def generateJobs(time: Time): Seq[Job] = {
    logDebug("Generating jobs for time " + time)
    val jobs = this.synchronized {
      outputStreams.flatMap { outputStream =>
        val jobOption = outputStream.generateJob(time)
        jobOption.foreach(_.setCallSite(outputStream.creationSite))
        jobOption
      }
    }
    logDebug("Generated " + jobs.length + " jobs for time " + time)
    jobs
  }

一个outputStream就对应一个job,遍历所有的outputStreams,为其生成job:

# ForEachDStream
override def generateJob(time: Time): Option[Job] = {
    parent.getOrCompute(time) match {
      case Some(rdd) =>
        val jobFunc = () => createRDDWithLocalProperties(time, displayInnerRDDOps) {
          foreachFunc(rdd, time)
        }
        Some(new Job(time, jobFunc))
      case None => None
    }
  }

先获取到time对应的RDD,然后将其作为参数再调用foreachFunc方法,foreachFunc方法是通过构造器传过来的,我们来看看print()输出的情况:

def print(num: Int): Unit = ssc.withScope {
    def foreachFunc: (RDD[T], Time) => Unit = {
      (rdd: RDD[T], time: Time) => {
        val firstNum = rdd.take(num + 1)
        // scalastyle:off println
        println("-------------------------------------------")
        println(s"Time: $time")
        println("-------------------------------------------")
        firstNum.take(num).foreach(println)
        if (firstNum.length > num) println("...")
        println()
        // scalastyle:on println
      }
    }
    foreachRDD(context.sparkContext.clean(foreachFunc), displayInnerRDDOps = false)
  }

这里的构造的foreachFunc方法就是最终和rdd一起提交job的执行方法,也即对rdd调用take()后并打印,真正触发action操作的是在这个func函数里,现在再来看看是怎么拿到rdd的,每个DStream都有一个generatedRDDs:Map[Time, RDD[T]]变量,来保存time对应的RDD,若获取不到则会通过compute()方法来计算,对于需要在executor上启动Receiver来接收数据的ReceiverInputDStream来说:

 override def compute(validTime: Time): Option[RDD[T]] = {
    val blockRDD = {

      if (validTime < graph.startTime) {
        // If this is called for any time before the start time of the context,
        // then this returns an empty RDD. This may happen when recovering from a
        // driver failure without any write ahead log to recover pre-failure data.
        new BlockRDD[T](ssc.sc, Array.empty)
      } else {
        // Otherwise, ask the tracker for all the blocks that have been allocated to this stream
        // for this batch
        val receiverTracker = ssc.scheduler.receiverTracker
        val blockInfos = receiverTracker.getBlocksOfBatch(validTime).getOrElse(id, Seq.empty)

        // Register the input blocks information into InputInfoTracker
        val inputInfo = StreamInputInfo(id, blockInfos.flatMap(_.numRecords).sum)
        ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)

        // Create the BlockRDD
        createBlockRDD(validTime, blockInfos)
      }
    }
    Some(blockRDD)
  }

会通过receiverTracker来获取该batch对应的blocks,前面已经分析过为所有streamId分配了对应的未分配的block,并且放在了timeToAllocatedBlocks:Map[Time, AllocatedBlocks]中,这里底层就是从这个timeToAllocatedBlocks获取到的blocksInfo,然后调用了createBlockRDD(validTime, blockInfos)通过blockId创建了RDD。

最后,将通过此RDD和foreachFun构建jobFunc,并创建Job返回。

封装jobs成JobSet并提交执行

每个outputStream对应一个Job,最终就会生成一个jobs,为这个jobs创建JobSet,并通过jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos))来提交这个JobSet:

jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job)))

然后通过jobExecutor来执行,jobExecutor是一个线程池,并行度默认为1,可通过spark.streaming.concurrentJobs配置,即同时可执行几个批次的数据。

处理类JobHandler中调用的是Job.run(),执行的是前面构建的 jobFunc 方法。

spark任务之Task失败监控

需求

spark应用程序中,只要task失败就发送邮件,并携带错误原因。

背景

在spark程序中,task有失败重试机制(根据 spark.task.maxFailures 配置,默认是4次),当task执行失败时,并不会直接导致整个应用程序down掉,只有在重试了 spark.task.maxFailures 次后任然失败的情况下才会使程序down掉。另外,spark on yarn模式还会受yarn的重试机制去重启这个spark程序,根据 yarn.resourcemanager.am.max-attempts 配置(默认是2次)。

即使spark程序task失败4次后,受yarn控制重启后在第4次执行成功了,一切都好像没有发生,我们只有通过spark的监控UI去看是否有失败的task,若有还得去查找看是哪个task由于什么原因失败了。基于以上原因,我们需要做个task失败的监控,只要失败就带上错误原因通知我们,及时发现问题,促使我们的程序更加健壮。

捕获Task失败事件

顺藤摸瓜,task在Executor中执行,跟踪源码看task在失败后都干了啥?

  1. 在executor中task执行完不管成功与否都会向execBackend报告task的状态;
 execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
  1. 在CoarseGrainedExecutorBackend中会向driver发送StatusUpdate状态变更信息;
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
    val msg = StatusUpdate(executorId, taskId, state, data)
    driver match {
      case Some(driverRef) => driverRef.send(msg)
      case None => logWarning(s"Drop $msg because has not yet connected to driver")
    }
  }
  1. CoarseGrainedSchedulerBackend收到消息后有调用了scheduler的方法;
override def receive: PartialFunction[Any, Unit] = {
      case StatusUpdate(executorId, taskId, state, data) =>
        scheduler.statusUpdate(taskId, state, data.value)
        ......
  1. 由于代码繁琐,列出了关键的几行代码,嵌套调用关系,这里最后向eventProcessLoop发送了CompletionEvent事件;
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
taskSetManager.handleFailedTask(tid, taskState, reason)
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info)
eventProcessLoop.post(CompletionEvent(task, reason, result, accumUpdates, taskInfo)) 
  1. DAGSchedulerEventProcessLoop处理方法中 handleTaskCompletion(event: CompletionEvent)有着最为关键的一行代码,这里listenerBus把task的状态发了出去,凡是监听了SparkListenerTaskEnd的listener都可以获取到对应的消息,而且这个是带了失败的原因(event.reason)。其实第一遍走源码并没有注意到前面提到的sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info)方法,后面根据SparkUI的page页面往回追溯才发现。
 listenerBus.post(SparkListenerTaskEnd(
       stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics))

自定义监听器

需要获取到SparkListenerTaskEnd事件,得继承SparkListener类并重写onTaskEnd方法,
在方法中获取task失败的reason,发送邮件给对应的负责人。这样我们就可以第一时间知道哪个task是以什么原因失败了。

import cn.i4.utils.MailUtil
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}

class I4SparkAppListener(conf: SparkConf) extends SparkListener with Logging {

  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
    val info = taskEnd.taskInfo
    // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task
    // completion event is for. Let's just drop it here. This means we might have some speculation
    // tasks on the web ui that's never marked as complete.
    if (info != null && taskEnd.stageAttemptId != -1) {
      val errorMessage: Option[String] =
        taskEnd.reason match {
          case kill: TaskKilled =>
            Some(kill.toErrorString)
          case e: ExceptionFailure =>
            Some(e.toErrorString)
          case e: TaskFailedReason =>
            Some(e.toErrorString)
          case _ => None
        }
      if (errorMessage.nonEmpty) {
        if (conf.getBoolean("enableSendEmailOnTaskFail", false)) {
          val args = Array("********@qq.com", "spark任务监控", errorMessage.get)
          try {
            MailUtil.sendMail(args)
          } catch {
            case e: Exception =>
          }
        }
      }
    }
  }
}

注意这里还需要在我们的spark程序中注册好这个listener:

.config("enableSendEmailOnTaskFail", "true")
.config("spark.extraListeners", "cn.i4.monitor.streaming.I4SparkAppListener")

总结

这里只是实现了一个小demo,可以做的更完善使之更通用,比如加上应用程序的名字、host、stageid、taskid等,单独达成jia包放到classPath,并把该listener的注册放到默认配置文件中永久有效,只需控制enableSendEmailOnTaskFail控制是否启用。

RDD缓存源码解析

spark的缓存机制保证了需要访问重复数据的应用(如迭代型算法和交互式应用)可以运行的更快。

完整的存储级别介绍如下所示:

Storage Level Meaning
MEMORY_ONLY 将RDD作为非序列化的Java对象存储在jvm中。如果RDD不能被内存装下,一些分区将不会被缓存,并且在需要的时候被重新计算。这是系统默认的存储级别。
MEMORY_AND_DISK 将RDD作为非序列化的Java对象存储在jvm中。如果RDD不能被与内存装下,超出的分区将被保存在硬盘上,并且在需要时被读取。
MEMORY_ONLY_SER 将RDD作为序列化的Java对象存储(每个分区一个byte数组)。这种方式比非序列化方式更节省空间,特别是用到快速的序列化工具时,但是会更耗费cpu资源—密集的读操作。
MEMORY_AND_DISK_SER 和MEMORY_ONLY_SER类似,但不是在每次需要时重复计算这些不适合存储到内存中的分区,而是将这些分区存储到磁盘中。
DISK_ONLY 仅仅将RDD分区存储到磁盘中
MEMORY_ONLY_2, MEMORY_AND_DISK_2, etc. 和上面的存储级别类似,但是复制每个分区到集群的两个节点上面

如何使用

我们可以利用不同的存储级别存储每一个被持久化的RDD。可以存储在内存中,也可以序列化后存储在磁盘上等方式。Spark也会自动持久化一些shuffle操作(如reduceByKey)中的中间数据,即使用户没有调用persist方法。这样的好处是避免了在shuffle出错情况下,需要重复计算整个输入。

系统将要计算 RDD partition 的时候就去判断 partition 要不要被 cache。如果要被 cache 的话,先将 partition 计算出来,然后 cache 到内存。

我们也可以通过persist()或者cache()方法持久化一个rdd,但只有当action操作时才会触发cache的真正执行,下面看看两者的区别:

def cache(): this.type = persist()

def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)

def persist(newLevel: StorageLevel): this.type = {
    if (isLocallyCheckpointed) {  //该RDD之前被checkpoint过,说明RDD已经被缓存过。
       //我们只需要直接覆盖原来的存储级别即可
      persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true)
    } else {
      persist(newLevel, allowOverride = false)
    }
  }

private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
    // 原来的存储级别不为NONE;新存储级别!=原来的存储界别;不允许覆盖
    if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
      throw new UnsupportedOperationException(  
        "Cannot change storage level of an RDD after it was already assigned a level")
    }
    if (storageLevel == StorageLevel.NONE) {  // 第一次调用persist
      sc.cleaner.foreach(_.registerRDDForCleanup(this))  // 通过sc来清理注册
      sc.persistRDD(this) //缓存RDD
    }
    storageLevel = newLevel //跟新存储级别
    this
  }

可以直观的看到cache直接调用了无参的persist(),该方法即默认使用了StorageLevel.MEMORY_ONLY级别的存储,另外两个重载的方法细节可看代码中的注释。

什么时候会用到缓存的RDD

当真正需要计算某个分区的数据时,将会触发RDD的iterator方法执行,该方法会返回一个迭代器,迭代器可遍历分区所有数据。

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      getOrCompute(split, context)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

执行的第一步便是检查当前RDD的存储级别,若不为NONE则之前肯定对RDD执行过persist操作,继续跟进getOrCompute方法

private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
    val blockId = RDDBlockId(id, partition.index)
    var readCachedBlock = true
    // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
      readCachedBlock = false
      computeOrReadCheckpoint(partition, context)
    }) match {
       ...
    }
  }

通过rddid和partitionid唯一标示一个block,由blockManager的方法getOrElseUpdate获取对应的block,若未获取到则执行computeOrReadCheckpoint来获取,未获取到的原因可能是数据丢失或者该rdd被persist了但是是第一次计算,跟进方法getOrElseUpdate:

 def getOrElseUpdate[T](
      blockId: BlockId,
      level: StorageLevel,
      classTag: ClassTag[T],
      makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
    // Attempt to read the block from local or remote storage. If it's present, then we don't need
    // to go through the local-get-or-put path.
    get[T](blockId)(classTag) match {
      case Some(block) =>
        return Left(block)
      case _ =>
        // Need to compute the block.
    }
    // Initially we hold no locks on this block.
    doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
          ...
    }
  }

getOrElseUpdate方法中会尝试从本地或者远程存储介质中获取数据,若为获取到则会通过computeOrReadCheckpoint来获取数据,该方法也在存储级别为NONE时调用,跟进方法computeOrReadCheckpoint:

private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointedAndMaterialized) {
      firstParent[T].iterator(split, context)
    } else {
      compute(split, context)
    }
  }

若当前RDD被checkpoint过,则直接调用其父RDD checkpointRDD的iterator方法来获取数据,最后实在取不到数据就只有通过RDD的compute计算出来了。

获取 cached partitions 的存储位置

partition 被 cache 后所在节点上的 blockManager 会通知 driver 上的 blockMangerMasterActor 说某 rdd 的 partition 已经被我 cache 了,这个信息会存储在 blockMangerMasterActor 的 blockLocations: HashMap中。等到 task 执行需要 cached rdd 的时候,会调用 blockManagerMaster 的 getLocations(blockId) 去询问某 partition 的存储位置,这个询问信息会发到 driver 那里,driver 查询 blockLocations 获得位置信息并将信息送回。

读取其他节点上的 cached partition:task 得到 cached partition 的位置信息后,将 GetBlock(blockId) 的请求通过 connectionManager 发送到目标节点。目标节点收到请求后从本地 blockManager 那里的 memoryStore 读取 cached partition,最后发送回来。

[Spark SQL] 源码解析之Analyzer

前言

由前面博客我们知道了SparkSql整个解析流程如下:

  • sqlText 经过 SqlParser 解析成 Unresolved LogicalPlan;
  • analyzer 模块结合catalog进行绑定,生成 resolved LogicalPlan;
  • optimizer 模块对 resolved LogicalPlan 进行优化,生成 optimized LogicalPlan;
  • SparkPlan 将 LogicalPlan 转换成PhysicalPlan;
  • prepareForExecution()将 PhysicalPlan 转换成可执行物理计划;
  • 使用 execute()执行可执行物理计划;

详解analyzer模块

Analyzer模块将Unresolved LogicalPlan结合元数据catalog进行绑定,最终转化为Resolved LogicalPlan。跟着代码看流程:

// 代码1
spark.sql("select * from table").show(false)
---
// 代码2
def sql(sqlText: String): DataFrame = {
    Dataset.ofRows(self, sessionState.sqlParser.parsePlan(sqlText))
  }
---
// 代码3
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
    val qe = sparkSession.sessionState.executePlan(logicalPlan)
    qe.assertAnalyzed()
    new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
  }

代码2中的后半段sessionState.sqlParser.parsePlan(sqlText)在上篇博客已经解析,即将sqlText通过第三方解析器antlr解析成语法树。

接着进入代码3,通过Unresolved LogicalPlan创建QueryExecution对象, 这是一个非常关键的类,analyzer 、optimizer 、SparkPlan、executedPlan等都是在该类中触发的。继续跟着代码3走:

// 代码4
def assertAnalyzed(): Unit = {
    // Analyzer is invoked outside the try block to avoid calling it again from within the
    // catch block below.
    analyzed
   ...
// 代码5
lazy val analyzed: LogicalPlan = {
    SparkSession.setActiveSession(sparkSession)
    sparkSession.sessionState.analyzer.execute(logical)
  }

最终调用analyzer的execute方法,该方法在Analyzer的父类RuleExecutor中,另外还继承了CheckAnalysis 类,用于对 plan 做一些解析,如果解析失败则抛出用户层面的错误:

class Analyzer(
    catalog: SessionCatalog,
    conf: SQLConf,
    maxIterations: Int)
  extends RuleExecutor[LogicalPlan] with CheckAnalysis {

可以看到构造器中有SessionCatalog类型的catalog,此类管理着临时表、view、函数及外部依赖元数据(如hive metastore),是analyzer进行绑定的桥梁。

继承了RuleExecutor的类(Analyzer、Optimizer)需要实现def batches: Seq[Batch]方法,在execute方法中再对此batches进行遍历执行,batches 由多个Batch构成,每个Batch由多个Rule构成,看看Batch的定义protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*),Strategy是每个Batch的执行策略即该batch被最大执行次数maxIterations ,Once和FixedPoint即执行一次和多次(默认是100次),停止执行batch的条件有两个,一是在执行maxIterations 次之前规则前后plan没有变化,二是执行次数达到maxIterations 。batch里面的所有规则都继承了Rule,在execute方法里就是遍历这些batchs,将所有的规则应用到LogicalPlan上。

接下来我们看看execute中具体是怎么做的:

def execute(plan: TreeType): TreeType = {
    var curPlan = plan
    //遍历batches
    batches.foreach { batch =>
      val batchStartPlan = curPlan
      var iteration = 1 //每个batch单独计数
      var lastPlan = curPlan //保存遍历batch之前的plan,以便和遍历后的plan进行比较,若无变化则停止执行当前batch
      var continue = true

      // Run until fix point (or the max number of iterations as specified in the strategy.
      while (continue) {
        curPlan = batch.rules.foldLeft(curPlan) { // 遍历一个batch所有的Rule,并应用到LogicalPlan上
          case (plan, rule) =>
            val startTime = System.nanoTime()
            val result = rule(plan)  // 规则应用到LogicalPlan
            val runTime = System.nanoTime() - startTime
            RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime)

            if (!result.fastEquals(plan)) {
              logTrace(
                s"""
                  |=== Applying Rule ${rule.ruleName} ===
                  |${sideBySide(plan.treeString, result.treeString).mkString("\n")}
                """.stripMargin)
            }

            result
        }
        iteration += 1 //对当前batch执行次数进行计数
        if (iteration > batch.strategy.maxIterations) { // 若大于了执行策略定义的次数,则停止执行此batch
          // Only log if this is a rule that is supposed to run more than once.
          if (iteration != 2) {
            val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}"
            if (Utils.isTesting) {
              throw new TreeNodeException(curPlan, message, null)
            } else {
              logWarning(message)
            }
          }
          continue = false
        }

        if (curPlan.fastEquals(lastPlan)) { // 若执行batch前后,plan没有变化,则停止执行此batch
          logTrace(
            s"Fixed point reached for batch ${batch.name} after ${iteration - 1} iterations.")
          continue = false
        }
        lastPlan = curPlan
      }

      if (!batchStartPlan.fastEquals(curPlan)) {
        logDebug(
          s"""
          |=== Result of Batch ${batch.name} ===
          |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")}
        """.stripMargin)
      } else {
        logTrace(s"Batch ${batch.name} has no effect.")
      }
    }

    curPlan
  }

主要执行步骤都在代码中进行了注释。
batch和里面的rules都是连续执行的,每执行完一个batch都判断此batch执行的次数是否达到maxIterations 和执行此batch前后是否有变化,达到maxIterations 或者执行batch前后无变化都不再执行此batch。

Analyzer的batches 如下:

lazy val batches: Seq[Batch] = Seq(
    Batch("Hints", fixedPoint,
      new ResolveHints.ResolveBroadcastHints(conf),
      ResolveHints.RemoveAllHints),
    Batch("Simple Sanity Check", Once,
      LookupFunctions),
    Batch("Substitution", fixedPoint,
      CTESubstitution,
      WindowsSubstitution,
      EliminateUnions,
      new SubstituteUnresolvedOrdinals(conf)),
    Batch("Resolution", fixedPoint,
      ResolveTableValuedFunctions ::
      ResolveRelations ::
      ResolveReferences ::
      ResolveCreateNamedStruct ::
      ResolveDeserializer ::
      ResolveNewInstance ::
      ResolveUpCast ::
      ResolveGroupingAnalytics ::
      ResolvePivot ::
      ResolveOrdinalInOrderByAndGroupBy ::
      ResolveAggAliasInGroupBy ::
      ResolveMissingReferences ::
      ExtractGenerator ::
      ResolveGenerate ::
      ResolveFunctions ::
      ResolveAliases ::
      ResolveSubquery ::
      ResolveWindowOrder ::
      ResolveWindowFrame ::
      ResolveNaturalAndUsingJoin ::
      ExtractWindowExpressions ::
      GlobalAggregates ::
      ResolveAggregateFunctions ::
      TimeWindowing ::
      ResolveInlineTables(conf) ::
      ResolveTimeZone(conf) ::
      TypeCoercion.typeCoercionRules ++
      extendedResolutionRules : _*),
    Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
    Batch("View", Once,
      AliasViewChild(conf)),
    Batch("Nondeterministic", Once,
      PullOutNondeterministic),
    Batch("UDF", Once,
      HandleNullInputsForUDF),
    Batch("FixNullability", Once,
      FixNullability),
    Batch("Subquery", Once,
      UpdateOuterReferences),
    Batch("Cleanup", fixedPoint,
      CleanupAliases)
  )

继续回到代码3(如下代码),这里通过analyzer模块和catalog绑定完后,由sparkSession、queryExecution和Row编码器构造了Dataset就返回了,并没有继续执行后面的其他模块,其他模块都是lazy的,只有出发了action操作的时候才会去执行。

def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
    val qe = sparkSession.sessionState.executePlan(logicalPlan)
    qe.assertAnalyzed()
    new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
  }

接下来举例子看看Analyzer模块中的规则Rule是怎么通过catalog进行绑定的。

ResolveRelations

此规则是通过catalog替换掉UnresolvedRelation:

UnresolvedRelation(tableIdentifier: TableIdentifier)

case class TableIdentifier(table: String, database: Option[String])

即可以从中获取到database和table的名字,接下来从入口方法apply看是怎么一步一步替换掉的:

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
      case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
        EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
          case v: View =>
            u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.")
          case other => i.copy(table = other)
        }
      case u: UnresolvedRelation => resolveRelation(u)
    }

首先执行的是plan的resolveOperators 方法,这是一个柯里化函数,跟进看看:

def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
    if (!analyzed) {
      val afterRuleOnChildren = mapChildren(_.resolveOperators(rule))
      if (this fastEquals afterRuleOnChildren) {
        CurrentOrigin.withOrigin(origin) {
          rule.applyOrElse(this, identity[LogicalPlan])
        }
      } else {
        CurrentOrigin.withOrigin(origin) {
          rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan])
        }
      }
    } else {
      this
    }
  }

首先判断此plan是否已经被处理过,接着调用mapChildren,并且传入的是resolveOperators方法,其实就是一个递归调用,它会优先处理它的子节点,然后再处理自己,如果处理后的LogicalPlan和当前的相等就说明他没有子节点了,则处理它自己,反之处理返回的plan。

回到前面看看这个Rule是怎么应用起来的:

case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
        EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
          case v: View =>
            u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.")
          case other => i.copy(table = other)
        }
      case u: UnresolvedRelation => resolveRelation(u)

先看第二种情况若为UnresolvedRelation,则调用resolveRelation方法进行解析:

def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {  
                                    //不是这种情况 select * from parquet.`/path/to/query`
      case u: UnresolvedRelation if !isRunningDirectlyOnFiles(u.tableIdentifier) => 
        val defaultDatabase = AnalysisContext.get.defaultDatabase // 获取默认database
        val relation = lookupTableFromCatalog(u, defaultDatabase)
        resolveRelation(relation)
      // The view's child should be a logical plan parsed from the `desc.viewText`, the variable
      // `viewText` should be defined, or else we throw an error on the generation of the View
      // operator.
      case view @ View(desc, _, child) if !child.resolved =>
        // Resolve all the UnresolvedRelations and Views in the child.
        val newChild = AnalysisContext.withAnalysisContext(desc.viewDefaultDatabase) {
          if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) {
            view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " +
              s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " +
              "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " +
              "aroud this.")
          }
          execute(child)
        }
        view.copy(child = newChild)
      case p @ SubqueryAlias(_, view: View) =>
        val newChild = resolveRelation(view)
        p.copy(child = newChild)
      case _ => plan
    }

这里第一次进来肯定是先进入第一个case,然后会调用lookupTableFromCatalog方法从catalog中找关系,此方法最终调用了SessionCatalog的lookupRelation方法:

def lookupRelation(name: TableIdentifier): LogicalPlan = {
    synchronized {
      val db = formatDatabaseName(name.database.getOrElse(currentDb))
      val table = formatTableName(name.table)
      if (db == globalTempViewManager.database) {
        globalTempViewManager.get(table).map { viewDef =>
          SubqueryAlias(table, viewDef)
        }.getOrElse(throw new NoSuchTableException(db, table))
      } else if (name.database.isDefined || !tempTables.contains(table)) {
        val metadata = externalCatalog.getTable(db, table)
        if (metadata.tableType == CatalogTableType.VIEW) {
          val viewText = metadata.viewText.getOrElse(sys.error("Invalid view without text."))
          // The relation is a view, so we wrap the relation by:
          // 1. Add a [[View]] operator over the relation to keep track of the view desc;
          // 2. Wrap the logical plan in a [[SubqueryAlias]] which tracks the name of the view.
          val child = View(
            desc = metadata,
            output = metadata.schema.toAttributes,
            child = parser.parsePlan(viewText))
          SubqueryAlias(table, child)
        } else {
          val tableRelation = CatalogRelation(
            metadata,
            // we assume all the columns are nullable.
            metadata.dataSchema.asNullable.toAttributes,
            metadata.partitionSchema.asNullable.toAttributes)
          SubqueryAlias(table, tableRelation)
        }
      } else {
        SubqueryAlias(table, tempTables(table))
      }
    }
  }
  • 若db等于globalTempViewManager.database,globalTempViewManager维护了一个全局viewName和其元数据LogicalPlan 的映射: val viewDefinitions = new mutable.HashMap[String, LogicalPlan]则直接从globalTempViewManager获取并返回。

  • 若database已定义,且临时表中未有此table:
    从externalCatalog(如hive)中获取table对应的元数据信息metadata:CatalogTable,此对象包含了table对应的类型(table(内部还是外部表),view)、存储格式、字段shema信息等:

    • 若返回的table是View类型则构造View对象(包括将viewText通过parser模块解析成语法树),并传入构造一个SubqueryAlias返回
    • 说明此table名对应的就是一个如hive的table表,通过metadata、数据和分区列的schema构造了CatalogRelation,并以此tableRelation构造SubqueryAlias返回。这里就可以看出从一个未绑定的UnresolvedRelation 到通过catalog替换的过程。
  • 说明是个session级别的临时表,从tempTables获取到包含元数据信息的LogicalPlan 并构造SubqueryAlias返回。

再次回到resolveRelation方法:

def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {
      case u: UnresolvedRelation if !isRunningDirectlyOnFiles(u.tableIdentifier) =>
        val defaultDatabase = AnalysisContext.get.defaultDatabase
        val relation = lookupTableFromCatalog(u, defaultDatabase)
        resolveRelation(relation)
      // The view's child should be a logical plan parsed from the `desc.viewText`, the variable
      // `viewText` should be defined, or else we throw an error on the generation of the View
      // operator.
      case view @ View(desc, _, child) if !child.resolved =>
        // Resolve all the UnresolvedRelations and Views in the child.
        val newChild = AnalysisContext.withAnalysisContext(desc.viewDefaultDatabase) {
          if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) {
            view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " +
              s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " +
              "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " +
              "aroud this.")
          }
          execute(child)
        }
        view.copy(child = newChild)
      case p @ SubqueryAlias(_, view: View) =>
        val newChild = resolveRelation(view)
        p.copy(child = newChild)
      case _ => plan
    }

经过lookupTableFromCatalog方法后,又调用了resolveRelation方法本身:

  • case UnresolvedRelation上面讲过了
  • case View,通过上面的解析可知这可能是外部catalog(如hive)的View,其child是viewText被parser模块解析完的Unresolved LogicalPlan,调用execute方法进行analyze。简单的说若是View,则会获取viewText重走parser和analyzer模块。
  • case SubqueryAlias(_, view: View):为view调用resolveRelation方法
  • case _ :若是其他情况,直接返回plan

总之经过resolveRelation方法之后,返回的plan是已经和实际元数据绑定好的plan,可能是从globalTempViewManager直接获取的,可能是从tempTables直接获取,也可能是从externalCatalog获取的元数据。

再回到最初的apply方法:

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
      case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
        EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
          case v: View =>
            u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.")
          case other => i.copy(table = other)
        }
      case u: UnresolvedRelation => resolveRelation(u)
    }

这里第二种情况已经分析完,再看看第一种情况,若plan是InsertIntoTable类型并且其对应的table还未绑定,则调用lookupTableFromCatalog方法与catalog进行analyze之后应用到Rule EliminateSubqueryAliases:

object EliminateSubqueryAliases extends Rule[LogicalPlan] {
  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
    case SubqueryAlias(_, child) => child
  }
}

遍历子节点有两种方式,transformDown(默认,前序遍历)、transformUp 后续遍历。
UnresolvedRelation解析后可能会是SubqueryAlias,真正有用的是其child(CatalogRelation),一旦解析完就将其删除掉保留child。
到这里Rule ResolveRelations就解析完了,其他就不再一一列举了。

RDD解析

RDD(Resilient Distributed Dataset):弹性分布式数据集。

特性

  • A list of partitions (可分片)
  • A function for computing each split (compute func)
  • A list of dependencies on other RDDs (依赖)
  • A Partitioner for key-value RDDs (分片器,决定一条数据属于某分片)
  • A list of preferred locations to compute each split on (e.g. block locations for an HDFS file) (位置优先)

Partition

  1. RDD能并行计算的原因就是Partition,一个RDD可有多个partition,每个partition一个task任务,每个partition代表了该RDD一部分数据,分区内部并不会存储具体的数据,访问数据时是通过partition的迭代器,iterator 可遍历到所有数据。
  2. partition的个数需要视情况而定,RDD 可以通过创建操作或者转换操作得到,转换操作中,分区的个数会根据转换操作对应多个 RDD 之间的依赖关系确定,窄依赖子 RDD 由父 RDD 分区个数决定,Shuffle 依赖由子 RDD 分区器决定,从集合中创建RDD时默认个数为defaultParallelism,当该值没有设定时:
    • 本地模式: conf.getInt("spark.default.parallelism", totalCores) // CPU cores
    • Mesos: conf.getInt("spark.default.parallelism", 8) // 8
    • Standalone&Yarn: conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2))
  3. 特质Partition只有一个返回index的方法,很多具体的 RDD 也会有自己实现的 partition。
  trait Partition extends Serializable { 
    def index: Int
    override def hashCode(): Int = index
    override def equals(other: Any): Boolean = super.equals(other)
  }

compute func

每个具体的RDD都得实现compute 方法,该方法接受的参数之一是一个Partition 对象,目的是计算该分区中的数据。
我们通过map方法来看具体的实现:

def map[U: ClassTag](f: T => U): RDD[U] = withScope {
    val cleanF = sc.clean(f)
    new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
  }

调用map时都会new一个MapPartitionsRDD实例,并且接收一个方法作为参数,该方法接收一个迭代器(后面会细讲),对该RDD的map操作函数f将作用于这个迭代器的每一条数据。在MapPartitionsRDD中是通过compute方法来计算对应分区的数据:

override def compute(split: Partition, context: TaskContext): Iterator[U] =
    f(context, split.index, firstParent[T].iterator(split, context))

这里将调用该RDD(MapPartitionsRDD)内的第一个父 RDD 的 iterator 方法,该方的目的是拉取父 RDD 对应分区内的数据。iterator方法会返回一个迭代器,对应的是父RDD计算完成的数据,该迭代器将作为 f 方法的一个参数,该f 方法就是上面提到的创建MapPartitionsRDD实例时传入的方法。

其实RDD的compute方法也类似。接下来我们看看iterator方法究竟都做了什么事:

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      getOrCompute(split, context)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

RDD的iterator方法即遍历对应分区的数据,先判断改RDD的存储级别若不为NONE,则说明该数据已经存在于缓存中,RDD 经过持久化操作并经历了一次计算过程 ,可直接将数据返回。

private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
    val blockId = RDDBlockId(id, partition.index)
    var readCachedBlock = true
    // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
      readCachedBlock = false
      computeOrReadCheckpoint(partition, context)
    }) match {
      case Left(blockResult) =>
        if (readCachedBlock) {
          val existingMetrics = context.taskMetrics().inputMetrics
          existingMetrics.incBytesRead(blockResult.bytes)
          new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
            override def next(): T = {
              existingMetrics.incRecordsRead(1)
              delegate.next()
            }
          }
        } else {
          new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
        }
       ...
    }
  }

通过RDD_id和partition_index唯一表示一个block,先从缓存中取数据,也有可能取不到数据

  • 数据丢失
  • RDD 经过持久化操作,但是是当前分区数据是第一次被计算,因此会出现拉取得到数据为 None

取不到的时候则调用computeOrReadCheckpoint来获取并加入缓存。
当RDD的存储级别若为NONE,则需要直接通过computeOrReadCheckpoint方法来计算。

private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointedAndMaterialized) {
      firstParent[T].iterator(split, context)
    } else {
      compute(split, context)
    }
  }

该方法会先检查当前RDD是否被checkpoint,若是则直接从依赖的checkpointRDD中获取迭代对象,若不是则需要通过compute方法计算。

dependency

RDD的容错机制就是通过dependency实现的,在外部成为血统(Lineage)关系,在源码里面实为dependency,抽象类Dependency只有一个返回对应RDD的方法。

abstract class Dependency[T] extends Serializable {
  def rdd: RDD[T]
}

每个RDD都有一个返回其所依赖的dependences:Seq[Dependency[_]] 的dependencies方法,Dependency里面存的就是父RDD,递归RDD+遍历这个dependences将可得到整个DAG。

依赖分为两种,分别是窄依赖(Narrow Dependency)和 Shuffle 依赖(Shuffle Dependency,也称即宽依赖)。在窄依赖中,父RDD的一个分区至多被一个子RDD的一个分区所依赖,分区数据不可被拆分:

在宽依赖中,父RDD的一个分区被子RDD的多个分区所依赖,分区数据被拆分:

一次转换操作可同时包含窄依赖和宽依赖:

窄依赖的抽象类为NarrowDependency,对应实现分别是 OneToOneDependency (一对一依赖)类和RangeDependency (范围依赖)类。

一对一依赖表示子 RDD 分区的编号与父 RDD 分区的编号完全一致的情况,若两个 RDD 之间存在着一对一依赖,则子 RDD 的分区个数、分区内记录的个数都将继承自父 RDD。

范围依赖是依赖关系中的一个特例,只被用于表示 UnionRDD 与父 RDD 之间的依赖关系。

宽依赖的对应实现为 ShuffleDependency 类,宽依赖支持两种 Shuffle Manager,即 HashShuffleManager 和 SortShuffleManager

Partitioner

partitioner就是决定一条数据应该数据哪个分区的分区器,但只有 k, v 类型的 RDD 才能有 partitioner,因为都是由其 k 来决定的。
特质 Partitioner提供了一个返回分区index的方法,通过传入k及指定的分区个数:

trait Partitioner { 
  def partition(key: Any, numPartitions: Int): Int
}

Spark 内置了两种分区器,分别是哈希分区器(Hash Partitioner)和范围分区器(Range Partitioner)。

Hash Partitioner

我们来看HashPartitioner的定义,主要是getPartition方法,当key为null时直接返回null

class HashPartitioner(partitions: Int) extends Partitioner {
  def numPartitions: Int = partitions
  def getPartition(key: Any): Int = key match {
    case null => 0
    case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
  }
  override def equals(other: Any): Boolean = other match {
    case h: HashPartitioner =>
      h.numPartitions == numPartitions
    case _ =>
      false
  }
  override def hashCode: Int = numPartitions
}

不为null时将调用Utils的nonNegativeMod方法,即将key的hashcode和mod取余,若结果为正,则返回该结果;若结果为负,返回结果加上 mod。

 def nonNegativeMod(x: Int, mod: Int): Int = {
    val rawMod = x % mod
    rawMod + (if (rawMod < 0) mod else 0)
  }

RangePartitioner

在key的hashcode分布不均的情况下会到导致通过HashPartitioner分出来的分区数据倾斜不均匀,这是就需要用到RangePartitioner分区器,该分区器运行速度相对HashPartitioner较慢,原理复杂。

HashPartitioner会将一个范围的key直接映射到一个partition,也就是一个partition的key一定比另一个partition的key都大或者都小,而怎么具体划分这个范围的边界成为关键,既要保证分布均匀又要减少遍历次数。具体实现可参考 Spark分区器HashPartitioner和RangePartitioner代码详解

preferred locations

每个具体的RDD实例都需要实现自己的getPreferredLocations方法,RDD位置优先即返回partition的存储位置,该位置和spark的任务调度有关,尽量将计算移到该partition对应的地方。
以从Hadoop中读取数据生成RDD为例,preferredLocations返回每一个数据块所在的机器名或者IP地址,如果每一个数据块是多份存储的(HDFS副本数),那么就会返回多个机器地址。

SparkStreaming DStream 和 DStreamGraph 解析

看 spark streaming 源码解析之前最好先了解spark core的内容。

前言

Spark Streaming 是基于Spark Core将流式计算分解成一系列的小批处理任务来执行。

在Spark Streaming里,总体负责任务的动态调度是JobScheduler,而JobScheduler有两个很重要的成员:JobGeneratorReceiverTrackerJobGenerator 负责将每个 batch 生成具体的 RDD DAG ,而ReceiverTracker负责数据的来源。

Spark Streaming里的DStream可以看成是Spark Core里的RDD的模板,DStreamGraph是RDD DAG的模板。

跟着例子看流程

DStream 也和 RDD 一样有着转换(transformation)和 输出(output)操作,通过 transformation 操作会产生新的DStream,典型的 transformation 操作有map(), filter(), reduce(), join()等。RDD的输出操作会触发action,而DStream的输出操作也会新建一个ForeachDStream,用一个函数func来记录所需要做的操作。

下面看一个例子:

val conf = new SparkConf().setMaster("local[2]")
                          .setAppName("NetworkWordCount")
val ssc = new StreamingContext(conf, Seconds(1))
val lines = ssc.socketTextStream("localhost", 9999)
val words = lines.flatMap(_.split(" "))      
val pairs = words.map(word => (word, 1))    
val wordCounts = pairs.reduceByKey(_ + _)   
wordCounts.print()
ssc.start()
ssc.awaitTermination()

在创建 StreamingContext 的时候实创建了 graph: DStreamGraph:

private[streaming] val graph: DStreamGraph = {
    if (isCheckpointPresent) {
      _cp.graph.setContext(this)
      _cp.graph.restoreCheckpointData()
      _cp.graph
    } else {
      require(_batchDur != null, "Batch duration for StreamingContext cannot be null")
      val newGraph = new DStreamGraph()
      newGraph.setBatchDuration(_batchDur)
      newGraph
    }
  }

checkpoint 可用,会优先从 checkpoint 恢复 graph,否则新建一个。graph用来动态的创建RDD DAG,DStreamGraph有两个重要的成员:inputStreamsoutputStreams

private val inputStreams = new ArrayBuffer[InputDStream[_]]()
private val outputStreams = new ArrayBuffer[DStream[_]]()

Spark Streaming记录DStream DAG 的方式就是通过DStreamGraph 实例记录所有的outputStreams ,因为outputStream 会通过依赖
dependencies 来和parent DStream形成依赖链,通过outputStreams 向前追溯遍历就可以得到所有上游的DStream,另外,DStreamGraph 还会记录所有的inputStreams ,避免每次为查找 input stream 而对 output steam 进行 BFS 的消耗。

继续回到例子,这里通过ssc.socketTextStream 创建了一个ReceiverInputDStream,在其父类 InputDStream 中会将该ReceiverInputDStream添加到inputStream里。

接着调用了flatMap方法:

def flatMap[U: ClassTag](flatMapFunc: T => TraversableOnce[U]): DStream[U] = ssc.withScope {
    new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc))
  }

--------------------------------------------------------------------

private[streaming]
class FlatMappedDStream[T: ClassTag, U: ClassTag](
    parent: DStream[T],
    flatMapFunc: T => TraversableOnce[U]
  ) extends DStream[U](parent.ssc) {

  override def dependencies: List[DStream[_]] = List(parent)

  override def slideDuration: Duration = parent.slideDuration

  override def compute(validTime: Time): Option[RDD[U]] = {
    parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc))
  }
}

创建了一个 FlatMappedDStream ,而该类的compute方法是在父 DStream(ReceiverInputDStream) 在对应batch时间的RDD上调用了flatMap方法,也就是构造了 rdd.flatMap(func) 这样的代码,后面的操作类似,随后形成的是 rdd.flatMap(func1).map(func2).reduceByKey(func3).take(),这不就是我们spark core里的东西吗。另外其dependencies是直接指向了其构造参数parent,也就是刚才的ReceiverInputDStream,每个新建的DStream的dependencies都是指向了其父DStream,这样就构成了一个依赖链,也就是形成了DStream DAG。

这里我们再看看最后的 print() 操作:

----
def print(num: Int): Unit = ssc.withScope {
    def foreachFunc: (RDD[T], Time) => Unit = {
      (rdd: RDD[T], time: Time) => {
        val firstNum = rdd.take(num + 1)
        // scalastyle:off println
        println("-------------------------------------------")
        println(s"Time: $time")
        println("-------------------------------------------")
        firstNum.take(num).foreach(println)
        if (firstNum.length > num) println("...")
        println()
        // scalastyle:on println
      }
    }
    foreachRDD(context.sparkContext.clean(foreachFunc), displayInnerRDDOps = false)
  }
----
private def foreachRDD(
      foreachFunc: (RDD[T], Time) => Unit,
      displayInnerRDDOps: Boolean): Unit = {
    new ForEachDStream(this,
      context.sparkContext.clean(foreachFunc, false), displayInnerRDDOps).register()
  }
----
#ForEachDStream
override def generateJob(time: Time): Option[Job] = {
    parent.getOrCompute(time) match {
      case Some(rdd) =>
        val jobFunc = () => createRDDWithLocalProperties(time, displayInnerRDDOps) {
          foreachFunc(rdd, time)
        }
        Some(new Job(time, jobFunc))
      case None => None
    }
  }

在print() 方法里构建了一个foreachFunc方法:对一个rdd进行了take操作并打印(spark core中的action操作)。随后创建了ForEachDStream实例并调用了register()方法:

 private[streaming] def register(): DStream[T] = {
    ssc.graph.addOutputStream(this)
    this
  }

将 OutputStream 添加到DStreamGraphoutputStreams 里。可以看到刚才构建的 foreachFunc 方法最终用在了ForEachDStream实例的generateJob方法里,并创建了一个Streaming 中的Job,在job中的run方法中会调用这个方法,也就是会触发action操作。

注意这里Spark Streaming的Job和Spark Core里的Job是不一样的,Streaming的Job执行的是前面构造的方法,方法里面是Core里的Job,方法可以定义多个core里的Job,也可以一个core里的job都没有。

[Spark SQL] 源码解析之Optimizer

前言

由前面博客我们知道了SparkSql整个解析流程如下:

  • sqlText 经过 SqlParser 解析成 Unresolved LogicalPlan;
  • analyzer 模块结合catalog进行绑定,生成 resolved LogicalPlan;
  • optimizer 模块对 resolved LogicalPlan 进行优化,生成 optimized LogicalPlan;
  • SparkPlan 将 LogicalPlan 转换成PhysicalPlan;
  • prepareForExecution()将 PhysicalPlan 转换成可执行物理计划;
  • 使用 execute()执行可执行物理计划;

详解optimizer 模块

optimizer 以及之后的模块都只会在触发了action操作后才会执行。优化器是用来将Resolved LogicalPlan转化为optimized LogicalPlan的。

optimizer 就是根据大佬们多年的SQL优化经验来对语法树进行优化,比如谓词下推、列值裁剪、常量累加等。优化的模式和Analyzer非常相近,Optimizer 同样继承了RuleExecutor,并定义了很多优化的Rule:

def batches: Seq[Batch] = {
    // Technically some of the rules in Finish Analysis are not optimizer rules and belong more
    // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
    // However, because we also use the analyzer to canonicalized queries (for view definition),
    // we do not eliminate subqueries or compute current time in the analyzer.
    Batch("Finish Analysis", Once,
      EliminateSubqueryAliases,
      EliminateView,
      ReplaceExpressions,
      ComputeCurrentTime,
      GetCurrentDatabase(sessionCatalog),
      RewriteDistinctAggregates,
      ReplaceDeduplicateWithAggregate) ::
    //////////////////////////////////////////////////////////////////////////////////////////
    // Optimizer rules start here
    //////////////////////////////////////////////////////////////////////////////////////////
    // - Do the first call of CombineUnions before starting the major Optimizer rules,
    //   since it can reduce the number of iteration and the other rules could add/move
    //   extra operators between two adjacent Union operators.
    // - Call CombineUnions again in Batch("Operator Optimizations"),
    //   since the other rules might make two separate Unions operators adjacent.
    Batch("Union", Once,
      CombineUnions) ::
    Batch("Pullup Correlated Expressions", Once,
      PullupCorrelatedPredicates) ::
    Batch("Subquery", Once,
      OptimizeSubqueries) ::
    Batch("Replace Operators", fixedPoint,
      ReplaceIntersectWithSemiJoin,
      ReplaceExceptWithAntiJoin,
      ReplaceDistinctWithAggregate) :: // aggregate替换distinct
    Batch("Aggregate", fixedPoint,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) ::
    Batch("Operator Optimizations", fixedPoint, Seq(
      // Operator push down
      PushProjectionThroughUnion, //谓词下推
      ReorderJoin(conf),
      EliminateOuterJoin(conf),
      PushPredicateThroughJoin,
      PushDownPredicate,
      LimitPushDown(conf),
      ColumnPruning, //列剪裁
      InferFiltersFromConstraints(conf),
      // Operator combine
      CollapseRepartition,
      CollapseProject,
      CollapseWindow,
      CombineFilters, //合并filter
      CombineLimits, //合并limit
      CombineUnions,
      // Constant folding and strength reduction
      NullPropagation(conf), //null处理
      FoldablePropagation,
      OptimizeIn(conf), // 关键字in的优化,替代为InSet
      ConstantFolding, //针对常量的优化,在这里会直接计算可以获得的常量
      ReorderAssociativeOperator,
      LikeSimplification, //表达式简化
      BooleanSimplification,
      SimplifyConditionals,
      RemoveDispensableExpressions,
      SimplifyBinaryComparison,
      PruneFilters(conf),
      EliminateSorts,
      SimplifyCasts,
      SimplifyCaseConversionExpressions,
      RewriteCorrelatedScalarSubquery,
      EliminateSerialization,
      RemoveRedundantAliases,
      RemoveRedundantProject,
      SimplifyCreateStructOps,
      SimplifyCreateArrayOps,
      SimplifyCreateMapOps) ++
      extendedOperatorOptimizationRules: _*) ::
    Batch("Check Cartesian Products", Once,
      CheckCartesianProducts(conf)) ::
    Batch("Join Reorder", Once,
      CostBasedJoinReorder(conf)) ::
    Batch("Decimal Optimizations", fixedPoint, //精度优化
      DecimalAggregates(conf)) ::
    Batch("Object Expressions Optimization", fixedPoint,
      EliminateMapObjects,
      CombineTypedFilters) ::
    Batch("LocalRelation", fixedPoint,
      ConvertToLocalRelation,
      PropagateEmptyRelation) ::
    Batch("OptimizeCodegen", Once,
      OptimizeCodegen(conf)) ::
    Batch("RewriteSubquery", Once,
      RewritePredicateSubquery,
      CollapseProject) :: Nil
  }

batch的执行和analyzer一样是通过RuleExecutor的execute方法依次遍历,这里不再解析。这里有部分优化的例子

数据本地化及延迟调度

前言

Spark数据本地化即移动计算而不是移动数据,而现实又是残酷的,不是想要在数据块的地方计算就有足够的资源提供,为了让task能尽可能的以最优本地化级别(Locality Levels)来启动,Spark的延迟调度应运而生,资源不够可在该Locality Levels对应的限制时间内重试,超过限制时间后还无法启动则降低Locality Levels再尝试启动……

本地化级别(Locality Levels)

  • PROCESS_LOCAL:进程本地化,代码和数据在同一个进程中,也就是在同一个executor中;计算数据的task由executor执行,数据在executor的BlockManager中,性能最好
  • NODE_LOCAL:节点本地化,代码和数据在同一个节点中;比如说,数据作为一个HDFS block块在节点上,而task在节点上某个executor中运行;或者是数据和task在一个节点上的不同executor中,数据需要在进程间进行传输
  • NO_PREF:对于task来说,数据从哪里获取都一样,没有好坏之分,比如说SparkSQL读取MySql中的数据
  • RACK_LOCAL:机架本地化,数据和task在一个机架的两个节点上,数据需要通过网络在节点之间进行传输
  • ANY:数据和task可能在集群中的任何地方,而且不在一个机架中,性能最差

这些Task的本地化级别其实描述的就是计算与数据的位置关系,这个最终的关系是如何产生的呢?接下来对其来龙去脉进行详细的讲解。

DAGScheduler提交tasks

DAGScheduler对job进行stage划分完后,会通过submitMissingTasks方法将Stage以TaskSet的形式提交给TaskScheduler,看看该方法关于位置优先的一些代码:

...
// 获取还未执行或未成功执行分区的id
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
...
// 通过getPreferredLocs方法获取rdd该分区的优先位置
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        case s: ShuffleMapStage =>
          partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        case s: ResultStage =>
          val job = s.activeJob.get
          partitionsToCompute.map { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
          }.toMap
      }
    } catch { 
    }
...
//通过最优位置等信息构建Task
val tasks: Seq[Task[_]] = try {
      stage match {
        case stage: ShuffleMapStage =>
          partitionsToCompute.map { id =>
            val locs = taskIdToLocations(id)
            val part = stage.rdd.partitions(id)
            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, stage.latestInfo.taskMetrics, properties)
          }

        case stage: ResultStage =>
          val job = stage.activeJob.get
          partitionsToCompute.map { id =>
            val p: Int = stage.partitions(id)
            val part = stage.rdd.partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(stage.id, stage.latestInfo.attemptId,
              taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics)
          }
      }
    } catch { 
    }
...
//将所有task以TaskSet的形式提交给TaskScheduler
taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))

注意这里提交的TaskSet里面的Task已经包含了该Task的优先位置,而该优先位置是通过getPreferredLocs方法获取,可以简单看看其实现:

private def getPreferredLocsInternal(
      rdd: RDD[_],
      partition: Int,
      visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = {
    ...
    // 从缓存中获取
    val cached = getCacheLocs(rdd)(partition)
    if (cached.nonEmpty) {
      return cached
    }
    // 直接通过rdd的preferredLocations方法获取
    val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
    if (rddPrefs.nonEmpty) {
      return rddPrefs.map(TaskLocation(_))
    }
    // 递归从parent Rdd获取(窄依赖)
    rdd.dependencies.foreach {
      case n: NarrowDependency[_] =>
        for (inPart <- n.getParents(partition)) {
          val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
          if (locs != Nil) {
            return locs
          }
        }
      case _ =>
    }
    Nil
  }

无论是通过哪种方式获取RDD分区的优先位置,第一次计算的数据来源肯定都是通过RDD的preferredLocations方法获取的,不同的RDD有不同的preferredLocations实现,但是数据无非就是在三个地方存在,被cache到内存、HDFS、磁盘,而这三种方式的TaskLocation都有具体的实现:

//数据在内存中
private [spark] case class ExecutorCacheTaskLocation(override val host: String, executorId: String)
  extends TaskLocation {
  override def toString: String = s"${TaskLocation.executorLocationTag}${host}_$executorId"
}
//数据在磁盘上(非HDFS上)
private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation {
  override def toString: String = host
}
//数据在HDFS上
private [spark] case class HDFSCacheTaskLocation(override val host: String) extends TaskLocation {
  override def toString: String = TaskLocation.inMemoryLocationTag + host
}

所以,在实例化Task的时候传的优先位置就是这三种的其中一种。

Locality levels生成

DAGScheduler将TaskSet提交给TaskScheduler后,TaskScheduler会为每个TaskSet创建一个TaskSetMagager来对其Task进行管理,在初始化TaskSetMagager的时候就会通过computeValidLocalityLevels计算该TaskSet包含的locality levels:

private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
    import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY}
    val levels = new ArrayBuffer[TaskLocality.TaskLocality]
    if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 &&
        pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) {
      levels += PROCESS_LOCAL
    }
    if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0 &&
        pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) {
      levels += NODE_LOCAL
    }
    if (!pendingTasksWithNoPrefs.isEmpty) {
      levels += NO_PREF
    }
    if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 &&
        pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) {
      levels += RACK_LOCAL
    }
    levels += ANY
    logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
    levels.toArray
  }

程序会依次判断该TaskSetMagager是否包含各个级别,逻辑都类似,我们就细看第一个,pendingTasksForExecutor的定义与添加:

// key为executorId,value为在该executor上有缓存的数据块对应的taskid数组
private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
...
//遍历所有该TaskSet的所有task进行添加
for (i <- (0 until numTasks).reverse) {
    addPendingTask(i)
  }
...
private def addPendingTask(index: Int) {
    for (loc <- tasks(index).preferredLocations) {
      loc match {
        case e: ExecutorCacheTaskLocation =>
          pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index
        case e: HDFSCacheTaskLocation =>
          val exe = sched.getExecutorsAliveOnHost(loc.host)
          exe match {
            case Some(set) =>
              for (e <- set) {
                pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer) += index
              }
              logInfo(s"Pending task $index has a cached location at ${e.host} " +
                ", where there are executors " + set.mkString(","))
            case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
                ", but there are no executors alive there.")
          }
        case _ =>
      }
      pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index
      for (rack <- sched.getRackForHost(loc.host)) {
        pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index
      }
    }

    if (tasks(index).preferredLocations == Nil) {
      pendingTasksWithNoPrefs += index
    }

    allPendingTasks += index  // No point scanning this whole list to find the old task there
  }

注意这里的addPendingTask方法,会遍历该TaskSetMagager管理的所有Task的优先位置(上文已解析),若是ExecutorCacheTaskLocation (缓存在内存中)则添加对应的executorId和taskId到pendingTasksForExecutor,同时还会添加到低级别需要的pendingTasksForHost、pendingTasksForRack中,说明假设一个 task 的最优本地性为 X,那么该 task 同时也具有其他所有本地性比X差的本地性。
回到上面的本地性级别判断:

if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 &&
        pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) {
      levels += PROCESS_LOCAL
    }

只要是看第三个判断 pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive())),其中,pendingTasksForExecutor.keySet就是上面说明的存在有与task对应的数据块被缓存在executor中的executorId,sched.isExecutorAlive()就是判断参数中的 executor id 当前是否 active。所以整行代码意思是存在有与task对应的数据块被缓存在executor中的executors是否有active的,若有则添加PROCESS_LOCAL级别到该TaskSet的LocalityLevels中。

后面的其他本地性级别是同样的逻辑就不细讲了,区别是如判断存在有与task对应的数据块在某些节点中的hosts是否有Alive的等……

至此,TaskSet包含的LocalityLevels就已经计算完。

延迟调度策略

若spark跑在yarn上,也有两层延迟调度,第一层就是yarn尽量将spark的executor分配到有数据的nodemanager上,这一层没有做到data locality,到spark阶段,data locality更不可能了。

延迟调度的目的是为了较小网络及IO开销,在数据量大而计算逻辑简单(task执行时间小于数据传输时间)的情况下表现明显。

Spark调度总是会尽量让每个task以最高的本地性级别来启动,当一个task以X本地性级别启动,但是该本地性级别对应的所有节点都没有空闲资源而启动失败,此时并不会马上降低本地性级别启动而是在某个时间长度内再次以X本地性级别来启动该task,若超过限时时间则降级启动。

TaskSetMagager会以某一种TaskSet包含的本地性级别遍历每个可用executor资源尝试在该executor上启动当前管理的tasks,那么是如何决定某个task能否在该executor上启动呢?首先都会通过getAllowedLocalityLevel(curTime)方法计算当前TaskSetMagager中未执行的tasks的最高本地级别:

private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
    // Remove the scheduled or finished tasks lazily
    def tasksNeedToBeScheduledFrom(pendingTaskIds: ArrayBuffer[Int]): Boolean = {
      var indexOffset = pendingTaskIds.size
      while (indexOffset > 0) {
        indexOffset -= 1
        val index = pendingTaskIds(indexOffset)
        if (copiesRunning(index) == 0 && !successful(index)) {
          return true
        } else {
          pendingTaskIds.remove(indexOffset)
        }
      }
      false
    }
    // Walk through the list of tasks that can be scheduled at each location and returns true
    // if there are any tasks that still need to be scheduled. Lazily cleans up tasks that have
    // already been scheduled.
    def moreTasksToRunIn(pendingTasks: HashMap[String, ArrayBuffer[Int]]): Boolean = {
      val emptyKeys = new ArrayBuffer[String]
      val hasTasks = pendingTasks.exists {
        case (id: String, tasks: ArrayBuffer[Int]) =>
          if (tasksNeedToBeScheduledFrom(tasks)) {
            true
          } else {
            emptyKeys += id
            false
          }
      }
      // The key could be executorId, host or rackId
      emptyKeys.foreach(id => pendingTasks.remove(id))
      hasTasks
    }

    while (currentLocalityIndex < myLocalityLevels.length - 1) {
      val moreTasks = myLocalityLevels(currentLocalityIndex) match {
        case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor)
        case TaskLocality.NODE_LOCAL => moreTasksToRunIn(pendingTasksForHost)
        case TaskLocality.NO_PREF => pendingTasksWithNoPrefs.nonEmpty
        case TaskLocality.RACK_LOCAL => moreTasksToRunIn(pendingTasksForRack)
      }
      if (!moreTasks) {
        // This is a performance optimization: if there are no more tasks that can
        // be scheduled at a particular locality level, there is no point in waiting
        // for the locality wait timeout (SPARK-4939).
        lastLaunchTime = curTime
        logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " +
          s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}")
        currentLocalityIndex += 1
      } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) {
        // Jump to the next locality level, and reset lastLaunchTime so that the next locality
        // wait timer doesn't immediately expire
        lastLaunchTime += localityWaits(currentLocalityIndex)
        logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex + 1)} after waiting for " +
          s"${localityWaits(currentLocalityIndex)}ms")
        currentLocalityIndex += 1
      } else {
        return myLocalityLevels(currentLocalityIndex)
      }
    }
    myLocalityLevels(currentLocalityIndex)
  }

循环条件里的currentLocalityIndex是getAllowedLocalityLevel 前一次被调用返回的LocalityIndex在 myLocalityLevels 中的索引,初始值为0,myLocalityLevels则是TaskSetMagager所有tasks包含的本地性级别。

  • 若myLocalityLevels(currentLocalityIndex)对应的level是否还有未执行的task,通过moreTasksToRunIn方法获取(逻辑很简单:执行完及正在执行的task都从对应列表中移除,有未执行过的task直接返回true)
  • 若没有,则currentLocalityIndex 加一继续循环(降级)
  • 若有,则先判断当前时间与上次以该级别启动时间之差是否超过了该级别能容忍的时间限制,若未超过,则直接返回对应的LocalityLevel,若超过,则currentLocalityIndex 加一继续循环(降级)

至此,就取出了该TaskSetMagager中未执行的tasks的最高本地性级别(取和maxLocality中级别高的作为最终的allowedLocality)。

最终决定是否在某个executor上启动某个task的是方法dequeueTask(execId, host, allowedLocality)

private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value)
    : Option[(Int, TaskLocality.Value, Boolean)] =
  {
    for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) {
      return Some((index, TaskLocality.PROCESS_LOCAL, false))
    }

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) {
      for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) {
        return Some((index, TaskLocality.NODE_LOCAL, false))
      }
    }

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) {
      // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic
      for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) {
        return Some((index, TaskLocality.PROCESS_LOCAL, false))
      }
    }
    ...
  }

通过TaskLocality.isAllowed方法来保证只以比allowedLocality级别高(可相等)的locality来启动task,因为一个 task 拥有比最优本地性 差的其他所有本地性。这样就保证了能尽可能的以高本地性级别来启动一个task。

优化建议

可用过Spark UI来查看某个job的task的locality level,若都是NODE_LOCAL、ANY,那么可调整数据本地化的等待时长:

  • spark.locality.wait 全局的,适用于每个locality level,默认为3s
  • spark.locality.wait.process
  • spark.locality.wait.node
  • spark.locality.wait.rack

TaskScheduler 任务提交与调度源码解析

在DAGScheduler划分为Stage并以TaskSet的形式提交给TaskScheduler后,再由TaskScheduler通过TaskSetMagager对taskSet的task进行调度与执行。

taskScheduler.submitTasks(new TaskSet(
        tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))

submitTasks方法的实现在TaskScheduler的实现类TaskSchedulerImpl中。先看整个实现:

override def submitTasks(taskSet: TaskSet) {
    val tasks = taskSet.tasks
    logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
    this.synchronized {
      val manager = createTaskSetManager(taskSet, maxTaskFailures)
      val stage = taskSet.stageId
      val stageTaskSets =
        taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
      stageTaskSets(taskSet.stageAttemptId) = manager
      val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
        ts.taskSet != taskSet && !ts.isZombie
      }
      if (conflictingTaskSet) {
        throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
          s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
      }
      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

      if (!isLocal && !hasReceivedTask) {
        starvationTimer.scheduleAtFixedRate(new TimerTask() {
          override def run() {
            if (!hasLaunchedTask) {
              logWarning("Initial job has not accepted any resources; " +
                "check your cluster UI to ensure that workers are registered " +
                "and have sufficient resources")
            } else {
              this.cancel()
            }
          }
        }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
      }
      hasReceivedTask = true
    }
    backend.reviveOffers()
  }
val manager = createTaskSetManager(taskSet, maxTaskFailures)

先为当前TaskSet创建TaskSetManager,TaskSetManager负责管理一个单独taskSet的每一个task,决定某个task是否在一个executor上启动,如果task失败,负责重试task直到task重试次数,并通过延迟调度来执行task的位置感知调度。

val stageTaskSets =
        taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
      stageTaskSets(taskSet.stageAttemptId) = manager

key为stageId,value为一个HashMap,其中key为stageAttemptId,value为TaskSet。

val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
        ts.taskSet != taskSet && !ts.isZombie
      }
      if (conflictingTaskSet) {
        throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
          s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
      }

isZombie是TaskSetManager中所有tasks是否不需要执行(成功执行或者stage被删除)的一个标记,如果该TaskSet没有被完全执行并且已经存在和新进来的taskset一样的另一个TaskSet,则抛出异常,确保一个stage不能有两个taskSet同时运行。

schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

将当前taskSet添加到调度池中,schedulableBuilder的类型是SchedulerBuilder的一个trait,有两个实现FIFOSchedulerBuilder和 FairSchedulerBuilder,并且默认采用的是FIFO方式。

schedulableBuilder是SparkContext 中newTaskSchedulerImpl(sc)在创建TaskSchedulerImpl的时候通过scheduler.initialize(backend)的initialize方法对schedulableBuilder进行了实例化。

def initialize(backend: SchedulerBackend) {
    this.backend = backend
    // temporarily set rootPool name to empty
    rootPool = new Pool("", schedulingMode, 0, 0)
    schedulableBuilder = {
      schedulingMode match {
        case SchedulingMode.FIFO =>
          new FIFOSchedulableBuilder(rootPool)
        case SchedulingMode.FAIR =>
          new FairSchedulableBuilder(rootPool, conf)
        case _ =>
          throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode")
      }
    }
    schedulableBuilder.buildPools()
  }
backend.reviveOffers()

接下来调用了SchedulerBackend的riviveOffers方法向schedulerBackend申请资源。backend也是通过scheduler.initialize(backend)的参数传递过来的,具体是在SparkContext 中被创建的。

val (sched, ts) = SparkContext.createTaskScheduler(this, master, deployMode)

回到向schedulerBackend申请资源,
调用CoarseGrainedSchedulerBackend的reviveOffers方法,该方法给driverEndpoint发送ReviveOffer消息。

 override def reviveOffers() {
    driverEndpoint.send(ReviveOffers)
  }

driverEndpoint收到ReviveOffer消息后调用makeOffers方法。

private def makeOffers() {
      // Filter out executors under killing
      val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
      val workOffers = activeExecutors.map { case (id, executorData) =>
        new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
      }.toSeq
      launchTasks(scheduler.resourceOffers(workOffers))
    }

该方法先过滤出活跃的executor并封装成WorkerOffer,WorkerOffer包含executorId、host、可用的cores三个信息。这里的executorDataMap是HashMap[String, ExecutorData]类型,key为executorId,value为对应executor的信息,包括host、RPC信息、totalCores、freeCores。

在客户端向Master注册Application的时候,Master已经为Application分配并启动好Executor,然后注册给CoarseGrainedSchedulerBackend,注册信息就是存储在executorDataMap数据结构中。

launchTasks(scheduler.resourceOffers(workOffers))

先看里面的scheduler.resourceOffers(workOffers),TaskSchedulerImpl调用resourceOffers方法通过准备好的资源获得要被执行的Seq[TaskDescription],交给CoarseGrainedSchedulerBackend分发到各个executor上执行。下面看具体实现:

def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
    //标记是否有新的executor加入
    var newExecAvail = false
    // 更新executor,host,rack信息
    for (o <- offers) {
      executorIdToHost(o.executorId) = o.host
      executorIdToTaskCount.getOrElseUpdate(o.executorId, 0)
      if (!executorsByHost.contains(o.host)) {
        executorsByHost(o.host) = new HashSet[String]()
        executorAdded(o.executorId, o.host)
        newExecAvail = true
      }
      for (rack <- getRackForHost(o.host)) {
        hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host
      }
    }

    // 随机打乱offers,避免多个task集中分配到某些节点上,为了负载均衡
    val shuffledOffers = Random.shuffle(offers)
    // 建一个二维数组,保存每个Executor上将要分配的那些task
    val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
    //每个executor上可用的cores
    val availableCpus = shuffledOffers.map(o => o.cores).toArray
    //返回排序过的TaskSet队列,有FIFO及Fair两种排序规则,默认为FIFO
    val sortedTaskSets = rootPool.getSortedTaskSetQueue
    for (taskSet <- sortedTaskSets) {
      logDebug("parentName: %s, name: %s, runningTasks: %s".format(
        taskSet.parent.name, taskSet.name, taskSet.runningTasks))
      if (newExecAvail) { // 如果该executor是新分配来的
        taskSet.executorAdded() // 重新计算TaskSetManager的就近原则
      }
    }

    // 利用双重循环对每一个taskSet依照调度的顺序,依次按照本地性级别顺序尝试启动task
    // 根据taskSet及locality遍历所有可用的executor,找出可以在各个executor上启动的task,
    // 加到tasks:Seq[Seq[TaskDescription]]中
    // 数据本地性级别顺序:PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
    var launchedTask = false
    for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
      do {
       //将计算资源按照就近原则分配给taskSet,用于执行其中的task
        launchedTask = resourceOfferSingleTaskSet(
            taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
      } while (launchedTask)
    }

    if (tasks.size > 0) {
      hasLaunchedTask = true
    }
    return tasks
  }

跟进resourceOfferSingleTaskSet方法:

private def resourceOfferSingleTaskSet(
      taskSet: TaskSetManager,
      maxLocality: TaskLocality,
      shuffledOffers: Seq[WorkerOffer],
      availableCpus: Array[Int],
      tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
    var launchedTask = false
    //遍历所有executor
    for (i <- 0 until shuffledOffers.size) {
      val execId = shuffledOffers(i).executorId
      val host = shuffledOffers(i).host
      //若当前executor可用的core数满足一个task所需的core数
      if (availableCpus(i) >= CPUS_PER_TASK) {
        try {
          //获取taskSet哪些task可以在该executor上启动
          for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
            //将需要在该executor启动的task添加到tasks中
            tasks(i) += task 
            val tid = task.taskId 
            taskIdToTaskSetManager(tid) = taskSet // task -> taskSetManager
            taskIdToExecutorId(tid) = execId // task -> executorId
            executorIdToTaskCount(execId) += 1 //该executor上的task+1
            executorsByHost(host) += execId // host -> executorId
            availableCpus(i) -= CPUS_PER_TASK //该executor上可用core数减去对应task的core数
            assert(availableCpus(i) >= 0)
            launchedTask = true
          }
        } catch {
          case e: TaskNotSerializableException =>
            logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
            // Do not offer resources for this task, but don't throw an error to allow other
            // task sets to be submitted.
            return launchedTask
        }
      }
    }
    return launchedTask
  }

这个方法主要是遍历所有可用的executor,在core满足一个task所需core的条件下,通过resourceOffer方法获取taskSet能在该executor上启动的task,并添加到tasks中予以返回。下面具体看resourceOffer的实现:

def resourceOffer(
      execId: String,
      host: String,
      maxLocality: TaskLocality.TaskLocality)
    : Option[TaskDescription] =
  {
    if (!isZombie) {
      val curTime = clock.getTimeMillis()

      var allowedLocality = maxLocality

      if (maxLocality != TaskLocality.NO_PREF) {
        allowedLocality = getAllowedLocalityLevel(curTime)
        if (allowedLocality > maxLocality) {
          // We're not allowed to search for farther-away tasks
          allowedLocality = maxLocality
        }
      }

      dequeueTask(execId, host, allowedLocality) match {
        case Some((index, taskLocality, speculative)) =>
          // Found a task; do some bookkeeping and return a task description
          val task = tasks(index)
          val taskId = sched.newTaskId()
          // Do various bookkeeping
          copiesRunning(index) += 1
          val attemptNum = taskAttempts(index).size
          val info = new TaskInfo(taskId, index, attemptNum, curTime,
            execId, host, taskLocality, speculative)
          taskInfos(taskId) = info
          taskAttempts(index) = info :: taskAttempts(index)
          // Update our locality level for delay scheduling
          // NO_PREF will not affect the variables related to delay scheduling
          if (maxLocality != TaskLocality.NO_PREF) {
            currentLocalityIndex = getLocalityIndex(taskLocality)
            lastLaunchTime = curTime
          }
          // Serialize and return the task
          val startTime = clock.getTimeMillis()
          val serializedTask: ByteBuffer = try {
            Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
          } catch {
            // If the task cannot be serialized, then there's no point to re-attempt the task,
            // as it will always fail. So just abort the whole task-set.
            case NonFatal(e) =>
              val msg = s"Failed to serialize task $taskId, not attempting to retry it."
              logError(msg, e)
              abort(s"$msg Exception during serialization: $e")
              throw new TaskNotSerializableException(e)
          }
          if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
              !emittedTaskSizeWarning) {
            emittedTaskSizeWarning = true
            logWarning(s"Stage ${task.stageId} contains a task of very large size " +
              s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " +
              s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.")
          }
          addRunningTask(taskId)

          // We used to log the time it takes to serialize the task, but task size is already
          // a good proxy to task serialization time.
          // val timeTaken = clock.getTime() - startTime
          val taskName = s"task ${info.id} in stage ${taskSet.id}"
          logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," +
            s" $taskLocality, ${serializedTask.limit} bytes)")

          sched.dagScheduler.taskStarted(task, info)
          return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
            taskName, index, serializedTask))
        case _ =>
      }
    }
    None
  }
 if (maxLocality != TaskLocality.NO_PREF) {
        allowedLocality = getAllowedLocalityLevel(curTime)
        if (allowedLocality > maxLocality) {
          // We're not allowed to search for farther-away tasks
          allowedLocality = maxLocality
        }
      }

getAllowedLocalityLevel(curTime)会根据延迟调度调整合适的Locality,目的都是尽可能的以最好的locality来启动每一个task,getAllowedLocalityLevel返回的是当前taskSet中所有未执行的task的最高locality,以该locality作为本次调度能容忍的最差locality,在后续的搜索中只搜索本地性比这个级别好的情况。allowedLocality 最终取以getAllowedLocalityLevel(curTime)返回的locality和maxLocality中级别较高的locality。

根据allowedLocality寻找合适的task,若返回不为空,则说明在该executor上分配了task,然后进行信息跟新,将taskid加入到runningTask中,跟新延迟调度信息,序列化task,通知DAGScheduler,最后返回taskDescription,我们来看看dequeueTask的实现:

private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value)
    : Option[(Int, TaskLocality.Value, Boolean)] =
  {
    for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) {
      return Some((index, TaskLocality.PROCESS_LOCAL, false))
    }

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) {
      for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) {
        return Some((index, TaskLocality.NODE_LOCAL, false))
      }
    }

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) {
      // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic
      for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) {
        return Some((index, TaskLocality.PROCESS_LOCAL, false))
      }
    }

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) {
      for {
        rack <- sched.getRackForHost(host)
        index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack))
      } {
        return Some((index, TaskLocality.RACK_LOCAL, false))
      }
    }

    if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) {
      for (index <- dequeueTaskFromList(execId, allPendingTasks)) {
        return Some((index, TaskLocality.ANY, false))
      }
    }

    // find a speculative task if all others tasks have been scheduled
    dequeueSpeculativeTask(execId, host, maxLocality).map {
      case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)}
  }

首先看是否存在execId对应的PROCESS_LOCAL类别的任务,如果存在,取出来调度,如果不存在,只在比allowedLocality大或者等于的级别上去查看是否存在execId对应类别的任务,若有则调度。

其中的dequeueTaskFromList是从execId对应类别(如PROCESS_LOCAL)的任务列表中尾部取出一个task返回其在taskSet中的taskIndex,跟进该方法:

private def dequeueTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = {
    var indexOffset = list.size
    while (indexOffset > 0) {
      indexOffset -= 1
      val index = list(indexOffset)
      if (!executorIsBlacklisted(execId, index)) {
        // This should almost always be list.trimEnd(1) to remove tail
        list.remove(indexOffset)
        if (copiesRunning(index) == 0 && !successful(index)) {
          return Some(index)
        }
      }
    }
    None
  }

这里有个黑名单机制,利用executorIsBlacklisted方法查看该executor是否属于task的黑名单,黑名单记录task上一次失败所在的Executor Id和Host,以及其对应的“黑暗”时间,“黑暗”时间是指这段时间内不要再往这个节点上调度这个Task了。

private def executorIsBlacklisted(execId: String, taskId: Int): Boolean = {
    if (failedExecutors.contains(taskId)) {
      val failed = failedExecutors.get(taskId).get
      return failed.contains(execId) &&
        clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT
    }
    false
  }

可以看到在dequeueTask方法的最后一段代码:

 // find a speculative task if all others tasks have been scheduled
    dequeueSpeculativeTask(execId, host, maxLocality).map {
      case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)}

这里是启动推测执行,推测任务是指对一个Task在不同的Executor上启动多个实例,如果有Task实例运行成功,则会干掉其他Executor上运行的实例,只会对运行慢的任务启动推测任务。

通过scheduler.resourceOffers(workOffers)方法返回了在哪些executor上启动哪些task的Seq[Seq[TaskDescription]]信息后,将调用launchTasks来启动各个task,实现如下:

private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
      for (task <- tasks.flatten) {
        val serializedTask = ser.serialize(task)
        if (serializedTask.limit >= maxRpcMessageSize) {
          scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
            try {
              var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
                "spark.rpc.message.maxSize (%d bytes). Consider increasing " +
                "spark.rpc.message.maxSize or using broadcast variables for large values."
              msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)
              taskSetMgr.abort(msg)
            } catch {
              case e: Exception => logError("Exception in error callback", e)
            }
          }
        }
        else {
          val executorData = executorDataMap(task.executorId)
          executorData.freeCores -= scheduler.CPUS_PER_TASK

          logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
            s"${executorData.executorHost}.")

          executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
        }
      }
    }

先将task进行序列化, 如果当前task序列化后的大小超过了128MB-200KB,跳过当前task,并把对应的taskSetManager置为zombie模式,若大小不超过限制,则发送消息到executor启动task执行。

DAGScheduler划分stage源码解析

概述

Spark Application只有遇到action操作时才会真正的提交任务并进行计算,DAGScheduler 会根据各个RDD之间的依赖关系形成一个DAG,并根据ShuffleDependency来进行stage的划分,stage包含多个tasks,个数由该stage的finalRDD决定,stage里面的task完全相同,DAGScheduler 完成stage的划分后基于每个Stage生成TaskSet,并提交给TaskScheduler,TaskScheduler负责具体的task的调度,在Worker节点上启动task。

Job的提交

以count为例,直接看源码都有哪些步骤:

def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum
    DAGScheduler#runJob
        DAGScheduler#runJob
            DAGScheduler#runJob
                DAGScheduler#dagScheduler.runJob
                    DAGScheduler#submitJob
                        eventProcessLoop.post(JobSubmitted(**))

eventProcessLoop是一个DAGSchedulerEventProcessLoop(this)对象,可以把DAGSchedulerEventProcessLoop理解成DAGScheduler的对外的功能接口。它对外隐藏了自己内部实现的细节。无论是内部还是外部消息,DAGScheduler可以共用同一消息处理代码,逻辑清晰,处理方式统一。
eventProcessLoop接收各种消息并进行处理,处理的逻辑在其doOnReceive方法中:

 private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
    case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
      dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)

    case MapStageSubmitted(jobId, dependency, callSite, listener, properties) =>
      dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties)

    ......
}

当提交的是JobSubmitted,便会通过 dagScheduler.handleJobSubmitted处理此事件。

Stage的划分

在handleJobSubmitted方法中第一件事情就是通过finalRDD向前追溯对Stage的划分。

private[scheduler] def handleJobSubmitted(jobId: Int,
    finalRDD: RDD[_],
    func: (TaskContext, Iterator[_]) => _,
    partitions: Array[Int],
    callSite: CallSite,
    listener: JobListener,
    properties: Properties) {
  var finalStage: ResultStage = null
  try { 
 //Stage划分过程是从最后一个Stage开始往前执行的,最后一个Stage的类型是ResultStage
    finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)
  } catch {
    case e: Exception =>
      logWarning("Creating new stage failed due to exception - job: " + jobId, e)
      listener.jobFailed(e)
      return
  }
  //为此job生成一个ActiveJob对象
  val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
  clearCacheLocs()
  logInfo("Got job %s (%s) with %d output partitions".format(
    job.jobId, callSite.shortForm, partitions.length))
  logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
  logInfo("Parents of final stage: " + finalStage.parents)
  logInfo("Missing parents: " + getMissingParentStages(finalStage))

  val jobSubmissionTime = clock.getTimeMillis()
  jobIdToActiveJob(jobId) = job //记录该job处于active状态
  activeJobs += job 
  finalStage.setActiveJob(job)
  val stageIds = jobIdToStageIds(jobId).toArray
  val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
  listenerBus.post( //向LiveListenerBus发送Job提交事件
    SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
  submitStage(finalStage) //提交Stage

  submitWaitingStages()
}

跟进newResultStage方法:

private def newResultStage(
      rdd: RDD[_],
      func: (TaskContext, Iterator[_]) => _,
      partitions: Array[Int],
      jobId: Int,
      callSite: CallSite): ResultStage = {
    val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) //获取stage的parentstage
    val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite)
    stageIdToStage(id) = stage //将Stage和stage_id关联
    updateJobIdStageIdMaps(jobId, stage) //跟新job所包含的stage
    stage
  }

直接实例化一个ResultStage,但需要parentStages作为参数,我们看看getParentStagesAndId做了什么:

private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = {
    val parentStages = getParentStages(rdd, firstJobId)
    val id = nextStageId.getAndIncrement()
    (parentStages, id)
  }

获取parentStages,并返回一个与stage关联的唯一id,由于是递归的向前生成stage,所以最先生成的stage是最前面的stage,越往前的stageId就越小,即父Stage的id最小。继续跟进getParentStages:

private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
    val parents = new HashSet[Stage] // 当前Stage的所有parent Stage
    val visited = new HashSet[RDD[_]] // 已经访问过的RDD
    // We are manually maintaining a stack here to prevent StackOverflowError
    // caused by recursively visiting
    val waitingForVisit = new Stack[RDD[_]] //等待访问的RDD
    def visit(r: RDD[_]) {
      if (!visited(r)) { //若未访问过
        visited += r  //标记已被访问
        // Kind of ugly: need to register RDDs with the cache here since
        // we can't do it in its constructor because # of partitions is unknown
        for (dep <- r.dependencies) { //遍历其所有依赖
          dep match {
            case shufDep: ShuffleDependency[_, _, _] => //若为宽依赖,则生成新的Stage,shuffleMapstage
              parents += getShuffleMapStage(shufDep, firstJobId)
            case _ => //若为窄依赖(归为当前Stage),压入栈,继续向前循环,直到遇到宽依赖或者无依赖
              waitingForVisit.push(dep.rdd)
          }
        }
      }
    }
    waitingForVisit.push(rdd) //将当前rdd压入栈
    while (waitingForVisit.nonEmpty) { //等待访问的rdd不为空时继续访问
      visit(waitingForVisit.pop())
    }
    parents.toList
  }

通过给定的RDD返回其依赖的Stage集合。通过RDD每一个依赖进行遍历,遇到窄依赖就继续往前遍历,遇到ShuffleDependency便通过getShuffleMapStage返回一个ShuffleMapStage对象添加到父Stage列表中。可见,这里的parentStage是Stage直接依赖的父stages(parentStage也有自己的parentStage),而不是整个DAG的所有stages。继续跟进getShuffleMapStage的实现:

private def getShuffleMapStage(
      shuffleDep: ShuffleDependency[_, _, _],
      firstJobId: Int): ShuffleMapStage = {
    shuffleToMapStage.get(shuffleDep.shuffleId) match {
      case Some(stage) => stage //若已经在shuffleToMapStage存在直接返回Stage
      case None => //不存在需要生成新的Stage
        //为当前shuffle的父shuffle都生成一个ShuffleMapStage
       getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
          if (!shuffleToMapStage.contains(dep.shuffleId)) {
            shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) //跟新shuffleToMapStage映射
          }
        }
        // 为当前shuffle生成新的Stage
        val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
        shuffleToMapStage(shuffleDep.shuffleId) = stage
        stage
    }
  }

先从shuffleToMapStage根据shuffleid获取Stage,若未获取到再去计算,第一次都肯定为None,我们先看getAncestorShuffleDependencies干了什么:

 private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
    val parents = new Stack[ShuffleDependency[_, _, _]] // 当前shuffleDependency所有的祖先ShuffleDependency(不是直接ShuffleDependency)
    val visited = new HashSet[RDD[_]] // 已经被访问过的RDD
    // 等待被访问的RDD
    val waitingForVisit = new Stack[RDD[_]]
    def visit(r: RDD[_]) {
      if (!visited(r)) { //未被访问过
        visited += r //标记已被访问
        for (dep <- r.dependencies) { //遍历直接依赖
          dep match {
            case shufDep: ShuffleDependency[_, _, _] => 
              if (!shuffleToMapStage.contains(shufDep.shuffleId)) { // 若为shuffleDependency并且还没有映射,则添加到parents 
                parents.push(shufDep)
              }
            case _ =>
          }
          waitingForVisit.push(dep.rdd)  //即使是shuffleDependency的rdd也要继续遍历
        }
      }
    }

    waitingForVisit.push(rdd)
    while (waitingForVisit.nonEmpty) {
      visit(waitingForVisit.pop())
    }
    parents
  }

貌似和getParentStages方法很像,区别是这里获取的所有祖先ShuffleDependency,而不是直接父ShuffleDependency。

为当前shuffle的父shuffle都生成一个ShuffleMapStage后再通过newOrUsedShuffleStage获取当前依赖的shuffleStage,再和shuffleid关联起来,看newOrUsedShuffleStage的实现:

private def newOrUsedShuffleStage(
      shuffleDep: ShuffleDependency[_, _, _],
      firstJobId: Int): ShuffleMapStage = {
    val rdd = shuffleDep.rdd //依赖对应的rdd
    val numTasks = rdd.partitions.length //分区个数
    val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) //返回当前rdd的shufflestage
    if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
    //如果当前shuffle已经在MapOutputTracker中注册过,也就是Stage已经被计算过,从MapOutputTracker中获取计算结果
      val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
      val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
      (0 until locs.length).foreach { i => // 更新Shuffle的Shuffle Write路径
        if (locs(i) ne null) {
          // locs(i) will be null if missing
          stage.addOutputLoc(i, locs(i))
        }
      }
    } else { //还没有被注册计算过
      // Kind of ugly: need to register RDDs with the cache and map output tracker here
      // since we can't do it in the RDD constructor because # of partitions is unknown
      logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
      mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)  //注册
    }
    stage
  }

继续看newShuffleMapStage:

private def newShuffleMapStage(
      rdd: RDD[_],
      numTasks: Int,
      shuffleDep: ShuffleDependency[_, _, _],
      firstJobId: Int,
      callSite: CallSite): ShuffleMapStage = {
    val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId) //获取parentstages即stageid
    val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages,
      firstJobId, callSite, shuffleDep) //实例化一个shuffleStage对象

    stageIdToStage(id) = stage //Stage和id关联
    updateJobIdStageIdMaps(firstJobId, stage) //跟新job所有的Stage
    stage
  }

怎么和newResultStage极其的相似?是的没错,这里会生成ShuffleStage,getParentStagesAndId里面的实现就是一个递归调用。

由finalRDD往前追溯递归生成Stage,最前面的ShuffleStage先生成,最终生成ResultStage,至此,DAGScheduler对Stage的划分已经完成。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.