Spring解决RocketMQ发消息与MySQL事务一致性

Spring解决RocketMQ发消息与MySQL事务一致性

场景

  1. 用户订单并支付
  2. 发送消息开通查看文章权限
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    // 伪代码
    @Transactional(rollbackFor=Exception.class)
    public void pay(long uid, String orderNO) {
    Order order = orderService.selectOrder(uid, orderNO)
    if (order != null) {
    String status = "paid";
    orderDao.updateStatus(uid, orderNo, status);

    rocketMQTemplate.send("order:status", message(uid, orderNo, order.itemId, status));
    }
    }
    public class OrderStatusArticleListener implements RocketMQListener {
    public void onMessage(message) {
    Order order = orderService.selectOrder(message.uid, message.orderNo)
    if (order == null) {
    throw new RuntimeException("order not found. " + message.orderNo)
    }
    if (order.status != "paid") {
    throw new RuntimeException("order not paid. " + message.orderNo)
    }
    // 授权
    articleService.authorize(message.uid, message.itemId)
    }
    }

上面的例子中会出现消费者查询订单的时候是未支付的状态。

为什么会这样呢?

这是因为我们在spring的事务中同步发送消息导致事务还没有提交。消息已经到了消费者端开始消费了。

解决:

  1. 增加消息表,与事务同步落库,标记为待处理
  2. MQ 发送成功
  3. MQ 的回调处理落库的数据,标记为处理完成

由于是在Spring的环境中,我们使用Spring的TransactionSynchronizationManager#registerSynchronization

1
2
3
4
5
if (TransactionSynchronizationManager.isSynchronizationActive()) {
TransactionSynchronizationManager.registerSynchronization(new MQTransactionSynchronization(
rocketMQTemplate, destination, message, timeout, delayLevel
));
}

我们自定义一个TransactionSynchronization名字叫MQTransactionSynchronization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@Slf4j
public class MQTransactionSynchronization implements TransactionSynchronization {
private DataSource dataSource;
private ConnectionHolder connectionHolder;
private String id;
private RocketMQTemplate rocketMQTemplate;
private String destination;
private Message message;
private long timeout;
private int delayLevel;

public MQTransactionSynchronization(RocketMQTemplate rocketMQTemplate, String destination, Message message, long timeout, int delayLevel) {
this.rocketMQTemplate = rocketMQTemplate;
this.destination = destination;
this.message = message;
this.timeout = timeout;
this.delayLevel = delayLevel;
}

@Override
public void beforeCompletion() {}

@Override
public void beforeCommit(boolean readOnly) {
Map<Object, Object> resourceMap = TransactionSynchronizationManager.getResourceMap();
for (Map.Entry<Object, Object> entry : resourceMap.entrySet()) {
Object key = entry.getKey();
Object value = entry.getValue();
if (value instanceof ConnectionHolder) {
this.dataSource = (DataSource) key;
this.connectionHolder = (ConnectionHolder) value;
break;
}
}
if (connectionHolder == null) {
log.warn("connectionHolder is null");
return;
}
this.id = UUID.randomUUID().toString();
final String mqTemplateName = ApplicationContextUtils.findBeanName(rocketMQTemplate.getClass(), rocketMQTemplate);
MqMsgDao.insertMsg(connectionHolder, id, mqTemplateName, destination, message, timeout, delayLevel);
}
@Override
public void afterCommit() {
log.debug("afterCommit {}", TransactionSynchronizationManager.getCurrentTransactionName());
try {
rocketMQTemplate.syncSend(destination, message, timeout, delayLevel);
MqMsgDao.deleteMsgById(dataSource, this.id);
} catch (Exception e) {
log.error("mq send message failed. topic:[{}], message:[{}]", destination, message, e);
}
}
@Override
public void afterCompletion(int status) {
log.debug("afterCompletion {} : {}", TransactionSynchronizationManager.getCurrentTransactionName(), status);
rocketMQTemplate = null;
destination = null;
message = null;
connectionHolder = null;
dataSource = null;
id = null;
}
}
@Slf4j
public class MqMsgDao {
public static final String STATUS_NEW = "NEW";
public static final Integer MAX_RETRY_TIMES = 5;
private static final JsonMapper MAPPER = JsonMapper.builder()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.enable(MapperFeature.PROPAGATE_TRANSIENT_MARKER)
.build();

public static List<MqMsg> listMsg(DataSource dataSource) {
Connection conn = null;
PreparedStatement ps = null;
ResultSet rs = null;
try {
conn = dataSource.getConnection();
ps = conn.prepareStatement("select * from tb_mq_msg where status = ? and retry_times < ? limit 100");
int i = 0;
ps.setObject(++i, STATUS_NEW);
ps.setObject(++i, MAX_RETRY_TIMES);
rs = ps.executeQuery();
List<MqMsg> list = new ArrayList<>(100);
while (rs.next()) {
MqMsg mqMsg = new MqMsg();
mqMsg.setId(rs.getString("id"));
mqMsg.setStatus(rs.getString("status"));
mqMsg.setMqTemplateName(rs.getString("mq_template_name"));
mqMsg.setMqDestination(rs.getString("mq_destination"));
mqMsg.setMqTimeout(rs.getLong("mq_timeout"));
mqMsg.setMqDelayLevel(rs.getInt("mq_delay_level"));
Map<String, Object> map = fromJson(rs.getString("payload"));
GenericMessage<Object> message = new GenericMessage<>(map.get("payload"), (Map<String, Object>) map.get("headers"));
mqMsg.setMessage(message);
mqMsg.setRetryTimes(rs.getInt("retry_times"));
mqMsg.setCreateTime(rs.getTimestamp("create_time"));
mqMsg.setUpdateTime(rs.getTimestamp("update_time"));
list.add(mqMsg);
}
return list;
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
close(rs, ps, conn);
}
}

public static void insertMsg(ConnectionHolder connectionHolder,
String id,
String mqTemplateName,
String mqDestination,
Message message,
long mqTimeout,
int mqDelayLevel) {
Connection connection = connectionHolder.getConnection();
PreparedStatement ps = null;
Map<String, Object> payload = new HashMap<>();
payload.put("payload", message.getPayload());
payload.put("headers", message.getHeaders());
try {
ps = connection.prepareStatement("insert into tb_mq_msg values(?,?,?,?,?,?,?,?,?,?)");
Date now = new Date();
int i = 0;
ps.setObject(++i, id);
ps.setObject(++i, STATUS_NEW);
ps.setObject(++i, mqTemplateName);
ps.setObject(++i, mqDestination);
ps.setObject(++i, mqTimeout);
ps.setObject(++i, mqDelayLevel);
ps.setObject(++i, toJson(payload));
ps.setObject(++i, 0);
ps.setObject(++i, now);
ps.setObject(++i, now);
int affect = ps.executeUpdate();
if (affect <= 0) {
throw new RuntimeException("insert mq msg affect : " + affect);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
close(ps);
}
}

public static void updateMsgRetryTimes(DataSource dataSource, String id) {
Connection conn = null;
PreparedStatement ps = null;
try {
conn = dataSource.getConnection();
ps = conn.prepareStatement("update tb_mq_msg set retry_times = retry_times + 1, update_time = ? where id = ?");
int i = 0;
ps.setObject(++i, new Date());
ps.setObject(++i, id);
int affect = ps.executeUpdate();
if (affect <= 0) {
log.error("update mq msg retry_times failed. id:[{}]", id);
throw new RuntimeException("update mq msg retry_times failed. id:" + id);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
close(ps, conn);
}
}

public static void deleteMsgById(DataSource dataSource, String id) {
Connection conn = null;
PreparedStatement ps = null;
try {
conn = dataSource.getConnection();
ps = conn.prepareStatement("delete from tb_mq_msg where id = ?");
int i = 0;
ps.setObject(++i, id);
int affect = ps.executeUpdate();
if (affect <= 0) {
log.error("delete mq msg failed. id:[{}]", id);
throw new RuntimeException("delete mq msg failed. id:" + id);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
close(ps, conn);
}
}

private static void close(AutoCloseable... closeables) {
if (closeables != null && closeables.length > 0) {
for (AutoCloseable closeable : closeables) {
if (closeable != null) {
try {
closeable.close();
} catch (Exception ignore) {
}
}
}
}
}

private static String toJson(Object payload) {
try {
return MAPPER.writeValueAsString(payload);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

private static Map<String, Object> fromJson(String payload) {
try {
return MAPPER.readValue(payload, new TypeReference<Map<String, Object>>() {
});
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@Slf4j
@Component
public class ApplicationContextUtils implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
ApplicationContextUtils.applicationContext = applicationContext;
log.info("=== ApplicationContextUtils init ===");
}
public static ApplicationContext getApplicationContext() {
return applicationContext;
}

public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}
public static <T> T getBean(String name, Class<T> clazz) {
return getApplicationContext().getBean(name, clazz);
}
public static String findBeanName(Class clazz, Object obj) {
Map<String, Object> beans = getApplicationContext().getBeansOfType(clazz);
for (Map.Entry<String, Object> entry : beans.entrySet()) {
Object value = entry.getValue();
if (value == obj) {
return entry.getKey();
}
}
return null;
}
}

解决消息发送失败,使用定时任务重试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@Slf4j
@Component
pubilc class MqMsgSchedule implements InitializingBean {
private static final ScheduledThreadPoolExecutor EXECUTOR =
new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
AtomicInteger threadCount = new AtomicInteger(0);
@Override
public Thread newThread(Runnable r) {
return new Thread(r, "mq-msg-" + threadCount.getAndIncrement() + "-" + r.hashCode());
}
}, new ThreadPoolExecutor.DiscardPolicy());
@Override
public void afterPropertiesSet() throws Exception {
EXECUTOR.scheduleAtFixedRate(new Runnable() {
@Override
public void run() {
retrySendTask();
}
}, 0, 5000, TimeUnit.MILLISECONDS);
}
public void retrySendTask() {
try {
Map<String, DataSource> beans = ApplicationContextUtils.getApplicationContext().getBeansOfType(DataSource.class);
for (Map.Entry<String, DataSource> entry : beans.entrySet()) {
List<MqMsg> mqMsgList = MqMsgDao.listMsg(entry.getValue());
for (MqMsg mqMsg : mqMsgList) {
if (mqMsg.getRetryTimes() >= MqMsgDao.MAX_RETRY_TIMES) {
log.error("mqMsg retry times reach {}, id:[{}]", MqMsgDao.MAX_RETRY_TIMES, mqMsg.getId());
} else {
RocketMQTemplate rocketMQTemplate = (RocketMQTemplate) ApplicationContextUtils.getBean(mqMsg.getMqTemplateName());
try {
rocketMQTemplate.syncSend(mqMsg.getMqDestination(),
mqMsg.getMessage(),
mqMsg.getMqTimeout(),
mqMsg.getMqDelayLevel());
MqMsgDao.deleteMsgById(entry.getValue(), mqMsg.getId());
} catch (Exception e) {
MqMsgDao.updateMsgRetryTimes(entry.getValue(), mqMsg.getId());
log.error("[task] mq send failed. mqMsg:[{}]", mqMsg, e);
}
}
}
}
} catch (Exception e) {
log.error("task error.", e);
}
}
}

提供调用类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@Slf4j
public final class MQTransactionHelper {
public static void syncSend(final RocketMQTemplate rocketMQTemplate,
final String destination,
final Message message) {
syncSend(rocketMQTemplate, destination, message,
rocketMQTemplate.getProducer().getSendMsgTimeout(), 0);
}
public static void syncSend(final RocketMQTemplate rocketMQTemplate,
final String destination,
final Message message,
final long timeout,
final int delayLevel) {
if (TransactionSynchronizationManager.isSynchronizationActive()) {
TransactionSynchronizationManager.registerSynchronization(new MQTransactionSynchronization(
rocketMQTemplate, destination, message, timeout, delayLevel
));
}
}
}

数据库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
CREATE TABLE `tb_mq_msg` (
`id` VARCHAR(64) NOT NULL,
`status` VARCHAR(20) NOT NULL COMMENT '事件状态(待发布NEW)',
`mq_template_name` VARCHAR(1000) NOT NULL,
`mq_destination` VARCHAR(1000) NOT NULL,
`mq_timeout` BIGINT NOT NULL,
`mq_delay_level` INT NOT NULL,
`payload` TEXT NOT NULL,
`retry_times` INT NOT NULL,
`create_time` DATETIME NOT NULL,
`update_time` DATETIME NOT NULL,
PRIMARY KEY (`id`),
KEY `idx_status` (`status`)
) ENGINE=INNODB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;

源码:https://github.com/jsbxyyx/rmq-transaction