手写RPC-升级版

概述

手写RPC-简易版 ,实现了一个很简单的RPC调用示例,其中还存在很多问题及可优化的点,
这次做个升级,完全重写之前的代码,使得代码逻辑更加规范,同时引入Zookeeper辅助完成服务治理。在代码展示前,先介绍下Zookeeper、服务治理等概念。

Zookeeper

ZooKeeper(简称zk)是一个分布式、开源的应用协调服务,利用和Paxos类似的ZAB选举算法实现分布式一致性服务。有类似于Unix文件目录的节点信息,同时可以针对节点的变更添加watcher监听以能够及时感知到节点信息变更。可提供的功能例如数据发布/订阅、负载均衡、命名服务、分布式协调/通知、集群管理、Master选举、分布式锁和分布式队列等功能。如下图就是DUBBO存储在ZooKeeper的节点数据情况:
img

在本地启动服务后通过zk客户端连接后也可通过命令查看节点信息,如下图所示。
img

ZooKeeper包含了4种不同含义的功能节点,在每次创建节点之前都需要明确声明节点类型:

类型 定义 描述
PERSISTENT 持久化目录节点 客户端与zookeeper断开连接后,该节点依旧存在
PERSISTENT_SEQUENTIAL 持久化顺序编号目录节点 客户端与zookeeper断开连接后,该节点依旧存在,只是Zookeeper给该节点名称进行顺序编号
EPHEMERAL 临时目录节点 客户端与zookeeper断开连接后,该节点被删除
EPHEMERAL_SEQUENTIAL 临时顺序编号目录节点 客户端与zookeeper断开连接后,该节点被删除,只是Zookeeper给该节点名称进行顺序编号

ZooKeeper使用之前需要先进行安装,后开启服务端的服务, 我们的服务作为客户端连接ZooKeeper以便于后续的操作。具体可参考官网文档Zookeeper3.5.5 官方文档,在实际的java项目开发中也是可以通过maven引入ZkClient或者Curator开源的客户端,在本文学习笔记中是使用的Curator,因为其已经封装了原始的节点注册、数据获取、添加watcher等功能。具体maven引入的版本如下:

1
2
3
4
5
6
7
8
9
10
<dependency>
<groupId>org.apache.curator</groupId>
<artifactId>curator-framework</artifactId>
<version>4.2.0</version>
</dependency>
<dependency>
<groupId>org.apache.curator</groupId>
<artifactId>curator-recipes</artifactId>
<version>4.2.0</version>
</dependency>

服务治理

服务治理也就是针对服务进行管理的措施,例如 服务发现、 服务暴露、 负载均衡、 快速上下线等都是服务治理的具体体现。

服务发现:从服务管理中心获取到需要的服务相关信息,例如可以从zk中获取相关服务的机器信息,然后就可以和具体机器直连完成相关功能。

服务暴露:服务提供方可以提供什么样子的功能,经过服务暴露暴露出去,其他使用方就可以通过服务发现发现具体的服务提供方信息。

负载均衡:一般针对的是服务提供方,避免大量请求同时打到一台机器上,采用随机、轮询等措施让请求均分到各个机器上,提供服务效率, 限流, 灰度等也都是类似的操作,通过动态路由、软负载的形式处理分发请求。

快速上线下:以往需要上下线可能需要杀掉机器上的进程,现在只需要让该服务停止暴露即可,实现服务的灵活上下线。

数据处理流程

服务端:服务的提供方,接受网络传输的请求数据、通过网络把应答数据发送给客户端;

客户端:服务的调用方,使用本地代理,通过网络把请求数据发送出去,接受服务端返回的应答数据.
img

所有的数据传输都是按照上面图片说的流程来的,如果需要添加自定义的序列化工具,则需要在把数据提交到socket的输出流缓冲区之前按照序列化工具完成序列化操作,反序列化则进行反向操作即可。

RPC V2版本

文件夹目录如下图所示:
img
img

  • balance文件夹:负载均衡有关;
  • config文件夹:网络套接字传输的数据模型以及服务暴露、服务发现的数据模型;
  • core文件夹:核心文件夹,包含了服务端和客户端的请求处理、代理生成等;
  • demo文件夹:测试使用;
  • domain文件夹:模型、枚举常量;
  • io.protocol文件夹:目前是只有具体的请求对象和网络io的封装;
  • register文件夹:服务注册使用,实现了使用zk进行服务注册和服务发现的操作;
  • serialize文件夹:序列化、反序列化,实现了Java和Hessian两种。

服务注册&服务发现

ServiceRegister:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
package com.springboot.whb.study.rpc.rpc_v2.register;

import com.springboot.whb.study.rpc.rpc_v2.config.BasicConfig;
import com.springboot.whb.study.rpc.rpc_v2.core.RpcRequest;
import com.springboot.whb.study.rpc.rpc_v2.domain.ServiceType;

import java.net.InetSocketAddress;

/**
* @author: whb
* @description: 服务注册
*/
public interface ServiceRegister {

/**
* 服务注册
*
* @param config
*/
void register(BasicConfig config);

/**
* 服务发现,从注册中心获取可用的服务提供方信息
*
* @param request
* @param nodeType
* @return
*/
InetSocketAddress discovery(RpcRequest request, ServiceType nodeType);
}

ZkServiceRegister

默认使用了CuratorFramework客户端完成zk数据的操作.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package com.springboot.whb.study.rpc.rpc_v2.register;

import com.springboot.whb.study.rpc.rpc_v2.balance.DefaultLoadBalance;
import com.springboot.whb.study.rpc.rpc_v2.balance.LoadBalance;
import com.springboot.whb.study.rpc.rpc_v2.config.BasicConfig;
import com.springboot.whb.study.rpc.rpc_v2.core.RpcRequest;
import com.springboot.whb.study.rpc.rpc_v2.domain.ServiceType;
import lombok.extern.slf4j.Slf4j;
import org.apache.curator.RetryPolicy;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.CuratorFrameworkFactory;
import org.apache.curator.retry.ExponentialBackoffRetry;
import org.apache.zookeeper.CreateMode;

import java.net.InetSocketAddress;
import java.util.List;

/**
* @author: whb
* @description: Zookeeper服务注册实现类
*/
@Slf4j
public class ZkServiceRegister implements ServiceRegister {

private CuratorFramework client;

private static final String ROOT_PATH = "whb/demo-rpc";

private LoadBalance loadBalance = new DefaultLoadBalance();

public ZkServiceRegister() {
//重试策略
RetryPolicy policy = new ExponentialBackoffRetry(1000, 3);

this.client = CuratorFrameworkFactory
.builder()
.connectString("127.0.0.1:2181")
.sessionTimeoutMs(50000)
.retryPolicy(policy)
.namespace(ROOT_PATH)
.build();
// 业务的根路径是 /whb/demo-rpc ,其他的都会默认挂载在这里

this.client.start();
System.out.println("zk启动正常");
}

/**
* 服务注册
*
* @param config
*/
@Override
public void register(BasicConfig config) {
String interfacePath = "/" + config.getInterfaceName();
try {
if (this.client.checkExists().forPath(interfacePath) == null) {
// 创建 服务的永久节点
this.client.create()
.creatingParentsIfNeeded()
.withMode(CreateMode.PERSISTENT)
.forPath(interfacePath);
}

config.getMethods().forEach(method -> {
String methodPath = null;
try {
ServiceType serviceType = config.getType();
if (serviceType == ServiceType.PROVIDER) {
// 服务提供方,需要暴露自身的ip、port信息,而消费端则不需要
String address = getServiceAddress(config);
methodPath = String.format("%s/%s/%s/%s", interfacePath, serviceType.getType(), method.getMethodName(), address);
} else {
methodPath = String.format("%s/%s/%s", interfacePath, serviceType.getType(), method.getMethodName());
}
log.info("zk path: [" + ROOT_PATH + methodPath + "]");
// 创建临时节点,节点包含了服务提供段的信息
this.client.create()
.creatingParentsIfNeeded()
.withMode(CreateMode.EPHEMERAL)
.forPath(methodPath, "0".getBytes());
} catch (Exception e) {
log.error("创建临时节点[" + methodPath + "]失败,error:{}", e);
}
});
} catch (Exception e) {
log.error("创建服务节点失败,error:{}", e);
}
}

/**
* 服务发现
*
* @param request
* @param nodeType
* @return
*/
@Override
public InetSocketAddress discovery(RpcRequest request, ServiceType nodeType) {
String path = String.format("/%s/%s/%s", request.getClassName(), nodeType.getType(), request.getMethodName());
try {
List<String> addressList = this.client.getChildren().forPath(path);
// 采用负载均衡的方式获取服务提供方信息,不过并没有添加watcher监听模式
String address = loadBalance.balance(addressList);
if (address == null) {
return null;
}
return parseAddress(address);
} catch (Exception e) {
log.error("服务发现接口异常,error:{}", e);
}
return null;
}

/**
* 获取服务地址
*
* @param config
* @return
*/
private String getServiceAddress(BasicConfig config) {
String hostInfo = new StringBuilder()
.append(config.getHost())
.append(":")
.append(config.getPort())
.toString();
return hostInfo;
}

/**
* 封装端口
*
* @param address
* @return
*/
private InetSocketAddress parseAddress(String address) {
String[] result = address.split(":");
return new InetSocketAddress(result[0], Integer.valueOf(result[1]));
}

/**
* 设置负载均衡策略
*
* @param loadBalance
*/
public void setLoadBalance(LoadBalance loadBalance) {
this.loadBalance = loadBalance;
}
}

负载均衡

LoadBalance

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package com.springboot.whb.study.rpc.rpc_v2.balance;

import java.util.List;

/**
* @author: whb
* @description: 负载均衡接口定义
*/
public interface LoadBalance {

/**
* 负载均衡
*
* @param addressList
* @return
*/
String balance(List<String> addressList);
}

AbstractLoadBalance

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
package com.springboot.whb.study.rpc.rpc_v2.balance;

import java.util.List;

/**
* @author: whb
* @description: 抽象负载均衡
*/
public abstract class AbstractLoadBalance implements LoadBalance {

@Override
public String balance(List<String> addressList) {
if (addressList == null || addressList.isEmpty()) {
return null;
}
if (addressList.size() == 1) {
return addressList.get(0);
}
return doLoad(addressList);
}

/**
* 抽象接口,让子类去实现
*
* @param addressList
* @return
*/
abstract String doLoad(List<String> addressList);
}

DefaultLoadBalance

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package com.springboot.whb.study.rpc.rpc_v2.balance;

import java.util.List;
import java.util.Random;

/**
* @author: whb
* @description: 默认负载均衡--随机负载均衡
*/
public class DefaultLoadBalance extends AbstractLoadBalance {

@Override
String doLoad(List<String> addressList) {
//随机
Random random = new Random();
return addressList.get(random.nextInt(addressList.size()));
}
}

消息协议

MessageProtocol

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
package com.springboot.whb.study.rpc.rpc_v2.io.protocol;

import com.springboot.whb.study.rpc.rpc_v2.core.RpcRequest;
import com.springboot.whb.study.rpc.rpc_v2.core.RpcResponse;

import java.io.InputStream;
import java.io.OutputStream;

/**
* @author: whb
* @description: 请求、应答 解析和反解析,包含了序列化以及反序列化操作
*/
public interface MessageProtocol {

/**
* 服务端解析从网络传输的数据,转变成request对象
*
* @param inputStream
* @return
*/
RpcRequest serviceToRequest(InputStream inputStream);

/**
* 服务端把计算的结果包装好,通过输出流返回给客户端
*
* @param response
* @param outputStream
* @param <T>
*/
<T> void serviceGetResponse(RpcResponse<T> response, OutputStream outputStream);

/**
* 客户端把请求拼接好,通过输出流发送到服务端
*
* @param request
* @param outputStream
*/
void clientToRequest(RpcRequest request, OutputStream outputStream);

/**
* 客户端接收到服务端响应的结果,转变成response对象
*
* @param inputStream
*/
<T> RpcResponse<T> clientGetResponse(InputStream inputStream);
}

DefaultMessageProtocol

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package com.springboot.whb.study.rpc.rpc_v2.io.protocol;

import com.springboot.whb.study.rpc.rpc_v2.core.RpcRequest;
import com.springboot.whb.study.rpc.rpc_v2.core.RpcResponse;
import com.springboot.whb.study.rpc.rpc_v2.serialize.HessianSerialize;
import com.springboot.whb.study.rpc.rpc_v2.serialize.SerializeProtocol;
import lombok.extern.slf4j.Slf4j;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;

/**
* @author: whb
* @description: 套接字的io流和服务端、客户端的数据传输
*/
@Slf4j
public class DefaultMessageProtocol implements MessageProtocol {

/**
* 序列化协议
*/
private SerializeProtocol serializeProtocol;

public DefaultMessageProtocol() {
this.serializeProtocol = new HessianSerialize();
//this.serializeProtocol = new JavaInnerSerialize();
}

public void setSerializeProtocol(SerializeProtocol serializeProtocol) {
// 可替换序列化协议
this.serializeProtocol = serializeProtocol;
}

/**
* 服务端解析从网络传输的数据,转变成request对象
*
* @param inputStream
* @return
*/
@Override
public RpcRequest serviceToRequest(InputStream inputStream) {
try {
// 2、bytes -> request 反序列化
byte[] bytes = readBytes(inputStream);
System.out.println("[2]服务端反序列化出obj:[" + new String(bytes) + "], length:" + bytes.length);
//System.out.println("[2]服务端反序列化出obj length:" + bytes.length);
RpcRequest request = serializeProtocol.deserialize(RpcRequest.class, bytes);
return request;
} catch (Exception e) {
log.error("[2]服务端反序列化从网络传输的数据转变成request对象失败,error:{}", e);
}
return null;
}

/**
* 服务端把计算的结果包装好,通过输出流返回给客户端
*
* @param response
* @param outputStream
* @param <T>
*/
@Override
public <T> void serviceGetResponse(RpcResponse<T> response, OutputStream outputStream) {
try {
// 3、把response 序列化成bytes 传给客户端
byte[] bytes = serializeProtocol.serialize(RpcResponse.class, response);
System.out.println("[3]服务端序列化出bytes:[" + new String(bytes) + "], length:" + bytes.length);
//System.out.println("[3]服务端序列化出bytes length:" + bytes.length);
outputStream.write(bytes);
} catch (Exception e) {
log.error("[3]服务端序列化计算的结果出输给客户端失败,error:{}", e);
}
}

/**
* 客户端把请求拼接好,通过输出流发送到服务端
*
* @param request
* @param outputStream
*/
@Override
public void clientToRequest(RpcRequest request, OutputStream outputStream) {
try {
// 1、先把这个request -> bytes 序列化掉
byte[] bytes = serializeProtocol.serialize(RpcRequest.class, request);
System.out.println("[1]客户端序列化出bytes:[" + new String(bytes) + "], length:" + bytes.length);
//System.out.println("[1]客户端序列化出bytes length:" + bytes.length);
outputStream.write(bytes);
} catch (IOException e) {
log.error("[1]客户端序列化请求参数失败,error:{}", e);
}
}

/**
* 客户端接收到服务端响应的结果,转变成response对象
*
* @param inputStream
*/
@Override
public <T> RpcResponse<T> clientGetResponse(InputStream inputStream) {
try {
// 4、bytes 反序列化成response
byte[] bytes = readBytes(inputStream);
System.out.println("[4]客户端反序列化出bytes:[" + new String(bytes) + "], length:" + bytes.length);
//System.out.println("[4]客户端反序列化出bytes length:" + bytes.length);
RpcResponse response = serializeProtocol.deserialize(RpcResponse.class, bytes);

return response;
} catch (Exception e) {
log.error("[4]客户端反序列化计算结果失败,error:{}", e);
}
return null;
}

/**
* 流转二进制数组
*
* @param inputStream
* @return
* @throws IOException
*/
private byte[] readBytes(InputStream inputStream) throws IOException {
if (inputStream == null) {
throw new RuntimeException("输入流为空");
}
return inputStreamToByteArr2(inputStream);
}

/**
* 流转二进制数组方法1
*
* @param inputStream
* @return
* @throws IOException
*/
private byte[] inputStreamToByteArr1(InputStream inputStream) throws IOException {
// 有个前提是数据最大是1024,并没有迭代读取数据
byte[] bytes = new byte[1024];
int count = inputStream.read(bytes, 0, 1024);
return Arrays.copyOf(bytes, count);
}

/**
* 流转二进制数组方法2
*
* @param inputStream
* @return
* @throws IOException
*/
private byte[] inputStreamToByteArr2(InputStream inputStream) throws IOException {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
int bufesize = 1024;
while (true) {
byte[] data = new byte[bufesize];
int count = inputStream.read(data, 0, bufesize);
byteArrayOutputStream.write(data, 0, count);
if (count < bufesize) {
break;
}
}
return byteArrayOutputStream.toByteArray();
}

/**
* 流转二进制数组方法3,调用该方法之后会阻塞在read,可通过jstack查看相关信息
*
* @param inputStream
* @return
* @throws IOException
*/
private byte[] inputStreamToByteArr3(InputStream inputStream) throws IOException {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
int bufesize = 1024;

byte[] buff = new byte[bufesize];
int rc = 0;
while ((rc = inputStream.read(buff, 0, bufesize)) > 0) {
byteArrayOutputStream.write(buff, 0, rc);
buff = new byte[bufesize];
}
byte[] bytes = byteArrayOutputStream.toByteArray();
return bytes;
}

}

数据传输模型

ArgumentConfig

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
package com.springboot.whb.study.rpc.rpc_v2.config;

import lombok.Data;

import java.io.Serializable;

/**
* @author: whb
* @description: 参数配置
*/
@Data
public class ArgumentConfig implements Serializable {

private static final long serialVersionUID = 1L;

/**
* 第几个参数
*/
private int index;

/**
* 参数类型
*/
private String type;
}

BasicConfig

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package com.springboot.whb.study.rpc.rpc_v2.config;

import com.springboot.whb.study.rpc.rpc_v2.domain.ServiceType;
import lombok.Data;

import java.util.List;

/**
* @author: whb
* @description: 基础配置
*/
@Data
public class BasicConfig {

/**
* 地址
*/
private String host;
/**
* 端口号
*/
private int port;

/**
* 服务提供方还是服务消费方
*/
private ServiceType type;

/**
* 接口名
*/
private String interfaceName;

/**
* 接口类
*/
private Class<?> interfaceClass;

/**
* 方法集合
*/
private List<MethodConfig> methods;

/**
* 分组
*/
private String group;

/**
* 默认版本号是default
*/
private String version = "default";

}

ClientConfig

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package com.springboot.whb.study.rpc.rpc_v2.config;

import com.springboot.whb.study.rpc.rpc_v2.core.ProxyInstance;
import com.springboot.whb.study.rpc.rpc_v2.domain.ServiceType;
import lombok.Data;

import java.io.Serializable;
import java.lang.reflect.Proxy;

/**
* @author: whb
* @description: 客户端配置
*/
@Data
public class ClientConfig<T> extends BasicConfig implements Serializable {

private static final long serialVersionUID = 1L;

private T proxy;

/**
* 反射包装成客户端参数配置对象
*
* @param interfaceClass
* @param invocationHandler
* @param <T>
* @return
*/
public static <T> ClientConfig<T> convert(Class<T> interfaceClass, ProxyInstance invocationHandler) {
ClientConfig<T> config = new ClientConfig<>();

config.setVersion("default");
config.setInterfaceClass(interfaceClass);
config.setInterfaceName(interfaceClass.getName());
config.setMethods(MethodConfig.convert(interfaceClass.getMethods()));
config.setType(ServiceType.CONSUMER);

Object proxy = Proxy.newProxyInstance(ClientConfig.class.getClassLoader(),
new Class<?>[]{interfaceClass},
invocationHandler);
config.setProxy((T) proxy);
return config;
}
}

MethodConfig

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package com.springboot.whb.study.rpc.rpc_v2.config;

import lombok.Data;

import java.io.Serializable;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.List;

/**
* @author: whb
* @description: 方法配置
*/
@Data
public class MethodConfig implements Serializable {

private static final long serialVersionUID = 1L;

/**
* 方法名
*/
private String methodName;

/**
* 参数
*/
private List<ArgumentConfig> argumentConfigs;

/**
* 是否需要返回
*/
private Boolean isReturn;

/**
* 返回值类型
*/
private Class<?> returnType;

/**
* 方法数组转方法配置集合
*
* @param methods
* @return
*/
public static List<MethodConfig> convert(Method[] methods) {
List<MethodConfig> methodConfigList = new ArrayList<>(methods.length);
MethodConfig methodConfig = null;
for (Method method : methods) {
methodConfig = new MethodConfig();
methodConfig.setMethodName(method.getName());

Class<?> returnType = method.getReturnType();
String returnName = returnType.getName();
if ("void".equals(returnName)) {
methodConfig.setIsReturn(false);
} else {
methodConfig.setIsReturn(true);
}
methodConfig.setReturnType(returnType);
methodConfig.setArgumentConfigs(convert(method.getParameters()));

methodConfigList.add(methodConfig);
}
return methodConfigList;
}

/**
* 参数数组转参数配置集合
*
* @param parameters
* @return
*/
private static List<ArgumentConfig> convert(Parameter[] parameters) {
List<ArgumentConfig> argumentConfigs = new ArrayList<>(parameters.length);
int start = 0;
ArgumentConfig argumentConfig = null;
for (Parameter parameter : parameters) {
argumentConfig = new ArgumentConfig();
argumentConfig.setIndex(start);
argumentConfig.setType(parameter.getType().getName());
argumentConfigs.add(argumentConfig);
start += 1;
}
return argumentConfigs;
}
}

ServiceConfig

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
package com.springboot.whb.study.rpc.rpc_v2.config;

import com.alibaba.fastjson.JSON;
import com.springboot.whb.study.rpc.rpc_v2.core.RpcService;
import com.springboot.whb.study.rpc.rpc_v2.domain.ServiceType;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import java.io.Serializable;
import java.net.InetAddress;
import java.net.UnknownHostException;

/**
* @author: whb
* @description: 服务方配置
*/
@Data
@Slf4j
public class ServiceConfig<T> extends BasicConfig implements Serializable {

private static final long serialVersionUID = 1L;

private T ref;

/**
* 统计调用次数使用
*/
private int count;

@Override
public String toString() {
return JSON.toJSONString(this);
}

public static <T> ServiceConfig<T> convert(String interfaceName,
Class<T> interfaceClass,
T ref, RpcService rpcService) {
ServiceConfig<T> serviceConfig = new ServiceConfig<>();

serviceConfig.setRef(ref);
serviceConfig.setInterfaceName(interfaceName);
serviceConfig.setInterfaceClass(interfaceClass);
serviceConfig.setCount(0);
serviceConfig.setMethods(MethodConfig.convert(interfaceClass.getMethods()));
serviceConfig.setPort(rpcService.getPort());
serviceConfig.setType(ServiceType.PROVIDER);

try {
InetAddress addr = InetAddress.getLocalHost();
serviceConfig.setHost(addr.getHostAddress());
} catch (UnknownHostException e) {
log.error("服务方获取本机地址失败,error:{}", e);
}
return serviceConfig;
}
}

枚举常量

ServiceType

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
package com.springboot.whb.study.rpc.rpc_v2.domain;

/**
* @author: whb
* @description: 服务类型枚举常量
*/
public enum ServiceType {
/**
* 服务提供者
*/
PROVIDER("provider"),

/**
* 服务消费者
*/
CONSUMER("consumer");

private String type;

ServiceType(String type) {
this.type = type;
}

public String getType() {
return type;
}
}

序列化、反序列化

SerializeProtocol

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package com.springboot.whb.study.rpc.rpc_v2.serialize;

/**
* @author: whb
* @description: 序列化协议接口
*/
public interface SerializeProtocol {

/**
* 序列化
*/
<T> byte[] serialize(Class<T> clazz, T t);

/**
* 反序列化
*/
<T> T deserialize(Class<T> clazz, byte[] bytes);
}

JavaInnerSerialize

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package com.springboot.whb.study.rpc.rpc_v2.serialize;

import lombok.extern.slf4j.Slf4j;

import java.io.*;

/**
* @author: whb
* @description: Java序列化
*/
@Slf4j
public class JavaInnerSerialize implements SerializeProtocol {

/**
* 序列化
*
* @param clazz
* @param t
* @param <T>
* @return
*/
@Override
public <T> byte[] serialize(Class<T> clazz, T t) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
ObjectOutputStream objectOutputStream = null;
try {
objectOutputStream = new ObjectOutputStream(outputStream);
objectOutputStream.writeObject(t);
objectOutputStream.flush();
byte[] bytes = outputStream.toByteArray();
return bytes;
} catch (Exception e) {
log.error("Java 序列化失败,error:{}", e);
} finally {
if (outputStream != null) {
try {
outputStream.close();
} catch (IOException e) {
log.error("Java 序列化关闭二进制输出流失败,error:{}", e);
}
}
if (objectOutputStream != null) {
try {
objectOutputStream.close();
} catch (IOException e) {
log.error("Java 序列化关闭对象流失败,error:{}", e);
}
}
}
return null;
}

/**
* 反序列化
*
* @param clazz
* @param bytes
* @param <T>
* @return
*/
@Override
public <T> T deserialize(Class<T> clazz, byte[] bytes) {
ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes);
ObjectInputStream objectInputStream = null;
try {
objectInputStream = new ObjectInputStream(inputStream);
T obj = (T) objectInputStream.readObject();
return obj;
} catch (Exception e) {
log.error("Java 反序列化失败,error:{}", e);
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
log.error("Java 反序列化关闭二进制输入流失败,error:{}", e);
}
}
if (objectInputStream != null) {
try {
objectInputStream.close();
} catch (IOException e) {
log.error("Java 反序列化关闭对象输入流失败,error:{}", e);
}
}
}
return null;
}
}

HessianSerialize

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package com.springboot.whb.study.rpc.rpc_v2.serialize;

import com.alibaba.com.caucho.hessian.io.Hessian2Input;
import com.alibaba.com.caucho.hessian.io.Hessian2Output;
import lombok.extern.slf4j.Slf4j;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;

/**
* @author: whb
* @description: Hessian二进制序列化
*/
@Slf4j
public class HessianSerialize implements SerializeProtocol {

/**
* 序列化
*
* @param clazz
* @param t
* @param <T>
* @return
*/
@Override
public <T> byte[] serialize(Class<T> clazz, T t) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
Hessian2Output hessian2Output = new Hessian2Output(outputStream);
try {
//验证过,一定需要在flush之前关闭掉hessian2Output,否则获取的bytes字段信息为空
hessian2Output.writeObject(t);
} catch (IOException e) {
throw new RuntimeException(e.getMessage());
} finally {
try {
hessian2Output.close();
} catch (IOException e) {
log.error("Hessian 二进制序列化,关闭流失败,error:{}", e);
}
}
try {
outputStream.flush();
byte[] bytes = outputStream.toByteArray();
return bytes;
} catch (IOException e) {
throw new RuntimeException(e.getMessage());
} finally {
try {
outputStream.close();
} catch (IOException e) {
log.error("Hessian 二进制序列化,关闭输出流失败,error:{}", e);
}
}
}

/**
* 反序列化
*
* @param clazz
* @param bytes
* @param <T>
* @return
*/
@Override
public <T> T deserialize(Class<T> clazz, byte[] bytes) {
ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes);
Hessian2Input hessian2Input = new Hessian2Input(inputStream);
try {
T t = (T) hessian2Input.readObject();
return t;
} catch (IOException e) {
throw new RuntimeException(e.getMessage());
} finally {
try {
hessian2Input.close();
} catch (IOException e) {
log.error("Hessian 反序列化,流关闭失败,error:{}", e);
}
try {
inputStream.close();
} catch (IOException e) {
log.error("Hessian 反序列化,输入流关闭失败,error:{}", e);
}
}
}
}

请求、响应对象

RpcRequest

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package com.springboot.whb.study.rpc.rpc_v2.core;

import com.alibaba.fastjson.JSON;
import lombok.Data;

import java.io.Serializable;

/**
* @author: whb
* @description: RPC请求对象
*/
@Data
public class RpcRequest implements Serializable {

private static final long serialVersionUID = 1L;

/**
* 类名
*/
private String className;

/**
* 方法名
*/
private String methodName;

/**
* 参数
*/
private Object[] arguments;

/**
* 参数类型
*/
private Class<?>[] parameterTypes;

@Override
public String toString() {
return JSON.toJSONString(this);
}
}

RpcResponse

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
package com.springboot.whb.study.rpc.rpc_v2.core;

import com.alibaba.fastjson.JSON;
import lombok.Data;

import java.io.Serializable;

/**
* @author: whb
* @description: RPC响应对象
*/
@Data
public class RpcResponse<T> implements Serializable {

private static final long serialVersionUID = 1L;

/**
* 响应结果
*/
private T result;

/**
* 是否出错
*/
private Boolean isError;

/**
* 错误信息
*/
private String errorMsg;

@Override
public String toString() {
return JSON.toJSONString(this);
}
}

服务端处理

ServiceConnection

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package com.springboot.whb.study.rpc.rpc_v2.core;

import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;

/**
* @author: whb
* @description: 服务连接
*/
@Slf4j
@Data
public class ServiceConnection implements Runnable {

/**
* 端口号
*/
private int port;

/**
* 服务关闭标记位
*/
private volatile boolean flag = true;

/**
* 服务端套接字
*/
private ServerSocket serverSocket;

/**
* 服务处理器
*/
private ServiceHandler serviceHandler;

/**
* 初始化
*
* @param port
* @param serviceHandler
*/
public void init(int port, ServiceHandler serviceHandler) {
try {
this.port = port;
this.serverSocket = new ServerSocket(this.port);
} catch (IOException e) {
throw new RuntimeException("启动失败:" + e.getMessage());
}
this.serviceHandler = serviceHandler;
log.info("服务启动了...");
}

@Override
public void run() {
while (flag) {
try {
Socket socket = serverSocket.accept();
serviceHandler.handler(socket);
} catch (IOException e) {
try {
Thread.sleep(100);
} catch (InterruptedException e1) {
log.error("服务处理异常,error:{}", e);
}
}
}
}

/**
* 关闭连接
*/
public void destory() {
log.info("服务端套接字关闭...");
this.flag = false;
}
}

ServiceHandler

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package com.springboot.whb.study.rpc.rpc_v2.core;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.springboot.whb.study.rpc.rpc_v2.io.protocol.MessageProtocol;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
* @author: whb
* @description: 服务端处理器
*/
@Slf4j
@Data
public class ServiceHandler {

/**
* 线程池
*/
private ThreadPoolExecutor executor = null;

/**
* 服务接口
*/
private RpcService rpcService;

/**
* 消息协议
*/
private MessageProtocol messageProtocol;

public ServiceHandler(RpcService rpcService) {
this.rpcService = rpcService;
//创建线程的线程工厂
ThreadFactory commonThreadName = new ThreadFactoryBuilder()
.setNameFormat("Parse-Task-%d")
.build();
//构造线程池
this.executor = new ThreadPoolExecutor(
10,
10,
2,
TimeUnit.SECONDS,
new ArrayBlockingQueue<>(200),
commonThreadName,
(Runnable r, ThreadPoolExecutor executor) -> {
SocketTask socketTask = (SocketTask) r;
Socket socket = socketTask.getSocket();
if (socket != null) {
try {
//无法及时处理和响应就快速拒绝掉
socket.close();
log.info("reject socket:" + socketTask + ", and closed.");
} catch (IOException e) {
log.error("socket关闭失败,error:{}", e);
}
}
}
);
}

/**
* 服务处理:接收到新的套接字,包装成为一个runnable提交给线程去执行
*
* @param socket
*/
public void handler(Socket socket) {
this.executor.execute(new SocketTask(socket));
}

class SocketTask implements Runnable {
private Socket socket;

public SocketTask(Socket socket) {
this.socket = socket;
}

public Socket getSocket() {
return socket;
}

@Override
public void run() {
try {
InputStream inputStream = socket.getInputStream();
OutputStream outputStream = socket.getOutputStream();
// 获取客户端请求数据,统一包装成RpcRequest
RpcRequest request = messageProtocol.serviceToRequest(inputStream);
RpcResponse response = rpcService.invoke(request);
log.info("request:[" + request + "],response:[" + response + "]");
// 反射调用,得到具体的返回值
messageProtocol.serviceGetResponse(response, outputStream);
} catch (Exception e) {
log.error("服务端处理出现异常,error:{}", e);
} finally {
if (socket != null) {
try {
socket.close();
} catch (IOException e) {
log.error("socket关闭失败,error:{}", e);
}
}
}
}
}
}

RpcService

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package com.springboot.whb.study.rpc.rpc_v2.core;

import com.google.common.base.Joiner;
import com.springboot.whb.study.rpc.rpc_v2.config.ServiceConfig;
import com.springboot.whb.study.rpc.rpc_v2.io.protocol.DefaultMessageProtocol;
import com.springboot.whb.study.rpc.rpc_v2.io.protocol.MessageProtocol;
import com.springboot.whb.study.rpc.rpc_v2.register.ServiceRegister;
import com.springboot.whb.study.rpc.rpc_v2.register.ZkServiceRegister;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;

import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
* @author: whb
* @description: RPC服务
*/
@Slf4j
@Data
public class RpcService {

/**
* k 是接口全名称
* v 是对应的对象包含的详细信息
*/
private Map<String, ServiceConfig> serviceConfigMap = new HashMap<>();

/**
* 端口号
*/
private int port;

/**
* 服务注册
*/
private ServiceRegister serviceRegister;

/**
* 连接器还未抽象处理,使用的还是BIO模型
*/
private ServiceConnection serviceConnection;

/**
* 服务处理器
*/
private ServiceHandler serviceHandler;

/**
* 线程池
*/
ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(10, 100, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<Runnable>(1000),
new BasicThreadFactory.Builder().namingPattern(Joiner.on("-").join("service-thread-pool-", "%s")).build());

public RpcService(int port) {
this.port = port;
this.serviceHandler = new ServiceHandler(this);
this.serviceHandler.setMessageProtocol(new DefaultMessageProtocol());
this.serviceRegister = new ZkServiceRegister();
}

/**
* 设置消息协议
*
* @param messageProtocol
*/
public void setMessageProtocol(MessageProtocol messageProtocol) {
if (this.serviceHandler == null) {
throw new RuntimeException("套接字处理器无效");
}
this.serviceHandler.setMessageProtocol(messageProtocol);
}

/**
* 添加服务接口
*
* @param interfaceClass
* @param ref
* @param <T>
*/
public <T> void addService(Class<T> interfaceClass, T ref) {
String interfaceName = interfaceClass.getName();
ServiceConfig<T> serviceConfig = ServiceConfig.convert(interfaceName, interfaceClass, ref, this);
serviceConfigMap.put(interfaceName, serviceConfig);
}

/**
* 注册服务
*/
private void register() {
//服务注册,在网络监听启动之前就需要完成
serviceConfigMap.values().forEach(serviceRegister::register);
}

/**
* 服务启动
*/
public void start() {
this.register();
log.info("服务注册完成");

this.serviceConnection = new ServiceConnection();
this.serviceConnection.init(port, serviceHandler);
threadPoolExecutor.execute(serviceConnection);

//优雅关闭
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
RpcService.this.destroy();
}));
}

/**
* 通过反射执行,执行结果封装RpcResponse
*
* @param request
* @param <K>
* @param <V>
* @return
*/
public <K, V> RpcResponse invoke(RpcRequest request) {
if (request == null) {
RpcResponse<V> response = new RpcResponse<>();
response.setResult(null);
response.setIsError(true);
response.setErrorMsg("未知异常");
return response;
}
String className = request.getClassName();
//暂时不考虑没有对应的serviceConfig的情况
ServiceConfig<K> serviceConfig = serviceConfigMap.get(className);
K ref = serviceConfig.getRef();
try {
Method method = ref.getClass().getMethod(request.getMethodName(), request.getParameterTypes());
V result = (V) method.invoke(ref, request.getArguments());
RpcResponse<V> response = new RpcResponse<>();
response.setResult(result);
response.setIsError(false);
response.setErrorMsg("");
return response;
} catch (Exception e) {

}
return null;
}

/**
* 关闭服务
*/
public void destroy() {
this.serviceConnection.destory();
log.info("服务端关闭了");
}
}

客户端处理

代理对象ProxyInstance

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package com.springboot.whb.study.rpc.rpc_v2.core;

import lombok.extern.slf4j.Slf4j;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;

/**
* @author: whb
* @description: 客户端代理对象
*/
@Slf4j
public class ProxyInstance implements InvocationHandler {

/**
* RPC调用方
*/
private RpcClient rpcClient;

private Class clazz;

public ProxyInstance(RpcClient client, Class clazz) {
this.rpcClient = client;
this.clazz = clazz;
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
RpcRequest request = new RpcRequest();
request.setClassName(clazz.getName());
request.setMethodName(method.getName());
request.setArguments(args);
request.setParameterTypes(method.getParameterTypes());

//获取服务提供方信息
InetSocketAddress address = rpcClient.discovery(request);
log.info("[" + Thread.currentThread().getName() + "] discovery service: " + address);

//发起网络请求,得到请求数据
RpcResponse response = rpcClient.invoke(request, address);
return response.getResult();
}
}

ClientHandler

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package com.springboot.whb.study.rpc.rpc_v2.core;

import com.springboot.whb.study.rpc.rpc_v2.io.protocol.MessageProtocol;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;

/**
* @author: whb
* @description: 客户端处理器
*/
@Slf4j
public class ClientHandler {

private RpcClient rpcClient;

private MessageProtocol messageProtocol;

public ClientHandler(RpcClient rpcClient) {
this.rpcClient = rpcClient;
}

public void setMessageProtocol(MessageProtocol messageProtocol) {
this.messageProtocol = messageProtocol;
}

public <T> RpcResponse<T> invoke(RpcRequest request, InetSocketAddress address) {
RpcResponse<T> response = new RpcResponse<>();

Socket socket = getSocketInstance(address);
if (socket == null) {
// 套接字链接失败
response.setIsError(true);
response.setErrorMsg("套接字链接失败");
return response;
}

try {
InputStream inputStream = socket.getInputStream();
OutputStream outputStream = socket.getOutputStream();

messageProtocol.clientToRequest(request, outputStream);

response = messageProtocol.clientGetResponse(inputStream);
} catch (IOException e) {
log.error("客户端处理异常,error:{}", e);
} finally {
if (socket != null) {
try {
socket.close();
} catch (IOException e) {
log.error("客户端关闭套接字失败,error:{}", e);
}
}
}
return response;
}

/**
* 获取对象实例
*
* @param address
* @return
*/
private Socket getSocketInstance(InetSocketAddress address) {
try {
return new Socket(address.getHostString(), address.getPort());
} catch (IOException e) {
log.error("客户端获取套接字失败,error:{}", e);
}
return null;
}
}

RpcClient

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package com.springboot.whb.study.rpc.rpc_v2.core;

import com.springboot.whb.study.rpc.rpc_v2.config.ClientConfig;
import com.springboot.whb.study.rpc.rpc_v2.domain.ServiceType;
import com.springboot.whb.study.rpc.rpc_v2.io.protocol.DefaultMessageProtocol;
import com.springboot.whb.study.rpc.rpc_v2.register.ServiceRegister;
import com.springboot.whb.study.rpc.rpc_v2.register.ZkServiceRegister;

import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;

/**
* @author: whb
* @description: RPC客户端
*/
public class RpcClient {
/**
* k 是接口的全名称
* v 是对应的对象包含的详细信息
*/
private Map<String, ClientConfig> clientConfigMap = new HashMap<>();

/**
* 服务注册
*/
private ServiceRegister serviceRegister;

/**
* 客户端处理器
*/
private ClientHandler clientHandler;

public RpcClient() {
this.serviceRegister = new ZkServiceRegister();
this.clientHandler = new ClientHandler(this);
// 设置默认的消息处理协议
this.clientHandler.setMessageProtocol(new DefaultMessageProtocol());
}

/**
* 订阅服务
*
* @param clazz
* @param <T>
*/
public <T> void subscribe(Class<T> clazz) {
String interfaceName = clazz.getName();
ProxyInstance invocationHandler = new ProxyInstance(this, clazz);
ClientConfig<T> clientConfig = ClientConfig.convert(clazz, invocationHandler);
clientConfigMap.put(interfaceName, clientConfig);
}

/**
* 服务注册
*/
private void register() {
// 服务注册,在网络监听启动之前就需要完成
clientConfigMap.values().forEach(serviceRegister::register);
}

/**
* 服务启动
*/
public void start() {
this.register();
}

/**
* 服务发现
*
* @param request
* @return
*/
public InetSocketAddress discovery(RpcRequest request) {
return serviceRegister.discovery(request, ServiceType.PROVIDER);
}

/**
* 反射调用
*
* @param request
* @param address
* @return
*/
public RpcResponse invoke(RpcRequest request, InetSocketAddress address) {
return this.clientHandler.invoke(request, address);
}

/**
* 获取对象实例
*
* @param clazz
* @param <T>
* @return
*/
public <T> T getInstance(Class<T> clazz) {
return (T) (clientConfigMap.get(clazz.getName()).getProxy());
}

}

测试

测试接口定义Calculate

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package com.springboot.whb.study.rpc.rpc_v2.demo;

/**
* @author: whb
* @description: 测试接口定义
*/
public interface Calculate<T> {

/**
* 求和
*
* @param a
* @param b
* @return
*/
T add(T a, T b);

/**
* 求差
*
* @param a
* @param b
* @return
*/
T sub(T a, T b);
}

测试接口实现类SimpleCalculate

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package com.springboot.whb.study.rpc.rpc_v2.demo;

import java.util.Random;

/**
* @author: whb
* @description: 测试接口实现类
*/
public class SimpleCalculate implements Calculate<Integer> {

@Override
public Integer add(Integer a, Integer b) {
long start = System.currentTimeMillis();
try {
Thread.sleep(new Random().nextInt(1000));
} catch (InterruptedException e) {
e.printStackTrace();
}
int c = a + b;
System.out.println(Thread.currentThread().getName() + " 耗时:" + (System.currentTimeMillis() - start));
return c;
}

@Override
public Integer sub(Integer a, Integer b) {
return a - b;
}
}

测试-服务端Service

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
package com.springboot.whb.study.rpc.rpc_v2.demo;

import com.springboot.whb.study.rpc.rpc_v2.core.RpcService;

/**
* @author: whb
* @description: 测试服务端
*/
public class Service {
public static void main(String[] args) {
RpcService rpcService = new RpcService(10001);
rpcService.addService(Calculate.class, new SimpleCalculate());

rpcService.start();
}
}

测试-客户端Client

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package com.springboot.whb.study.rpc.rpc_v2.demo;

import com.google.common.base.Joiner;
import com.springboot.whb.study.rpc.rpc_v2.core.RpcClient;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;

import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
* @author: whb
* @description: 测试客户端
*/
public class Client {
public static final ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(10, 100, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<Runnable>(1000),
new BasicThreadFactory.Builder().namingPattern(Joiner.on("-").join("client-thread-pool-", "%s")).build());

public static void main(String[] args) {
RpcClient rpcClient = new RpcClient();

rpcClient.subscribe(Calculate.class);
rpcClient.start();

Calculate<Integer> calculateProxy = rpcClient.getInstance(Calculate.class);

for (int i = 0; i < 200; i++) {
threadPoolExecutor.execute(() -> {
long start = System.currentTimeMillis();
int s1 = new Random().nextInt(100);
int s2 = new Random().nextInt(100);
int s3 = calculateProxy.add(s1, s2);
System.out.println("[" + Thread.currentThread().getName() + "]a: " + s1 + ", b:" + s2 + ", c=" + s3 + ", 耗时:" + (System.currentTimeMillis() - start));
});
}
}
}

测试结果

zookeeper

img

服务端

img
img

客户端

img

总结

v2版本相比v1版本修改了整个代码结构,使得结构能够更加明确,引入zookeeper作为服务治理功能,大致介绍了zookeeper的特点以及功能,给服务注册、服务发现、序列化协议等均留下了口子,以便实现自定义的协议,v1的io模型是BIO,v2并没有变化,只是由单线程改造成多线程。

整体而言符合一个简单的rpc框架,依旧还是有很多点可以完善、优化的点,如:

  • io模型还是没有替换,后面考虑直接整体接入netty;

  • 不应该每次实时从zk获取节点信息,应该先设置一个本地缓存,再利用zookeeper的watcher功能,开启一个异步线程去监听更新本地缓存,降低和zk交互带来的性能损耗;

  • 没有快速失败、重试的功能,客观情况下存在网络抖动的问题,重试就可以了。

  • 整体的各种协议约定并没有明确规范,比较混乱。

本文标题:手写RPC-升级版

文章作者:王洪博

发布时间:2019年05月19日 - 15:05

最后更新:2019年09月12日 - 10:09

原始链接:http://whb1990.github.io/posts/a6a74f01.html

▄︻┻═┳一如果你喜欢这篇文章,请点击下方"打赏"按钮请我喝杯 ☕
0%