Java在SpringCloud中自定义Gateway负载均衡策略

一、前言

spring-cloud-starter-netflix-ribbon已经不再更新了,最新版本是2.2.10.RELEASE,最后更新时间是2021年11月18日,详细信息可以看maven官方仓库:org.springframework.cloud/spring-cloud-starter-netflix-ribbon,SpringCloud官方推荐使用spring-cloud-starter-loadbalancer进行负载均衡。

背景:大文件上传做切片文件上传;

流程:将切片文件上传到服务器,然后进行合并任务,合并完成之后上传到对象存储;现在服务搞成多节点以后,网关默认走轮循,但是相同的服务在不同的机器上,这样就会导致切片文件散落在不同的服务器上,会导致文件合并失败;所以根据一个标识去自定义gateway对应服务的负载均衡策略,可以解决这个问题;

我的版本如下:

2.7.3         2021.0.4         2021.0.4.0

二、参考默认实现

springCloud原生默认的负载均衡策略是这个类:

org.springframework.cloud.loadbalancer.core.RoundRobinLoadBalancer

我们参考这个类实现自己的负载均衡策略即可,RoundRobinLoadBalancer实现了ReactorServiceInstanceLoadBalancer这个接口,实现了choose这个方法,如下图:

在choose方法中调用了processInstanceResponse方法,processInstanceResponse方法中调用了getInstanceResponse方法,所以我们我们可以复制RoundRobinLoadBalancer整个类,只修改getInstanceResponse这个方法里的内容就可以实现自定义负载均衡策略。

三、实现代码

原理:根据请求头当中设备的唯一标识传递到下游,唯一标识做哈希取余,可以指定对应的服务器节点,需要的服务设置自定义负载策略,不需要的服务设置默认的轮循机制即可.我这里是根据单独的接口请求地址去自定义,也可以根据服务名称自定义

package com.wondertek.gateway.loadBalancer;

import cn.hutool.core.util.ObjectUtil;

import com.wondertek.web.exception.enums.HttpRequestHeaderEnum;

import lombok.extern.slf4j.Slf4j;

import org.springframework.cloud.gateway.filter.GatewayFilterChain;

import org.springframework.cloud.gateway.filter.GlobalFilter;

import org.springframework.core.Ordered;

import org.springframework.http.server.reactive.ServerHttpRequest;

import org.springframework.stereotype.Component;

import org.springframework.web.server.ServerWebExchange;

import reactor.core.publisher.Mono;

@Slf4j

@Component

public class RequestFilter implements GlobalFilter, Ordered {

@Override

public int getOrder() {

// 应该小于LoadBalancerClientFilter的顺序值

return Ordered.HIGHEST_PRECEDENCE;

}

@Override

public Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) {

ServerHttpRequest request = exchange.getRequest();

String clientDeviceUniqueCode = request.getHeaders().getFirst(HttpRequestHeaderEnum.CLIENT_DEVICE_UNIQUE_CODE.getCode());

// 存入Reactor上下文

String resultCode = clientDeviceUniqueCode;

//路径

String pathUrl = request.getURI().getPath();

/**

* ^ 锚点匹配输入字符串的开始位置。

* /(oms-api|unity-api|cloud-api) 匹配以 /oms-api 或 /unity-api 或 /cloud-api 开始的任何字符串。

* replaceFirst() 方法用空字符串替换第一次匹配的内容,也就是我们想要去掉的服务名称。

*/

String resultPathUrl = pathUrl.replaceFirst("^/(oms-api|unity-api|cloud-api)", "");

return chain.filter(exchange)

.contextWrite(context -> {

if (ObjectUtil.isNotEmpty(resultCode) && ObjectUtil.isNotEmpty(resultPathUrl)) {

log.info("开始将request中的唯一标识封装到上下游中:{},请求path是:{}", resultCode, resultPathUrl);

return context.put("identification", resultCode).put("pathUrl", resultPathUrl);

} else {

//根据需求进行其他处理

return context;

}

});

}

}

package com.wondertek.gateway.loadBalancer;

import cn.hutool.core.util.ObjectUtil;

import com.wondertek.center.constants.BusinessCenterApi;

import lombok.extern.slf4j.Slf4j;

import org.springframework.beans.factory.ObjectProvider;

import org.springframework.cloud.client.ServiceInstance;

import org.springframework.cloud.client.loadbalancer.DefaultResponse;

import org.springframework.cloud.client.loadbalancer.EmptyResponse;

import org.springframework.cloud.client.loadbalancer.Request;

import org.springframework.cloud.client.loadbalancer.Response;

import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;

import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;

import org.springframework.cloud.loadbalancer.core.SelectedInstanceCallback;

import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;

import reactor.core.publisher.Mono;

import java.util.ArrayList;

import java.util.Collections;

import java.util.Comparator;

import java.util.List;

import java.util.concurrent.atomic.AtomicInteger;

@Slf4j

public class ClientDeviceUniqueCodeInstanceLoadBalancer implements ReactorServiceInstanceLoadBalancer {

private final String serviceId;

final AtomicInteger position;

private ObjectProvider serviceInstanceListSupplierProvider;

public ClientDeviceUniqueCodeInstanceLoadBalancer(ObjectProvider serviceInstanceListSupplierProvider, String serviceId, AtomicInteger position) {

this.serviceId = serviceId;

this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;

this.position = position;

}

@Override

public Mono> choose(Request request) {

//在 choose 方法中,使用 deferContextual 方法来访问上下文并提取客户端标识。这里的 getOrDefault 方法尝试从上下文中获取一个键为 "identification" 的值,如果不存在则返回 "default-identification"

return Mono.deferContextual(contextView -> {

String identification = contextView.getOrDefault("identification", "");

log.info("上下游获取到的identification的值为:{}", identification);

String pathUrl = contextView.getOrDefault("pathUrl", "");

log.info("上下游获取到的pathUrl的值为:{}", pathUrl);

ServiceInstanceListSupplier supplier = serviceInstanceListSupplierProvider

.getIfAvailable(NoopServiceInstanceListSupplier::new);

return supplier.get(request).next()

.map(serviceInstances -> processInstanceResponse(supplier, serviceInstances, identification, pathUrl));

});

}

private Response processInstanceResponse(ServiceInstanceListSupplier supplier, List serviceInstances, String identification, String pathUrl) {

Response serviceInstanceResponse;

//特定接口走自定义负载策略

Boolean status = ObjectUtil.isNotEmpty(identification) && ObjectUtil.isNotEmpty(pathUrl) &&

(pathUrl.contains(BusinessCenterApi.WEB_UPLOAD_SLICE_FILE) ||

pathUrl.contains(BusinessCenterApi.WEB_MERGE_SLICE_FILE) ||

pathUrl.contains(BusinessCenterApi.UNITY_UPLOAD_SLICE_FILE) ||

pathUrl.contains(BusinessCenterApi.UNITY_MERGE_SLICE_FILE) ||

pathUrl.contains(BusinessCenterApi.CLOUD_UPLOAD_SLICE_FILE) ||

pathUrl.contains(BusinessCenterApi.CLOUD_MERGE_SLICE_FILE));

if (status) {

serviceInstanceResponse = this.getIpInstanceResponse(serviceInstances, identification);

} else {

serviceInstanceResponse = this.getInstanceResponse(serviceInstances);

}

if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {

((SelectedInstanceCallback) supplier).selectedServiceInstance((ServiceInstance) serviceInstanceResponse.getServer());

}

return serviceInstanceResponse;

}

private Response getInstanceResponse(List instances) {

if (instances.isEmpty()) {

if (log.isWarnEnabled()) {

log.warn("No servers available for service: " + this.serviceId);

}

return new EmptyResponse();

} else {

//创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题

List sortedInstances = new ArrayList<>(instances);

// 现在对新列表进行排序,保持原始列表的顺序不变

Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));

//log.info("获取到的实例个数的值为:{}", sortedInstances.size());

sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));

int pos = Math.abs(this.position.incrementAndGet());

//log.info("默认轮循机制,pos递加后的值为:{}", pos);

int positionIndex = pos % instances.size();

//log.info("取余后的positionIndex的值为:{}", positionIndex);

ServiceInstance instance = instances.get(positionIndex);

//log.info("instance.getUri()的值为:{}", instance.getUri());

log.info("特殊服务,默认轮循机制,routed to instance: {}:{}", instance.getHost(), instance.getPort());

return new DefaultResponse(instance);

}

}

private Response getIpInstanceResponse(List instances, String identification) {

if (instances.isEmpty()) {

log.warn("No servers available for service: " + this.serviceId);

return new EmptyResponse();

} else {

//创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题

List sortedInstances = new ArrayList<>(instances);

// 现在对新列表进行排序,保持原始列表的顺序不变

Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));

//log.info("获取到的实例个数的值为:{}", sortedInstances.size());

sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));

//log.info("多个服务实例,使用客户端 identification 地址的哈希值来选择服务实例");

// 使用排序后的列表来找到实例

int ipHashCode = Math.abs(identification.hashCode());

//log.info("identificationHashCode的值为:{}", ipHashCode);

int instanceIndex = ipHashCode % sortedInstances.size();

//log.info("instanceIndex的值为:{}", instanceIndex);

ServiceInstance instanceToReturn = sortedInstances.get(instanceIndex);

//log.info("instanceToReturn.getUri()的值为:{}", instanceToReturn.getUri());

log.info("特殊服务,自定义identification负载机制,Client identification: {} is routed to instance: {}:{}", identification, instanceToReturn.getHost(), instanceToReturn.getPort());

return new DefaultResponse(instanceToReturn);

}

}

}

package com.wondertek.gateway.loadBalancer;

import lombok.extern.slf4j.Slf4j;

import org.springframework.beans.factory.ObjectProvider;

import org.springframework.cloud.client.ServiceInstance;

import org.springframework.cloud.client.loadbalancer.DefaultResponse;

import org.springframework.cloud.client.loadbalancer.EmptyResponse;

import org.springframework.cloud.client.loadbalancer.Request;

import org.springframework.cloud.client.loadbalancer.Response;

import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;

import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;

import org.springframework.cloud.loadbalancer.core.SelectedInstanceCallback;

import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;

import reactor.core.publisher.Mono;

import java.util.ArrayList;

import java.util.Collections;

import java.util.Comparator;

import java.util.List;

import java.util.concurrent.atomic.AtomicInteger;

@Slf4j

public class DefaultInstanceLoadBalancer implements ReactorServiceInstanceLoadBalancer {

private final String serviceId;

private ObjectProvider serviceInstanceListSupplierProvider;

final AtomicInteger position;

public DefaultInstanceLoadBalancer(ObjectProvider serviceInstanceListSupplierProvider, String serviceId, AtomicInteger position) {

this.serviceId = serviceId;

this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;

this.position = position;

}

@Override

public Mono> choose(Request request) {

ServiceInstanceListSupplier supplier = serviceInstanceListSupplierProvider

.getIfAvailable(NoopServiceInstanceListSupplier::new);

return supplier.get(request).next()

.map(serviceInstances -> processInstanceResponse(supplier, serviceInstances));

}

private Response processInstanceResponse(ServiceInstanceListSupplier supplier,

List serviceInstances) {

Response serviceInstanceResponse = getInstanceResponse(serviceInstances);

if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {

((SelectedInstanceCallback) supplier).selectedServiceInstance(serviceInstanceResponse.getServer());

}

return serviceInstanceResponse;

}

private Response getInstanceResponse(List instances) {

if (instances.isEmpty()) {

if (log.isWarnEnabled()) {

log.warn("No servers available for service: " + serviceId);

}

return new EmptyResponse();

}

//创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题

List sortedInstances = new ArrayList<>(instances);

// 现在对新列表进行排序,保持原始列表的顺序不变

Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));

//log.info("获取到的实例个数的值为:{}", sortedInstances.size());

sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));

int pos = Math.abs(this.position.incrementAndGet());

//log.info("默认轮循机制,pos递加后的值为:{}", pos);

int positionIndex = pos % instances.size();

//log.info("取余后的positionIndex的值为:{}", positionIndex);

ServiceInstance instance = instances.get(positionIndex);

//log.info("instance.getUri()的值为:{}", instance.getUri());

log.info("默认轮循机制,routed to instance: {}:{}",instance.getHost(), instance.getPort());

return new DefaultResponse(instance);

}

}

package com.wondertek.gateway.loadBalancer;

import lombok.extern.slf4j.Slf4j;

import org.springframework.beans.factory.ObjectProvider;

import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;

import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClients;

import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;

import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;

import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;

import org.springframework.context.annotation.Bean;

import org.springframework.context.annotation.Configuration;

import org.springframework.core.env.Environment;

import java.util.concurrent.atomic.AtomicInteger;

@Configuration

//单台服务

//@LoadBalancerClient(name = "oms-api", configuration = CustomLoadBalancerConfig.class)

//多台服务

@LoadBalancerClients({

@LoadBalancerClient(name = "oms-api", configuration = CustomLoadBalancerConfig.class),

@LoadBalancerClient(name = "unity-api", configuration = CustomLoadBalancerConfig.class),

@LoadBalancerClient(name = "cloud-api", configuration = CustomLoadBalancerConfig.class),

@LoadBalancerClient(name = "open-api", configuration = CustomLoadBalancerConfig.class),

@LoadBalancerClient(name = "server-api", configuration = CustomLoadBalancerConfig.class),

@LoadBalancerClient(name = "center-service", configuration = CustomLoadBalancerConfig.class),

})

@Slf4j

public class CustomLoadBalancerConfig {

// 定义一个Bean来提供AtomicInteger的实例

@Bean

public AtomicInteger positionTracker() {

// 这将在应用上下文中只初始化一次

return new AtomicInteger(0);

}

//自定义优先级负载均衡器

@Bean

public ReactorServiceInstanceLoadBalancer customPriorityLoadBalancer(ObjectProvider serviceInstanceListSupplierProvider,

Environment environment,AtomicInteger positionTracker) {

String serviceId = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);

//目的为解决文件上传切片文件分散上传的问题

if ("oms-api".equals(serviceId)||"unity-api".equals(serviceId)||"cloud-api".equals(serviceId)){

//log.info("服务名称:serviceId:{},走自定义clientDeviceUniqueCode负载模式", serviceId);

return new ClientDeviceUniqueCodeInstanceLoadBalancer(serviceInstanceListSupplierProvider, serviceId, positionTracker);

}

//log.info("服务名称:serviceId:{},走默认负载模式", serviceId);

return new DefaultInstanceLoadBalancer(serviceInstanceListSupplierProvider, serviceId,positionTracker);

}

}

【SpringCloud系列】开发环境下重写Loadbalancer实现自定义负载均衡

相关阅读

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: