本篇文章基于 ShardingSphere-JDBC 5.3.0 版本,通过 Java API 的方式,对 用户账户记录表 进行按月水平分片配置;目前记录表每日增长量约 50w。

引入依赖

1
2
3
4
5
6
7
8
9
10
11
<dependency>
<groupId>org.apache.shardingsphere</groupId>
<artifactId>shardingsphere-jdbc-core</artifactId>
<version>5.3.0</version>
</dependency>

<dependency>
<groupId>org.yaml</groupId>
<artifactId>snakeyaml</artifactId>
<version>1.33</version>
</dependency>

MyBatis 配置

对现有数据源进行改造,将数据源配置为 ShardingSphereDataSource。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@Configuration
@MapperScan(basePackages = "xxxx", sqlSessionTemplateRef = "mallSqlSessionTemplate")
public class MybatisMallConfig {

public static final String DATA_SOURCE_PREFIX = "spring.datasource.mall";

@Bean(name = "mallDataSource")
@ConfigurationProperties(prefix = DATA_SOURCE_PREFIX)
public DataSource mallDataSource() {
return DataSourceBuilder.create().build();
}

@Bean(name = "mallSqlSessionFactory")
public SqlSessionFactory mallSqlSessionFactory(@Qualifier("shardingMallDataSource") DataSource dataSource) throws Exception {
SqlSessionFactoryBean bean = new SqlSessionFactoryBean();
bean.setDataSource(dataSource);
bean.setMapperLocations(new PathMatchingResourcePatternResolver().getResources("classpath:mappers/mall/*.xml"));
return bean.getObject();
}

@Bean(name = "mallTransactionManager")
public DataSourceTransactionManager mallTransactionManager(@Qualifier("shardingMallDataSource") DataSource dataSource) {
return new DataSourceTransactionManager(dataSource);
}

@Bean(name = "mallSqlSessionTemplate")
public SqlSessionTemplate mallSqlSessionTemplate(@Qualifier("mallSqlSessionFactory") SqlSessionFactory sqlSessionFactory) throws Exception {
return new SqlSessionTemplate(sqlSessionFactory);
}

@Bean(name = "shardingMallDataSource")
public DataSource shardingMallDataSource() throws SQLException {
Properties prop = new Properties();
prop.setProperty("sql-show", "true");

Map<String, DataSource> dataSourceMap = new HashMap<>();
dataSourceMap.put("ds_mall0", mallDataSource());
return ShardingSphereDataSourceFactory.createDataSource(dataSourceMap, Collections.singleton(shardingRuleConfiguration()), prop);
}

@Bean
public ShardingRuleConfiguration shardingRuleConfiguration() {
ShardingRuleConfiguration shardingRuleConfig = new ShardingRuleConfiguration();

// add sharding algorithms
Properties addDateAlgProp = new Properties();
addDateAlgProp.put(DateShardingAlgorithm.PROP_DS_PREFIX, DATA_SOURCE_PREFIX);
shardingRuleConfig.getShardingAlgorithms().put("addDateAlg", new AlgorithmConfiguration("DATE_ALG", addDateAlgProp));
// primary key generator
Properties keyGeneratorProp = new Properties();
keyGeneratorProp.setProperty("worker-id", ShardingUtils.getWorkerId());
shardingRuleConfig.getKeyGenerators().put("snowflake", new AlgorithmConfiguration("SNOWFLAKE", keyGeneratorProp));

// table rule config
shardingRuleConfig.getTables().add(getAccountRecordTableRuleConfiguration());
return shardingRuleConfig;
}

private ShardingTableRuleConfiguration getAccountRecordTableRuleConfiguration() {
ShardingTableRuleConfiguration tableRuleConfiguration = new ShardingTableRuleConfiguration("account_record", "ds_mall0.account_record");
tableRuleConfiguration.setTableShardingStrategy(new StandardShardingStrategyConfiguration("add_date", "addDateAlg"));
tableRuleConfiguration.setKeyGenerateStrategy(new KeyGenerateStrategyConfiguration("ID", "snowflake"));
return tableRuleConfiguration;
}
}

Snowflake worderId 生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@Slf4j
public class ShardingUtils {

/** 机器标识位数 */
private final static long workerIdBits = 5L;

/** 数据中心标识位数 */
private final static long datacenterIdBits = 5L;

/** 数据中心ID最大值 */
private final static long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);

/** 机器ID最大值 */
private final static long maxWorkerId = -1L ^ (-1L << workerIdBits);

public static String getWorkerId() {
long datacenterId = getDatacenterId(maxDatacenterId);
return String.valueOf(getMaxWorkerId(datacenterId, maxWorkerId));
}

/**
* <p>
* 获取 maxWorkerId
* </p>
*/
protected static long getMaxWorkerId(long datacenterId, long maxWorkerId) {
StringBuffer mpid = new StringBuffer();
mpid.append(datacenterId);
String name = ManagementFactory.getRuntimeMXBean().getName();
if (!name.isEmpty()) {
/*
* GET jvmPid
*/
mpid.append(name.split("@")[0]);
}
/*
* MAC + PID 的 hashcode 获取16个低位
*/
return (mpid.toString().hashCode() & 0xffff) % (maxWorkerId + 1);
}

/**
* <p>
* 数据标识id部分
* </p>
*/
protected static long getDatacenterId(long maxDatacenterId) {
long id = 0L;
try {
InetAddress ip = InetAddress.getLocalHost();
NetworkInterface network = NetworkInterface.getByInetAddress(ip);
if (network == null) {
id = 1L;
} else {
byte[] mac = network.getHardwareAddress();
id = ((0x000000FF & (long) mac[mac.length - 1]) | (0x0000FF00 & (((long) mac[mac.length - 2]) << 8))) >> 6;
id = id % (maxDatacenterId + 1);
}
} catch (Exception e) {
log.error("getDatacenterId error: ", e);
}
return id;
}
}

自定义分片算法(自动建表)

自定义类,实现 StandardShardingAlgorithm 和 ShardingAutoTableAlgorithm 接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@Slf4j
public class DateShardingAlgorithm implements StandardShardingAlgorithm<Comparable<?>>, ShardingAutoTableAlgorithm {

/** 数据源 prefix */
public static final String PROP_DS_PREFIX = "ds";
private static final DateTimeFormatter TABLE_SHARD_DATE_FORMATTER = DateTimeFormatter.ofPattern("yyyyMM");
private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyyMMdd HH:mm:ss");
private final String TABLE_SPLIT_SYMBOL = "_";

@Getter
private Properties props;

private static final int autoTablesAmount = 1;
private String dsPrefix;

@Override
public int getAutoTablesAmount() {
return autoTablesAmount;
}

@Override
public String doSharding(Collection<String> availableTargetNames, PreciseShardingValue<Comparable<?>> preciseShardingValue) {
String logicTableName = preciseShardingValue.getLogicTableName();
log.info("doSharding logicTable = {}", logicTableName);

Comparable<?> value = preciseShardingValue.getValue();
LocalDateTime dateTime = getLocalDateTime(value);

String actualTableName = logicTableName + "_" + dateTime.format(TABLE_SHARD_DATE_FORMATTER);

log.info("doSharding actualTable = {}", actualTableName);

availableTargetNames.clear();
availableTargetNames.add(actualTableName);
return ShardingAlgorithmUtils.getShardingTableAndCreate(logicTableName, actualTableName, dsPrefix);
}

@Override
public Collection<String> doSharding(Collection<String> availableTargetNames, RangeShardingValue<Comparable<?>> rangeShardingValue) {
String logicTableName = rangeShardingValue.getLogicTableName();

Range<Comparable<?>> valueRange = rangeShardingValue.getValueRange();
boolean hasLowerBound = valueRange.hasLowerBound();
boolean hasUpperBound = valueRange.hasUpperBound();

Comparable<?> lowerEndpoint = valueRange.lowerEndpoint();
Comparable<?> upperEndpoint = valueRange.upperEndpoint();


LocalDateTime min = hasLowerBound ? getLocalDateTime(lowerEndpoint) : getLowerEndpoint(availableTargetNames);
LocalDateTime max = hasUpperBound ? getLocalDateTime(upperEndpoint) : getUpperEndpoint(availableTargetNames);

// cal sharding range
Set<String> actualTableNames = new LinkedHashSet<>();
while (min.isBefore(max) || min.equals(max)) {
String tableName = logicTableName + TABLE_SPLIT_SYMBOL + min.format(TABLE_SHARD_DATE_FORMATTER);
actualTableNames.add(tableName);
min = min.plusMinutes(1);
}
return ShardingAlgorithmUtils.getShardingTablesAndCreate(logicTableName, actualTableNames, dsPrefix);
}

@Override
public Properties getProps() {
return props;
}

@Override
public void init(Properties properties) {
this.props = properties;
dsPrefix = (String) props.get(PROP_DS_PREFIX);
}

@Override
public String getType() {
return "DATE_ALG";
}

private static LocalDateTime getLocalDateTime(Comparable<?> value) {
LocalDateTime dateTime;
if (value instanceof String) {
try {
Date date = DateUtils.parseDate((String) value, "yyyy-MM-dd", "yyyy-MM-dd HH:mm:ss", "yyyyMMdd");
dateTime = date.toInstant()
.atZone(ZoneId.systemDefault())
.toLocalDateTime();
} catch (ParseException e) {
throw new IllegalArgumentException(e);
}
} else if (value instanceof Date) {
dateTime = ((Date) value).toInstant()
.atZone(ZoneId.systemDefault())
.toLocalDateTime();
} else if (value instanceof LocalDateTime) {
dateTime = (LocalDateTime) value;
} else {
throw new IllegalArgumentException("datetime Only Support String, java.util.Date, java.time.LocalDateTime");
}
return dateTime;
}

private LocalDateTime getLowerEndpoint(Collection<String> tableNames) {
Optional<LocalDateTime> optional = tableNames.stream()
.map(o -> LocalDateTime.parse(o.replace(TABLE_SPLIT_SYMBOL, "") + "01 00:00:00", DATE_TIME_FORMATTER))
.min(Comparator.comparing(Function.identity()));
if (optional.isPresent()) {
return optional.get();
} else {
log.error("getLowerEndpoint failure,tableName:{}", tableNames);
throw new IllegalArgumentException("getLowerEndpoint failure");
}
}

private LocalDateTime getUpperEndpoint(Collection<String> tableNames) {
Optional<LocalDateTime> optional = tableNames.stream()
.map(o -> LocalDateTime.parse(o.replace(TABLE_SPLIT_SYMBOL, "") + "01 00:00:00", DATE_TIME_FORMATTER))
.max(Comparator.comparing(Function.identity()));
if (optional.isPresent()) {
return optional.get();
} else {
log.error("getUpperEndpoint failure,tableName:{}", tableNames);
throw new IllegalArgumentException("getUpperEndpoint failure");
}
}

}

创建 ShardingAlgorithmUtils 工具类,实现当表在缓存中不存在时,自动创建表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@Slf4j
public class ShardingAlgorithmUtils {

private static final Map<String, Set<String>> DS_TABLE_NAME_CACHE = new ConcurrentHashMap<>();

/** 表分片符号,例:account_record_202406 中,分片符号为 "_" */
private static final String TABLE_SPLIT_SYMBOL = "_";

private static final Environment ENV = SpringUtils.getApplicationContext().getEnvironment();


public static String getShardingTableAndCreate(String logicTableName, String actualTableName, String dsPrefix) {
synchronized (logicTableName.intern()) {
// get tableName from cache
Set<String> set = DS_TABLE_NAME_CACHE.get(dsPrefix);
if (CollectionUtils.isNotEmpty(set) && set.contains(actualTableName)) {
return actualTableName;
}

// create table && add cache
boolean result = false;
try {
result = createShardingTable(logicTableName, actualTableName, dsPrefix);
if (result) {
set.add(actualTableName);
}
} catch (Exception e) {
if (e instanceof SQLSyntaxErrorException && e.getMessage().contains("already exists")) {
set.add(actualTableName);
}
}
}
return actualTableName;
}

public static Set<String> getShardingTablesAndCreate(String logicTableName, Collection<String> actualTableNames, String dsPrefix) {
return actualTableNames.stream().map(o -> getShardingTableAndCreate(logicTableName, o, dsPrefix)).collect(Collectors.toSet());
}

/**
* load all table names
* @param dsPrefix datasource prefix
* @param logicTableNames logic table names
*/
public static void loadAllTableNameBySchema(List<String> dsPrefix, List<String> logicTableNames) {
DS_TABLE_NAME_CACHE.clear();

for (String prefix : dsPrefix) {

String jdbcUrl = ENV.getProperty(prefix + ".jdbc-url");
String username = ENV.getProperty(prefix + ".username");
String password = ENV.getProperty(prefix + ".password");

if (StringUtils.isEmpty(jdbcUrl) || StringUtils.isEmpty(username) || StringUtils.isEmpty(password)) {
log.error("jdbc properties invalid,URL:{}, username:{}, password:{}", jdbcUrl, username, password);
throw new IllegalArgumentException("jdbc properties invalid!");
}
try (Connection conn = DriverManager.getConnection(jdbcUrl, username, password);
Statement st = conn.createStatement()) {
Set<String> set = new HashSet<>();
for (String logicTableName : logicTableNames) {
try (ResultSet rs = st.executeQuery("SHOW TABLES LIKE '" + logicTableName + TABLE_SPLIT_SYMBOL + "%'")) {
while (rs.next()) {
String tableName = rs.getString(1);
// 匹配分表格式 例:^(t\_contract_\d{6})$
if (tableName != null && tableName.matches(String.format("^(%s\\d{6})$", logicTableName + TABLE_SPLIT_SYMBOL))) {
set.add(rs.getString(1));
}
}
DS_TABLE_NAME_CACHE.put(prefix, set);
}
}

} catch (SQLException e) {
log.error("Database connection failure:", e);
throw new IllegalArgumentException("Database connection failure!");
}
}

}

/**
* create table
*
* @param logicTableName logic table name
* @param actualTableName actual table name,eg:accountrecord_202405
* @param dsPrefix datasource prefix
* @return true-success,false-failure
*/
public static boolean createShardingTable(String logicTableName, String actualTableName, String dsPrefix) throws SQLException {
// 根据日期判断,当前月份之后分表不提前创建
String month = actualTableName.replace(logicTableName + TABLE_SPLIT_SYMBOL,"");
YearMonth shardingMonth = YearMonth.parse(month, DateTimeFormatter.ofPattern("yyyyMM"));
if (shardingMonth.isAfter(YearMonth.now())) {
return false;
}

synchronized (logicTableName.intern()) {
// 缓存中无此表,则建表并添加缓存
executeSql(Collections.singletonList("CREATE TABLE `" + actualTableName + "` LIKE `" + logicTableName + "`;"), dsPrefix);
}
return true;
}

/**
* execute sql
*
* @param sqlList sql list
* @param dsPrefix datasource
*/
private static void executeSql(List<String> sqlList, String dsPrefix) throws SQLException {
String jdbcUrl = ENV.getProperty(dsPrefix + ".jdbc-url");
String username = ENV.getProperty(dsPrefix + ".username");
String password = ENV.getProperty(dsPrefix + ".password");

if (StringUtils.isEmpty(jdbcUrl) || StringUtils.isEmpty(username) || StringUtils.isEmpty(password)) {
log.error("jdbc properties invalid,URL:{}, username:{}, password:{}", jdbcUrl, username, password);
throw new IllegalArgumentException("jdbc properties invalid!");
}
try (Connection conn = DriverManager.getConnection(jdbcUrl, username, password)) {
try (Statement st = conn.createStatement()) {
conn.setAutoCommit(false);
for (String sql : sqlList) {
st.execute(sql);
}
} catch (Exception e) {
conn.rollback();
log.error("table create failure:", e);
throw e;
}
} catch (SQLException e) {
log.error("database connection failure:", e);
throw e;
}
}
}

当 Spring Boot 程序启动的时候,我们读取所有的分片表名,存储到上面的 ConcurrentHashMap 中

1
2
3
4
5
6
7
8
9
10
11
12
13
@Order(1)
@Component
public class ShardingTableRunner implements CommandLineRunner {
@Override
public void run(String... args) throws Exception {
List<String> list = new ArrayList<>();
list.add("spring.datasource.mall");

List<String> logicTables = new ArrayList<>();
logicTables.add("account_record");
ShardingAlgorithmUtils.loadAllTableNameBySchema(list, logicTables);
}
}

通过 SPI 配置算法类

创建对应目录及文件:src/main/resources/META-INF.services/org.apache.shardingsphere.sharding.spi.ShardingAlgorithm

将自定义算法类路径写入文件

1
com.xxx.sharding.DateShardingAlgorithm