实现最简单的分布式任务调度框架

前段时间,公司要改造现有的单节点调度为分布式任务调度,然后就研究了目前市面上主流的开源分布式任务调度框架,用起来就一个感觉:麻烦!特别是之前在一个类里写了好多个调度任务,改造起来更加麻烦。我这人又比较懒,总感觉用了别人写好的工具还要改一大堆,心里就有点不舒服。

实现最简单的分布式任务调度框架

于是我就想自己写一个框架,毕竟自己觉得分布式任务调度在所有分布式系统中是最简单的,因为一般公司任务调度本身不可能同时调度海量的任务,很大的并发,改造成分布式主要还是为了分散任务到多个节点,以便同一时间处理更多的任务。

后面有一天,我在公司前台取快递,看到这样一个现象:我们好几个同事(包括我)在前台那从头到尾看快递是不是自己的,是自己的就取走,不是就忽略,然后我就收到了启发。这个场景类比到分布式调度系统中,我们可以认为是快递公司或者快递员已经把每个快递按照我们名字电话分好了快递,我们只需要取走自己的就行了。

但是从另外一个角度看,也可以理解成我们每个人都是从头到尾看了所有快递,然后按照某种约定的规则,如果是自己的快递就拿走,不是自己的就忽略继续看下一个。如果把快递想象成任务,一堆人去拿一堆快递也可以很顺利的拿到各自的快递,那么一堆节点自己去取任务是不是也可以很好的处理各自的任务呢?

传统的分布式任务调度都有一个调度中心,这个调度中心也都要部署称多节点的集群,避免单点故障,然后还有一堆执行器,执行器负责执行调度中心分发的任务。按照上面的启发,我的思路是放弃中心式的调度中心直接由各个执行器节点去公共的地方按照约定的规则去取任务,然后执行。设计示意图如下

实现最简单的分布式任务调度框架

有人可能怀疑那任务db库不是有单点问题吗,我想反问下,难道其他的分布式任务调度框架没有这个问题吗?针对数据库单点我们可以单独类似业务库那样考虑高可用方案,这里不是这篇文章的讨论重点。很明显我们重点放在执行节点那里到底怎么保证高可用,单个任务不会被多个节点同时执行,单个节点执行到一半突然失联了,这个任务怎么办等复杂的问题。

后续我们使用未经修饰的代码的方式一一解决这个问题(未经修饰主要是没有优化结构流水账式的代码风格,主要是很多人包括我自己看别人源码时总是感觉晕头转向的,仿佛置身迷宫般,看起来特别费劲,可能是我自己境界未到吧)

既然省略了集中式的调度,那么既然叫任务调度很明显必须要有调度的过程,不然多个节点去抢一个任务怎么避免冲突呢?我这里解决方式是:首先先明确一个任务的几种状态:待执行,执行中,有异常,已完成。

每个节点起一个线程一直去查很快就要开始执行的待执行任务,然后遍历这些任务,使用乐观锁的方式先更新这个任务的版本号(版本号+1)和状态(变成执行中),如果更新成功就放入节点自己的延时队列中等待执行。

由于每个节点的线程都是去数据库查待执行的任务,很明显变成执行中的任务下次就不会被其他节点再查询到了,至于对于那些在本节点更新状态之前就查到的待执行任务也会经过乐观锁尝试后更新失败从而跳过这个任务,这样就可以避免一个任务同时被多个节点重复执行。关键代码如下:

package com.rdpaas.task.scheduler;  import com.rdpaas.task.common.*; 
import com.rdpaas.task.config.EasyJobConfig; import com.rdpaas.task.repository.NodeRepository; 
import com.rdpaas.task.repository.TaskRepository; import com.rdpaas.task.strategy.Strategy;  
import org.slf4j.Logger; import org.slf4j.LoggerFactory; 
import org.springframework.beans.factory.annotation.Autowired; 
import org.springframework.stereotype.Component;  
import javax.annotation.PostConstruct;  
import java.util.Date; import java.util.List; import java.util.concurrent.*;   @Component public class TaskExecutor {      private static final Logger logger = LoggerFactory.getLogger(TaskExecutor.class);      @Autowired     private TaskRepository taskRepository;      @Autowired     private NodeRepository nodeRepository;      @Autowired     private EasyJobConfig config;     private DelayQueue<DelayItem<Task>> taskQueue = new DelayQueue<>();           private ExecutorService bossPool = Executors.newFixedThreadPool(2);           private ThreadPoolExecutor workerPool;       @PostConstruct     public void init() {          workerPool = new ThreadPoolExecutor(config.getCorePoolSize(), config.getMaxPoolSize(), 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(config.getQueueSize()));                  bossPool.execute(new Loader());                  bossPool.execute(new Boss());      }      class Loader implements Runnable {          @Override         public void run() {             for(;;) {                 try {                                   List<Task> tasks = taskRepository.listPeddingTasks(config.getFetchDuration());                     if(tasks == null || tasks.isEmpty()) {                         continue;                     }                     for(Task task:tasks) {                          task.setStatus(TaskStatus.DOING);                         task.setNodeId(config.getNodeId());                                                  int n = taskRepository.updateWithVersion(task);                         Date nextStartTime = task.getNextStartTime();                         if(n == 0 || nextStartTime == null) {                             continue;                         }                                                  task = taskRepository.get(task.getId());                         DelayItem<Task> delayItem = new DelayItem<Task>(nextStartTime.getTime() - new Date().getTime(), task);                         taskQueue.offer(delayItem);                      }                     Thread.sleep(config.getFetchPeriod());                 } catch(Exception e) {                     logger.error("fetch task list failed,cause by:{}", e);                 }             }         }      }      class Boss implements Runnable {         @Override         public void run() {             for (;;) {                 try {                                           DelayItem<Task> item = taskQueue.take();                     if(item != null && item.getItem() != null) {                         Task task = item.getItem();                         workerPool.execute(new Worker(task));                     }                  } catch (Exception e) {                     logger.error("fetch task failed,cause by:{}", e);                 }             }         }      }      class Worker implements Runnable {          private Task task;          public Worker(Task task) {             this.task = task;         }          @Override         public void run() {             logger.info("Begin to execute task:{}",task.getId());             TaskDetail detail = null;             try {                 //开始任务                 detail = taskRepository.start(task);                 if(detail == nullreturn;                 //执行任务                 task.getInvokor().invoke();                 //完成任务                 finish(task,detail);                 logger.info("finished execute task:{}",task.getId());             } catch (Exception e) {                 logger.error("execute task:{} error,cause by:{}",task.getId(), e);                 try {                     taskRepository.fail(task,detail,e.getCause().getMessage());                 } catch(Exception e1) {                     logger.error("fail task:{} error,cause by:{}",task.getId(), e);                 }             }         }      }           private void finish(Task task,TaskDetail detail) throws Exception {          //查看是否有子类任务         List<Task> childTasks = taskRepository.getChilds(task.getId());         if(childTasks == null || childTasks.isEmpty()) {             //当没有子任务时完成父任务             taskRepository.finish(task,detail);             return;         } else {             for (Task childTask : childTasks) {                 //开始任务                 TaskDetail childDetail = null;                 try {                     //将子任务状态改成执行中                     childTask.setStatus(TaskStatus.DOING);                     childTask.setNodeId(config.getNodeId());                     //开始子任务                     childDetail = taskRepository.startChild(childTask,detail);                     //使用乐观锁更新下状态,不然这里可能和恢复线程产生并发问题                     int n = taskRepository.updateWithVersion(childTask);                     if (n > 0) {                         //再从数据库取一下,避免上面update修改后version不同步                         childTask = taskRepository.get(childTask.getId());                         //执行子任务                         childTask.getInvokor().invoke();                         //完成子任务                         finish(childTask, childDetail);                     }                 } catch (Exception e) {                     logger.error("execute child task error,cause by:{}", e);                     try {                         taskRepository.fail(childTask, childDetail, e.getCause().getMessage());                     } catch (Exception e1) {                         logger.error("fail child task error,cause by:{}", e);                     }                 }             }                          taskRepository.finish(task,detail);          }      }  }

如上所述,可以保证一个任务同一个时间只会被一个节点调度执行。这时候如果部署多个节点,正常应该可以很顺利的将任务库中的任务都执行到,就像一堆人去前台取快递一样,可以很顺利的拿走所有快递。毕竟对于每个快递不是自己的就是其他人的,自己的快递也不会是其他人的。

但是这里的调度和取快递有一点不一样,取快递的每个人都知道怎么去区分到底哪个快递是自己的。这里的调度完全没这个概念,完全是哪个节点运气好使用乐观锁更新了这个任务状态就是哪个节点的。总的来说区别就是需要一个约定的规则,快递是不是自己的,直接看快递上的名字和手机号码就知道了。任务到底该不该自己执行我们也可以出一个这种规则,明确哪些任务那些应该是哪些节点可以执行,从而避免无谓的锁竞争。

这里可以借鉴负载均衡的那些策略,目前我想实现如下规则:

    id_hash : 按照任务自增id的对节点个数取余,余数值和当前节点的实时序号匹配,可以匹配就可以拿走执行,否则请自觉忽略掉这个任务 least_count:最少执行任务的节点优先去取任务 weight:按照节点权重去取任务 default:默认先到先得,没有其它规则

根据上面规则也可以说是任务的负载均衡策略可以知道除了默认规则,其余规则都需要知道全局的节点信息,比如节点执行次数,节点序号,节点权重等,所以我们需要给节点添加一个心跳,隔一个心跳周期上报一下自己的信息到数据库,心跳核心代码如下:

     private DelayQueue<DelayItem<Node>> heartBeatQueue = new DelayQueue<>();       private ExecutorService bossPool = Executors.newFixedThreadPool(2);    @PostConstruct     public void init() {                  if(config.isRecoverEnable() && config.isHeartBeatEnable()) {                          heartBeatQueue.offer(new DelayItem<>(0,new Node(config.getNodeId())));                          bossPool.execute(new HeartBeat());                          bossPool.execute(new Recover());         }     }   class HeartBeat implements Runnable {         @Override         public void run() {             for(;;) {                 try {                                          DelayItem<Node> item = heartBeatQueue.take();                     if(item != null && item.getItem() != null) {                         Node node = item.getItem();                         handHeartBeat(node);                     }                     heartBeatQueue.offer(new DelayItem<>(config.getHeartBeatSeconds() * 1000,new Node(config.getNodeId())));                 } catch (Exception e) {                     logger.error("task heart beat error,cause by:{} ",e);                 }             }         }     }           private void handHeartBeat(Node node) {         if(node == null) {             return;         }                  Node currNode= nodeRepository.getByNodeId(node.getNodeId());         if(currNode == null) {             node.setRownum(nodeRepository.getNextRownum());             nodeRepository.insert(node);         } else  {             nodeRepository.updateHeartBeat(node.getNodeId());         }      }

数据库有了节点信息后,我们就可以实现各种花式的取任务的策略了,代码如下:

 public interface Strategy {           String DEFAULT = "default";           String ID_HASH = "id_hash";           String LEAST_COUNT = "least_count";           String WEIGHT = "weight";       public static Strategy choose(String key) {         switch(key) {             case ID_HASH:                 return new IdHashStrategy();             case LEAST_COUNT:                 return new LeastCountStrategy();             case WEIGHT:                 return new WeightStrategy();             default:                 return new DefaultStrategy();         }     }      public boolean accept(List<Node> nodes,Task task,Long myNodeId);  }

~

 public class IdHashStrategy implements Strategy {           @Override     public boolean accept(List<Node> nodes, Task task, Long myNodeId) {         int size = nodes.size();         long taskId = task.getId();                  Node myNode = nodes.stream().filter(node -> node.getNodeId() == myNodeId).findFirst().get();         return myNode == null ? false : (taskId % size) + 1 == myNode.getRownum();     }  }

~

 public class LeastCountStrategy implements Strategy {      @Override     public boolean accept(List<Node> nodes, Task task, Long myNodeId) {                   Optional<Node> min = nodes.stream().min((o1, o2) -> o1.getCounts().compareTo(o2.getCounts()));          return min.isPresent()? min.get().getNodeId() == myNodeId : false;     }  }

~

 public class WeightStrategy implements Strategy {      @Override     public boolean accept(List<Node> nodes, Task task, Long myNodeId) {         Node myNode = nodes.stream().filter(node -> node.getNodeId() == myNodeId).findFirst().get();         if(myNode == null) {             return false;         }                  int preWeightSum = nodes.stream().filter(node -> node.getRownum() < myNode.getRownum()).collect(Collectors.summingInt(Node::getWeight));                  int weightSum = nodes.stream().collect(Collectors.summingInt(Node::getWeight));                  int remainder = (int)(task.getId() % weightSum);         return remainder >= preWeightSum && remainder < preWeightSum + myNode.getWeight();     }  }

然后我们再改造下调度类

     private Strategy strategy;       @PostConstruct     public void init() {                  strategy = Strategy.choose(config.getNodeStrategy());                  workerPool = new ThreadPoolExecutor(config.getCorePoolSize(), config.getMaxPoolSize(), 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(config.getQueueSize()));                  bossPool.execute(new Loader());                  bossPool.execute(new Boss());      }      class Loader implements Runnable {          @Override         public void run() {             for(;;) {                 try {                                           List<Node> nodes = nodeRepository.getEnableNodes(config.getHeartBeatSeconds() * 2);                     if(nodes == null || nodes.isEmpty()) {                         continue;                     }                                          List<Task> tasks = taskRepository.listPeddingTasks(config.getFetchDuration());                     if(tasks == null || tasks.isEmpty()) {                         continue;                     }                     for(Task task:tasks) {                          boolean accept = strategy.accept(nodes, task, config.getNodeId());                                                  if(!accept) {                             continue;                         }                         task.setStatus(TaskStatus.DOING);                         task.setNodeId(config.getNodeId());                                                  int n = taskRepository.updateWithVersion(task);                         Date nextStartTime = task.getNextStartTime();                         if(n == 0 || nextStartTime == null) {                             continue;                         }                                                  task = taskRepository.get(task.getId());                         DelayItem<Task> delayItem = new DelayItem<Task>(nextStartTime.getTime() - new Date().getTime(), task);                         taskQueue.offer(delayItem);                      }                     Thread.sleep(config.getFetchPeriod());                 } catch(Exception e) {                     logger.error("fetch task list failed,cause by:{}", e);                 }             }         }      }

如上可以通过各种花式的负载策略来平衡各个节点获取的任务,同时也可以显著降低各个节点对同一个任务的竞争。

但是还有个问题,假如某个节点拿到了任务更新成了执行中,执行到一半,没执行完也没发生异常,突然这个节点由于各种原因挂了,那么这时候这个任务永远没有机会再执行了。这就是传说中的占着茅坑不拉屎。

解决这个问题可以用最终一致系统常见的方法,异常恢复线程。在这种场景下只需要检测一下指定心跳超时时间(比如默认3个心跳周期)下没有更新心跳时间的节点所属的未完成任务,将这些任务状态重新恢复成待执行,并且下次执行时间改成当前就可以了。核心代码如下:

class Recover implements Runnable {         @Override         public void run() {             for (;;) {                 try {                                          List<Task> tasks = taskRepository.listRecoverTasks(config.getHeartBeatSeconds() * 3);                     if(tasks == null || tasks.isEmpty()) {                         return;                     }                                        List<Node> nodes = nodeRepository.getEnableNodes(config.getHeartBeatSeconds() * 2);                    if(nodes == null || nodes.isEmpty()) {                        return;                    }                    long maxNodeId = nodes.get(nodes.size() - 1).getNodeId();                     for (Task task : tasks) {                                                  long currNodeId = chooseNodeId(nodes,maxNodeId,task.getNodeId());                         long myNodeId = config.getNodeId();                                                  if(currNodeId != myNodeId) {                             continue;                         }                                                  task.setStatus(TaskStatus.PENDING);                         task.setNextStartTime(new Date());                         task.setNodeId(config.getNodeId());                         taskRepository.updateWithVersion(task);                     }                     Thread.sleep(config.getRecoverSeconds() * 1000);                 } catch (Exception e) {                     logger.error("Get next task failed,cause by:{}", e);                 }             }         }      }      private long chooseNodeId(List<Node> nodes,long maxNodeId,long nodeId) {         if(nodeId > maxNodeId) {             return nodes.get(0).getNodeId();         }         return nodes.stream().filter(node -> node.getNodeId() > nodeId).findFirst().get().getNodeId();     }

如上为了避免每个节点的异常恢复线程对同一个任务做无谓的竞争,每个异常任务只能被任务所属节点ID的下一个正常节点去恢复。这样处理后就能确保就算出现了上面那种任务没执行完节点挂了的情况,一段时间后也可以自动恢复。

总的来说上面那些不考虑优化应该可以做为一个还不错的任务调度框架了。如果你们以为这样就完了,我只能说抱歉了,还有,哈哈!前面提到我是嫌弃其它任务调度用起来麻烦,特别是习惯用spring的注解写调度的,那些很可能一个类里写了n个带有@Scheduled注解的调度方法,这样改造起来更加麻烦,我是希望做到如下方式就可以直接整合到分布式任务调度里:

 @Component public class SchedulerTest {      @Scheduled(cron = "0/10 * * * * ?")     public void test1() throws InterruptedException {         SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");         Thread.sleep(2000);         System.out.println("当前时间1:"+sdf.format(new Date()));     }      @Scheduled(cron = "0/20 * * * * ?",parent = "test1")     public void test2() throws InterruptedException {         SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");         Thread.sleep(2000);         System.out.println("当前时间2:"+sdf.format(new Date()));     }      @Scheduled(cron = "0/10 * * * * ?",parent = "test2")     public void test3() throws InterruptedException {         SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");         Thread.sleep(2000);         System.out.println("当前时间3:"+sdf.format(new Date()));     }      @Scheduled(cron = "0/10 * * * * ?",parent = "test3")     public void test4() throws InterruptedException {         SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");         Thread.sleep(2000);         System.out.println("当前时间4:"+sdf.format(new Date()));     }  }

为了达到上述目标,我们还需要在spring启动后加载自定义的注解(名称和spring的一样),代码如下

 @Component public class ContextRefreshedListener implements ApplicationListener<ContextRefreshedEvent> {      @Autowired     private TaskExecutor taskExecutor;           private Map<String,Long> taskIdMap = new HashMap<>();      @Override     public void onApplicationEvent(ContextRefreshedEvent event) {                  if(event.getApplicationContext().getParent()==null){                          ApplicationContext context = event.getApplicationContext();             Map<String,Object> beans = context.getBeansWithAnnotation(org.springframework.scheduling.annotation.EnableScheduling.class);             if(beans == null) {                 return;             }                          Map<String,Method> methodMap = new HashMap<>();                          Map<String,Object> allBeans = context.getBeansWithAnnotation(org.springframework.stereotype.Component.class);             Set<Map.Entry<String,Object>> entrys = allBeans.entrySet();                          for(Map.Entry entry:entrys){                 Object obj = entry.getValue();                 Class clazz = obj.getClass();                 Method[] methods = clazz.getMethods();                 for(Method m:methods) {                     if(m.isAnnotationPresent(Scheduled.class)) {                         methodMap.put(clazz.getName() + Delimiters.DOT + m.getName(),m);                     }                 }             }                          handleSheduledAnn(methodMap);                          taskIdMap.clear();         }     }           private void handleSheduledAnn(Map<String,Method> methodMap) {         if(methodMap == null || methodMap.isEmpty()) {             return;         }         Set<Map.Entry<String,Method>> entrys = methodMap.entrySet();                  for(Map.Entry<String,Method> entry:entrys){             Method m = entry.getValue();             try {                 handleSheduledAnn(methodMap,m);             } catch (Exception e) {                 e.printStackTrace();                 continue;             }         }     }           private void handleSheduledAnn(Map<String,Method> methodMap,Method m) throws Exception {         Class<?> clazz = m.getDeclaringClass();         String name = m.getName();         Scheduled sAnn = m.getAnnotation(Scheduled.class);         String cron = sAnn.cron();         String parent = sAnn.parent();                  if(StringUtils.isEmpty(parent)) {             if(!taskIdMap.containsKey(clazz.getName() + Delimiters.DOT + name)) {                 Long taskId = taskExecutor.addTask(name, cron, new Invocation(clazz, name, new Class[]{}, new Object[]{}));                 taskIdMap.put(clazz.getName() + Delimiters.DOT + name, taskId);             }         } else {             String parentMethodName = parent.lastIndexOf(Delimiters.DOT) == -1 ? clazz.getName() + Delimiters.DOT + parent : parent;             Long parentTaskId = taskIdMap.get(parentMethodName);             if(parentTaskId == null) {                 Method parentMethod = methodMap.get(parentMethodName);                 handleSheduledAnn(methodMap,parentMethod);                                  parentTaskId = taskIdMap.get(parentMethodName);             }             if(parentTaskId != null && !taskIdMap.containsKey(clazz.getName() + Delimiters.DOT + name)) {                 Long taskId = taskExecutor.addChildTask(parentTaskId, name, cron, new Invocation(clazz, name, new Class[]{}, new Object[]{}));                 taskIdMap.put(clazz.getName() + Delimiters.DOT + name, taskId);             }          }       } }

上述代码就完成了spring初始化完成后加载了自己的自定义任务调度的注解,并且也受spring的调度开关@EnableScheduling的控制,实现无缝整合到spring或者springboot中去,达到了我这种的懒人的要求。

好了其实写这个框架差不多就用了5天业余时间,估计会有一些隐藏的坑,不过明显的坑我自己都解决了,开源出来的目的既是为了抛砖引玉,也为了广大屌丝程序员提供一种新的思路,希望对大家有所帮助,同时也希望大家多帮忙找找bug,一起来完善这个东西,大神们请忽略。

极牛网精选文章《实现最简单的分布式任务调度框架》文中所述为作者独立观点,不代表极牛网立场。如若转载请注明出处:https://geeknb.com/6404.html

(34)
打赏 微信公众号 微信公众号 微信小程序 微信小程序
主编的头像主编认证作者
上一篇 2019年11月15日 上午9:00
下一篇 2019年11月15日 上午10:41

相关推荐

发表回复

登录后才能评论
扫码关注
扫码关注
分享本页
返回顶部