面试官:讲讲你对ConcurrentHashMap的理解

Published on with 0 views and 0 comments

为什么要用 ConcurrentHashMap?

HashMap -> 非线程安全的

HashTable -> synchronized(偏向锁、轻量级锁(CAS)),锁的粒度太粗

ConcurrentHashMap -> 锁的粒度细,而且有很多优化操作在里面,比如它的并发扩容、高低位迁移、红黑树、链表等等。

ConcurrentHashMap 的使用

jdk1.8 的 map 引入了新的几个方法:

computeIfAbsent

computeIfPresent

compute(computeIfAbsent 和 computeIfPresent 两者的结合)

merge(可用于计数)

ConcurrentHashMap 的存储结构

链表用来解决 hash 冲突问题,红黑树用来解决链表过长的问题。
image20220401145445jvn90z120220402222819mbygy7y.png

put 方法源码分析
final V putVal(K key, V value, boolean onlyIfAbsent) {
        if (key == null || value == null) throw new NullPointerException();
	//计算哈希值
        int hash = spread(key.hashCode());
        int binCount = 0;
	//自旋,保证cas成功
        for (Node<K,V>[] tab = table;;) {
            Node<K,V> f; int n, i, fh;
            if (tab == null || (n = tab.length) == 0)
		//初始化数组
                tab = initTable();
	    //计算数组下标位置,如果这个下标没有存放node的话
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
		//如果cas成功了就直接结束
                if (casTabAt(tab, i, null,
                             new Node<K,V>(hash, key, value, null)))
                    break;                   // no lock when adding to empty bin
            }
	    //MOVED表示当前节点正在扩容时迁移
            else if ((fh = f.hash) == MOVED)
		//帮忙扩容
                tab = helpTransfer(tab, f);
            else {
                V oldVal = null;
		//锁住当前node节点,避免线程冲突问题,优化体现在这里,粒度更细
                synchronized (f) {
		    //重新判断,避免在加锁前f发生了变化。
                    if (tabAt(tab, i) == f) {
			//这个针对链表
                        if (fh >= 0) {
                            binCount = 1;
			    //遍历整个链表,binCount 是步长。
                            for (Node<K,V> e = f;; ++binCount) {
                                K ek;
				//判断hash相同,并且key也相同,则覆盖
                                if (e.hash == hash &&
                                    ((ek = e.key) == key ||
                                     (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
				//如果hash不存在,就遍历找到最后一个节点e,然后把当前的key/value添加到链表中,尾插法
                                Node<K,V> pred = e;
                                if ((e = e.next) == null) {
                                    pred.next = new Node<K,V>(hash, key,
                                                              value, null);
                                    break;
                                }
                            }
                        }
			//这个针对红黑树
                        else if (f instanceof TreeBin) {
                            Node<K,V> p;
                            binCount = 2;
                            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                           value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                }
		//到此处,binCount 相当于统计了链表的长度
                if (binCount != 0) {
		    //如果链表的长度大于等于8,则会调用treeifyBin(tab, i)方法,根据阈值来判断转化为红黑树还是扩容
		    //当数组长度大于64时,才回去做红黑树的转化
                    if (binCount >= TREEIFY_THRESHOLD)
                        treeifyBin(tab, i);
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
        addCount(1L, binCount);
        return null;
    }
数组初始化 initTable()

有可能多个线程去调用 initTable()方法去初始化,用 cas 加锁就行了,成功一次就行了。

private transient volatile int sizeCtl;

这个 sizeCtl 有不同的值,而且会经常用到

默认情况下为 0

当数组初始化过程中时它为-1

当数组初始化完成后它为下次扩容的阈值即 0.75*n

扩容中时它的值会小于-1,扩容完成后又会变成下次扩容的阈值即 0.75*n

private final Node<K,V>[] initTable() {
        Node<K,V>[] tab; int sc;
	//只要table没有初始化就不断循环
        while ((tab = table) == null || tab.length == 0) {
	    //sizeCtl相当于aqs中的state,就是配合cas的一个变量值,用于标记抢到了锁
            if ((sc = sizeCtl) < 0)
                Thread.yield(); // lost initialization race; just spin
	    //通过cas来占用一个锁的标记
            else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                try {
                    if ((tab = table) == null || tab.length == 0) {
			//n默认=16
                        int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                        @SuppressWarnings("unchecked")
                        Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                        table = tab = nt;
			//sc = 16 - (16 >> 2),这里的sc是扩容的阈值,即12
                        sc = n - (n >>> 2);
                    }
                } finally {
                    sizeCtl = sc;
                }
                break;
            }
        }
        return tab;
    }
扩容 treeifyBin()
private final void treeifyBin(Node<K,V>[] tab, int index) {
        Node<K,V> b; int n, sc;
        if (tab != null) {
	    //如果tab的长度小于64,则进行扩容
            if ((n = tab.length) < MIN_TREEIFY_CAPACITY)
                tryPresize(n << 1);
	    //否则进行红黑树的转化
            else if ((b = tabAt(tab, index)) != null && b.hash >= 0) {
                synchronized (b) {
                    if (tabAt(tab, index) == b) {
                        TreeNode<K,V> hd = null, tl = null;
                        for (Node<K,V> e = b; e != null; e = e.next) {
                            TreeNode<K,V> p =
                                new TreeNode<K,V>(e.hash, e.key, e.val,
                                                  null, null);
                            if ((p.prev = tl) == null)
                                hd = p;
                            else
                                tl.next = p;
                            tl = p;
                        }
                        setTabAt(tab, index, new TreeBin<K,V>(hd));
                    }
                }
            }
        }
    }
扩容方法 tryPresize()

并发扩容

扩容的本质是 16->32,将数组扩容一倍,然后将老数组的数据迁移到新的数组

private final void tryPresize(int size) {
	//c是扩容后数组应该为多大
	//MAXIMUM_CAPACITY >>> 1 数组最大长度无符号右移1位,即二分之一大小,536870912
	//tableSizeFor会将size自动设置成16的倍数
        int c = (size >= (MAXIMUM_CAPACITY >>> 1)) ? MAXIMUM_CAPACITY :
            tableSizeFor(size + (size >>> 1) + 1);
        int sc;
	//根据sizeCtl判断,数组开始扩容时就退出while循环
        while ((sc = sizeCtl) >= 0) {
            Node<K,V>[] tab = table; int n;
	    //如果为空就初始化数组,跟之前的initTable()方法一样
            if (tab == null || (n = tab.length) == 0) {
		//初始容量和扩容的目标容量,谁最大选谁。
                n = (sc > c) ? sc : c;
                if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                    try {
                        if (table == tab) {
                            @SuppressWarnings("unchecked")
                            Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                            table = nt;
                            sc = n - (n >>> 2);
                        }
                    } finally {
                        sizeCtl = sc;
                    }
                }
            }
	    //如果已经是最大容量了,直接返回
            else if (c <= sc || n >= MAXIMUM_CAPACITY)
                break;
            else if (tab == table) {
		//扩容戳. 保证当前扩容范围的唯一性
                int rs = resizeStamp(n);
		//判断sc<0,第一次扩容的时候,不会走这段逻辑,因为只有在扩容中的时候sc才会小于0变成-1
                if (sc < 0) {
                    Node<K,V>[] nt;
		    //表示扩容结束
                    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                        sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                        transferIndex <= 0)
                        break;
		    //高16位表示当前的扩容标记,保证唯一性
		    //低16位表示当前扩容的线程数量
		    //每增加一个扩容线程,就会在低16位+1
                    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                        transfer(tab, nt);
                }
		//第一次扩容走这段逻辑
                else if (U.compareAndSwapInt(this, SIZECTL, sc,
                                             (rs << RESIZE_STAMP_SHIFT) + 2))
                    transfer(tab, null);
            }
        }
    }
实现数据转移 transfer()

重点要理解如何实现多个线程的数组进行数据迁移
image2022040117080493c2d9620220402222819sbjww0l.png

private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
        int n = tab.length, stride;
	//计算每个线程处理数据的区间大小,默认最小是16,当数组长度大时,会扩大区间大小
        if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
            stride = MIN_TRANSFER_STRIDE; // subdivide range
	//nextTab表示扩容之后的数组
        if (nextTab == null) {            // initiating
            try {
                @SuppressWarnings("unchecked")
		//在原来的基础上扩大两倍
                Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
                nextTab = nt;
            } catch (Throwable ex) {      // try to cope with OOME
                sizeCtl = Integer.MAX_VALUE;
                return;
            }
            nextTable = nextTab;
            transferIndex = n;
        }
        int nextn = nextTab.length;
	//fwd用来表示已经迁移完的状态,如果某个old数组的节点完成了迁移,则需要更改fwd
        ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
        boolean advance = true;
        boolean finishing = false; // to ensure sweep before committing nextTab
	//for循环自旋
        for (int i = 0, bound = 0;;) {
            Node<K,V> f; int fh;
            while (advance) {
                int nextIndex, nextBound;
                if (--i >= bound || finishing)
                    advance = false;
                else if ((nextIndex = transferIndex) <= 0) {
                    i = -1;
                    advance = false;
                }
		//修改TRANSFERINDEX,区间的计算
		//假设数组长度是32,第一次是[16(nextBound),31(i)],第二次是[0,15]
                else if (U.compareAndSwapInt
                         (this, TRANSFERINDEX, nextIndex,
                          nextBound = (nextIndex > stride ?
                                       nextIndex - stride : 0))) {
                    bound = nextBound;
                    i = nextIndex - 1;
                    advance = false;
                }
            }
	    //是否扩容结束
            if (i < 0 || i >= n || i + n >= nextn) {
                int sc;
                if (finishing) {
                    nextTable = null;
                    table = nextTab;
                    sizeCtl = (n << 1) - (n >>> 1);
                    return;
                }
		//sc-1代表线程数-1
                if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
		    //sc-2是因为之前是+2所以要-2,用2是因为不想和其他状态值有冲突
                    if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                        return;
                    finishing = advance = true;
                    i = n; // recheck before commit
                }
            }
	    //得到数组下标为i的值,如果为空说明不需要迁移直接修改为fwd表示迁移完成
            else if ((f = tabAt(tab, i)) == null)
                advance = casTabAt(tab, i, null, fwd);
	    //判断当前节点是否已经被处理过了,如果是,进入下一次遍历
            else if ((fh = f.hash) == MOVED)
                advance = true; // already processed
            else {
		//针对当前要去迁移的节点进行加锁,保证迁移过程中其他线程调用put时要等待
		//针对链表或者红黑树做不同处理
		//当数组扩容之后会出现数组元素位置发生变化的情况,因为15&hash和31&hash的结果不一定一样
                synchronized (f) {
                    if (tabAt(tab, i) == f) {
			//ln表示低位链表,hn表示高位链表
			//低位链表表示遍历到某一个链表节点时发现这个节点及其后方节点都不需要变,
			//直接将后面的所有节点变为一个链表
                        Node<K,V> ln, hn;
                        if (fh >= 0) {
			    //fh代表hash值,n代表数组长度
                            int runBit = fh & n;
                            Node<K,V> lastRun = f;
			    //遍历当前节点的链表,
                            for (Node<K,V> p = f.next; p != null; p = p.next) {
                                int b = p.hash & n;
                                if (b != runBit) {
                                    runBit = b;
                                    lastRun = p;
                                }
                            }
                            if (runBit == 0) {
                                ln = lastRun;
                                hn = null;
                            }
                            else {
                                hn = lastRun;
                                ln = null;
                            }
			    //遍历链表重新计算数组位置
                            for (Node<K,V> p = f; p != lastRun; p = p.next) {
                                int ph = p.hash; K pk = p.key; V pv = p.val;
                                if ((ph & n) == 0)
                                    ln = new Node<K,V>(ph, pk, pv, ln);
                                else
                                    hn = new Node<K,V>(ph, pk, pv, hn);
                            }
                            setTabAt(nextTab, i, ln);
                            setTabAt(nextTab, i + n, hn);
                            setTabAt(tab, i, fwd);
                            advance = true;
                        }
                        else if (f instanceof TreeBin) {
                            TreeBin<K,V> t = (TreeBin<K,V>)f;
                            TreeNode<K,V> lo = null, loTail = null;
                            TreeNode<K,V> hi = null, hiTail = null;
                            int lc = 0, hc = 0;
                            for (Node<K,V> e = t.first; e != null; e = e.next) {
                                int h = e.hash;
                                TreeNode<K,V> p = new TreeNode<K,V>
                                    (h, e.key, e.val, null, null);
                                if ((h & n) == 0) {
                                    if ((p.prev = loTail) == null)
                                        lo = p;
                                    else
                                        loTail.next = p;
                                    loTail = p;
                                    ++lc;
                                }
                                else {
                                    if ((p.prev = hiTail) == null)
                                        hi = p;
                                    else
                                        hiTail.next = p;
                                    hiTail = p;
                                    ++hc;
                                }
                            }
                            ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
                                (hc != 0) ? new TreeBin<K,V>(lo) : t;
                            hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
                                (lc != 0) ? new TreeBin<K,V>(hi) : t;
                            setTabAt(nextTab, i, ln);
                            setTabAt(nextTab, i + n, hn);
                            setTabAt(tab, i, fwd);
                            advance = true;
                        }
                    }
                }
            }
        }
    }
扩容完成之后统计个数 addCount()
private final void addCount(long x, int check) {
        CounterCell[] as; long b, s;
	//统计元素个数
        if ((as = counterCells) != null ||
            !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
            CounterCell a; long v; int m;
            boolean uncontended = true;
            if (as == null || (m = as.length - 1) < 0 ||
                (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
                !(uncontended =
                  U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
                fullAddCount(x, uncontended);
                return;
            }
            if (check <= 1)
                return;
            s = sumCount();
        }
	//是否要做扩容
        if (check >= 0) {
            Node<K,V>[] tab, nt; int n, sc;
            while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
                   (n = tab.length) < MAXIMUM_CAPACITY) {
                int rs = resizeStamp(n);
                if (sc < 0) {
                    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                        sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                        transferIndex <= 0)
                        break;
                    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                        transfer(tab, nt);
                }
                else if (U.compareAndSwapInt(this, SIZECTL, sc,
                                             (rs << RESIZE_STAMP_SHIFT) + 2))
                    transfer(tab, null);
                s = sumCount();
            }
        }
    }

fullAddCount()

private final void fullAddCount(long x, boolean wasUncontended) {
        int h;
        if ((h = ThreadLocalRandom.getProbe()) == 0) {
            ThreadLocalRandom.localInit();      // force initialization
            h = ThreadLocalRandom.getProbe();
            wasUncontended = true;
        }
        boolean collide = false;                // True if last slot nonempty
        for (;;) {
            CounterCell[] as; CounterCell a; int n; long v;
            if ((as = counterCells) != null && (n = as.length) > 0) {
                if ((a = as[(n - 1) & h]) == null) {
                    if (cellsBusy == 0) {            // Try to attach new Cell
                        CounterCell r = new CounterCell(x); // Optimistic create
                        if (cellsBusy == 0 &&
                            U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
                            boolean created = false;
                            try {               // Recheck under lock
                                CounterCell[] rs; int m, j;
                                if ((rs = counterCells) != null &&
                                    (m = rs.length) > 0 &&
                                    rs[j = (m - 1) & h] == null) {
                                    rs[j] = r;
                                    created = true;
                                }
                            } finally {
                                cellsBusy = 0;
                            }
                            if (created)
                                break;
                            continue;           // Slot is now non-empty
                        }
                    }
                    collide = false;
                }
                else if (!wasUncontended)       // CAS already known to fail
                    wasUncontended = true;      // Continue after rehash
                else if (U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))
                    break;
                else if (counterCells != as || n >= NCPU)
                    collide = false;            // At max size or stale
                else if (!collide)
                    collide = true;
		//扩容counterCells 数组
                else if (cellsBusy == 0 &&
                         U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
                    try {
                        if (counterCells == as) {// Expand table unless stale
                            CounterCell[] rs = new CounterCell[n << 1];
			    //遍历数组添加到新的数组中
                            for (int i = 0; i < n; ++i)
                                rs[i] = as[i];
                            counterCells = rs;
                        }
                    } finally {
                        cellsBusy = 0;
                    }
                    collide = false;
                    continue;                   // Retry with expanded table
                }
                h = ThreadLocalRandom.advanceProbe(h);
            }
	    //如果counterCells 为空,初始化counterCells 
	    //CELLSBUSY标识用来保证在初始化过程中的线程安全性
            else if (cellsBusy == 0 && counterCells == as &&
                     U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
                boolean init = false;
                try {                           // Initialize table
                    if (counterCells == as) {
			//取模&1,将x保存到CounterCell的0下标或者1下标
                        CounterCell[] rs = new CounterCell[2];
                        rs[h & 1] = new CounterCell(x);
                        counterCells = rs;
                        init = true;
                    }
                } finally {
		    //释放锁
                    cellsBusy = 0;
                }
                if (init)
                    break;
            }
	    //兜底方案,如果CounterCell修改不成功,尝试修改baseCount
            else if (U.compareAndSwapLong(this, BASECOUNT, v = baseCount, v + x))
                break;                          // Fall back on using base
        }
    }

HashMap 里面是有一个成员变量 size 来统计个数

transient int size;

ConcurrentHashMap里面,保证最终一致性

如果竞争不激烈的情况下,直接用 cas 将 baseCount+1

private transient volatile long baseCount;//基本元素统计

如果竞争激烈的情况下,采用 counterCells 数组进行计数

counterCells 长度为 2,随机负载,落到 0 下标就对 0 下标的 value 进行 cas 操作,落到 1 下标就对 1 下标进行 cas 操作,

这样就将竞争程度下降了一倍,统计 size()的时候,遍历 counterCells 数组,将数组值进行累加,然后 baseCount+counterCells 数组累加的数。

如果 counterCells 长度为 2 不够,counterCells 会动态扩容。

private transient volatile CounterCell[] counterCells;
final long sumCount() {
        CounterCell[] as = counterCells; CounterCell a;
        long sum = baseCount;
        if (as != null) {
            for (int i = 0; i < as.length; ++i) {
                if ((a = as[i]) != null)
                    sum += a.value;
            }
        }
        return sum;
    }
总结
  • 计算 key 的哈希值

  • for 自旋保证 put 成功

    • 如果没有初始化就初始化 table

      • 有可能多个线程去调用 initTable()方法去初始化,用 cas 加锁就行了,成功一次就行了
    • 通过与哈希取模计算数组下标,如果下标节点为 null,就通过 cas 放进数组当前下标的位置

    • 如果当前下标有值,并且发现当前节点正在做扩容迁移操作,就去帮助扩容

    • 如果既有值,又没在扩容,就锁住这个数组下标节点,开始进行 put 操作

      • 第一种情况当前节点是一个链表

        • 遍历整个链表
        • 判断 hash 相同,并且 key 也相同,则覆盖
        • 如果 hash 不存在,此时已经遍历到了最后一个节点 e,然后把当前的 key/value 添加到链表 e 节点的后 i 面,尾插法
      • 第二种情况当前节点是红黑树

        • 将节点放入红黑树,具体怎么放的参考我另一篇同系列下的文章之红黑树
    • put 进去之后,会对链表长度进行判断,如果链表的长度大于等于 8,进行扩容或者转化为红黑树

      • 链表的扩容

        • 如果 tab 的长度小于 64,则调用 tryPresize()方法进行扩容

        • 链表的扩容的本质是 16->32,将数组扩容一倍,然后将老数组的数据迁移到新的数组

        • 如果为空就初始化数组,跟之前的 initTable()方法一样

        • 如果已经是最大容量了,直接返回

        • 判断 sizeCtl 是否小于 0,因为只有在扩容中的时候 sizeCtl 才会小于 0 变成-1,

        • 多线程扩容,高 16 位表示当前的扩容标记,保证唯一性,低 16 位表示当前扩容的线程数量,每增加一个扩容线程,就会在低 16 位 +1

          • 实现数据转移 transfer()

          • 计算每个线程处理数据的区间大小,默认最小是 16,当数组长度大时,会扩大区间大小

            • 链表的情况

              • 遍历旧链表,使用 hash&新数组长度重新计算数组下标位置,
              • ln 表示低位链表,hn 表示高位链表
              • 低位链表表示遍历到某一个链表节点时发现这个节点及其后方节点都不需要变,直接将后面的所有节点变为一个链表
            • 红黑树的情况

              • 左旋右旋...
  • 扩容完成之后,统计个数,table 的 size+1

    • HashMap 里面是有一个成员变量 size 来统计个数
    • 如果竞争不激烈的情况下,直接用 cas 将 baseCount+1
    • 如果竞争激烈的情况下,采用 counterCells 数组进行计数
    • counterCells 长度为 2,随机负载,落到 0 下标就对 0 下标的 value 进行 cas 操作,落到 1 下标就对 1 下标进行 cas 操作
    • 这样就将竞争程度下降了一倍,统计 size()的时候,遍历 counterCells 数组,将数组值进行累加,然后 baseCount+counterCells 数组累加的数。
    • 如果 counterCells 长度为 2 不够,counterCells 会动态扩容。

下面这两个连接可以再加上拳打

震惊!ConcurrentHashMap 里面也有死循环,作者留下的“彩蛋”了解一下? - 掘金
这道面试题我真不知道面试官想要的回答是什么


标题:面试官:讲讲你对ConcurrentHashMap的理解
作者:cuijianzhe
地址:https://cjzshilong.cn/articles/2022/04/03/1648949875595.html