开发者博客 – IT技术 尽在开发者博客

开发者博客 – 科技是第一生产力


  • 首页

  • 归档

  • 搜索

Java & Android 集合框架

发表于 2022-11-15

⭐️ 本文已收录到 AndroidFamily,技术和职场问题,请关注公众号 [彭旭锐] 和 [BaguTree Pro] 知识星球提问。

学习数据结构与算法的关键在于掌握问题背后的算法思维框架,你的思考越抽象,它能覆盖的问题域就越广,理解难度也更复杂。在实际的业务开发中,往往不需要我们手写数据结构,而是直接使用标准库的数据结构 / 容器类。

本文是 Java & Android 集合框架系列的第 9 篇文章,完整文章目录请移步到文章末尾~

前言

大家好,我是小彭。

在前面的文章里,我们聊到了散列表的开放寻址法和分离链表法,也聊到了 HashMap、LinkedHashMap 和 WeakHashMap 等基于分离链表法实现的散列表。

今天,我们来讨论 Java 标准库中一个使用开放寻址法的散列表结构,也是 Java & Android “面试八股文” 的标准题库之一 —— ThreadLocal。

本文源码基于 Java 8 ThreadLocal。

  • Java & Android 集合框架 #9 全网最全的 ThreadLocal 原理详细解析 —— 原理篇
  • Java & Android 集合框架 #10 全网最全的 ThreadLocal 原理详细解析 —— 源码篇

思维导图:


  1. 回顾散列表的工作原理

在开始分析 ThreadLocal 的实现原理之前,我们先回顾散列表的工作原理。

散列表是基于散列思想实现的 Map 数据结构,将散列思想应用到散列表数据结构时,就是通过 hash 函数提取键(Key)的特征值(散列值),再将键值对映射到固定的数组下标中,利用数组支持随机访问的特性,实现 O(1) 时间的存储和查询操作。

散列表示意图

在从键值对映射到数组下标的过程中,散列表会存在 2 次散列冲突:

  • 第 1 次 - hash 函数的散列冲突: 这是一般意义上的散列冲突;
  • 第 2 次 - 散列值取余转数组下标: 本质上,将散列值转数组下标也是一次 Hash 算法,也会存在散列冲突。

事实上,由于散列表是压缩映射,所以我们无法避免散列冲突,只能保证散列表不会因为散列冲突而失去正确性。常用的散列冲突解决方法有 2 类:

  • 开放寻址法: 例如 ThreadLocalMap;
  • 分离链表法: 例如 HashMap。

开放寻址(Open Addressing)的核心思想是: 在出现散列冲突时,在数组上重新探测出一个空闲位置。 经典的探测方法有线性探测、平方探测和双散列探测。线性探测是最基本的探测方法,我们今天要分析的 ThreadLocal 中的 ThreadLocalMap 散列表就是采用线性探测的开放寻址法。


  1. 认识 ThreadLocal 线程局部存储

2.1 说一下 ThreadLocal 的特点?

ThreadLocal 提供了一种特殊的线程安全方式。

使用 ThreadLocal 时,每个线程可以通过 ThreadLocal#get 或 ThreadLocal#set 方法访问资源在当前线程的副本,而不会与其他线程产生资源竞争。这意味着 ThreadLocal 并不考虑如何解决资源竞争,而是为每个线程分配独立的资源副本,从根本上避免发生资源冲突,是一种无锁的线程安全方法。

用一个表格总结 ThreadLocal 的 API:

public API 描述
set(T) 设置当前线程的副本
T get() 获取当前线程的副本
void remove() 移除当前线程的副本
ThreadLocal withInitial(Supplier) 创建 ThreadLocal 并指定缺省值创建工厂
protected API 描述
T initialValue() 设置缺省值

2.2 ThreadLocal 如何实现线程隔离?(重点理解)

ThreadLocal 在每个线程的 Thread 对象实例数据中分配独立的内存区域,当我们访问 ThreadLocal 时,本质上是在访问当前线程的 Thread 对象上的实例数据,不同线程访问的是不同的实例数据,因此实现线程隔离。

Thread 对象中这块数据就是一个使用线性探测的 ThreadLocalMap 散列表,ThreadLocal 对象本身就作为散列表的 Key ,而 Value 是资源的副本。当我们访问 ThreadLocal 时,就是先获取当前线程实例数据中的 ThreadLocalMap 散列表,再通过当前 ThreadLocal 作为 Key 去匹配键值对。

ThreadLocal.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
java复制代码// 获取当前线程的副本
public T get() {
// 先获取当前线程实例数据中的 ThreadLocalMap 散列表
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
// 通过当前 ThreadLocal 作为 Key 去匹配键值对
ThreadLocalMap.Entry e = map.getEntry(this);
// 详细源码分析见下文 ...
}

// 获取线程 t 的 threadLocals 字段,即 ThreadLocalMap 散列表
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

// 静态内部类
static class ThreadLocalMap {
// 详细源码分析见下文 ...
}

Thread.java

1
2
3
4
5
6
7
8
9
10
11
12
java复制代码// Thread 对象的实例数据
ThreadLocal.ThreadLocalMap threadLocals = null;
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

// 线程退出之前,会置空threadLocals变量,以便随后GC
private void exit() {
// ...
threadLocals = null;
inheritableThreadLocals = null;
inheritedAccessControlContext = null;
// ...
}

ThreadLocal 示意图

2.3 使用 InheritableThreadLocal 继承父线程的局部存储

在业务开发的过程中,我们可能希望子线程可以访问主线程中的 ThreadLocal 数据,然而 ThreadLocal 是线程隔离的,包括在父子线程之间也是线程隔离的。为此,ThreadLocal 提供了一个相似的子类 InheritableThreadLocal,ThreadLocal 和 InheritableThreadLocal 分别对应于线程对象上的两块内存区域:

  • 1、ThreadLocal 字段: 在所有线程间隔离;
  • 2、InheritableThreadLocal 字段: 子线程会继承父线程的 InheritableThreadLocal 数据。父线程在创建子线程时,会批量将父线程的有效键值对数据拷贝到子线程的 InheritableThreadLocal,因此子线程可以复用父线程的局部存储。

在 InheritableThreadLocal 中,可以重写 childValue() 方法修改拷贝到子线程的数据。

1
2
3
4
5
6
7
8
java复制代码public class InheritableThreadLocal<T> extends ThreadLocal<T> {

// 参数:父线程的数据
// 返回值:拷贝到子线程的数据,默认为直接传递
protected T childValue(T parentValue) {
return parentValue;
}
}

需要特别注意:

  • 注意 1 - InheritableThreadLocal 区域在拷贝后依然是线程隔离的: 在完成拷贝后,父子线程对 InheritableThreadLocal 的操作依然是相互独立的。子线程对 InheritableThreadLocal 的写不会影响父线程的 InheritableThreadLocal,反之亦然;
  • 注意 2 - 拷贝过程在父线程执行: 这是容易混淆的点,虽然拷贝数据的代码写在子线程的构造方法中,但是依然是在父线程执行的。子线程是在调用 start() 后才开始执行的。

InheritableThreadLocal 示意图

2.4 ThreadLocal 的自动清理与内存泄漏问题

ThreadLocal 提供具有自动清理数据的能力,具体分为 2 个颗粒度:

  • 1、自动清理散列表: ThreadLocal 数据是 Thread 对象的实例数据,当线程执行结束后,就会跟随 Thread 对象 GC 而被清理;
  • 2、自动清理无效键值对: ThreadLocal 是使用弱键的动态散列表,当 Key 对象不再被持有强引用时,垃圾收集器会按照弱引用策略自动回收 Key 对象,并在下次访问 ThreadLocal 时清理无效键值对。

引用关系示意图

然而,自动清理无效键值对会存在 “滞后性”,在滞后的这段时间内,无效的键值对数据没有及时回收,就发生内存泄漏。

  • 举例 1: 如果创建 ThreadLocal 的线程一直持续运行,整个散列表的数据就会一致存在。比如线程池中的线程(大体)是复用的,这部分复用线程中的 ThreadLocal 数据就不会被清理;
  • 举例 2: 如果在数据无效后没有再访问过 ThreadLocal 对象,那么自然就没有机会触发清理;
  • 举例 3: 即使访问 ThreadLocal 对象,也不一定会触发清理(原因见下文源码分析)。

综上所述:虽然 ThreadLocal 提供了自动清理无效数据的能力,但是为了避免内存泄漏,在业务开发中应该及时调用 ThreadLocal#remove 清理无效的局部存储。

2.5 ThreadLocal 的使用场景

  • 场景 1 - 无锁线程安全: ThreadLocal 提供了一种特殊的线程安全方式,从根本上避免资源竞争,也体现了空间换时间的思想;
  • 场景 2 - 线程级别单例: 一般的单例对象是对整个进程可见的,使用 ThreadLocal 也可以实现线程级别的单例;
  • 场景 3 - 共享参数: 如果一个模块有非常多地方需要使用同一个变量,相比于在每个方法中重复传递同一个参数,使用一个 ThreadLocal 全局变量也是另一种传递参数方式。

2.6 ThreadLocal 使用示例

我们采用 Android Handler 机制中的 Looper 消息循环作为 ThreadLocal 的学习案例:

android.os.Looper.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
java复制代码// /frameworks/base/core/java/android/os/Looper.java

public class Looper {

// 静态 ThreadLocal 变量,全局共享同一个 ThreadLocal 对象
static final ThreadLocal<Looper> sThreadLocal = new ThreadLocal<Looper>();

private static void prepare(boolean quitAllowed) {
if (sThreadLocal.get() != null) {
throw new RuntimeException("Only one Looper may be created per thread");
}
// 设置 ThreadLocal 变量的值,即设置当前线程关联的 Looper 对象
sThreadLocal.set(new Looper(quitAllowed));
}

public static Looper myLooper() {
// 获取 ThreadLocal 变量的值,即获取当前线程关联的 Looper 对象
return sThreadLocal.get();
}

public static void prepare() {
prepare(true);
}
...
}

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
java复制代码new Thread(new Runnable() {
@Override
public void run() {
Looper.prepare();
// 两个线程独立访问不同的 Looper 对象
System.out.println(Looper.myLooper());
}
}).start();

new Thread(new Runnable() {
@Override
public void run() {
Looper.prepare();
// 两个线程独立访问不同的 Looper 对象
System.out.println(Looper.myLooper());
}
}).start();

要点如下:

  • 1、Looper 中的 ThreadLocal 被声明为静态类型,泛型参数为 Looper,全局共享同一个 ThreadLocal 对象;
  • 2、Looper#prepare() 中调用 ThreadLocal#set() 设置当前线程关联的 Looper 对象;
  • 3、Looper#myLooper() 中调用 ThreadLocal#get() 获取当前线程关联的 Looper 对象。

我们可以画出 Looper 中访问 ThreadLocal 的 Timethreads 图,可以看到不同线程独立访问不同的 Looper 对象,即线程间不存在资源竞争。

Looper ThreadLocal 示意图

2.7 阿里巴巴 ThreadLocal 编程规约

在《阿里巴巴 Java 开发手册》中,亦有关于 ThreadLocal API 的编程规约:

  • 【强制】 SimpleDateFormate 是线程不安全的类,一般不要定义为 static ****变量。如果定义为 static,必须加锁,或者使用 DateUtils 工具类(使用 ThreadLocal 做线程隔离)。

DataFormat.java

1
2
3
4
5
6
7
8
9
10
java复制代码private static final ThreadLocal<DataFormat> df = new ThreadLocal<DateFormat>(){
// 设置缺省值 / 初始值
@Override
protected DateFormat initialValue(){
return new SimpleDateFormat("yyyy-MM-dd");
}
};

// 使用:
DateUtils.df.get().format(new Date());
  • 【参考】 (原文过于啰嗦,以下是小彭翻译转述)ThreadLocal 变量建议使用 static 全局变量,可以保证变量在类初始化时创建,所有类实例可以共享同一个静态变量(例如,在 Android Looper 的案例中,ThreadLocal 就是使用 static 修饰的全局变量)。
  • 【强制】 必须回收自定义的 ThreadLocal 变量,尤其在线程池场景下,线程经常被反复用,如果不清理自定义的 ThreadLocal 变量,则可能会影响后续业务逻辑和造成内存泄漏等问题。尽量在代码中使用 try-finally 块回收,在 finally 中调用 remove() 方法。

  1. ThreadLocal 源码分析

这一节,我们来分析 ThreadLocal 中主要流程的源码。

3.1 ThreadLocal 的属性

ThreadLocal 只有一个 threadLocalHashCode 散列值属性:

  • 1、threadLocalHashCode 相当于 ThreadLocal 的自定义散列值,在创建 ThreadLocal 对象时,会调用 nextHashCode() 方法分配一个散列值;
  • 2、ThreadLocal 每次调用 nextHashCode() 方法都会将散列值追加 HASH_INCREMENT,并记录在一个全局的原子整型 nextHashCode 中。

提示: ThreadLocal 的散列值序列为:0、HASH_INCREMENT、HASH_INCREMENT * 2、HASH_INCREMENT * 3、…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
java复制代码public class ThreadLocal<T> {

// 疑问 1:OK,threadLocalHashCode 类似于 hashCode(),那为什么 ThreadLocal 不重写 hashCode()
// ThreadLocal 的散列值,类似于重写 Object#hashCode()
private final int threadLocalHashCode = nextHashCode();

// 全局原子整型,每调用一次 nextHashCode() 累加一次
private static AtomicInteger nextHashCode = new AtomicInteger();

// 疑问:为什么 ThreadLocal 散列值的增量是 0x61c88647?
private static final int HASH_INCREMENT = 0x61c88647;

private static int nextHashCode() {
// 返回上一次 nextHashCode 的值,并累加 HASH_INCREMENT
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
}

static class ThreadLocalMap {
// 详细源码分析见下文 ...
}

不出意外的话又有小朋友出来举手提问了🙋🏻‍♀️:

  • 🙋🏻‍♀️疑问 1:OK,threadLocalHashCode 类似于 hashCode(),那为什么 ThreadLocal 不重写 hashCode()?

如果重写 Object#hashCode(),那么 threadLocalHashCode 散列值就会对所有散列表生效。而 threadLocalHashCode 散列值是专门针对数组为 2 的整数幂的散列表设计的,在其他散列表中不一定表现良好。因此 ThreadLocal 没有重写 Object#hashCode(),让 threadLocalHashCode 散列值只在 ThreadLocal 内部的 ThreadLocalMap 使用。

常规做法

1
2
3
4
5
6
7
8
java复制代码public class ThreadLocal<T> {

// ThreadLocal 未重写 hashCode()
@Override
public int hashCode() {
return threadLocalHashCode;
}
}
  • 🙋🏻‍♀️疑问 2:为什么使用 ThreadLocal 作为散列表的 Key,而不是常规思维用 Thread Id 作为 Key?

如果使用 Thread Id 作为 Key,那么就需要在每个 ThreadLocal 对象中维护散列表,而不是每个线程维护一个散列表。此时,当多个线程并发访问同一个 ThreadLocal 对象中的散列表时,就需要通过加锁保证线程安全。而 ThreadLocal 的方案让每个线程访问独立的散列表,就可以从根本上规避线程竞争。

3.2 ThreadLocal 的 API

分析代码,可以总结出 ThreadLocal API 的用法和注意事项:

  • 1、ThreadLocal#get: 获取当前线程的副本;
  • 2、ThreadLocal#set: 设置当前线程的副本;
  • 3、ThreadLocal#remove: 移除当前线程的副本;
  • 4、ThreadLocal#initialValue: 由子类重写来设置缺省值:
    • 4.1 如果未命中(Map 取值为 nul),则会调用 initialValue() 创建并设置缺省值;
    • 4.2 ThreadLocal 的缺省值只会在缓存未命中时创建,即缺省值采用懒初始化策略;
    • 4.3 如果先设置后又移除副本,再次 get 获取副本未命中时依然会调用 initialValue() 创建并设置缺省值。
  • 5、ThreadLocal#withInitial: 方便设置缺省值,而不需要实现子类。

在 ThreadLocal 的 API 会通过 getMap() 方法获取当前线程的 Thread 对象中的 threadLocals 字段,这是线程隔离的关键。

ThreadLocal.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
java复制代码public ThreadLocal() {
// do nothing
}

// 子类可重写此方法设置缺省值(方法命名为 defaultValue 获取更贴切)
protected T initialValue() {
// 默认不提供缺省值
return null;
}

// 帮助方法:不重写 ThreadLocal 也可以设置缺省值
// supplier:缺省值创建工厂
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}

// 1. 获取当前线程的副本
public T get() {
Thread t = Thread.currentThread();
// ThreadLocalMap 详细源码分析见下文
ThreadLocalMap map = getMap(t);
if (map != null) {
// 存在匹配的Entry
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
T result = (T)e.value;
return result;
}
}
// 未命中,则获取并设置缺省值(即缺省值采用懒初始化策略)
return setInitialValue();
}

// 获取并设置缺省值
private T setInitialValue() {
T value = initialValue();
// 其实源码中是并不是直接调用set(),而是复制了一份 set() 方法的源码
// 这是为了防止子类重写 set() 方法后改变缺省值逻辑
set(value);
return value;
}

// 2. 设置当前线程的副本
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
// 直到设置值的时候才创建(即 ThreadLocalMap 采用懒初始化策略)
createMap(t, value);
}

// 3. 移除当前线程的副本
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}

ThreadLocalMap getMap(Thread t) {
// 重点:获取当前线程的 threadLocals 字段
return t.threadLocals;
}

// ThreadLocal 缺省值帮助类
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

private final Supplier<? extends T> supplier;

SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}

// 重写 initialValue() 以设置缺省值
@Override
protected T initialValue() {
return supplier.get();
}
}

3.3 InheritableThreadLocal 如何继承父线程的局部存储?

父线程在创建子线程时,在子线程的构造方法中会批量将父线程的有效键值对数据拷贝到子线程,因此子线程可以复用父线程的局部存储。

Thread.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
java复制代码// Thread 对象的实例数据
ThreadLocal.ThreadLocalMap threadLocals = null;
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

// 构造方法
public Thread() {
init(null, null, "Thread-" + nextThreadNum(), 0);
}

private void init(ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc, boolean inheritThreadLocals) {
...
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
// 拷贝父线程的 InheritableThreadLocal 散列表
this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
...
}

ThreadLocal.java

1
2
3
4
5
6
7
8
9
10
11
12
13
java复制代码// 带 Map 的构造方法
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
return new ThreadLocalMap(parentMap);
}

static class ThreadLocalMap {

private ThreadLocalMap(ThreadLocalMap parentMap) {
// 详细源码分析见下文 ...
Object value = key.childValue(e.value);
...
}
}

InheritableThreadLocal 在拷贝父线程散列表的过程中,会调用 InheritableThreadLocal#childValue() 尝试转换为子线程需要的数据,默认是直接传递,可以重写这个方法修改拷贝的数据。

InheritableThreadLocal.java

1
2
3
4
5
6
7
java复制代码public class InheritableThreadLocal<T> extends ThreadLocal<T> {

// 参数:父线程的数据
// 返回值:拷贝到子线程的数据,默认为直接传递
protected T childValue(T parentValue) {
return parentValue;
}

下面,我们来分析 ThreadLocalMap 的源码。


后续源码分析,见下一篇文章:Java & Android 集合框架 #10 全网最全的 ThreadLocal 原理详细解析 —— 源码篇。


版权声明

本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!

参考资料

  • 数据结构与算法分析 · Java 语言描述(第 5 章 · 散列)—— [美] Mark Allen Weiss 著
  • 算法导论(第 11 章 · 散列表)—— [美] Thomas H. Cormen 等 著
  • 《阿里巴巴Java开发手册》 杨冠宝 编著
  • 数据结构与算法之美(第 18~22 讲) —— 王争 著,极客时间 出品
  • ThreadLocal 和 ThreadLocalMap源码分析 —— KingJack 著
  • Why 0x61c88647? —— Dr. Heinz M. Kabutz 著

推荐阅读

Java & Android 集合框架系列文章目录(2023/07/08 更新):

  • #1 ArrayList 可以完全替代数组吗?
  • #2 说一下 ArrayList 和 LinkedList 的区别?
  • #3 CopyOnWriteArrayList 是如何保证线程安全的?
  • #4 ArrayDeque:如何用数组实现栈和队列?
  • #5 万字 HashMap 详解,基础(优雅)永不过时 —— 原理篇
  • #6 万字 HashMap 详解,基础(优雅)永不过时 —— 源码篇
  • #7 如何使用 LinkedHashMap 实现 LRU 缓存?
  • #8 说一下 WeakHashMap 如何清理无效数据的?
  • #9 全网最全的 ThreadLocal 原理详细解析 —— 原理篇
  • #10 全网最全的 ThreadLocal 原理详细解析 —— 源码篇

数据结构与算法系列文章:跳转阅读

⭐️ 永远相信美好的事情即将发生,欢迎加入小彭的 Android 交流社群~

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

【AI】浅析恶意文件静态检测及部分问题解决思路 前言 分析

发表于 2022-11-14

本文正在参加「金石计划 . 瓜分6万现金大奖」

前言

随着互联网的繁荣和发展,海量的核心数据和网络应用也不断向云端、数据中心等关键信息基础设施整合和迁移,主机安全也因此成为网络攻防战的焦点。恶意文件 是指由攻击者专门设计的,在未经所有者许可的情况下用来访问计算机、损害或破坏系统,对保密性、完整性或可用性进行攻击的文件,是当前互联网安全的主要威胁之一。目前,比较主流的恶意文件包括恶意脚本、漏洞利用、蠕虫、木马和间谍软件以及他们的组合或变体。

为了应对挑战,恶意文件静态检测的思想被提了出来。基于机器学习算法的防护技术为实现高准确率、自动化的未知恶意文件检测提供了行之有效的技术途径,已逐渐成为业内研究的热点。

接下来博主将简单介绍其中一种恶意文件静态检测模型的部分内容;

番外:对于想了解梯度下降算法的小伙伴,也可以看看博主的往期博文:

  • 【AI】浅谈梯度下降算法(理论篇)
  • 【AI】浅谈梯度下降算法(实战篇)
  • 【AI】浅谈梯度下降算法(拓展篇)

分析

这里的恶意文件静态检测是将恶意文件的二进制转成灰度图,作为 CNN 模型的输入,经过一系列的过程得到输出,然后进行对比、评估等;

考虑到每个样本的大小是不固定的,本来是以 1M 大小作为区分界限的,这里的话,使用 PadSequence 来确保数据长度的一致性;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
py复制代码class PadSequence(object):

...

def pltexe(self, arr):
arr_n = len(arr) // (1024*1024)
arr_end_len = len(arr) % (1024*1024)
re_arr = []
siz = 1024

# 矩阵转换:行列变化,总数不变
for ite in range(arr_n):
st = ite * 1024 * 1024
pggg0 = np.array(arr[st : st+1024*1024])
re_arr.append(pggg0.reshape(siz,siz) / 255)

# 用 0 补足
if arr_end_len!= 0 :
arr_ = (1024*1024-arr_end_len) * [0]
pggg0 = np.array(arr[1024*1024*arr_n:] + arr_)
re_arr.append(pggg0.reshape(siz,siz) / 255)

return re_arr

def doooo_(self, filelist):
...

# 设定列表长度不超过20
if len(featurelist) > 20:
re_feature_lab = random.sample(list(zip(featurelist,labellist)), 20)
featurelist = [x[0] for x in re_feature_lab]
labellist = [x[1] for x in re_feature_lab]

...

return featurelist_batch, labellist_batch

def __call__(self, batch):
return self.doooo_(batch)

然后进行数据加载:

1
2
3
4
5
6
7
py复制代码with open(path, 'rb') as f:
train_data = pickle.load(f)

train_loader = DataLoader(train_data, batch_size=10, shuffle=True, num_workers = 20, collate_fn=PadSequence(maxlen = 0))

pad = PadSequence()
pad.__call__(train_data[:4])[0]

image.png

最后进入模型进行训练以及验证;

image.png

TIP

在模型训练中,可以使用 try...excpet 模块,即使因为意外中断训练,之前的训练结果也都保存下来了,下次训练就不用重头开始了:

1
2
3
4
5
6
py复制代码try:
...

except:
model = model.eval()
torch.save(model.state_dict(), 'error.pth')

问题解决

OOM

在启动项目时,可能会出现以下报错:

1
2
3
4
5
6
7
8
sql复制代码(sid10t) bash-4.2# python model_run.py 
Traceback (most recent call last):
...
RuntimeError: Caught RuntimeError in replica 0 on device 0.

Original Traceback (most recent call last):
...
RuntimeError: CUDA out of memory. Tried to allocate 24.00 MiB (GPU 0; 7.80 GiB total capacity; 6.31 GiB already allocated; 6.56 MiB free; 6.46 GiB reserved in total by PyTorch)

这是因为我们将 DataLoader 里的 batch_size 参数设置的过大了,从而导致了显存溢出;

那么无非就是两个解决方案:

  • 多选定几个 CUDA;
  • 将 batch_size 参数调小;

对于第一种方案,可以一股脑的将机子上的所有 CUDA 全部选上:

1
2
3
py复制代码device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model= nn.DataParallel(model)

对于第二种方案,将 batch_size 参数调小,也是有讲究的,我们要尽可能的提高资源的利用率,因此需要做一些操作:

  1. 首先是查看模型占用了多少 GPU,watch -n 1 nvidia-smi:
    image.png
  2. 然后折半减少 batch_size,查看显存占用率,调节至合适大小:
    image.png

Socket exception

由于模型跑在空闲的机子上,而样本却在另一台机子上,因此,需要通过 SFTP 进行读取,不出意外的话,要出意外了;

首先是在 pad 函数里构建 SFTP 连接:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
py复制代码class PadSequence(object):
def __init__(self, maxlen = 8000):
self.maxlen = maxlen
...
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.client.connect(self.hostname, self.port ,self.username, self.password, compress=True)
self.sftp_client = self.client.open_sftp()

def getfile_ftp(self, file_path):
remote_file = self.sftp_client.open(file_path, 'rb')
try:
str_object_with_pe_file_data = remote_file.read()
finally:
remote_file.close()
return str_object_with_pe_file_data

def __call__(self, batch):
return self.doooo_(batch)

然后在 DataLoader 中使用到它的回显函数:

1
py复制代码train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers = 20, collate_fn=PadSequence(maxlen = 0))

好的,做完上面的之后,坑爹的来了,报错了:Socket exception: Connection reset by peer (104)

1668414245075_3148F4BD-D120-4d3c-8640-BB3DDCEFC492.png

不知道是因为 DataLoader 底层逻辑问题,还是这台服务器的问题,症结就是在于不能使用多进程进行 SFTP 读取,因此这里的解决方案就是将参数 num_workers 置为 0;

后记

以上就是 浅析恶意文件静态检测及部分问题解决思路 的全部内容了,大致讲述了恶意文件静态检测的其中一种思路,以及图文结合的分析了部分问题的解决思路,希望大家有所收获!

📝 上篇精讲:【AI】浅谈梯度下降算法(实战篇)

💖 我是 𝓼𝓲𝓭𝓲𝓸𝓽,期待你的关注;

👍 创作不易,请多多支持;

🔥 系列专栏:AI 项目实战

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

分析 vant4 源码,如何用 vue3 + ts 开发一个

发表于 2022-11-14

本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!

  1. 前言

大家好,我是若川。我倾力持续组织了一年每周大家一起学习200行左右的源码共读活动,感兴趣的可以点此扫码加我微信 ruochuan02 参与。另外,想学源码,极力推荐关注我写的专栏《学习源码整体架构系列》,目前是掘金关注人数(4.1k+人)第一的专栏,写有20余篇源码文章。

我们开发业务时经常会使用到组件库,一般来说,很多时候我们不需要关心内部实现。但是如果希望学习和深究里面的原理,这时我们可以分析自己使用的组件库实现。有哪些优雅实现、最佳实践、前沿技术等都可以值得我们借鉴。

相比于原生 JS 等源码。我们或许更应该学习,正在使用的组件库的源码,因为有助于帮助我们写业务和写自己的组件。

如果是 Vue 技术栈,开发移动端的项目,大多会选用 vant 组件库,目前(2022-11-13) star 多达 20.4k。我们可以挑选 vant 组件库学习,我会写一个组件库源码系列专栏,欢迎大家关注。

vant 组件库源码分析系列:

  • 1.vant 4 即将正式发布,支持暗黑主题,那么是如何实现的呢
  • 2.跟着 vant 4 源码学习如何用 vue3+ts 开发一个 loading 组件,仅88行代码
  • 3.分析 vant 4 源码,如何用 vue3 + ts 开发一个瀑布流滚动加载的列表组件?
  • 4.分析 vant 4 源码,学会用 vue3 + ts 开发毫秒级渲染的倒计时组件,真是妙啊
  • 5.vant 4.0 正式发布了,分析其源码学会用 vue3 写一个图片懒加载组件!

学完本文,你将学到:

1
2
3
bash复制代码1. 学会如何用 vue3 + ts 开发一个 List 组件
2. 学会封装各种组合式 `API`
3. 等等
  1. 准备工作

看一个开源项目,第一步应该是先看 README.md 再看贡献文档 github/CONTRIBUTING.md。

2.1 克隆源码 && 跑起来

You will need Node.js >= 14 and pnpm.

1
2
3
4
5
6
7
8
9
10
11
12
13
bash复制代码# 推荐克隆我的项目
git clone https://github.com/lxchuan12/vant-analysis
cd vant-analysis/vant

# 或者克隆官方仓库
git clone git@github.com:vant-ui/vant.git
cd vant

# 安装依赖,会运行所有 packages 下仓库的 pnpm i 钩子 pnpm prepare 和 pnpm i
pnpm i

# Start development
pnpm dev

我们先来看 pnpm dev 最终执行的什么命令。

vant 项目使用的是 monorepo 结构。查看根路径下的 package.json。

vant/package.json => "dev": "pnpm --dir ./packages/vant dev"
vant/packages/vant/package.json => "dev": "vant-cli dev"

pnpm dev 最终执行的是:vant-cli dev 执行测试用例。本文主要是学习 List 组件 的实现,所以我们就不深入 vant-cli dev 命令了。

  1. List 组件

List 组件文档

瀑布流滚动加载,用于展示长列表,当列表即将滚动到底部时,会触发事件并加载更多列表项。

从这个描述和我们自己体验 demo 来。
至少有以下三个问题值得去了解学习。

  • 如何监听滚动
  • 如何计算滚动到了底部
  • 如何触发事件加载更多

带着问题我们直接找到 list demo 文件:vant/packages/vant/src/list/demo/index.vue。为什么是这个文件,我在上篇文章跟着 vant4 源码学习如何用 vue3+ts 开发一个 loading 组件,仅88行代码分析了其原理,感兴趣的小伙伴点击查看。这里就不赘述了。

3.1 利用 demo 调试

组件源码中的 TS 代码我不会过多解释。没学过 TS 的小伙伴,推荐学这个TypeScript 入门教程。
另外,vant 使用了 @vue/babel-plugin-jsx 插件来支持 JSX、TSX。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
js复制代码// vant/packages/vant/src/list/demo/index.vue
// 代码有删减
<script setup lang="ts">
import VanList from '..';
import { ref } from 'vue';

const t = useTranslate({
'zh-CN': {
errorInfo: '错误提示',
errorText: '请求失败,点击重新加载',
pullRefresh: '下拉刷新',
finishedText: '没有更多了',
},
'en-US': {
errorInfo: 'Error Info',
errorText: 'Request failed. Click to reload',
pullRefresh: 'PullRefresh',
finishedText: 'Finished',
},
});

const list = ref([
{
items: [] as string[],
refreshing: false,
loading: false,
error: false,
finished: false,
},
]);

// 加载数据
const onLoad = (index: number) => {
const currentList = list.value[index];
currentList.loading = true;

setTimeout(() => {
if (currentList.refreshing) {
currentList.items = [];
currentList.refreshing = false;
}

for (let i = 0; i < 10; i++) {
const text = currentList.items.length + 1;
currentList.items.push(text < 10 ? '0' + text : String(text));
}

currentList.loading = false;
currentList.refreshing = false;

// show error info in second demo
if (index === 1 && currentList.items.length === 10 && !currentList.error) {
currentList.error = true;
} else {
currentList.error = false;
}

if (currentList.items.length >= 40) {
currentList.finished = true;
}
}, 1000);
};
</script>
<template>
<van-tabs>
<van-tab :title="t('basicUsage')">
<van-list
v-model:loading="list[0].loading"
:finished="list[0].finished"
:finished-text="t('finishedText')"
@load="onLoad(0)"
>
<van-cell v-for="item in list[0].items" :key="item" :title="item" />
</van-list>
</van-tab>
<template>
  1. 入口文件

主要就是导出一下类型和变量等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
js复制代码// vant/packages/vant/src/list/index.ts
import { withInstall } from '../utils';
import _List, { ListProps } from './List';

export const List = withInstall(_List);
export default List;
export { listProps } from './List';
export type { ListProps };
export type { ListInstance, ListDirection, ListThemeVars } from './types';

declare module 'vue' {
export interface GlobalComponents {
VanList: typeof List;
}
}

withInstall 函数在上篇文章5.1 withInstall 给组件对象添加 install 方法 也有分析,这里就不赘述了。

我们可以在这些文件,任意位置加上 debugger 调试源码。

  1. 主文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
js复制代码import {
ref,
watch,
nextTick,
onUpdated,
onMounted,
defineComponent,
type ExtractPropTypes,
} from 'vue';

// Utils
import {
isHidden,
truthProp,
makeStringProp,
makeNumericProp,
createNamespace,
} from '../utils';

// Composables
import { useRect, useScrollParent, useEventListener } from '@vant/use';
import { useExpose } from '../composables/use-expose';
import { useTabStatus } from '../composables/use-tab-status';

// Components
import { Loading } from '../loading';

// Types
import type { ListExpose, ListDirection } from './types';

const [name, bem, t] = createNamespace('list');

export const listProps = {
error: Boolean,
offset: makeNumericProp(300),
loading: Boolean,
finished: Boolean,
errorText: String,
direction: makeStringProp<ListDirection>('down'),
loadingText: String,
finishedText: String,
immediateCheck: truthProp,
};

export type ListProps = ExtractPropTypes<typeof listProps>;

List 组件 api

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
js复制代码export default defineComponent({
name,
props: listProps,

emits: ['load', 'update:error', 'update:loading'],

setup(props, { emit, slots }) {
// TODODEL: 可以在这里打上断点调试,或者其他地方。
debugger;
// 省略若干代码
const loading = ref(false);
const root = ref<HTMLElement>();
const placeholder = ref<HTMLElement>();
const tabStatus = useTabStatus();
const scrollParent = useScrollParent(root);
// 省略若干代码
return () => {
const Content = slots.default?.();
const Placeholder = <div ref={placeholder} class={bem('placeholder')} />;

return (
<div ref={root} role="feed" class={bem()} aria-busy={loading.value}>
{props.direction === 'down' ? Content : Placeholder}
// 比如:加载中
{renderLoading()}
// 结束文字 比如:没有更多了
{renderFinishedText()}
// 加载错误文字:比如加载失败
{renderErrorText()}
{props.direction === 'up' ? Content : Placeholder}
</div>
);
};
}
}

debugger 调试截图。

debugger 调试截图

接着我们来看其他一些事件。

5.1 一些事件 useExpose、useEventListener

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
js复制代码// 省略若干代码
setup(props, { emit, slots }) {
// 省略 check 函数,后文讲述
const check = () => {}

// 监听参数变更,执行 check
watch(() => [props.loading, props.finished, props.error], check);

// van-tabs tab 切换状态变更时 执行 check
if (tabStatus) {
watch(tabStatus, (tabActive) => {
if (tabActive) {
check();
}
});
}

onUpdated(() => {
// !是 ts中的非空断言,很多人问过
loading.value = props.loading!;
});

// 如果参数是立即检测,执行 check 函数
onMounted(() => {
if (props.immediateCheck) {
check();
}
});

// 导出 check 函数,让 refs.xxx 可以使用
useExpose<ListExpose>({ check });

// 监听滚动事件,执行 check 函数
useEventListener('scroll', check, {
target: scrollParent,
passive: true,
});
}

由上面代码可以看出,check 函数非常重要,我们在下文分析它。

我们先分析上面代码用到的 useExpose、useEventListener 组合式 API。

5.2 useExpose 暴露

1
2
3
4
5
6
7
8
9
10
js复制代码import { getCurrentInstance } from 'vue';
import { extend } from '../utils';

// expose public api
export function useExpose<T = Record<string, any>>(apis: T) {
const instance = getCurrentInstance();
if (instance) {
extend(instance.proxy as object, apis);
}
}

通过 ref 可以获取到 List 实例并调用实例方法,详见组件实例方法。

Vant 中的许多组件提供了实例方法,调用实例方法时,我们需要通过 ref 来注册组件引用信息,引用信息将会注册在父组件的 $refs 对象上。注册完成后,我们可以通过 this.$refs.xxx 访问到对应的组件实例,并调用上面的实例方法。

5.3 useEventListener 绑定事件

方便地进行事件绑定,在组件 mounted 和 activated 时绑定事件,unmounted 和 deactivated 时解绑事件。

useEventListener

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
js复制代码import { Ref, watch, isRef, unref, onUnmounted, onDeactivated } from 'vue';
import { onMountedOrActivated } from '../onMountedOrActivated';
import { inBrowser } from '../utils';

type TargetRef = EventTarget | Ref<EventTarget | undefined>;

export type UseEventListenerOptions = {
target?: TargetRef;
capture?: boolean;
passive?: boolean;
};

// TS 函数重载
// 重载 可以参考这里:http://ts.xcatliu.com/basics/type-of-function.html#%E9%87%8D%E8%BD%BD
export function useEventListener<K extends keyof DocumentEventMap>(
type: K,
listener: (event: DocumentEventMap[K]) => void,
options?: UseEventListenerOptions
): void;
export function useEventListener(
type: string,
listener: EventListener,
options?: UseEventListenerOptions
): void;
export function useEventListener(
type: string,
listener: EventListener,
options: UseEventListenerOptions = {}
) {
// 如果不是浏览器环境,直接返回,比如 SSR
if (!inBrowser) {
return;
}

const { target = window, passive = false, capture = false } = options;

let attached: boolean;

// 添加事件
const add = (target?: TargetRef) => {
const element = unref(target);

if (element && !attached) {
element.addEventListener(type, listener, {
capture,
passive,
});
attached = true;
}
};

// 移除事件
const remove = (target?: TargetRef) => {
const element = unref(target);

if (element && attached) {
element.removeEventListener(type, listener, capture);
attached = false;
}
};

// 移除事件
onUnmounted(() => remove(target));
onDeactivated(() => remove(target));
onMountedOrActivated(() => add(target));

if (isRef(target)) {
watch(target, (val, oldVal) => {
remove(oldVal);
add(val);
});
}
}
  1. steup check 函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
js复制代码const check = () => {
nextTick(() => {
// 正在 loading 或者已经完成加载
// 或者加载失败,或者tab的状态不是激活时,返回。
if (
loading.value ||
props.finished ||
props.error ||
// skip check when inside an inactive tab
tabStatus?.value === false
) {
return;
}

// offset 默认 300
const { offset, direction } = props;
// 滚动的父级元素的位置
const scrollParentRect = useRect(scrollParent);

if (!scrollParentRect.height || isHidden(root)) {
return;
}

// 触底计算
// 滚动父元素 和 占位元素
let isReachEdge = false;
const placeholderRect = useRect(placeholder);

if (direction === 'up') {
isReachEdge = scrollParentRect.top - placeholderRect.top <= offset;
} else {
isReachEdge =
placeholderRect.bottom - scrollParentRect.bottom <= offset;
}

// 触底了
if (isReachEdge) {
loading.value = true;
emit('update:loading', true);
emit('load');
}
});
};

从 check 函数可以看出,主要就是利用滚动高度,接下来我们看这个函数中,使用到的组合式 API,useTabStatus、useScrollParent、useRect。

6.1 useTabStatus tab 组件的状态

1
2
3
4
5
6
js复制代码import { inject, ComputedRef, InjectionKey } from 'vue';

// eslint-disable-next-line
export const TAB_STATUS_KEY: InjectionKey<ComputedRef<boolean>> = Symbol();

export const useTabStatus = () => inject(TAB_STATUS_KEY, null);

代码根据 commit 可以发现 useTabStatus 有这样一次提交。

fix(List): skip check when inside an inactive tab

主要是在 van-tabs 组件中,provide(TAB_STATUS_KEY, active); 提供了一个状态。tab 不活跃时,跳过 check 函数,不执行。

6.2 useScrollParent 获取元素最近的可滚动父元素

获取元素最近的可滚动父元素。

给定参数 el, root 节点,遍历父级节点查找 style 包含 scroll|auto|overlay 的元素,如果没找到,返回第二个 root 参数(没有第二个参数则是 window)。

useScrollParent 文档

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
js复制代码import { ref, Ref, onMounted } from 'vue';
import { inBrowser } from '../utils';

type ScrollElement = HTMLElement | Window;

const overflowScrollReg = /scroll|auto|overlay/i;
const defaultRoot = inBrowser ? window : undefined;

// 元素节点
function isElement(node: Element) {
const ELEMENT_NODE_TYPE = 1;
return (
node.tagName !== 'HTML' &&
node.tagName !== 'BODY' &&
node.nodeType === ELEMENT_NODE_TYPE
);
}

// https://github.com/vant-ui/vant/issues/3823
export function getScrollParent(
el: Element,
root: ScrollElement | undefined = defaultRoot
) {
let node = el;

// 遍历得到父级滚动的元素,style 样式包含 scroll|auto|overlay 的节点
while (node && node !== root && isElement(node)) {
const { overflowY } = window.getComputedStyle(node);
if (overflowScrollReg.test(overflowY)) {
return node;
}
node = node.parentNode as Element;
}

// 没找到返回参数 root,如果没传参,默认是 window
return root;
}

export function useScrollParent(
el: Ref<Element | undefined>,
root: ScrollElement | undefined = defaultRoot
) {
const scrollParent = ref<Element | Window>();

onMounted(() => {
if (el.value) {
scrollParent.value = getScrollParent(el.value, root);
}
});

return scrollParent;
}

6.3 useRect 获取元素的大小及其相对于视口的位置

vant-contrib.gitee.io/vant/#/zh-C…

获取元素的大小及其相对于视口的位置,等价于 Element.getBoundingClientRect。

getBoundingClientRect

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
js复制代码// vant/packages/vant-use/src/useRect/index.ts
import { Ref, unref } from 'vue';

const isWindow = (val: unknown): val is Window => val === window;

const makeDOMRect = (width: number, height: number) =>
({
top: 0,
left: 0,
right: width,
bottom: height,
width,
height,
} as DOMRect);

export const useRect = (
elementOrRef: Element | Window | Ref<Element | Window | undefined>
) => {
// unref():如果参数是 ref,则返回内部值,否则返回参数本身。这是 val = isRef(val) ? val.value : val 计算的一个语法糖。
const element = unref(elementOrRef);

// 如果是 window 直接返回 innerWidth 和 innerHeight
if (isWindow(element)) {
const width = element.innerWidth;
const height = element.innerHeight;
return makeDOMRect(width, height);
}

// 否则用 getBoundingClientRect api
if (element?.getBoundingClientRect) {
return element.getBoundingClientRect();
}

// 不支持的情况下返回 0 0
return makeDOMRect(0, 0);
};

6.4 isHidden 是否隐藏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
js复制代码// vant/packages/vant/src/utils/dom.ts
export function isHidden(
elementRef: HTMLElement | Ref<HTMLElement | undefined>
) {
const el = unref(elementRef);
if (!el) {
return false;
}

const style = window.getComputedStyle(el);
const hidden = style.display === 'none';

// offsetParent returns null in the following situations:
// 1. The element or its parent element has the display property set to none.
// 2. The element has the position property set to fixed
const parentHidden = el.offsetParent === null && style.position !== 'fixed';

return hidden || parentHidden;
}

接着我们来分析开头的插槽部分。

  1. 插槽

插槽部分基本都是有插槽用插槽没有则用默认的。

插槽是函数,比如 slots.default()。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
js复制代码// setup 函数
return () => {
const Content = slots.default?.();
const Placeholder = <div ref={placeholder} class={bem('placeholder')} />;

return (
<div ref={root} role="feed" class={bem()} aria-busy={loading.value}>
{props.direction === 'down' ? Content : Placeholder}
// 比如:加载中
{renderLoading()}
// 结束文字 比如:没有更多了
{renderFinishedText()}
// 加载错误文字:比如加载失败
{renderErrorText()}
{props.direction === 'up' ? Content : Placeholder}
</div>
);
};

7.1 renderFinishedText 渲染加载完成文字

1
2
3
4
5
6
7
8
js复制代码const renderFinishedText = () => {
if (props.finished) {
const text = slots.finished ? slots.finished() : props.finishedText;
if (text) {
return <div class={bem('finished-text')}>{text}</div>;
}
}
};

7.2 renderErrorText 渲染加载失败文字

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
js复制代码const clickErrorText = () => {
emit('update:error', false);
check();
};

const renderErrorText = () => {
if (props.error) {
const text = slots.error ? slots.error() : props.errorText;
if (text) {
return (
<div
role="button"
class={bem('error-text')}
tabindex={0}
onClick={clickErrorText}
>
{text}
</div>
);
}
}
};

7.3 renderLoading 渲染 loading

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
js复制代码const renderLoading = () => {
if (loading.value && !props.finished) {
return (
<div class={bem('loading')}>
{slots.loading ? (
slots.loading()
) : (
<Loading class={bem('loading-icon')}>
{props.loadingText || t('loading')}
</Loading>
)}
</div>
);
}
};
  1. 总结

我们主要分析了 List 组件 实现原理。

原理:使用 addEventListener 监听父级元素的 sroll 事件,用 Element.getBoundingClientRect 获取元素的大小及其相对于视口的位置,(滚动父级元素和占位元素计算和组件属性 offset(默认300) 属性比较),检测是否触底,触底则加载更多。

1
2
js复制代码emit('update:loading', true);
emit('load');

同时分析了一些相关组合式 API

  • useExpose 暴露接口供 this.$refs.xxx 使用
  • useEventListener 绑定事件
  • useTabStatus 当前 tab 是否激活的状态
  • useScrollParent 获取元素最近的可滚动父元素
  • useRect 获取元素的大小及其相对于视口的位置

组件留有四个插槽,分别是:

  • default 列表内容
  • loading 自定义底部加载中提示
  • finished 自定义加载完成后的提示文案
  • error 自定义加载失败后的提示文案

至此,我们就分析完了 List 组件,主要与 DOM 操作会比较多。List 组件 主文件的代码仅有 100 多行,但封装了很多组合式 API 。看完这篇源码文章,再去看 List 组件文档,可能就会有豁然开朗的感觉。再看其他组件,可能就可以猜测出大概实现的代码了。

如果是使用 react、Taro 技术栈,感兴趣也可以看看 taroify List 组件的实现 文档,源码。

如果看完有收获,欢迎点赞、评论、分享支持。你的支持和肯定,是我写作的动力。

  1. 加源码共读群交流

最后可以持续关注我@若川。我会写一个组件库源码系列专栏,欢迎大家关注。

我倾力持续组织了一年每周大家一起学习200行左右的源码共读活动,感兴趣的可以点此扫码加我微信 ruochuan02 参与。

另外,想学源码,极力推荐关注我写的专栏《学习源码整体架构系列》,目前是掘金关注人数(4.1k+人)第一的专栏,写有20余篇源码文章。

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

Java & Android 集合框架

发表于 2022-11-12

⭐️ 本文已收录到 AndroidFamily,技术和职场问题,请关注公众号 [彭旭锐] 和 [BaguTree Pro] 知识星球提问。

学习数据结构与算法的关键在于掌握问题背后的算法思维框架,你的思考越抽象,它能覆盖的问题域就越广,理解难度也更复杂。在实际的业务开发中,往往不需要我们手写数据结构,而是直接使用标准库的数据结构 / 容器类。

本文是 Java & Android 集合框架系列的第 8 篇文章,完整文章目录请移步到文章末尾~

前言

大家好,我是小彭。

在之前的文章里,我们聊到了 Java 标准库中 HashMap 与 LinkedHashMap 的实现原理。HashMap 是一个标准的散列表数据结构,而 LinkedHashMap 是在 HashMap 的基础上实现的哈希链表。

今天,我们来讨论 WeakHashMap,其中的 “Weak” 是指什么,与前两者的使用场景有何不同?我们就围绕这些问题展开。

提示: 本文源码基于 JDK 1.2 WeakHashMap。


思维导图:


  1. 回顾 HashMap 和 LinkedHashMap

其实,WeakHashMap 与 HashMap 和 LinkedHashMap 的数据结构大同小异,所以我们先回顾后者的实现原理。

1.1 说一下 HashMap 的实现结构

HashMap 是基于分离链表法解决散列冲突的动态散列表。

  • 1、HashMap 在 Java 7 中使用的是 “数组 + 链表”,发生散列冲突的键值对会用头插法添加到单链表中;
  • 2、HashMap 在 Java 8 中使用的是 “数组 + 链表 + 红黑树”,发生散列冲突的键值对会用尾插法添加到单链表中。如果链表的长度大于 8 时且散列表容量大于 64,会将链表树化为红黑树。

HashMap 实现示意图

1.2 说一下 LinkedHashMap 的实现结构

LinkedHashMap 是继承于 HashMap 实现的哈希链表。

  • 1、LinkedHashMap 同时具备双向链表和散列表的特点。当 LinkedHashMap 作为散列表时,主要体现出 O(1) 时间复杂度的查询效率。当 LinkedHashMap 作为双向链表时,主要体现出有序的特性;
  • 2、LinkedHashMap 支持 FIFO 和 LRU 两种排序模式,默认是 FIFO 排序模式,即按照插入顺序排序。Android 中的 LruCache 内存缓存和 DiskLruCache 磁盘缓存也是直接复用 LinkedHashMap 提供的缓存管理能力。

LinkedHashMap 示意图


  1. 认识 WeakHashMap

2.1 WeakReference 弱引用的特点

WeakHashMap 中的 “Weak” 指键 Key 是弱引用,也叫弱键。弱引用是 Java 四大引用类型之一,一共有四种引用类型,分别是强引用、软引用、弱引用和虚引用。我将它们的区别概括为 3 个维度:

  • 维度 1 - 对象可达性状态的区别: 强引用指向的对象是强可达的,只有强可达的对象才会认为是存活的对象,才能保证在垃圾收集的过程中不会被回收;
  • 维度 2 - 垃圾回收策略的区别: 不同的引用类型的回收激进程度不同,
    • 强引用指向的对象不会被回收;
    • 软引用指向的对象在内存充足时不会被回收,在内存不足时会被回收;
    • 弱引用和虚引用指向的对象无论在内存是否充足的时候都会被回收;
  • 维度 3 - 感知垃圾回收时机: 当引用对象关联的实际对象被垃圾回收时,引用对象会进入关联的引用队列,程序可以通过观察引用队列的方式,感知对象被垃圾回收的时机。

感知垃圾回收示意图

提示: 关于 “Java 四种引用类型” 的区别,在小彭的 Java 专栏中深入讨论过 《吊打面试官:说一下 Java 的四种引用类型》,去看看。

2.2 WeakHashMap 的特点

WeakHashMap 是使用弱键的动态散列表,用于实现 “自动清理” 的内存缓存。

  • 1、WeakHashMap 使用与 Java 7 HashMap 相同的 “数组 + 链表” 解决散列冲突,发生散列冲突的键值对会用头插法添加到单链表中;
  • 2、WeakHashMap 依赖于 Java 垃圾收集器自动清理不可达对象的特性。当 Key 对象不再被持有强引用时,垃圾收集器会按照弱引用策略自动回收 Key 对象,并在下次访问 WeakHashMap 时清理全部无效的键值对。因此,WeakHashMap 特别适合实现 “自动清理” 的内存活动缓存,当键值对有效时保留,在键值对无效时自动被垃圾收集器清理;
  • 3、需要注意,因为 WeakHashMap 会持有 Value 对象的强引用,所以在 Value 对象中一定不能持有 key 的强引用。否则,会阻止垃圾收集器回收 “本该不可达” 的 Key 对象,使得 WeakHashMap 失去作用。
  • 4、与 HashMap 相同,LinkedHashMap 也不考虑线程同步,也会存在线程安全问题。可以使用 Collections.synchronizedMap 包装类,其原理也是在所有方法上增加 synchronized 关键字。

WeakHashMap 示意图

自动清理数据

2.3 说一下 WeakHashMap 与 HashMap 和 LinkedHashMap 的区别?

WeakHashMap 与 HashMap 都是基于分离链表法解决散列冲突的动态散列表,两者的主要区别在 键 Key 的引用类型上:

  • HashMap 会持有键 Key 的强引用,除非手动移除,否则键值对会长期存在于散列表中;
  • WeakHashMap 只持有键 Key 的弱引用,当 Key 对象不再被外部持有强引用时,键值对会被自动被清理。

WeakHashMap 与 LinkedHashMap 都有自动清理的能力,两者的主要区别在于 淘汰数据的策略上:

  • LinkedHashMap 会按照 FIFO 或 LRU 的策略 “尝试” 淘汰数据,需要开发者重写 removeEldestEntry() 方法实现是否删除最早节点的判断逻辑;
  • WeakHashMap 会按照 Key 对象的可达性淘汰数据,当 Key 对象不再被持有强引用时,会自动清理无效数据。

2.4 重建 Key 对象不等价的问题

WeakHashMap 的 Key 使用弱引用,也就是以 Key 作为清理数据的判断锚点,当 Key 变得不可达时会自动清理数据。此时,如果使用多个 equals 相等的 Key 对象访问键值对,就会出现第 1 个 Key 对象不可达导致键值对被回收,而第 2 个 Key 查询键值对为 null 的问题。 这说明 equals 相等的 Key 对象在 HashMap 等散列表中是等价的,但是在 WeakHashMap 散列表中是不等价的。

因此,如果 Key 类型没有重写 equals 方法,那么 WeakHashMap 就表现良好,否则会存在歧义。例如下面这个 Demo 中,首先创建了指向 image_url1 的图片 Key1,再重建了同样指向 image_url1 的图片 Key2。在 HashMap 中,Key1 和 Key2 等价,但在 WeakHashMap 中,Key1 和 Key2 不等价。

Demo

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
java复制代码class ImageKey {
private String url;

ImageKey(String url) {
this.url = url;
}

public boolean equals(Object obj) {
return (obj instanceOf ImageKey) && Objects.equals(((ImageKey)obj).url, this.url);
}
}

WeakHashMap<ImageKey, Bitmap> map = new WeakHashMap<>();
ImageKey key1 = new ImageKey("image_url1");
ImageKey key2 = new ImageKey("image_url2");
// key1 equalsTo key3
ImageKey key3 = new ImageKey("image_url1");

map.put(key1, bitmap1);
map.put(key2, bitmap2);

System.out.println(map.get(key1)); // 输出 bitmap1
System.out.println(map.get(key2)); // 输出 bitmap2
System.out.println(map.get(key3)); // 输出 bitmap1

// 使 key1 不可达,key3 保持
key1 = null;

// 说明重建 Key 与原始 Key 不等价
System.out.println(map.get(key1)); // 输出 null
System.out.println(map.get(key2)); // 输出 bitmap2
System.out.println(map.get(key3)); // 输出 null

默认的 Object#equals 是判断两个变量是否指向同一个对象:

Object.java

1
2
3
java复制代码public boolean equals(Object obj) {
return (this == obj);
}

2.5 Key 弱引用和 Value 弱引用的区别

不管是 Key 还是 Value 使用弱引用都可以实现自动清理,至于使用哪一种方法各有优缺点,适用场景也不同。

  • Key 弱引用: 以 Key 作为清理数据的判断锚点,当 Key 不可达时清理数据。优点是容器外不需要持有 Value 的强引用,缺点是重建的 Key 与原始 Key 不等价,重建 Key 无法阻止数据被清理;
  • Value 弱引用: 以 Value 作为清理数据的判断锚点,当 Value 不可达时清理数据。优点是重建 Key 与与原始 Key 等价,缺点是容器外需要持有 Value 的强引用。
类型 优点 缺点 场景
Key 弱引用 外部不需要持有 Value 的强引用,使用更简单 重建 Key 不等价 未重写 equals
Value 弱引用 重建 Key 等价 外部需要持有 Value 的强引用 重写 equals

举例 1: 在 Android Glide 图片框架的多级缓存中,因为图片的 EngineKey 是可重建的,存在多个 EngineKey 对象指向同一个图片 Bitmap,所以 Glide 最顶层的活动缓存采用的是 Value 弱引用。

EngineKey.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
java复制代码class EngineKey implements Key {

// 重写 equals
@Override
public boolean equals(Object o) {
if (o instanceof EngineKey) {
EngineKey other = (EngineKey) o;
return model.equals(other.model)
&& signature.equals(other.signature)
&& height == other.height
&& width == other.width
&& transformations.equals(other.transformations)
&& resourceClass.equals(other.resourceClass)
&& transcodeClass.equals(other.transcodeClass)
&& options.equals(other.options);
}
return false;
}
}

举例 2: 在 ThreadLocal 的 ThreadLocalMap 线程本地存储中,因为 ThreadLocal 没有重写 equals,不存在多个 ThreadLocal 对象指向同一个键值对的情况,所以 ThreadLocal 采用的是 Key 弱引用。

ThreadLocal.java

1
2
3
4
5
6
7
8
9
10
11
java复制代码static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}

// 未重写 equals
}

  1. WeakHashMap 源码分析

这一节,我们来分析 WeakHashMap 中主要流程的源码。

事实上,WeakHashMap 就是照着 Java 7 版本的 HashMap 依葫芦画瓢的,没有树化的逻辑。考虑到我们已经对 HashMap 做过详细分析,所以我们没有必要重复分析 WeakHashMap 的每个细节,而是把重心放在 WeakHashMap 与 HashMap 不同的地方。

3.1 WeakHashMap 的属性

先用一个表格整理 WeakHashMap 的属性:

版本 数据结构 节点实现类 属性
Java 7 HashMap 数组 + 链表 Entry(单链表) 1、table(数组)2、size(尺寸)3、threshold(扩容阈值)4、loadFactor(装载因子上限)5、modCount(修改计数)6、默认数组容量 167、最大数组容量 2^308、默认负载因子 0.75
WeakHashMap 数组 + 链表 Entry(单链表,弱引用的子类型) 9、queue(引用队列)

WeakHashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
java复制代码public class WeakHashMap<K,V> extends AbstractMap<K,V> implements Map<K,V> {

// 默认数组容量
private static final int DEFAULT_INITIAL_CAPACITY = 16;

// 数组最大容量:2^30(高位 0100,低位都是 0)
private static final int MAXIMUM_CAPACITY = 1 << 30;

// 默认装载因子上限:0.75
private static final float DEFAULT_LOAD_FACTOR = 0.75f;

// 底层数组
Entry<K,V>[] table;

// 键值对数量
private int size;

// 扩容阈值(容量 * 装载因子)
private int threshold;

// 装载因子上限
private final float loadFactor;

// 引用队列
private final ReferenceQueue<Object> queue = new ReferenceQueue<>();

// 修改计数
int modCount;

// 链表节点(一个 Entry 等于一个键值对)
private static class Entry<K,V> extends WeakReference<Object> implements Map.Entry<K,V> {
// Key:与 HashMap 和 LinkedHashMap 相比,少了 key 的强引用
// final K key;
// Value(强引用)
V value;
// 哈希值
final int hash;
Entry<K,V> next;

Entry(Object key, V value, ReferenceQueue<Object> queue, int hash, Entry<K,V> next) {
super(key /*注意:只有 Key 是弱引用*/, queue);
this.value = value;
this.hash = hash;
this.next = next;
}
}
}

WeakHashMap 与 HashMap 的属性几乎相同,主要区别有 2 个:

  • 1、ReferenceQueue: WeakHashMap 的属性里多了一个 queue 引用队列;
  • 2、Entry: WeakHashMap#Entry 节点继承于 WeakReference,表面看是 WeakHashMap 持有了 Entry 的强引用,其实不是。注意看 Entry 的构造方法,WeakReference 关联的实际对象是 Key。 所以,WeakHashMap 依然持有 Entry 和 Value 的强引用,仅持有 Key 的弱引用。

引用关系示意图

不出意外的话又有小朋友出来举手提问了🙋🏻‍♀️:

  • 🙋🏻‍♀️疑问 1:说一下 ReferenceQueue queue 的作用?

ReferenceQueue 与 Reference 配合能够实现感知对象被垃圾回收的能力。在创建引用对象时可以关联一个实际对象和一个引用队列,当实现对象被垃圾回收后,引用对象会被添加到这个引用队列中。在 WeakHashMap 中,就是根据这个引用队列来自动清理无效键值对。

  • 🙋🏻‍♀️疑问 2:为什么 Key 是弱引用,而不是 Entry 或 Value 是弱引用?

首先,Entry 一定要持有强引用,而不能持有弱引用。这是因为 Entry 是 WeakHashMap 内部维护数据结构的实现细节,并不会暴露到 WeakHashMap 外部,即除了 WeakHashMap 本身之外没有其它地方持有 Entry 的强引用。所以,如果持有 Entry 的弱引用,即使 WeakHashMap 外部依然在使用 Key 对象,WeakHashMap 内部依然会回收键值对,这与预期不符。

其次,不管是 Key 还是 Value 使用弱引用都可以实现自动清理。至于使用哪一种方法各有优缺点,适用场景也不同,这个在前文分析过了。

3.2 WeakHashMap 如何清理无效数据?

在通过 put / get /size 等方法访问 WeakHashMap 时,其内部会调用 expungeStaleEntries() 方法清理 Key 对象已经被回收的无效键值对。其中会遍历 ReferenceQueue 中持有的弱引用对象(即 Entry 节点),并将该结点从散列表中移除。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
java复制代码private final ReferenceQueue<Object> queue = new ReferenceQueue<>();

// 添加键值对
public V put(K key, V value) {
...
// 间接 expungeStaleEntries()
Entry<K,V>[] tab = getTable();
...
}

// 扩容
void resize(int newCapacity) {
// 间接 expungeStaleEntries()
Entry<K,V>[] oldTable = getTable();
...
}

// 获取键值对
public V get(Object key) {
...
// 间接 expungeStaleEntries()
Entry<K,V>[] tab = getTable();
...
}

private Entry<K,V>[] getTable() {
// 清理无效键值对
expungeStaleEntries();
return table;
}

// ->清理无效键值对
private void expungeStaleEntries() {
// 遍历引用队列
for (Object x; (x = queue.poll()) != null; ) {
// 疑问 3:既然 WeakHashMap 不考虑线程同步,为什么这里要做加锁,岂不是突兀?
synchronized (queue) {
Entry<K,V> e = (Entry<K,V>) x;
// 根据散列值定位数组下标
int i = indexFor(e.hash /*散列值*/, table.length);
// 遍历桶寻找节点 e 的前驱结点
Entry<K,V> prev = table[i];
Entry<K,V> p = prev;
while (p != null) {
Entry<K,V> next = p.next;
if (p == e) {
// 删除节点 e
if (prev == e)
// 节点 e 是根节点
table[i] = next;
else
// 节点 e 是中间节点
prev.next = next;
// Must not null out e.next;
// stale entries may be in use by a HashIterator
e.value = null; // Help GC
size--;
break;
}
prev = p;
p = next;
}
}
}
}

  1. 总结

  • 1、WeakHashMap 使用与 Java 7 HashMap 相同的 “数组 + 链表” 解决散列冲突,发生散列冲突的键值对会用头插法添加到单链表中;
  • 2、WeakHashMap 能够实现 “自动清理” 的内存缓存,其中的 “Weak” 指键 Key 是弱引用。当 Key 对象不再被持有强引用时,垃圾收集器会按照弱引用策略自动回收 Key 对象,并在下次访问 WeakHashMap 时清理全部无效的键值对;
  • 3、WeakHashMap 和 LinkedHashMap 都具备 “自动清理” 的 能力,WeakHashMap 根据 Key 对象的可达性淘汰数据,而 LinkedHashMap 根据 FIFO 或 LRU 策略尝试淘汰数据;
  • 4、WeakHashMap 使用 Key 弱引用,会存在重建 Key 对象不等价问题。

版权声明

本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!

推荐阅读

Java & Android 集合框架系列文章目录(2023/07/08 更新):

  • #1 ArrayList 可以完全替代数组吗?
  • #2 说一下 ArrayList 和 LinkedList 的区别?
  • #3 CopyOnWriteArrayList 是如何保证线程安全的?
  • #4 ArrayDeque:如何用数组实现栈和队列?
  • #5 万字 HashMap 详解,基础(优雅)永不过时 —— 原理篇
  • #6 万字 HashMap 详解,基础(优雅)永不过时 —— 源码篇
  • #7 如何使用 LinkedHashMap 实现 LRU 缓存?
  • #8 说一下 WeakHashMap 如何清理无效数据的?
  • #9 全网最全的 ThreadLocal 原理详细解析 —— 原理篇
  • #10 全网最全的 ThreadLocal 原理详细解析 —— 源码篇

数据结构与算法系列文章:跳转阅读

⭐️ 永远相信美好的事情即将发生,欢迎加入小彭的 Android 交流社群~

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

Java & Android 集合框架

发表于 2022-11-10

⭐️ 本文已收录到 AndroidFamily,技术和职场问题,请关注公众号 [彭旭锐] 和 [BaguTree Pro] 知识星球提问。

学习数据结构与算法的关键在于掌握问题背后的算法思维框架,你的思考越抽象,它能覆盖的问题域就越广,理解难度也更复杂。在实际的业务开发中,往往不需要我们手写数据结构,而是直接使用标准库的数据结构 / 容器类。

本文是 Java & Android 集合框架系列的第 7 篇文章,完整文章目录请移步到文章末尾~

前言

大家好,我是小彭。

在上一篇文章里,我们聊到了 HashMap 的实现原理和源码分析,在源码分析的过程中,我们发现一些 LinkedHashMap 相关的源码,当时没有展开,现在它来了。

那么,LinkedHashMap 与 HashMap 有什么区别呢?其实,LinkedHashMap 的使用场景非常明确 —— LRU 缓存。今天,我们就来讨论 LinkedHashMap 是如何实现 LRU 缓存的。

本文源码基于 Java 8 LinkedHashMap。


思维导图:


  1. 认识 LRU 缓存淘汰算法

1.1 什么是缓存淘汰算法?

缓存是提高数据读取性能的通用技术,在硬件和软件设计中被广泛使用,例如 CPU 缓存、Glide 内存缓存,数据库缓存等。由于缓存空间不可能无限大,当缓存容量占满时,就需要利用某种策略将部分数据换出缓存,这就是缓存的替换策略 / 淘汰问题。常见缓存淘汰策略有:

  • 1、随机策略: 使用一个随机数生成器随机地选择要被淘汰的数据块;
  • 2、FIFO 先进先出策略: 记录各个数据块的访问时间,最早访问的数据最先被淘汰;
  • 3、LRU (Least Recently Used)最近最少策略: 记录各个数据块的访问 “时间戳” ,最近最久未使用的数据最先被淘汰。与前 2 种策略相比,LRU 策略平均缓存命中率更高,这是因为 LRU 策略利用了 “局部性原理”:最近被访问过的数据,将来被访问的几率较大,最近很久未访问的数据,将来访问的几率也较小;
  • 4、LFU (Least Frequently Used)最不经常使用策略: 与 LRU 相比,LFU 更加注重使用的 “频率” 。LFU 会记录每个数据块的访问次数,最少访问次数的数据最先被淘汰。但是有些数据在开始时使用次数很高,以后不再使用,这些数据就会长时间污染缓存。可以定期将计数器右移一位,形成指数衰减。

FIFO 与 LRU 策略

1.2 向外看:LRU 的变型

其实,在标准的 LRU 算法上还有一些变型实现,这是因为 LRU 算法本身也存在一些不足。例如,当数据中热点数据较多时,LRU 能够保证较高的命中率。但是当有偶发的批量的非热点数据产生时,就会将热点数据寄出缓存,使得缓存被污染。因此,LRU 也有一些变型:

  • LRU-K: 提供两个 LRU 队列,一个是访问计数队列,一个是标准的 LRU 队列,两个队列都按照 LRU 规则淘汰数据。当访问一个数据时,数据先进入访问计数队列,当数据访问次数超过 K 次后,才会进入标准 LRU 队列。标准的 LRU 算法相当于 LRU-1;
  • Two Queue: 相当于 LRU-2 的变型,将访问计数队列替换为 FIFO 队列淘汰数据数据。当访问一个数据时,数据先进入 FIFO 队列,当第 2 次访问数据时,才会进入标准 LRU 队列;
  • Multi Queue: 在 LRU-K 的基础上增加更多队列,提供多个级别的缓冲。

小彭在 Redis 和 Vue 中有看到这些 LRU 变型的应用,在 Android 领域的框架中还没有看到具体应用,你知道的话可以提醒我。

1.3 如何实现 LRU 缓存淘汰算法?

这一小节,我们尝试找到 LRU 缓存淘汰算法的实现方案。经过总结,我们可以定义一个缓存系统的基本操作:

  • 操作 1 - 添加数据: 先查询数据是否存在,不存在则添加数据,存在则更新数据,并尝试淘汰数据;
  • 操作 2 - 删除数据: 先查询数据是否存在,存在则删除数据;
  • 操作 3 - 查询数据: 如果数据不存在则返回 null;
  • 操作 4 - 淘汰数据: 添加数据时如果容量已满,则根据缓存淘汰策略一个数据。

我们发现,前 3 个操作都有 “查询” 操作, 所以缓存系统的性能主要取决于查找数据和淘汰数据是否高效。 下面,我们用递推的思路推导 LRU 缓存的实现方案,主要分为 3 种方案:

  • 方案 1 - 基于时间戳的数组: 在每个数据块中记录最近访问的时间戳,当数据被访问(添加、更新或查询)时,将数据的时间戳更新到当前时间。当数组空间已满时,则扫描数组淘汰时间戳最小的数据。
+ 查找数据: 需要遍历整个数组找到目标数据,时间复杂度为 O(n);
+ 淘汰数据: 需要遍历整个数组找到时间戳最小的数据,且在移除数组元素时需要搬运数据,整体时间复杂度为 O(n)。
  • 方案 2 - 基于双向链表: 不再直接维护时间戳,而是利用链表的顺序隐式维护时间戳的先后顺序。当数据被访问(添加、更新或查询)时,将数据插入到链表头部。当空间已满时,直接淘汰链表的尾节点。
+ 查询数据:需要遍历整个链表找到目标数据,时间复杂度为 O(n);
+ 淘汰数据:直接淘汰链表尾节点,时间复杂度为 O(1)。
  • 方案 3 - 基于双向链表 + 散列表: 使用双向链表可以将淘汰数据的时间复杂度降低为 O(1),但是查询数据的时间复杂度还是 O(n),我们可以在双向链表的基础上增加散列表,将查询操作的时间复杂度降低为 O(1)。
+ 查询数据:通过散列表定位数据,时间复杂度为 O(1);
+ 淘汰数据:直接淘汰链表尾节点,时间复杂度为 O(1)。

方案 3 这种数据结构就叫 “哈希链表或链式哈希表”,我更倾向于称为哈希链表,因为当这两个数据结构相结合时,我们更看重的是它作为链表的排序能力。

我们今天要讨论的 Java LinkedHashMap 就是基于哈希链表的数据结构。


  1. 认识 LinkedHashMap 哈希链表

2.1 说一下 LinkedHashMap 的特点

需要注意:LinkedHashMap 中的 “Linked” 实际上是指双向链表,并不是指解决散列冲突中的分离链表法。

  • 1、LinkedHashMap 是继承于 HashMap 实现的哈希链表,它同时具备双向链表和散列表的特点。事实上,LinkedHashMap 继承了 HashMap 的主要功能,并通过 HashMap 预留的 Hook 点维护双向链表的逻辑。
+ 1.1 当 LinkedHashMap 作为散列表时,主要体现出 O(1) 时间复杂度的查询效率;
+ 1.2 当 LinkedHashMap 作为双向链表时,主要体现出有序的特性。
  • 2、LinkedHashMap 支持 2 种排序模式,这是通过构造器参数 accessOrder 标记位控制的,表示是否按照访问顺序排序,默认为 false 按照插入顺序。
+ **2.1 插入顺序(默认):** 按照数据添加到 LinkedHashMap 的顺序排序,即 FIFO 策略;
+ **2.2 访问顺序:** 按照数据被访问(包括插入、更新、查询)的顺序排序,即 LRU 策略。
  • 3、在有序性的基础上,LinkedHashMap 提供了维护了淘汰数据能力,并开放了淘汰判断的接口 removeEldestEntry()。在每次添加数据时,会回调 removeEldestEntry() 接口,开发者可以重写这个接口决定是否移除最早的节点(在 FIFO 策略中是最早添加的节点,在 LRU 策略中是最早未访问的节点);
  • 4、与 HashMap 相同,LinkedHashMap 也不考虑线程同步,也会存在线程安全问题。可以使用 Collections.synchronizedMap 包装类,其原理也是在所有方法上增加 synchronized 关键字。

2.2 说一下 HashMap 和 LinkedHashMap 的区别?

事实上,HashMap 和 LinkedHashMap 并不是平行的关系,而是继承的关系,LinkedHashMap 是继承于 HashMap 实现的哈希链表。

两者主要的区别在于有序性: LinkedHashMap 会维护数据的插入顺序或访问顺序,而且封装了淘汰数据的能力。在迭代器遍历时,HashMap 会按照数组顺序遍历桶节点,从开发者的视角看是无序的。而是按照双向链表的顺序从 head 节点开始遍历,从开发者的视角是可以感知到的插入顺序或访问顺序。

LinkedHashMap 示意图


  1. HashMap 预留的 Hook 点

LinkedHashMap 继承于 HashMap,在后者的基础上通过双向链表维护节点的插入顺序或访问顺序。因此,我们先回顾下 HashMap 为 LinkedHashMap 预留的 Hook 点:

  • afterNodeAccess: 在节点被访问时回调;
  • afterNodeInsertion: 在节点被插入时回调,其中有参数 evict 标记是否淘汰最早的节点。在初始化、反序列化或克隆等构造过程中,evict 默认为 false,表示在构造过程中不淘汰。
  • afterNodeRemoval: 在节点被移除时回调。

HashMap.java

1
2
3
4
5
6
7
java复制代码// 节点访问回调
void afterNodeAccess(Node<K,V> p) { }
// 节点插入回调
// evict:是否淘汰最早的节点
void afterNodeInsertion(boolean evict) { }
// 节点移除回调
void afterNodeRemoval(Node<K,V> p) { }

除此了这 3 个空方法外,LinkedHashMap 也重写了部分 HashMap 的方法,在其中插入双链表的维护逻辑,也相当于 Hook 点。在 HashMap 的添加、获取、移除方法中,与 LinkedHashMap 有关的 Hook 点如下:

3.1 HashMap 的添加方法中的 Hook 点

LinkedHashMap 直接复用 HashMap 的添加方法,也支持批量添加:

  • HashMap#put: 逐个添加或更新键值对;
  • HashMap#putAll: 批量添加或更新键值对。

不管是逐个添加还是批量添加,最终都会先通过 hash 函数计算键(Key)的散列值,再通过 HashMap#putVal 添加或更新键值对,这些都是 HashMap 的行为。关键的地方在于:LinkedHashMap 在 HashMap#putVal 的 Hook 点中加入了双线链表的逻辑。区分 2 种情况:

  • 添加数据: 如果数据不存在散列表中,则调用 newNode() 或 newTreeNode() 创建节点,并回调 afterNodeInsertion();
  • 更新数据: 如果数据存在散列表中,则更新 Value,并回调 afterNodeAccess()。

HashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
java复制代码// 添加或更新键值对
public V put(K key, V value) {
return putVal(hash(key) /*计算散列值*/, key, value, false, true);
}

// hash:Key 的散列值(经过扰动)
final V putVal(int hash, K key, V value, boolean onlyIfAbsent, boolean evict) {
Node<K,V>[] tab;
Node<K,V> p;
int n;
int i;
if ((tab = table) == null || (n = tab.length) == 0)
n = (tab = resize()).length;
// (n - 1) & hash:散列值转数组下标
if ((p = tab[i = (n - 1) & hash]) == null)
// 省略遍历桶的代码,具体分析见 HashMap 源码讲解

// 1.1 如果节点不存在,则新增节点
p.next = newNode(hash, key, value, null);
// 2.1 如果节点存在更新节点 Value
if (e != null) {
V oldValue = e.value;
if (!onlyIfAbsent || oldValue == null)
e.value = value;
// 2.2 Hook:访问节点回调
afterNodeAccess(e);
return oldValue;
}
}
++modCount;
// 扩容
if (++size > threshold)
resize();
// 1.2 Hook:新增节点回调
afterNodeInsertion(evict);
return null;
}

HashMap#put 示意图

3.2 HashMap 的获取方法中的 Hook 点

LinkedHashMap 重写了 HashMap#get 方法,在 HashMap 版本的基础上,增加了 afterNodeAccess() 回调。

HashMap.java

1
2
3
4
java复制代码public V get(Object key) {
Node<K,V> e;
return (e = getNode(hash(key), key)) == null ? null : e.value;
}

LinkedHashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
java复制代码public V get(Object key) {
Node<K,V> e;
if ((e = getNode(hash(key), key)) == null)
return null;
// Hook:节点访问回调
if (accessOrder)
afterNodeAccess(e);
return e.value;
}

public V getOrDefault(Object key, V defaultValue) {
Node<K,V> e;
if ((e = getNode(hash(key), key)) == null)
return defaultValue;
// Hook:节点访问回调
if (accessOrder)
afterNodeAccess(e);
return e.value;
}

HashMap#get 示意图

3.3 HashMap 的移除方法中的 Hook 点

LinkedHashMap 直接复用 HashMap 的移除方法,在移除节点后,增加 afterNodeRemoval() 回调。

HashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
java复制代码// 移除节点
public V remove(Object key) {
Node<K,V> e;
return (e = removeNode(hash(key)/*计算散列值*/, key, null, false, true)) == null ? null : e.value;
}

final Node<K,V> removeNode(int hash, Object key, Object value,
boolean matchValue, boolean movable) {
Node<K,V>[] tab;
Node<K,V> p;
int n, index;
// (n - 1) & hash:散列值转数组下标
if ((tab = table) != null && (n = tab.length) > 0 && (p = tab[index = (n - 1) & hash]) != null) {
Node<K,V> node = null, e; K k; V v;
// 省略遍历桶的代码,具体分析见 HashMap 源码讲解
// 删除 node 节点
if (node != null && (!matchValue || (v = node.value) == value || (value != null && value.equals(v)))) {
// 省略删除节点的代码,具体分析见 HashMap 源码讲解
++modCount;
--size;
// Hook:删除节点回调
afterNodeRemoval(node);
return node;
}
}
return null;
}

HashMap#remove 示意图


  1. LinkedHashMap 源码分析

这一节,我们来分析 LinkedHashMap 中主要流程的源码。

4.1 LinkedHashMap 的属性

  • LinkedHashMap 继承于 HashMap,并且新增 head 和 tail 指针指向链表的头尾节点(与 LinkedList 类似的头尾节点);
  • LinkedHashMap 的双链表节点 Entry 继承于 HashMap 的单链表节点 Node,而 HashMap 的红黑树节点 TreeNode 继承于 LinkedHashMap 的双链表节点 Entry。

节点继承关系

LinkedHashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
java复制代码public class LinkedHashMap<K,V> extends HashMap<K,V> implements Map<K,V> {
// 头指针
transient LinkedHashMap.Entry<K,V> head;
// 尾指针
transient LinkedHashMap.Entry<K,V> tail;
// 是否按照访问顺序排序
final boolean accessOrder;

// 双向链表节点
static class Entry<K,V> extends HashMap.Node<K,V> {
// 前驱指针和后继指针(用于双向链表)
Entry<K,V> before, after;
Entry(int hash, K key, V value, Node<K,V> next/*单链表指针(用于散列表的冲突解决)*/) {
super(hash, key, value, next);
}
}
}

LinkedList.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
java复制代码public class LinkedList<E> extends AbstractSequentialList<E> implements List<E>, Deque<E>, Cloneable, java.io.Serializable {
// 头指针(// LinkedList 中也有类似的头尾节点)
transient Node<E> first;
// 尾指针
transient Node<E> last;

// 双向链表节点
private static class Node<E> {
// 节点数据
// (类型擦除后:Object item;)
E item;
// 前驱指针
Node<E> next;
// 后继指针
Node<E> prev;

Node(Node<E> prev, E element, Node<E> next) {
this.item = element;
this.next = next;
this.prev = prev;
}
}
}

LinkedHashMap 的属性很好理解的,不出意外的话又有小朋友出来举手提问了:

  • 🙋🏻‍♀️疑问 1:HashMap.TreeNode 和 LinkedHashMap.Entry 的继承顺序是不是反了?

我的理解是作者希望简化节点类型,所以采用了非常规的做法(不愧是标准库)。由于 Java 是单继承的,如果按照常规的做法让 HashMap.TreeNode 直接继承 HashMap.Node,那么在 LinkedHashMap 中就需要区分 LinkedHashMap.Entry 和 LinkedHashMap.TreeEntry,再使用接口统一两种类型。

常规实现

4.2 LinkedHashMap 的构造方法

LinkedHashMap 有 5 个构造方法,作用与 HashMap 的构造方法基本一致,区别只在于对 accessOrder 字段的初始化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
java复制代码// 带初始容量和装载因子的构造方法
public LinkedHashMap(int initialCapacity, float loadFactor) {
super(initialCapacity, loadFactor);
accessOrder = false;
}

// 带初始容量的构造方法
public LinkedHashMap(int initialCapacity) {
super(initialCapacity);
accessOrder = false;
}

// 无参构造方法
public LinkedHashMap() {
super();
accessOrder = false;
}

// 带 Map 的构造方法
public LinkedHashMap(Map<? extends K, ? extends V> m) {
super();
accessOrder = false;
putMapEntries(m, false);
}

// 带初始容量、装载因子和 accessOrder 的构造方法
// 是否按照访问顺序排序,为 true 表示按照访问顺序排序,默认为 false
public LinkedHashMap(int initialCapacity, float loadFactor, boolean accessOrder) {
super(initialCapacity, loadFactor);
this.accessOrder = accessOrder;
}

4.3 LinkedHashMap 如何维护双链表

现在,我们看下 LinkedHashMap 是如何维护双链表的。其实,我们将上一节所有的 Hook 点汇总,会发现这些 Hook 点正好组成了 LinkedHashMap 双向链表的行为:

  • 添加数据: 将数据链接到双向链表的尾节点,时间复杂度为 O(1);
  • 访问数据(包括添加、查询、更新): 将数据移动到双向链表的尾节点,亦相当于先移除再添加到尾节点,时间复杂度为 O(1);
  • 删除数据: 将数据从双向链表中移除,时间复杂度为 O(1);
  • 淘汰数据: 直接淘汰双向链表的头节点,时间复杂度为 O(1)。

LinkedHashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
java复制代码// -> 1.1 如果节点不存在,则新增节点
Node<K,V> newNode(int hash, K key, V value, Node<K,V> e) {
// 新建双向链表节点
LinkedHashMap.Entry<K,V> p = new LinkedHashMap.Entry<K,V>(hash, key, value, e);
// 添加到双向链表尾部,等价于 LinkedList#linkLast
linkNodeLast(p);
return p;
}

// -> 1.1 如果节点不存在,则新增节点
TreeNode<K,V> newTreeNode(int hash, K key, V value, Node<K,V> next) {
// 新建红黑树节点(继承于双向链表节点)
TreeNode<K,V> p = new TreeNode<K,V>(hash, key, value, next);
// 添加到双向链表尾部,等价于 LinkedList#linkLast
linkNodeLast(p);
return p;
}

// 添加到双向链表尾部,等价于 LinkedList#linkLast
private void linkNodeLast(LinkedHashMap.Entry<K,V> p) {
LinkedHashMap.Entry<K,V> last = tail;
tail = p;
if (last == null)
// last 为 null 说明首个添加的元素,需要修改 first 指针
head = p;
else {
// 将新节点的前驱指针指向 last
p.before = last;
// 将 last 的 next 指针指向新节点
last.after = p;
}
}

// 节点插入回调
// evict:是否淘汰最早的节点
void afterNodeInsertion(boolean evict) { // possibly remove eldest
LinkedHashMap.Entry<K,V> first;
// removeEldestEntry:是否淘汰最早的节点,即是否淘汰头节点(由子类实现)
if (evict && (first = head) != null && removeEldestEntry(first)) {
// 移除 first 节点,腾出缓存空间
K key = first.key;
removeNode(hash(key), key, null, false, true);
}
}

// 移除节点回调
void afterNodeRemoval(Node<K,V> e) { // unlink
// 实现了标准的双链表移除
LinkedHashMap.Entry<K,V> p = (LinkedHashMap.Entry<K,V>)e, b = p.before, a = p.after;
p.before = p.after = null;
if (b == null)
// 删除的是头节点,则修正 head 指针
head = a;
else
// 修正前驱节点的后继指针,指向被删除节点的后继节点
b.after = a;
if (a == null)
// 删除的是尾节点,则修正 tail 指针
tail = b;
else
// 修正后继节点的前驱指针,指向被删除节点的前驱节点
a.before = b;
}

// 节点访问回调
void afterNodeAccess(Node<K,V> e) { // move node to last
// 先将节点 e 移除,再添加到链表尾部
LinkedHashMap.Entry<K,V> last;
// accessOrder:是否按照访问顺序排序,为 false 则保留插入顺序
if (accessOrder && (last = tail) != e) {
// 这两个 if 语句块就是 afterNodeRemoval 的逻辑
LinkedHashMap.Entry<K,V> p = (LinkedHashMap.Entry<K,V>)e, b = p.before, a = p.after;
p.after = null;
if (b == null)
head = a;
else
b.after = a;
if (a != null)
a.before = b;
else
last = b;
// 这个 if 语句块就是 linkNodeLast 的逻辑
if (last == null)
head = p;
else {
p.before = last;
last.after = p;
}
tail = p;
++modCount;
}
}

// 淘汰判断接口,由子类实现
protected boolean removeEldestEntry(Map.Entry<K,V> eldest) {
return false;
}

4.4 LinkedHashMap 的迭代器

与 HashMap 类似,LinkedHashMap 也提供了 3 个迭代器:

  • LinkedEntryIterator: 键值对迭代器
  • LinkedKeyIterator: 键迭代器
  • LinkedValueIterator: 值迭代器

区别在于 LinkedHashMap 自己实现了 LinkedHashIterator。在迭代器遍历时,HashMap 会按照数组顺序遍历桶节点,从开发者的视角看是无序的。而 LinkedHashMap 是按照双向链表的顺序从 head 节点开始遍历,从开发者的视角是可以感知到的插入顺序或访问顺序。

LinkedHashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
java复制代码abstract class LinkedHashIterator {
LinkedHashMap.Entry<K,V> next;
LinkedHashMap.Entry<K,V> current;
// 修改计数
int expectedModCount;

LinkedHashIterator() {
// 从头结点开始遍历
next = head;
// 修改计数
expectedModCount = modCount;
current = null;
}

public final boolean hasNext() {
return next != null;
}

final LinkedHashMap.Entry<K,V> nextNode() {
LinkedHashMap.Entry<K,V> e = next;
// 检查修改计数
if (modCount != expectedModCount)
throw new ConcurrentModificationException();
if (e == null)
throw new NoSuchElementException();
current = e;
next = e.after;
return e;
}
...
}

4.5 LinkedHashMap 的序列化过程

与 HashMap 相同,LinkedHashMap 也重写了 JDK 序列化的逻辑,并保留了 HashMap 中序列化的主体结构。LinkedHashMap 只是重写了 internalWriteEntries(),按照双向链表的顺序进行序列化,这样在反序列化时就能够恢复双向链表顺序。

HashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
java复制代码// 序列化过程
private void writeObject(java.io.ObjectOutputStream s) throws IOException {
int buckets = capacity();
s.defaultWriteObject();
// 写入容量
s.writeInt(buckets);
// 写入有效元素个数
s.writeInt(size);
// 写入有效元素
internalWriteEntries(s);
}

// 不关心键值对所在的桶,在反序列化会重新映射
void internalWriteEntries(java.io.ObjectOutputStream s) throws IOException {
Node<K,V>[] tab;
if (size > 0 && (tab = table) != null) {
for (int i = 0; i < tab.length; ++i) {
for (Node<K,V> e = tab[i]; e != null; e = e.next) {
s.writeObject(e.key);
s.writeObject(e.value);
}
}
}
}

LinkedHashMap.java

1
2
3
4
5
6
7
java复制代码// 重写:按照双向链表顺序写入
void internalWriteEntries(java.io.ObjectOutputStream s) throws IOException {
for (LinkedHashMap.Entry<K,V> e = head; e != null; e = e.after) {
s.writeObject(e.key);
s.writeObject(e.value);
}
}

  1. 基于 LinkedHashMap 实现 LRU 缓存

这一节,我们来实现一个简单的 LRU 缓存。理解了 LinkedHashMap 维护插入顺序和访问顺序的原理后,相信你已经知道如何实现 LRU 缓存了。

  • 首先,我们已经知道,LinkedHashMap 支持 2 种排序模式,这是通过构造器参数 accessOrder 标记位控制的。所以,这里我们需要将 accessOrder 设置为 true 表示使用 LRU 模式的访问顺序排序。
  • 其次,我们不需要实现淘汰数据的逻辑,只需要重写淘汰判断接口 removeEldestEntry(),当缓存数量大于缓存容量时返回 true,表示移除最早的节点。

MaxSizeLruCacheDemo.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
java复制代码public class MaxSizeLruCacheDemo extends LinkedHashMap {

private int maxElements;

public LRUCache(int maxSize) {
super(maxSize, 0.75F, true);
maxElements = maxSize;
}

protected boolean removeEldestEntry(java.util.Map.Entry eldest) {
// 超出容量
return size() > maxElements;
}
}

  1. 总结

  • 1、LRU 是一种缓存淘汰算法,与其他淘汰算法相比,LRU 算法利用了 “局部性原理”,缓存的平均命中率更高;
  • 2、使用双向链表 + 散列表实现的 LRU,在添加、查询、移除和淘汰数据的时间复杂度都是 O(1),这种数据结构也叫哈希链表;
+ **查询数据:** 通过散列表定位数据,时间复杂度为 O(1);
+ **淘汰数据:** 直接淘汰链表尾节点,时间复杂度为 O(1)。
  • 3、使用 LinkedHashMap 时,主要关注 2 个 API:
+ **`accessOrder` 标记位:** LinkedHashMap 同时实现了 FIFO 和 LRU 两种淘汰策略,默认为 FIFO 排序,可以使用 accessOrder 标记位修改排序模式。
+ **`removeEldestEntry()` 接口:** 每次添加数据时,LinkedHashMap 会回调 removeEldestEntry() 接口。开发者可以重写 removeEldestEntry() 接口决定是否移除最早的节点(在 FIFO 策略中是最早添加的节点,在 LRU 策略中是最久未访问的节点)。
  • 4、Android 的 LruCache 内存缓存和 DiskLruCache 磁盘缓存中,都直接复用了 LinkedHashMap 的 LRU 能力。

今天,我们分析了 LinkedHashMap 的实现原理。在下篇文章里,我们来分析 LRU 的具体实现应用,例如 Android 标准库中的 LruCache 内存缓存。

可以思考一个问题,LinkedHashMap 是非线程安全的,Android 的 LruCache 是如何解决线程安全问题的?请关注 小彭说 · Android 开源组件 专栏。


版权声明

本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!

参考资料

  • 数据结构与算法分析 · Java 语言描述(第 5 章 · 散列)—— [美] Mark Allen Weiss 著
  • 算法导论(第 11 章 · 散列表)—— [美] Thomas H. Cormen 等 著
  • 数据结构与算法之美(第 6、18~22 讲) —— 王争 著,极客时间 出品
  • LinkedHashMap 源码详细分析(JDK1.8)—— 田小波 著
  • LRU 算法及其优化策略——算法篇 —— 豆豉辣椒炒腊肉 著
  • 缓冲池(buffer pool),这次彻底懂了! —— 58 沈剑 著
  • LeetCode 146. LRU 缓存 —— LeetCode
  • Cache replacement policies —— Wikipedia

推荐阅读

Java & Android 集合框架系列文章目录(2023/07/08 更新):

  • #1 ArrayList 可以完全替代数组吗?
  • #2 说一下 ArrayList 和 LinkedList 的区别?
  • #3 CopyOnWriteArrayList 是如何保证线程安全的?
  • #4 ArrayDeque:如何用数组实现栈和队列?
  • #5 万字 HashMap 详解,基础(优雅)永不过时 —— 原理篇
  • #6 万字 HashMap 详解,基础(优雅)永不过时 —— 源码篇
  • #7 如何使用 LinkedHashMap 实现 LRU 缓存?
  • #8 说一下 WeakHashMap 如何清理无效数据的?
  • #9 全网最全的 ThreadLocal 原理详细解析 —— 原理篇
  • #10 全网最全的 ThreadLocal 原理详细解析 —— 源码篇

数据结构与算法系列文章:跳转阅读

⭐️ 永远相信美好的事情即将发生,欢迎加入小彭的 Android 交流社群~

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

【AI】浅谈梯度下降算法(拓展篇) 前言 学习率 批量梯度下

发表于 2022-11-10

本文正在参加「金石计划 . 瓜分6万现金大奖」

前言

前导博文:

  • 【AI】浅谈梯度下降算法(理论篇)
  • 【AI】浅谈梯度下降算法(实战篇)

通过前导博文的学习,想必大家对于梯度下降也有所掌握了,其中在 【AI】浅谈梯度下降算法(实战篇) 博文中有粗略的提到过梯度下降的三大家族,本博文将结合代码实现来细细讲解;

学习率

学习率,想必大家在前导博文中也见过该词,为什么要提到学习率呢,且听我慢慢分析;

首先要知道,梯度下降的中心思想就是迭代地调整参数从而使成本函数最小化。

具体来说,首先使用一个随机的 θ\thetaθ 值(这被称为随机初始化),然后逐步改进,每次踏出一步,每一步都尝试降低一点成本函数(如在线性回归中采用 MSE),直到算法收敛出一个最小值,如下图所示:

image.png

然而,这其中有一个十分重要的超参数,学习率 Learning Rate,它影响了每一步的步长;

如果学习率太低,算法需要经过大量迭代才能收敛,这将耗费很长时间:

image.png

反之,如果学习率太高,这会导致算法发散,值越来越大,最后无法找到好的解决方案:

image.png

而且,并不是所有的成本函数看起来都像一个漂亮的碗。有的可能看着像洞、像山脉、像高原或者是各种不规则的地形,导致很难收敛到最小值。

image.png

如果随机初始化,算法从左侧起步,那么会收敛到一个局部最小值,而不是全局最小值。如果从右侧起步,那么需要经过很长时间才能越过整片高原,如果停下来太早,将永远达不到全局最小值。

因此,对于学习率的设置显得尤为重要,下面来看看不同的学习率所带来的影响:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
py复制代码def plot_gradient_descent(theta, eta, theta_path=None):
m = len(X_b)
plt.plot(X, y, "b.")
for i in range(max_iters):
if i < 10:
y_predict = X_new_b.dot(theta)
style = "g-" if i > 0 else "r--"
plt.plot(X_new, y_predict, style)
gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y)
theta = theta - eta * gradients
if theta_path is not None:
theta_path.append(theta)
plt.xlabel("$X$", fontsize=18)
plt.title(r"$\eta = {}$".format(eta), fontsize=16)

绘制了梯度下降的前十步,虚线表示起点,η\etaη 表示为学习率:

image.png

得出结论:

  • 左图:在前十步无法找到解决方案,但是只要长时间的迭代就一定可以找到解决方案;
  • 中图:效果看起来不错,比较符合预期,几次迭代就收敛出了最终解;
  • 右图:算法发散,直接跳过了数据区域,并且每一步都离实际解决方案越来越远;

要找到合适的学习率,可以使用网络搜索。但是可能需要限制迭代次数,这样网络搜索就可以淘汰掉那些收敛耗时太长的模型。

然而怎么限制迭代次数呢?如果设置太低,算法可能在离最优解还很远时就停止了;但是如果设置得太高,模型到达最优解后,继续迭代参数不再变化,又会浪费时间。

一个简单的方法是在开始设置一个非常大的迭代次数,但是当梯度向量的值变得很微小时中断算法,也就是当他的范数变得低于 ε\varepsilonε(称为容差)时,因为这是梯度下降已经(几乎)到达了最小值。

收敛率:当成本函数为凸函数,并且斜率没有陡峭的变化时(如 MSE 成本函数),通过批量梯度下降可以看出一个固定的学习率有一个收敛率,为 O(1迭代次数)O(\frac{1}{迭代次数})O(迭代次数1)。换句话说,如果将容差 ε\varepsilonε 缩小为原来的 110\frac{1}{10}101(以得到更精确的解),算法将不得不运行 10 倍的迭代次数。

批量梯度下降 BGD

梯度下降法最常用的形式,具体做法也就是在更新参数时使用所有的样本来进行更新;

θi=θi−α∑j=0m(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)θ_i=θ_i−α∑_{j=0}^m(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)x_i^{(j)}θi=θi−αj=0∑m(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)

  • 优点:全局最优解,易于并行实现;
  • 缺点:计算代价大,数据量大时,训练过程慢;

image.png

要实现梯度下降,需要计算每个模型关于参数 θj\theta_jθj 的成本函数的梯度。换言之,需要计算的是,如果改变 θj\theta_jθj,成本函数会改变多少,即偏导数。以线性回归的成本函数 MSEMSEMSE 为例,其偏导数为:

∂∂θjMSE(θ)=∂∂θj(1m∑i=1m(θT⋅X(i)−y(i))2)=2m∑i=1m(θT⋅x(i)−y(i))xj(i)\frac{∂}{∂θ_j} MSE(θ) = \frac{∂}{∂θ_j}(\frac{1}{m}∑_{i=1}^m(θ^T⋅X^{(i)}−y^{(i)})^2) \
\quad\quad\quad\quad
= \frac{2}{m}∑_{i=1}^m(θ^T⋅x^{(i)}−y^{(i)})x^{(i)}_j∂θj∂MSE(θ)=∂θj∂(m1i=1∑m(θT⋅X(i)−y(i))2)=m2i=1∑m(θT⋅x(i)−y(i))xj(i)
如果不想单独计算这些梯度,可以使用下面的公式对其进行一次性计算。梯度向量 ∇θMSE(θ)\nabla_\theta MSE(\theta)∇θMSE(θ),包含所有成本函数(每个模型参数一个)的偏导数。

∇θMSE(θ)=(∂∂θ0MSE(θ)∂∂θ1MSE(θ)⋮∂∂θnMSE(θ))=2mXT⋅(X⋅θ−y)∇_θMSE(θ)=
\begin{pmatrix}
\frac{∂}{∂θ_0} MSE(θ) \
\frac{∂}{∂θ_1} MSE(θ) \
\vdots \
\frac{∂}{∂θ_n} MSE(θ)
\end{pmatrix}
=\frac{2}{m}X^T⋅(X⋅θ-y)∇θMSE(θ)=⎝⎛∂θ0∂MSE(θ)∂θ1∂MSE(θ)⋮∂θn∂MSE(θ)⎠⎞=m2XT⋅(X⋅θ−y)
代码实现:

1、随机初始化:

1
2
3
py复制代码X = 2 * np.random.rand(100, 1)
y = 21 + 5 * X + np.random.randn(100, 1)
X_b = np.c_[np.ones((100, 1)), X]

2、设置各参数:

1
2
3
4
py复制代码m = 100
eta = 0.1
n_epochs = 1000
theta = np.random.randn(2,1)

3、BGD 算法实现:

1
2
3
4
py复制代码for epoch in range(n_epochs):
gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y)
theta = theta - eta * gradients
print('theta:\n{}\n'.format(theta))

4、绘制图像:

1
2
3
4
5
6
7
8
9
py复制代码X_new = np.array([[0], [2]])
X_new_b = np.c_[np.ones((2, 1)), X_new]
y_predict = X_new_b.dot(theta)

plt.title(f"$\eta = {0.1}$")
plt.plot(X_new, y_predict, "r-")
plt.plot(X, y, "b.")

plt.show()

image.png

批量梯度下降法由于使用了全部样本进行训练,所以当损失函数是凸函数时,理论上可以找到全局最优解,但当训练样本很大时,其训练速度会非常慢,不适用于在线学习的一些项目。为了解决这个问题,随机梯度下降算法被提出。

随机梯度下降 SGD

和批量梯度下降法原理类似,区别在于求梯度时,没有用所有的 mmm 个样本的数据,而是仅仅选取一个样本 jjj 来求梯度;

θi=θi−α(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)θ_i=θ_i−α(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)x_i^{(j)}θi=θi−α(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)

  • 优点:训练速度快;
  • 缺点:准确度下降,并不是全局最优,不易于并行实现;

由于算法的随机性质,它比批量梯度下降要不规则得多。成本函数将不再是缓缓降低直到抵达最小值,而是不断上上下下,但是从整体来看,还是在慢慢下降。随着时间的推移,最终会非常接近最小值,但是即使它到达了最小值,依然还会持续反弹,永远不会停止。所以算法停下来的参数值肯定是足够好的,但不是最优的。

image.png

当成本函数非常不规则时,随机梯度下降其实可以帮助算法跳出局部最小值,所以 相比批量梯度下降,它对找到全局最小值更有优势 。

因为,随机性的好处在于可以逃离局部最优,但缺点是永远定位不出最小值。 要解决这个困境,有一个办法是逐步降低学习率。开始的步长比较大(这有助于快速进展和逃离局部最小值),然后越来越小,让算法尽量靠近全局最小值,这个过程叫做模拟退火:因为它类似于冶金时融化的金属慢慢冷却的退火过程。

确定每个迭代学习率的函数叫作学习计划。如果学习率降得太快,可能会陷入局部最小值,甚至是停留在走向最小值的半途中。如果学习率太慢,你可能需要太长时间太能跳到差不多最小值附近,如果提早结束训练,可能只得到一个次优的解决方案。

代码实现:

1、初始化过程同上;

2、SGD 算法实现:

为了避免训练速度过慢,随机梯度下降法在训练过程中每次仅针对一个样本进行训练,但进行多次更新。在每一轮新的更新之前,需要对数据样本进行重新洗牌(shuffle)。

1
2
3
4
5
6
7
8
9
10
11
12
13
py复制代码for epoch in range(n_epochs):
# 每个轮次开始分批次迭代之前打乱数据索引顺序
arr = np.arange(len(X_b))
np.random.shuffle(arr)
X_b = X_b[arr]
y = y[arr]
for i in range(m):
xi = X_b[i:i+1]
yi = y[i:i+1]
gradients = xi.T.dot(xi.dot(theta)-yi)
theta = theta - eta * gradients

print('theta:\n{}\n'.format(theta))

3、绘制图像:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
py复制代码# 在上述代码段内部加入
for i in range(m):
xi = X_b[i:i+1]
yi = y[i:i+1]
gradients = xi.T.dot(xi.dot(theta)-yi)
theta = theta - eta * gradients
if epoch == 0 and i % 5 == 0:
y_predict = X_new_b.dot(theta)
style = "m--" if i > 0 else "r--"
plt.plot(X_new, y_predict, style)
if epoch == 1 and i % 5 == 0:
y_predict = X_new_b.dot(theta)
style = "g--" if i > 0 else "r--"
plt.plot(X_new, y_predict, style)


X_new = np.array([[0], [2]])
X_new_b = np.c_[np.ones((2, 1)), X_new]
y_predict = X_new_b.dot(theta)

plt.title(f"$\eta = {0.1}$")
plt.plot(X_new, y_predict, "k-")
plt.plot(X, y, "b.")
plt.show()

image.png

随机梯度下降法在更新过程中由于是针对单个样本,所以其迭代的方向有时候并不是整体最优的方向,同时其方差较大,导致损失函数值的变动并不是规律的递减,更多的情况可能是波动形状的下降。

为了解决批量梯度下降的速度太慢以及随机梯度下降方差变动过大的情况,一种折中的算法–小批量梯度下降算法被提出,其从全部样本中选取部分样本进行迭代训练。并且在每一轮新的迭代开始之前,对全部样本进行 Shuffle 处理。

小批量梯度下降 MBGD

小批量梯度下降法是批量梯度下降法和随机梯度下降法的折中,也就是对于 mmm 个样本,我们采用 xxx 个样本来迭代,1<x<m1<x<m1<x<m。一般可以取 x=10x=10x=10,当然根据样本的数据量,可以调整这个 xxx 的值;

θi=θi−α∑j=tt+x−1(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)θ_i=θ_i−α∑_{j=t}^{t+x-1}(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)x_i^{(j)}θi=θi−αj=t∑t+x−1(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)
相对于随机梯度下降算法,小批量梯度下降算法降低了收敛波动性, 即降低了参数更新的方差,使得更新更加稳定。相对于全量梯度下降,其提高了每次学习的速度。并且其不用担心内存瓶颈从而可以利用矩阵运算进行高效计算。一般而言每次更新随 机选择[50,256]个样本进行学习,但是也要根据具体问题而选择,实践中可以进行多次试验, 选择一个更新速度与更次次数都较适合的样本数。

image.png

代码实现:

1、初始化过程同上;

2、动态调整学习率:

1
2
3
py复制代码t0, t1 = 5, 500
def learning_rate_schedule(t):
return t0 / (t + t1)

3、MBGD 算法实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
py复制代码theta = np.random.randn(2,1)
for epoch in range(n_epochs):
arr = np.arange(len(X_b))
np.random.shuffle(arr)
X_b = X_b[arr]
y = y[arr]
for i in range(n_batches):
x_batch = X_b[i * batch: i * batch + batch]
y_batch = y[i * batch: i * batch + batch]
gradients = x_batch.T.dot(x_batch.dot(theta)-y_batch)
eta = learning_rate_schedule(epoch * m + i)
theta = theta - eta * gradients

print('theta:\n{}\n'.format(theta))

4、绘制图像:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
py复制代码# 在上述代码段内部加入
for i in range(n_batches):
x_batch = X_b[i * batch: i * batch + batch]
y_batch = y[i * batch: i * batch + batch]
gradients = x_batch.T.dot(x_batch.dot(theta)-y_batch)
eta = learning_rate_schedule(epoch * m + i)
theta = theta - eta * gradients
if epoch == 0 and i < 5:
y_predict = X_new_b.dot(theta)
style = "m--" if i > 0 else "r--"
plt.plot(X_new, y_predict, style)
if epoch == 1 and i < 5:
y_predict = X_new_b.dot(theta)
style = "g--" if i > 0 else "r--"
plt.plot(X_new, y_predict, style)


X_new = np.array([[0], [2]])
X_new_b = np.c_[np.ones((2, 1)), X_new]
y_predict = X_new_b.dot(theta)

plt.title(f"$\eta = {0.1}$")
plt.plot(X_new, y_predict, "k-")
plt.plot(X, y, "b.")
plt.show()

image.png

后记

以上三种梯度下降算法仅局限于对训练样本进行变更,且每次迭代更新权重时使用的梯度仅作用于当前状态。由于每一期的样本有好有坏,导致迭代过程是曲折波动的,影响了收敛速度。

下图显示了训练期间参数空间中三种梯度下降算法所采用的路径。它们最终都接近最小值,但是批量梯度下降的路径实际上是在最小值处停止,而随机梯度下降和小批量梯度下降都继续走动。但是,不要忘记批量梯度下降每步需要花费很多时间,如果你使用良好的学习率调度,随机梯度下降和小批量梯度下降也会达到最小值。

image.png

让我们比较到目前为止讨论的线性回归算法,mmm 是训练实例的数量,nnn 是特征的数量。

算法 mmm 很大 核外支持 nnn 很大 超参数 要求缩放
标准方程 快 否 慢 0 否
SVD 快 否 慢 0 否
BGD 慢 否 快 2 是
SGD 快 是 快 ≥\geq≥2 是
MBGD 快 是 快 ≥\geq≥2 是

image.png

参考:

  • 不同梯度下降算法的比较及Python实现
  • 利用python实现3种梯度下降算法

📝 上篇精讲:【AI】浅谈梯度下降算法(实战篇)

💖 我是 𝓼𝓲𝓭𝓲𝓸𝓽,期待你的关注;

👍 创作不易,请多多支持;

🔥 系列专栏:AI

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

Java & Android 集合框架

发表于 2022-11-09

⭐️ 本文已收录到 AndroidFamily,技术和职场问题,请关注公众号 [彭旭锐] 和 [BaguTree Pro] 知识星球提问。

学习数据结构与算法的关键在于掌握问题背后的算法思维框架,你的思考越抽象,它能覆盖的问题域就越广,理解难度也更复杂。在实际的业务开发中,往往不需要我们手写数据结构,而是直接使用标准库的数据结构 / 容器类。

本文是 Java & Android 集合框架系列的第 6 篇文章,完整文章目录请移步到文章末尾~

前言

大家好,我是小彭。

在上一篇文章里,我们聊到了 HashMap 的基本原理,这一节我们来结合 HashMap 的源码做分析。

本文源码基于 Java 8 HashMap,并关联分析部分 Java 7 HashMap。

  • Java & Android 集合框架 #5 万字 HashMap 详解,基础(优雅)永不过时 —— 原理篇
  • Java & Android 集合框架 #6 万字 HashMap 详解,基础(优雅)永不过时 —— 源码篇

思维导图:


  1. HashMap 源码分析

3.1 HashMap 的构造方法

HashMap 有 4 个构造方法:

  • 1、带初始容量和装载因子的构造方法: 检查初始容量和装载因子的有效性,并计算初始容量最近的 2 的整数幂;
  • 2、带初始容量的构造方法: 使用默认负载因子 0.75 调用上一个构造方法;
  • 3、无参构造方法: 设置默认装载因子 0.75;
  • 4、带 Map 参数的构造方法: 设置默认装载因子 0.75,并逐个添加 Map 中的映射关系。

可以看到,在 HashMap 的构造方法中并没有创建底层数组,而是延迟到 put 操作中触发的 resize 扩容操作中创建数组。另外,在可以已知存储的数据量时,可以在构造器中预先设置初始容量,避免在添加数据的过程中多次触发扩容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
java复制代码// 带初始容量和装载因子的构造方法
public HashMap(int initialCapacity, float loadFactor) {
if (initialCapacity < 0)
throw new IllegalArgumentException("Illegal initial capacity: " + initialCapacity);
if (initialCapacity > MAXIMUM_CAPACITY)
// 最大容量限制
initialCapacity = MAXIMUM_CAPACITY;
if (loadFactor <= 0 || Float.isNaN(loadFactor))
throw new IllegalArgumentException("Illegal load factor: " + loadFactor);
// 装载因子上限
this.loadFactor = loadFactor;
// 扩容阈值(此处不是真正的阈值,仅仅只是将传入的容量转化最近的 2 的整数幂,该阈值后面会重新计算)
this.threshold = tableSizeFor(initialCapacity);
}

// 带初始容量的构造方法
public HashMap(int initialCapacity) {
this(initialCapacity, DEFAULT_LOAD_FACTOR /*0.75*/);
}

// 无参构造方法
public HashMap() {
this.loadFactor = DEFAULT_LOAD_FACTOR /*0.75*/;
}

// 带 Map 的构造方法
public HashMap(Map<? extends K, ? extends V> m) {
this.loadFactor = DEFAULT_LOAD_FACTOR /*0.75*/;
// 疑问 7:为什么不使用 Arrays 工具类整体复制,而是使用 putMapEntries 批量添加?
// 批量添加
putMapEntries(m, false);
}

// 疑问 8:tableSizeFor() 的函数体解释一下?
// 获取最近的 2 的整数幂
static final int tableSizeFor(int cap) {
// 先减 1,让 8、16 这种本身就是 2 的整数幂的容量保持不变
// 在 ArrayDeque 中没有先减 1,所以容量 8 会转为 16
int n = cap - 1;
n |= n >>> 1;
n |= n >>> 2;
n |= n >>> 4;
n |= n >>> 8;
n |= n >>> 16;
return (n < 0) ? 1 /*tableSizeFor() 方法外层已经检查过超过 2^30 的值,应该不存在整型溢出的情况*/
: (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1;
}

小朋友总是有太多问号,举手提问🙋🏻‍♀️:

🙋🏻‍♀️疑问 7:为什么带集合的构造方法不使用 Arrays 工具类整体复制,而是使用 putMapEntries 批量添加?

首先,参数 Map 不一定是基于散列表的 Map,所以不能整体复制。其次,就算参数 Map 也是 HashMap,如果两个散列表的 length 长度不同,键值对映射到的数组下标也会不同。因此不能用 Arrays 工具类整体复制,必须逐个再散列到新的散列表中。

🙋🏻‍♀️疑问 8:tableSizeFor() 的函数体解释一下?

其实,HashMap#tableSizeFor() 函数体与 ArrayDeque#calculateSize() 函数体相似,也是求最近的 2 的整数幂,即 nextPow2 问题。区别在于 HashMap 在第一步对参数 cap - 1,而 ArrayDeque 没有这一步,会将 8、16 这种本身就是 2 的整数幂的容量翻倍。

tableSizeFor() 中经过五轮无符号右移和或运算,将 cap 转换为从最高位开始后面都是 1 的数。再执行 +1 运算,就求出了最近的 2 的整数幂(最高有效位是 1,低位都是 0)。

1
2
3
4
5
6
7
java复制代码n = 0 0 0 0 1 x x x x x     //n
n = 0 0 0 0 1 1 x x x x //n |= n >>> 1;
n = 0 0 0 0 1 1 1 1 x x //n |= n >>> 2;
n = 0 0 0 0 1 1 1 1 1 1 //n |= n >>> 4;
n = 0 0 0 0 1 1 1 1 1 1 //n |= n >>> 8;(这一步对 n 没有影响了)
n = 0 0 0 0 1 1 1 1 1 1 //n |= n >>> 16;(这一步对 n 没有影响了)
n = 0 0 0 1 0 0 0 0 0 0 //n + 1(进位,得到最近 2 的整数幂)

3.2 HashMap 的哈希函数

将 HashMap#put 方法中,有一个重要的步骤就是使用 Hash 函数计算键值对中键(Key)的散列值。HashMap#put 的执行流程非常复杂,为了降低理解难度,我们先分析 HashMap#hash 方法。

Hash 函数是散列表的核心特性,Hash 函数是否足够随机,会直接影响散列表的查询性能。在 Java 7 和 Java 8 中,HashMap 会在 Object#hashCode() 的基础上增加 “扰动”:

  • Java 7: 做 4 次扰动,通过无符号右移,让散列值的高位与低位做异或;
  • Java 8: 做 1 次扰动,通过无符号右移,让高 16 位与低 16 位做异或。在 Java 8 只做一次扰动,是为了在随机性和计算效率之间的权衡。

HashMap#hash

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
java复制代码public V put(K key, V value) {
return putVal(hash(key) /*计算散列值*/, key, value, false, true);
}

// Java 7:4 次位运算 + 5次异或运算
static final int hash(int h) {
h ^= k.hashCode();
h ^= (h >>> 20) ^ (h >>> 12);
return h ^ (h >>> 7) ^ (h >>> 4);
}

// 疑问 9:为什么 HashMap 要在 Object#hashCode() 上增加扰动,而不是要求 Object#hashCode() 尽可能随机?
// 为什么让高位与低位做异或就可以提高随机性?
// Java 8:1 次位运算 + 1次异或运算
static final int hash(Object key) {
int h;
return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
}

小朋友总是有太多问号,举手提问🙋🏻‍♀️:

  • 🙋🏻‍♀️疑问 9:为什么 HashMap 要在 Object#hashCode() 上增加扰动,而不是要求 Object#hashCode() 尽可能随机?

这是兜下限,以保证所有使用 HashMap 的开发者都能获得良好的性能。而且,由于数组的长度有限,在将散列值映射到数组下标时,会使用数组的长度做取余运算,最终影响下标位置的只有散列值的低几位元素,会破坏映射的随机性(即散列值随机,但映射到下标后不随机)。

因此,HashMap 会对散列值做位移和异或运算,让高 16 位与低 16 位做异或运算。等于说在低位中加入了高位的特性,让高位的数值也会影响到数组下标的计算。

到这里,基本可以回答上一节剩下的疑问 4:

  • 🙋🏻‍♀️疑问 4:为什么 HashMap 要求数组的容量是 2 的整数幂?

这是为了提高散列值映射到数组下标的计算效率和随机性,原因有 3 个:

1、提高取余操作的计算效率:

如果数组的容量是 2 的整数幂,那么就可以将取余运算 |hash % length| 替换为位运算 hash & (length - 1) ,不管被除数是正负结果都是正数。 不仅将取余运算替换为位运算,而且减少了一次取绝对值运算,提高了索引的计算效率。

1
2
3
4
java复制代码10  % 4 = 2
-10 % 4 = -2 // 负数
10 & (4 - 1) = 2
-10 & (4 - 1) = 2 // 正数

2、数组长度是偶数能避免散列值都映射到偶数下标上:

如果数组的长度是奇数,那么 (length - 1) 的结果一定是偶数,即二进制最低 1 位是 0。这就会导致 hash & (length - 1) 的结果一定是偶数,即始终会映射到偶数下标中,不仅浪费了一般数组空间,也会增大冲突概率。

3、保留所有的低位特征:

数组长度 length 为 2 的整数幂对应 (length - 1) 正好是高位为 0,低位都是 1 的低位掩码,能够让影响映射的因素全部归结到散列值上。

3.3 HashMap 的添加方法

HashMap 直接添加一个键值对,也支持批量添加键值对:

  • put: 逐个添加或更新键值对
  • putAll: 批量添加或更新键值对

不管是逐个添加还是批量添加,最终都会先通过 hash 函数计算键(Key)的散列值,再通过 putVal 添加或更新键值对。

putValue 的流程非常复杂,我将主要步骤概括为 5 步:

  • 1、如果数组为空,则使用扩容函数创建(说明数组的创建时机在首次 put 操作时);
  • 2、(n - 1) & hash:散列值转数组下标,与 Java 7 的 indexFor() 方法相似;
  • 3、如果是桶中的第一个节点,则创建并插入 Node 节点;
  • 4、如果不是桶中的第一个节点(即发生哈希冲突),需要插入链表或红黑树。在添加到链表的过程中,遍历链表找到 Key 相等(equals)的节点,如果不存在则使用尾插法添加新节点。如果链表节点数超过树化阈值 8,则将链表转为红黑树。
  • 5、如果键值对数量大于扩容阈值,则触发扩容。

HashMap#put

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
java复制代码// 添加或更新键值对
public V put(K key, V value) {
return putVal(hash(key) /*计算散列值*/, key, value, false, true);
}

// 批量添加或更新键值对
public void putAll(Map<? extends K, ? extends V> m) {
putMapEntries(m, true);
}

// 批量添加或更新键值对
// evict:是否驱逐最早的节点(在 LinkedHashMap 中使用,我们先忽略)
final void putMapEntries(Map<? extends K, ? extends V> m, boolean evict) {
int s = m.size();
if (s > 0) {
if (table == null) {
// 如果数组为空,则先初始化 threshold 扩容阈值
float ft = ((float)s / loadFactor) + 1.0F;
// 扩容阈值上限
int t = ((ft < (float)MAXIMUM_CAPACITY) ? (int)ft : MAXIMUM_CAPACITY);
if (t > threshold)
threshold = tableSizeFor(t);
} else if (s > threshold)
// 参数 Map 的长度大于扩容阈值,先扩容(如果扩容后依然不足,在下面的 putVal 中会再次扩容)
// 这里应该有优化空间,批量添加时可以直接扩容到满足要求的容量,避免在 for 循环中多次扩容
resize();
// 逐个添加 Map 中的键值对
for (Map.Entry<? extends K, ? extends V> e : m.entrySet()) {
K key = e.getKey();
V value = e.getValue();
// hash(key):计算 Key 的哈希值
// pubVal:添加或更新键值对
putVal(hash(key), key, value, false, evict);
}
}
}

// 最终都会走到 putVal方法:

// hash:Key 的散列值(经过扰动)
// onlyIfAbsent:如果为 true,不会覆盖旧值
// evict:是否驱逐最早的节点(在 LinkedHashMap 中使用,我们先忽略)
final V putVal(int hash, K key, V value, boolean onlyIfAbsent, boolean evict) {
// 数组
Node<K,V>[] tab;
// 目标桶(同一个桶中节点的散列值有可能不同)
Node<K,V> p;
// 数组长度
int n;
// 桶的位置
int i;
// 1. 如果数组为空,则使用扩容函数创建(说明数组的创建时机在首次 put 操作时)
if ((tab = table) == null || (n = tab.length) == 0)
n = (tab = resize()).length;
// 2. (n - 1) & hash:散列值转数组下标,与 Java 7 的 indexFor() 方法相似
if ((p = tab[i = (n - 1) & hash]) == null)
// 3. 如果是桶中的第一个节点,则创建并插入 Node 节点
tab[i] = newNode(hash, key, value, null);
else {
// 4. 如果不是桶中的第一个节点(即发生哈希冲突),需要插入链表或红黑树
// e:最终匹配的节点
Node<K,V> e;
// 节点上的 Key
K k;
if (p.hash == hash && ((k = p.key) == key || (key != null && key.equals(k))))
// 4.1 如果桶的根节点与 Key 相等,则将匹配到根节点
// p.hash == hash:快捷比较(同一个桶中节点的散列值有可能不同,如果散列值不同,键不可能相同)
// (k = p.key) == key:快捷比较(同一个对象)
// key != null && key.equals(k):判断两个对象 equals 相同
e = p;
else if (p instanceof TreeNode)
// 4.2 如果桶是红黑树结构,则采用红黑树的插入方式
e = ((TreeNode<K,V>)p).putTreeVal(this, tab, hash, key, value);
else {
// 4.3 如果桶是链表结构,则采用链表的插入方式:
// 4.3.1 遍历链表找到 Key 相等的节点
// 4.3.2 否则使用尾插法添加新节点
// 4.3.3 链表节点数超过树化阈值,则将链表转为红黑树
for (int binCount = 0; ; ++binCount) {
// 尾插法(Java 7 使用头插法)
if ((e = p.next) == null) {
p.next = newNode(hash, key, value, null);
if (binCount >= TREEIFY_THRESHOLD - 1) // -1 for 1st
// 链表节点数超过树化阈值,则将链表转为红黑树
treeifyBin(tab, hash);
break;
}
// 找到 Key 相等的节点
if (e.hash == hash && ((k = e.key) == key || (key != null && key.equals(k))))
break;
p = e;
}
}
// 4.4 新 Value 替换旧 Value(新增节点时不会走到这个分支)
if (e != null) {
V oldValue = e.value;
if (!onlyIfAbsent || oldValue == null)
e.value = value;
// 访问节点回(用于 LinkedHashMap,默认为空实现)
afterNodeAccess(e);
return oldValue;
}
}
// 修改记录
++modCount;
// 5. 如果键值对数量大于扩容阈值,则触发扩容
if (++size > threshold)
resize();
// 新增节点回调(用于 LinkedHashMap,默认为空实现)
afterNodeInsertion(evict);
return null;
}

// -> 4.2 如果桶是红黑树结构,则采用红黑树的插入方式
final TreeNode<K,V> putTreeVal(HashMap<K,V> map, Node<K,V>[] tab,
int h, K k, V v) {
...
}

// -> 链表节点数超过树化阈值,则将链表转为红黑树
final void treeifyBin(Node<K,V>[] tab, int hash) {
int n, index; Node<K,V> e;
if (tab == null || (n = tab.length) < MIN_TREEIFY_CAPACITY)
resize();
else if ((e = tab[index = (n - 1) & hash]) != null) {
TreeNode<K,V> hd = null, tl = null;
do {
TreeNode<K,V> p = replacementTreeNode(e, null);
if (tl == null)
hd = p;
else {
p.prev = tl;
tl.next = p;
}
tl = p;
} while ((e = e.next) != null);
if ((tab[index] = hd) != null)
hd.treeify(tab);
}
}

小朋友总是有太多问号,举手提问🙋🏻‍♀️:

  • 🙋🏻‍♀️疑问 10:为什么 Java 8 要将头插法改为尾插法?

HashMap 不考虑多线程同步,会存在多线程安全问题。当多个线程同时执行 put 操作并且触发扩容时,Java 7 的头插法会翻转链表的顺序,有可能会引起指针混乱形成环形链表,而 Java 8 使用尾插法,在扩容时会保持链表原本的顺序。

  • 🙋🏻‍♀️疑问 11:解释一下 p.hash == hash && ((k = p.key) == key || (key != null && key.equals(k)))?

这个问题等价于问 HashMap 如何确定键值对的位置:

1、首先,HashMap 会对键 Key 计算 hashCode() 并添加扰动,得到扰动后的散列值 hash。随后通过对数组长度取余映射到数组下标中;

2、然后,当数组下标的桶中存在多个节点时,HashMap 需要遍历桶找到与 Key 相等的节点,以区分是更新还是添加。为了提高效率,就有了 if 语句中的多次判断:

2.1 p.hash == hash 快捷判断: 同一个桶中节点的散列值有可能不同,如果散列值不同,键一定相等:

2.2 (k = p.key) == key 快捷判断:同一个对象;

2.3 key != null && key.equals(k) 最终判断:判断两个键 Key 是否相等,即 equals 相等。

综上所述,HashMap 是通过 hashCode() 定位桶,通过 equals() 确定键值对。

HashMap#put 执行流程

3.4 HashMap 的扩容方法

在 putVal 方法中,如果添加键值对后散列值的长度超过扩容阈值,就会调用 resize() 扩容,主体流程分为 3步:

  • 1、计算扩容后的新容量和新扩容阈值;
  • 2、创建新数组;
  • 3、将旧数组上的键值对再散列到新数组上。

扩容分为 2 种情况:

  • 1、首次添加元素: 会根据构造方法中设置的初始容量和装载因子确定新数组的容量和扩容阈值在无参构造方法中,会使用 16 的数组容量和 0.75 的扩容阈值;
  • 2、非首次添加: 将底层数组和扩容阈值扩大为原来的 2 倍,如果旧容量大于等于 2^30 次幂,则无法扩容。此时,将扩容阈值调整到整数最大值。

再散列的步骤不好理解,这里解释下:

  • 3.1 桶的根节点,直接再散列;
  • 3.2 以红黑树的方式再散列,思路与 3.3 链表的方式相似;
  • 3.3 以链表的形式再散列:hash & oldCap 就是获取 hash 在扩容后新参与映射的 1 个最高有效位。如果这一位是 0,那么映射后的位置还是在原来的桶中,如果这一位是 1,那么映射后的位置就是原始位置 + 旧数组的容量。
1
2
3
4
5
6
java复制代码oldCap     = 0 0 0 0 1 0 0 0 0 0 // 32
oldCap - 1 = 0 0 0 0 0 1 1 1 1 1 // 32
newCap = 0 0 0 1 0 0 0 0 0 0 // 64
newCap - 1 = 0 0 0 0 1 1 1 1 1 1 // 64
^
增加 1 个有效位参与映射

HashMap#resize

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
java复制代码// 扩容
final Node<K,V>[] resize() {
// 旧数组
Node<K,V>[] oldTab = table;
// 旧容量
int oldCap = (oldTab == null) ? 0 : oldTab.length;
// 旧扩容阈值
int oldThr = threshold;
// 新容量
int newCap = 0;
// 新扩容阈值
int newThr = 0;
// 1. 计算扩容后的新容量和新扩容阈值
// 旧容量大于 0,说明不是第一次添加元素
if (oldCap > 0) {
// 如果旧容量大于等于 2^30 次幂,则无法扩容。此时,将扩容阈值调整到整数最大值
if (oldCap >= MAXIMUM_CAPACITY) {
threshold = Integer.MAX_VALUE;
return oldTab;
}
// 数组容量和扩容阈值扩大为原来的 2 倍
else if ((newCap = oldCap << 1) < MAXIMUM_CAPACITY && oldCap >= DEFAULT_INITIAL_CAPACITY)
newThr = oldThr << 1; // double threshold
}
// 旧容量为 0,需要初始化数组
else if (oldThr > 0)
// (带初始容量和负载因子的构造方法走这里)
// 使用构造方法中计算的最近 2 的整数幂作为数组容量
newCap = oldThr;
else {
// (无参构造方法走这里)
// 使用默认 16 长度作为初始容量
newCap = DEFAULT_INITIAL_CAPACITY;
// 使用默认的负载因子乘以容量计算扩容阈值
newThr = (int)(DEFAULT_LOAD_FACTOR * DEFAULT_INITIAL_CAPACITY);
}
if (newThr == 0) {
//(带初始容量和负载因子的构造方法走这里)
// 使用负载因子乘以容量计算扩容阈值
float ft = (float)newCap * loadFactor;
newThr = (newCap < MAXIMUM_CAPACITY && ft < (float)MAXIMUM_CAPACITY ? (int)ft : Integer.MAX_VALUE);
}
// 最终计算的扩容阈值
threshold = newThr;
// 2. 创建新数组
Node<K,V>[] newTab = (Node<K,V>[])new Node[newCap];
table = newTab;
// 3. 将旧数组上的键值对再散列到新数组上
if (oldTab != null) {
// 遍历旧数组上的每个桶
for (int j = 0; j < oldCap; ++j) {
// 桶的根节点
Node<K,V> e;
// 桶的根节点不为 null
if ((e = oldTab[j]) != null) {
oldTab[j] = null;
if (e.next == null)
// 3.1 桶的根节点,直接再散列
newTab[e.hash & (newCap - 1)] = e;
else if (e instanceof TreeNode)
// 3.2 以红黑树的方式再散列,思路与 3.3 链表的方式相似
((TreeNode<K,V>)e).split(this, newTab, j, oldCap);
else {
// 3.3 以链表的形式再散列
Node<K,V> loHead = null, loTail = null;
Node<K,V> hiHead = null, hiTail = null;
Node<K,V> next;
do {
next = e.next;
// 3.3.1 若散列值新参与映射的位为 0,那么映射到原始位置上
if ((e.hash & oldCap) == 0) {
if (loTail == null)
loHead = e;
else
loTail.next = e;
loTail = e;
}
// 3.3.2 若散列值新参与映射的位为 0,那么映射到原始位置 + 旧数组容量的位置上
else {
if (hiTail == null)
hiHead = e;
else
hiTail.next = e;
hiTail = e;
}
} while ((e = next) != null);
if (loTail != null) {
loTail.next = null;
newTab[j] = loHead;
}
if (hiTail != null) {
hiTail.next = null;
newTab[j + oldCap] = hiHead;
}
}
}
}
}
return newTab;
}

3.5 HashMap 的获取方法

HashMap 的获取方法相对简单,与 put 方法类似:先通过 hash 函数计算散列值,再通过 hash 取余映射到数组下标的桶中,最后遍历桶中的节点,找到与键(Key)相等(equals)的节点。

HashMap#get

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
java复制代码// 获取 Key 映射的键值对
public V get(Object key) {
Node<K,V> e;
return (e = getNode(hash(key)/*计算散列值*/, key)) == null ? null : e.value;
}

// 通过 Key 的散列值和 Key 获取映射的键值对
final Node<K,V> getNode(int hash, Object key) {
Node<K,V>[] tab; Node<K,V> first, e; int n; K k;
if ((tab = table) != null && (n = tab.length) > 0 && (first = tab[(n - 1) & hash]) != null) {
// 先检查根节点
if (first.hash == hash && ((k = first.key) == key || (key != null && key.equals(k))))
return first;
if ((e = first.next) != null) {
// 以红黑树的方式检索
if (first instanceof TreeNode)
return ((TreeNode<K,V>)first).getTreeNode(hash, key);
// 以链表的方式检索
do {
if (e.hash == hash &&
((k = e.key) == key || (key != null && key.equals(k))))
return e;
} while ((e = e.next) != null);
}
}
return null;
}

HashMap#get 示意图

3.6 HashMap 的移除方法

HashMap 的移除方法是添加方法的逆运算,HashMap 没有做动态缩容。

HashMap#remove

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
java复制代码public V remove(Object key) {
Node<K,V> e;
return (e = removeNode(hash(key)/*计算散列值*/, key, null, false, true)) == null ? null : e.value;
}

final Node<K,V> removeNode(int hash, Object key, Object value,
boolean matchValue, boolean movable) {
// 底层数组
Node<K,V>[] tab;
// 目标桶(同一个桶中节点的散列值有可能不同)
Node<K,V> p;
int n, index;
// 定位到散列值对应的数组下标
if ((tab = table) != null && (n = tab.length) > 0 && (p = tab[index = (n - 1) & hash]) != null) {
Node<K,V> node = null, e; K k; V v;
if (p.hash == hash && ((k = p.key) == key || (key != null && key.equals(k))))
// 先检查根节点
node = p;
else if ((e = p.next) != null) {
if (p instanceof TreeNode)
// 以红黑树的方式查询节点
node = ((TreeNode<K,V>)p).getTreeNode(hash, key);
else {
// 以链表的方式查询节点
do {
if (e.hash == hash && ((k = e.key) == key || (key != null && key.equals(k)))) {
node = e;
break;
}
p = e;
} while ((e = e.next) != null);
}
}
// node 不为 null,删除 node 节点
if (node != null && (!matchValue || (v = node.value) == value || (value != null && value.equals(v)))) {
if (node instanceof TreeNode)
// 以红黑树的方式删除
((TreeNode<K,V>)node).removeTreeNode(this, tab, movable);
else if (node == p)
// 以链表的方式删除(删除跟节点)
tab[index] = node.next;
else
// 以链表的方式删除(删除中间节点)
p.next = node.next;
++modCount;
--size;
// 删除节点回调(用于 LinkedHashMap,默认为空实现)
afterNodeRemoval(node);
return node;
}
}
return null;
}

HashMap#remove 示意图

3.7 HashMap 的迭代器

Java 的 foreach 是语法糖,本质上也是采用 iterator 的方式。HashMap 提供了 3 个迭代器:

  • EntryIterator: 键值对迭代器
  • KeyIterator: 键迭代器
  • ValueIterator: 值迭代器

在迭代器遍历数组的过程中,有可能出现多个线程并发修改数组的情况,Java 很多容器类的迭代器中都有 fail-fast 机制。如果在迭代的过程中发现 expectedModCount 变化,说明数据被修改,此时就会提前抛出 ConcurrentModificationException 异常(当然也不一定是被其他线程修改)。

其实,这 3 个迭代器都是 HashIterator 的子类,每个子类在 HashIterator#nextNode() 中获取不同的值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
java复制代码final class KeyIterator extends HashIterator implements Iterator<K> {
public final K next() { return nextNode().key; }
}

final class ValueIterator extends HashIterator implements Iterator<V> {
public final V next() { return nextNode().value; }
}

final class EntryIterator extends HashIterator implements Iterator<Map.Entry<K,V>> {
public final Map.Entry<K,V> next() { return nextNode(); }
}

// 非静态内部类
abstract class HashIterator {
Node<K,V> next; // next entry to return
Node<K,V> current; // current entry
int expectedModCount; // for fast-fail
int index; // current slot

HashIterator() {
// 记录外部类的修改计数
expectedModCount = modCount;
// 记录底层数组
Node<K,V>[] t = table;
current = next = null;
index = 0;
if (t != null && size > 0) { // advance to first entry
do {} while (index < t.length && (next = t[index++]) == null);
}
}

public final boolean hasNext() {
return next != null;
}

final Node<K,V> nextNode() {
Node<K,V>[] t;
Node<K,V> e = next;
// 检查修改记录
if (modCount != expectedModCount)
throw new ConcurrentModificationException();
if (e == null)
throw new NoSuchElementException();
// TreeNode 也会用 next 指针串联
if ((next = (current = e).next) == null && (t = table) != null) {
do {} while (index < t.length && (next = t[index++]) == null);
}
return e;
}
...
}

基于这 3 个迭代器,HashMap 的遍历方式就分为 3 种:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
java复制代码// 1. 直接遍历节点
Iterator<Entry<String, Integer>> iterator = map.entrySet().iterator();
while (iterator.hasNext()) {
Entry<String, Integer> next = iterator.next();
}

// 2. 遍历 Key,再通过 Key 查询 Value(性能最差,多一次查询)
Iterator<String> keyIterator = map.keySet().iterator();
while (keyIterator.hasNext()) {
String key = keyIterator.next();
}

// 3. 直接遍历 Value
Iterator<Integer> valueIterator = map.values().iterator();
while (valueIterator.hasNext()) {
Integer value = valueIterator.next();
}

// foreach 是语法糖
for (Map.Entry<String, Integer> entry : map.entrySet()) {
}
// 编译后:
Iterator var2 = map.entrySet().iterator();
while(var2.hasNext()) {
Entry<String, Integer> entry = (Entry)var2.next();
}

3.8 HashMap 的序列化过程

HashMap 重写了 JDK 序列化的逻辑,只把 table 数组中有效元素的部分序列化,而不会序列化整个数组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
java复制代码// 序列化过程
private void writeObject(java.io.ObjectOutputStream s) throws IOException {
int buckets = capacity();
s.defaultWriteObject();
// 写入容量
s.writeInt(buckets);
// 写入有效元素个数
s.writeInt(size);
// 写入有效元素
internalWriteEntries(s);
}

// 不关心键值对所在的桶,在反序列化会重新映射
void internalWriteEntries(java.io.ObjectOutputStream s) throws IOException {
Node<K,V>[] tab;
if (size > 0 && (tab = table) != null) {
for (int i = 0; i < tab.length; ++i) {
for (Node<K,V> e = tab[i]; e != null; e = e.next) {
s.writeObject(e.key);
s.writeObject(e.value);
}
}
}
}

3.9 HashMap 的 clone() 过程

HashMap 中的 table 数组是引用类型,因此在 clone() 中需要实现深拷贝,否则原对象与克隆对象会相互影响:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
java复制代码public Object clone() {
HashMap<K,V> result;
try {
result = (HashMap<K,V>)super.clone();
} catch (CloneNotSupportedException e) {
// this shouldn't happen, since we are Cloneable
throw new InternalError(e);
}
// 重置变量
result.reinitialize();
// 深拷贝
result.putMapEntries(this, false);
return result;
}
  1. 总结

今天,我们分析了 HashMap 的设计思路和核心源码,内容很多,收获也很多。其中,红黑树的部分我们没有展开讨论,这部分我们留到下一篇文章里讨论。请关注。


一道题目:

在网上看到一道题目,问题挺有迷惑性的:

  • 准备用 HashMap 存 1w 条数据,在构造时传 1w 容量,在添加时还会触发扩容吗?(答案是不会)
  • 准备用 HashMap 存 1k 条数据,在构造时传 1k 容量,在添加时还会触发扩容吗?(答案是会)

这是想考对 HashMap 容量和扩容阈值的理解了。在构造器中传递的 initialCapacity 并不一定是最终的容量,因为 HashMap 会使用 tableSizeFor() 方法计算一个最近的 2 的整数幂,而扩容阈值是在容量的基础上乘以默认的 0.75 装载因子上限。

因此,以上两种情况中,实际的容量和扩容阈值是:

  • 1w: 10000 转最近的 2 的整数幂是 16384,再乘以装载因子上限得出扩容阈值为 12288,所以不会触发扩容;
  • 1k: 1000 转最近的 2 的整数幂是 1024,再乘以装载因子上限得出扩容阈值为 768,所以会触发扩容;

版权声明

本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!

参考资料

  • 数据结构与算法分析 · Java 语言描述(第 5 章 · 散列)—— [美] Mark Allen Weiss 著
  • 算法导论(第 11 章 · 散列表)—— [美] Thomas H. Cormen 等 著
  • 数据结构与算法之美(第 18~22 讲) —— 王争 著,极客时间 出品
  • Java:这是一份详细&全面的HashMap 1.7 源码分析 —— Carson 著
  • Java源码分析:HashMap 1.8 相对于1.7 到底更新了什么? —— Carson 著
  • 都说 HashMap 是线程不安全的,到底体现在哪儿? —— developer 著
  • 漫画:高并发下的HashMap —— 程序员小灰 著
  • 面试官:”准备用HashMap存1w条数据,构造时传10000还会触发扩容吗?” —— 承香墨影 著
  • 散列算法 —— Wikipedia
  • Poisson Distribution —— Wikipedia

推荐阅读

Java & Android 集合框架系列文章目录(2023/07/08 更新):

  • #1 ArrayList 可以完全替代数组吗?
  • #2 说一下 ArrayList 和 LinkedList 的区别?
  • #3 CopyOnWriteArrayList 是如何保证线程安全的?
  • #4 ArrayDeque:如何用数组实现栈和队列?
  • #5 万字 HashMap 详解,基础(优雅)永不过时 —— 原理篇
  • #6 万字 HashMap 详解,基础(优雅)永不过时 —— 源码篇
  • #7 如何使用 LinkedHashMap 实现 LRU 缓存?
  • #8 说一下 WeakHashMap 如何清理无效数据的?
  • #9 全网最全的 ThreadLocal 原理详细解析 —— 原理篇
  • #10 全网最全的 ThreadLocal 原理详细解析 —— 源码篇

数据结构与算法系列文章:跳转阅读

⭐️ 永远相信美好的事情即将发生,欢迎加入小彭的 Android 交流社群~

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

Java & Android 集合框架

发表于 2022-11-09

⭐️ 本文已收录到 AndroidFamily,技术和职场问题,请关注公众号 [彭旭锐] 和 [BaguTree Pro] 知识星球提问。

学习数据结构与算法的关键在于掌握问题背后的算法思维框架,你的思考越抽象,它能覆盖的问题域就越广,理解难度也更复杂。在实际的业务开发中,往往不需要我们手写数据结构,而是直接使用标准库的数据结构 / 容器类。

本文是 Java & Android 集合框架系列的第 5 篇文章,完整文章目录请移步到文章末尾~

前言

大家好,我是小彭。

在上一篇文章里,我们聊到了散列表的整体设计思想,在后续几篇文章里,我们将以 Java 语言为例,分析标准库中实现的散列表实现,包括 HashMap、ThreadLocalMap、LinkedHashMap 和 ConcurrentHashMap。

今天,我们来讨论 Java 标准库中非常典型的散列表结构,也是 “面试八股文” 的标准题库之一 —— HashMap。

本文源码基于 Java 8 HashMap,并关联分析部分 Java 7 HashMap。

  • Java & Android 集合框架 #5 万字 HashMap 详解,基础(优雅)永不过时 —— 原理篇
  • Java & Android 集合框架 #6 万字 HashMap 详解,基础(优雅)永不过时 —— 源码篇

思维导图:


  1. 回顾散列表工作原理

在分析 HashMap 的实现原理之前,我们先来回顾散列表的工作原理。

散列表是基于散列思想实现的 Map 数据结构,将散列思想应用到散列表数据结构时,就是通过 hash 函数提取键(Key)的特征值(散列值),再将键值对映射到固定的数组下标中,利用数组支持随机访问的特性,实现 O(1) 时间的存储和查询操作。

散列表示意图

在从键值对映射到数组下标的过程中,散列表会存在 2 次散列冲突:

  • 第 1 次 - hash 函数的散列冲突: 这是一般意义上的散列冲突;
  • 第 2 次 - 散列值取余转数组下标: 本质上,将散列值转数组下标也是一次 Hash 算法,也会存在散列冲突。同时,这也说明 HashMap 中同一个桶中节点的散列值不一定是相同的。

事实上,由于散列表是压缩映射,所以我们无法避免散列冲突,只能保证散列表不会因为散列冲突而失去正确性。常用的散列冲突解决方法有 2 类:

  • 开放寻址法: 例如 ThreadLocalMap;
  • 分离链表法: 例如今天要分析的 HashMap 散列表。

分离链表法(Separate Chaining)的核心思想是: 在出现散列冲突时,将冲突的元素添加到同一个桶(Bucket / Slot)中,桶中的元素会组成一个链表,或者跳表、红黑树等动态数据结构。相比于开放寻址法,链表法是更常用且更稳定的冲突解决方法。

分离链表法示意图

影响散列表性能的关键在于 “散列冲突的发生概率”,冲突概率越低,时间复杂度越接近于 O(1)。 那么,哪些因素会影响冲突概率呢?主要有 3 个:

  • 因素 1 - 装载因子: 装载因子 (Load Factor) = 散列表中键值对数目 / 散列表的长度。随着散列表中元素越来越多,空闲位置越来越少,就会导致散列冲突的发生概率越来越大,使得散列表操作的平均时间会越来越大;
  • 因素 2 - 采用的冲突解决方法: 开放寻址法的冲突概率天然比分离链表法高,适合于小数据量且装载因子较小的场景;分离链表法对装载因子的容忍度更高,适合于大数据量且大对象(相对于一个指针)的场景;
  • 因素 3 - 散列函数设计: 散列算法随机性和高效性也会影响散列表的性能。如果散列值不够随机,即使散列表整体的装载因子不高,也会使得数据聚集在某一个区域或桶内,依然会影响散列表的性能。

  1. 认识 HashMap 散列表

2.1 说一下 HashMap 的底层结构?

HashMap 是基于分离链表法解决散列冲突的动态散列表:

  • 在 Java 7 中使用的是 “数组 + 链表”,发生散列冲突的键值对会用头插法添加到单链表中;
  • 在 Java 8 中使用的是 “数组 + 链表 + 红黑树”,发生散列冲突的键值对会用尾插法添加到单链表中。如果链表的长度大于 8 时且散列表容量大于 64,会将链表树化为红黑树。在扩容再散列时,如果红黑树的长度低于 6 则会还原为链表;
  • HashMap 的数组长度保证是 2 的整数幂,默认的数组容量是 16,默认装载因子上限是 0.75,扩容阈值是 12(16*0.75);
  • 在创建 HashMap 对象时,并不会创建底层数组,这是一种懒初始化机制,直到第一次 put 操作才会通过 resize() 扩容操作初始化数组;
  • HashMap 的 Key 和 Value 都支持 null,Key 为 null 的键值对会映射到数组下标为 0 的桶中。

2.2 为什么 HashMap 采用拉链法而不是开放地址法?

我认为 Java 给予 HashMap 的定位是一个相对 “通用” 的散列表容器,它应该在面对各种输入场景中都表现稳定。

开放地址法的散列冲突发生概率天然比分离链表法更高,所以基于开放地址法的散列表不能把装载因子的上限设置得很高。在存储相同的数据量时,开放地址法需要预先申请更大的数组空间,内存利用率也不会高。因此,开放地址法只适合小数据量且装载因子较小的场景。

而分离链表法对于装载因子的容忍度更高,能够适合大数据量且更高的装载因子上限,内存利用率更高。虽然链表节点会多消耗一个指针内存,但在一般的业务场景中可以忽略不计。

我们可以举个反例,在 Java 原生的数据结构中,也存在使用开放地址法的散列表 —— 就是 ThreadlLocal。因为项目中不会大量使用 ThreadLocal 线程局部存储,所以它是一个小规模数据场景,这里使用开放地址法是没问题的。

2.3 为什么 HashMap 在 Java 8 要引入红黑树呢?

因为当散列冲突加剧的时候,在链表中寻找对应元素的时间复杂度是 O(K),K 是链表长度。在极端情况下,当所有数据都映射到相同链表时,时间复杂度会 “退化” 到 O(n)。

而使用红黑树(近似平衡的二叉搜索树)的话,树形结构的时间复杂度与树的高度有关, 查找复杂度是 O(lgK),最坏情况下时间复杂度是 O(lgn),时间复杂度更低。

2.4 为什么 HashMap 使用红黑树而不是平衡二叉树?

这是在查询性能和维护成本上的权衡,红黑树和平衡二叉树的区别在于它们的平衡程度的强弱不同:

平衡二叉树追求的是一种 “完全平衡” 状态:任何结点的左右子树的高度差不会超过 1。优势是树的结点是很平均分配的;

红黑树不追求这种完全平衡状态,而是追求一种 “弱平衡” 状态:整个树最长路径不会超过最短路径的 2 倍。优势是虽然牺牲了一部分查找的性能效率,但是能够换取一部分维持树平衡状态的成本。

2.5 为什么经常使用 String 作为 HashMap 的 Key?

  • 1、不可变类 String 可以避免修改后无法定位键值对: 假设 String 是可变类,当我们在 HashMap 中构建起一个以 String 为 Key 的键值对时,此时对 String 进行修改,那么通过修改后的 String 是无法匹配到刚才构建过的键值对的,因为修改后的 hashCode 可能会变化,而不可变类可以规避这个问题。
  • 2、String 能够满足 Java 对于 hashCode() 和 equals() 的通用约定: 既两个对象 equals() 相同,则 hashCode() 相同,如果 hashCode() 相同,则 equals() 不一定相同。这个约定是为了避免两个 equals() 相同的 Key 在 HashMap 中存储两个独立的键值对,引起矛盾。

2.6 HashMap 的多线程程序中会出现什么问题?

  • 数据覆盖问题:如果两个线程并发执行 put 操作,并且两个数据的 hash 值冲突,就可能出现数据覆盖(线程 A 判断 hash 值位置为 null,还未写入数据时挂起,此时线程 B 正常插入数据。接着线程 A 获得时间片,由于线程 A 不会重新判断该位置是否为空,就会把刚才线程 B 写入的数据覆盖掉)。事实上,这个未同步数据在任意多线程环境中都会存在这个问题。
  • 环形链表问题: 在 HashMap 触发扩容时,并且正好两个线程同时在操作同一个链表时,就可能引起指针混乱,形成环型链条(因为 Java 7 版本采用头插法,在扩容时会翻转链表的顺序,而 Java 8 采用尾插法,再扩容时会保持链表原本的顺序)。

2.7 HashMap 如何实现线程安全?

有 3 种方式:

  • 方式 1 - 使用 hashTable 容器类(过时): hashTable 是线程安全版本的散列表,它会在所有方法上增加 synchronized 关键字,且不支持 null 作为 Key。
  • 方法 2 - 使用 Collections.synchronizedMap 包装类: 原理也是在所有方法上增加 synchronized 关键字;
  • 方法 3 - 使用 ConcurrentHashMap 容器类: 基于 CAS 无锁 + 分段实现的线程安全散列表;

  1. HashMap 的属性

在分析 HashMap 的执行流程之前,我们先用一个表格整理 HashMap 的属性:

版本 数据结构 节点实现类 属性
Java 7 数组 + 链表 Entry(单链表) 1、table(数组)2、size(尺寸)3、threshold(扩容阈值)4、loadFactor(装载因子上限)5、modCount(修改计数)6、默认数组容量 167、最大数组容量 2^308、默认负载因子 0.75
Java 8 数组 + 链表 + 红黑树 1、Node(单链表)2、TreeNode(红黑树) 9、桶的树化阈值 810、桶的还原阈值 611、最小树化容量阈值 64

HashMap.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
java复制代码public class HashMap<K,V> extends AbstractMap<K,V>
implements Map<K,V>, Cloneable, Serializable {

// 默认数组容量
static final int DEFAULT_INITIAL_CAPACITY = 1 << 4; // aka 16

// 疑问 3:为什么最大容量是 2^30 次幂?
// 疑问 4:为什么 HashMap 要求数组的容量是 2 的整数幂?
// 数组最大容量:2^30(高位 0100,低位都是 0)
static final int MAXIMUM_CAPACITY = 1 << 30;

// 默认负载因子:0.75
static final float DEFAULT_LOAD_FACTOR = 0.75f;

// 疑问 5:为什么要设置桶的树化阈值,而不是直接使用数组 + 红黑树?
// (Java 8 新增)桶的树化阈值:8
static final int TREEIFY_THRESHOLD = 8;

// (Java 8 新增)桶的还原阈值:6(在扩容时,当原有的红黑树内数量 <= 6时,则将红黑树还原成链表)
static final int UNTREEIFY_THRESHOLD = 6;

// 疑问 6:为什么要在设置桶的树化阈值后,还要设置树化的最小容量?
// (Java 8 新增)树化的最小容量:64(只有整个散列表的长度满足最小容量要求时才允许链表树化,否则会直接扩容,而不是树化)
static final int MIN_TREEIFY_CAPACITY = 64;

// 底层数组(每个元素是一个单链表或红黑树)
transient Node<K,V>[] table;

// entrySet() 返回值缓存
transient Set<Map.Entry<K,V>> entrySet;

// 有效键值对数量
transient int size;

// 扩容阈值(容量 * 装载因子)
int threshold;

// 装载因子上限
final float loadFactor;

// 修改计数
transient int modCount;

// 链表节点(一个 Node 等于一个键值对)
static class Node<K,V> implements Map.Entry<K,V> {
// 哈希值(相同链表上 Key 的哈希值可能相同)
final int hash;
// Key(一个散列表上 Key 的 equals() 一定不同)
final K key;
// Value(Value 不影响节点位置)
V value;
Node<K,V> next;

Node(int hash, K key, V value, Node<K,V> next) {
this.hash = hash;
this.key = key;
this.value = value;
this.next = next;
}

// Node 的 hashCode 取 Key 和 Value 的 hashCode
public final int hashCode() {
return Objects.hashCode(key) ^ Objects.hashCode(value);
}

// 两个 Node 的 Key 和 Value 都相等,才认为相等
public final boolean equals(Object o) {
if (o == this)
return true;
if (o instanceof Map.Entry) {
Map.Entry<?,?> e = (Map.Entry<?,?>)o;
if (Objects.equals(key, e.getKey()) &&
Objects.equals(value, e.getValue()))
return true;
}
return false;
}
}

// (Java 8 新增)红黑树节点
static final class TreeNode<K,V> extends LinkedHashMap.Entry<K,V> {
// 父节点
TreeNode<K,V> parent;
// 左子节点
TreeNode<K,V> left;
// 右子节点
TreeNode<K,V> right;
// 删除辅助节点
TreeNode<K,V> prev;
// 颜色
boolean red;

TreeNode(int hash, K key, V val, Node<K,V> next) {
super(hash, key, val, next);
}

// 返回树的根节点
final TreeNode<K,V> root() {
for (TreeNode<K,V> r = this, p;;) {
if ((p = r.parent) == null)
return r;
r = p;
}
}
}
}

LinkedHashMap.java

1
2
3
4
5
6
java复制代码static class Entry<K,V> extends HashMap.Node<K,V> {
Entry<K,V> before, after;
Entry(int hash, K key, V value, Node<K,V> next) {
super(hash, key, value, next);
}
}

相比于线性表,HashMap 的属性可算是上难度了,HashMap 真卷。不出意外的话又有小朋友出来举手提问了🙋🏻‍♀️:

  • 🙋🏻‍♀️疑问 1: 为什么字段不声明 private 关键字?(回答过多少次了,把手给我放下)
  • 🙋🏻‍♀️疑问 2: 为什么字段声明 transient 关键字?(回答过多少次了,把手给我放下)
  • 🙋🏻‍♀️疑问 3:为什么最大容量是 2^30?

因为 HashMap 要求散列表的数组容量是 2 的整数幂 ,而 int 类型能够表示的最大 2 的整数幂就是 2^30,即高位第 31 位是 1,低位都是 0。

  • 🙋🏻‍♀️疑问 4:为什么 HashMap 要求数组的容量是 2 的整数幂?

这个问题我们下面再回答。

  • 🙋🏻‍♀️疑问 5:为什么要设置桶的树化阈值,而不是直接使用数组 + 红黑树?

其实,红黑树是 “兜底” 策略,而不一定是最优策略。

首先,红黑树节点本身的内存消耗是链表节点的 2 倍。其次,红黑树在添加和删除数据时需要维护红黑树的性质,会增加旋转等操作。所以,当桶的节点数很低时,并不能体现出红黑树的优势(类似于 Arrays.sort 在子数组长度小于 47 时用插入排序而不是快速排序)。

再结合散列分析的数据统计,在装载因子上限为 0.75 且平均负载因子为 0.5 HashMap 中,桶长度的出现频率符合泊松分布,大部分的桶分布在 0 ~ 3 的长度上,长度大于 8 的桶的出现频率低于千万分之一。

综上所述,为了避免在小桶中使用红黑树,HashMap 在桶的长度大于等于 8 时才会树化为红黑树。并且在扩容再散列时,如果桶的长度小于等于 6,也会还原为链表。

散列冲突数据统计

1
2
3
4
5
6
7
8
9
10
11
bash复制代码# 装载因子上限为 0.75、平均负载因子为 0.5,且散列函数随机性良好时,不同长度桶的出现频率
0: 0.60653066
1: 0.30326533
2: 0.07581633
3: 0.01263606
4: 0.00157952
5: 0.00015795
6: 0.00001316
7: 0.00000094
8: 0.00000006
more: less than 1 in ten million # 低于千万分之一
  • 🙋🏻‍♀️疑问 6:为什么要在设置桶的树化阈值后,还要设置树化的最小容量?

这是为了避免无效的树化。

在散列表的容量较低时,添加数据时很容易会触发扩容。此时,一部分原本已经树化的桶会由于长度下降而退还回链表。因此,红黑树为树化操作设置了最小容量要求:如果链表长度达到树化阈值,但散列表整体的长度未达到最小容量要求,那么就直接扩容,而不是在桶上树化。


后续源码分析,见下一篇文章:Java & Android 集合框架 #6 万字 HashMap 详解,基础(优雅)永不过时 —— 源码篇。


版权声明

本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!

参考资料

  • 数据结构与算法分析 · Java 语言描述(第 5 章 · 散列)—— [美] Mark Allen Weiss 著
  • 算法导论(第 11 章 · 散列表)—— [美] Thomas H. Cormen 等 著
  • 数据结构与算法之美(第 18~22 讲) —— 王争 著,极客时间 出品
  • Java:这是一份详细&全面的HashMap 1.7 源码分析 —— Carson 著
  • Java源码分析:HashMap 1.8 相对于1.7 到底更新了什么? —— Carson 著
  • 都说 HashMap 是线程不安全的,到底体现在哪儿? —— developer 著
  • 漫画:高并发下的HashMap —— 程序员小灰 著
  • 面试官:”准备用HashMap存1w条数据,构造时传10000还会触发扩容吗?” —— 承香墨影 著
  • 散列算法 —— Wikipedia
  • Poisson Distribution —— Wikipedia

推荐阅读

Java & Android 集合框架系列文章目录(2023/07/08 更新):

  • #1 ArrayList 可以完全替代数组吗?
  • #2 说一下 ArrayList 和 LinkedList 的区别?
  • #3 CopyOnWriteArrayList 是如何保证线程安全的?
  • #4 ArrayDeque:如何用数组实现栈和队列?
  • #5 万字 HashMap 详解,基础(优雅)永不过时 —— 原理篇
  • #6 万字 HashMap 详解,基础(优雅)永不过时 —— 源码篇
  • #7 如何使用 LinkedHashMap 实现 LRU 缓存?
  • #8 说一下 WeakHashMap 如何清理无效数据的?
  • #9 全网最全的 ThreadLocal 原理详细解析 —— 原理篇
  • #10 全网最全的 ThreadLocal 原理详细解析 —— 源码篇

数据结构与算法系列文章:跳转阅读

⭐️ 永远相信美好的事情即将发生,欢迎加入小彭的 Android 交流社群~

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

【AI】浅谈梯度下降算法(实战篇) 前言 大家族 一维问题

发表于 2022-11-08

本文正在参加「金石计划 . 瓜分6万现金大奖」

前言

在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent) 是最常采用的方法之一,另一种常用的方法是最小二乘法。

在 【AI】浅谈梯度下降算法(理论篇) 这篇博文中,我们已经学习了梯度下降算法的一些基本概念以及理论推导,接下来,我们将通过结合代码进行实战,理论与实践相结合,加深对知识点的理解;

大家族

尽管说是梯度下降,但其实它还是个庞大的家族,就类似于编程语言有 C、Java、Python 等之分,梯度下降算法也被分为了几大类,主要的有 BGD、SGD、MBGD:

  • 批量梯度下降法(Batch Gradient Descent) : 梯度下降法最常用的形式,具体做法也就是在更新参数时使用所有的样本来进行更新;

θi=θi−α∑j=0m(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)θ_i=θ_i−α∑_{j=0}^m(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)x_i^{(j)}θi=θi−αj=0∑m(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)
优点:全局最优解,易于并行实现;

缺点:计算代价大,数据量大时,训练过程慢;

  • 随机梯度下降法(Stochastic Gradient Descent) : 和批量梯度下降法原理类似,区别在于求梯度时,没有用所有的 mmm 个样本的数据,而是仅仅选取一个样本 jjj 来求梯度;

θi=θi−α(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)θ_i=θ_i−α(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)x_i^{(j)}θi=θi−α(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)
优点:训练速度快;

缺点:准确度下降,并不是全局最优,不易于并行实现;

  • 小批量梯度下降法(Mini-batch Gradient Descent) : 小批量梯度下降法是批量梯度下降法和随机梯度下降法的折中,也就是对于 mmm 个样本,我们采用 xxx 个样本来迭代,1<x<m1<x<m1<x<m。一般可以取 x=10x=10x=10,当然根据样本的数据量,可以调整这个 xxx 的值;

θi=θi−α∑j=tt+x−1(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)θ_i=θ_i−α∑_{j=t}^{t+x-1}(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)x_i^{(j)}θi=θi−αj=t∑t+x−1(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)
前两种方法的性能折中;

一维问题

例1:求 f(x)=x2+1f(x) = x^2 + 1f(x)=x2+1 的最小值

image.png

f(x)=x2+1f(x) = x^2 + 1f(x)=x2+1
使用梯度下降法求 f(x)=x2+1(−10≤x≤10)f(x) = x^2 + 1 \quad (-10 \leq x \leq 10)f(x)=x2+1(−10≤x≤10) 的最小值

因为 f(x)=x2+1f(x) = x^2 + 1f(x)=x2+1 是凸函数,从图中也可以一眼看出,其最小值就在 x=0x=0x=0 处;

接下来就使用梯度下降法进行求解:

1、目标函数,即 f(x)=x2+1f(x) = x^2 + 1f(x)=x2+1 :

1
2
py复制代码def func_target(x):
return x ** 2 + 1

2、求解梯度,即 f(x)′=2xf(x)^{‘} = 2xf(x)′=2x :

1
2
py复制代码def func_gradient(x):
return x * 2

3、梯度下降算法,需要注意几个参数的意义:

  • x : 当前 x 的值,可以通过参数提供初始值;
  • learn_rate : 学习率,相当于设置的步长;
  • precision : 收敛精度;
  • max_iters : 最大迭代次数;
1
2
3
4
5
6
7
8
9
10
11
12
13
py复制代码def SGD(x=1, learn_rate=0.1, precision=1e-5, max_iters=10000):
for i in range(max_iters):
grad_cur = func_gradient(x)
if abs(grad_cur) < precision:
break
x = x - learn_rate * grad_cur
print(f"第 {i+1} 次迭代: x 值为 {x}, y 值为 {func_target(x)}")

print(f"\n最小值 x = {x}, y = {func_target(x)}")
return x

if __name__ == '__main__':
SGD(x=10, learn_rate=0.2)

image.png

例2:求 12[(x1+x2−4)2+(2x1+3x2−7)2]\frac{1}{2}[(x_1+x_2-4)^2 + (2x_1+3x_2-7)^2]21[(x1+x2−4)2+(2x1+3x2−7)2] 的极值

通过梯度下降的方法成功求得了 的最小值之后是不是信心大增呢,接下来让我们逐步加深难度:使用梯度下降法求多项式 12[(x1+x2−4)2+(2x1+3x2−7)2]\frac{1}{2}[(x_1+x_2-4)^2 + (2x_1+3x_2-7)^2]21[(x1+x2−4)2+(2x1+3x2−7)2] 的极值;

在使用梯度下降求解这道题的过程中,就不得不注意到一个问题:梯度下降可能在局部最小的点收敛;

image.png

1、目标函数,即 12[(x1+x2−4)2+(2x1+3x2−7)2\frac{1}{2}[(x_1+x_2-4)^2 + (2x_1+3x_2-7)^221[(x1+x2−4)2+(2x1+3x2−7)2 :

1
2
py复制代码def func_target(x1, x2):
return ((x1 + x2 - 4) ** 2 + (2*x1 + 3*x2 -7) ** 2) * 0.5

2、求解梯度,即 ∂f∂x1=(x1+x2−4)+2(2x1+3x2−7)\frac{∂f}{∂x_1} = (x_1+x_2-4)+2(2x_1+3x_2-7)∂x1∂f=(x1+x2−4)+2(2x1+3x2−7) 和 ∂f∂x2=(x1+x2−4)+3(2x1+3x2−7)\frac{∂f}{∂x_2} = (x_1+x_2-4)+3(2x_1+3x_2-7)∂x2∂f=(x1+x2−4)+3(2x1+3x2−7):

1
2
3
4
py复制代码def func_gradient(x1, x2):
grad_x1 = (x1 + x2 - 4) + 2 * (2*x1 + 3*x2 -7)
grad_x2 = (x1 + x2 - 4) + 3 * (2*x1 + 3*x2 -7)
return grad_x1, grad_x2

3、梯度下降算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
py复制代码def SGD(x1=0, x2=0, learn_rate=0.01, precision=1e-6, max_iters=10000):
y1 = func_target(x1, x2)
for i in range(max_iters):
grad_x1, grad_x2 = func_gradient(x1, x2)
x1 = x1 - learn_rate * grad_x1
x2 = x2 - learn_rate * grad_x2
y2 = func_target(x1, x2)
if (y1 - y2) < precision:
break
if y2 < y1: y1 = y2
print(f"第 {i+1} 次迭代: x1 值为 {x1}, x2 值为 {x2}, 输出值为 {y2}")

print(f"该多项式的极小值为 {y2}, ({x1}, {x2})")
return x1, x2, y2

if __name__ == '__main__':
SGD()

image.png

image.png

中间的迭代过程就省略了;

二维问题

当你通过自己的努力完成前两个例子之后,你是不是已经不满足于一维问题了呢,那么接下来我们进入二维问题:使用梯度下降法求 f(x,y)=−e−(x2+y2)f(x,y) = -e^{-(x^2+y^2)}f(x,y)=−e−(x2+y2) 在 [0,0][0,0][0,0] 处有最小值;

image.png

f(x,y)=−e−(x2+y2)f(x,y) = -e^{-(x^2+y^2)}f(x,y)=−e−(x2+y2)
通过这个例子,我们会发现梯度下降的局限性,先在这里留个悬念;

1、目标函数,即 f(x,y)=−e−(x2+y2)f(x,y) = -e^{-(x^2+y^2)}f(x,y)=−e−(x2+y2) :

1
2
3
py复制代码def func_target(cell):
:param cell: 二维向量
return -math.exp(-(cell[0] ** 2 + cell[1] ** 2))

2、求解梯度,即 ∂f∂x=2xe−(x2+y2)\frac{∂f}{∂x} = 2xe^{-(x^2+y^2)}∂x∂f=2xe−(x2+y2) 和 ∂f∂y=2ye−(x2+y2)\frac{∂f}{∂y} = 2ye^{-(x^2+y^2)}∂y∂f=2ye−(x2+y2):

1
2
3
4
5
py复制代码def func_gradient(cell):
:param cell: 二维向量
grad_x = 2 * cell[0] * math.exp(-(cell[0] ** 2 + cell[1] ** 2))
grad_y = 2 * cell[1] * math.exp(-(cell[0] ** 2 + cell[1] ** 2))
return np.array([grad_x, grad_y])

3、梯度下降算法:

1
2
3
4
5
6
7
8
9
10
py复制代码def SGD(x=np.array([0.1, 0.1]), learn_rate=0.1, precision=1e-6, max_iters=10000):
for i in range(max_iters):
grad_cur = func_gradient(x)
if np.linalg.norm(grad_cur, ord=2) < precision:
break
x = x - learn_rate * grad_cur
print(f"第 {i+1} 次迭代: x 值为 {x}, y 值为 {func_target(x)}")

print(f"\n最小值 x = {x}, y = {func_target(x)}")
return x

4、当 x0x0x0 的初始值设为 [1,−1][1,−1][1,−1] 时,一切都显得很正常:

image.png

5、但当我们把 x0x0x0 的初始值设为 [3,−3][3,−3][3,−3] 时,结果是出乎意料的:
image.png

梯度下降法没有找到真正的极小值点!

局限性

继续讲述上面的非预期结果:

如果仔细观察目标函数的图像,以及梯度下降法的算法原理,你就很容易发现问题所在了。在 [3,−3][3,−3][3,−3] 处的梯度就几乎为 0 了!

由于“梯度过小”,梯度下降法可能无法确定前进的方向了。即使人为增加收敛条件中的精度,也会由于梯度过小,导致迭代中前进的步长距离过小,循环时间过长。

梯度下降法实现简单,原理也易于理解,但它有自身的局限性,因此有了后面很多算法对它的改进。

对于梯度过小的情况,梯度下降法可能难以求解。

此外,梯度下降法适合求解只有一个局部最优解的目标函数,对于存在多个局部最优解的目标函数,一般情况下梯度下降法不保证得到全局最优解(由于凸函数有个性质是只存在一个局部最优解,所有也有文献的提法是:当目标函数是凸函数时,梯度下降法的解才是全局最优解)。

由于泰勒公式的展开是近似公式,要求迭代步长要足够小,因此梯度下降法的收敛速度并非很快的。

后记

上述就是本篇博文的所有内容了,通过实战对梯度下降知识点进行巩固和加深印象,并且层层收入,希望读者能有所收获!

对于理论还不是很清楚的读者,可以回看上篇博文:【AI】浅谈梯度下降算法(理论篇);

参考:

  • 梯度下降(Gradient Descent)
  • Python 实现简单的梯度下降法
  • 梯度下降法原理与python实现

📝 上篇精讲:【AI】浅谈梯度下降算法(理论篇)

💖 我是 𝓼𝓲𝓭𝓲𝓸𝓽,期待你的关注;

👍 创作不易,请多多支持;

🔥 系列专栏:AI

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

【AI】浅谈梯度下降算法(理论篇) 前言 梯度 梯度下降 算

发表于 2022-11-08

本文正在参加「金石计划 . 瓜分6万现金大奖」

前言

在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent) 是最常采用的方法之一,另一种常用的方法是最小二乘法。

目前正在学习这方面相关的内容,因此简单谈谈与梯度下降法相关的内容。

梯度

在微积分里面,对多元函数的参数求 ∂∂∂ 偏导数,把求得的各个参数的偏导数以向量的形式写出来,就是梯度。

比如函数 f(x,y)f(x, y)f(x,y),分别对 xxx, yyy 求偏导数,求得的梯度向量就是 (∂f∂x, ∂f∂y)T(\frac{∂f}{∂x}, \frac{∂f}{∂y})^T(∂x∂f, ∂y∂f)T,简称 gradf(x,y)grad \quad f(x,y)gradf(x,y) 或者 ▽f(x,y)▽f(x,y)▽f(x,y)。对于在点 (x0,y0)(x_0,y_0)(x0,y0) 的具体梯度向量就是 (∂f∂x0, ∂f∂y0)T(\frac{∂f}{∂x_0}, \frac{∂f}{∂y_0})^T(∂x0∂f, ∂y0∂f)T 或者 ▽f(x0,y0)▽f(x_0,y_0)▽f(x0,y0),如果是3个参数的向量梯度,就是 (∂f∂x, ∂f∂y,∂f∂z)T(\frac{∂f}{∂x}, \frac{∂f}{∂y},\frac{∂f}{∂z})^T(∂x∂f, ∂y∂f,∂z∂f)T,以此类推。

那么这个梯度向量求出来有什么意义呢?他的意义从几何意义上讲,就是函数变化增加最快的地方。具体来说,对于函数 f(x,y)f(x,y)f(x,y),在点 (x0,y0)(x_0,y_0)(x0,y0),沿着梯度向量的方向,即 (∂f∂x0, ∂f∂y0)T(\frac{∂f}{∂x_0}, \frac{∂f}{∂y_0})^T(∂x0∂f, ∂y0∂f)T 的方向,是 f(x,y)f(x,y)f(x,y) 增加最快的地方。或者说,沿着梯度向量的方向,更加容易找到函数的最大值。反过来说,沿着梯度向量相反的方向,也就是 −(∂f∂x0, ∂f∂y0)T-(\frac{∂f}{∂x_0}, \frac{∂f}{∂y_0})^T−(∂x0∂f, ∂y0∂f)T 的方向,梯度减少最快,也就是更加容易找到函数的最小值。

梯度下降

image.png

梯度下降法(英语:Gradient descent)是一个一阶最优化算法,通常也称为最陡下降法,但是不该与近似积分的最陡下降法(英语:Method of steepest descent)混淆。 要使用梯度下降法找到一个函数的局部极小值,必须向函数上当前点对应梯度(或者是近似梯度)的 反方向 的规定步长距离点进行迭代搜索。如果相反地向梯度 正方向 迭代进行搜索,则会接近函数的局部极大值点;这个过程则被称为梯度上升法。

上述对梯度下降法的描述来自于维基百科,简单概括一下就是 选取适当的初值 x0x_0x0,不断迭代更新 xxx 的值,极小化目标函数,最终收敛;

在进行算法推导时,我们还需要注意一些概念:

  1. 步长(Learning rate):步长决定了在梯度下降迭代的过程中,每一步沿梯度负方向前进的长度。
  2. 特征(feature):指的是样本中输入部分,比如2个单特征的样本 (x(0),y(0))(x^{(0)},y^{(0)})(x(0),y(0)),(x(1),y(1))(x^{(1)},y^{(1)})(x(1),y(1)),则第一个样本特征为 x(0)x^{(0)}x(0),第一个样本输出为 y(0)y^{(0)}y(0)。
  3. 假设函数(hypothesis function):在监督学习中,为了拟合输入样本,而使用的假设函数,记为 hθ(x)h_θ(x)hθ(x)。比如对于单个特征的 m 个样本 (x(i),y(i))(i=1,2,…,m)(x^{(i)},y^{(i)})(i=1,2,…,m)(x(i),y(i))(i=1,2,…,m),可以采用拟合函数如下: hθ(x)=θ0+θ1xh_θ(x)=θ_0+θ_1xhθ(x)=θ0+θ1x。
  4. 损失函数(loss function):为了评估模型拟合的好坏,通常用损失函数来度量拟合的程度。损失函数极小化,意味着拟合程度最好,对应的模型参数即为最优参数。在线性回归中,损失函数通常为样本输出和假设函数的差取平方。 比如对于 m 个样本 (xi,yi)(i=1,2,…,m)(x_i,y_i)(i=1,2,…,m)(xi,yi)(i=1,2,…,m),采用线性回归,损失函数为:
    J(θ0,θ1)=∑i=1m(hθ(xi)−yi)2J(θ_0,θ_1)=∑_{i=1}^m(h_θ(x_i)−y_i)^2J(θ0,θ1)=i=1∑m(hθ(xi)−yi)2
    其中 xix_ixi 表示第 iii 个样本特征,yiy_iyi 表示第 iii 个样本对应的输出,hθ(xi)h_θ(x_i)hθ(xi) 为假设函数。

算法推导

先决条件: 在线性回归的前提下,确认优化模型的假设函数和损失函数。

1、确定当前位置的损失函数的梯度,对于 θiθ_iθi,其梯度表达式如下:

∂∂θiJ(θ0,θ1…,θn)\frac{∂}{∂θ_i}J(θ_0,θ_1…,θ_n)∂θi∂J(θ0,θ1…,θn)
2、用步长 ααα (这里指机器学习中的学习率更为合适) 乘以损失函数的梯度,得到当前位置下降的距离,即

α∂∂θiJ(θ0,θ1…,θn)α\frac{∂}{∂θ_i}J(θ_0,θ_1…,θ_n)α∂θi∂J(θ0,θ1…,θn)
3、确定是否所有的 θiθ_iθi,梯度下降的距离都小于 εεε,如果小于 εεε 则算法终止,当前所有的 θi(i=0,1,…n)θ_i(i=0,1,…n)θi(i=0,1,…n) 即为最终结果,否则进入步骤4;

4、更新所有的 θθθ,对于 θiθ_iθi,其更新表达式如下,更新完毕后继续转入步骤1;

θi=θi−α∂∂θiJ(θ0,θ1…,θn)θ_i=θ_i−α\frac{∂}{∂θ_i}J(θ_0,θ_1…,θ_n)θi=θi−α∂θi∂J(θ0,θ1…,θn)

TIP

损失函数如前面先决条件所述:

J(θ0,θ1…,θn)=12m∑j=0m(hθ(x0(j),x1(j),…,xn(j))−yj)2J(θ_0,θ_1…,θ_n)=\frac{1}{2m}∑_{j=0}^m(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)^2J(θ0,θ1…,θn)=2m1j=0∑m(hθ(x0(j),x1(j),…,xn(j))−yj)2
则在算法过程步骤1中对于 θiθ_iθi 的偏导数计算如下:

∂∂θiJ(θ0,θ1…,θn)=1m∑j=0m(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)\frac{∂}{∂θ_i}J(θ_0,θ_1…,θ_n)=\frac{1}{m}∑_{j=0}^m(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)x_i^{(j)}∂θi∂J(θ0,θ1…,θn)=m1j=0∑m(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)
由于样本中没有 x0x_0x0,上式中令所有的 x0jx^j_0x0j 为1,步骤4中 θiθ_iθi 的表达式更新如下:

θi=θi−α1m∑j=0m(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)θ_i=θ_i−α\frac{1}{m}∑_{j=0}^m(h_θ(x^{(j)}_0,x^{(j)}_1,…,x^{(j)}_n)−y_j)x_i^{(j)}θi=θi−αm1j=0∑m(hθ(x0(j),x1(j),…,xn(j))−yj)xi(j)
从这个例子可以看出当前点的梯度方向是由所有的样本决定的;

后记

上述就是本篇博文的所有内容了,比较细致的介绍了梯度以及梯度下降算法相关的内容,下一篇博文 【AI】浅谈梯度下降算法(实战篇) 我们将结合代码,通过实战对梯度下降知识点进行巩固和加深印象,深入理解其中的奥义!

参考:

  • 梯度下降(Gradient Descent)
  • Python 实现简单的梯度下降法
  • 梯度下降法原理与python实现

📝 上篇精讲:【项目实战】MNIST 手写数字识别(下)

💖 我是 𝓼𝓲𝓭𝓲𝓸𝓽,期待你的关注;

👍 创作不易,请多多支持;

🔥 系列专栏:AI

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

1…838485…956

开发者博客

9558 日志
1953 标签
RSS
© 2025 开发者博客
本站总访问量次
由 Hexo 强力驱动
|
主题 — NexT.Muse v5.1.4
0%