用GO写一个RPC框架 s04 (编写服务端核心) 前言

前言

通过上两篇的学习 我们已经了解了 服务端本地服务的注册, 服务端配置,协议 现在我们开始写服务端的核心逻辑

github.com/dollarkille…

默认配置

我们先看下默认的配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
go复制代码func defaultOptions() *Options {
return &Options{
Protocol: transport.TCP, // default TCP
Uri: "0.0.0.0:8397",
UseHttp: false,
readTimeout: time.Minute * 3, // 心跳包 默认 3min
writeTimeout: time.Second * 30,
ctx: context.Background(), // ctx 是控制服务退出的
options: map[string]interface{}{
"TCPKeepAlivePeriod": time.Minute * 3,
},
processChanSize: 1000,
Trace: false,
RSAPublicKey: []byte(`-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----`),
RSAPrivateKey: []byte(`-----BEGIN RSA PRIVATE KEY-----
-----END RSA PRIVATE KEY-----`),
Discovery: &discovery.SimplePeerToPeer{},
}
}

run

服务注册完毕之后 调用Run方法 启动服务

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
go复制代码func (s *Server) Run(options ...Option) error {
// 初始化 服务端配置
for _, fn := range options {
fn(s.options)
}

var err error
// 更具配置传入的protocol 获取到 网络插件 (KCP UDP TCP) 我们等下细讲
s.options.nl, err = transport.Transport.Gen(s.options.Protocol, s.options.Uri)
if err != nil {
return err
}

log.Printf("LightRPC: %s %s \n", s.options.Protocol, s.options.Uri)

// 这里是服务注册 我们这里先跳过
if s.options.Discovery != nil {
// 读取服务配置文件
sIdb, err := ioutil.ReadFile("./light.conf")
if err != nil {
// 如果没有 就生成 分布式ID
id, err := utils.DistributedID()
if err != nil {
return err
}
sIdb = []byte(id)
}
// 进行服务注册
sId := string(sIdb)
for k := range s.serviceMap { // 进行服务注册
err := s.options.Discovery.Registry(k, s.options.registryAddr, s.options.weights, s.options.Protocol, s.options.MaximumLoad, &sId)
if err != nil {
return err
}
log.Printf("Discovery Registry: %s addr: %s SUCCESS", k, s.options.registryAddr)
}

ioutil.WriteFile("./light.conf", sIdb, 00666)
}

// 启动服务
return s.run()
}



func (s *Server) run() error {
loop:
for {
select {
case <-s.options.ctx.Done(): // 检查是否需要退出服务
break loop
default:
accept, err := s.options.nl.Accept() // 获取一个链接
if err != nil {
log.Println(err)
continue
}
if s.options.Trace {
log.Println("connect: ", accept.RemoteAddr())
}

go s.process(accept) // 开一个协程去处理 该 链接
}

}

return nil
}

我们先回顾一下 上章讲的 握手逻辑

  1. 建立链接 通过非对称加密 传输 aes 密钥给服务端 (携带token)
  2. 服务端 验证 token 并记录 aes 密钥 后面与客户端交互 都采用对称加密

具体处理 链接 process (重点!!!)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
go复制代码func (s *Server) process(conn net.Conn) {

defer func() {
// 网络不可靠
if err := recover(); err != nil {
utils.PrintStack()
log.Println("Recover Err: ", err)
}
}()

// 每进来一个请求这里就ADD
s.options.Discovery.Add(1)
defer func() {
s.options.Discovery.Less(1) // 处理完 请求就退出
// 退出 回收句柄
err := conn.Close()
if err != nil {
log.Println(err)
return
}

if s.options.Trace {
log.Println("close connect: ", conn.RemoteAddr())
}
}()

// 这里定义一个xChannel 用于分离 请求和返回
xChannel := utils.NewXChannel(s.options.processChanSize)

// 握手
handshake := protocol.Handshake{}
err := handshake.Handshake(conn)
if err != nil {
return
}

// 非对称加密 解密 AES KEY
aesKey, err := cryptology.RsaDecrypt(handshake.Key, s.options.RSAPrivateKey)
if err != nil {
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
conn.Write(encodeHandshake)
return
}

// 检测 AES KEY 是否正确
if len(aesKey) != 32 && len(aesKey) != 16 {
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte("aes key != 32 && key != 16"))
conn.Write(encodeHandshake)
return
}

// 解密 TOKEN
token, err := cryptology.RsaDecrypt(handshake.Token, s.options.RSAPrivateKey)
if err != nil {
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
conn.Write(encodeHandshake)
return
}
// 对TOKEN进行校验
if s.options.AuthFunc != nil {
err := s.options.AuthFunc(light.DefaultCtx(), string(token))
if err != nil {
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(err.Error()))
conn.Write(encodeHandshake)
return
}
}

// limit 限流
if s.options.Discovery.Limit() {
// 熔断
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(pkg.ErrCircuitBreaker.Error()))
conn.Write(encodeHandshake)
log.Println(s.options.Discovery.Limit())
return
}

// 如果握手没有问题 则返回握手成功
encodeHandshake := protocol.EncodeHandshake([]byte(""), []byte(""), []byte(""))
_, err = conn.Write(encodeHandshake)
if err != nil {
return
}

// send
go func() {
loop:
for {
select {
// 这就是刚刚的xChannel 对读写进行分离
case msg, ex := <-xChannel.Ch:
if !ex {
if s.options.Trace {
log.Printf("ip: %s close send server", conn.RemoteAddr())
}
break loop
}
now := time.Now()
if s.options.writeTimeout > 0 {
conn.SetWriteDeadline(now.Add(s.options.writeTimeout))
}
// send message
_, err := conn.Write(msg)
if err != nil {
if s.options.Trace {
log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)
}
break loop
}
}
}
}()

defer func() {
xChannel.Close()
}()
loop:
for { // 具体消息获取
now := time.Now()
if s.options.readTimeout > 0 {
conn.SetReadDeadline(now.Add(s.options.readTimeout))
}

proto := protocol.NewProtocol()
msg, err := proto.IODecode(conn) // 获取一个消息
if err != nil {
if err == io.EOF {
if s.options.Trace {
log.Printf("ip: %s close", conn.RemoteAddr())
}
break loop
}

// 遇到错误关闭链接
if s.options.Trace {
log.Printf("ip: %s err: %s", conn.RemoteAddr(), err)
}
break loop
}

go s.processResponse(xChannel, msg, conn.RemoteAddr().String(), aesKey)
}
}

具体处理 (重点!!!)

注意此RPC传输消息都是编码过的 要进行转码

  • 第一层 为压缩编码
  • 第二层 为加密编码
  • 第三层 为序列化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
go复制代码func (s *Server) processResponse(xChannel *utils.XChannel, msg *protocol.Message, addr string, aesKey []byte) {
var err error
s.options.Discovery.Add(1)
defer func() {
s.options.Discovery.Less(1)
if err != nil {
if s.options.Trace {
log.Println("ProcessResponse Error: ", err, " ID: ", addr)
}
xChannel.Close()
}
}()

// heartBeat 判断
if msg.Header.RespType == byte(protocol.HeartBeat) {
// 心跳返回
if s.options.Trace {
log.Println("HeartBeat: ", addr)
}

// 4. 打包
_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), []byte(""), byte(protocol.HeartBeat), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))
if err != nil {
return
}
// 5. 回写
err = xChannel.Send(message)
if err != nil {
return
}

return
}

// 限流
if s.options.Discovery.Limit() {
serialization, _ := codes.SerializationManager.Get(codes.MsgPack)
metaData := make(map[string]string)
metaData["RespError"] = pkg.ErrCircuitBreaker.Error()
meta, err := serialization.Encode(metaData)
if err != nil {
return
}
decrypt, err := cryptology.AESDecrypt(aesKey, meta)
if err != nil {
return
}
_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), decrypt, byte(protocol.Response), byte(codes.RawData), byte(codes.MsgPack), []byte(""))
if err != nil {
return
}
// 5. 回写
err = xChannel.Send(message)
if err != nil {
return
}

log.Println(s.options.Discovery.Limit())
log.Println("限流/////////////")

return
}

// 1. 解压缩
compressor, ex := codes.CompressorManager.Get(codes.CompressorType(msg.Header.CompressorType))
if !ex {
err = errors.New("compressor 404")
return
}
msg.MetaData, err = compressor.Unzip(msg.MetaData)
if err != nil {
return
}

msg.Payload, err = compressor.Unzip(msg.Payload)
if err != nil {
return
}
// 2. 解密
msg.MetaData, err = cryptology.AESDecrypt(aesKey, msg.MetaData)
if err != nil {
return
}

msg.Payload, err = cryptology.AESDecrypt(aesKey, msg.Payload)
if err != nil {
return
}

// 3. 反序列化
serialization, ex := codes.SerializationManager.Get(codes.SerializationType(msg.Header.SerializationType))
if !ex {
err = errors.New("serialization 404")
return
}

metaData := make(map[string]string)
err = serialization.Decode(msg.MetaData, &metaData)
if err != nil {
return
}

// 初始化context
ctx := light.DefaultCtx()
ctx.SetMetaData(metaData)

// 1.3 auth
if s.options.AuthFunc != nil {
auth := metaData["Light_AUTH"]
err := s.options.AuthFunc(ctx, auth)
if err != nil {
ctx.SetValue("RespError", err.Error())
var metaDataByte []byte
metaDataByte, _ = serialization.Encode(ctx.GetMetaData())
metaDataByte, _ = cryptology.AESEncrypt(aesKey, metaDataByte)
metaDataByte, _ = compressor.Zip(metaDataByte)
// 4. 打包
_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, []byte(""))
if err != nil {
return
}
// 5. 回写
err = xChannel.Send(message)
if err != nil {
return
}
return
}
}

// 找到具体调用的服务
ser, ex := s.serviceMap[msg.ServiceName]
if !ex {
err = errors.New("service does not exist")
return
}

// 找到具体调用的方法
method, ex := ser.methodType[msg.ServiceMethod]
if !ex {
err = errors.New("method does not exist")
return
}

// 初始化 req, resp
req := utils.RefNew(method.RequestType)
resp := utils.RefNew(method.ResponseType)

err = serialization.Decode(msg.Payload, req)
if err != nil {
return
}

// 定义ctx paht 为 服务名称.服务方法
path := fmt.Sprintf("%s.%s", msg.ServiceName, msg.ServiceMethod)
ctx.SetPath(path)

// 前置middleware
if len(s.beforeMiddleware) != 0 {
for idx := range s.beforeMiddleware {
err := s.beforeMiddleware[idx](ctx, req, resp)
if err != nil {
return
}
}
}
funcs, ex := s.beforeMiddlewarePath[path]
if ex {
if len(funcs) != 0 {
for idx := range funcs {
err := funcs[idx](ctx, req, resp)
if err != nil {
return
}
}
}
}

// 核心调用
callErr := ser.call(ctx, method, reflect.ValueOf(req), reflect.ValueOf(resp))
if callErr != nil {
ctx.SetValue("RespError", callErr.Error())
}

// 后置middleware
if len(s.afterMiddleware) != 0 {
for idx := range s.afterMiddleware {
err := s.afterMiddleware[idx](ctx, req, resp)
if err != nil {
return
}
}
}
funcs, ex = s.afterMiddlewarePath[path]
if ex {
if len(funcs) != 0 {
for idx := range funcs {
err := funcs[idx](ctx, req, resp)
if err != nil {
return
}
}
}
}
// response

// 1. 序列化
var respBody []byte
respBody, err = serialization.Encode(resp)

var metaDataByte []byte
metaDataByte, _ = serialization.Encode(ctx.GetMetaData())
// 2. 加密
metaDataByte, err = cryptology.AESEncrypt(aesKey, metaDataByte)
if err != nil {
return
}
respBody, err = cryptology.AESEncrypt(aesKey, respBody)
if err != nil {
return
}
// 3. 压缩
metaDataByte, err = compressor.Zip(metaDataByte)
if err != nil {
return
}
respBody, err = compressor.Zip(respBody)
if err != nil {
return
}
// 4. 打包
_, message, err := protocol.EncodeMessage(msg.MagicNumber, []byte(msg.ServiceName), []byte(msg.ServiceMethod), metaDataByte, byte(protocol.Response), msg.Header.CompressorType, msg.Header.SerializationType, respBody)
if err != nil {
return
}
// 5. 回写
err = xChannel.Send(message)
if err != nil {
return
}
}

调用具体方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
go复制代码func (s *service) call(ctx *light.Context, mType *methodType, request, response reflect.Value) (err error) {
// recover 捕获堆栈消息
defer func() {
if r := recover(); r != nil {
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
buf = buf[:n]

err = fmt.Errorf("[painc service internal error]: %v, method: %s, argv: %+v, stack: %s",
r, mType.method.Name, request.Interface(), buf)
log.Println(err)
}
}()

fn := mType.method.Func
returnValue := fn.Call([]reflect.Value{s.refVal, reflect.ValueOf(ctx), request, response})
errInterface := returnValue[0].Interface()
if errInterface != nil {
return errInterface.(error)
}

return nil
}

这里就完成了服务端的基础逻辑了

专栏: juejin.cn/column/6986…

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

0%