开发者博客 – IT技术 尽在开发者博客

开发者博客 – 科技是第一生产力


  • 首页

  • 归档

  • 搜索

Scala 技术周刊 第 23 期

发表于 2017-10-15

这里有最新的 Scala 社区动态、技术博文。
微信搜索 「scalacool」关注我们,及时获取最新资讯。


深度阅读

  • Getting Started with Elastic4s programming in Scala
    Elastic4s 介绍
  • “Bootstrapping the Web with Scala Native” by Richard Whaling
    Scala Native 在 Web 方面的应用
  • Solving Dynamic Programming problems using Functional Programming (Part 2)
    函数式编程
  • We need a good name for Scala programmers. Scala doesn’t have a mascot and it’s not the most punnable name.
    为 Scala 程序员征集名字
  • Akka Typed: New Cluster Tools API
    Akka Typed

一周速递

  • Play 2.6.6 发布
  • Lagom 1.3.9 发布

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

Netty线程模型及EventLoop详解

发表于 2017-10-15

线程模型与并发

什么是线程模型呢?线程模型指定了线程管理的模型。在进行并发编程的过程中,我们需要小心的处理多个线程之间的同步关系,而一个好的线程模型可以大大减少管理多个线程的成本。在阅读本文之前,你可以选择性的阅读下面列出的文章,来快速了解和回顾java中的并发编程内容:

  • Java线程池详解(一)
  • Java线程池详解(二)
  • Java调度线程池ScheduleExecutorService
  • Java调度线程池ScheduleExecutorService(续)
  • Java中的ThreadLocal和 InheritableThreadLocal
  • Java AQS
  • Java可重入锁详解

Reactor线程模型

Reactor是一种经典的线程模型,Reactor线程模型分为单线程模型、多线程模型以及主从多线程模型。下面分别分析一下各个Reactor线程模型的优缺点。首先是Reactor单线程模型,下面的图片展示了这个线程模型的结构:

Reactor单线程模型

Reactor单线程模型

Reactor单线程模型仅使用一个线程来处理所有的事情,包括客户端的连接和到服务器的连接,以及所有连接产生的读写事件,这种线程模型需要使用异步非阻塞I/O,使得每一个操作都不会发生阻塞,Handler为具体的处理事件的处理器,而Acceptor为连接的接收者,作为服务端接收来自客户端的链接请求。这样的线程模型理论上可以仅仅使用一个线程就完成所有的事件处理,显得线程的利用率非常高,而且因为只有一个线程在工作,所有不会产生在多线程环境下会发生的各种多线程之间的并发问题,架构简单明了,线程模型的简单性决定了线程管理工作的简单性。但是这样的线程模型存在很多不足,比如:

  • 仅利用一个线程来处理事件,对于目前普遍多核心的机器来说太过浪费资源
  • 一个线程同时处理N个连接,管理起来较为复杂,而且性能也无法得到保证,这是以线程管理的简洁换取来的事件管理的复杂性,而且是在性能无 法得到保证的前提下换取的,在大流量的应用场景下根本没有实用性
  • 根据第二条,当处理的这个线程负载过重之后,处理速度会变慢,会有大量的事件堆积,甚至超时,而超时的情况下,客户端往往会重新发送请求,这样的情况下,这个单线程的模型就会成为整个系统的瓶颈
  • 单线程模型的一个致命缺钱就是可靠性问题,因为仅有一个线程在工作,如果这个线程出错了无法正常执行任务了,那么整个系统就会停止响应,也就是系统会因为这个单线程模型而变得不可用,这在绝大部分场景(所有)下是不允许出现的

介于上面的种种缺陷,Reactor演变出了第二种模型,也就是Reactor多线程模型,下面展示了这种模型:

Reactor多线程模型

Reactor多线程模型

可以发现,多线程模型下,接收链接和处理请求作为两部分分离了,而Acceptor使用单独的线程来接收请求,做好准备后就交给事件处理的handler来处理,而handler使用了一个线程池来实现,这个线程池可以使用Executor框架实现的线程池来实现,所以,一个连接会交给一个handler线程来复杂其上面的所有事件,需要注意,一个连接只会由一个线程来处理,而多个连接可能会由一个handler线程来处理,关键在于一个连接上的所有事件都只会由一个线程来处理,这样的好处就是消除了不必要的并发同步的麻烦。Reactor多线程模型似乎已经可以很好的工作在我们的项目中了,但是还有一个问题没有解决,那就是,多线程模型下任然只有一个线程来处理客户端的连接请求,那如果这个线程挂了,那整个系统任然会变为不可用,而且,因为仅仅由一个线程来负责客户端的连接请求,如果连接之后要做一些验证之类复杂耗时操作再提交给handler线程来处理的话,就会出现性能问题。

Reactor多线程模型对Reactor单线程模型做了一些改进,但是在某些场景下任然有所缺陷,所以就有了第三种Reactor模型,Reactor主从多线程模型,下面展示了这种模型的架构:

Reactor主从多线程模型

Reactor主从多线程模型

Reactor多线程模型解决了Reactor单线程模型和Reactor多线程模型中存在的问题,解决了handler的性能问题,以及Acceptor的安全以及性能问题,Netty就使用了这种线程模型来处理事件。

Netty线程模型

在了解了线程模型以及Reactor线程模型之后,我们来看一下Netty的线程模型是怎么样的。首先,Netty使用EventLoop来处理连接上的读写事件,而一个连接上的所有请求都保证在一个EventLoop中被处理,一个EventLoop中只有一个Thread,所以也就实现了一个连接上的所有事件只会在一个线程中被执行。一个EventLoopGroup包含多个EventLoop,可以把一个EventLoop当做是Reactor线程模型中的一个线程,而一个EventLoopGroup类似于一个ExecutorService,当然,这只是为了更好的理解Netty的线程模型,它们之间是没有等价关系的,后面的分析中会详细讲到。下面的图片展示了Netty的线程模型:

Netty线程模型

Netty线程模型

首先看一下Netty服务端启动的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
复制代码
// Configure the server.
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 100)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(your_handler_name, your_handler_instance);
}
});

// Start the server.
ChannelFuture f = b.bind(PORT).sync();

// Wait until the server socket is closed.
f.channel().closeFuture().sync();
} finally {
// Shut down all event loops to terminate all threads.
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}

Netty的服务端使用了两个EventLoopGroup,而第一个EventLoopGroup通常只有一个EventLoop,通常叫做bossGroup,负责客户端的连接请求,然后打开Channel,交给后面的EventLoopGroup中的一个EventLoop来负责这个Channel上的所有读写事件,一个Channel只会被一个EventLoop处理,而一个EventLoop可能会被分配给多个Channel来负责上面的事件,当然,Netty不仅支持NI/O,还支持OI/O,所以两者的EventLoop分配方式有所区别,下面分别展示了NI/O和OI/O的分配方式:

Netty NIO分配EventLoop模型

Netty NIO分配EventLoop模型

Netty OIO分配EventLoop模型

Netty OIO分配EventLoop模型

在NI/O非阻塞模式下,Netty将负责为每个Channel分配一个EventLoop,一旦一个EventLoop呗分配给了一个Channel,那么在它的整个生命周期中都使用这个EventLoop,但是多个Channel将可能共享一个EventLoop,所以和Thread相关的ThreadLocal的使用就要特别注意,因为有多个Channel在使用该Thread来处理读写时间。在阻塞IO模式下,考虑到一个Channel将会阻塞,所以不太可能将一个EventLoop共用于多个Channel之间,所以,每一个Channel都将被分配一个EventLoop,并且反过来也成立,也就是一个EventLoop将只会被绑定到一个Channel上来处理这个Channel上的读写事件。无论是非阻塞模式还是阻塞模式,一个Channel都将会保证一个Channel上的所有读写事件都只会在一个EventLoop上被处理。

Netty EventLoop

上文中分析了Reactor线程模型以及Netty的线程模型,在Netty中,EventLoop是一个极为重要的组件,它翻译过来称为事件循环,一个EventLoop将被分配给一个Channel,来负责这个Channel的整个生命周期之内的所有事件,下面来分析一下EventLoop的结构和实现细节。首先展示了EventLoop的类图:

EventLoop类图

EventLoop类图

从EventLoop的类图中可以发现,其实EventLoop继承了Java的ScheduledExecutorService,也就是调度线程池,所以,EventLoop应当有ScheduledExecutorService提供的所有功能。那为什么需要继承ScheduledExecutorService呢,也就是为什么需要延时调度功能,那是因为,在Netty中,有可能用户线程和Netty的I/O线程同时操作网络资源,而为了减少并发锁竞争,Netty将用户线程的任务包装成Netty的task,然后向Netty的I/O任务一样去执行它们。有些时候我们需要延时执行任务,或者周期性执行任务,那么就需要调度功能。这是Netty在设计上的考虑,为我们极大的简化的编程方法。

EventLoop是一个接口,它在继承了ScheduledExecutorService等多个类的同时,仅仅提供了一个方法parent,这个方法返回它属于哪个EventLoopGroup。本文只分析非阻塞模式,而阻塞模式留到未来某个合适的时候再做分析总结。在上文中展示的服务端启动的代码中我们发现我们使用的EventLoop是一个子类NioEventLoopGroup,下面就来分析一下NioEventLoopGroup这个类。首先展示一下NioEventLoopGroup的类图:

 NioEventLoopGroup类图

NioEventLoopGroup类图

可以发现,NioEventLoopGroup的实现非常的复杂,但是只要我们清楚了Netty的线程模型,我们就可以有入口去分析它的代码。首先,我们知道每个EventLoop只要一个Thread来处理事件,那我们就来找到那个Thread在什么地方。可以在SingleThreadEventExecutor类中找到thread,它的初始化在doStartThread这个方法中,而这个方法被startThread方法调用,而startThread 这个方法被execute方法调用,也就是提交任务的入口,这个方法是Executor接口的唯一方法。也就是说,所有我们通过EventLoop的execute方法提交的任务都将被这个Thread线程来执行。我们还知道一个事实,EventLoop是一个循环执行来消耗Channel事件的类,那么它必然会有一个类似循环的方法来作为任务,来提交给这个Thread来执行,而这可以在doStartThread方法中被发现,因为这个方法非常重要,所以下面展示了它的实现细节,但是去掉了一些代码来减少代码量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
复制代码
private void doStartThread() {
assert thread == null;
executor.execute(new Runnable() {
@Override
public void run() {
thread = Thread.currentThread();
if (interrupted) {
thread.interrupt();
}

boolean success = false;
updateLastExecutionTime();
try {
SingleThreadEventExecutor.this.run();
success = true;
} catch (Throwable t) {
logger.warn("Unexpected exception from an event executor: ", t);
} finally {
for (;;) {
int oldState = state;
if (oldState >= ST_SHUTTING_DOWN || STATE_UPDATER.compareAndSet(
SingleThreadEventExecutor.this, oldState, ST_SHUTTING_DOWN)) {
break;
}
}

try {
// Run all remaining tasks and shutdown hooks.
for (;;) {
if (confirmShutdown()) {
break;
}
}
} finally {
try {
cleanup();
} finally {
STATE_UPDATER.set(SingleThreadEventExecutor.this, ST_TERMINATED);
threadLock.release();
terminationFuture.setSuccess(null);
}
}
}
}
});
}

上面所提到的事件循环就是通过SingleThreadEventExecutor.this.run()这句话来触发的。这个run方法的具体实现在NioEventLoop中,下面展示了它的实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
复制代码
protected void run() {
for (;;) {
try {
switch (selectStrategy.calculateStrategy(selectNowSupplier, hasTasks())) {
case SelectStrategy.CONTINUE:
continue;
case SelectStrategy.SELECT:
select(wakenUp.getAndSet(false));
if (wakenUp.get()) {
selector.wakeup();
}
// fall through
default:
}

cancelledKeys = 0;
needsToSelectAgain = false;
final int ioRatio = this.ioRatio;
if (ioRatio == 100) {
try {
processSelectedKeys();
} finally {
// Ensure we always run tasks.
runAllTasks();
}
} else {
final long ioStartTime = System.nanoTime();
try {
processSelectedKeys();
} finally {
// Ensure we always run tasks.
final long ioTime = System.nanoTime() - ioStartTime;
runAllTasks(ioTime * (100 - ioRatio) / ioRatio);
}
}
} catch (Throwable t) {
handleLoopException(t);
}
// Always handle shutdown even if the loop processing threw an exception.
try {
if (isShuttingDown()) {
closeAll();
if (confirmShutdown()) {
return;
}
}
} catch (Throwable t) {
handleLoopException(t);
}
}
}

首先,我们来分析一下NioEventLoop的相关细节,在一个无限循环里面,只有在遇到shutdown的情况下才会停止循环。然后在循环里会询问是否有事件,如果没有,则继续循环,如果有事件,那么就开始处理时间。上文中我们提到,在事件循环中我们不仅要处理IO事件,还要处理非I/O事件。Netty中可以设置用于I/O操作和非I/O操作的时间占比,默认各位50%,也就是说,如果某次I/O操作的时间花了100ms,那么这次循环中非I/O得任务也可以花费100ms。Netty中的I/O时间处理通过processSelectedKeys方法来进行,而非I/O操作通过runAllTasks反复来进行,首先来看runAllTasks方法,虽然设定了一个可以运行的时间参数,但是实际上Netty并不保证能精确的确保非I/O任务只运行设定的毫秒,下面来看下runAllTasks的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
复制代码
protected boolean runAllTasks(long timeoutNanos) {
fetchFromScheduledTaskQueue();
Runnable task = pollTask();
if (task == null) {
afterRunningAllTasks();
return false;
}

final long deadline = ScheduledFutureTask.nanoTime() + timeoutNanos;
long runTasks = 0;
long lastExecutionTime;
for (;;) {
safeExecute(task);

runTasks ++;

// Check timeout every 64 tasks because nanoTime() is relatively expensive.
// XXX: Hard-coded value - will make it configurable if it is really a problem.
if ((runTasks & 0x3F) == 0) {
lastExecutionTime = ScheduledFutureTask.nanoTime();
if (lastExecutionTime >= deadline) {
break;
}
}

task = pollTask();
if (task == null) {
lastExecutionTime = ScheduledFutureTask.nanoTime();
break;
}
}

afterRunningAllTasks();
this.lastExecutionTime = lastExecutionTime;
return true;
}

// 将任务运行起来
protected static void safeExecute(Runnable task) {
try {
task.run();
} catch (Throwable t) {
logger.warn("A task raised an exception. Task: {}", task, t);
}
}

可以看到,这个方法是在每运行了64个任务之后再进行比较的,如果超出了设定的运行时间则退出,否则再运行64个任务再比较。所以,Netty强烈要求不要在I/O线程中运行阻塞任务,因为阻塞任务将会阻塞住Netty的事件循环,从而造成事件堆积的现象。现在回头看处理I/O任务的processSelectedKeys方法,跟踪代码之后发现最后实际处理I/O事件的一个方法为processSelectedKey,下面展示了它的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
复制代码
private void processSelectedKey(SelectionKey k, AbstractNioChannel ch) {
final AbstractNioChannel.NioUnsafe unsafe = ch.unsafe();
if (!k.isValid()) {
final EventLoop eventLoop;
try {
eventLoop = ch.eventLoop();
} catch (Throwable ignored) {
// If the channel implementation throws an exception because there is no event loop, we ignore this
// because we are only trying to determine if ch is registered to this event loop and thus has authority
// to close ch.
return;
}
// Only close ch if ch is still registered to this EventLoop. ch could have deregistered from the event loop
// and thus the SelectionKey could be cancelled as part of the deregistration process, but the channel is
// still healthy and should not be closed.
// See https://github.com/netty/netty/issues/5125
if (eventLoop != this || eventLoop == null) {
return;
}
// close the channel if the key is not valid anymore
unsafe.close(unsafe.voidPromise());
return;
}

try {
int readyOps = k.readyOps();
// We first need to call finishConnect() before try to trigger a read(...) or write(...) as otherwise
// the NIO JDK channel implementation may throw a NotYetConnectedException.
if ((readyOps & SelectionKey.OP_CONNECT) != 0) {
// remove OP_CONNECT as otherwise Selector.select(..) will always return without blocking
// See https://github.com/netty/netty/issues/924
int ops = k.interestOps();
ops &= ~SelectionKey.OP_CONNECT;
k.interestOps(ops);

unsafe.finishConnect();
}

// Process OP_WRITE first as we may be able to write some queued buffers and so free memory.
if ((readyOps & SelectionKey.OP_WRITE) != 0) {
// Call forceFlush which will also take care of clear the OP_WRITE once there is nothing left to write
ch.unsafe().forceFlush();
}

// Also check for readOps of 0 to workaround possible JDK bug which may otherwise lead
// to a spin loop
if ((readyOps & (SelectionKey.OP_READ | SelectionKey.OP_ACCEPT)) != 0 || readyOps == 0) {
unsafe.read();
}
} catch (CancelledKeyException ignored) {
unsafe.close(unsafe.voidPromise());
}
}

这个方法运行的流程为:

  1. 从Channel上获取一个unsafe对象,这个对象 是用来进行NIO操作的一系列系统级API,关于Netty的Channel的深层次分析将在另外的篇章中进行
  2. 从Channel上获取了eventLoop,而这个eventLoop是什么时候分配给Channel的细节在后文中进行分析
  3. 根据事件调用底层API来处理事件

下面,我们分析一下是什么时候将一个EventLoop分配给一个Channel的,并且这个EventLoop的那个唯一的Thread是什么时候被赋值的。在这个问题上,服务端的流程和客户端的流程可能不太一样,对于服务端来说,首先需要bind一个端口,然后在进行Accept进来的连接,而客户端需要进行connect到服务端。先来分析一下服务端。

还是看上面提供的服务端的示例代码,其中启动的代码为下面这句代码:

1
2
3
复制代码
// Start the server.
ChannelFuture f = b.bind(PORT).sync();

也就是我们网络编程中的bind操作,这个操作会发生什么呢?追踪代码如下:

1
2
3
4
5
6
7
8
9
复制代码 
-> AbstractBootstrap.bind(port)
-> AbstractBootstrap.bind(address)
-> AbstractBootstrap.doBind(final SocketAddress localAddress)
-> AbstractBootstrap.initAndRegister
-> AbstractBootstrap.doBind0
-> SingleThreadEventExecutor.execute
-> SingleThreadEventExecutor.startThread()
-> SingleThreadEventExecutor.doStartThread

EventLoop在AbstractBootstrap.initAndRegister中获得了一个新的Channel,然后在AbstractBootstrap.doBind0 方法里面调用接下来的方法来初始化EventLoop的Thread的工作,并且将EventLoop的时间循环打开了,可以开始接收客户端的连接请求了。下面来分析一下客户端的流程。

一个客户端的启动代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
复制代码
// Configure the client.
EventLoopGroup group = new NioEventLoopGroup();
try {
Bootstrap b = new Bootstrap();
b.group(group)
.channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
if (sslCtx != null) {
p.addLast(sslCtx.newHandler(ch.alloc(), HOST, PORT));
}
//p.addLast(new LoggingHandler(LogLevel.INFO));
p.addLast(new EchoClientHandler());
}
});

// Start the client.
ChannelFuture f = b.connect(HOST, PORT).sync();

// Wait until the connection is closed.
f.channel().closeFuture().sync();
} finally {
// Shut down the event loop to terminate all threads.
group.shutdownGracefully();
}

其中启动的关键代码为:

1
2
3
复制代码
// Start the client.
ChannelFuture f = b.connect(HOST, PORT).sync();

下面是connect的调用流程:

1
2
3
4
5
6
7
复制代码 -> Bootstrap.doResolveAndConnect
-> AbstractBootstrap.initAndRegister
-> Bootstrap.doResolveAndConnect0
-> Bootstrap.doConnect
-> SingleThreadEventExecutor.execute
-> SingleThreadEventExecutor.startThread()
-> SingleThreadEventExecutor.doStartThread

后半部分和服务端的启动过程是一致的,而区别在于服务端是通过bind操作来启动的,而客户端是通过connect操作来启动的。执行到此,客户端和服务端的EventLoop都已经启动起来,服务端可以接受客户端的连接并且处理Channel上的读写事件,而客户端可以去连接远程服务端来请求数据。

EventLoopGroup

到目前为止,我们已经知道了Reactor多个线程模型,并且知道了一个EventLoop会负责一个Channel的生命周期内的所有事件,并且知道了服务端和客户端是如何启动这个EventLoop得,但是还有一个问题没有解决,那就是一个EventLoop是如何被分配给一个Channel的。下文就来分析这个分配的原理和过程。而对于阻塞I/O模型的分配和非阻塞I/O模型的分配是不一样的,在上文中也提到这个内容,所以本文只分析对于非阻塞I/O模型的分配。

EventLoopGroup是用来管理EventLoop的对象,一个EventLoopGroup里面有多个EventLoop,下面展示了EventLoopGroup的类图:

EventLoopGroup类图

EventLoopGroup类图

我们从实际的代码出发来分析EventLoopGroup。上文中已经展示了客户端和服务端的启动代码,其中有类似的代码如下:

1
2
3
复制代码
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup();

上文中我们分析了EventLoop被启动的过程,我们肯定,EventLoop是在分配之后启动的,因为对于服务端而言,bind是一个最开始的网络操作,对于客户端来说,connect也是最开始的网络操作,在这之前是没有关于网络I/O的操作的,所以,EventLoop的分配和启动是在这两个过程或者之后的流程中进行的,但是EventLoop的分配肯定是在启动之前的,但是EventLoop的分配和启动在bind和connect中进行,那么我们可以肯定,EventLoop的分配也是在这两个方法中进行的。为了证明这个假设,回头再看一下服务端的EventLoop的启动过程,其中有一个方法值得我们注意:AbstractBootstrap.initAndRegister,我们进行了init部分的分析,而register部分我们还没有分析,下面就对服务端来进行register部分的分析,下面展示了register的调用链路:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
复制代码
-> Bootstrap.doResolveAndConnect
-> AbstractBootstrap.initAndRegister
-> EventLoopGroup.register
-> MultithreadEventLoopGroup.register
-> SingleThreadEventLoop.register
-> Channel.register
-> AbstractUnsafe.register

public final void register(EventLoop eventLoop, final ChannelPromise promise) {
if (eventLoop == null) {
throw new NullPointerException("eventLoop");
}
if (isRegistered()) {
promise.setFailure(new IllegalStateException("registered to an event loop already"));
return;
}
if (!isCompatible(eventLoop)) {
promise.setFailure(
new IllegalStateException("incompatible event loop type: " + eventLoop.getClass().getName()));
return;
}

AbstractChannel.this.eventLoop = eventLoop;

if (eventLoop.inEventLoop()) {
register0(promise);
} else {
try {
eventLoop.execute(new Runnable() {
@Override
public void run() {
register0(promise);
}
});
} catch (Throwable t) {
logger.warn(
"Force-closing a channel whose registration task was not accepted by an event loop: {}",
AbstractChannel.this, t);
closeForcibly();
closeFuture.setClosed();
safeSetFailure(promise, t);
}
}
}

最后展示了AbstractUnsafe.register这个方法,在这里初始化了一个EventLoop,需要记住的一点是,EventLoopGroup中的是EventLoop,不然在追踪代码的时候会迷失。现在来正式看一下NioEventLoopGroup这个类,它的它继承了MultithreadEventExecutorGroup这个类,而我们在初始化EventLoopGroup的时候传递进去的参数,也就是我们希望这个EventLoopGroup拥有的EventLoop数量,会在MultithreadEventExecutorGroup这个类中初始化,并且是在构造函数中初始化的,如果在new
EventLoopGroup的时候没有任何参数,那么默认的EventLoop的数量是机器CPU数量的两倍。现在我们来看一下MultithreadEventExecutorGroup这个类的一个重要的构造函数,这个构造函数初始化了EventLoopGroup的EventLoop。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
复制代码
protected MultithreadEventExecutorGroup(int nThreads, Executor executor,
EventExecutorChooserFactory chooserFactory, Object... args) {
if (nThreads <= 0) {
throw new IllegalArgumentException(String.format("nThreads: %d (expected: > 0)", nThreads));
}

if (executor == null) {
executor = new ThreadPerTaskExecutor(newDefaultThreadFactory());
}

children = new EventExecutor[nThreads];

for (int i = 0; i < nThreads; i ++) {
boolean success = false;
try {
children[i] = newChild(executor, args);
success = true;
} catch (Exception e) {
// TODO: Think about if this is a good exception type
throw new IllegalStateException("failed to create a child event loop", e);
} finally {
if (!success) {
for (int j = 0; j < i; j ++) {
children[j].shutdownGracefully();
}

for (int j = 0; j < i; j ++) {
EventExecutor e = children[j];
try {
while (!e.isTerminated()) {
e.awaitTermination(Integer.MAX_VALUE, TimeUnit.SECONDS);
}
} catch (InterruptedException interrupted) {
// Let the caller handle the interruption.
Thread.currentThread().interrupt();
break;
}
}
}
}
}

chooser = chooserFactory.newChooser(children);

final FutureListener<Object> terminationListener = new FutureListener<Object>() {
@Override
public void operationComplete(Future<Object> future) throws Exception {
if (terminatedChildren.incrementAndGet() == children.length) {
terminationFuture.setSuccess(null);
}
}
};

for (EventExecutor e: children) {
e.terminationFuture().addListener(terminationListener);
}

Set<EventExecutor> childrenSet = new LinkedHashSet<EventExecutor>(children.length);
Collections.addAll(childrenSet, children);
readonlyChildren = Collections.unmodifiableSet(childrenSet);
}

一个较为重要的方法为newChild,这是初始化一个EventLoop的方法,下面是它的具体实现,假设我们使用NioEventLoop:

1
2
3
4
5
复制代码
protected EventLoop newChild(Executor executor, Object... args) throws Exception {
return new NioEventLoop(this, executor, (SelectorProvider) args[0],
((SelectStrategyFactory) args[1]).newSelectStrategy(), (RejectedExecutionHandler) args[2]);
}

我们现在知道了EventLoopGroup管理着很多的EventLoop,上文中我们仅仅分析了分配的流程,但是分配的策略还没有分析,现在来分析一下EventLoopGroup是如何分配EventLoop给Channel的,我们仅分析非阻塞I/O下的分配策略,阻塞模式下的分配策略可以参考非阻塞下的分配策略。

在MultithreadEventLoopGroup.register方法中,调用了next()方法,我们来看一下这个流程:

1
2
3
4
5
6
复制代码
-> MultithreadEventExecutorGroup.next()

public EventExecutor next() {
return chooser.next();
}

chooser是什么东西?

1
2
复制代码
private final EventExecutorChooserFactory.EventExecutorChooser chooser;

它是怎么初始化的呢?

1
2
3
4
5
6
7
8
复制代码
public EventExecutorChooser newChooser(EventExecutor[] executors) {
if (isPowerOfTwo(executors.length)) {
return new PowerOfTwoEventExecutorChooser(executors);
} else {
return new GenericEventExecutorChooser(executors);
}
}

这是它初始化最后调用的方法,这个方法在DefaultEventExecutorChooserFactory中被实现,这个参数是MultithreadEventExecutorGroup类中的children,也就是EventLoopGroup中的所有EventLoop,那这个newChooser得分配方法就是如果EventLoop的数量是2的n次方,那么就使用PowerOfTwoEventExecutorChooser来分配,否则使用GenericEventExecutorChooser来分配。这两个策略类的分配方法实现分别如下:

1
2
3
4
5
6
7
8
9
10
复制代码     
1、PowerOfTwoEventExecutorChooser
public EventExecutor next() {
return executors[idx.getAndIncrement() & executors.length - 1];
}

2、GenericEventExecutorChooser
public EventExecutor next() {
return executors[Math.abs(idx.getAndIncrement() % executors.length)];
}

所以,到此为止,我们可以解决为什么一个EventLoop会被分配给多个Channel的疑惑。本文到此也就结束了。篇幅较长,内容涉及到Reactor的三种线程模型,然后分析了Netty的线程模型,然后分析了Netty的EventLoop,以及EventLoopGroup,以及分析了EventLoop是怎么被分配给一个Channel的,和一个EventLoop是如何启动起来来处理事件的。最后分析了EventLoopGroup分配EventLoop的策略,对于本文涉及的内容的更为深入的分析总结,将在未来的某个适宜的时刻进行。

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

Java中的ThreadLocal和 Inheritable

发表于 2017-10-15

ThreadLocal

ThreadLocal从字面理解就是线程本地变量,貌似是一种线程私有的缓存变量的容器。为了说明ThreadLocal的特点,举个例子:比如有三个人,每个人比作一个线程,它们都需要一个袋子来装捡到的东西,也就是每个线程都希望自己有一个容器,当然,自己的捡到的东西肯定不希望和别人分享啊,也就是希望这个容器对其他人(线程)是不可见的,如果现在只有一个袋子,那怎么办?

  1. 每个人在捡东西之前一定会先抢到那个唯一的袋子,然后再捡东西,如果使用袋子的时间到了,就会马上把里面的东西消费掉,然后把袋子放到原来的地方,然后再次去抢袋子。这个方案是使用锁来避免线程竞争问题的,三个线程需要竞争同一个共享变量。
  2. 我们假设现在不是只有一个袋子了,而是有三个袋子,那么就可以给每个人安排一个袋子,然后每个人的袋子里面的对象是对其他人不可见的,这样的好处是解决了多个人竞争同一个袋子的问题。这个方案就是使用ThreadLocal来避免不必要的线程竞争的。

大概了解了ThreadLocal,下面来看看它的使用方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
复制代码
private static class UnsafeThreadClass {

private int i;

UnsafeThreadClass(int i) {
this.i = i;
}

int getAndIncrement() {
return ++ i;
}

@Override
public String toString() {
return "[" + Thread.currentThread().getName() + "]" + i;
}
}

private static ThreadLocal<UnsafeThreadClass> threadLocal = new ThreadLocal<>();

static class ThreadLocalRunner extends Thread {
@Override
public void run() {

UnsafeThreadClass unsafeThreadClass = threadLocal.get();

if (unsafeThreadClass == null) {
unsafeThreadClass = new UnsafeThreadClass(0);
threadLocal.set(unsafeThreadClass);
}

unsafeThreadClass.getAndIncrement();

System.out.println(unsafeThreadClass);
}

}

上面的例子仅仅是为了说明ThreadLocal可以为每个线程保存一个本地变量,这个变量不会受到其他线程的干扰,你可以使用多个ThreadLocal来让线程保存多个变量,下面我们分析一下ThreadLocal的具体实现细节,首先展示了ThreadLocal提供的一些方法,我们重点关注的是get、set、remove方法。

ThreadLocal方法

ThreadLocal方法

首先,我们需要new一个ThreadLocal对象,那么ThreadLocal的构造函数做了什么呢?

1
2
3
4
5
6
7
复制代码
/**
* Creates a thread local variable.
* @see #withInitial(java.util.function.Supplier)
*/
public ThreadLocal() {
}

很遗憾它什么都没做,那么初始化的过程势必是在首次set的时候做的,我们来看一下set方法的细节:

1
2
3
4
5
6
7
8
9
复制代码
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

看起来首先根据当前线程获取到了一个ThreadLocalMap,getMap方法是做了什么?

1
2
3
4
复制代码
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

非常的简洁,是和Thread与生俱来的,我们看一下Thread中的相关定义:

1
2
3
4
5
6
7
8
9
10
复制代码
/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;

/*
* InheritableThreadLocal values pertaining to this thread. This map is
* maintained by the InheritableThreadLocal class.
*/
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

关于inheritableThreadLocals将在下一小节再学习总结。

获得了线程的ThreadLocalMap之后,如果不为null,说明不是首次set,直接set就可以了,注意key是this,也就是当前的ThreadLocal啊不是Thread。如果为空呢?说明还没有初始化,那么就需要执行createMap这个方法:

1
2
3
4
复制代码
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

没什么特别的,就是初始化线程的threadLocals,然后设定key-value。

下面分析一下get的逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
复制代码
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

和set一样,首先根据当前线程获取ThreadLocalMap,然后判断是否为null,如果为null,说明ThreadLocalMap还没有被初始化啊,那么就返回方法setInitialValue的结果,这个方法做了什么?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
复制代码
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

protected T initialValue() {
return null;
}

最后会返回null,但是会做一些初始化的工作,和set一样。在get里面,如果返回的ThreadLocalMap不为null,则说明ThreadLocalMap已经被初始化了,那么就可以正常根据ThreadLocal作为key获取了。

当线程退出时,会清理ThreadLocal,可以看下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
复制代码
/**
* This method is called by the system to give a Thread
* a chance to clean up before it actually exits.
*/
private void exit() {
if (group != null) {
group.threadTerminated(this);
group = null;
}
/* Aggressively null out all reference fields: see bug 4006245 */
target = null;
/* Speed the release of some of these resources */
threadLocals = null;
inheritableThreadLocals = null;
inheritedAccessControlContext = null;
blocker = null;
uncaughtExceptionHandler = null;
}

这里做了大量“Help GC”的工作。包括我们本节所讲的threadLocals和下一小节要讲的inheritableThreadLocals都会被清理。

如果我们想要显示的清理ThreadLocal,可以使用remove方法:

1
2
3
4
5
6
复制代码
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}

逻辑较为直接,很好理解。

InheritableThreadLocal

ThreadLocal固然很好,但是子线程并不能取到父线程的ThreadLocal的变量,比如下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
复制代码
private static ThreadLocal<Integer> integerThreadLocal = new ThreadLocal<>();
private static InheritableThreadLocal<Integer> inheritableThreadLocal =
new InheritableThreadLocal<>();

public static void main(String[] args) throws InterruptedException {

integerThreadLocal.set(1001); // father
inheritableThreadLocal.set(1002); // father

new Thread(() -> System.out.println(Thread.currentThread().getName() + ":"
+ integerThreadLocal.get() + "/"
+ inheritableThreadLocal.get())).start();

}

//output:
Thread-0:null/1002

使用ThreadLocal不能继承父线程的ThreadLocal的内容,而使用InheritableThreadLocal时可以做到的,这就可以很好的在父子线程之间传递数据了。下面我们分析一下InheritableThreadLocal的实现细节,下面展示了InheritableThreadLocal提供的方法:

InheritableThreadLocal方法

InheritableThreadLocal方法

InheritableThreadLocal继承了ThreadLocal,然后重写了上面三个方法,所以除了上面三个方法之外,其他所有对InheritableThreadLocal的调用都是对ThreadLocal的调用,没有什么特别的。我们上文中提到了Thread类,里面有我们本文关心的两个成员,我们来看一下再Thread中做了哪些工作,我们跟踪一下new一个Thread的调用路径:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
复制代码
new Thread()

init(ThreadGroup g, Runnable target, String name, long stackSize)


init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc,
boolean inheritThreadLocals)

->
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

createInheritedMap(ThreadLocalMap parentMap)


ThreadLocalMap(ThreadLocalMap parentMap)

上面列出了最为关键的代码,可以看到,最后会调用ThreadLocal的createInheritedMap方法,而该方法会新建一个ThreadLocalMap,看一下构造函数的内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
复制代码
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}

parentMap就是父线程的ThreadLocalMap,这个构造函数的意思大概就是将父线程的ThreadLocalMap复制到自己的ThreadLocalMap里面来,这样我们就可以使用InheritableThreadLocal访问到父线程中的变量了。

对ThreadLocal更为具体和深入的分析将在其他的篇章中进行,本文点到即可,为了深入理解ThreadLocal,可以阅读ThreadLocalMap的源码,以及可以在项目中多思考是否可以使用ThreadLocal来做一些事情,比如,如果我们具有这样一种线程模型,一个任务从始至终只会被一个线程执行,那么可以使用ThreadLocal来计算运行该任务的时间。

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

SSH 详解 - 不挑食的程序员 - SegmentFaul

发表于 2017-09-29

文章同步于Github blog

简介

SSH(Secure Shell)是一个提供数据通信安全、远程登录、远程指令执行等功能的安全网络协议,由芬兰赫尔辛基大学研究员Tatu Ylönen,于1995年提出,其目的是用于替代非安全的Telnet、rsh、rexec等远程Shell协议。之后SSH发展了两个大版本SSH-1和SSH-2。

通过使用SSH,你可以把所有传输的数据进行加密,这样”中间人”这种攻击方式就不可能实现了,而且也能够防止 DNS欺骗 和 IP欺骗。使用SSH,还有一个额外的好处就是传输的数据是经过压缩的,所以可以加快传输的速度。SSH有很多功能,它既可以代替Telnet,又可以为FTP、Pop、甚至为PPP提供一个安全的”通道”。

SSH为一项创建在应用层和传输层基础上的安全协议,为计算机上的 Shell 提供安全的传输和使用环境。

SSH的基本框架

SSH协议框架中最主要的部分是三个协议:

  • 传输层协议(The Transport Layer Protocol):传输层协议提供服务器认证,数据机密性,信息完整性等的支持。
  • 用户认证协议(The User Authentication Protocol):用户认证协议为服务器提供客户端的身份鉴别。
  • 连接协议(The Connection Protocol):连接协议将加密的信息隧道复用成若干个逻辑通道,提供给更高层的应用协议使用。

同时SSH协议框架中还为许多高层的网络安全应用协议提供扩展的支持。它们之间的层次关系可以用如下图来表示:

SSH的加密

SSH从安全和性能两方面综合考虑,结合使用了 Public Key/Private key(公钥/私钥) 和 Secret Key(密钥)。

  • Public Key/Private key:非对称加密,安全,但效率低,不适合大规模进行数据的加密和解密操作
  • Secret Key:对称机密,高效,但安全性相对较低,Key的分发尤其不方便

对密码学基础知识和数字签名了解,可以参考阮一峰的博文

  • 密码学笔记
  • 数字签名是什么

SSH的主要特性

  • 加密:避免数据内容泄漏
  • 通信的完整性:避免数据被篡改,以及发送或接受地址伪装
    (检查数据是否被篡改,数据是否来自发送者而非攻击者) SSH-2 通过 MD5 和 SHA-1 实现该功能,SSH-1 使用 CRC-32
  • 认证:识别数据发送者和接收者身份 客户端验证SSH服务端的身份:防止攻击者仿冒SSH服务端身份,避免中介人攻击和重定向请求的攻击;OpenSSH 通过在 know-hosts 中存储主机名和 host key 对服务端身份进行认证 服务端验证请求者身份:提供安全性较弱的用户密码方式,和安全性更强的 per-user public-key signatures;此外SSH还支持与第三方安全服务系统的集成,如
    Kerberos等
  • 授权:用户访问控制
  • Forwarding or tunneling to encrypt other TCP/IP-based sessions 可以通过SSH为Telnet、FTP等提供通信安全保障,支持三种类型的 Forwarding 操作:Port Forwarding;X Forwarding;Agent Forwarding

SSH中的Key

SSH结合使用了Public Key/Private key和Secret Key:

  • Public Key/Private key(非对称加密)用于在建立安全通道前在客户端和服务端之间传输 Secret Key和进行身份认证;
  • Secret Key(对称加密)则用来作为SSH会话的安全保证,对数据进行加密和解密。

SSH可以处理4种密钥:

名称 生命周期 创建 类型 描述
Host Key 持久化 服务端 Public Key Host Key是服务器用来证明自己身份的一个永久性的非对称密钥
User Key 持久化 用户 Public Key User Key 是客户端用来证明用户身份的一个永久性的非对称密钥(一个用户可以有多个密钥/身份标识)
Server Key 默认为1小时 服务端 Public Key Server Key 是SSH-1协议中使用的一个临时的非对称密钥,每隔一定的间隔(默认是一个小时)都会在服务器重新生成。用于对Session Key进行加密(仅SSH-1协议有,SSH-2对其进行了增强,这里Server Key作为一个概念便于在流程中进行描述)
Session Key 客户端 会话(Session) Secret Key Session Key是一个随机生成的对称密钥,用户SSH客户端和服务器之间的通信进行加密,会话结束时,被销毁

SSH框架:

安全连接的建立

在进行有意义的会话之前,SSH客户端和服务器必须首先建立一条安全连接。该连接可以允许双方共享密钥、密码,最后可以相互传输任何数据。

现在我们介绍SSH-1协议是如何确保网络连接的安全性的。SSH-1客户端和服务器从阿卡似乎经过很多个步骤,协商使用加密算法,生成并共享一个会话密钥,最终建立一条安全连接:

  1. 客户端连接到服务器上
  2. 客户端和服务器交换自己支持的SSH协议版本号
  3. 客户端和服务器切换到基于报文的协议
  4. 服务器向客户端提供自己的身份证明和会话参数
  5. 客户端给服务器发送一个(会话)密钥
  6. 双方启用加密并完成服务器认证
  7. 建立安全连接

每个阶段均涉及到客户端与服务端的多次交互,通过这些交互过程完成包括证书传输、算法协商、通道加密等过程。

1 客户端连接到服务器上

这个步骤没什么好说的,就是向服务器的TCP端口(约定是22)发送连接请求。

2 客户端和服务器交换自己支持的协议版本号

这些协议是以 ASCII 字符串表示,例如:SSH-1.5-1.2.27,其意义为SSH协议,版本号是V1.5,SSH1实现版本为1.2.27。可以使用 Telnet 客户端连接到一个SSH服务器端口是看到这个字符串:

1
2
3
4
5
复制代码➜ telnet 192.168.1.200 22
Trying 192.168.1.200...
Connected to doc.dinghuo123.com.
Escape character is '^]'.
SSH-2.0-OpenSSH_6.0p1 Debian-4+deb7u6

如果客户端和服务器确定其协议版本号是兼容的,那么连按就继续进行,否则,双方都可能决定中断连接。例如,如果一个只使用 SSH-1 的客户端连接到一个只使用 SSH-2 的服务器上,那么客户端就会断开连接并打印一条错误消息。实际上还可能执行其他操作:例如,只使用SSH-2的服务器可以调用SSH-1服务器来处理这次连接请求。

3 客户端和服务器切换基于报文的协议

协议版本号交换过程一旦完成,客户端和服务器都立即从下层的 TCP 连接切换到基于子报文的协议。每个报文都包含一个32位的字段,1 - 8字节的填充位[ 用来防止已知明文攻击unknown-plaintext attack ],一个1字节的报文类型代码, 报文有效数据和一个4字节的完整性检査字段。

4 服务器向客户提洪自己的身份证明和会话参数

服务器向客户端发送以下信息(现在还沒有加密):

  • 主机密钥(Host Key),用于后面证明服务器主机的身份
  • 服务器密钥(Server Key),用来帮助建立安全连接
  • 8个随机字节序列,称为检测字节(check bytes)。客户端在下一次响应中必须包括这些检测字节,否則服务器就会拒绝接收响应信息,这种方法可以防止某些 IP伪装攻击(IP spoofing attack)。
  • 该服务器支持的加密、压缩和认证方法

此时,双方都要计算一个通用的 128 位会话标识符(Session ID)。它在某些协议中用来惟一标识这个 SSH 会话。该值是 主机密钥(Host Key)、服务器密钥(Server Key)和检测字节(check bytes)一起应用 MD5散列函数 得到的结果。

当客户端接收到 主机密钥(Host Key)时,它要进行询问:“之前我和这个服务器通信过吗?如果通信过,那么它的主机密钥是什么呢?”要回答这个问题,客户端就要査阅自己的已知名主机数据库。如果新近到达的主机密钥可以和数据库中以前的一个密钥匹,那么就没有问题了。

但是,此时还存在两种可能:已知名主机数据库中没有这个服务器,也可能有这个服务器但是其主机密钥不同。在这两种情况中,客户端要选择是信任这个新近到达的密钥还是拒绝接受该密钥。此时就需要人的指导参与了,例如,客户端用户可能被提示要求确定是接受还是拒绝该密钥。

1
2
3
复制代码The authenticity of host 'ssh-server.example.com (12.18.429.21)' can't be established.
RSA key fingerprint is 98:2e:d7:e0:de:9f:ac:67:28:c2:42:2d:37:16:58:4d.
Are you sure you want to continue connecting (yes/no)?

如果客户端拒绝接受这个主机密钥,那么连接就中止了。让我们假设客户端接受该密钥,现在继续介绍。

5 客户端给眼务器发送一个(会话)密钥

现在客户端为双方都支持的 bulk箅法 随机生成一个新密钥,称为 会话密钥(Session Key)。其目的是对客户端和服务器之间发送的数据进行加密和解密。所需要做的工作是把这个 会话密钥(Session Key)发送给服务器,双方就可以启用加密并开始安全通信了。

当然,客户端不能简单地把会话密钥(Session Key)发送给服务器。此时数据还没有进行加密,如果第三方中途截获了这个密钥,那么他就可以解密客户端和服务器之间的消息。此后你就和安全性无缘了。因此客户端必须安全地发送会话密钥(Session Key)。 这是通过两次加密实现的:一次使用服务器的公共主机密钥(Host Key),一次使用服务器密钥(Server Key)。

这个步骤确保只有服务器可以读取会话密钥(Session Key)。在会话密钥(Session Key)经过两次加密之后,客户端就将其发送给服务器,同时还会发送检测字节和所选定的算法(这些算法是从第4步中服务器支持的算法列表中挑选出来的)。

6 双方启用加密并完成服务器认证

在发送会话密钥之后,双方开始使用密钥和所选定的 bulk算法 对会话数据进行加密,但是在开始继续发送其他数据之前,客户端要等待服务器发来一个确认消息,该消息(以及之后的所有数据)都必须使用这个会话密钥(Session Key)加密。这是最后一歩,它提供了服务器认证:只有目的服务器才可以解密 会话密钥(Session Key),因为它是使用前面的 主机密钥(Host Key)(这个密钥已经对已知名主机列表进行了验证)进行加密的。

如果没有会话密钥(Session Key),假冒的服务器就不能解密以后的协议通信,也就不能生成有效的通信,客户端会注意到这一点并中断连接。

注意服务器认址是隐含的;并没有显式交换来验证服务器主机密钥(Host Key)。因此客户端在继续发送数椐之前,必须等待服务器使用新会话密钥(Session Key)作出有意义的响应。 从而在处理之前验证服务的身份,虽然 SSH-1 协议在这点上并没有什么特殊 . 但是 SSH-2 需要服务器认证时显示地地交换会话密钥(Session Key)。

使用服务器密钥(Server Key)对会话密钥(Session Key)再进行一次加密就提供了一种称为完美转发安全性的特性。这就是说不存在永久性密钥泄露的可能,因为它不会危害到其他部分和以后SSH会话的安全性。如果我们只使用服务器主机密钥(Host Key)来保护会话密钥(Session Key), 那么主机密钥(Host Key)的泄露就会危害到以后的通倍,并允许解密原来记录下来的会话。使用服务器密钥(Server Key)再加密次就消除了这种缺点,因为服务器密钥(Server Key)是临时的,它不会保存到磁盘上,而且会周期性地更新(缺省情况下,一小时更新一次)。如果一个入侵者已经获取了服务器的私钥,那么他必须还要执行中间人攻击或服务器欺骗攻击才能对会话造成损害。

7 建立安全连接

由于客户端和服务器现在都知道会话密钥(Session Key),而其他人都不知道,因此他们就可以相互发送加密消息(使用他们一致同意的 bulk算法 )并对其进行解密了。而且,客户端还可以完成服务器认证。我们现在就已经准备好开始客户端认证了。

客户端认证

SSH提供多种客户端认证方式。

SSH-1:

  • Password
  • Public Key
  • Kerberos
  • Rhosts && RhostsRSA
  • TIS

SSH-2:

  • Password
  • Public Key
  • hostbased 在SSH-2中考虑 Rhosts 存在安全漏洞,废弃了这种方式。

这里之讨论我们经常使用的的 Password 和 Public Key 方式。

此时安全通道已经及建立,之后的所有内容都通过 Session Key 加密后进行传输。

Password

Password 方式既客户端提供用户和密码,服务端对用户和密码进行匹配,完成认证。类Unix系统中,如 OpenSSH 的框架,一般通过系统的本地接口完成认证。

Password 的优势是简单,无需任何而外的配置就可以使用。缺点密码不便于记忆,过于简单的密码容易被暴力破解。

Public Key

Public Key 认证的基本原理是基于非对称加密方式,分别在服务端对一段数据通过公钥进行加密,如果客户端能够证明其可以使用私钥对这段数据进行解密,则可以说明客户端的身份。因为服务端需要使用客户端生成的密钥对的公钥对数据首先加密,所以需要先将公钥存储到服务端的密钥库(Auhtorized Key)。还记得Github中使用git协议push代码前需要先添加SSH KEY吗?

下面详细介绍一个通过 Public Key 进行客户端认证的过程。

  1. 客户端发起一个 Public Key 的认证请求,并发送 RSA Key 的模数作为标识符。(如果想深入了解RSA Key详细 –> 维基百科)
  2. 服务端检查是否存在请求帐号的公钥(Linux中存储在 ~/.ssh/authorized_keys 文件中),以及其拥有的访问权限。如果没有则断开连接
  3. 服务端使用对应的公钥对一个随机的256位的字符串进行加密,并发送给客户端
  4. 客户端使用私钥对字符串进行解密,并将其结合 Session ID 生成一个MD5值发送给服务端。 结合 Session ID 的目的是为了避免攻击者采用 重放攻击(replay attack)。
  5. 服务端采用同样的方式生成 MD5值 与客户端返回的 MD5值 进行比较,完成对客户端的认证。

图解SSH

参考

  • SSH原理简介
  • SSH协议介绍
  • O‘RELLY的《SSH: The Secure Shell - The Definitive Guide》

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

《JavaScript 正则表达式迷你书》问世了!

发表于 2017-09-28

1.1版,下载链接:github.com/qdlaoyao/js…之前在本站发表了一篇文章,《JS正则表达式完整教程(略长)》,正如你所看到的那样确实很长,也获得了近1000人的喜欢。但文章太长,想必有很多同学看不下去,大都只是收藏罢了。因此我整理成一本pdf。既然只是收藏,那么收藏文章就不如收藏书来的好。其实,整理成pdf的灵感也来自本站读者的反馈。

同时,我也相信我们不可能只做一个收藏家,有人8小时看完,有人花了一周看完,也有个把小时就能看完的。有很多读者看完反馈说,表示弄懂正则了。对此,个人表示非常欣慰,我的这一丁点儿付出,能让别人有所收获,真是没有比此更让人开心的事情了,也算我也为前端界做出的一点小小的贡献。

本书是第一版,对文章部分内容都修改了一下,当然也难免有笔误(勘误信息在此处)。欢迎大家挑毛病,不管是笔误、还是没有说清楚的地方,都欢迎读者留言。一段时间后,我会整理再出了新的版本。目前此书只有pdf格式,我最近也在学习mobi格式制作方法。

后续的大版本也会出的。可能会添加一些新的章节和专题。只是目前个人再找工作,等安心之后再说哈。

下面的内容是我的感谢和后记(有人已经在本站帮我转发过了,感谢!)

感谢

由于本书是由个人文章修改而成,感谢各平台读者的支持。感谢湖山,是他说我该把这些东西写出来的。

感谢小不,他在多方面给予了我帮助,封面是他设计的。感谢小鱼二,他对全书进行了仔细地校对,并提出了相应的修改意见。

感谢丹迪的支持,他为我设计了多个封面,风格比较前卫,留给后续版本。最后,尤其要感谢各位大佬帮我写的推荐序。他们的名字不分先后如下:大漠穷秋、小鱼二、Jack Lo、程序猿DD、江湖人称向前兄、文蔺、_周末、Dark_Night。

后记

我竟然写了一本书!想想就挺开心的。这是个人的第一本书,虽然不厚,但也算是完成了个人的一个小梦想了。

说起正则表达式,我之所以会去详细地研究它,最初的动机是,当我分析前端常见的框架和库的源码时,发现一般被卡住的地方就是它。后来逐渐学习并看懂了“天书”,仿佛进入了一个新世界。有些工具就是这样,当你没有它时,可能并未觉得有啥不好,可是一旦你拥有了它,再也放不下手了。掌握正则了后,对字符串一些复杂操作,竟然能很快地实现。看待问题的角度也发生了改变,每次看着精炼的正则代码,总是感觉真是妙不可言。

当然,对我而言,正则表达式不仅应用在代码里。生活中也会经常使用它。比如个人平时回答网友问题时,一些网站私信里贴的代码中字符都是转义的。此时我都会贴到某个编辑器里,然后写个正则,再一次性替换,真方便。另外一个例子是,一些代码编辑器的代码格式化功能,总有让人不舒服的地方,此时我都会用写好正则表达式,再格式化一下。

还有一个很应景的例子,在编辑本书时,经常要在指定位置插入特定的语法格式,比如代码段前面要插入

1
2
复制代码[source,javascript]
----

这样的字符,此时,我发现我的大部分代码段,都是var开头的,并且前面有一空行。此时我打开查找替换功能,查找

1
复制代码(^\r\n)var

替换为

1
复制代码[source,javascript]\n----\nvar

这确实也帮我解决一部分工作。当然,正则表达式是跟具体语言(比如JavaScript)无关的。因为正则表达式是用来处理字符串问题的,基本上每门语言都有字符串类型,那么也都会支持正则表达式的。正则表达式是分流派的,也跟实现引擎有关。而JavaScript用到的正则表达式的语法,是市面常见语言都支持的核心子集。关于API,各语言基本大同小异,想用的话,应该很快就能熟悉起来。

关于正则表达式就说到这里,下面说一说自己写这本书的收获。有人说最好的学习方法就是写一本书。其实,要想把知识掌握牢固,归根到底就是用起来。写书或者说写作是一种很好的以教为学的手段。毕竟,形成文字,教给别人算是对知识的最直接的应用了。看似为了教,其实是为了学。只有教会别人才说明你掌握了。“以教为学”的手段除了写东西之外,还有翻译、以及面对面的辅导等。

以目标为导向的做中学,是比较有效的学习手段。本书是用Asciidoc写成的。它类似于Markdown,但在此书之前本人都没有用过。以需求为驱动,逐步百度检索,自己才逐渐把书整理好了。其中遇到了很多与语法无关的问题,比如转换pdf的过程中用的工具运行不起来,自己寻找原因,凭着感觉修改版本号等。又比如导出的pdf有缺字的问题,百度明白后才发现跟字体有关。边干边学,每解决掉一个问题,都挺有满足感的。带着问题去研究去学习,这是一种问题思维。然而一时的解决方案还不够,后来我详细地阅读了Asciidoc使用手册,也经常有“原来,还可以这样写!”的体会。这点跟我们平常工作很像,以项目为导向,用啥学啥。比如初学一个框架,先干起来,边看文档,边敲代码。代码敲完了,还要详细地看一遍文档,届时会发现还有更好的实现方式。不只有眼前的苟且,还会有明天的迭代。

另外一点,我深深体会到了,干着简单繁杂的工作是怎样的体验。一遍遍校对,一遍遍修改。每次,看都会发现新的待完善的地方。以至于现在我感觉已经能把本书背下来了,单调的工作确实考验人的耐心。

就写到这里吧。如果你觉得此书不错的话,欢迎赞赏(书中有微信二维码的,看完之后再决定赞赏也不迟)。

最后,我们该想起陆游诗人对前端做出的贡献:

纸上得来终觉浅,觉知此事要躬行。本文完。

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

可能是国内最火的开源项目 —— C/C++ 篇

发表于 2017-09-26

推荐阅读:

  • 可能是最火的开源项目 —— Java 篇
  • 可能是国内最火的开源项目 —— PHP 篇
  • 可能是国内最火的开源项目 —— Python 篇

截止目前开源中国收录了 44513 款开源项目,囊括了最热门的各类开源项目,而软件的评分在一定程度上代表了软件的质量和热度,而 C 和 C++ 语言作为最基础的语言,在各类编程语言排行榜中高居不下,因此本文整理了 C/C++ 语言中评分最高并且收藏量超过 100 的几款项目,以供开发者选择和交流,排名如下:

高性能 TCP/UDP/HTTP 通信框架 HP-Socket

评分:9.8,收藏:1404

HP-Socket 是一套通用的高性能 TCP/UDP/HTTP 通信框架,包含服务端组件、客户端组件和Agent组件,广泛适用于各种不同应用场景的 TCP/UDP/HTTP 通信系统,提供 C/C++、C#、Delphi、E(易语言)、Java、Python 等编程语言接口。HP-Socket 对通信层实现完全封装,应用程序不必关注通信层的任何细节;HP-Socket 提供基于事件通知模型的 API 接口,能非常简单高效地整合到新旧应用程序中。

为了让使用者能方便快速地学习和使用 HP-Socket ,迅速掌握框架的设计思想和使用方法,特此精心制作了大量 Demo 示例(如:PUSH 模型示例、PULL 模型示例、PACK 模型示例、性能测试示例以及其它编程语言示例)。

基于 C++/Python 的开源量化交易研究框架 Hikyuu

评分:8.3,收藏:144

Hikyuu Quant Framework是一款基于C++/Python的开源量化交易研究框架,用于策略分析及回测。其核心思想基于当前成熟的系统化交易方法,将整个系统化交易抽象为由市场环境判断策略、系统有效条件、信号指示器、止损/止盈策略、资金管理策略、盈利目标策略、移滑价差算法七大组件,你可以分别构建这些组件的策略资产库,在实际研究中对它们自由组合来观察系统的有效性、稳定性以及单一种类策略的效果。

10000-overview.png

开源自动驾驶平台 ApolloAuto

评分:8.1,收藏:156

Apollo (阿波罗)是一个开放的、完整的、安全的平台,将帮助汽车行业及自动驾驶领域的合作伙伴结合车辆和硬件系统,快速搭建一套属于自己的自动驾驶系统。

Apollo 是百度重点打造的 AI 开放平台之一,计划主要包含 4 个技术模块:定位/感知模块、车辆规划与运营(AI+大数据,精准控制车辆,适合不同路况)、软件运营框架(支持英特尔、英伟达等多种芯片)。

分布式图片实时动态压缩 ngx-fastdfs

评分:8.1,收藏:215

ngx-fastdfs 是 nginx + lua +fastdfs 实现分布式图片实时动态压缩。

cut.png

高性能 RPC 开发框架 Tars

评分:8.0,收藏:296

Tars 是基于名字服务使用 Tars 协议的高性能 RPC 开发框架,同时配套一体化的服务治理平台,帮助个人或者企业快速的以微服务的方式构建自己稳定可靠的分布式应用。它是将腾讯内部使用的微服务架构 TAF(Total Application Framework)多年的实践成果总结而成的开源项目。

目前该框架在腾讯内部,有 100 多个业务(如手机浏览器、应用宝、手机管家、手机QQ、手机游戏等)、1.6 多万台服务器上运行使用。

Go语言开发工具 LiteIDE

评分:7.9,收藏:384

LiteIDE是一款开源、跨平台的轻量级Go语言集成开发环境(IDE)。

分布式TCP压力测试工具 tcpcopy

评分:7.9,收藏:380

tcpcopy是一种应用请求复制(基于tcp的packets)工具,其应用领域较广,目前已经应用于国内各大互联网公司。总体说来,tcpcopy主要有如下功能:

  • 分布式压力测试工具,利用在线数据,可以测试系统能够承受的压力大小(远比ab压力测试工具真实地多),也可以提前发现一些bug
  • 普通上线测试,可以发现新系统是否稳定,提前发现上线过程中会出现的诸多问题,让开发者有信心上线
  • 对比试验,同样请求,针对不同或不同版本程序,可以做性能对比等试验
  • 利用多种手段,构造无限在线压力,满足中小网站压力测试要求
  • 实战演习(架构师必备)

tcpcopy可以用于实时和离线回放领域,并且tcpcopy支持mysql协议的复制,开源二年以来,功能上越来越完善。如果你对上线没有信心,如果你的单元测试不够充分,如果你对新系统不够有把握,如果你对未来的请求压力无法预测,tcpcopy可以帮助你解决上述难题。

中文文本转语音引擎 Ekho

评分:7.9,收藏:393

Ekho(余音)是一个把文字转换成声音的软件。它目前支持粤语、普通话(国语)、诏安客语、藏语、雅言(中国古代通用语)和韩语(试验中),英文则通过Festival间接实现。支持Linux、Windows、Android.

在 Linux 系统中运行 Android 应用 Anbox

评分:7.8,收藏:191

Anbox 可让你在任何 GNU/Linux 操作系统上运行 Android 应用程序。具有以下特性:

  • 没有限制:由于 Anbox 运行着整个 Android 系统,所以理论上任何应用都可以在其中运行
  • 安全:Anbox 将 Android APP 放进一个密封的盒子中,无需直接访问硬件或数据
  • 性能:无需虚拟化硬件而运行 Android,可以无缝桥接硬件加速功能
  • 集成:与主机操作系统紧密集成,以提供丰富的功能集

机器学习系统 TensorFlow

评分:7.8,收藏:602

TensorFlow 是谷歌的第二代机器学习系统,按照谷歌所说,在某些基准测试中,TensorFlow的表现比第一代的DistBelief快了2倍。

TensorFlow 内建深度学习的扩展支持,任何能够用计算流图形来表达的计算,都可以使用TensorFlow。任何基于梯度的机器学习算法都能够受益于TensorFlow的自动分 化(auto-differentiation)。通过灵活的Python接口,要在TensorFlow中表达想法也会很容易。TensorFlow 对于实际的产品也是很有意义的。将思路从桌面GPU训练无缝搬迁到手机中运行。

MySQL衍生版 Percona Server

评分:7.8,收藏:426

Percona 为 MySQL 数据库服务器进行了改进,在功能和性能上较 MySQL 有着很显著的提升。该版本提升了在高负载情况下的 InnoDB 的性能、为 DBA 提供一些非常有用的性能诊断工具;另外有更多的参数和命令来控制服务器行为。

Percona Server 只包含 MySQL 的服务器版,并没有提供相应对 MySQL 的 Connector 和 GUI 工具进行改进。Percona Server 使用了一些 google-mysql-tools, Proven Scaling, Open Query 对
MySQL 进行改造。

数据中间层项目 ProxySQL

评分:7.8,收藏:128

ProxySQL 是一个高性能,高可用性,的数据中间层项目。它具有先进的多核架构。 它从根本上构建,支持数十万个并发连接,复用到可能数百个后端服务器。 最大的 ProxySQL 部署跨越了几百个代理。

开源网盘云存储 Seafile

评分:7.8,收藏:1499

Seafile 是一款安全、高性能的开源网盘(云存储)软件。Seafile 提供了主流网盘(云盘)产品所具有的功能,包括文件同步、文件共享等。在此基础上,Seafile 还提供了高级的安全保护功能以及群组协作功能。由于 Seafile 是开源的,你可以把它部署在私有云的环境中,作为私有的企业网盘。Seafile 支持 Mac、Linux、Windows 三个桌面平台,支持 Android 和 iOS 两个移动平台。

Seafile 是由国内团队开发的国际型项目,目前已有50万左右的用户,以欧洲用户为多。自发布以来,Seafile 一直保持开放、国际化、高质量的宗旨,受到国内外大型机构的信赖。目前主要的大型客户包括卡巴斯基、中国平安,以及欧美多家知名大学和科研机构。你可以把它想象成是面向团队的开源Dropbox。

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

公司编程竞赛之最长路径问题 前言 前置条件 深度优先搜索算法

发表于 2017-09-25

算法源码地址:https://github.com/wudashan/longest-path-problem

前言

最近产品线举办了一个软件编程大赛,题目非常的有趣,就是在一个9 × 9的格子里,你要和另一个敌人PK,在PK的过程中,你可以吃格子里的果实来提升攻击力。每次可以往正上、正下、正左、正右、左上、左下、右上、右下八个方向走。每次要么连续吃果实要么连续走空白区域,且不能走重复的位置。初始状态如下图所示:

为了提升攻击力,我们需要尽可能地一次吃最多的果实,所以路线可以这样规划:

至此,我们可以对这个问题进行描述:已知空白区域不能走,每次可以往正上、正下、正左、正右、左上、左下、右上、右下八个方向走,走过的位置不能再走,求能吃最多果实的路线(最长路径问题)?


前置条件

地图表示

首先我们将上面的地图使用布尔类型的二维数组表示,其中true表示可以行走的格子,false表示不能行走的格子:

1
2
3
4
5
6
7
8
9
10
11
复制代码boolean[][] simpleMap = new boolean[][] {
{false, false, false, false, false, false, false, false, false},
{false, false, false, false, false, false, true , true , false},
{false, false, false, true , false, false, true , true , false},
{false, false, true , false, false, false, false, false, false},
{false, false, true , false, false, false, false, false, false},
{false, false, true , false, false, false, false, false, false},
{false, false, false, true , false, true , false, false, false},
{false, false, false, false, true , true , false, false, false},
{false, false, false, false, false, false, false, false, false}
};

格子表示

对于地图上的每一个格子,我们用一个简单类来表示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
复制代码public class Pos {

private int x; // 横坐标
private int y; // 纵坐标

// get、set、construct方法省略

@Override
public String toString() {
final StringBuffer sb = new StringBuffer("Pos{");
sb.append("x=").append(x);
sb.append(", y=").append(y);
sb.append('}');
return sb.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

Pos pos = (Pos) o;

if (x != pos.x) return false;
return y == pos.y;
}

@Override
public int hashCode() {
int result = x;
result = 31 * result + y;
return result;
}


}

由于我们是使用横纵坐标而不是几行几列来表示一个格子(没错,我就是这么傲娇),那么我们就需要给地图定义横纵坐标方向。方向如下图所示:

那么起点上方的果实坐标就是[3, 2](横坐标为3,纵坐标为2),但是对应着二维数组为map[2][3](第二行,第三列),即横坐标对应着二维数组的列,纵坐标对应着二维数组的行。

移动表示

为了程序简洁,我们给八个方向的移动定义对应的偏移量,这样每次行走只要对偏移量数组进行for循环就可以了。

1
2
3
4
5
6
7
8
9
10
复制代码Pos[] moveOffset = new Pos[] {
new Pos(-1, 0), // 向左移动
new Pos(-1, -1), // 向左上移动
new Pos( 0, -1), // 向上移动
new Pos( 1, -1), // 向右上移动
new Pos( 1, 0), // 向右移动
new Pos( 1, 1), // 向右下移动
new Pos( 0, 1), // 向下移动
new Pos(-1, 1) // 向左下移动
};

深度优先搜索算法

算法思想

拿到这道题,脑袋里第一个想到的就是深度优先搜索算法,其思想为每次往八个方向递归,当不能继续走下去的时候保存最长路径,并回退到能继续行走的格子,继续递归直到结束。

代码示例

接下来就是我们的深度优先搜索代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
复制代码/**
* 通过深度优先搜索算法获取最长路径
* @param map 地图
* @param start 起点
* @param moveOffset 移动偏移量
* @return 最长路径
*/
public static List<Pos> getLongestPathByDFS(boolean[][] map, Pos start, Pos[] moveOffset) {

List<Pos> longestPath = new ArrayList<>();
dfs(start, map, new ArrayList<>(), longestPath, moveOffset);
return longestPath;

}

/**
* 递归实现深度优先搜索
*/
private static void dfs(Pos pos, boolean[][] map, List<Pos> path, List<Pos> result, Pos[] moveOffset) {

// 记录当前位置向周围格子移动的记录
List<Pos> visited = new ArrayList<>();

// 保存当前位置的周围格子
Pos[] neighbours = new Pos[moveOffset.length];

// 依次向周围移动
for (int i = 0; i < moveOffset.length; i++) {
Pos next = new Pos(pos.getX() + moveOffset[i].getX(), pos.getY() + moveOffset[i].getY());
neighbours[i] = next;
if (inMap(map, next) && !path.contains(next) && map[next.getY()][next.getX()]) {
path.add(next);
visited.add(next);
dfs(next, map, path, result, moveOffset);
}
}

// 若在当前位置下,没有向周围的格子移动过时,保存最长路径
if (visited.isEmpty()) {
if (path.size() > result.size()) {
result.clear();
result.addAll(path);
}
}

// 周围的格子都不可以移动时回退到上一格子
for (Pos neighbour : neighbours) {
if (canPath(map, path, neighbour, visited)) {
return;
}
}
path.remove(pos);

}

/**
* 判断格子是否可以移动
*/
private static boolean canPath(boolean[][] map, List<Pos> path, Pos pos, List<Pos> visited) {

// 不在地图里,不能移动
if (!inMap(map, pos)) {
return false;
}

// 空白格子,不能移动
if (!map[pos.getY()][pos.getX()]) {
return false;
}

// 已经在路径中或经过,不能移动
if (path.contains(pos) || visited.contains(pos)) {
return false;
}

return true;
}

/**
* 判断格子是否在地图内
*/
private static boolean inMap(boolean[][] map, Pos pos) {

if (pos.getY() < 0 || pos.getY() >= map.length) {
return false;
}

if (pos.getX() < 0 || pos.getX() >= map[0].length) {
return false;
}

return true;

}

接下来,就让我们在主函数里验证一下结果吧!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
复制代码public static void main(String[] args) {

// 初始化参数
boolean[][] simpleMap = new boolean[][] {
{false, false, false, false, false, false, false, false, false},
{false, false, false, false, false, false, true , true , false},
{false, false, false, true , false, false, true , true , false},
{false, false, true , false, false, false, false, false, false},
{false, false, true , false, false, false, false, false, false},
{false, false, true , false, false, false, false, false, false},
{false, false, false, true , false, true , false, false, false},
{false, false, false, false, true , true , false, false, false},
{false, false, false, false, false, false, false, false, false}
};
Pos[] moveOffset = new Pos[] {
new Pos(-1, 0), // 向左移动
new Pos(-1, -1), // 向左上移动
new Pos( 0, -1), // 向上移动
new Pos( 1, -1), // 向右上移动
new Pos( 1, 0), // 向右移动
new Pos( 1, 1), // 向右下移动
new Pos( 0, 1), // 向下移动
new Pos(-1, 1) // 向左下移动
};
Pos start = new Pos(3, 3);

// 执行深度优先算法
List<Pos> longestPath = getLongestPathByDFS(simpleMap, start, moveOffset);

// 打印路径
System.out.println(longestPath);

}

执行Main函数之后,控制台将输出[Pos{x=3, y=2}, Pos{x=2, y=3}, Pos{x=2, y=4}, Pos{x=2, y=5}, Pos{x=3, y=6}, Pos{x=4, y=7}, Pos{x=5, y=6}, Pos{x=5, y=7}],即行走的最长路径。

虽然深度优先搜索算法可以计算出最长路径,但是它的时间复杂度却高得惊人!已知每次可以向8个方向移动,最多可以走m × n步(地图的长和宽),那么时间复杂度就是 O(8mn)。由于我们上面的地图可以走的选择比较单一,所以在我的电脑上1ms就可以算出结果。感兴趣的童鞋可以试试下面这个地图在你们的电脑上需要多久出结果:

1
2
3
4
5
6
7
8
9
10
11
复制代码boolean[][] complexMap = new boolean[][] {
{false, true, true, false, false, true, true, false, true},
{true, false, false, false, true, false, false, false, true},
{true, true, false, false, true, true, false, false, false},
{false, true, true, false, false, true, true, true, false},
{false, true, true, false, false, true, true, false, true},
{true, false, false, false, true, false, false, false, true},
{true, true, false, false, true, true, false, false, false},
{false, true, true, true, false, true, true, true, false},
{false, true, true, false, false, true, true, false, true}
};

至少,在我的电脑上需要5471ms才能得出结果,非常的夸张!由于产品线比赛要求每次计算时间不能超过1000ms,所以使用该算法基本不可行。那么是否有时间更快的算法呢?别走开,答案就在下面。


贪心算法

算法思想

贪心算法采用的是这样一种思想:每次都走出路最少的格子,这样后面可以选择的余地就比较大,最优解的概率也就大的多。

代码示例

下面便是贪心算法的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
复制代码/**
* 通过贪心算法获取最长路径
* @param map 地图
* @param start 起点
* @param moveOffset 移动偏移量
* @return 最长路径
*/
public static List<Pos> getLongestPathByChain(boolean[][] map, Pos start, Pos[] moveOffset) {

List<Pos> longestPath = new ArrayList<>();
chain(start, map, new ArrayList<>(), longestPath, moveOffset);
return longestPath;

}


/**
* 递归实现贪心算法
*/
private static void chain(Pos pos, boolean[][] map, List<Pos> path, List<Pos> result, Pos[] moveOffset) {

// 获取出路最小的格子
Pos minWayPos = getMinWayPos(pos, map, moveOffset);

if (minWayPos != null) {
// 递归搜寻路径
path.add(minWayPos);
map[minWayPos.getY()][minWayPos.getX()] = false;
chain(minWayPos, map, path, result, moveOffset);
} else {
// 当前无路可走时保存最长路径
if (path.size() > result.size()) {
result.clear();
result.addAll(path);
}
}

}

/**
* 获取当前格子周围出路最小的格子
*/
private static Pos getMinWayPos(Pos pos, boolean[][] map, Pos[] moveOffset) {

int minWayCost = Integer.MAX_VALUE;
List<Pos> minWayPoss = new ArrayList<>();

for (int i = 0; i < moveOffset.length; i++) {
Pos next = new Pos(pos.getX() + moveOffset[i].getX(), pos.getY() + moveOffset[i].getY());
if (inMap(map, next) && map[next.getY()][next.getX()]) {
int w = -1;
for (int j = 0; j < moveOffset.length; j++) {
Pos nextNext = new Pos(next.getX() + moveOffset[j].getX(), next.getY() + moveOffset[j].getY());
if (inMap(map, nextNext) && map[nextNext.getY()][nextNext.getX()]) {
w++;
}
}
if (minWayCost > w) {
minWayCost = w;
minWayPoss.clear();
minWayPoss.add(next);
} else if (minWayCost == w) {
minWayPoss.add(next);
}
}
}

if (minWayPoss.size() != 0) {
// 随机返回一个出路最小的格子
return minWayPoss.get((int) (Math.random() * minWayPoss.size()));
} else {
return null;
}

}

写好算法之后,再验证一下结果!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
复制代码public static void main(String[] args) {

// 初始化参数
boolean[][] simpleMap = new boolean[][] {
{false, false, false, false, false, false, false, false, false},
{false, false, false, false, false, false, true , true , false},
{false, false, false, true , false, false, true , true , false},
{false, false, true , false, false, false, false, false, false},
{false, false, true , false, false, false, false, false, false},
{false, false, true , false, false, false, false, false, false},
{false, false, false, true , false, true , false, false, false},
{false, false, false, false, true , true , false, false, false},
{false, false, false, false, false, false, false, false, false}
};
Pos[] moveOffset = new Pos[] {
new Pos(-1, 0), // 向左移动
new Pos(-1, -1), // 向左上移动
new Pos( 0, -1), // 向上移动
new Pos( 1, -1), // 向右上移动
new Pos( 1, 0), // 向右移动
new Pos( 1, 1), // 向右下移动
new Pos( 0, 1), // 向下移动
new Pos(-1, 1) // 向左下移动
};
Pos start = new Pos(3, 3);

// 执行贪心算法
List<Pos> longestPath = getLongestPathByChain(simpleMap, start, moveOffset);

// 打印路径
System.out.println(longestPath);

}

执行Main函数之后,控制台将输出[Pos{x=3, y=2}, Pos{x=2, y=3}, Pos{x=2, y=4}, Pos{x=2, y=5}, Pos{x=3, y=6}, Pos{x=4, y=7}, Pos{x=5, y=7}, Pos{x=5, y=6}],路径长度与深度优先搜索算法一致,即也能找到最长路径。

那么在复杂一点的地图上,与深度优先搜索相比,贪心算法的结果怎么样呢?在我的机器上,计算结果如下:

simpleMap complexMap
深度优先搜索算法 最长路径为8步,计算时间为1ms 最长路径为33步,计算时间为5254ms
贪心算法 最长路径为8步,计算时间为1ms 最长路径为4/9/20步,计算时间为1ms

从结果上可以发现,由于贪心算法并没有遍历所有路径,而是每次都往出路最少的格子走,所以时间上快很多,但是其结果却非常地不稳定,这是因为贪心算法容易陷入局部最优解的情况!

显然如果用贪心算法来寻路吃果实,那么能不能打败敌人就要靠运气了。怎么办?是否有折中一点的算法,既耗费可接受的时间,又可以计算出较好的结果呢?答案还是肯定的,接下来就介绍更高端的算法——模拟退火算法。


模拟退火算法

算法思想

模拟退火算法的灵感是来自物理学里的固体退火原理:将固体加热时,固体内部粒子随温度上升变为无序状态,内能不断增大;当慢慢冷却时内部粒子逐渐有序,在每个温度都达到平衡态,最后在常温时达到基态,内能减为最小。

用计算机语言来描述的话就是:在函数不断迭代的过程中,以一定的概率来接受一个比当前解要差的新解,因此有可能会跳出这个局部最优解,从而达到全局最优解。

代码示例

由于是高端算法,所以代码会比较多,但据说能看完模拟退火算法代码的人智商都超过180!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
复制代码/**
* 通过模拟退火算法获取最长路径
* @param map 地图
* @param start 起点
* @param moveOffset 移动偏移量
* @return 最长路径
*/
public static List<Pos> getLongestPathBySA(boolean[][] map, Pos start, Pos[] moveOffset) {

// 初始化退火参数
double temperature = 100.0;
double endTemperature = 1e-8;
double descentRate = 0.98;
double count = 0;
double total = Math.log(endTemperature / temperature) / Math.log(descentRate);
int iterations = map.length * map[0].length;
List<Pos> longestPath = new ArrayList<>();
List<List<Pos>> paths = new ArrayList<>();
for (int i = 0; i < iterations; i++) {
boolean[][] cloneMap = deepCloneMap(map);
List<Pos> path = initPath(cloneMap, start, moveOffset);
paths.add(path);
}

// 降温过程
while (temperature > endTemperature) {

// 迭代过程
for (int i = 0; i < iterations; i++) {

// 取出当前解,并计算函数结果
List<Pos> path = paths.get(i);
int result = caculateResult(path);

// 在邻域内产生新的解,并计算函数结果
boolean[][] cloneMap = deepCloneMap(map);
List<Pos> newPath = getNewPath(cloneMap, path, moveOffset, count / total);
int newResult = caculateResult(newPath);

// 根据函数结果判断是否替换解
if (newResult - result < 0) {
// 替换
path.clear();
path.addAll(newPath);
} else {
// 以一定的概率替换
double p = 1 / (1 + Math.exp(-(newResult - result) / temperature));
if (Math.random() < p) {
path.clear();
path.addAll(newPath);
}
}

}

count++;
temperature = temperature * descentRate;

}

// 返回一条最长路径
for (int i = 0; i < paths.size(); i++) {
if (paths.get(i).size() > longestPath.size()) {
longestPath = paths.get(i);
}
}
return longestPath;

}

/**
* 深拷贝地图
*/
private static boolean[][] deepCloneMap(boolean[][] map) {
boolean[][] cloneMap = new boolean[map.length][];
for (int i = 0; i < map.length; i++) {
cloneMap[i] = map[i].clone();
}
return cloneMap;
}

/**
* 初始化路径
*/
private static List<Pos> initPath(boolean[][] map, Pos start, Pos[] moveOffset) {
List<Pos> path = new ArrayList<>();
getPath(map, start, path, moveOffset);
return path;
}

/**
* 根据当前路径继续移动到底,采用随机移动策略
*/
private static void getPath(boolean[][] map, Pos current, List<Pos> path, Pos[] moveOffset) {

boolean end = true;
List<Pos> neighbours = new ArrayList<>();
for (int i = 0; i < moveOffset.length; i++) {
Pos neighbour = new Pos(current.getX() + moveOffset[i].getX(), current.getY() + moveOffset[i].getY());
if (inMap(map, neighbour) && map[neighbour.getY()][neighbour.getX()]) {
end = false;
neighbours.add(neighbour);
}
}
if (end) {
return;
} else {
Pos random = neighbours.get((int) (Math.random() * neighbours.size()));
map[random.getY()][random.getX()] = false;
path.add(random);
getPath(map, random, path, moveOffset);
}

}

/**
* 计算函数结果,函数结果为路径负长度
*/
private static int caculateResult(List<Pos> path) {
return -path.size();
}


/**
* 根据当前路径和降温进度,生成一条新路径
*/
private static List<Pos> getNewPath(boolean[][] map, List<Pos> path, Pos[] moveOffset, double ratio) {

int size = (int) (path.size() * ratio);
if (size == 0) {
size = 1;
}
if (size > path.size()) {
size = path.size();
}

List<Pos> newPath = new ArrayList<>();
for (int i = 0; i < size; i++) {
Pos pos = path.get(i);
newPath.add(pos);
map[pos.getY()][pos.getX()] = false;
}

getPath(map, newPath.get(newPath.size() - 1), newPath, moveOffset);
return newPath;

}

测试代码我就不再列出了,最后让我们看一下这三种算法在两个地图上的执行结果:

simpleMap complexMap
深度优先搜索算法 最长路径为8步,计算时间为1ms 最长路径为33步,计算时间为5254ms
贪心算法 最长路径为8步,计算时间为1ms 最长路径为4/9/20步,计算时间为1ms
模拟退火算法 最长路径为8步,计算时间为147ms 最长路径为30~33步,计算时间为212ms

总结

求最长路径问题可以看成是哈密尔顿路径问题,由于寻找哈密尔顿路径是一个典型的NPC问题,所以不能在多项式时间内得到最优解。感兴趣的小伙伴可以去了解一下相关的知识,我在参考阅读章节给出了几个相应的链接。

解决这类问题,我们可以通过深度优先搜索算法得到最优解,但是时间复杂度是指数级的;也可以通过贪心算法得到一个局部最优解,其时间复杂度是线性级的,但得到的解时好时坏;还可以通过模拟退火算法得到一个近似解,这个时间复杂度也是线性级的,只要退火参数配置得当,其解是稳定地,且是一个趋向最优解的近似解。


参考阅读

[1] 哈密顿图 - 维基百科

[2] 贪婪算法求解哈密尔顿路径问题 - 51CTO博客

[3] 什么是P问题、NP问题和NPC问题

[4] 最长路径问题 - 维基百科

[5] 模拟退火 - 维基百科

[6] 大白话解析模拟退火算法 - 博客园


本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

【Python】在一段Python程序中使用多次事件循环 背

发表于 2017-09-21

背景

我们在Python异步程序编写中经常要用到如下的结构

1
2
3
4
5
6
7
8
复制代码import asyncio
async def doAsync():
await asyncio.sleep(0)
#...
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(doAsync())
loop.close()

这当然是很不错的,但当你第二次使用loop的时候程序就会抛出异常RuntimeError: Event loop is closed,这也无可厚非,理想的程序也应该是在一个时间循环中解决掉各种异步IO的问题。
但放在终端环境如Ipython中,如果想要练习Python的异步程序的编写的话每次都要重新开启终端未免太过于麻烦,这时候要探寻有没有更好的解决方案。

解决方案

我们可以使用asyncio.new_event_loop函数建立一个新的事件循环,并使用asyncio.set_event_loop设置全局的事件循环,这时候就可以多次运行异步的事件循环了,不过最好保存默认的asyncio.get_event_loop并在事件循环结束的时候还原回去。
最终我们的代码就像这样。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
复制代码import asyncio
async def doAsync():
await asyncio.sleep(0)
#...
def runEventLoop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(doAsync())
loop.close()
if __name__ == "__main__":
oldloop = asyncio.get_event_loop()
runEventLoop()
runEventLoop()
asyncio.set_event_loop(oldloop)

感想

事件循环本来就是要一起做很多事情,在正式的Python代码中还是只用一个默认的事件循环比较好,平时的学习练习的话倒是随意了。

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

java垃圾回收算法之-coping复制 概述 应用场景 优

发表于 2017-09-20

之前的java垃圾回收算法之-标记清除 会导致内存碎片。下文的介绍的coping算法可以解决内存碎片问题。

概述


如果jvm使用了coping算法,一开始就会将可用内存分为两块,from域和to域, 每次只是使用from域,to域则空闲着。当from域内存不够了,开始执行GC操作,这个时候,会把from域存活的对象拷贝到to域,然后直接把from域进行内存清理。

应用场景


coping算法一般是使用在新生代中,因为新生代中的对象一般都是朝生夕死的,存活对象的数量并不多,这样使用coping算法进行拷贝时效率比较高。

jvm将Heap 内存划分为新生代与老年代,又将新生代划分为Eden(伊甸园) 与2块Survivor Space(幸存者区) ,然后在Eden –>Survivor Space 以及From Survivor Space 与To Survivor Space 之间实行Copying 算法。 不过jvm在应用coping算法时,并不是把内存按照1:1来划分的,这样太浪费内存空间了。一般的jvm都是8:1。也即是说,Eden区:From区:To区域的比例是

8:1:1

始终有90%的空间是可以用来创建对象的,而剩下的10%用来存放回收后存活的对象。

这里写图片描述

这里写图片描述

1、当Eden区满的时候,会触发第一次young gc,把还活着的对象拷贝到Survivor From区;当Eden区再次触发young gc的时候,会扫描Eden区和From区域,对两个区域进行垃圾回收,经过这次回收后还存活的对象,则直接复制到To区域,并将Eden和From区域清空。
2、当后续Eden又发生young gc的时候,会对Eden和To区域进行垃圾回收,存活的对象复制到From区域,并将Eden和To区域清空。
3、可见部分对象会在From和To区域中复制来复制去,如此交换15次(由JVM参数MaxTenuringThreshold决定,这个参数默认是15),最终如果还是存活,就存入到老年代

注意:

1
复制代码万一存活对象数量比较多,那么To域的内存可能不够存放,这个时候会借助老年代的空间。

优点


在存活对象不多的情况下,性能高,能解决内存碎片和java垃圾回收算法之-标记清除 中导致的引用更新问题。

缺点


  • 会造成一部分的内存浪费。不过可以根据实际情况,将内存块大小比例适当调整;
  • 如果存活对象的数量比较大,coping的性能会变得很差。

原文链接


java垃圾回收算法之-coping复制

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

吴恩达Coursera Deep Learning学习笔记

发表于 2017-09-20

【写在前面】作为一名终身学习实践者,我持之以恒地学习各种深度学习和机器学习的新知识。一个无聊的假日里我突然想到:反正一样要记笔记,为什么不把我学习的笔记写成博客供大家一起交流呢?于是,我就化身数据女侠,打算用博客分享的方式开启我的深度学习深度学习之旅。

本文仅供学习交流使用,侵权必删,不用于商业目的,转载须注明出处。

【学习笔记】

之前我们以逻辑回归 (logistic regression) 为例介绍了神经网络(),但它并没有隐藏层,所以并不算严格意义上的神经网络。在本文中,让我们随着Andrew一起深化神经网络,在sigmoid之前再增加一些ReLU神经元。最后我们会以疯狂收到AI科学家迷恋的可爱猫咪图案为例,用深度学习建立一个猫咪图案的识别模型。不过在这之前,让我们来看一下除了上次说到的sigmoid和ReLU之外,还有什么激活函数、他们之间又各有什么优劣——

激活函数 Activation Functions

1) sigmoid
除了在输出层(当输出是{0,1}的binary classification时)可能会用到之外,隐藏层中很少用到sigmoid,因为它的mean是0.5,同等情况下用均值为0的tanh函数取代。
2) tanh其实就是sigmoid的shifted版本,但输出从(0, 1)变为(-1, 1),所以下一层的输入是centered,更受欢迎。
3) ReLU (Rectified Linear Unit)在
吴恩达Coursera Deep Learning学习笔记 1 (上)中提到过ReLU激活函数,它在深度学习中比sigmoid和tanh更常用。这是因为当激活函数的输入z的值很大或很小的时候,sigmoid和tanh的梯度非常小,这会大大减缓梯度下降的学习速度。所以与sigmoid和tanh相比,ReLU的训练速度要快很多。
4) Leaky ReLULeaky ReLU比ReLU要表现得稍微好一丢丢但是实践中大家往往都用ReLU。


从浅神经网络到深神经网络

最开始的逻辑回归的例子中,并没有隐藏层。在接下来的课程中,Andrew分别介绍了1个隐藏层、2个隐藏层和L个隐藏层的神经网络。但是只要把前向传播和反向传播搞明白了,再加上之后会讲述的一些小撇步,就会发现其实都是换汤不换药。大家准备好了吗?和我一起深呼吸再一头扎进深度学习的海洋吧~

一般输入层是不计入神经网络的层数的,如果一个神经网络有L层,那么就意味着它有L-1个隐藏层和1个输出层。我们可以观察到输入层和输出层,但是不容易观察到中间数据是怎么变化的,因此在输入和输出层之间的部分叫隐藏层。

训练一个深神经网络大致分为以下步骤(吴恩达Coursera Deep Learning学习笔记 1 (上)也有详细说明):

  1. 定义神经网络的结构(超参数)
  2. 初始化参数
  3. 循环迭代
    3.1 在前向传播中,分为linear forward和activation forward。在linear
    forward中,Z[l]=W[l]A[l−1]+b[l],其中A[0]=X;在activation forward中,A[l]=g(Z[l])。期间要储存W、b、Z的值。最后算出当前的Loss。
    3.2 在反向传播中,分为activation backward和linear backward。
    3.3 更新参数。
    下图展现了一个L层的深度神经网络、其中L-1个隐藏层都用的是ReLU激活函数的训练步骤和过程。

一个L-1个ReLU隐藏层的训练过程

标注声明 Notations:
随着神经网络的模型越来越深,我们会有L-1个隐藏层,每一层都用小写的L即[l]标注为上标。z是上一层的输出(即这一层的输入)的线性组合,a是这一层通过激活函数对z的非线性变化,也叫激活值 (activation values)。训练数据中有m个样本,每一个样本都用(i)来标注为上标。每一个隐藏层l里都有n[l](上标内是小写的L不是1)个神经元,每一个神经元都用i标注为下标。

每一层、样本、单元的标注

超参数

随着我们的深度学习模型越来越复杂,我们要学习区分普通参数 (Parameters) 和超参数 (Hyperparameters)。 在上图的标注声明中出现的W和b是普通的参数,而超参数是指:
学习速率 (learning rate) alpha、循环次数 (# iterations)、隐藏层层数 (# hidden layers) L、每隐藏层中的神经元数量 (size of the hidden layers) n[l] ——上标内是小写的L不是1、激活函数 (choice of activation functions) g[l] ——上标内是小写的L不是1。除了这些超参数之外,之后还会学习到以下超参数:动量 (momentum)、小批量更新的样本数
(minibatch size)、正则化系数 (regularization weight),等等。

随机初始化 Random Initialization

吴恩达Coursera Deep Learning学习笔记 1 (上)中的逻辑回归 (logistic regression) 中没有隐藏层,所以将W直接初始化为0并无大碍。但是在初始化隐藏层的W时如果将每个神经元的权重都初始化为0,那么在之后的梯度下降中,每一个神经元的权重都会有相同的梯度和更新,这样的对称在梯度下降中永远无法打破,如此就算隐藏层中有一千一万个神经元,也只同于一个神经元。所以,为了打破这种对称的魔咒,在初始化参数时往往会加入一些微小的抖动,即用较小的随机数来初始化W,偏置项b可以初始化为0或者同样是较小的随机数。在Python中,可以用np.random.randn(a,b)
* 0.01来随机地初始化a*b的W矩阵,并用np.zeros((a, 1))来初始化a*1的b矩阵。

为什么是0.01呢?同sigmoid和tanh中所说,数据科学家通常会从将W initialize为很小的随机数,防止训练的速度过缓。但是如果神经网络很深的话,0.01这样小的数字未必最理想。但是总体来说,人们还是倾向于从较小的参数开始训练。

承接上面的超参数,对于每个隐藏层中的神经元数量,我们可以将这几个超参数设定为layer_dims的array,如layer_dims = [n_x, 4,3,2,1] 说明输入的X有n_x个特征,第一层有4个神经元,第二层有3个神经元,第三层有2个,最后一个输出单元。有一个容易搞错的地方,就是W[l]是n[l]*n[l-1]的矩阵,b[l]是n[l]*1的矩阵。所以初始化W和b就可以写成:
for l in range(1, L):
parameters[“W”+str(l)]
= np.random.randn(layer_dims[l],layer_dims[l-1])*0.01
parameters[“b”+str(l)] =np.zeros((layer_dims[l],1))
详见例2中的initialize_parameters_deep函数。

并不是很复杂有没有!那么,下面我们一起跟着Andrew来看几个神经网络的例子——


【例 1】用单个隐藏层的神经网络来分类平面数据

本例:4个神经元的隐藏层 (tanh) 加一个sigmoid的输出层

第三课的例子是Planar data classification with one hidden layer,即帮助大家搭建一个上图所示的浅神经网络(Shallow Neural Networks):一个4个单元的隐藏层 (tanh) 加一个sigmoid的输出层。

本例反向传播中各个梯度的计算(左为一个样本的情况,右为矢量化运算)

最后一步的prediction是用了0.5的cutoff,很简单:

将A2转化为Y-hat

最后的决策边界如下图,在训练数据上的精确度为90%,是不是比logistic regression表现强多啦?可见logistic regression不能学习到数据中的非线性关系,而神经网络可以(哪怕是本例中一个非常浅的神经网络)。

决策边界:一个有四个神经元的隐藏层的神经网络

其实本例中的模型也很简单,如果再复杂些可以做到更精确,但是可能会overfit,毕竟从上图中可以看出现有的模型已经抓住了数据中的大趋势。下面尝试了在隐藏层中设置不同个数的神经元,来看模型的精确度和决策边界是如何变化的:
Accuracy for 1 hidden units: 67.5 %
Accuracy for 2 hidden units: 67.25 %
Accuracy for 3 hidden units: 90.75 %
Accuracy
for 4 hidden units: 90.5 %
Accuracy for 5 hidden units: 91.25 %
Accuracy for 20 hidden units: 90.0 %
Accuracy for 50 hidden units: 90.25 %

隐藏层有5个神经元vs.20个神经元的决策边界

可以看到,在训练数据上,5个神经元的精确度是最高的,而当神经元数超过20时,决策边界就显示有overfitting的情况了。不过没事,之后会学习正则化 (regularization),能使很复杂的神经网络都不会出现overfitting。

这个例子的代码很简单,就不贴了。


【例 2】L层深度神经网络

第四节课的例子有三个:第一是一个ReLU+sigmoid的浅层神经网络,是为了后面的例子做铺垫;第二个将其深化,用了L-1个ReLU层,输出层也是sigmoid;第三个例子就是用前两个神经网络训练猫咪识别模型[吐血]。我将L层的模型和其猫咪识别器的训练过程精简地说一下。

设计神经网络与随机初始化参数
下图就是我们要搭建的L层神经网络,不过在这之前,让我们先挑个lucky number方便以后重复训练结果^_^
np.random.seed(1)

本例:[线性组合 -> ReLU激活] \ L-1次 -> 线性组合 -> sigmoid输出*

其次,让我们设计一下我们的神经网络。因为激活函数已经确定用ReLU了,所以在本例中我们只需设计layer_dims,就能确定输入的维度、层数和每层的神经元数。

def initialize_parameters_deep(layer_dims):
parameters = {}
L = len(layer_dims) # 根据我们一开始设计的模型超参数,读取L(其实是L+1)

for l in range(1, L): # 设定W1到W(L-1)和b1和b(L-1),一共有L-1层(其实是L)
parameters[‘W’ + str(l)] = np.random.randn(layer_dims[l], layer_dims[l-1])
* 0.01
parameters[‘b’ + str(l)] = np.zeros((layer_dims[l], 1))
assert(parameters[‘W’ + str(l)].shape == (layer_dims[l], layer_dims[l-1]))
assert(parameters[‘b’ +
str(l)].shape == (layer_dims[l], 1))
return parameters

和具体的数据结合,就知道了输入的维度和样本的数量。假设我们的训练数据中有209张图片,每张都是64*64像素,那么输入特征数n_x就是64*64*3 = 12288,m就是209,如下图:

每一层参数的维度

如果我们将W和b参数设为parameters,每一个初始化的W和b都是parameters这个list中的一个元素,那么L-1个循环隐藏层其实就是len(parameters)//2。下图是一个ReLU层加一个sigmoid层的一个loop,怎么将同样的计算复制到我们的L层深度神经网络中呢?

线性组合->ReLU->线性组合->sigmoid的前向与反向传播的例子

前向传播

前向传播分为linear forward和activation forward,前者计算Z,后者计算A=g(Z),g视激活函数的不同而不同。因为activation forward这步中包括了linear的值,所以名为linear_activation_forward函数。由于反向传播的梯度计算中会用到W、b、Z的值,所以我们将每一个iteration中将每个神经元的这些值暂时储存在caches这个大列表中,再在下一轮循环中覆盖掉。代码如下:

def linear_forward(A, W, b):
Z = np.dot(W, A) + b
assert(Z.shape == (W.shape[0], A.shape[1]))
cache = (A, W, b)
return Z, cache

def linear_activation_forward(A_prev, W, b, activation):
if activation == “sigmoid”:
Z, linear_cache = linear_forward(A_prev, W, b)
A, activation_cache
= sigmoid(Z)
elif activation == “relu”:
Z, linear_cache = linear_forward(A_prev, W, b)
A, activation_cache = relu(Z)
assert (A.shape
== (W.shape[0], A_prev.shape[1]))
cache = (linear_cache, activation_cache)
return A, cache

在定义了每一个神经单元的linear-activation forward之后,我们来定义这个L层神经网络的前向传播:

def L_model_forward(X, parameters):
caches = []
A = X
L = len(parameters) // 2 # 因为之前设定的parameters包含了每一层W和b的初始值,所以层数是这个列表长度的一半

for l in range(1, L): # L-1个隐藏层用ReLU激活函数
A_prev = A
A, cache = linear_activation_forward(A_prev,
parameters[‘W’ + str(l)], parameters[‘b’ + str(l)], activation = “relu”)
caches.append(cache)
AL, cache = linear_activation_forward(A, parameters[‘W’ + str(L)], parameters[‘b’ + str(L)],
activation = “sigmoid”) # 第L个层用sigmoid函数
caches.append(cache)
assert(AL.shape == (1,X.shape[1]))
return AL, caches

前向传播的尽头是计算当前参数下的损失~不过正如在后面L_model_backward函数中看到的,我们这里直接计算dL/dAL,并不计算L,这里计算cost是为了在训练过程检查代价是不是在稳定下降,以确保我们使用了合适的学习率。

def compute_cost(AL, Y):
m = Y.shape[1]
cost = -np.sum(Y*np.log(AL) + (1-Y)*np.log(1-AL))/m
cost = np.squeeze(cost)
# 将类似于 [[17]] 的cost变成 17
assert(cost.shape == ())
return cost

反向传播

反向传播和前向传播的函数设计是对称的,但是会比前向传播复杂一丢丢,需要小心各种线性代数中的运算规则——这也是为什么在前向传播中我们都在return前加入了维度检查(assert + shape)。下图显示了每一个神经元在反向传播中的输入和输出。现在我们看到之前在前向传播中缓存的用处了。如果我不储存W和Z的值,我就没有办法在反向线性传播中计算dW,db同理。

反向线性传播中的输入和输出

def linear_backward(dZ, cache):
A_prev, W, b = cache
m = A_prev.shape[1]
dW = np.dot(dZ, A_prev.T)/m
db = np.sum(dZ, axis=1, keepdims=True)/m
dA_prev
= np.dot(W.T, dZ)
assert (dA_prev.shape == A_prev.shape)
assert (dW.shape == W.shape)
assert (db.shape == b.shape)
return dA_prev, dW, db

上述的公式用线性代数表示为下图:

反向传播中参数梯度的计算

Andrew贴心地为大家提供了写好的函数:relu_backward和sigmoid_backward,如果我们自己写的话,需要在前向传播中储存A的值,否则在很多反向传播中就不知道dA/dZ,因为有些激活函数的导数是A的函数,比如sigmoid函数和tanh函数。

def linear_activation_backward(dA, cache, activation):
linear_cache, activation_cache = cache
if activation == “relu”:
dZ = relu_backward(dA, activation_cache)

   dA\_prev, dW, db = linear\_backward(dZ, linear\_cache)  
elif activation == "sigmoid":  
    dZ = sigmoid\_backward(dA, activation\_cache)  
    dA\_prev,

dW, db = linear_backward(dZ, linear_cache)
return dA_prev, dW, db

同前向传播一样,我们将dL/dAL反向传播,通过每一层的linear-activation backward构建整个完整的反向传播体系:

def L_model_backward(AL, Y, caches):
grads = {}
L = len(caches) # 层数
m = AL.shape[1]
Y = Y.reshape(AL.shape) # 改变Y的维度,确保其与AL的维度统一
dAL = - (np.divide(Y,
AL) - np.divide(1 - Y, 1 - AL)) # 代价函数对输出层输出AL的导数,就不计算具体的cost了
current_cache = caches[L-1]
grads[“dA” + str(L)], grads[“dW” + str(L)], grads[“db” + str(L)] = linear_activation_backward(dAL, current_cache,
activation = “sigmoid”)
for l in reversed(range(L-1)):
current_cache = caches[l]
dA_prev_temp, dW_temp, db_temp = linear_activation_backward(grads[“dA”+str(l+2)],
current_cache, activation = “relu”)
grads[“dA” + str(l + 1)] = dA_prev_temp
grads[“dW” + str(l + 1)] = dW_temp
grads[“db” + str(l + 1)]
= db_temp
return grads

一般现实工作中不会用线性代数如此折磨你,就算要自己一步一步这么写,也可以加入梯度检查等等来为你增添信心,具体以后再分享~

参数更新

至此我们已经在一个循环中计算出了当前W和b的梯度,最后就是用梯度下降的定义更新参数。在下一个例子中我们会看到如何用我们已经写好的每一步的函数,使用for loop执行梯度下降,最后得到训练好的模型。

def update_parameters(parameters, grads, learning_rate):
L = len(parameters) // 2
for l in range(L):
parameters[“W” + str(l+1)] = parameters[“W” + str(l+1)] - learning_rate*grads[“dW”

  • str(l + 1)]
    parameters["b" + str(l+1)] = parameters["b" + str(l+1)] - learning\_rate\*grads["db" + str(l + 1)]  
    return parameters

【例 3】继续AI科学家对猫的执念……

下面我们用例2中的L层深度神经网络来识别一张图是不是猫咪[捂脸],因为代码有点多所以分成了2和3的两个例子。假设train_x_orig是我们原始的输入,已经将图片像素数据提取并flatten为适合训练的数据,这里我们将每一个样本从64*64*3的输入变成一个12288*1的矢量,然后将值标准化到0-1之间:

train_x_flatten = train_x_orig.reshape(train_x_orig.shape[0], -1).Ttrain_x = train_x_flatten/255.

image2vector conversion

设计一个神经网络:layers_dims = [12288, 20, 7, 5, 1],即每个样本有12288个像素输入,第一层20个ReLU神经元,第二层7个,第三层5个,最后一个sigmoid。

L层的神经网络来识别图像中的喵喵

终于可以调用我们之前辛辛苦苦写好的函数啦!之前写的函数都是每一个iteration中的每一步骤,现在我们将每一个loop循环num_iterations次。

parameters = initialize_parameters_deep(layers_dims)
for i in range(0, num_iterations):
AL, caches = L_model_forward(X, parameters)
cost = compute_cost(AL, Y)
grads = L_model_backward(AL,
Y, caches)
parameters = update_parameters(parameters, grads, learning_rate)

这里的parameters就是训练好的参数,我们就可以用它来预测新的萌萌哒猫猫啦。读取一张num_px*num_px图片的像素再将其RGB转换为num_px*num_px*3的方法,请注意这里的图片尺寸需和训练数据中的一样:

fname = “images/“ + my_image
np.array(ndimage.imread(fname, flatten=False))
scipy.misc.imresize(image, size=(num_px,num_px)).reshape((num_px*num_px*3,1))

最后我们的模型在训练数据上的精确度为98.6%。然后我们就可以用类似于predict(test_x, test_y, parameters)这样的方法就能预测这个图片是不是一个喵喵啦!最后得到在训练数据上的精确度为80%,但是让我们来看看剩下20%没有正确预测的样本是什么样子的……

基础模型没有正确预测的样本例子

除了第五张姿势扭捏的猫猫外,2和4中的猫猫我们也没有很好地识别出来。不过不用担心,卷积神经网络 (Convolutional Neural Networks) 会比image2vector更适合于处理图片数据,所以敬请期待以后的更新!

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

1…396397398399

开发者博客

3990 日志
1304 标签
RSS
© 2024 开发者博客
本站总访问量次
由 Hexo 强力驱动
|
主题 — NexT.Muse v5.1.4
0%