主键生成策略源码

开发者手册

KeyGenerateAlgorithm

全限定类名org.apache.shardingsphere.sharding.spi.KeyGenerateAlgorithm

分布式主键生成算法,已知实现

配置标识详细说明全限定类名
SNOWFLAKE基于雪花算法的分布式主键生成算法org.apache.shardingsphere.sharding.algorithm.keygen.SnowflakeKeyGenerateAlgorithm
UUID基于 UUID 的分布式主键生成算法org.apache.shardingsphere.sharding.algorithm.keygen.UUIDKeyGenerateAlgorithm
NANOID基于 NanoId 的分布式主键生成算法org.apache.shardingsphere.sharding.nanoid.algorithm.keygen.NanoIdKeyGenerateAlgorithm
COSID基于 CosId 的分布式主键生成算法org.apache.shardingsphere.sharding.cosid.algorithm.keygen.CosIdKeyGenerateAlgorithm
COSID_SNOWFLAKE基于 CosId 的雪花算法分布式主键生成算法org.apache.shardingsphere.sharding.cosid.algorithm.keygen.CosIdSnowflakeKeyGenerateAlgorithm



源码入口

package org.apache.shardingsphere.sharding.factory;

/**
 * Key generate algorithm factory.
 */
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class KeyGenerateAlgorithmFactory {
    //加载所有的主键生成策略
    static {
        ShardingSphereServiceLoader.register(KeyGenerateAlgorithm.class);
    }
    
    /**
     * 根据配置的主键生成策略,获取一个主键生成算法
     * 例如:spring.shardingsphere.rules.sharding.key-generators.usercourse_keygen.type=SNOWFLAKE
     */
    public static KeyGenerateAlgorithm newInstance(final AlgorithmConfiguration keyGenerateAlgorithmConfig) {
        return ShardingSphereAlgorithmFactory.createAlgorithm(keyGenerateAlgorithmConfig, KeyGenerateAlgorithm.class);
    }
    
    /**
     * 判断是否包含配置的算法
     */
    public static boolean contains(final String keyGenerateAlgorithmType) {
        return TypedSPIRegistry.findRegisteredService(KeyGenerateAlgorithm.class, keyGenerateAlgorithmType).isPresent();
    }
}



先来看主键生成策略是如何加载的:ShardingSphereServiceLoader.register(KeyGenerateAlgorithm.class);

public final class ShardingSphereServiceLoader {
    //线程安全Map,缓存所有主键生成器
    private static final Map<Class<?>, Collection<Object>> SERVICES = new ConcurrentHashMap<>();
    
    // 进入到register()方法中
    public static void register(final Class<?> serviceInterface) {
        if (!SERVICES.containsKey(serviceInterface)) {
            // 调用下方的load()方法
            SERVICES.put(serviceInterface, load(serviceInterface));
        }
    }
    
    //使用java的SPI机制加载接口的所有实现类
    private static <T> Collection<Object> load(final Class<T> serviceInterface) {
        Collection<Object> result = new LinkedList<>();
        for (T each : ServiceLoader.load(serviceInterface)) {
            result.add(each);
        }
        return result;
    }
}



实现

ShardingJDBC是通过SPI机制,加载org.apache.shardingsphere.sharding.spi.KeyGenerateAlgorithm接口的实现类,也就是上方表格中的内容

我们就可以直接在yml配置文件中进行配置分布式主键生成算法



接下来就以SNOWFLAKE雪花算法举例,下方就列举出了几个关键方法

// 实现了KeyGenerateAlgorithm接口
public final class SnowflakeKeyGenerateAlgorithm implements KeyGenerateAlgorithm, InstanceContextAware {
    
    // 在init方法中,会把我们yml配置文件中定义的props的配置项,保存在下面方法的形参中,并赋值给props成员属性
    // 其他地方再用props对象获取我们的配置项
    @Override
    public void init(final Properties props) {
        this.props = props;
        maxVibrationOffset = getMaxVibrationOffset(props);
        maxTolerateTimeDifferenceMilliseconds = getMaxTolerateTimeDifferenceMilliseconds(props);
    }
    
    // 实现KeyGenerateAlgorithm接口中的抽象方法generateKey()
    // 也就是在这个方法中具体生成分布式主键值的
    @Override
    public synchronized Long generateKey() {
        long currentMilliseconds = timeService.getCurrentMillis();
        if (waitTolerateTimeDifferenceIfNeed(currentMilliseconds)) {
            currentMilliseconds = timeService.getCurrentMillis();
        }
        if (lastMilliseconds == currentMilliseconds) {
            if (0L == (sequence = (sequence + 1) & SEQUENCE_MASK)) {
                currentMilliseconds = waitUntilNextTime(currentMilliseconds);
            }
        } else {
            vibrateSequenceOffset();
            sequence = sequenceOffset;
        }
        lastMilliseconds = currentMilliseconds;
        return ((currentMilliseconds - EPOCH) << TIMESTAMP_LEFT_SHIFT_BITS) | (getWorkerId() << WORKER_ID_LEFT_SHIFT_BITS) | sequence;
    }
    
    // getType() 方法中返回的字符串就是我们上方yml配置文件中type配置项填写的值
    @Override
    public String getType() {
        return "SNOWFLAKE";
    }
}



其他几个实现类也是一样的格式

在这里插入图片描述



扩展 自定义分布式主键生成策略

package com.hs.sharding.algorithm;

import com.google.common.base.Preconditions;
import org.apache.shardingsphere.infra.instance.InstanceContext;
import org.apache.shardingsphere.infra.instance.InstanceContextAware;
import org.apache.shardingsphere.sharding.algorithm.keygen.TimeService;
import org.apache.shardingsphere.sharding.spi.KeyGenerateAlgorithm;

import java.util.Calendar;
import java.util.Properties;

/**
 * 改进雪花算法,让他能够 %4 均匀分布。
 * @auth hs
 */
public final class MySnowFlakeAlgorithm implements KeyGenerateAlgorithm, InstanceContextAware {

    public static final long EPOCH;

    private static final String MAX_VIBRATION_OFFSET_KEY = "max-vibration-offset";

    private static final String MAX_TOLERATE_TIME_DIFFERENCE_MILLISECONDS_KEY = "max-tolerate-time-difference-milliseconds";

    private static final long SEQUENCE_BITS = 12L;

    private static final long WORKER_ID_BITS = 10L;

    private static final long SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1;

    private static final long WORKER_ID_LEFT_SHIFT_BITS = SEQUENCE_BITS;

    private static final long TIMESTAMP_LEFT_SHIFT_BITS = WORKER_ID_LEFT_SHIFT_BITS + WORKER_ID_BITS;

    private static final int DEFAULT_VIBRATION_VALUE = 1;

    private static final int MAX_TOLERATE_TIME_DIFFERENCE_MILLISECONDS = 10;

    private static final long DEFAULT_WORKER_ID = 0;

    private static TimeService timeService = new TimeService();

    public static void setTimeService(TimeService timeService) {
        MySnowFlakeAlgorithm.timeService = timeService;
    }

    private Properties props;

    @Override
    public Properties getProps() {
        return props;
    }

    private int maxVibrationOffset;

    private int maxTolerateTimeDifferenceMilliseconds;

    private volatile int sequenceOffset = -1;

    private volatile long sequence;

    private volatile long lastMilliseconds;

    private volatile InstanceContext instanceContext;

    static {
        Calendar calendar = Calendar.getInstance();
        calendar.set(2016, Calendar.NOVEMBER, 1);
        calendar.set(Calendar.HOUR_OF_DAY, 0);
        calendar.set(Calendar.MINUTE, 0);
        calendar.set(Calendar.SECOND, 0);
        calendar.set(Calendar.MILLISECOND, 0);
        EPOCH = calendar.getTimeInMillis();
    }

    @Override
    public void init(final Properties props) {
        this.props = props;
        maxVibrationOffset = getMaxVibrationOffset(props);
        maxTolerateTimeDifferenceMilliseconds = getMaxTolerateTimeDifferenceMilliseconds(props);
    }

    @Override
    public void setInstanceContext(final InstanceContext instanceContext) {
        this.instanceContext = instanceContext;
        if (null != instanceContext) {
            instanceContext.generateWorkerId(props);
        }
    }

    private int getMaxVibrationOffset(final Properties props) {
        int result = Integer.parseInt(props.getOrDefault(MAX_VIBRATION_OFFSET_KEY, DEFAULT_VIBRATION_VALUE).toString());
        Preconditions.checkArgument(result >= 0 && result <= SEQUENCE_MASK, "Illegal max vibration offset.");
        return result;
    }

    private int getMaxTolerateTimeDifferenceMilliseconds(final Properties props) {
        return Integer.parseInt(props.getOrDefault(MAX_TOLERATE_TIME_DIFFERENCE_MILLISECONDS_KEY, MAX_TOLERATE_TIME_DIFFERENCE_MILLISECONDS).toString());
    }

    @Override
    public synchronized Long generateKey() {
        long currentMilliseconds = timeService.getCurrentMillis();
        if (waitTolerateTimeDifferenceIfNeed(currentMilliseconds)) {
            currentMilliseconds = timeService.getCurrentMillis();
        }
        if (lastMilliseconds == currentMilliseconds) {
//            if (0L == (sequence = (sequence + 1) & SEQUENCE_MASK)) {
                currentMilliseconds = waitUntilNextTime(currentMilliseconds);
//            }
        } else {
            vibrateSequenceOffset();
//            sequence = sequenceOffset;
            sequence = sequence >= SEQUENCE_MASK ? 0:sequence+1;
        }
        lastMilliseconds = currentMilliseconds;
        return ((currentMilliseconds - EPOCH) << TIMESTAMP_LEFT_SHIFT_BITS) | (getWorkerId() << WORKER_ID_LEFT_SHIFT_BITS) | sequence;
    }

    private boolean waitTolerateTimeDifferenceIfNeed(final long currentMilliseconds) {
        if (lastMilliseconds <= currentMilliseconds) {
            return false;
        }
        long timeDifferenceMilliseconds = lastMilliseconds - currentMilliseconds;
        Preconditions.checkState(timeDifferenceMilliseconds < maxTolerateTimeDifferenceMilliseconds,
                "Clock is moving backwards, last time is %d milliseconds, current time is %d milliseconds", lastMilliseconds, currentMilliseconds);
        try {
            Thread.sleep(timeDifferenceMilliseconds);
        } catch (InterruptedException e) {
        }
        return true;
    }

    private long waitUntilNextTime(final long lastTime) {
        long result = timeService.getCurrentMillis();
        while (result <= lastTime) {
            result = timeService.getCurrentMillis();
        }
        return result;
    }

    @SuppressWarnings("NonAtomicOperationOnVolatileField")
    private void vibrateSequenceOffset() {
        sequenceOffset = sequenceOffset >= maxVibrationOffset ? 0 : sequenceOffset + 1;
    }

    private long getWorkerId() {
        return null == instanceContext ? DEFAULT_WORKER_ID : instanceContext.getWorkerId();
    }

    @Override
    public String getType() {
        return "MYSNOWFLAKE";
    }

    @Override
    public boolean isDefault() {
        return true;
    }
}



使用spi机制加载我们上方定义的类

在这里插入图片描述



yml配置文件中使用我们自己定义的类

在这里插入图片描述



分片算法

开发者手册

ShardingAlgorithm

全限定类名org.apache.shardingsphere.sharding.spi.ShardingAlgorithm

分片算法,已知实现

配置标识自动分片算法详细说明类名
MODY基于取模的分片算法ModShardingAlgorithm
HASH_MODY基于哈希取模的分片算法HashModShardingAlgorithm
BOUNDARY_RANGEY基于分片边界的范围分片算法BoundaryBasedRangeShardingAlgorithm
VOLUME_RANGEY基于分片容量的范围分片算法VolumeBasedRangeShardingAlgorithm
AUTO_INTERVALY基于可变时间范围的分片算法AutoIntervalShardingAlgorithm
INTERVALN基于固定时间范围的分片算法IntervalShardingAlgorithm
CLASS_BASEDN基于自定义类的分片算法ClassBasedShardingAlgorithm
INLINEN基于行表达式的分片算法InlineShardingAlgorithm
COMPLEX_INLINEN基于行表达式的复合分片算法ComplexInlineShardingAlgorithm
HINT_INLINEN基于行表达式的 Hint 分片算法HintInlineShardingAlgorithm
COSID_MODN基于 CosId 的取模分片算法CosIdModShardingAlgorithm
COSID_INTERVALN基于 CosId 的固定时间范围的分片算法CosIdIntervalShardingAlgorithm
COSID_INTERVAL_SNOWFLAKEN基于 CosId 的雪花ID固定时间范围的分片算法CosIdSnowflakeIntervalShardingAlgorithm



实现

这里就拿CLASS_BASED自定义分片策略来举例。我们之前的配置项如下所示。

这里就有一个问题,props的值我怎么知道写什么,我又怎么知道我自定义的类需要实现什么接口?

在这里插入图片描述



我们现在进入到CLASS_BASED分片算法的实现类中ClassBasedShardingAlgorithm去看看它的源码

public final class ClassBasedShardingAlgorithm implements StandardShardingAlgorithm<Comparable<?>>, ComplexKeysShardingAlgorithm<Comparable<?>>, HintShardingAlgorithm<Comparable<?>> {
    
    // 定义两个常量,我们会发现这里就是props中我们进行配置的值
    private static final String STRATEGY_KEY = "strategy";
    
    private static final String ALGORITHM_CLASS_NAME_KEY = "algorithmClassName";
    
    @Getter
    private Properties props;
    
    private ClassBasedShardingAlgorithmStrategyType strategy;
    
    private String algorithmClassName;
    
    private StandardShardingAlgorithm standardShardingAlgorithm;
    
    private ComplexKeysShardingAlgorithm complexKeysShardingAlgorithm;
    
    private HintShardingAlgorithm hintShardingAlgorithm;
    
    // init()方法中会获取到props对象,props对象中保存了我们yml配置文件中的配置内容
    // 这里就会取出来,赋值给 strategy  和  algorithmClassName 成员属性
    @Override
    public void init(final Properties props) {
        this.props = props;
        strategy = getStrategy(props);
        algorithmClassName = getAlgorithmClassName(props);
        initAlgorithmInstance(props);
    }
    
    private ClassBasedShardingAlgorithmStrategyType getStrategy(final Properties props) {
        String strategy = props.getProperty(STRATEGY_KEY);
        Preconditions.checkNotNull(strategy, "Properties `%s` can not be null when uses class based sharding strategy.", STRATEGY_KEY);
        return ClassBasedShardingAlgorithmStrategyType.valueOf(strategy.toUpperCase().trim());
    }
    
    private String getAlgorithmClassName(final Properties props) {
        String result = props.getProperty(ALGORITHM_CLASS_NAME_KEY);
        Preconditions.checkNotNull(result, "Properties `%s` can not be null when uses class based sharding strategy.", ALGORITHM_CLASS_NAME_KEY);
        return result;
    }
    
    // 这里就会判断 strategy 属性是哪一个  STANDARD、COMPLEX、HINT
    // 然后在进行具体的实例 StandardShardingAlgorithm、ComplexKeysShardingAlgorithm、HintShardingAlgorithm
    private void initAlgorithmInstance(final Properties props) {
        switch (strategy) {
            case STANDARD:
                standardShardingAlgorithm = ClassBasedShardingAlgorithmFactory.newInstance(algorithmClassName, StandardShardingAlgorithm.class, props);
                break;
            case COMPLEX:
                complexKeysShardingAlgorithm = ClassBasedShardingAlgorithmFactory.newInstance(algorithmClassName, ComplexKeysShardingAlgorithm.class, props);
                break;
            case HINT:
                hintShardingAlgorithm = ClassBasedShardingAlgorithmFactory.newInstance(algorithmClassName, HintShardingAlgorithm.class, props);
                break;
            default:
                break;
        }
    }
    
    
    // doSharding()方法,具体的分片算法逻辑
    @SuppressWarnings("unchecked")
    @Override
    public String doSharding(final Collection<String> availableTargetNames, final PreciseShardingValue<Comparable<?>> shardingValue) {
        return standardShardingAlgorithm.doSharding(availableTargetNames, shardingValue);
    }
    
    @SuppressWarnings("unchecked")
    @Override
    public Collection<String> doSharding(final Collection<String> availableTargetNames, final RangeShardingValue<Comparable<?>> shardingValue) {
        return standardShardingAlgorithm.doSharding(availableTargetNames, shardingValue);
    }
    
    @SuppressWarnings("unchecked")
    @Override
    public Collection<String> doSharding(final Collection<String> availableTargetNames, final ComplexKeysShardingValue<Comparable<?>> shardingValue) {
        return complexKeysShardingAlgorithm.doSharding(availableTargetNames, shardingValue);
    }
    
    @SuppressWarnings("unchecked")
    @Override
    public Collection<String> doSharding(final Collection<String> availableTargetNames, final HintShardingValue<Comparable<?>> shardingValue) {
        return hintShardingAlgorithm.doSharding(availableTargetNames, shardingValue);
    }
    
    // 返回trye为 CLASS_BASED  这里也就是和yml配置文件中的type对应上了
    @Override
    public String getType() {
        return "CLASS_BASED";
    }
}



其他的分片算法也是类似的实现

在这里插入图片描述



扩展 自定义分片算法

自定义一个java类,实现ShardingAlgorithm接口,或者是它的子接口 StandardShardingAlgorithm、ComplexKeysShardingAlgorithm、HintShardingAlgorithm都行,重写其中的doSharding()方法,我们自己指定分片逻辑



重写getType()方法,返回一个字符串,能够让我们在yml配置文件中进行配置

@Override
public String getType() {
    return "MY_COMPLEX_ALGORITHM";
}



例如我现在自定义的分片类如下

package com.hs.sharding.algorithm;

import com.google.common.base.Preconditions;
import com.google.common.collect.Range;
import org.apache.shardingsphere.sharding.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.RangeShardingValue;
import org.apache.shardingsphere.sharding.api.sharding.standard.StandardShardingAlgorithm;


import java.util.*;

/**
 * 自定义分片策略 , 我们这里实现标准的分片算法接口StandardShardingAlgorithm
 * 我这里是分片逻辑就是按照个数取模,在分发到sys_user1   sys_user2数据表中
 */
public class HsComplexAlgorithm implements StandardShardingAlgorithm<Long> {

    /**
     * 数据库个数
     */
    private final String DB_COUNT = "db-count";
    /**
     * 数据表个数
     */
    private final String TAB_COUNT = "tab-count";
    /**
     * 真实数据表前缀
     */
    private final String PERTAB = "pertab";


    private Integer dbCount;
    private Integer tabCount;
    private String pertab;

    private Properties props;

    @Override
    public void init(Properties props) {
        this.props = props;
        this.dbCount = getDbCount(props);
        this.tabCount = getTabCount(props);
        this.pertab = getPertab(props);
        // 校验条件
        Preconditions.checkState(null != pertab && !pertab.isEmpty(),
                "Inline hsComplex algorithm expression cannot be null or empty.");
    }

    /**
     * 精确查询分片执行接口(对应的sql是where ??=值)
     * @param collection 可用的分片名集合(分库就是库名,分表就是表名)
     * @param preciseShardingValue 分片键
     */
    @Override
    public String doSharding(Collection<String> collection, PreciseShardingValue<Long> preciseShardingValue) {

        Long uid = preciseShardingValue.getValue();
        String resultTableName = pertab + ((uid + 1) % (dbCount * tabCount) / tabCount + 1);
        if (collection.contains(resultTableName)){
            return resultTableName;
        }
        throw new UnsupportedOperationException("route: " + resultTableName + " is not supported, please check your config");
    }

    /**
     * 范围分片规则(对应的是where ??>='XXX' and ??<='XXX')
     * 范围查询分片算法(分片键涉及区间查询时会进入该方法进行分片计算)
     */
    @Override
    public Collection<String> doSharding(Collection<String> collection, RangeShardingValue<Long> rangeShardingValue) {
        List<String> result = new ArrayList<>();
        Range<Long> valueRange = rangeShardingValue.getValueRange();
        Long upperEndpoint = valueRange.upperEndpoint();
        Long aLong = valueRange.lowerEndpoint();

        // TODO 进行相应的分片判断
//        return result;

        return collection;
    }

    private String getPertab(Properties props) {
        return props.getProperty(PERTAB);
    }

    private Integer getDbCount(Properties props) {
        String count = props.getProperty(DB_COUNT);
        return count == null || count.isEmpty() ? 0 : Integer.valueOf(count);
    }

    private Integer getTabCount(Properties props) {
        String count = props.getProperty(TAB_COUNT);
        return count == null || count.isEmpty() ? 0 : Integer.valueOf(count);
    }


    @Override
    public Properties getProps() {
        return props;
    }

    @Override
    public String getType() {
        return "HS";
    }
}



需要添加一个SPI的配置文件org.apache.shardingsphere.sharding.spi.ShardingAlgorithm,在该文件中指定我们上方创建的java类

在这里插入图片描述



yml配置文件中进行相应的更改

在这里插入图片描述



踩的坑

我先是自定义的类实现的是ComplexKeysShardingAlgorithm接口,但是我们yml配置类中还是一直按照standard的配置,导致我自定义的类中的doSharding()方法所以就一直没有调用到

在这里插入图片描述



之后我修改了complex就能调用了

在这里插入图片描述



在配置还是standard时,我通过debug,发现init()getType()方法都能够调用,证明SPI机制相关的文件没问题。

我就想会不会是单分片键、精确查询、范围查询相关问题导致的?

我修改了实现接口,改为了StandardShardingAlgorithm,然后就进入了其中单分片键的doSharding()方法。最后就一点一点的排查,再到了这上面的配置

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐