【JAVA多线程】JDK线程同步工具:Phaser

目录

1.干什么的?

2.代码示例

2.1.线程间的协作

2.2.协作API

2.3.动态调整线程数

3.树形结构

4.state

5.Treiber Stack

6.核心源码


1.干什么的?

线程的同步无非就Semaphore、CountDownLatch、CyclicBarrier去应付的三大类情况了,其中线程之间有协作关系的是CountDownLatch、CyclicBarrier去对付的情况。线程间有协作关系的场景里线程数量是可能需要动态调整的,尤其是CyclicBarrier要面对的分阶段执行的场景,这个阶段可能是10条线程,下个阶段可能需要8条线程......Phaser就是用来干这事儿的。

Phaser:

  • 可以实现CountDownLatch、CyclicBarrier需要应付的线程之间协作的场景,也就是说可以实现这两者的效果。

  • 可以动态调整线程数量(核心能力)

除了以上外,Phaser还有一个核心点:

  • 支持任务之间存在前后依赖关系,比如B任务依赖于A任务的结果,支持B任务在A任务之前执行。

2.代码示例

2.1.线程间的协作

phaser实现CountDownLatch的效果,一条线程等待其它线程执行完:

public static void main(String[] args) throws InterruptedException {
        phaserDemo();
    }
​
    private static void phaserDemo1() {
        Phaser phaser = new Phaser(10);
        Thread threadParent = new Thread(() -> {
            //主线调用awaitAdvance程阻塞在当前轮次
            System.out.println(Thread.currentThread().getName()+"......awaitAdvance");
            phaser.awaitAdvance(phaser.getPhase());
            System.out.println(Thread.currentThread().getName()+"......wakeUp");
        });
        Thread threadChild = new Thread(() -> {
            //10条子线程随机到达
            ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(10, 20, 2L, TimeUnit.MINUTES, new ArrayBlockingQueue(10));
            for (int i = 0; i < threadPoolExecutor.getCorePoolSize(); i++) {
                threadPoolExecutor.execute(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            Thread.sleep(1000);
                            System.out.println(Thread.currentThread().getName() + "......arrive");
                            //子线程到达
                            phaser.arrive();
                        } catch (InterruptedException e) {
                            throw new RuntimeException(e);
                        }
                    }
                });
            }
            //关闭线程池,不然程序不会退出
            threadPoolExecutor.shutdown();
        });
        threadParent.start();
        threadChild.start();
    }

执行结果:

phaser实现CyclicBarrier的效果,线程之间相互协作:

private static void phaserDemo2(){
        Phaser phaser = new Phaser(10);
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(10, 20, 2L, TimeUnit.MINUTES, new ArrayBlockingQueue(10));
        for(int i=0;i<threadPoolExecutor.getCorePoolSize();i++){
            threadPoolExecutor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        Thread.sleep(1000);
                        //线程之间相互到达、等待
                        System.out.println("phase"+phaser.getPhase()+"......"+Thread.currentThread().getName()+"......arriveAndAwaitAdvance");
                        phaser.arriveAndAwaitAdvance();
                        System.out.println(Thread.currentThread().getName()+"......wakeUp");
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                }
            });
        }
​
        for(int i=0;i<threadPoolExecutor.getCorePoolSize();i++){
            threadPoolExecutor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        Thread.sleep(1000);
                        //线程之间相互到达、等待
                        System.out.println("phase"+phaser.getPhase()+"......"+Thread.currentThread().getName()+"......arriveAndAwaitAdvance");
                        phaser.arriveAndAwaitAdvance();
                        System.out.println(Thread.currentThread().getName()+"......wakeUp");
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                }
            });
        }
        threadPoolExecutor.shutdown();
    }

执行结果:

2.2.协作API

  • arrive和awaitAdvance

    • arrive,线程已到达

    • awaitAdvance,线程阻塞在当前轮次

  • arriveAndAwaitAdvance

    • 线程到达并等待其它参与协作的线程全部到达

  • arriveAndDeregister

    • 线程到达,并从phaser中注销,即协作的线程数量-1

2.3.动态调整线程数

phaser可以动态的调整参与协作的线程数量是其核心能力,调整的方式有如下:

  • 注册新参与者: 使用register() 方法来将协作线程数+1

  • 取消注册参与者: 使用arriveAndDeregister() 方法来表示当前线程到达并且phaser的协作线程数-1

  • 手动调整: 使用 bulkRegister(int parties) 方法一次注册多个参与者。 使用 bulkArriveAndDeregister(int parties) 方法一次取消注册多个参与者。

  • 查询参与者状态: 使用 getRegisteredParties() 获取注册的参与者数量。 使用 getArrivedParties() 获取已到达当前阶段的参与者数量。

以上都是Phaser的API,可以在任何阶段来进行线程的调整。

3.树形结构

前面我们说过了Phaser有一个核心能力是支持任务之间的前后依赖关系,如B任务依赖于A任务,那么A任务在B任务之前执行。

再说清楚一点:

Phaser支持在单轮次里用依赖关系来控制任务的执行顺序!

如何实现喃?其实很容易能想到,链表关系,但是链表关系很明显只支持单依赖,没办法支持多依赖,用树形结构就能支持多依赖了,B任务和C任务依赖于A任务,那么A任务是父节点,B任务和C任务是子节点。

Phaser允许组成树形结构:

Phaser phaser0=new Phaser(1);
        Phaser phaser1 = new Phaser(phaser0,1);
        Phaser phaser2 = new Phaser(phaser0,1);
        Phaser phaser3 = new Phaser(phaser1,1);
        Phaser phaser4 = new Phaser(phaser2,1);
        Phaser phaser5 = new Phaser(phaser2,1);

再推深一点:

先从构造方法进去,可以看到所有构造本质上都是调用的一个构造方法:

public Phaser(int parties) {
    this(null, parties);
}

这个构造方法里会完成state和Treiber Stack的初始化:

这里我们可以看到Phaser是允许传一个父Phaser进构造方法来组成树形结构的,不管组成树状结构与否,全局都用的一个Treiber Stack和state。

所以我们更能理解了,树形结构的存在就是为了以依赖关系来控制执行顺序,父子之间只有执行顺序的差异,资源都是共用的一套。

4.state

其中轮数叫phase。

Phaser给出了一堆方法用来获取这四部分的值:

5.Treiber Stack

phaser并没有依托于AQS来实现,所以自己实现了一套完整的线程阻塞唤醒逻辑,被阻塞的线程放在Treiber Stack中。

Treiber Stack是R.Kent Treibe r在其于1986年发表的论文Systems Programming:Coping with Para llelism 中首次提出。说白了就是一个用链表实现的栈。

Phaser内部实现了Treiber Stack。因为是链表结构,所以有Node节点,节点里面存线程:

由于是栈结构,所以Phaser里面只有一个头指针,永远的指向栈顶,只是为了减少并发冲突,这里定义了2个链表,也就是2个Treiber Sta ck。当phase为奇数轮的时候,阻塞线程放在oddQ里面;当phase为偶 数轮的时候,阻塞线程放在evenQ里面:

6.核心源码

核心源码无非就是阻塞等待的awaitAdvance方法和线程到达的arrive方法。

先来看arrive方法,arrive方法会去调用doArrive方法,默认传参ONE_ARRIVAL为1:

叫大模型帮我们给doArrive方法加上注释来读一下:

private int doArrive(int adjust) {
    final Phaser root = this.root;
    for (;;) {
        // 如果当前对象是根节点,则直接获取 state 字段;
        // 否则调用 reconcileState() 方法来获取正确的状态。
        long s = (root == this) ? state : reconcileState();
​
        // 提取当前阶段的值。
        int phase = (int)(s >>> PHASE_SHIFT);
​
        // 如果阶段小于 0,这通常意味着发生了错误,直接返回阶段值。
        if (phase < 0)
            return phase;
​
        // 获取状态中的计数部分。
        int counts = (int)s;
​
        // 计算未到达的参与者数量。
        int unarrived = (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK);
​
        // 如果没有未到达的参与者,则抛出异常。
        if (unarrived <= 0)
            throw new IllegalStateException(badArrive(s));
​
        // 使用 CAS 更新状态。adjust 参数决定了状态应该如何被调整。
        if (UNSAFE.compareAndSwapLong(this, stateOffset, s, s-=adjust)) {
            // 如果只有一个未到达的参与者,则执行以下逻辑:
            if (unarrived == 1) {
                // 获取下一阶段的基础状态。
                long n = s & PARTIES_MASK;
​
                // 提取下一阶段未到达的参与者数量。
                int nextUnarrived = (int)n >>> PARTIES_SHIFT;
​
                // 如果当前对象是根节点,则进行以下操作:
                if (root == this) {
                    // 如果 onAdvance 返回 true,则设置终止标志。
                    if (onAdvance(phase, nextUnarrived))
                        n |= TERMINATION_BIT;
                    // 如果下一阶段没有未到达的参与者,则设置空闲标志。
                    else if (nextUnarrived == 0)
                        n |= EMPTY;
                    // 否则,设置下一阶段未到达的参与者数量。
                    else
                        n |= nextUnarrived;
​
                    // 更新下一阶段的阶段号。
                    int nextPhase = (phase + 1) & MAX_PHASE;
                    n |= (long)nextPhase << PHASE_SHIFT;
​
                    // 再次使用 CAS 更新状态。
                    UNSAFE.compareAndSwapLong(this, stateOffset, s, n);
​
                    // 调用 releaseWaiters 方法释放等待者。
                    releaseWaiters(phase);
                }
                // 如果当前对象不是根节点且下一阶段没有未到达的参与者,则递归调用父节点的 doArrive 方法,并设置空闲标志。
                else if (nextUnarrived == 0) {
                    phase = parent.doArrive(ONE_DEREGISTER);
                    UNSAFE.compareAndSwapLong(this, stateOffset,
                                              s, s | EMPTY);
                }
                // 否则,递归调用父节点的 doArrive 方法。
                else
                    phase = parent.doArrive(ONE_ARRIVAL);
            }
            // 返回当前阶段号。
            return phase;
        }
    }
}

对上面的源码进行一下总结:

  • 未到达者unarrived==1,说明当前这条线程就是最后一条未到达线程,所以会进行一些列的资源操作,推进到下一轮。

  • 其中releaseWaiters方法会去唤醒阻塞在当前轮次的线程,也就是调用awaitAdvance方法的线程。

  • 先去让父Phaser去doArrive,因为Phaser的树形结构的存在是为了满足任务之间有依赖的情况,设计上是子Phaser对父Phaser有依赖,在组成树形结构的时候,依赖者为子,被依赖者为父。这样能在单轮次里面绝对控制住依赖关系的先后执行。

awaitAdvance的逻辑也不复杂,核心就是将线程加入Treiber Stack,然后阻塞该线程:

private int internalAwaitAdvance(int phase, QNode node) {
    // assert root == this;  // 确保当前对象是根节点
​
    releaseWaiters(phase-1);          // 清理旧的等待队列
​
    boolean queued = false;           // 表示节点是否已经被加入到等待队列中
    int lastUnarrived = 0;            // 用于增加自旋次数的变量
    int spins = SPINS_PER_ARRIVAL;    // 初始自旋次数
​
    long s;
    int p;
    while ((p = (int)((s = state) >>> PHASE_SHIFT)) == phase) {
        // 如果 node 为空,则表示当前线程正在以不可中断模式自旋等待
        if (node == null) {
            int unarrived = (int)s & UNARRIVED_MASK;
            if (unarrived != lastUnarrived &&
                (lastUnarrived = unarrived) < NCPU)
                spins += SPINS_PER_ARRIVAL;  // 增加自旋次数
​
            boolean interrupted = Thread.interrupted();  // 检查线程是否被中断
            if (interrupted || --spins < 0) {  // 如果线程被中断或自旋次数耗尽
                node = new QNode(this, phase, false, false, 0L);
                node.wasInterrupted = interrupted;  // 创建 QNode 并记录中断状态
            }
        }
        // 如果 node 不为空并且可以被释放,则退出循环
        else if (node.isReleasable())
            break;
        // 如果 node 还没有被加入到队列中
        else if (!queued) {
            AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;  // 获取当前阶段的队列头
            QNode q = node.next = head.get();  // 将 node 加入到队列中
            if ((q == null || q.phase == phase) &&
                (int)(state >>> PHASE_SHIFT) == phase)  // 避免加入过时的节点
                queued = head.compareAndSet(q, node);  // 尝试将 node 设置为队列头
        }
        // 如果 node 已经加入到队列中,则阻塞当前线程
        else {
            try {
                ForkJoinPool.managedBlock(node);
            } catch (InterruptedException ie) {
                node.wasInterrupted = true;  // 如果线程被中断,则记录中断状态
            }
        }
    }
​
    // 如果 node 不为空
    if (node != null) {
        if (node.thread != null)
            node.thread = null;  // 清理 thread 引用
​
        if (node.wasInterrupted && !node.interruptible)
            Thread.currentThread().interrupt();  // 如果线程被中断且不可中断,则重新设置中断标志
​
        if (p == phase && (p = (int)(state >>> PHASE_SHIFT)) == phase)
            return abortWait(phase);  // 如果阶段没有变化,则可能需要清理
    }
​
    releaseWaiters(phase);  // 清理等待队列
​
    return p;  // 返回当前阶段号
}

注意:

在前面聊Treiber的时候只说了它的栈的结构,但是没有说线程的阻塞,这里可以看到线程是通过调用ForkJoinPool.managedBlock来阻塞的,这里用ForkJoinPool来阻塞线程,并不是为了用到ForkJoinPool的高效调度能力,毕竟在这里所有线程都不是放在一个ForkJoinPool里的,只是用managedBlock将线程封装成了ForkJoinWorkerThread类型而已,ForkJoinWorkerThread能支持对中断进行响应,仅此而已,这里不要被迷惑了,觉得用到了ForkJoinPool核心目的是为了加快资源调度,其实不是这样的。

评论 28
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_BugMan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值