从开源框架学一致性哈希算法
前言
你也知道我最近在看 Seata 这个分布式事务框架,还给官方提了一篇 博客,有兴趣可以看看。
当然,这篇文章也不是用来打广告的,事情的起因也是我在 Seata 的源码中发现了一个叫做 ConsistentHashLoadBalance 的东西,这个叫做“一致性哈希负载均衡”的小东西吸引力我的目光,所以,这不就来研究一下了吗。
哈希负载均衡
先来看看什么是“哈希负载均衡”。
假设你们公司有 5 个服务器节点,当一个用户请求到来的时候,它该去访问哪一台服务器呢?
选择哪个服务器节点处理用户请求 的这个过程就叫做“负载均衡”,“负载均衡”的算法有很多,像什么随机、轮询、加权轮询、最少连接、加权最少连接等。
这里我们所说的“哈希负载均衡”也是一种负载均衡算法,算法也很简单:
hash(param) % node_num -> node
就是对请求的某个参数进行哈希运算,将得到的结果 % 节点数量,就可以知道这个请求应该被哪个服务器节点处理了。
这里参数的选择有很多,比如客户端 ip、请求 url、或者是请求参数的某个变量。
我想你应该发现这种方式的弊端了,即:当节点上/下线导致节点数量变化时,请求的路由关系将被打乱,需要重新计算。
这一点或许在这种情况下(服务器节点路由的情况)还能接受,但换个例子,在分库分表的时候,如果我们也采用“哈希”算法进行路由,由于分片键是不会变的,所有一条数据在分表/库数量一定的情况下,它存储的表和库是固定的。
但是一旦我们需要对数据库进行扩容,那就有的玩了,我们就需要重新计算每一行数据的新的路由表/库,这你受得了吗?
但是现在大多数分库分表还是采用的这种“哈希”算法,你说这又是为什么呢?这里面的关键是分表/库的数量。
一般来说,我们会将分表或者分库的数量设定为 2^n,这样当进行翻倍扩容时,每一行数据要么在原来的位置,要么在原来的位置 + 原分表/库数量。
这里面的逻辑和 HashMap 的扩容是一致的,这样做了之后,就可以将数据迁移的单位由行变为表,甚至是库,这里给个关键词:“2n 平滑扩容”,你可以自行了解一下。
所以,总的来说,这种“哈希”算法的负载均衡策略会涉及到大量数据的 rehash,而要想尽量减少这种情况的发生,就要看看我们接下来提到的“一致性哈希算法”了。
一致性哈希算法
先整一套官方的解释:
一致性哈希算法在 1997 年由麻省理工学院提出,是一种特殊的哈希算法,在移除或者添加一个服务器时,能够尽可能小地改变已存在的服务请求与处理请求服务器之间的映射关系。一致性哈希解决了简单哈希算法在分布式哈希表(Distributed Hash Table,DHT)中存在的动态伸缩等问题。
那具体是怎么做的呢?看下去。
假设你有一个环,我们暂且称它为“哈希环”,它由 -2^31 ~ 2^31 - 1 的点组成,就像下面这样:
与此同时,我们也有 5 台服务器,我们需要将这 5 台服务器分布到这个哈希环上,比如我们对服务器的 ip + port 进行哈希计算,得到一个 int 值,这个值就对应了哈希环上的位置,这下你知道为啥这个环上的点的范围是 -2^31 ~ 2^31 - 1 了吧。
当用户请求来了之后,还是对某个参数(客户端 ip、请求 url、请求参数)进行哈希计算,得到一个 int 哈希值,这个哈希值在环上也有对应的位置,找到这个位置,它的顺时针第一个节点就是经过一致性哈希负载均衡(路由)之后的目标节点了。
如果出现节点上/下线的情况呢?参考上图,比如此时 4 号节点下线,那么原本由 4 号节点处理的 C 区域的请求就会交给 5 号节点来处理,再来,如果在 C 区域新加了一个 6 号节点呢?结果仅仅会导致 C 区域的请求被分摊给 4 号和 6 号节点。
所以,对于节点的上/下线,或者说动态扩容,影响的数据仅仅是一小部分的。
但是,这样还存在问题,比如下面的情况:
当这些服务器节点在哈希环上分布不均匀时,就会带来“数据倾斜”的问题,就上面的图来说,就是大部分的请求(D 区域)都被 1 号节点承担了。
所以,解决的办法就是加入一定数量的 虚拟节点,所谓虚拟节点,简单理解就是给每个服务器节点编个号,而进行哈希计算时带上这个编号,比如对 ip + port + number 进行哈希计算,指向哈希环上的一个位置。
所以,目前的情况就是一个物理节点对应多个虚拟节点,这些虚拟节点分布在哈希环上,这样即使物理节点数量有限,哈希环上的节点也会变成密集,进而就提高了哈希分布的均匀性。
Seata 如何落地
理论说清楚了,那如何进行落地呢,这就要看看开源框架是怎么做的了。
我们知道,在 Seata 中,如果部署了多个 TC 服务,那么 TM、RM 则会从所有的 TC 服务中选择一个节点进行远程 RPC,其中一种选择 TC 节点的策略就是基于 XID 进行一致性哈希负载均衡。
核心源码如下:
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
public class ConsistentHashLoadBalance implements LoadBalance {
/**
* 虚拟节点数量,Seata 默认是 10 个
*/
private static final int VIRTUAL_NODES_NUM = 10;
@Override
public <T> T select(List<T> invokers, String xid) {
return new ConsistentHashSelector<>(invokers, VIRTUAL_NODES_NUM).select(xid);
}
private static final class ConsistentHashSelector<T> {
private final SortedMap<Long, T> virtualInvokers = new TreeMap<>();
private final HashFunction hashFunction = new SHA256Hash();
ConsistentHashSelector(List<T> invokers, int virtualNodes) {
for (T invoker : invokers) {
for (int i = 0; i < virtualNodes; i++) {
virtualInvokers.put(hashFunction.hash(invoker.toString() + i), invoker);
}
}
}
public T select(String key) {
SortedMap<Long, T> tailMap = virtualInvokers.tailMap(hashFunction.hash(key));
Long nodeHashVal = tailMap.isEmpty() ? virtualInvokers.firstKey() : tailMap.firstKey();
return virtualInvokers.get(nodeHashVal);
}
}
private static class SHA256Hash implements HashFunction {
MessageDigest instance;
public SHA256Hash() {
try {
instance = MessageDigest.getInstance("SHA-256");
} catch (NoSuchAlgorithmException e) { throw new IllegalStateException(e.getMessage(), e); }
}
@Override
public long hash(String key) {
instance.reset();
instance.update(key.getBytes());
byte[] digest = instance.digest(key.getBytes(StandardCharsets.UTF_8));
long hash = 0;
for (int i = 0; i < 8 && i < digest.length; i++) {
hash <<= 8;
hash |= digest[i] & 0xff;
}
return hash;
}
}
public interface HashFunction { long hash(String key); }
}
首先 Seata 采用的是 SHA-256 哈希算法,但是 SHA-256 算法产生的哈希是 256 位,所以这里只取了前 8 个字节。
下面的代码描述了进行一致性哈希的流程:
public <T> T select(List<T> invokers, String xid) {
return new ConsistentHashSelector<>(invokers, VIRTUAL_NODES_NUM).select(xid);
}
也就是,首先根据实际的节点集合(invokers)和给定的虚拟节点数量构建哈希环,然后在哈希环上查找 xid 应该路由的节点。
下面我们就来详细看看:
private final SortedMap<Long, T> virtualInvokers = new TreeMap<>();
private final HashFunction hashFunction = new SHA256Hash();
ConsistentHashSelector(List<T> invokers, int virtualNodes) {
// 枚举每一个节点
for (T invoker : invokers) {
// 构建虚拟节点
for (int i = 0; i < virtualNodes; i++) {
virtualInvokers.put(hashFunction.hash(invoker.toString() + i), invoker);
}
}
}
这里一个关键的集合是 TreeMap,它是一个有序表结构,底层由红黑树实现,可用于对 Key 进行排序,并且这里 TreeMap 的 Key 和哈希函数的返回值都是一个 long 类型,所以我们可以理解为整个哈希环点的范围是 -2^63 ~ 2^63 - 1。
简单测试一下吧:
import java.util.*;
public class LoadBalanceTest {
public static void main(String[] args) {
Node n1 = new Node("1.1.1.1", 8081);
Node n2 = new Node("2.2.2.2", 8082);
Node n3 = new Node("3.3.3.3", 8083);
Node n4 = new Node("4.4.4.4", 8084);
Node n5 = new Node("5.5.5.5", 8085);
List<Node> nodes = new ArrayList<>(Arrays.asList(n1, n2, n3, n4, n5));
ConsistentHashLoadBalance.ConsistentHashSelector<Node> selector =
new ConsistentHashLoadBalance.ConsistentHashSelector<>(nodes, 10);
SortedMap<Long, Node> virtualInvokers = selector.virtualInvokers;
int maxLength = virtualInvokers.keySet().stream()
.map(node -> String.valueOf(node).length())
.max(Integer::compare).orElse(0);
for (Map.Entry<Long, Node> entry : virtualInvokers.entrySet()) {
System.out.printf("%" + maxLength + "s: %s\n", entry.getKey(), entry.getValue());
}
}
static class Node {
private String ip;
private int port;
public Node(String ip, int port) {
this.ip = ip;
this.port = port;
}
@Override
public String toString() {
return "{ip='" + ip + '\'' + ", port=" + port + "}";
}
}
}
结果:
-8765100378696268369: {ip='2.2.2.2', port=8082}
-7309610807049095245: {ip='5.5.5.5', port=8085}
-7154494275960100736: {ip='1.1.1.1', port=8081}
-6499503475465088347: {ip='1.1.1.1', port=8081}
-5431081118827408279: {ip='4.4.4.4', port=8084}
-4976262585019046723: {ip='3.3.3.3', port=8083}
-4615142504536031791: {ip='1.1.1.1', port=8081}
-3927884372318786534: {ip='4.4.4.4', port=8084}
-3651902831028064285: {ip='2.2.2.2', port=8082}
-3222011103493482997: {ip='3.3.3.3', port=8083}
-3208981219407080150: {ip='3.3.3.3', port=8083}
-2938198424423537517: {ip='5.5.5.5', port=8085}
-2909268249183675921: {ip='2.2.2.2', port=8082}
-2087306210261307181: {ip='1.1.1.1', port=8081}
-2017735560730504623: {ip='3.3.3.3', port=8083}
-1836979937240161177: {ip='5.5.5.5', port=8085}
-1686538238368548846: {ip='2.2.2.2', port=8082}
-1539332660961059016: {ip='4.4.4.4', port=8084}
-273018401611904073: {ip='4.4.4.4', port=8084}
-270633464619608925: {ip='4.4.4.4', port=8084}
323040766389448319: {ip='1.1.1.1', port=8081}
328007061058545004: {ip='4.4.4.4', port=8084}
460746675228006480: {ip='4.4.4.4', port=8084}
1478789356108103694: {ip='5.5.5.5', port=8085}
1587256342345534307: {ip='5.5.5.5', port=8085}
2184318476501135040: {ip='3.3.3.3', port=8083}
2650482749933920208: {ip='2.2.2.2', port=8082}
2739324027403965993: {ip='5.5.5.5', port=8085}
2770497952838825852: {ip='4.4.4.4', port=8084}
2865147171845233319: {ip='2.2.2.2', port=8082}
3569253519880407118: {ip='2.2.2.2', port=8082}
3641213276315232007: {ip='2.2.2.2', port=8082}
3681628225113916252: {ip='4.4.4.4', port=8084}
4085289147239250468: {ip='2.2.2.2', port=8082}
4145873723130151432: {ip='1.1.1.1', port=8081}
4303476820718951657: {ip='1.1.1.1', port=8081}
4616179267042897811: {ip='3.3.3.3', port=8083}
4946644748369368976: {ip='5.5.5.5', port=8085}
5254334165476013142: {ip='2.2.2.2', port=8082}
5649678429821113543: {ip='5.5.5.5', port=8085}
6356958223267464975: {ip='3.3.3.3', port=8083}
6892812445865672169: {ip='5.5.5.5', port=8085}
7284331290158892296: {ip='3.3.3.3', port=8083}
7960744547236797364: {ip='1.1.1.1', port=8081}
8100283097378057042: {ip='3.3.3.3', port=8083}
8250635747428575132: {ip='4.4.4.4', port=8084}
8457071997675064186: {ip='1.1.1.1', port=8081}
8533842698737907148: {ip='1.1.1.1', port=8081}
8622533820610121188: {ip='5.5.5.5', port=8085}
8799698339406741248: {ip='3.3.3.3', port=8083}
接着是根据指定的 key 去路由到某一节点的逻辑,也就是 select 方法:
public T select(String objectKey) {
// 返回所有 >= hashFunction.hash(objectKey) 的键值对
SortedMap<Long, T> tailMap = virtualInvokers.tailMap(hashFunction.hash(objectKey));
Long nodeHashVal = tailMap.isEmpty() ? virtualInvokers.firstKey() : tailMap.firstKey();
return virtualInvokers.get(nodeHashVal);
}
这里有几个不太常见的 API。
首先 tailMap(key) 方法会返回所有 ≥ key 的一个新的 SortedMap,而 firstKey 方法则是找出当前 Map 中最小的 Key。
所以,select 方法的逻辑就应该很简单了,就是找出 ≥ key 的最小的键值对,如果给定的 key 就是最大的,那么就返回最小的键值对。
我们再以哈希环的方式来理解一下,如果从起点到终点(顺时针)增大,那么 select 是不是就是找顺时针方向的第一个 Key。
当然,这样的写法也是可以的:
public T select(String objectKey) {
Map.Entry<Long, T> entry = virtualInvokers.ceilingEntry(hashFunction.hash(objectKey));
if (entry == null) {
entry = virtualInvokers.firstEntry();
}
return entry.getValue();
}
这个 ceilingEntry 会返回 ≥ key 的第一个键值对。
还是上面的例子:
public static void main(String[] args) {
Node n1 = new Node("1.1.1.1", 8081);
Node n2 = new Node("2.2.2.2", 8082);
Node n3 = new Node("3.3.3.3", 8083);
Node n4 = new Node("4.4.4.4", 8084);
Node n5 = new Node("5.5.5.5", 8085);
List<Node> nodes = new ArrayList<>(Arrays.asList(n1, n2, n3, n4, n5));
ConsistentHashLoadBalance.ConsistentHashSelector<Node> selector =
new ConsistentHashLoadBalance.ConsistentHashSelector<>(nodes, 10);
String key = "key";
long hash = selector.hashFunction.hash(key);
System.out.println("hash: " + hash);
Node node = selector.select(key);
System.out.println(node);
}
结果:
hash: -4804355588323245734
{ip='1.1.1.1', port=8081}
可以看到,key 这个字符串的哈希值是 -4804355588323245734,我们假想中的这个 Key 的哈希值在哈希环的顺时针方向的第一个节点(比它大的第一个节点)是不是就是哈希值为 -4615142504536031791 的节点,ip 为 1.1.1.1,端口为 8081。
不过,最后还是要吐槽一下,Seata 的这种一致性哈希算法实现,在每一次路由节点时都需要构建哈希环,一个改进的地方就是,我们能不能将构建好的哈希环缓存下来,只有当节点上/下线,也就是 invokers 发生变化时再重新映射。
于是,我改出了下面的代码:
package org.apache.seata.discovery.loadbalance;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import org.apache.seata.common.loader.LoadLevel;
import org.apache.seata.config.ConfigurationFactory;
import static org.apache.seata.common.DefaultValues.VIRTUAL_NODES_DEFAULT;
/**
* The type consistent hash load balance.
*/
@LoadLevel(name = LoadBalanceFactory.CONSISTENT_HASH_LOAD_BALANCE)
public class ConsistentHashLoadBalance implements LoadBalance {
/**
* The constant LOAD_BALANCE_CONSISTENT_HASH_VIRTUAL_NODES.
*/
public static final String LOAD_BALANCE_CONSISTENT_HASH_VIRTUAL_NODES = LoadBalanceFactory.LOAD_BALANCE_PREFIX
+ "virtualNodes";
/**
* The constant VIRTUAL_NODES_NUM.
*/
private static final int VIRTUAL_NODES_NUM = ConfigurationFactory.getInstance().getInt(
LOAD_BALANCE_CONSISTENT_HASH_VIRTUAL_NODES, VIRTUAL_NODES_DEFAULT);
/**
* The ConsistentHashSelectorWrapper that caches a {@link ConsistentHashSelector}.
*/
private volatile ConsistentHashSelectorWrapper selectorWrapper;
@SuppressWarnings("unchecked")
@Override
public <T> T select(List<T> invokers, String xid) {
if (selectorWrapper == null) {
synchronized (this) {
if (selectorWrapper == null) {
selectorWrapper = new ConsistentHashSelectorWrapper(
new ConsistentHashSelector<>(invokers, VIRTUAL_NODES_NUM), invokers);
}
}
}
return (T) selectorWrapper.getSelector(invokers).select(xid);
}
@SuppressWarnings({"rawtypes", "unchecked"})
private static final class ConsistentHashSelectorWrapper {
private volatile ConsistentHashSelector selector;
// only shared with read
private volatile Set invokers;
public ConsistentHashSelectorWrapper(ConsistentHashSelector selector, List invokers) {
this.selector = selector;
this.invokers = new HashSet<>(invokers);
}
public ConsistentHashSelector getSelector(List invokers) {
if (!equals(invokers)) {
synchronized (this) {
if (!equals(invokers)) {
selector = new ConsistentHashSelector(invokers, VIRTUAL_NODES_NUM);
this.invokers = new HashSet<>(invokers);
}
}
}
return selector;
}
private boolean equals(List invokers) {
if (invokers.size() != this.invokers.size()) {
return false;
}
for (Object invoker : invokers) {
if (!this.invokers.contains(invoker)) {
return false;
}
}
return true;
}
}
private static final class ConsistentHashSelector<T> {
// ...
}
private static class SHA256Hash implements HashFunction {
// ...
}
public interface HashFunction { long hash(String key); }
}
并且鼓起勇气向官方提交了一个 PR:
年轻人,有想法就去做!!!