springboot分页查询并行优化实践

            ——基于异步优化与 MyBatis-Plus 分页插件思想的实践

适用场景

  • 数据量较大的单表分页查询
  • 较复杂的多表关联查询,包含group by等无法进行count优化较耗时的分页查询

技术栈

  • 核心框架:Spring Boot + MyBatis-Plus

  • 异步编程:JDK 8+ 的 CompletableFuture 

  • 数据库:MySQL 8.0

  • 线程池:自定义线程池管理并行任务(如 ThreadPoolTaskExecutor

实现思路

解决传统分页查询中 串行执行 COUNT 与数据查询 的性能瓶颈,通过 并行化 减少总耗时,同时兼容复杂查询场景(如多表关联、DISTINCT 等)

兼容mybatisPlus分页参数,复用 IPage 接口定义分页参数(当前页、每页条数),

借鉴 MyBatis-Plus 的 PaginationInnerInterceptor,通过实现 MyBatis 的 Interceptor 接口,拦截 Executor#query 方法,动态修改 SQL,

sql优化适配:COUNT 优化:自动移除 ORDER BY,保留 GROUP BY 和 DISTINCT(需包裹子查询),数据查询:保留完整 SQL 逻辑,仅追加 LIMIT 和 OFFSET。

直接上代码

使用简单

调用查询方法前赋值page对象属性total大于0数值则可进入自定义分页查询方案。

//示例代码 Page<User> page = new Page<>(1,10); page.setTotal(1L);

线程池配置

@Configuration public class ThreadPoolTaskExecutorConfig {      public static final Integer CORE_POOL_SIZE = 20;     public static final Integer MAX_POOL_SIZE = 40;     public static final Integer QUEUE_CAPACITY = 200;     public static final Integer KEEP_ALIVE_SECONDS = 60;      @Bean("threadPoolTaskExecutor")     public ThreadPoolTaskExecutor getThreadPoolTaskExecutor() {         ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();         //核心线程数         threadPoolTaskExecutor.setCorePoolSize(CORE_POOL_SIZE);         //线程池最大线程数         threadPoolTaskExecutor.setMaxPoolSize(MAX_POOL_SIZE);         //队列容量         threadPoolTaskExecutor.setQueueCapacity(QUEUE_CAPACITY);         //线程空闲存活时间         threadPoolTaskExecutor.setKeepAliveSeconds(KEEP_ALIVE_SECONDS);         //线程前缀         threadPoolTaskExecutor.setThreadNamePrefix("commonTask-");         //拒绝策略         threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());         //线程池初始化         threadPoolTaskExecutor.initialize();          return threadPoolTaskExecutor;     }      @Bean("countAsyncThreadPool")     public ThreadPoolTaskExecutor getCountAsyncThreadPool() {         ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();         //核心线程数,根据负载动态调整         threadPoolTaskExecutor.setCorePoolSize(6);         //线程池最大线程数,根据负载动态调整         threadPoolTaskExecutor.setMaxPoolSize(12);         //队列容量  队列容量不宜过多,根据负载动态调整         threadPoolTaskExecutor.setQueueCapacity(2);         //线程空闲存活时间         threadPoolTaskExecutor.setKeepAliveSeconds(KEEP_ALIVE_SECONDS);         //线程前缀         threadPoolTaskExecutor.setThreadNamePrefix("countAsync-");         //拒绝策略  队列满时由调用者主线程执行         threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());         //线程池初始化         threadPoolTaskExecutor.initialize();          return threadPoolTaskExecutor;     } }

mybatis-plus配置类

@Configuration @MapperScan("com.xxx.mapper") public class MybatisPlusConfig {      @Resource     ThreadPoolTaskExecutor countAsyncThreadPool;     @Resource     ApplicationContext applicationContext;      @Bean     public MybatisPlusInterceptor mybatisPlusInterceptor() {         MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();         interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));         return interceptor;     }       @Bean     public PageParallelQueryInterceptor pageParallelQueryInterceptor() {         PageParallelQueryInterceptor pageParallelQueryInterceptor = new PageParallelQueryInterceptor();         pageParallelQueryInterceptor.setCountAsyncThreadPool(countAsyncThreadPool);         pageParallelQueryInterceptor.setApplicationContext(applicationContext);         return pageParallelQueryInterceptor;     } }

自定义mybatis拦截器

package com.example.dlock_demo.interceptor;  import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.statement.select.*; import org.apache.ibatis.builder.StaticSqlSource; import org.apache.ibatis.cache.CacheKey; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ResultMap; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.session.*; import org.springframework.context.ApplicationContext; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;  import java.lang.reflect.Field; import java.lang.reflect.Method; import java.sql.SQLException; import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap;   /**  * Mybatis-分页并行查询拦截器  *  * @author shf  */ @Intercepts({         @Signature(type = Executor.class, method = "query",                 args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),         @Signature(type = Executor.class, method = "query",                 args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}) }) @Slf4j public class PageParallelQueryInterceptor implements Interceptor {     /**      * 用于数据库并行查询线程池      */     private ThreadPoolTaskExecutor countAsyncThreadPool;     /**      * 容器上下文      */     private ApplicationContext applicationContext;      private static final String LONG_RESULT_MAP_ID = "twoPhase-Long-ResultMap";     private static final Map<String, MappedStatement> twoPhaseMsCache = new ConcurrentHashMap();      public void setCountAsyncThreadPool(ThreadPoolTaskExecutor countAsyncThreadPool) {         this.countAsyncThreadPool = countAsyncThreadPool;     }      public void setApplicationContext(ApplicationContext applicationContext) {         this.applicationContext = applicationContext;     }      @Override     public Object intercept(Invocation invocation) throws Throwable {         Object[] args = invocation.getArgs();         MappedStatement ms = (MappedStatement) args[0];         Object parameter = args[1];          //获取分页参数         Page<?> page = getPageParameter(parameter);         if (page == null || page.getSize() <= 0 || !page.searchCount() || page.getTotal() == 0) {             return invocation.proceed();         }         //获取Mapper方法(注解形式 需利用反射且只能应用在mapper接口层,不推荐使用)         /*Method method = getMapperMethod(ms);         if (method == null || !method.isAnnotationPresent(PageParallelQuery.class)) {             return invocation.proceed();         }*/          BoundSql boundSql = ms.getBoundSql(parameter);         String originalSql = boundSql.getSql();         //禁用mybatis plus PaginationInnerInterceptor count查询         page.setSearchCount(false);         page.setTotal(0);         args[2] = RowBounds.DEFAULT;         CompletableFuture<Long> countFuture = resolveCountCompletableFuture(invocation, originalSql);         //limit查询         long startTime = System.currentTimeMillis();         Object proceed = invocation.proceed();         log.info("原SQL数据查询-耗时={}", System.currentTimeMillis() - startTime);         page.setTotal(countFuture.get());          return proceed;     }      private CompletableFuture<Long> resolveCountCompletableFuture(Invocation invocation, String originalSql) {         return CompletableFuture.supplyAsync(() -> {             try {                 //查询总条数                 long startTime = System.currentTimeMillis();                 long total = executeCountQuery(originalSql, invocation);                 log.info("分页并行查询COUNT总条数[{}]-耗时={}", total, System.currentTimeMillis() - startTime);                 return total;             } catch (Throwable e) {                 log.error("page parallel query exception:", e);                 throw new CompletionException(e);             }         }, countAsyncThreadPool).exceptionally(throwable -> {             log.error("page parallel query exception:", throwable);             return 0L;         });     }      private CompletableFuture<Object> resolveOriginalProceedCompletableFuture(Invocation invocation) {         return CompletableFuture.supplyAsync(() -> {             try {                 long startTime = System.currentTimeMillis();                 Object proceed = invocation.proceed();                 log.info("原SQL数据查询-耗时={}", System.currentTimeMillis() - startTime);                 return proceed;             } catch (Throwable e) {                 throw new CompletionException(e);             }         }, countAsyncThreadPool).exceptionally(throwable -> {             log.error("page parallel query original proceed exception:", throwable);             return null;         });     }      /**      * 执行count查询      */     private long executeCountQuery(String originalSql, Invocation invocation)             throws JSQLParserException, SQLException {          //解析并修改SQL为count查询         Select countSelect = (Select) CCJSqlParserUtil.parse(originalSql);         PlainSelect plainSelect = (PlainSelect) countSelect.getSelectBody();          //修改select为count(*)         /*plainSelect.setSelectItems(Collections.singletonList(                 new SelectExpressionItem(new Function("COUNT", new Column("*")))         );*/         // 移除排序和分页         Distinct distinct = plainSelect.getDistinct();         GroupByElement groupBy = plainSelect.getGroupBy();         String countSql = "";         if (groupBy == null && distinct == null) {             Expression countFuncExpression = CCJSqlParserUtil.parseExpression("COUNT(*)");             plainSelect.setSelectItems(Collections.singletonList(                     new SelectExpressionItem(countFuncExpression)));             plainSelect.setOrderByElements(null);             countSql = plainSelect.toString();         } else if (groupBy != null) {             plainSelect.setLimit(null);             plainSelect.setOffset(null);             countSql = "SELECT COUNT(*) FROM (" + plainSelect + ") TOTAL";         } else {             plainSelect.setOrderByElements(null);             plainSelect.setLimit(null);             plainSelect.setOffset(null);             countSql = "SELECT COUNT(*) FROM (" + plainSelect + ") TOTAL";         }         //执行count查询         return doCountQuery(invocation, countSql);     }      /**      * 执行修改后的COUNT(*)-SQL查询      */     @SuppressWarnings("unchecked")     private Long doCountQuery(Invocation invocation, String modifiedSql) {         //Executor executor = (Executor) invocation.getTarget();         //创建新会话(自动获取新连接)         Executor executor;         SqlSessionFactory sqlSessionFactory = applicationContext.getBean(SqlSessionFactory.class);         try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.SIMPLE)) {             //com.alibaba.druid.pool.DruidPooledConnection             System.out.println("新会话Connection class: " + sqlSession.getConnection().getClass().getName());             Field executorField = sqlSession.getClass().getDeclaredField("executor");             executorField.setAccessible(true);             executor = (Executor) executorField.get(sqlSession);              Object[] args = invocation.getArgs();             MappedStatement originalMs = (MappedStatement) args[0];             Object parameter = args[1];             //创建新的查询参数             Map<String, Object> newParameter = new HashMap<>();             if (parameter instanceof Map) {                 // 复制原始参数但移除分页参数                 Map<?, ?> originalParams = (Map<?, ?>) parameter;                 originalParams.forEach((k, v) -> {                     if (!(v instanceof Page)) {                         newParameter.put(k.toString(), v);                     }                 });             }             //创建新的BoundSql             BoundSql originalBoundSql = originalMs.getBoundSql(newParameter);             BoundSql newBoundSql = new BoundSql(originalMs.getConfiguration(), modifiedSql, originalBoundSql.getParameterMappings(), newParameter);             //复制原始参数值             originalBoundSql.getParameterMappings().forEach(mapping -> {                 String prop = mapping.getProperty();                 if (mapping.getJavaType().isInstance(newParameter)) {                     newBoundSql.setAdditionalParameter(prop, newParameter);                 } else if (newParameter instanceof Map) {                     Object value = ((Map<?, ?>) newParameter).get(prop);                     newBoundSql.setAdditionalParameter(prop, value);                 }             });             //创建新的BoundSql             /*BoundSql originalBoundSql = originalMs.getBoundSql(parameter);             BoundSql newBoundSql = new BoundSql(originalMs.getConfiguration(), modifiedSql,                     originalBoundSql.getParameterMappings(), parameter);*/             Configuration configuration = originalMs.getConfiguration();             //创建临时ResultMap             ResultMap resultMap = new ResultMap.Builder(                     configuration,                     LONG_RESULT_MAP_ID,                     //强制指定结果类型                     Long.class,                     //自动映射列到简单类型                     Collections.emptyList()             ).build();             if (!configuration.hasResultMap(LONG_RESULT_MAP_ID)) {                 configuration.addResultMap(resultMap);             }              String countMsId = originalMs.getId() + "_countMsId";             MappedStatement mappedStatement = twoPhaseMsCache.computeIfAbsent(countMsId, (key) ->                     this.getNewMappedStatement(modifiedSql, originalMs, newBoundSql, resultMap, countMsId));             //执行查询             List<Object> result = executor.query(mappedStatement, newParameter, RowBounds.DEFAULT, (ResultHandler<?>) args[3]);             long total = 0L;             if (CollectionUtils.isNotEmpty(result)) {                 Object o = result.get(0);                 if (o != null) {                     total = Long.parseLong(o.toString());                 }             }             return total;         } catch (Throwable e) {             log.error("分页并行查询-executeCountQuery异常:", e);         }         return 0L;     }      private MappedStatement getNewMappedStatement(String modifiedSql, MappedStatement originalMs, BoundSql newBoundSql,                                                   ResultMap resultMap, String msId) {         //创建新的MappedStatement         MappedStatement.Builder builder = new MappedStatement.Builder(                 originalMs.getConfiguration(),                 msId,                 new StaticSqlSource(originalMs.getConfiguration(), modifiedSql, newBoundSql.getParameterMappings()),                 originalMs.getSqlCommandType()         );         //复制重要属性         builder.resource(originalMs.getResource())                 .fetchSize(originalMs.getFetchSize())                 .timeout(originalMs.getTimeout())                 .statementType(originalMs.getStatementType())                 .keyGenerator(originalMs.getKeyGenerator())                 .keyProperty(originalMs.getKeyProperties() == null ? null : String.join(",", originalMs.getKeyProperties()))                 .resultMaps(resultMap == null ? originalMs.getResultMaps() : Collections.singletonList(resultMap))                 .parameterMap(originalMs.getParameterMap())                 .resultSetType(originalMs.getResultSetType())                 .cache(originalMs.getCache())                 .flushCacheRequired(originalMs.isFlushCacheRequired())                 .useCache(originalMs.isUseCache());         return builder.build();     }      /**      * 获取分页参数      */     private Page<?> getPageParameter(Object parameter) {         if (parameter instanceof Map) {             Map<?, ?> paramMap = (Map<?, ?>) parameter;             return (Page<?>) paramMap.values().stream()                     .filter(p -> p instanceof Page)                     .findFirst()                     .orElse(null);         }         return parameter instanceof Page ? (Page<?>) parameter : null;     }      /**      * 获取Mapper方法      */     private Method getMapperMethod(MappedStatement ms) {         try {             String methodName = ms.getId().substring(ms.getId().lastIndexOf(".") + 1);             Class<?> mapperClass = Class.forName(ms.getId().substring(0, ms.getId().lastIndexOf(".")));             return Arrays.stream(mapperClass.getMethods())                     .filter(m -> m.getName().equals(methodName))                     .findFirst()                     .orElse(null);         } catch (ClassNotFoundException e) {             return null;         }     } }

注意事项

有人可能会担心并行查询,在高并发场景可能会导致count查询与limit数据查询不一致,但其实只要没有锁,只要是分开的两条sql查询,原mybatisplus分页插件也一样面临这个问题。

count优化没有进行join语句判断优化,相当于主动关闭了page.setOptimizeJoinOfCountSql(false);在一对多等场景可能会造成count查询有误,Mybatisplus官网也有相关提示,所以这里干脆舍弃了。

mybatisplus版本不同,可能会导致JsqlParser所使用的api有所不同,需要自己对应版本修改下。本篇版本使用的3.5.1

关于线程池的线程数设置顺便提一下:

网上流行一个说法:

1. CPU 密集型任务

特点:任务主要消耗 CPU 资源(如复杂计算、图像处理)。

线程数建议:

  • 核心线程数:CPU 核心数 + 1(或等于CPU核心数,避免上下文切换过多)。
  • 最大线程数:与核心线程数相同(防止过多线程竞争 CPU)。

2. I/O 密集型任务

特点:任务涉及大量等待(如网络请求、数据库读写)。

线程数建议:

  • 核心线程数:2 * CPU 核心数(确保正常负载下的高效处理)。
  • 最大线程数:根据系统资源调整(用于应对突发高并发)。

其实这个说法来源于一个经验公式推导而来:

threads = CPU核心数 * (1 + 平均等待时间 / 平均计算时间)

《Java 虚拟机并发编程》中介绍

springboot分页查询并行优化实践

springboot分页查询并行优化实践

 

另一篇:《Java Concurrency in Practice》即《java并发编程实践》,给出的线程池大小的估算公式:

 

springboot分页查询并行优化实践

Nthreads=Ncpu*Ucpu*(1+w/c),其中 Ncpu=CPU核心数,Ucpu=cpu使用率,0~1;W/C=等待时间与计算时间的比率

仔细推导两个公式,其实类似,在cpu使用率达100%时,其实结论是一致的,这时候计算线程数的公式就成了,Nthreads=Ncpu*100%*(1+w/c) =Ncpu*(1+w/c)。

那么在实践应用中计算的公式就出来了,【以下推算,不考虑内存消耗等方面】,如下:

1、针对IO密集型,阻塞耗时w一般都是计算耗时几倍c,假设阻塞耗时=计算耗时的情况下,Nthreads=Ncpu*(1+1)=2Ncpu,所以这种情况下,建议考虑2倍的CPU核心数做为线程数

2、对于计算密集型,阻塞耗时趋于0,即w/c趋于0,公式Nthreads = Ncpu。

实际应用时要考虑同时设置了几个隔离线程池,另外tomcat自带的线程池也会共享宿主机公共资源。

 

发表评论

评论已关闭。

相关文章