一个注解解决ShardingJdbc不支持复杂SQL

背景介绍

公司最近做分库分表业务,接入了 Sharding JDBC,接入完成后,回归测试时发现好几个 SQL 执行报错,关键这几个表都还不是分片表。报错如下:
一个注解解决ShardingJdbc不支持复杂SQL

这下糟了嘛。熟悉 Sharding JDBC 的同学应该知道,有很多 SQL 它是不支持的。官方截图如下:
一个注解解决ShardingJdbc不支持复杂SQL

如果要去修改这些复杂 SQL 的话,可能要花费很多时间。那怎么办呢?只能从 Sharding JDBC 这里找突破口了,两天的研究,出来了下面这个只需要加一个注解轻松解决 Sharding Jdbc 不支持复杂 SQL 的方案。

问题复现

我本地写了一个复杂 SQL 进行测试:

public List<Map<String, Object>> queryOrder(){         List<Map<String, Object>> orders = borderRepository.findOrders();         return orders;     } 
public interface BOrderRepository extends JpaRepository<BOrder,Long> {      @Query(value = "SELECT * FROM (SELECT id,CASE WHEN company_id =1 THEN '小' WHEN company_id=4 THEN '中' ELSE '大' END AS com,user_id as userId FROM b_order0) t WHERE t.com ='中'",nativeQuery =true)     List<Map<String, Object>> findOrders(); } 

写了个测试 controller 来调用,调用后果然报错了。

一个注解解决ShardingJdbc不支持复杂SQL

解决思路

因为查询的复杂 SQL 的表不是分片表,那能不能指定这几个复杂查询的时候不用 Sharding JDBC 的数据源呢?

  1. 在注入 Sharding JDBC 数据源的地方做处理,注入一个我们自定义的数据源
  2. 这样我们获取连接的时候就能返回原生数据源了
  3. 另外我们声明一个注解,对标识了注解的就返回原生数据源,否则还是返回 Sharding 数据源

具体实现

  1. 编写一个 autoConfig 类,来替换 ShardingSphereAutoConfiguration 类
/**  * 动态数据源核心自动配置类  *  *  */ @Configuration @ComponentScan("org.apache.shardingsphere.spring.boot.converter") @EnableConfigurationProperties(SpringBootPropertiesConfiguration.class) @ConditionalOnProperty(prefix = "spring.shardingsphere", name = "enabled", havingValue = "true", matchIfMissing = true) @AutoConfigureBefore(DataSourceAutoConfiguration.class) public class DynamicDataSourceAutoConfiguration implements EnvironmentAware {      private String databaseName;      private final SpringBootPropertiesConfiguration props;      private final Map<String, DataSource> dataSourceMap = new LinkedHashMap<>();      public DynamicDataSourceAutoConfiguration(SpringBootPropertiesConfiguration props) {         this.props = props;     }      /**      * Get mode configuration.      *      * @return mode configuration      */     @Bean     public ModeConfiguration modeConfiguration() {         return null == props.getMode() ? null : new ModeConfigurationYamlSwapper().swapToObject(props.getMode());     }      /**      * Get ShardingSphere data source bean.      *      * @param rules rules configuration      * @param modeConfig mode configuration      * @return data source bean      * @throws SQLException SQL exception      */     @Bean     @Conditional(LocalRulesCondition.class)     @Autowired(required = false)     public DataSource shardingSphereDataSource(final ObjectProvider<List<RuleConfiguration>> rules, final ObjectProvider<ModeConfiguration> modeConfig) throws SQLException {         Collection<RuleConfiguration> ruleConfigs = Optional.ofNullable(rules.getIfAvailable()).orElseGet(Collections::emptyList);         DataSource dataSource = ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig.getIfAvailable(), dataSourceMap, ruleConfigs, props.getProps());         return new WrapShardingDataSource((ShardingSphereDataSource) dataSource,dataSourceMap);     }      /**      * Get data source bean from registry center.      *      * @param modeConfig mode configuration      * @return data source bean      * @throws SQLException SQL exception      */     @Bean     @ConditionalOnMissingBean(DataSource.class)     public DataSource dataSource(final ModeConfiguration modeConfig) throws SQLException {         DataSource dataSource = !dataSourceMap.isEmpty() ? ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig, dataSourceMap, Collections.emptyList(), props.getProps())                 : ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig);         return new WrapShardingDataSource((ShardingSphereDataSource) dataSource,dataSourceMap);     }      /**      * Create transaction type scanner.      *      * @return transaction type scanner      */     @Bean     public TransactionTypeScanner transactionTypeScanner() {         return new TransactionTypeScanner();     }      @Override     public final void setEnvironment(final Environment environment) {         dataSourceMap.putAll(DataSourceMapSetter.getDataSourceMap(environment));         databaseName = DatabaseNameSetter.getDatabaseName(environment);     }      @Role(BeanDefinition.ROLE_INFRASTRUCTURE)     @Bean     @ConditionalOnProperty(prefix = "spring.datasource.dynamic.aop", name = "enabled", havingValue = "true", matchIfMissing = true)     public Advisor dynamicDatasourceAnnotationAdvisor() {         DynamicDataSourceAnnotationInterceptor interceptor = new DynamicDataSourceAnnotationInterceptor(true);         DynamicDataSourceAnnotationAdvisor advisor = new DynamicDataSourceAnnotationAdvisor(interceptor, DS.class);         return advisor;     }   } 
  1. 自定义数据源
public class WrapShardingDataSource extends AbstractDataSourceAdapter implements AutoCloseable{      private ShardingSphereDataSource dataSource;      private Map<String, DataSource> dataSourceMap;      public WrapShardingDataSource(ShardingSphereDataSource dataSource, Map<String, DataSource> dataSourceMap) {         this.dataSource = dataSource;         this.dataSourceMap = dataSourceMap;     }      public DataSource getTargetDataSource(){         String peek = DynamicDataSourceContextHolder.peek();         if(StringUtils.isEmpty(peek)){             return dataSource;         }         return dataSourceMap.get(peek);     }       @Override     public Connection getConnection() throws SQLException {         return getTargetDataSource().getConnection();     }      @Override     public Connection getConnection(final String username, final String password) throws SQLException {         return getConnection();     }        @Override     public void close() throws Exception {         DataSource targetDataSource = getTargetDataSource();         if (targetDataSource instanceof AutoCloseable) {             ((AutoCloseable) targetDataSource).close();         }     }      @Override     public int getLoginTimeout() throws SQLException {         DataSource targetDataSource = getTargetDataSource();         return targetDataSource ==null ? 0 : targetDataSource.getLoginTimeout();     }      @Override     public void setLoginTimeout(final int seconds) throws SQLException {         DataSource targetDataSource = getTargetDataSource();         targetDataSource.setLoginTimeout(seconds);     } } 
  1. 声明指定数据源注解
@Target({ElementType.TYPE, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface DS {      /**      * 数据源名      */     String value(); } 
  1. 另外使用 AOP 的方式拦截使用了注解的类或方法,并且要将这些用了注解的方法存起来,在获取数据源连接的时候取出来进行判断。这就还要用到 ThreadLocal。

aop 拦截器:

public class DynamicDataSourceAnnotationInterceptor implements MethodInterceptor {      private final DataSourceClassResolver dataSourceClassResolver;      public DynamicDataSourceAnnotationInterceptor(Boolean allowedPublicOnly) {         dataSourceClassResolver = new DataSourceClassResolver(allowedPublicOnly);     }      @Override     public Object invoke(MethodInvocation invocation) throws Throwable {         String dsKey = determineDatasourceKey(invocation);         DynamicDataSourceContextHolder.push(dsKey);         try {             return invocation.proceed();         } finally {             DynamicDataSourceContextHolder.poll();         }     }      private String determineDatasourceKey(MethodInvocation invocation) {         String key = dataSourceClassResolver.findKey(invocation.getMethod(), invocation.getThis());         return key;     } } 

aop 切面定义:

/**  * aop Advisor  */ public class DynamicDataSourceAnnotationAdvisor extends AbstractPointcutAdvisor implements BeanFactoryAware {      private final Advice advice;      private final Pointcut pointcut;      private final Class<? extends Annotation> annotation;      public DynamicDataSourceAnnotationAdvisor(MethodInterceptor advice,                                                Class<? extends Annotation> annotation) {         this.advice = advice;         this.annotation = annotation;         this.pointcut = buildPointcut();     }      @Override     public Pointcut getPointcut() {         return this.pointcut;     }      @Override     public Advice getAdvice() {         return this.advice;     }      @Override     public void setBeanFactory(BeanFactory beanFactory) throws BeansException {         if (this.advice instanceof BeanFactoryAware) {             ((BeanFactoryAware) this.advice).setBeanFactory(beanFactory);         }     }      private Pointcut buildPointcut() {         Pointcut cpc = new AnnotationMatchingPointcut(annotation, true);         Pointcut mpc = new AnnotationMethodPoint(annotation);         return new ComposablePointcut(cpc).union(mpc);     }      /**      * In order to be compatible with the spring lower than 5.0      */     private static class AnnotationMethodPoint implements Pointcut {          private final Class<? extends Annotation> annotationType;          public AnnotationMethodPoint(Class<? extends Annotation> annotationType) {             Assert.notNull(annotationType, "Annotation type must not be null");             this.annotationType = annotationType;         }          @Override         public ClassFilter getClassFilter() {             return ClassFilter.TRUE;         }          @Override         public MethodMatcher getMethodMatcher() {             return new AnnotationMethodMatcher(annotationType);         }          private static class AnnotationMethodMatcher extends StaticMethodMatcher {             private final Class<? extends Annotation> annotationType;              public AnnotationMethodMatcher(Class<? extends Annotation> annotationType) {                 this.annotationType = annotationType;             }              @Override             public boolean matches(Method method, Class<?> targetClass) {                 if (matchesMethod(method)) {                     return true;                 }                 // Proxy classes never have annotations on their redeclared methods.                 if (Proxy.isProxyClass(targetClass)) {                     return false;                 }                 // The method may be on an interface, so let's check on the target class as well.                 Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass);                 return (specificMethod != method && matchesMethod(specificMethod));             }              private boolean matchesMethod(Method method) {                 return AnnotatedElementUtils.hasAnnotation(method, this.annotationType);             }         }     } } 
 /**  * 数据源解析器  *  */ public class DataSourceClassResolver {      private static boolean mpEnabled = false;      private static Field mapperInterfaceField;      static {         Class<?> proxyClass = null;         try {             proxyClass = Class.forName("com.baomidou.mybatisplus.core.override.MybatisMapperProxy");         } catch (ClassNotFoundException e1) {             try {                 proxyClass = Class.forName("com.baomidou.mybatisplus.core.override.PageMapperProxy");             } catch (ClassNotFoundException e2) {                 try {                     proxyClass = Class.forName("org.apache.ibatis.binding.MapperProxy");                 } catch (ClassNotFoundException ignored) {                 }             }         }         if (proxyClass != null) {             try {                 mapperInterfaceField = proxyClass.getDeclaredField("mapperInterface");                 mapperInterfaceField.setAccessible(true);                 mpEnabled = true;             } catch (NoSuchFieldException e) {                 e.printStackTrace();             }         }     }      /**      * 缓存方法对应的数据源      */     private final Map<Object, String> dsCache = new ConcurrentHashMap<>();     private final boolean allowedPublicOnly;      /**      * 加入扩展, 给外部一个修改aop条件的机会      *      * @param allowedPublicOnly 只允许公共的方法, 默认为true      */     public DataSourceClassResolver(boolean allowedPublicOnly) {         this.allowedPublicOnly = allowedPublicOnly;     }      /**      * 从缓存获取数据      *      * @param method       方法      * @param targetObject 目标对象      * @return ds      */     public String findKey(Method method, Object targetObject) {         if (method.getDeclaringClass() == Object.class) {             return "";         }         Object cacheKey = new MethodClassKey(method, targetObject.getClass());         String ds = this.dsCache.get(cacheKey);         if (ds == null) {             ds = computeDatasource(method, targetObject);             if (ds == null) {                 ds = "";             }             this.dsCache.put(cacheKey, ds);         }         return ds;     }      /**      * 查找注解的顺序      * 1. 当前方法      * 2. 桥接方法      * 3. 当前类开始一直找到Object      * 4. 支持mybatis-plus, mybatis-spring      *      * @param method       方法      * @param targetObject 目标对象      * @return ds      */     private String computeDatasource(Method method, Object targetObject) {         if (allowedPublicOnly && !Modifier.isPublic(method.getModifiers())) {             return null;         }         //1. 从当前方法接口中获取         String dsAttr = findDataSourceAttribute(method);         if (dsAttr != null) {             return dsAttr;         }         Class<?> targetClass = targetObject.getClass();         Class<?> userClass = ClassUtils.getUserClass(targetClass);         // JDK代理时,  获取实现类的方法声明.  method: 接口的方法, specificMethod: 实现类方法         Method specificMethod = ClassUtils.getMostSpecificMethod(method, userClass);          specificMethod = BridgeMethodResolver.findBridgedMethod(specificMethod);         //2. 从桥接方法查找         dsAttr = findDataSourceAttribute(specificMethod);         if (dsAttr != null) {             return dsAttr;         }         // 从当前方法声明的类查找         dsAttr = findDataSourceAttribute(userClass);         if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) {             return dsAttr;         }         //since 3.4.1 从接口查找,只取第一个找到的         for (Class<?> interfaceClazz : ClassUtils.getAllInterfacesForClassAsSet(userClass)) {             dsAttr = findDataSourceAttribute(interfaceClazz);             if (dsAttr != null) {                 return dsAttr;             }         }         // 如果存在桥接方法         if (specificMethod != method) {             // 从桥接方法查找             dsAttr = findDataSourceAttribute(method);             if (dsAttr != null) {                 return dsAttr;             }             // 从桥接方法声明的类查找             dsAttr = findDataSourceAttribute(method.getDeclaringClass());             if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) {                 return dsAttr;             }         }         return getDefaultDataSourceAttr(targetObject);     }      /**      * 默认的获取数据源名称方式      *      * @param targetObject 目标对象      * @return ds      */     private String getDefaultDataSourceAttr(Object targetObject) {         Class<?> targetClass = targetObject.getClass();         // 如果不是代理类, 从当前类开始, 不断的找父类的声明         if (!Proxy.isProxyClass(targetClass)) {             Class<?> currentClass = targetClass;             while (currentClass != Object.class) {                 String datasourceAttr = findDataSourceAttribute(currentClass);                 if (datasourceAttr != null) {                     return datasourceAttr;                 }                 currentClass = currentClass.getSuperclass();             }         }         // mybatis-plus, mybatis-spring 的获取方式         if (mpEnabled) {             final Class<?> clazz = getMapperInterfaceClass(targetObject);             if (clazz != null) {                 String datasourceAttr = findDataSourceAttribute(clazz);                 if (datasourceAttr != null) {                     return datasourceAttr;                 }                 // 尝试从其父接口获取                 return findDataSourceAttribute(clazz.getSuperclass());             }         }         return null;     }      /**      * 用于处理嵌套代理      *      * @param target JDK 代理类对象      * @return InvocationHandler 的 Class      */     private Class<?> getMapperInterfaceClass(Object target) {         Object current = target;         while (Proxy.isProxyClass(current.getClass())) {             Object currentRefObject = AopProxyUtils.getSingletonTarget(current);             if (currentRefObject == null) {                 break;             }             current = currentRefObject;         }         try {             if (Proxy.isProxyClass(current.getClass())) {                 return (Class<?>) mapperInterfaceField.get(Proxy.getInvocationHandler(current));             }         } catch (IllegalAccessException ignore) {         }         return null;     }      /**      * 通过 AnnotatedElement 查找标记的注解, 映射为  DatasourceHolder      *      * @param ae AnnotatedElement      * @return 数据源映射持有者      */     private String findDataSourceAttribute(AnnotatedElement ae) {         AnnotationAttributes attributes = AnnotatedElementUtils.getMergedAnnotationAttributes(ae, DS.class);         if (attributes != null) {             return attributes.getString("value");         }         return null;     } } 

ThreadLocal:

public final class DynamicDataSourceContextHolder {      /**      * 为什么要用链表存储(准确的是栈)      * <pre>      * 为了支持嵌套切换,如ABC三个service都是不同的数据源      * 其中A的某个业务要调B的方法,B的方法需要调用C的方法。一级一级调用切换,形成了链。      * 传统的只设置当前线程的方式不能满足此业务需求,必须使用栈,后进先出。      * </pre>      */     private static final ThreadLocal<Deque<String>> LOOKUP_KEY_HOLDER = new NamedThreadLocal<Deque<String>>("dynamic-datasource") {         @Override         protected Deque<String> initialValue() {             return new ArrayDeque<>();         }     };      private DynamicDataSourceContextHolder() {     }      /**      * 获得当前线程数据源      *      * @return 数据源名称      */     public static String peek() {         return LOOKUP_KEY_HOLDER.get().peek();     }      /**      * 设置当前线程数据源      * <p>      * 如非必要不要手动调用,调用后确保最终清除      * </p>      *      * @param ds 数据源名称      */     public static String push(String ds) {         String dataSourceStr = StringUtils.isEmpty(ds) ? "" : ds;         LOOKUP_KEY_HOLDER.get().push(dataSourceStr);         return dataSourceStr;     }      /**      * 清空当前线程数据源      * <p>      * 如果当前线程是连续切换数据源 只会移除掉当前线程的数据源名称      * </p>      */     public static void poll() {         Deque<String> deque = LOOKUP_KEY_HOLDER.get();         deque.poll();         if (deque.isEmpty()) {             LOOKUP_KEY_HOLDER.remove();         }     }      /**      * 强制清空本地线程      * <p>      * 防止内存泄漏,如手动调用了push可调用此方法确保清除      * </p>      */     public static void clear() {         LOOKUP_KEY_HOLDER.remove();     } } 
  1. 启动类上做如下配置:

引入我们写的自动配置类,排除 ShardingJdbc 的自动配置类。

@SpringBootApplication(exclude = ShardingSphereAutoConfiguration.class) @Import({DynamicDataSourceAutoConfiguration.class}) public class ShardingRunApplication {      public static void main(String[] args) {         SpringApplication.run(ShardingRunApplication.class);     } } 

最后,我们给之前写的 Repository 加上注解:

public interface BOrderRepository extends JpaRepository<BOrder,Long> {      @DS("slave0")     @Query(value = "SELECT * FROM (SELECT id,CASE WHEN company_id =1 THEN '小' WHEN company_id=4 THEN '中' ELSE '大' END AS com,user_id as userId FROM b_order0) t WHERE t.com ='中'",nativeQuery =true)     List<Map<String, Object>> findOrders(); } 

再次调用,查询成功!!!
一个注解解决ShardingJdbc不支持复杂SQL

一个注解解决ShardingJdbc不支持复杂SQL

发表评论

评论已关闭。

相关文章