Skip to content

Commit

Permalink
add gateway ip priority load balancer
Browse files Browse the repository at this point in the history
  • Loading branch information
peacewong committed Apr 16, 2024
1 parent 2eefd7c commit 7624aa8
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

package org.apache.linkis.rpc.conf;

import org.apache.linkis.rpc.BaseRPCSender;
import org.apache.linkis.rpc.constant.RpcConstant;
import org.apache.linkis.server.BDPJettyServerHelper;
import org.apache.linkis.server.Message;
import org.apache.linkis.server.security.SSOUtils$;
import org.apache.linkis.server.security.SecurityFilter$;

import org.springframework.stereotype.Component;

import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Collection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,5 @@ public class NacosClientCacheManualRefresher implements CacheManualRefresher {
private static final Logger logger =
LoggerFactory.getLogger(NacosClientCacheManualRefresher.class);

public void refresh() {

}
public void refresh() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,11 @@
package org.apache.linkis.rpc.loadbalancer;

import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClients;
import org.springframework.cloud.loadbalancer.core.ReactorLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;


public class LinkisLoadBalancerClientConfiguration {
public ReactorLoadBalancer<ServiceInstance> customLoadBalancer(
Environment environment, LoadBalancerClientFactory loadBalancerClientFactory) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ public class ServiceInstancePriorityLoadBalancer implements ReactorServiceInstan
final AtomicInteger position;
private final ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;

private final Long maxWaitTime = RPCConfiguration.RPC_SERVICE_REFRESH_MAX_WAIT_TIME().getValue().toLong();
private final Long maxWaitTime =
RPCConfiguration.RPC_SERVICE_REFRESH_MAX_WAIT_TIME().getValue().toLong();

public ServiceInstancePriorityLoadBalancer(
ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider,
Expand Down Expand Up @@ -131,7 +132,8 @@ && isRPC(linkisLoadBalancerType)
if (null == serviceInstanceResponse && StringUtils.isNotBlank(clientIp)) {
throw new NoInstanceExistsException(
LinkisRpcErrorCodeSummary.INSTANCE_NOT_FOUND_ERROR.getErrorCode(),
MessageFormat.format(LinkisRpcErrorCodeSummary.INSTANCE_NOT_FOUND_ERROR.getErrorDesc(), clientIp));
MessageFormat.format(
LinkisRpcErrorCodeSummary.INSTANCE_NOT_FOUND_ERROR.getErrorDesc(), clientIp));
}

if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,13 @@ import org.springframework.web.bind.annotation.{RequestBody, RequestMapping, Req

private[rpc] trait RPCReceiveRemote {

@RequestMapping(
value = Array("${spring.mvc.servlet.path}/rpc/receive"),
method = Array(RequestMethod.POST)
)
@RequestMapping(value = Array("/rpc/receive"), method = Array(RequestMethod.POST))
def receive(@RequestBody message: Message): Message

@RequestMapping(
value = Array("${spring.mvc.servlet.path}/rpc/receiveAndReply"),
method = Array(RequestMethod.POST)
)
@RequestMapping(value = Array("/rpc/receiveAndReply"), method = Array(RequestMethod.POST))
def receiveAndReply(@RequestBody message: Message): Message

@RequestMapping(
value = Array("${spring.mvc.servlet.path}/rpc/replyInMills"),
method = Array(RequestMethod.POST)
)
@RequestMapping(value = Array("/rpc/replyInMills"), method = Array(RequestMethod.POST))
def receiveAndReplyInMills(@RequestBody message: Message): Message

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ private[rpc] object RPCSpringBeanCache extends Logging {
private var rpcServerLoader: RPCServerLoader = _
private var senderBuilders: Array[BroadcastSenderBuilder] = _
private var rpcReceiveRestful: RPCReceiveRestful = _
private var rpcReceiveRemote: RPCReceiveRemote = _

def registerReceiver(receiverName: String, receiver: Receiver): Unit = {
if (beanNameToReceivers == null) {
Expand Down Expand Up @@ -64,13 +63,6 @@ private[rpc] object RPCSpringBeanCache extends Logging {
rpcReceiveRestful
}

def getRPCReceiveRemote: RPCReceiveRemote = {
if (rpcReceiveRemote == null) {
rpcReceiveRemote = getApplicationContext.getBean(classOf[RPCReceiveRemote])
}
rpcReceiveRemote
}

private[rpc] def getReceivers: util.Map[String, Receiver] = {
if (beanNameToReceivers == null) {
beanNameToReceivers = getApplicationContext.getBeansOfType(classOf[Receiver])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.linkis.rpc.sender

import feign._
import org.apache.commons.lang3.StringUtils
import org.apache.linkis.common.ServiceInstance
import org.apache.linkis.rpc.interceptor.{RPCInterceptor, ServiceInstanceRPCInterceptorChain}
import org.apache.linkis.rpc.{BaseRPCSender, RPCMessageEvent, RPCSpringBeanCache}
import org.apache.linkis.rpc.interceptor.{RPCInterceptor, ServiceInstanceRPCInterceptorChain}

import org.apache.commons.lang3.StringUtils

import feign._

private[rpc] class SpringMVCRPCSender private[rpc] (
private[rpc] val serviceInstance: ServiceInstance
Expand All @@ -38,7 +40,7 @@ private[rpc] class SpringMVCRPCSender private[rpc] (
override protected def doBuilder(builder: Feign.Builder): Unit = {
if (serviceInstance != null && StringUtils.isNotBlank(serviceInstance.getInstance)) {
builder.requestInterceptor(new RequestInterceptor() {
def apply(template: RequestTemplate ): Unit = {
def apply(template: RequestTemplate): Unit = {
// Fixed instance
template.target(s"http://${serviceInstance.getInstance}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,28 @@ class StaticAuthenticationStrategy(override protected val sessionMaxAliveTime: L
override def isTimeout(authentication: Authentication): Boolean =
System.currentTimeMillis() - authentication.getLastAccessTime >= serverSessionTimeout

/**
* Forced login needs to consider the situation of multiple calls at the same time. If there are
* simultaneous calls, it should not be updated. request time < last creatTime and last createTime
* - currentTime < 1s
* @param requestAction
* @param serverUrl
* @return
*/
override def enforceLogin(requestAction: Action, serverUrl: String): Authentication = {
val key = getKey(requestAction, serverUrl)
if (key == null) return null
val requestTime = System.currentTimeMillis()
key.intern() synchronized {
val authentication = tryLogin(requestAction, serverUrl)
putSession(key, authentication)
var authentication = getAuthenticationActionByKey(key)
if (
authentication == null || (authentication.getCreateTime < requestTime && (System
.currentTimeMillis() - authentication.getCreateTime) > 1000)
) {
authentication = tryLogin(requestAction, serverUrl)
putSession(key, authentication)
logger.info(s"$key try enforceLogin")
}
authentication
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class DWSAuthenticationResult(response: HttpResponse, serverUrl: String)
override def getAuthentication: Authentication = if (getStatus == 0) new HttpAuthentication {
private var lastAccessTime: Long = System.currentTimeMillis

private val createTime: Long = System.currentTimeMillis()

override def authToCookies: Array[Cookie] = Array.empty

override def authToHeaders: util.Map[String, String] = new util.HashMap[String, String]()
Expand All @@ -61,6 +63,9 @@ class DWSAuthenticationResult(response: HttpResponse, serverUrl: String)
override def getLastAccessTime: Long = lastAccessTime

override def updateLastAccessTime(): Unit = lastAccessTime = System.currentTimeMillis

override def getCreateTime: Long = createTime

}
else {
throw new HttpMessageParseException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.linkis.gateway.security.LinkisPreFilter$;
import org.apache.linkis.gateway.security.SecurityFilter;
import org.apache.linkis.gateway.springcloud.SpringCloudGatewayConfiguration;
import org.apache.linkis.rpc.constant.RpcConstant;
import org.apache.linkis.server.Message;

import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -132,10 +131,7 @@ private Route getRealRoute(
}
String uri = scheme + serviceInstance.getApplicationName();
if (StringUtils.isNotBlank(serviceInstance.getInstance())) {
exchange
.getRequest()
.mutate()
.header(RpcConstant.FIXED_INSTANCE, serviceInstance.getInstance());
exchange.getRequest().mutate().header("FIXED_INSTANCE", serviceInstance.getInstance());
}
return Route.async()
.id(route.getId())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.linkis.gateway.springcloud.http;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.*;
import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.SelectedInstanceCallback;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.ThreadLocalRandom;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

public class IpPriorityLoadBalancer implements ReactorServiceInstanceLoadBalancer {

private static final Logger logger = LoggerFactory.getLogger(IpPriorityLoadBalancer.class);

private final String serviceId;
private final ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;

public IpPriorityLoadBalancer(
String serviceId,
ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider) {
this.serviceId = serviceId;
this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
}

@Override
public Mono<Response<ServiceInstance>> choose(Request request) {
List<String> clientIpList =
((RequestDataContext) request.getContext())
.getClientRequest()
.getHeaders()
.get("client-ip");
String clientIp = CollectionUtils.isNotEmpty(clientIpList) ? clientIpList.get(0) : null;
ServiceInstanceListSupplier supplier =
serviceInstanceListSupplierProvider.getIfAvailable(NoopServiceInstanceListSupplier::new);
return supplier
.get(request)
.next()
.map(serviceInstances -> processInstanceResponse(supplier, serviceInstances, clientIp));
}

private Response<ServiceInstance> processInstanceResponse(
ServiceInstanceListSupplier supplier,
List<ServiceInstance> serviceInstances,
String clientIp) {
Response<ServiceInstance> serviceInstanceResponse =
getInstanceResponse(serviceInstances, clientIp);
if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
((SelectedInstanceCallback) supplier)
.selectedServiceInstance(serviceInstanceResponse.getServer());
}
return serviceInstanceResponse;
}

private Response<ServiceInstance> getInstanceResponse(
List<ServiceInstance> instances, String clientIp) {
if (instances.isEmpty()) {
logger.warn("No servers available for service: " + serviceId);
return new EmptyResponse();
}
if (StringUtils.isEmpty(clientIp)) {
return new DefaultResponse(
instances.get(ThreadLocalRandom.current().nextInt(instances.size())));
}
String[] ipAndPort = clientIp.split(":");
if (ipAndPort.length != 2) {
return new DefaultResponse(
instances.get(ThreadLocalRandom.current().nextInt(instances.size())));
}
ServiceInstance chooseInstance = null;
for (ServiceInstance instance : instances) {
if (Objects.equals(ipAndPort[0], instance.getHost())
&& Objects.equals(ipAndPort[1], String.valueOf(instance.getPort()))) {
return new DefaultResponse(instance);
}
}
return new DefaultResponse(
instances.get(ThreadLocalRandom.current().nextInt(instances.size())));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.linkis.gateway.springcloud.http;

import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.loadbalancer.core.ReactorLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.core.env.Environment;

public class LinkisLoadBalancerClientConfiguration {
@Bean
public ReactorLoadBalancer<ServiceInstance> customLoadBalancer(
Environment environment, LoadBalancerClientFactory loadBalancerClientFactory) {
String name = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
return new IpPriorityLoadBalancer(
name, loadBalancerClientFactory.getLazyProvider(name, ServiceInstanceListSupplier.class));
}
}

0 comments on commit 7624aa8

Please sign in to comment.