开发者博客 – IT技术 尽在开发者博客

开发者博客 – 科技是第一生产力


  • 首页

  • 归档

  • 搜索

从头开始在Python中开发深度学习字幕生成模型

发表于 2017-12-12

本文从数据预处理开始详细地描述了如何使用 VGG 和循环神经网络构建图像描述系统,对读者使用 Keras 和 TensorFlow 理解与实现自动图像描述很有帮助。本文的代码都有解释,非常适合图像描述任务的入门读者详细了解这一过程。

图像描述是一个有挑战性的人工智能问题,涉及为给定图像生成文本描述。

字幕生成是一个有挑战性的人工智能问题,涉及为给定图像生成文本描述。

一般图像描述或字幕生成需要使用计算机视觉方法来了解图像内容,也需要自然语言处理模型将对图像的理解转换成正确顺序的文字。近期,深度学习方法在该问题的多个示例上获得了顶尖结果。

深度学习方法在字幕生成问题上展现了顶尖的结果。这些方法最令人印象深刻的地方:给定一个图像,我们无需复杂的数据准备和特殊设计的流程,就可以使用端到端的方式预测字幕。

本教程将介绍如何从头开发能生成图像字幕的深度学习模型。

完成本教程,你将学会:

  • 如何为训练深度学习模型准备图像和文本数据。
  • 如何设计和训练深度学习字幕生成模型。
  • 如何评估一个训练后的字幕生成模型,并使用它为全新的图像生成字幕。

教程概览

该教程共分为 6 部分:

  1. 图像和字幕数据集
  1. 准备图像数据
  1. 准备文本数据
  1. 开发深度学习模型
  1. 评估模型
  1. 生成新的字幕

Python 环境

本教程假设你已经安装了 Python SciPy 环境,该环境完美适合 Python 3。你必须安装 Keras(2.0 版本或更高),TensorFlow 或 Theano 后端。本教程还假设你已经安装了 scikit-learn、Pandas、NumPy 和 Matplotlib 等科学计算与绘图软件库。

我推荐在 GPU 系统上运行代码。你可以在 Amazon Web Services 上用廉价的方式获取 GPU:如何在 AWS GPU 上运行 Jupyter noterbook?

图像和字幕数据集

图像字幕生成可使用的优秀数据集有 Flickr8K 数据集。原因在于它逼真且相对较小,即使你的工作站使用的是 CPU 也可以下载它,并用于构建模型。

对该数据集的明确描述见 2013 年的论文《Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics》。

作者对该数据集的描述如下:

我们介绍了一种用于基于句子的图像描述和搜索的新型基准集合,包括 8000 张图像,每个图像有五个不同的字幕描述对突出实体和事件提供清晰描述。

图像选自六个不同的 Flickr 组,往往不包含名人或有名的地点,而是手动选择多种场景和情形。

该数据集可免费获取。你必须填写一份申请表,然后就可以通过电子邮箱收到数据集。申请表链接:https://illinois.edu/fb/sec/1713398。

很快,你会收到电子邮件,包含以下两个文件的链接:

  • Flickr8k_Dataset.zip(1 Gigabyte)包含所有图像。
  • Flickr8k_text.zip(2.2 Megabytes)包含所有图像文本描述。

下载数据集,并在当前工作文件夹里进行解压缩。你将得到两个目录:

  • Flicker8k_Dataset:包含 8092 张 JPEG 格式图像。
  • Flickr8k_text:包含大量不同来源的图像描述文件。

该数据集包含一个预制训练数据集(6000 张图像)、开发数据集(1000 张图像)和测试数据集(1000 张图像)。

用于评估模型技能的一个指标是 BLEU 值。对于推断,下面是一些精巧的模型在测试数据集上进行评估时获得的大概 BLEU 值(来源:2017 年论文《Where to put the Image in an Image Caption Generator》):

  • BLEU-1: 0.401 to 0.578.
  • BLEU-2: 0.176 to 0.390.
  • BLEU-3: 0.099 to 0.260.
  • BLEU-4: 0.059 to 0.170.

稍后在评估模型部分将详细介绍 BLEU 值。下面,我们来看一下如何加载图像。

准备图像数据

我们将使用预训练模型解析图像内容,且目前有很多可选模型。在这种情况下,我们将使用 Oxford Visual Geometry Group 或 VGG(该模型赢得了 2014 年 ImageNet 竞赛冠军)。

Keras 可直接提供该预训练模型。注意,第一次使用该模型时,Keras 将从互联网上下载模型权重,大概 500Megabytes。这可能需要一段时间(时间长度取决于你的网络连接)。

我们可以将该模型作为更大的图像字幕生成模型的一部分。问题在于模型太大,每次我们想测试新语言模型配置(下行)时在该网络中运行每张图像非常冗余。

我们可以使用预训练模型对「图像特征」进行预计算,并保存至文件中。然后加载这些特征,将其馈送至模型中作为数据集中给定图像的描述。在完整的 VGG 模型中运行图像也是这样,我们需要提前运行该步骤。

优化可以加快模型训练过程,消耗更少内存。我们可以使用 VGG class 在 Keras 中运行 VGG 模型。我们将移除加载模型的最后一层,因为该层用于预测图像的分类。我们对图像分类不感兴趣,我们感兴趣的是分类之前图像的内部表征。这些就是模型从图像中提取出的「特征」。

Keras 还提供工具将加载图像改造成模型的偏好大小(如 3 通道 224 x 224 像素图像)。

下面是 extract_features() 函数,即给出一个目录名,该函数将加载每个图像、为 VGG 准备图像数据,并从 VGG 模型中收集预测到的特征。图像特征是包含 4096 个元素的向量,该函数向图像特征返回一个图像标识符(identifier)词典。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
复制代码# extract features from each photo in the directory
def extract_features(directory):
# load the model
model = VGG16()
# re-structure the model
model.layers.pop()
model = Model(inputs=model.inputs, outputs=model.layers[-1].output)
# summarize
print(model.summary())
# extract features from each photo
features = dict()
for name in listdir(directory):
# load an image from file
filename = directory + '/' + name
image = load_img(filename, target_size=(224, 224))
# convert the image pixels to a numpy array
image = img_to_array(image)
# reshape data for the model
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# prepare the image for the VGG model
image = preprocess_input(image)
# get features
feature = model.predict(image, verbose=0)
# get image id
image_id = name.split('.')[0]
# store feature
features[image_id] = feature
print('>%s' % name)
return features

我们调用该函数为模型测试准备图像数据,然后将词典保存至 features.pkl 文件。

完整示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
复制代码from os import listdir
from pickle import dump
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.models import Model

# extract features from each photo in the directory
def extract_features(directory):
# load the model
model = VGG16()
# re-structure the model
model.layers.pop()
model = Model(inputs=model.inputs, outputs=model.layers[-1].output)
# summarize
print(model.summary())
# extract features from each photo
features = dict()
for name in listdir(directory):
# load an image from file
filename = directory + '/' + name
image = load_img(filename, target_size=(224, 224))
# convert the image pixels to a numpy array
image = img_to_array(image)
# reshape data for the model
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# prepare the image for the VGG model
image = preprocess_input(image)
# get features
feature = model.predict(image, verbose=0)
# get image id
image_id = name.split('.')[0]
# store feature
features[image_id] = feature
print('>%s' % name)
return features

# extract features from all images
directory = 'Flicker8k_Dataset'
features = extract_features(directory)
print('Extracted Features: %d' % len(features))
# save to file
dump(features, open('features.pkl', 'wb'))

运行该数据准备步骤可能需要一点时间,时间长度取决于你的硬件,带有 CPU 的现代工作站可能需要一个小时。

运行结束时,提取出的特征将存储在 features.pkl 文件中以备后用。该文件大概 127 Megabytes 大小。

准备文本数据

该数据集中每个图像有多个描述,文本描述需要进行最低限度的清洗。首先,加载包含所有文本描述的文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
复制代码# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, 'r')
# read all text
text = file.read()
# close the file
file.close()
return text

filename = 'Flickr8k_text/Flickr8k.token.txt'
# load descriptions
doc = load_doc(filename)

每个图像有一个独有的标识符,该标识符出现在文件名和文本描述文件中。

接下来,我们将逐步对图像描述进行操作。下面定义一个 load_descriptions() 函数:给出一个需要加载的文本文档,该函数将返回图像标识符词典。每个图像标识符映射到一或多个文本描述。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
复制代码# extract descriptions for images
def load_descriptions(doc):
mapping = dict()
# process lines
for line in doc.split('\n'):
# split line by white space
tokens = line.split()
if len(line) < 2:
continue
# take the first token as the image id, the rest as the description
image_id, image_desc = tokens[0], tokens[1:]
# remove filename from image id
image_id = image_id.split('.')[0]
# convert description tokens back to string
image_desc = ' '.join(image_desc)
# create the list if needed
if image_id not in mapping:
mapping[image_id] = list()
# store description
mapping[image_id].append(image_desc)
return mapping

# parse descriptions
descriptions = load_descriptions(doc)
print('Loaded: %d ' % len(descriptions))

下面,我们需要清洗描述文本。因为描述已经经过符号化,所以它十分易于处理。

我们将用以下方式清洗文本,以减少需要处理的词汇量:

  • 所有单词全部转换成小写。
  • 移除所有标点符号。
  • 移除所有少于或等于一个字符的单词(如 a)。
  • 移除所有带数字的单词。

下面定义了 clean_descriptions() 函数:给出描述的图像标识符词典,遍历每个描述,清洗文本。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
复制代码import string

def clean_descriptions(descriptions):
# prepare translation table for removing punctuation
table = str.maketrans('', '', string.punctuation)
for key, desc_list in descriptions.items():
for i in range(len(desc_list)):
desc = desc_list[i]
# tokenize
desc = desc.split()
# convert to lower case
desc = [word.lower() for word in desc]
# remove punctuation from each token
desc = [w.translate(table) for w in desc]
# remove hanging 's' and 'a'
desc = [word for word in desc if len(word)>1]
# remove tokens with numbers in them
desc = [word for word in desc if word.isalpha()]
# store as string
desc_list[i] = ' '.join(desc)

# clean descriptions
clean_descriptions(descriptions)

清洗后,我们可以总结词汇量。

理想情况下,我们希望使用尽可能少的词汇而得到强大的表达性。词汇越少则模型越小、训练速度越快。

对于推断,我们可以将干净的描述转换成一个集,将它的规模打印出来,这样就可以了解我们的数据集词汇量的大小了。

1
2
3
4
5
6
7
8
9
10
11
复制代码# convert the loaded descriptions into a vocabulary of words
def to_vocabulary(descriptions):
# build a list of all description strings
all_desc = set()
for key in descriptions.keys():
[all_desc.update(d.split()) for d in descriptions[key]]
return all_desc

# summarize vocabulary
vocabulary = to_vocabulary(descriptions)
print('Vocabulary Size: %d' % len(vocabulary))

最后,我们保存图像标识符词典和描述至一个新文本 descriptions.txt,该文件中每行只有一个图像和一个描述。

下面我们定义了 save_doc() 函数,即给出一个包含标识符和描述之间映射的词典和文件名,将该映射保存至文件中。

1
2
3
4
5
6
7
8
9
10
11
12
13
复制代码# save descriptions to file, one per line
def save_descriptions(descriptions, filename):
lines = list()
for key, desc_list in descriptions.items():
for desc in desc_list:
lines.append(key + ' ' + desc)
data = '\n'.join(lines)
file = open(filename, 'w')
file.write(data)
file.close()

# save descriptions
save_doc(descriptions, 'descriptions.txt')

汇总起来,完整的函数定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
复制代码import string

# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, 'r')
# read all text
text = file.read()
# close the file
file.close()
return text

# extract descriptions for images
def load_descriptions(doc):
mapping = dict()
# process lines
for line in doc.split('\n'):
# split line by white space
tokens = line.split()
if len(line) < 2:
continue
# take the first token as the image id, the rest as the description
image_id, image_desc = tokens[0], tokens[1:]
# remove filename from image id
image_id = image_id.split('.')[0]
# convert description tokens back to string
image_desc = ' '.join(image_desc)
# create the list if needed
if image_id not in mapping:
mapping[image_id] = list()
# store description
mapping[image_id].append(image_desc)
return mapping

def clean_descriptions(descriptions):
# prepare translation table for removing punctuation
table = str.maketrans('', '', string.punctuation)
for key, desc_list in descriptions.items():
for i in range(len(desc_list)):
desc = desc_list[i]
# tokenize
desc = desc.split()
# convert to lower case
desc = [word.lower() for word in desc]
# remove punctuation from each token
desc = [w.translate(table) for w in desc]
# remove hanging 's' and 'a'
desc = [word for word in desc if len(word)>1]
# remove tokens with numbers in them
desc = [word for word in desc if word.isalpha()]
# store as string
desc_list[i] = ' '.join(desc)

# convert the loaded descriptions into a vocabulary of words
def to_vocabulary(descriptions):
# build a list of all description strings
all_desc = set()
for key in descriptions.keys():
[all_desc.update(d.split()) for d in descriptions[key]]
return all_desc

# save descriptions to file, one per line
def save_descriptions(descriptions, filename):
lines = list()
for key, desc_list in descriptions.items():
for desc in desc_list:
lines.append(key + ' ' + desc)
data = '\n'.join(lines)
file = open(filename, 'w')
file.write(data)
file.close()

filename = 'Flickr8k_text/Flickr8k.token.txt'
# load descriptions
doc = load_doc(filename)
# parse descriptions
descriptions = load_descriptions(doc)
print('Loaded: %d ' % len(descriptions))
# clean descriptions
clean_descriptions(descriptions)
# summarize vocabulary
vocabulary = to_vocabulary(descriptions)
print('Vocabulary Size: %d' % len(vocabulary))
# save to file
save_descriptions(descriptions, 'descriptions.txt')

运行示例首先打印出加载图像描述的数量(8092)和干净词汇量的规模(8763 个单词)。

1
2
复制代码Loaded: 8,092
Vocabulary Size: 8,763

最后,把干净的描述写入 descriptions.txt。

查看文件,我们能够看到该描述可用于建模。文件中描述的顺序可能会发生改变。

1
2
3
4
5
6
复制代码2252123185_487f21e336 bunch on people are seated in stadium
2252123185_487f21e336 crowded stadium is full of people watching an event
2252123185_487f21e336 crowd of people fill up packed stadium
2252123185_487f21e336 crowd sitting in an indoor stadium
2252123185_487f21e336 stadium full of people watch game
...

开发深度学习模型

本节我们将定义深度学习模型,在训练数据集上进行拟合。本节分为以下几部分:

  1. 加载数据。
  1. 定义模型。
  1. 拟合模型。
  1. 完成示例。

加载数据

首先,我们必须加载准备好的图像和文本数据来拟合模型。

我们将在训练数据集中的所有图像和描述上训练数据。训练过程中,我们计划在开发数据集上监控模型性能,使用该性能确定什么时候保存模型至文件。

训练和开发数据集已经预制好,并分别保存在 Flickr_8k.trainImages.txt 和 Flickr_8k.devImages.txt 文件中,二者均包含图像文件名列表。从这些文件名中,我们可以提取图像标识符,并使用它们为每个集过滤图像和描述。

如下所示,load_set() 函数将根据训练或开发集文件名加载一个预定义标识符集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
复制代码# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, 'r')
# read all text
text = file.read()
# close the file
file.close()
return text

# load a pre-defined list of photo identifiers
def load_set(filename):
doc = load_doc(filename)
dataset = list()
# process line by line
for line in doc.split('\n'):
# skip empty lines
if len(line) < 1:
continue
# get the image identifier
identifier = line.split('.')[0]
dataset.append(identifier)
return set(dataset)

现在,我们可以使用预定义训练或开发标识符集加载图像和描述了。

下面是 load_clean_descriptions() 函数,该函数从给定标识符集的 descriptions.txt 中加载干净的文本描述,并向文本描述列表返回标识符词典。

我们将要开发的模型能够生成给定图像的字幕,一次生成一个单词。先前生成的单词序列作为输入。因此,我们需要一个 first word 来开启生成步骤和一个 last word 来表示字幕生成结束。

我们将使用字符串 startseq 和 endseq 完成该目的。这些标记被添加至加载描述,像它们本身就是加载出的那样。在对文本进行编码之前进行该操作非常重要,这样这些标记才能得到正确编码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
复制代码# load clean descriptions into memory
def load_clean_descriptions(filename, dataset):
# load document
doc = load_doc(filename)
descriptions = dict()
for line in doc.split('\n'):
# split line by white space
tokens = line.split()
# split id from description
image_id, image_desc = tokens[0], tokens[1:]
# skip images not in the set
if image_id in dataset:
# create list
if image_id not in descriptions:
descriptions[image_id] = list()
# wrap description in tokens
desc = 'startseq ' + ' '.join(image_desc) + ' endseq'
# store
descriptions[image_id].append(desc)
return descriptions

接下来,我们可以为给定数据集加载图像特征。

下面定义了 load_photo_features() 函数,该函数加载了整个图像描述集,然后返回给定图像标识符集你感兴趣的子集。

这不是很高效,但是,这可以帮助我们启动,快速运行。

1
2
3
4
5
6
7
复制代码# load photo features
def load_photo_features(filename, dataset):
# load all features
all_features = load(open(filename, 'rb'))
# filter features
features = {k: all_features[k] for k in dataset}
return features

我们可以在这里暂停一下,测试目前开发的所有内容。

完整的代码示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
复制代码# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, 'r')
# read all text
text = file.read()
# close the file
file.close()
return text

# load a pre-defined list of photo identifiers
def load_set(filename):
doc = load_doc(filename)
dataset = list()
# process line by line
for line in doc.split('\n'):
# skip empty lines
if len(line) < 1:
continue
# get the image identifier
identifier = line.split('.')[0]
dataset.append(identifier)
return set(dataset)

# load clean descriptions into memory
def load_clean_descriptions(filename, dataset):
# load document
doc = load_doc(filename)
descriptions = dict()
for line in doc.split('\n'):
# split line by white space
tokens = line.split()
# split id from description
image_id, image_desc = tokens[0], tokens[1:]
# skip images not in the set
if image_id in dataset:
# create list
if image_id not in descriptions:
descriptions[image_id] = list()
# wrap description in tokens
desc = 'startseq ' + ' '.join(image_desc) + ' endseq'
# store
descriptions[image_id].append(desc)
return descriptions

# load photo features
def load_photo_features(filename, dataset):
# load all features
all_features = load(open(filename, 'rb'))
# filter features
features = {k: all_features[k] for k in dataset}
return features

# load training dataset (6K)
filename = 'Flickr8k_text/Flickr_8k.trainImages.txt'
train = load_set(filename)
print('Dataset: %d' % len(train))
# descriptions
train_descriptions = load_clean_descriptions('descriptions.txt', train)
print('Descriptions: train=%d' % len(train_descriptions))
# photo features
train_features = load_photo_features('features.pkl', train)
print('Photos: train=%d' % len(train_features))

运行该示例首先在测试数据集中加载 6000 张图像标识符。这些特征之后将用于加载干净描述文本和预计算的图像特征。

1
2
3
复制代码Dataset: 6,000
Descriptions: train=6,000
Photos: train=6,000

描述文本在作为输入馈送至模型或与模型预测进行对比之前需要先编码成数值。

编码数据的第一步是创建单词到唯一整数值之间的持续映射。Keras 提供 Tokenizer class,可根据加载的描述数据学习该映射。

下面定义了用于将描述词典转换成字符串列表的 to_lines() 函数,和对加载图像描述文本拟合 Tokenizer 的 create_tokenizer() 函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
复制代码# convert a dictionary of clean descriptions to a list of descriptions
def to_lines(descriptions):
all_desc = list()
for key in descriptions.keys():
[all_desc.append(d) for d in descriptions[key]]
return all_desc

# fit a tokenizer given caption descriptions
def create_tokenizer(descriptions):
lines = to_lines(descriptions)
tokenizer = Tokenizer()
tokenizer.fit_on_texts(lines)
return tokenizer

# prepare tokenizer
tokenizer = create_tokenizer(train_descriptions)
vocab_size = len(tokenizer.word_index) + 1
print('Vocabulary Size: %d' % vocab_size)

我们现在对文本进行编码。

每个描述将被分割成单词。我们向该模型提供一个单词和图像,然后模型生成下一个单词。描述的前两个单词和图像将作为模型输入以生成下一个单词,这就是该模型的训练方式。

例如,输入序列「a little girl running in field」将被分割成 6 个输入-输出对来训练该模型:

1
2
3
4
5
6
7
复制代码X1,		X2 (text sequence), 						y (word)
photo startseq, little
photo startseq, little, girl
photo startseq, little, girl, running
photo startseq, little, girl, running, in
photo startseq, little, girl, running, in, field
photo startseq, little, girl, running, in, field, endseq

稍后,当模型用于生成描述时,生成的单词将被连结起来,递归地作为输入以生成图像字幕。

下面是 create_sequences() 函数,给出 tokenizer、最大序列长度和所有描述和图像的词典,该函数将这些数据转换成输入-输出对来训练模型。该模型有两个输入数组:一个用于图像特征,一个用于编码文本。模型输出是文本序列中编码的下一个单词。

输入文本被编码为整数,被馈送至词嵌入层。图像特征将被直接馈送至模型的另一部分。该模型输出的预测是所有单词在词汇表中的概率分布。

因此,输出数据是每个单词的 one-hot 编码,它表示一种理想化的概率分布,即除了实际词位置之外所有词位置的值都为 0,实际词位置的值为 1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
复制代码# create sequences of images, input sequences and output words for an image
def create_sequences(tokenizer, max_length, descriptions, photos):
X1, X2, y = list(), list(), list()
# walk through each image identifier
for key, desc_list in descriptions.items():
# walk through each description for the image
for desc in desc_list:
# encode the sequence
seq = tokenizer.texts_to_sequences([desc])[0]
# split one sequence into multiple X,y pairs
for i in range(1, len(seq)):
# split into input and output pair
in_seq, out_seq = seq[:i], seq[i]
# pad input sequence
in_seq = pad_sequences([in_seq], maxlen=max_length)[0]
# encode output sequence
out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]
# store
X1.append(photos[key][0])
X2.append(in_seq)
y.append(out_seq)
return array(X1), array(X2), array(y)

我们需要计算最长描述中单词的最大数量。下面是一个有帮助的函数 max_length()。

1
2
3
4
复制代码# calculate the length of the description with the most words
def max_length(descriptions):
lines = to_lines(descriptions)
return max(len(d.split()) for d in lines)

现在我们可以为训练和开发数据集加载数据,并将加载数据转换成输入-输出对来拟合深度学习模型。

定义模型

我们将根据 Marc Tanti, et al. 在 2017 年论文中描述的「merge-model」定义深度学习模型。

  • Where to put the Image in an Image Caption Generator,2017
  • What is the Role of Recurrent Neural Networks (RNNs) in an Image Caption Generator?,2017

论文作者提供了该模型的简图,如下所示:

我们将从三部分描述该模型:

  • 图像特征提取器:这是一个在 ImageNet 数据集上预训练的 16 层 VGG 模型。我们已经使用 VGG 模型(没有输出层)对图像进行预处理,并将使用该模型预测的提取特征作为输入。
  • 序列处理器:合适一个词嵌入层,用于处理文本输入,后面是长短期记忆(LSTM)循环神经网络层。
  • 解码器:特征提取器和序列处理器输出一个固定长度向量。这些向量由密集层(Dense layer)融合和处理,来进行最终预测。

图像特征提取器模型的输入图像特征是维度为 4096 的向量,这些向量经过全连接层处理并生成图像的 256 元素表征。

序列处理器模型期望馈送至嵌入层的预定义长度(34 个单词)输入序列使用掩码来忽略 padded 值。之后是具备 256 个循环单元的 LSTM 层。

两个输入模型均输出 256 元素的向量。此外,输入模型以 50% 的 dropout 率使用正则化,旨在减少训练数据集的过拟合情况,因为该模型配置学习非常快。

解码器模型使用额外的操作融合来自两个输入模型的向量。然后将其馈送至 256 个神经元的密集层,然后输送至最终输出密集层,从而在所有输出词汇上对序列中的下一个单词进行 softmax 预测。

下面的 define_model() 函数定义和返回要拟合的模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
复制代码# define the captioning model
def define_model(vocab_size, max_length):
# feature extractor model
inputs1 = Input(shape=(4096,))
fe1 = Dropout(0.5)(inputs1)
fe2 = Dense(256, activation='relu')(fe1)
# sequence model
inputs2 = Input(shape=(max_length,))
se1 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)
se2 = Dropout(0.5)(se1)
se3 = LSTM(256)(se2)
# decoder model
decoder1 = add([fe2, se3])
decoder2 = Dense(256, activation='relu')(decoder1)
outputs = Dense(vocab_size, activation='softmax')(decoder2)
# tie it together [image, seq] [word]
model = Model(inputs=[inputs1, inputs2], outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer='adam')
# summarize model
print(model.summary())
plot_model(model, to_file='model.png', show_shapes=True)
return model

要了解模型结构,特别是层的形状,请参考下表中的总结。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
复制代码____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_2 (InputLayer) (None, 34) 0
____________________________________________________________________________________________________
input_1 (InputLayer) (None, 4096) 0
____________________________________________________________________________________________________
embedding_1 (Embedding) (None, 34, 256) 1940224 input_2[0][0]
____________________________________________________________________________________________________
dropout_1 (Dropout) (None, 4096) 0 input_1[0][0]
____________________________________________________________________________________________________
dropout_2 (Dropout) (None, 34, 256) 0 embedding_1[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 256) 1048832 dropout_1[0][0]
____________________________________________________________________________________________________
lstm_1 (LSTM) (None, 256) 525312 dropout_2[0][0]
____________________________________________________________________________________________________
add_1 (Add) (None, 256) 0 dense_1[0][0]
lstm_1[0][0]
____________________________________________________________________________________________________
dense_2 (Dense) (None, 256) 65792 add_1[0][0]
____________________________________________________________________________________________________
dense_3 (Dense) (None, 7579) 1947803 dense_2[0][0]
====================================================================================================
Total params: 5,527,963
Trainable params: 5,527,963
Non-trainable params: 0
____________________________________________________________________________________________________

我们还创建了一幅图来可视化网络结构,帮助理解两个输入流。

图像字幕生成深度学习模型示意图。

拟合模型

现在我们已经了解如何定义模型了,那么接下来我们要在训练数据集上拟合模型。

该模型学习速度快,很快就会对训练数据集产生过拟合。因此,我们需要在留出的开发数据集上监控训练模型的泛化情况。如果模型在开发数据集上的技能在每个 epoch 结束时有所提升,则我们将整个模型保存至文件。

在运行结束时,我们能够使用训练数据集上具备最优技能的模型作为最终模型。

通过在 Keras 中定义 ModelCheckpoint,使之监控验证数据集上的最小损失,我们可以实现以上目的。然后将该模型保存至文件名中包含训练损失和验证损失的文件中。

1
2
3
复制代码# define checkpoint callback
filepath = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

之后,通过 fit() 中的 callbacks 参数指定检查点。我们还需要 fit() 中的 validation_data 参数指定开发数据集。

我们仅拟合模型 20 epoch,给出一定量的训练数据,在一般硬件上每个 epoch 可能需要 30 分钟。

1
2
复制代码# fit model
model.fit([X1train, X2train], ytrain, epochs=20, verbose=2, callbacks=[checkpoint], validation_data=([X1test, X2test], ytest))

完成示例

在训练数据上拟合模型的完整示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
复制代码from numpy import array
from pickle import load
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Embedding
from keras.layers import Dropout
from keras.layers.merge import add
from keras.callbacks import ModelCheckpoint

# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, 'r')
# read all text
text = file.read()
# close the file
file.close()
return text

# load a pre-defined list of photo identifiers
def load_set(filename):
doc = load_doc(filename)
dataset = list()
# process line by line
for line in doc.split('\n'):
# skip empty lines
if len(line) < 1:
continue
# get the image identifier
identifier = line.split('.')[0]
dataset.append(identifier)
return set(dataset)

# load clean descriptions into memory
def load_clean_descriptions(filename, dataset):
# load document
doc = load_doc(filename)
descriptions = dict()
for line in doc.split('\n'):
# split line by white space
tokens = line.split()
# split id from description
image_id, image_desc = tokens[0], tokens[1:]
# skip images not in the set
if image_id in dataset:
# create list
if image_id not in descriptions:
descriptions[image_id] = list()
# wrap description in tokens
desc = 'startseq ' + ' '.join(image_desc) + ' endseq'
# store
descriptions[image_id].append(desc)
return descriptions

# load photo features
def load_photo_features(filename, dataset):
# load all features
all_features = load(open(filename, 'rb'))
# filter features
features = {k: all_features[k] for k in dataset}
return features

# covert a dictionary of clean descriptions to a list of descriptions
def to_lines(descriptions):
all_desc = list()
for key in descriptions.keys():
[all_desc.append(d) for d in descriptions[key]]
return all_desc

# fit a tokenizer given caption descriptions
def create_tokenizer(descriptions):
lines = to_lines(descriptions)
tokenizer = Tokenizer()
tokenizer.fit_on_texts(lines)
return tokenizer

# calculate the length of the description with the most words
def max_length(descriptions):
lines = to_lines(descriptions)
return max(len(d.split()) for d in lines)

# create sequences of images, input sequences and output words for an image
def create_sequences(tokenizer, max_length, descriptions, photos):
X1, X2, y = list(), list(), list()
# walk through each image identifier
for key, desc_list in descriptions.items():
# walk through each description for the image
for desc in desc_list:
# encode the sequence
seq = tokenizer.texts_to_sequences([desc])[0]
# split one sequence into multiple X,y pairs
for i in range(1, len(seq)):
# split into input and output pair
in_seq, out_seq = seq[:i], seq[i]
# pad input sequence
in_seq = pad_sequences([in_seq], maxlen=max_length)[0]
# encode output sequence
out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]
# store
X1.append(photos[key][0])
X2.append(in_seq)
y.append(out_seq)
return array(X1), array(X2), array(y)

# define the captioning model
def define_model(vocab_size, max_length):
# feature extractor model
inputs1 = Input(shape=(4096,))
fe1 = Dropout(0.5)(inputs1)
fe2 = Dense(256, activation='relu')(fe1)
# sequence model
inputs2 = Input(shape=(max_length,))
se1 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)
se2 = Dropout(0.5)(se1)
se3 = LSTM(256)(se2)
# decoder model
decoder1 = add([fe2, se3])
decoder2 = Dense(256, activation='relu')(decoder1)
outputs = Dense(vocab_size, activation='softmax')(decoder2)
# tie it together [image, seq] [word]
model = Model(inputs=[inputs1, inputs2], outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer='adam')
# summarize model
print(model.summary())
plot_model(model, to_file='model.png', show_shapes=True)
return model

# train dataset

# load training dataset (6K)
filename = 'Flickr8k_text/Flickr_8k.trainImages.txt'
train = load_set(filename)
print('Dataset: %d' % len(train))
# descriptions
train_descriptions = load_clean_descriptions('descriptions.txt', train)
print('Descriptions: train=%d' % len(train_descriptions))
# photo features
train_features = load_photo_features('features.pkl', train)
print('Photos: train=%d' % len(train_features))
# prepare tokenizer
tokenizer = create_tokenizer(train_descriptions)
vocab_size = len(tokenizer.word_index) + 1
print('Vocabulary Size: %d' % vocab_size)
# determine the maximum sequence length
max_length = max_length(train_descriptions)
print('Description Length: %d' % max_length)
# prepare sequences
X1train, X2train, ytrain = create_sequences(tokenizer, max_length, train_descriptions, train_features)

# dev dataset

# load test set
filename = 'Flickr8k_text/Flickr_8k.devImages.txt'
test = load_set(filename)
print('Dataset: %d' % len(test))
# descriptions
test_descriptions = load_clean_descriptions('descriptions.txt', test)
print('Descriptions: test=%d' % len(test_descriptions))
# photo features
test_features = load_photo_features('features.pkl', test)
print('Photos: test=%d' % len(test_features))
# prepare sequences
X1test, X2test, ytest = create_sequences(tokenizer, max_length, test_descriptions, test_features)

# fit model

# define the model
model = define_model(vocab_size, max_length)
# define checkpoint callback
filepath = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
# fit model
model.fit([X1train, X2train], ytrain, epochs=20, verbose=2, callbacks=[checkpoint], validation_data=([X1test, X2test], ytest))

运行该示例首先打印加载训练和开发数据集的摘要。

1
2
3
4
5
6
7
8
复制代码Dataset: 6,000
Descriptions: train=6,000
Photos: train=6,000
Vocabulary Size: 7,579
Description Length: 34
Dataset: 1,000
Descriptions: test=1,000
Photos: test=1,000

之后,我们可以了解训练和验证(开发)输入-输出对的整体数量。

1
复制代码Train on 306,404 samples, validate on 50,903 samples

然后运行模型,将最优模型保存至.h5 文件。

在运行过程中,我把最优验证结果的模型保存至文件中:

  • model-ep002-loss3.245-val_loss3.612.h5

该模型在第 2 个 epoch 中结束时被保存,在训练数据集上的损失为 3.245,在开发数据集上的损失为 3.612,每个人的具体结果不同。如果你在 AWS 中运行上述示例,那么将模型文件复制回你当前的工作文件夹。

评估模型

模型拟合之后,我们可以在留出的测试数据集上评估它的预测技能。

使模型对测试数据集中的所有图像生成描述,使用标准代价函数评估预测,从而评估模型。

首先,我们需要使用训练模型对图像生成描述。输入开始描述的标记 『startseq『,生成一个单词,然后递归地用生成单词作为输入启用模型直到序列标记到 『endseq『或达到最大描述长度。

下面的 generate_desc() 函数实现该行为,并基于给定训练模型和作为输入的准备图像生成文本描述。它启用 word_for_id() 函数以映射整数预测至单词。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
复制代码# map an integer to a word
def word_for_id(integer, tokenizer):
for word, index in tokenizer.word_index.items():
if index == integer:
return word
return None

# generate a description for an image
def generate_desc(model, tokenizer, photo, max_length):
# seed the generation process
in_text = 'startseq'
# iterate over the whole length of the sequence
for i in range(max_length):
# integer encode input sequence
sequence = tokenizer.texts_to_sequences([in_text])[0]
# pad input
sequence = pad_sequences([sequence], maxlen=max_length)
# predict next word
yhat = model.predict([photo,sequence], verbose=0)
# convert probability to integer
yhat = argmax(yhat)
# map integer to word
word = word_for_id(yhat, tokenizer)
# stop if we cannot map the word
if word is None:
break
# append as input for generating the next word
in_text += ' ' + word
# stop if we predict the end of the sequence
if word == 'endseq':
break
return in_text

我们将为测试数据集和训练数据集中的所有图像生成预测。

下面的 evaluate_model() 基于给定图像描述数据集和图像特征评估训练模型。收集实际和预测描述,使用语料库 BLEU 值对它们进行评估。语料库 BLEU 值总结了生成文本和期望文本之间的相似度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
复制代码# evaluate the skill of the model
def evaluate_model(model, descriptions, photos, tokenizer, max_length):
actual, predicted = list(), list()
# step over the whole set
for key, desc_list in descriptions.items():
# generate description
yhat = generate_desc(model, tokenizer, photos[key], max_length)
# store actual and predicted
references = [d.split() for d in desc_list]
actual.append(references)
predicted.append(yhat.split())
# calculate BLEU score
print('BLEU-1: %f' % corpus_bleu(actual, predicted, weights=(1.0, 0, 0, 0)))
print('BLEU-2: %f' % corpus_bleu(actual, predicted, weights=(0.5, 0.5, 0, 0)))
print('BLEU-3: %f' % corpus_bleu(actual, predicted, weights=(0.3, 0.3, 0.3, 0)))
print('BLEU-4: %f' % corpus_bleu(actual, predicted, weights=(0.25, 0.25, 0.25, 0.25)))

BLEU 值用于在文本翻译中评估译文和一或多个参考译文的相似度。

这里,我们将每个生成描述与该图像的所有参考描述进行对比,然后计算 1、2、3、4 等 n 元语言模型的 BLEU 值。

NLTK Python 库在 corpus_bleu() 函数中实现了 BLEU 值计算。分值越接近 1.0 越好,越接近 0 越差。

我们可以结合前面加载数据部分中的函数。首先加载训练数据集来准备 Tokenizer,以使我们将生成单词编码成模型的输入序列。使用模型训练时使用的编码机制对生成单词进行编码非常关键。

然后使用这些函数加载测试数据集。完整示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
复制代码from numpy import argmax
from pickle import load
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import load_model
from nltk.translate.bleu_score import corpus_bleu

# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, 'r')
# read all text
text = file.read()
# close the file
file.close()
return text

# load a pre-defined list of photo identifiers
def load_set(filename):
doc = load_doc(filename)
dataset = list()
# process line by line
for line in doc.split('\n'):
# skip empty lines
if len(line) < 1:
continue
# get the image identifier
identifier = line.split('.')[0]
dataset.append(identifier)
return set(dataset)

# load clean descriptions into memory
def load_clean_descriptions(filename, dataset):
# load document
doc = load_doc(filename)
descriptions = dict()
for line in doc.split('\n'):
# split line by white space
tokens = line.split()
# split id from description
image_id, image_desc = tokens[0], tokens[1:]
# skip images not in the set
if image_id in dataset:
# create list
if image_id not in descriptions:
descriptions[image_id] = list()
# wrap description in tokens
desc = 'startseq ' + ' '.join(image_desc) + ' endseq'
# store
descriptions[image_id].append(desc)
return descriptions

# load photo features
def load_photo_features(filename, dataset):
# load all features
all_features = load(open(filename, 'rb'))
# filter features
features = {k: all_features[k] for k in dataset}
return features

# covert a dictionary of clean descriptions to a list of descriptions
def to_lines(descriptions):
all_desc = list()
for key in descriptions.keys():
[all_desc.append(d) for d in descriptions[key]]
return all_desc

# fit a tokenizer given caption descriptions
def create_tokenizer(descriptions):
lines = to_lines(descriptions)
tokenizer = Tokenizer()
tokenizer.fit_on_texts(lines)
return tokenizer

# calculate the length of the description with the most words
def max_length(descriptions):
lines = to_lines(descriptions)
return max(len(d.split()) for d in lines)

# map an integer to a word
def word_for_id(integer, tokenizer):
for word, index in tokenizer.word_index.items():
if index == integer:
return word
return None

# generate a description for an image
def generate_desc(model, tokenizer, photo, max_length):
# seed the generation process
in_text = 'startseq'
# iterate over the whole length of the sequence
for i in range(max_length):
# integer encode input sequence
sequence = tokenizer.texts_to_sequences([in_text])[0]
# pad input
sequence = pad_sequences([sequence], maxlen=max_length)
# predict next word
yhat = model.predict([photo,sequence], verbose=0)
# convert probability to integer
yhat = argmax(yhat)
# map integer to word
word = word_for_id(yhat, tokenizer)
# stop if we cannot map the word
if word is None:
break
# append as input for generating the next word
in_text += ' ' + word
# stop if we predict the end of the sequence
if word == 'endseq':
break
return in_text

# evaluate the skill of the model
def evaluate_model(model, descriptions, photos, tokenizer, max_length):
actual, predicted = list(), list()
# step over the whole set
for key, desc_list in descriptions.items():
# generate description
yhat = generate_desc(model, tokenizer, photos[key], max_length)
# store actual and predicted
references = [d.split() for d in desc_list]
actual.append(references)
predicted.append(yhat.split())
# calculate BLEU score
print('BLEU-1: %f' % corpus_bleu(actual, predicted, weights=(1.0, 0, 0, 0)))
print('BLEU-2: %f' % corpus_bleu(actual, predicted, weights=(0.5, 0.5, 0, 0)))
print('BLEU-3: %f' % corpus_bleu(actual, predicted, weights=(0.3, 0.3, 0.3, 0)))
print('BLEU-4: %f' % corpus_bleu(actual, predicted, weights=(0.25, 0.25, 0.25, 0.25)))

# prepare tokenizer on train set

# load training dataset (6K)
filename = 'Flickr8k_text/Flickr_8k.trainImages.txt'
train = load_set(filename)
print('Dataset: %d' % len(train))
# descriptions
train_descriptions = load_clean_descriptions('descriptions.txt', train)
print('Descriptions: train=%d' % len(train_descriptions))
# prepare tokenizer
tokenizer = create_tokenizer(train_descriptions)
vocab_size = len(tokenizer.word_index) + 1
print('Vocabulary Size: %d' % vocab_size)
# determine the maximum sequence length
max_length = max_length(train_descriptions)
print('Description Length: %d' % max_length)

# prepare test set

# load test set
filename = 'Flickr8k_text/Flickr_8k.testImages.txt'
test = load_set(filename)
print('Dataset: %d' % len(test))
# descriptions
test_descriptions = load_clean_descriptions('descriptions.txt', test)
print('Descriptions: test=%d' % len(test_descriptions))
# photo features
test_features = load_photo_features('features.pkl', test)
print('Photos: test=%d' % len(test_features))

# load the model
filename = 'model-ep002-loss3.245-val_loss3.612.h5'
model = load_model(filename)
# evaluate model
evaluate_model(model, test_descriptions, test_features, tokenizer, max_length)

运行示例打印 BLEU 值。我们可以看到 BLEU 值处于该问题较优的期望范围内,且接近最优水平。并且我们并没有对选择的模型配置进行特别的优化。

1
2
3
4
复制代码BLEU-1: 0.579114
BLEU-2: 0.344856
BLEU-3: 0.252154
BLEU-4: 0.131446

生成新的图像字幕

现在我们了解了如何开发和评估字幕生成模型,那么我们如何使用它呢?

我们需要模型文件中全新的图像,还需要 Tokenizer 用于对模型生成单词进行编码,生成序列和定义模型时使用的输入序列最大长度。

我们可以对最大序列长度进行硬编码。文本编码后,我们就可以创建 tokenizer,并将其保存至文件,这样我们可以在需要的时候快速加载,无需整个 Flickr8K 数据集。另一个方法是使用我们自己的词汇文件,在训练过程中将其映射到取整函数。

我们可以按照之前的方式创建 Tokenizer,并将其保存为 pickle 文件 tokenizer.pkl。完整示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
复制代码from keras.preprocessing.text import Tokenizer
from pickle import dump

# load doc into memory
def load_doc(filename):
# open the file as read only
file = open(filename, 'r')
# read all text
text = file.read()
# close the file
file.close()
return text

# load a pre-defined list of photo identifiers
def load_set(filename):
doc = load_doc(filename)
dataset = list()
# process line by line
for line in doc.split('\n'):
# skip empty lines
if len(line) < 1:
continue
# get the image identifier
identifier = line.split('.')[0]
dataset.append(identifier)
return set(dataset)

# load clean descriptions into memory
def load_clean_descriptions(filename, dataset):
# load document
doc = load_doc(filename)
descriptions = dict()
for line in doc.split('\n'):
# split line by white space
tokens = line.split()
# split id from description
image_id, image_desc = tokens[0], tokens[1:]
# skip images not in the set
if image_id in dataset:
# create list
if image_id not in descriptions:
descriptions[image_id] = list()
# wrap description in tokens
desc = 'startseq ' + ' '.join(image_desc) + ' endseq'
# store
descriptions[image_id].append(desc)
return descriptions

# covert a dictionary of clean descriptions to a list of descriptions
def to_lines(descriptions):
all_desc = list()
for key in descriptions.keys():
[all_desc.append(d) for d in descriptions[key]]
return all_desc

# fit a tokenizer given caption descriptions
def create_tokenizer(descriptions):
lines = to_lines(descriptions)
tokenizer = Tokenizer()
tokenizer.fit_on_texts(lines)
return tokenizer

# load training dataset (6K)
filename = 'Flickr8k_text/Flickr_8k.trainImages.txt'
train = load_set(filename)
print('Dataset: %d' % len(train))
# descriptions
train_descriptions = load_clean_descriptions('descriptions.txt', train)
print('Descriptions: train=%d' % len(train_descriptions))
# prepare tokenizer
tokenizer = create_tokenizer(train_descriptions)
# save the tokenizer
dump(tokenizer, open('tokenizer.pkl', 'wb'))

现在我们可以在需要的时候加载 tokenizer,无需加载整个标注训练数据集。下面,我们来为一个新图像生成描述,下面这张图是我从 Flickr 中随机选的一张图像。

海滩上的狗

我们将使用模型为它生成描述。首先下载图像,保存至本地文件夹,文件名设置为「example.jpg」。然后,我们必须从 tokenizer.pkl 中加载 Tokenizer,定义生成序列的最大长度,在对输入数据进行填充时需要该信息。

1
2
3
4
复制代码# load the tokenizer
tokenizer = load(open('tokenizer.pkl', 'rb'))
# pre-define the max sequence length (from training)
max_length = 34

然后我们必须加载模型,如前所述。

1
2
复制代码# load the model
model = load_model('model-ep002-loss3.245-val_loss3.612.h5')

接下来,我们必须加载要描述和提取特征的图像。

重定义该模型、向其中添加 VGG-16 模型,或者使用 VGG 模型来预测特征,使用这些特征作为现有模型的输入。我们将使用后一种方法,使用数据准备阶段所用的 extract_features() 函数的修正版本,该版本适合处理单个图像。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
复制代码# extract features from each photo in the directory
def extract_features(filename):
# load the model
model = VGG16()
# re-structure the model
model.layers.pop()
model = Model(inputs=model.inputs, outputs=model.layers[-1].output)
# load the photo
image = load_img(filename, target_size=(224, 224))
# convert the image pixels to a numpy array
image = img_to_array(image)
# reshape data for the model
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# prepare the image for the VGG model
image = preprocess_input(image)
# get features
feature = model.predict(image, verbose=0)
return feature

# load and prepare the photograph
photo = extract_features('example.jpg')

之后使用评估模型定义的 generate_desc() 函数生成图像描述。为单个全新图像生成描述的完整示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
复制代码from pickle import load
from numpy import argmax
from keras.preprocessing.sequence import pad_sequences
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.models import Model
from keras.models import load_model

# extract features from each photo in the directory
def extract_features(filename):
# load the model
model = VGG16()
# re-structure the model
model.layers.pop()
model = Model(inputs=model.inputs, outputs=model.layers[-1].output)
# load the photo
image = load_img(filename, target_size=(224, 224))
# convert the image pixels to a numpy array
image = img_to_array(image)
# reshape data for the model
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# prepare the image for the VGG model
image = preprocess_input(image)
# get features
feature = model.predict(image, verbose=0)
return feature

# map an integer to a word
def word_for_id(integer, tokenizer):
for word, index in tokenizer.word_index.items():
if index == integer:
return word
return None

# generate a description for an image
def generate_desc(model, tokenizer, photo, max_length):
# seed the generation process
in_text = 'startseq'
# iterate over the whole length of the sequence
for i in range(max_length):
# integer encode input sequence
sequence = tokenizer.texts_to_sequences([in_text])[0]
# pad input
sequence = pad_sequences([sequence], maxlen=max_length)
# predict next word
yhat = model.predict([photo,sequence], verbose=0)
# convert probability to integer
yhat = argmax(yhat)
# map integer to word
word = word_for_id(yhat, tokenizer)
# stop if we cannot map the word
if word is None:
break
# append as input for generating the next word
in_text += ' ' + word
# stop if we predict the end of the sequence
if word == 'endseq':
break
return in_text

# load the tokenizer
tokenizer = load(open('tokenizer.pkl', 'rb'))
# pre-define the max sequence length (from training)
max_length = 34
# load the model
model = load_model('model-ep002-loss3.245-val_loss3.612.h5')
# load and prepare the photograph
photo = extract_features('example.jpg')
# generate description
description = generate_desc(model, tokenizer, photo, max_length)
print(description)

这种情况下,生成的描述如下:

1
复制代码startseq dog is running across the beach endseq

移除开始和结束的标记,或许这就是我们希望模型生成的语句。至此,我们现在已经完整地使用模型为图像生成文本描述,虽然这一实现非常基础与简单,但它是我们继续学习强大图像描述模型的基础。我们也希望本文能带领给为读者实操地理解图像描述模型。

原文链接:machinelearningmastery.com/develop-a-d…

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

使用 Hyperledger Composer 创建强大的区

发表于 2017-12-12

使用 Hyperledger Composer 创建强大的区块链查询

原生查询语言和 REST API 让查询业务网络的分布式账本变得很轻松

对于应用程序开发人员,查询资源(比如区块链上各个区块中的资产和交易数据)可能是一个复杂的任务。您不禁思考:如何编写可被区块链理解的业务查询?如何获取结果?如何使用结果?您不该纠结于各种格式数据的不同性质、它们之间的转换以及提取出直观的结果集。相反,您应该只需简单地提交查询并处理查询结果。

考虑以下逻辑查询,例如:

  • “显示贸易方 A 创建的商品列表。”
  • “显示贸易方 B 上周创建的 EURONEXT 类型的商品列表。”
  • “显示特定贸易方在指定时间范围对特定商品实例的账本的更改历史。”

如何将这些查询转换为“区块链语言”呢?

通过使用 Hyperledger Composer 的富查询特性,可以轻松地编写诸如此类的强大查询。查询特性填补了空白,使您不再需要考虑构建语法,解析、转换和验证查询,以及查询的类型和您从 Hyperledger Fabric 区块链获得的结果。也无需担心如何保证格式一致。 Hyperledger Composer 的查询语言消除了这些麻烦,使得可以通过简单的表示法轻松定义查询,使您的应用程序能够基于业务定义的条件而执行操作。 然后可以直接在您的应用程序中使用查询,或者通过生成的查询 REST
API 来使用查询,都非常简单。

Hyperledger Composer 基于 Hyperledger Fabric v1.0,但在
Hyperledger Fabric v1.0 之前
,查询需要 Hyperledger Fabric 生成的 ID。现在借助 Hyperledger Composer 的富查询功能和 Hyperledger Fabric v1.0,可以使用任何属性或变量参数来查询资源。通过对查询使用参数,您不需要构造复杂的专门查询,而且可以构造模板查询,在每次使用时再设置参数。

关于本教程

本教程演示 Hyperledger Composer 查询语言的强大功能,它使用区块链资源的属性来返回查询结果。我们将以使用 Hyperledger Composer 部署的一个现有的 trade-network 示例 为例,展示如何结合使用富查询和
REST 服务器,以及如何对结果集执行操作,比如使用交易处理程序更新或删除资产。trade-network 示例展示了两个贸易方之间的商品所有权转移。

要学习本教程,您需要下面列出的软件,以及 JSON 和基本查询定义的实用知识。完成本教程后,您就可以自由使用 Hyperledger Composer 构建网络和尝试自己的查询!

软件需求

  • Hyperledger Composer 开发工具。(下一节将安装它们。)
    • Hyperledger Composer 命令行 (composer-cli)
    • Hyperledger Composer REST 服务器 (composer-rest-server)
  • 一个正在运行的 Hyperledger Fabric v1 正式发布版运行时。(下一节将提供下载和启动 Hyperledger Fabric 环境的操作说明。)
  • 请参阅安装开发工具的 其他前提条件。

1
使用 CouchDB 设置 Hyperledger Composer 运行时


  1. 安装 Hyperledger Composer 开发工具(使用非 root 用户身份):
    1. 要安装 composer-cli,请运行以下命令:
      npm install -g composer-cli
      composer-cli 包含用于开发业务网络的所有命令行操作。
    2. 要安装 composer-rest-server,请运行以下命令:
      npm install -g composer-rest-server
      Hyperledger Composer REST 服务器使用 Hyperledger Composer Loopback Connector 连接到业务网络,提取网络的模型和模式,然后生成一个 REST Explorer Web 页面,其中包含已为该模型生成的 REST API。
对于本教程,只需要使用这两个工具。
  1. 按照 安装 Hyperledger Composer 并进行开发 中的操作说明,启动 Hyperledger Fabric 运行时。
    启动后,您应该能够通过这个命令查看下面 4 个 Docker 服务:
1
复制代码docker ps -a


点击查看大图

为了使用富查询,需要设置和启用一个 CouchDB 映像和 CouchDB 配置。二者都已在 Hyperledger Composer v0.11+ 中通过 Hyperledger Fabric V1 运行时完成。您可以为每个对等节点或一个特定对等节点启用 CouchDB,并将该对等节点设置为依赖于它。对于启用了 CouchDB 的对等节点,已为其配置了一个 CouchDB 实例。

一个具有合适配置的 docker-compose.yml 文件类似于这个示例:

2
向业务网络添加查询


Hyperledger Composer 查询语言支持 CouchDB Mango 查询语言标准。查询是业务网络定义中查询文件 (.qry) 定义的 JSON 对象。查询可用于返回以下资源:

  • 资产
  • 参与者
  • 历史数据

对于本教程,我们将使用 Hyperledger Composer trade-network 示例 中定义的简单查询,可从 GitHub 下载该示例。

查询使用的语法很简单,由初始 query 关键字后跟查询名称(在本例中为 selectCommoditiesByOwner)来定义。description 字段应包含查询功能的有意义、人类可读的描述。statement 字段包含 SELECT 和 WHERE 属性,前者定义要查询的注册表或资源集,后者定义要返回资源而必须满足的条件。WHERE 属性可以包含语义 AND/OR 修饰符。有关更多细节,请查阅 Hyperledger Composer 查询语言文档。

要添加查询,可以打开一个现有的业务网络,创建一个 queries.qry 文件。在这个文件中,可以描述多个查询。下面的示例查询将返回其 owner 属性与在发送查询时提供的 _$owner 变量参数匹配的所有商品。

Hyperledger Composer 中的 Historian 特性维护业务网络中发生的账本更新历史。提交交易时,会更新 HistorianRecord,而且随着时间的推移,交易注册表会逐渐变大,既包含特定交易的交易输入,也包含提交这些交易所涉及的参与者和身份。为了演示 Hyperledger Composer 中的 Historian 特性,我们向 queries.qry 文件添加了两个额外的交易:

清单 1.包含更多交易的示例查询
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
复制代码query showCommodityAllHistorians{ 
  description: "Select commodity all historians"
  statement: 
    SELECT org.hyperledger.composer.system.HistorianRecord FROM
HistorianRegistry
          WHERE (transactionType == 'AddAsset' OR transactionType ==
'UpdateAsset' OR transactionType == 'RemoveAsset')
}
 
query findCommmodityHistoriansWithTime{
  description: "Find commodity historians after a specified time"
  statement: 
    SELECT org.hyperledger.composer.system.HistorianRecord FROM
HistorianRegistry WHERE (transactionTimestamp > _$justnow)
}

可以将这两个查询剪切并粘贴到 queries.qry 文件中,然后保存该文件。

3
将业务网络部署到 Hyperledger Fabric


现在您应该有一个包含 queries.qry 文件的业务网络。 如果您拥有 Hyperledger Fabric 的本地实例,可以使用以下命令部署该业务网络:

1
复制代码composer network deploy -p hlfv1 -a ./trade-network.bna -i PeerAdmin -s adminpw

可以将完整路径传递给 trade-network.bna;hlfv1 是与 Hyperledger Fabric 运行时连接的概要文件名称。有关更多细节,请查阅 Hyperledger Composer 命令行文档。

如果您已部署了 trade-network.bna,则应该使用此命令更新业务网络:

1
复制代码composer network update -p hlfv1 -a ./trade-network.bna -i PeerAdmin -s adminpw

您应该看到一条成功部署消息:


部署业务网络后,以下命令应该显示一个链代码容器:

1
复制代码docker ps –a

4
通过 REST API 探索富查询


部署包含查询的业务网络后,现在启动 Composer REST 服务器来通过 REST API 公开这些查询。 Hyperledger Composer REST 服务器基于业务网络而生成一组预定的 REST API。

  1. 从命令行,使用以下命令启动 REST 服务器:
1
复制代码composer-rest-server
  1. 提供如下所示的信息:
  2. 从 Web 浏览器键入以下地址来打开 REST 服务器:
    http://localhost:3000/explorer

您应该看到以下查询 API:


目前,您的业务网络还没有任何要查询的数据。在下一步,您将创建一些参与者、资产和一个 Historian 注册表来测试这些查询。

5
在 World State 数据库中创建参与者、资产和 Historian 注册表


在本节中,我们将创建两个参与者和两个资产,然后提交一个贸易交易。

创建两个参与者

可以从 REST 服务器界面创建参与者。单击 Trader,然后单击 POST 来创建一个 Trader 参与者。Trader 参与者需要 3 个字段,如下所示。执行此操作两次,以创建两个参与者。

清单 2.一个 Trader 参与者的示例
1
2
3
4
5
6
复制代码{
    "$class": "org.acme.trading.Trader",
    "tradeId": "fenglian@email.com",
    "firstName": "Fenglian",
    "lastName": "Xu"
}

创建两个 Commodity 资产

接下来,我们用创建参与者的相同方式创建两个 Commodity 资产。执行此操作两次,以创建两个商品。

清单 3.一个 Commodity 资产的示例
1
2
3
4
5
6
7
8
复制代码{
  "$class": "org.acme.trading.Commodity",
  "tradingSymbol": "XYZ",
  "description": "Soya",
  "mainExchange": "Chicago",
  "quantity": 50,
  "owner": "dan@email.com"
}

单击 Commodity 下的 GET 显示创建的所有商品。

清单 4.资产注册表中创建的两个资产
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
复制代码[
  {
    "$class": "org.acme.trading.Commodity",
    "tradingSymbol": "EMA",
    "description": "Corn",
    "mainExchange": "Euronext",
    "quantity": 100,
    "owner": "resource:org.acme.trading.Trader#dan@email.com"
  },
  {
    "$class": "org.acme.trading.Commodity",
    "tradingSymbol": "XYZ",
    "description": "Soya",
    "mainExchange": "Chicago",
    "quantity": 50,
    "owner": "resource:org.acme.trading.Trader#dan@email.com"
  }
]

请注意,owner 属性与一个 Participant 实例相关联。也就是说,它是在 trade-network 模型中建模的。Participant 注册表中的 id dan@email.com(完全限定的关系引用了它)包含所关注的参与者的一个 URI 前缀 resource:org.acme.trading.Trader#。

创建贸易交易

贸易交易(在我们的示例中)是定义来改变商品所有权的交易。所有交易都记录在这个商品贸易业务网络的 Historian 注册表中。它维护交易的历史、它们的类型,以及特定交易所添加的更改/增量(或者使用 create 时添加的交易,比如 create asset)。

导航到 Trade 节,打开 POST。

在 data 字段中,放入以下数据:

清单 5.一个贸易交易的示例
1
2
3
4
5
6
7
复制代码{
  "$class": "org.acme.trading.Trade",
  "commodity": "EMA",
  "newOwner": "fenglian@email.com",
  "transactionId": "",
  "timestamp": "2017-08-07T15:04:33.790Z"
}

然后单击 Try it out! 按钮提交该贸易交易。此交易将 EMA 商品的所有者从 dan 更改为 fenglian。可以单击 Commodity GET 查询下方的 Try it out! 来验证此结果。

清单 6.更新的 EMA 商品所有者
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
复制代码[
  {
    "$class": "org.acme.trading.Commodity",
    "tradingSymbol": "EMA",
    "description": "Corn",
    "mainExchange": "Euronext",
    "quantity": 100,
    "owner": "resource:org.acme.trading.Trader#fenglian@email.com"
  },
  {
    "$class": "org.acme.trading.Commodity",
    "tradingSymbol": "XYZ",
    "description": "Soya",
    "mainExchange": "Chicago",
    "quantity": 50,
    "owner": "resource:org.acme.trading.Trader#dan@email.com"
  }
]

6
通过 REST 服务器查询资源


本节将通过 REST 服务器,使用富查询函数和变量参数来查询业务网络数据。

Hyperledger Composer 和 REST 服务器支持的富查询数据类型

由于 Hyperledger Composer Loopback Connector 的原因,通过 REST API 公开的查询中的参数数据类型不同于 Composer 数据类型,如表 1 所示。

表 1 :Composer 与 REST 服务器之间的基本数据类型对应关系
组件
Composer 运行时 String Double/Integer/Long DateTime Boolean
REST 服务器 string number date boolean

除了上述 Composer 运行时基本数据类型,Composer 运行时还支持以下数据类型:

  • 一种关联关系:此关系始终由一个 String 类型的 key 属性来标识。
  • 一种类包含关系:属性数据类型是一个类,这个类又可以拥有另一个类包含关系,以此类推。这是一种嵌套关系,而且是一种基本类型。
  • 枚举:此类型用于一组预定义的数据。

使用参数来查询资产

在 REST 服务器 UI 中展开该查询。queries.qry 文件中之前定义了一个查询列表。可以使用任何这些查询来查询 CouchDB 数据库中现在包含的业务网络数据。

selectComoditiesByExchange 查询允许您通过交换符号属性来返回商品。


单击 Try it out! 发送该查询。结果将显示在 Try it out! 按钮下方:


响应页面显示了 Curl 和针对包含一个参数的查询格式的 Web 浏览器 URL。

selectCommoditiesByOwner 查询演示了一个使用 owner 属性的查询。您可以打开 selectCommoditiesByOwner 查询的 GET 面板,定义您想要查询的所有者:


下面显示了针对指定所有者的查询结果。

查询交易历史

Historian 是一个 Hyperledger Composer 特性,用于跟踪交易和资产更新。提交一个交易时,HistorianRecord 显示对一个业务网络中资产的更改,以及提交这些交易所涉及的参与者和身份。

Historian 是 Hyperledger Composer 系统名称空间中定义的一种资产,如清单 7 所示。

如果您想以编程方式使用 Historian,请参阅 Hyperledger Composer HistorianRecord 文档。

清单 7.Historian 模型
1
2
3
4
5
6
7
8
9
复制代码asset HistorianRecord identified by transactionId {
  o String      transactionId
  o String      transactionType
  --> Transaction transactionInvoked
  --> Participant participantInvoking  optional
  --> Identity    identityUsed         optional
  o Event[]       eventsEmitted        optional
  o DateTime      transactionTimestamp
}

导航到 findCommodityHistoriansWithTime 查询,以 UTC 格式指定一个日期时间值,然后单击 Try it out!。


此查询会找到 Historian 注册表中在指定日期时间后发生的交易的所有历史记录,如清单 8 所示。交易类型同时包含系统交易和用户定义的交易。例如,Trade 是用户定义的交易,AddAsset 是系统交易。

清单 8.交易的历史记录
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
复制代码[
  {
    "$class": "org.hyperledger.composer.system.HistorianRecord",
    "transactionId": "0392a2e7-e056-4442-82c7-7d020cd0ee7a",
    "transactionType": "AddParticipant",
    "transactionInvoked": "resource:org.hyperledger.composer.system.Transaction#0392a2e7-e056-4442-82c7-7d020cd0ee7a",
    "eventsEmitted": [],
    "transactionTimestamp": "2017-08-07T15:08:19.047Z"
  },
  {
    "$class": "org.hyperledger.composer.system.HistorianRecord",
    "transactionId": "04f4410c-3470-4319-a6c5-3f5681d86488",
    "transactionType": "AddAsset",
    "transactionInvoked": "resource:org.hyperledger.composer.system.Transaction#04f4410c-3470-4319-a6c5-3f5681d86488",
    "eventsEmitted": [],
    "transactionTimestamp": "2017-08-07T15:13:56.223Z"
  },
  {
    "$class": "org.hyperledger.composer.system.HistorianRecord",
    "transactionId": "2529902c-1393-4537-98ca-f3e4d46a3164",
    "transactionType": "Trade",
    "transactionInvoked": "resource:org.hyperledger.composer.system.Transaction#2529902c-1393-4537-98ca-f3e4d46a3164",
    "eventsEmitted": [
      {
        "$class": "org.acme.trading.TradeNotification",
        "commodity": "resource:org.acme.trading.Commodity#EMA",
        "eventId": "2529902c-1393-4537-98ca-f3e4d46a3164#0",
        "timestamp": "2017-08-07T15:04:33.790Z"
      }
    ],
    "transactionTimestamp": "2017-08-07T15:04:33.790Z"
  },
  {
    "$class": "org.hyperledger.composer.system.HistorianRecord",
    "transactionId": "3f31add4-5e4d-40cd-84a5-e5553a14cb50",
    "transactionType": "AddAsset",
    "transactionInvoked": "resource:org.hyperledger.composer.system.Transaction#3f31add4-5e4d-40cd-84a5-e5553a14cb50",
    "eventsEmitted": [],
    "transactionTimestamp": "2017-08-07T15:13:03.584Z"
  },
  {
    "$class": "org.hyperledger.composer.system.HistorianRecord",
    "transactionId": "a50145ba-df31-4def-932e-cfc9707131ec",
    "transactionType": "AddParticipant",
    "transactionInvoked": "resource:org.hyperledger.composer.system.Transaction#a50145ba-df31-4def-932e-cfc9707131ec",
    "eventsEmitted": [],
    "transactionTimestamp": "2017-08-07T15:07:15.985Z"
  }
]

结束语

内置于 Hyperledger Composer 中的富查询语言,让对资产和交易执行复杂且强大的查询变得很容易。 如果没有原生查询语言,您需要在代码中构造专门查询,然后努力理解返回的结果。

本教程展示了如何在 Hyperledger Composer 中定义查询,对启用了 CouchDB 的 Hyperledger Fabric 进行业务网络部署或更新,设置一个 REST 服务器,以及使用 REST 服务器查询业务网络。Hyperledger Composer 查询语言使查询任何属性都变得非常灵活而轻松。

后续行动

  • 在这篇 深入介绍 文章中进一步了解 Hyperledger Composer。
  • 加入 Hyperledger Composer 社区以解决您的难题,为该项目做贡献,以及参加每周开放社区电话会议。
  • 尝试使用 Hyperledger Composer 创建一个 去中心化的能源网络。

致谢

感谢 Hyperledger Composer 开发团队,特别感谢 Daniel Selman 和 Simon Stone 提供技术支持。还要感谢 Edward Prosser 和 Rachel Jackson 提供技术评审和支持。

相关主题

  • IBM Blockchain 开发人员中心
  • 安装 Hyperledger Composer 并进行开发
  • Hyperledger Composer 业务网络定义示例(在 GitHub 上)
  • Hyperledger Composer 查询语言
  • Hyperledger Composer 命令行
  • HyperLedger Composer Historian
  • IBM Blockchain Platform
  • 面向开发人员的 IBM Blockchain Platform
  • 区块链基础课程(面向开发人员的免费课程)

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

痛入爽出 HTTP/2:代码实战2 结语

发表于 2017-12-12

一个写文档的开发者,其实就是个 Docker


正文

上一期我们熟悉了应用场景和测试,这一期我们实现receive函数。

先重温一下 API:

1
2
3
4
5
复制代码class HTTP2Protocol:
def receive(self, data: bytes):
pass
def send(self, stream: Stream):
pass

我们的整体设计思路是 Event Driven + Mutable State.

Event Driven:gethy 内部自定义一些事件(Event),HTTP2Protocol的 Public API 只会返回这些 Event 而已。

Mutable State:HTTP2Protocol内部会管理两个缓冲(Buffer),一个inbound_buffer储存接收的数据,一个outbound_buffer储存需要发送的数据。这两个 Buffer 都是私有的,用户不应该使用。根据不同的事件,HTTP2Protocol会向 Buffer 添加数据或者清除数据。

HTTP2Protocol 类

现在,我们来看更具体的函数签名:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
复制代码# http2protocol.py
from typing import List

import h2.config
import h2.connection
import h2.events
import h2.exceptions
from h2.events import (
RequestReceived,
DataReceived,
WindowUpdated,
StreamEnded
)

from gethy.event import H2Event


class HTTP2Protocol:
def __init__(self):
self.current_events = []

self.request_buffer = {} # input buffer
self.response_buffer = {} # output buffer, not used in this tutorial

config = h2.config.H2Configuration(client_side=False, header_encoding='utf-8')
self.http2_connection = h2.connection.H2Connection(config=config)

def receive(self, data: bytes) -> List[H2Event]:
pass

current_events:顾名思义,用来存放目前已知的事件。

request_buffer:存放没有接收完整的 Request Stream。

response_buffer:存放没有完全发送的 Response Stream。

Stream 类

当然,我们还需要一个Stream来表示一个数据流。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
复制代码class Stream:
def __init__(self, stream_id: int, headers):
self.stream_id = stream_id
self.headers = headers # as the name indicates

# when stream_ended is True
# buffered_data has to be None
# and data has to be a bytes
#
# if buffered_data is empty
# then both buffered_data and data have to be None when stream_ended is True
#
# should write a value enforcement contract decorator for it
self.stream_ended = False
self.buffered_data = []
self.data = None

流程图

在实现之前,我们先来看看流程图。

Receive 逻辑

如图所示,我们的工作流程是纯线性的,所以也使其逻辑简明,容易实现。

receive

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
复制代码def receive(self, data: bytes):
"""
receive bytes, return HTTP Request object if any stream is ready
else return None

:param data: bytes, received from a socket
:return: list, of Request
"""
# First, proceed incoming data
# handle any events emitted from h2
events = self.http2_connection.receive_data(data)
for event in events:
self._handle_event(event)

self._parse_request_buffer()

events = self.current_events # assign all current events to an events variable and return this variable
self.current_events = [] # empty current event list by assign a newly allocated list

return events

这里就将receive函数写好了,接下来实现_handle_event和_parse_request_buffer。

_handle_event

Handle events 的部分由几个重要的函数组成。

1
2
3
4
5
6
7
8
9
10
11
12
复制代码def _handle_event(self, event: h2.events.Event):
# RequestReceived 的命名可能产生误解。
# 这里不是说一个完整的 Request 收到了。
# 而是说,Headers 收到了。
if isinstance(event, h2.events.RequestReceived):
self._request_headers_received(event)

elif isinstance(event, h2.events.DataReceived):
self._data_received(event)

else:
logging.info("Has not implement %s handler" % type(event))

首先_handle_event要判断是哪种 h2 事件。我们用if/else来将事件导流到相应的函数去。本期只关心 Request(Headers&Data),其余事件简单地打印出来。

注:这里的 h2 事件其实和 HTTP/2 的 frame 有直接的关系。一个 Request 事件其实就是一个 Request Frame。一个 Data 事件其实就是一个 Data Frame。

参考文档:

Hyper-h2 API

http2 FramingLayer

_request_headers_received

1
2
3
4
5
6
7
8
复制代码def _request_headers_received(self, event: RequestReceived):
self.request_buffer[event.stream_id] = Stream(event.stream_id, event.headers)

if event.priority_updated:
logging.warning("RequestReceived.priority_updated is not handled")

if event.stream_ended:
self._stream_ended(event.stream_ended)

这个 event 里有stream_id&headers,将其拿到并构造一个Stream实例。如果数据流结束,则调用_stream_ended。这里stream_ended == True的意思就是这个 Request 只有 Headers。通常的GET或者POST url param encoded就属于这个类型。很多框架甚至不允许GET带有 Request Body/Data。

_data_received

1
2
3
4
5
复制代码def _data_received(self, event: DataReceived):
self.request_buffer[event.stream_id].buffered_data.append(event.data)

if event.stream_ended:
self._stream_ended(event.stream_ended)

Request 也可以带有 Data,所以就会触发这个事件。这里request_buffer[event.stream_id]是一定不能触发KeyError的,因为只有可能先接收 Headers,再接收 Data。如果有 KeyError,那么八阿哥一定潜伏于某处。这里stream_ended == True就说明 Request 完整接收了。

_stream_ended

1
2
3
4
5
复制代码def _stream_ended(self, event: StreamEnded):
stream = self.request_buffer[event.stream_id]
stream.stream_ended = True
stream.data = b''.join(stream.buffered_data)
stream.buffered_data = None

当接收完一个 Request 数据流后,将Stream实例的状态做一些调整。

_parse_request_buffer

这样,我们就将所有数据都处理好了。现在的任务就是将缓冲扫描一遍,看有没有指的返回的东西。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
复制代码def _parse_request_buffer(self):
"""
exercise all inbound streams
"""
# This is a list of stream ids
streams_to_delete_from_request_buffer = []

# inbound_streams is a dictionary with schema {stream_id: stream_obj}
# therefore use .values()
for stream in self.request_buffer.values():
if stream.stream_ended:
# create a HTTP Request event, add it to current event list
event = RequestEvent(stream)
self.current_events.append(event)

# Emitting an event means to clear the cached inbound data
# The caller has to handle all returned events. Otherwise bad
streams_to_delete_from_request_buffer.append(stream.stream_id)

# clear the inbound cache
for stream_id in streams_to_delete_from_request_buffer:
del self.request_buffer[stream_id]

这里的逻辑也简单明了,检查有没有完整的 Request,有的话就构造一个完整的RequestEvent,然后将其放到self.current_events中。最后从缓冲中删除相应的Stream。

RequestEvent 类

RequestEvent定义如下:

1
2
3
4
5
6
7
复制代码# events.py
class H2Event:
pass

class RequestEvent(H2Event):
def __init__(self, stream):
self.stream = stream

纯粹为了代码可读性而定义的。

仔细的同学可能会看到两点:

  • 在_stream_ended中就可以完成这个函数中的所有操作,没有必要再 loop 一遍浪费时间。
  • 如果非要再 loop 一遍,可以写成函数式的,returncurrent_events,而不是更改对象的值。

完全正确,这里我为了大家看得简单明了,所以选择了更简洁,但是效率稍微慢一点的实现。

结语

到这里你就实现了一个完全正确可用的 HTTP/2 服务器端的接收功能。下一期就要实现发送了。

视频对文章进行补充,感兴趣就去看看吧!代码在 GitHub,喜欢给个🌟呗!

代码

GitHub

视频

B 站

油腻的管道(你留言我就上传)

文章

上期

下期(还没写)

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

Java—多线程基础

发表于 2017-12-12

基本概念

进程

所谓进程就是运行在操作系统的一个任务,进程是计算机任务调度的一个单位,操作系统在启动一个程序的时候,会为其创建一个进程,JVM就是一个进程。进程与进程之间是相互隔离的,每个进程都有独立的内存空间。

计算机实现并发的原理是:CPU分时间片,交替执行,宏观并行,微观串行。同理,在进程的基础上分出更小的任务调度单元就是线程,我们所谓的多线程就是一个进程并发多个线程。

线程

在上面我们提到,一个进程可以并发出多个线程,而线程就是最小的任务执行单元,具体来说,一个程序顺序执行的流程就是一个线程,我们常见的main就是一个线程(主线程)。

线程的组成

想要拥有一个线程,有这样的一些不可或缺的部分,主要有:CPU时间片,数据存储空间,代码。

CPU时间片都是有操作系统进行分配的,数据存储空间就是我们常说的堆空间和栈空间,在线程之间,堆空间是多线程共享的,栈空间是互相独立的,这样做的好处不仅在于方便,也减少了很多资源的浪费。代码就不做过多解释了,没有代码搞个毛的多线程。

线程的创建和启动
传统创建线程有两种方式
  1. 继承Thread类,覆盖run方法
  2. 实现Runnable接口,覆盖run方法

Runnable并不是线程对象,而是一个任务对象。那么Runnable和Thread有什么样的关系呢?

通过查阅API,我们发现创建一个线程除了使用Thread的无参构造方法以外有一个有参构造方法是这样 :Thread(Runnable target),通过这个方法会分配一个新的Thread 对象。

其中的参数是一个类型为Runnable的target属性。

Runnable接口最大的作用就是为非Thread子类的类提供了一种实现线程的方式,只需要实现Runnable接口就可以借助Thread创建一个线程;另一方面,如果只想重写run方法,不想得到其他的Thread的方法,实现Runnable是一个好的选择。

JDK1.5

线程池

ExecutorService(线程池 interface)

1
2
复制代码//通过工具类中的方法能够新建一个线程池,用ExecutorService接受
ExecutorService es = Executors.newFixedThreadPool(2);

Callable对象

类似于Runnable(描述任务的interface)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
复制代码//创建一个Callable的实现类
Callable<Integer> task1 = new Callable<Integer>(){
public Integer call() throws Exception{
int result = 0;
for(int i=2;i<=100;i+=2){
result += i;
Thread.sleep;
}
return result;
}

}

//用Future对象接收fask1的返回值 将任务提交给线程池
Future<Integer> f = es.submit(task1);
//通过get方法获取Future中的值 在这个时候主线程主动的调取get 如果分支线程还没有结束,主线程会在这里阻塞
int result = f1.get();
//关闭线程池
es.shutdown();

从以上这段代码我们可以看到很多不一样的地方,首先在Callable对象中是可以抛出异常的,其次有返回值,在这个基础上也就引出了一个新的问题,如果接收该线程的对象?JDK1.5中也给出了解决的方法是Future对象.

启动线程

在这里我们需要明白,上面两种方式并不会让我们得到真正的线程,只是得到了线程对象,只有启动线程,才算得到了真正的线程。

通过执行start()方法能够启动一个线程,但是启动线程并不是立即执行,成功启动的线程会处于就绪状态,什么时候执行需要等到拿到时间片之后。

线程的分类

用户线程和守护(Daemon)线程。

守护线程:守护线程会一直运行,直到其他非守护线程都结束的时候,才会结束。有一个典型的守护线程就是:垃圾回收线程,和虚拟机共存亡,直到虚拟机中没有任何线程的时候虚拟机关闭的时候才会终止,简单说就是虚拟机在,它就在,虚拟机亡便亡。

线程的状态

线程的状态

上面我们提到过,一个线程在启动之后不会立马执行,而是处于就绪状态(Ready),就绪状态就是线程的状态的一种,处于这种状态的线程意味着一切准备就绪, 需要等待系统分配到时间片。为什么没有立马运行呢,因为同一时间只有一个线程能够拿到时间片运行,新线程启动的时候让它启动的线程(主线程)正在运行,只有等主线程结束,它才有机会拿到时间片运行。

**线程的状态:**初始状态(New),就绪状态(Ready),运行状态(Running)(特别说明:在语法的定义中,就绪状态和运行状态是一个状态Runable),等待状态(Waitering),终止状态(Terminated)

RUNNABLE),等待状态(Waitering),终止状态(Terminated)

初始状态(New)

线程对象被创建出来,便是初始状态,这时候线程对象只是一个普通的对象,并不是一个线程。

Runable

**就绪状态(Ready):**执行start方法之后,进入就绪状态,等待被分配到时间片。

**运行状态(Running):**拿到CPU的线程开始执行。处于运行时间的线程并不是永久的持有CPU直到运行结束,很可能没有执行完毕时间片到期,就被收回CPU的使用权了,之后将会处于等待状态。

等待状态(Waiting)

等待状态分为有限期等待和无限期等待,所谓有限期等待是线程使用sleep方法主动进入休眠,有一定的时间限制,时间到期就重新进入就绪状态,再次等待被CPU选中。

而无限期等待就有些不同了,无限期并不是指永远的等待下去,而是指没有时间限制,可能等待一秒也可能很多秒。至于进入等待的原因也不尽相同,可能是因为CPU时间片到期,也可能是因为一个比较耗时的操作(数据库),或者主动的调用join方法。

wait和sleep的区别

wait sleep
wait()方法是Object类里的方法 sleep()是Thread类的static(静态)的方法
wait()睡眠时,释放对象锁 sleep()睡眠时,保持对象锁,仍然占有该锁
常用于线程间通信 常用于暂停执行
wait和notify/notifyAll是成对出现的, 必须在synchronize块中被调用
阻塞状态(Blocked)

在我看来,阻塞状态实际上是一种比较特殊的等待状态,处于其他等待状态的线程是在等着别的线程执行结束,等着拿CPU的使用权;而处于阻塞状态的线程等待的不仅仅是CPU的使用权,主要是锁标记,没有拿到锁标记,即便是CPU有空也没有办法执行。(关于锁见下节:线程同步)

等待和阻塞的区别

等待 阻塞
已经拿到锁对象,或者说不存在拿不到执行不了的情况 等待拿到锁对象
等待被唤醒 等待拿到锁对象
终止线程(Terminated)

已经终止的线程会处于该种状态。

总结

总体上来说,作为一个线程挺倒霉的,首先,不会知道自己什么时候被选中;其次在执行过程中随时可能被打断让出CPU,最后碰到数据库等耗时的操作也要让出CPU去等待,并且就算数据准备好了, 仍然需要等着被挑选。

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

用python写通用restful api service(

发表于 2017-12-12

今天项目已经能够做一个简单的后端服务了,在mysql中新建一个表,就能自动提供restful api的CURD服务了。

关键点

  • 根据REST的四种动词形式,动态调用相应的CURD方法;
  • 编写REST与基础数据库访问类之间的中间层(baseDao),实现从REST到数据访问接口之间能用业务逻辑处理;
  • 编写基础数据库访问类(dehelper),实现从字典形式的参数向SQL语句的转换;

实现的rest-api

实现了如下形式的rest-api

1
2
3
4
5
复制代码[GET]/rs/users/{id}
[GET]/rs/users/key1/value1/key2/value2/.../keyn/valuen
[POST]/rs/users
[PUT]/rs/users/{id}
[DELETE]/rs/users/{id}

基础数据库访问类

该类实现与pymysql库的对接,提供标准CURD接口。

准备数据库表

在数据库对应建立users表,脚本如下:

1
2
3
4
5
6
7
8
9
10
复制代码CREATE TABLE `users` (
`_id` int(11) NOT NULL AUTO_INCREMENT,
`name` varchar(32) CHARACTER SET utf8mb4 DEFAULT '' COMMENT '标题名称',
`phone` varchar(1024) DEFAULT '',
`address` varchar(1024) DEFAULT NULL,
`status` tinyint(4) DEFAULT '1' COMMENT '状态:0-禁;1-有效;9删除',
`create_time` datetime DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
PRIMARY KEY (`_id`),
UNIQUE KEY `uuid` (`_id`) USING BTREE
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8 COMMENT='表';

新建数据库配置文件(configs.json)

数据连接配置,不入版本库。

1
2
3
4
5
6
7
8
9
10
复制代码{
"db_config": {
"db_host": "ip",
"db_port": 1234,
"db_username": "root",
"db_password": "******",
"db_database": "name",
"db_charset": "utf8mb4"
}
}

对接pymysql接口

用函数exec_sql封装pymysql,提供统一访问mysql的接口。is_query函数用来区分是查询(R)还是执行(CUD)操作。出错处理折腾了好久,插入异常返回的错误形式与其它的竟然不一样!返回参数是一个三元组(执行是否成功,查询结果或错误对象,查询结果数或受影响的行数)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
复制代码with open("./configs.json", 'r', encoding='utf-8') as json_file:
dbconf = json.load(json_file)['db_config']


def exec_sql(sql, values, is_query=False):
try:
flag = False #是否有异常
error = {} #若异常,保存错误信息
conn = pymysql.connect(host=dbconf['db_host'], port=dbconf['db_port'], user=dbconf['db_username'],
passwd=dbconf['db_password'], db=dbconf['db_database'], charset=dbconf['db_charset'])
with conn.cursor(pymysql.cursors.DictCursor) as cursor:
num = cursor.execute(sql, values) #查询结果集数量或执行影响行数
if is_query: #查询取所有结果
result = cursor.fetchall()
else: #执行提交
conn.commit()
print('Sql: ', sql, ' Values: ', values)
except Exception as err:
flag = True
error = err
print('Error: ', err)
finally:
conn.close()
if flag:
return False, error, num if 'num' in dir() else 0
return True, result if 'result' in dir() else '', num

查询接口

pymysql的查询接口,可以接受数组,元组和字典,本查询接口使用数组形式来调用。现在此接口只支持与条件组合参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
复制代码def select(tablename, params={}, fields=[]):
sql = "select %s from %s " % ('*' if len(fields) == 0 else ','.join(fields), tablename)
ks = params.keys()
where = ""
ps = []
pvs = []
if len(ks) > 0: #存在查询条件时,以与方式组合
for al in ks:
ps.append(al + " =%s ")
pvs.append(params[al])
where += ' where ' + ' and '.join(ps)

rs = exec_sql(sql+where, pvs, True)
print('Result: ', rs)
if rs[0]:
return {"code": 200, "rows": rs[1], "total": rs[2]}
else:
return {"code": rs[1].args[0], "error": rs[1].args[1], "total": rs[2]}

插入接口

以数组形式提供参数,错误信息解析与其它接口不同。

1
2
3
4
5
6
7
8
9
10
11
复制代码def insert(tablename, params={}):
sql = "insert into %s " % tablename
ks = params.keys()
sql += "(`" + "`,`".join(ks) + "`)" #字段组合
vs = list(params.values()) #值组合,由元组转换为数组
sql += " values (%s)" % ','.join(['%s']*len(vs)) #配置相应的占位符
rs = exec_sql(sql, vs)
if rs[0]:
return {"code": 200, "info": "create success.", "total": rs[2]}
else:
return {"code": 204, "error": rs[1].args[0], "total": rs[2]}

修改接口

以字典形式提供参数,占位符的形式为:%(keyname)s,只支持按主键进行修改。

1
2
3
4
5
6
7
8
9
10
11
12
复制代码def update(tablename, params={}):
sql = "update %s set " % tablename
ks = params.keys()
for al in ks: #字段与占位符拼接
sql += "`" + al + "` = %(" + al + ")s,"
sql = sql[:-1] #去掉最后一个逗号
sql += " where _id = %(_id)s " #只支持按主键进行修改
rs = exec_sql(sql, params) #提供字典参数
if rs[0]:
return {"code": 200, "info": "update success.", "total": rs[2]}
else:
return {"code": rs[1].args[0], "error": rs[1].args[1], "total": rs[2]}

删除接口

以字典形式提供参数,占位符的形式为:%(keyname)s,只支持按主键进行删除。

1
2
3
4
5
6
7
8
复制代码def delete(tablename, params={}):
sql = "delete from %s " % tablename
sql += " where _id = %(_id)s "
rs = exec_sql(sql, params)
if rs[0]:
return {"code": 200, "info": "delete success.", "total": rs[2]}
else:
return {"code": rs[1].args[0], "error": rs[1].args[1], "total": rs[2]}

中间层(baseDao)

提供默认的操作数据库接口,实现基础的业务逻辑,单表的CURD有它就足够了。有复杂业务逻辑时,继承它,进行扩展就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
复制代码import dbhelper


class BaseDao(object):

def __init__(self, table):
self.table = table

def retrieve(self, params={}, fields=[], session={}):
return dbhelper.select(self.table, params)

def create(self, params={}, fields=[], session={}):
if '_id' in params and len(params) < 2 or '_id' not in params and len(params) < 1: #检测参数是否合法
return {"code": 301, "err": "The params is error."}
return dbhelper.insert(self.table, params)

def update(self, params={}, fields=[], session={}):
if '_id' not in params or len(params) < 2: #_id必须提供且至少有一修改项
return {"code": 301, "err": "The params is error."}
return dbhelper.update(self.table, params)

def delete(self, params={}, fields=[], session={}):
if '_id' not in params: #_id必须提供
return {"code": 301, "err": "The params is error."}
return dbhelper.delete(self.table, params)

动态调用CURD

根据客户调用的rest方式不同,动态调用baseDao的相应方法,这个很关键,实现了它才能自动分配方法调用,才能只需要建立一个数据表,就自动提供CURD基本访问功能。还好,动态语言能很方便的实现这种功能,感慨一下,node.js更方便且符合习惯^_^

1
2
3
4
5
6
7
8
复制代码    method = {
"GET": "retrieve",
"POST": "create",
"PUT": "update",
"DELETE": "delete"
}

getattr(BaseDao(table), method[request.method])(params, [], {})

说明:

  • table是前一章中解析出来的数据表名,这块就是users;
  • method应该是定义一个常量对象,对应rest的动词,因为对ypthon不熟,定义了一个变量先用着,查了下常量说明,看着好复杂;
  • request.method 客户请求的实际rest动词;
  • params是前一章中解析出来的参数对象;

完整代码

1
2
3
4
复制代码git clone https://github.com/zhoutk/pyrest.git
cd pyrest
export FLASK_APP=index.py
flask run

小结

至此,我们已经实现了基本的框架功能,以后就是丰富它的羽翼。比如:session、文件上传、跨域、路由改进(支持无缝切换操作数据库的基类与子类)、参数验证、基础查询功能增强(分页、排序、模糊匹配等)。感慨一下,好怀念在node.js中json对象的写法,不用在key外加引号。

补丁

刚把基础数据库访问类中的insert方法的参数形式改成了字典,结果异常信息也正常了,文章不再改动,有兴趣者请自行查阅源代码。

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

dubbo源码解析-逻辑层设计之服务降级

发表于 2017-12-12


前言
–

在dubbo服务暴露系列完结之后,按计划来说是应该要开启dubbo服务引用的讲解.但是现在到了年尾,一些朋友也和我谈起了明年跳槽的事.跳槽这件事,无非也就两个原因,一个是钱没给够,另一个是心里委屈了.首先钱没给够这件事我就不多说什么了,因为每个人都觉得自己钱没给够.那心里为啥委屈了?作为一个技术人,我认为心里委屈,无非也就是两个原因,一个是在公司得不到重视,另一个是感觉学不到东西,得不到成长.感觉我所了解到的情况,往往是后者居多.这道理也简单,我们小时候看重的是兴趣和爱好,长大后看重的是投资和回报.

互联网行业加班是常态,首先我不反对加班.但是加班带来的影响也不得不重视.首先对于老板而言,员工加班越多,获得的利润更大.比如:

我们老板来上班事开了一辆崭新的兰博基尼。
我说:“哇喔,这辆车好牛逼”
他回答:“如果你努力工作,全身心投入,力求卓越,那么我明年还会再有一辆”。

再次,一旦你加班多了,学习新技术的时间就少,这样你就会变得越来越不自信,自然不敢随便跳槽.但是加班和学习其实并不冲突.即使加班再多(比如今天周六我也还在加班),肥朝每周一篇dubbo源码解析与你不见不散,请放心保持密切关注肥朝.

如果你是因为学不到东西跳槽,那么往往会遇到一个问题,那就是公司的项目太low,找一下家的时候拿不出手.技术上,增删改查这种东西肯定是拿不出手的,我用freemarker + mybatis generator做代码生成器直接一键生成就可以直接运行跑起来.业务上,每个公司的业务都不同,讨论意义不大.想去网上找点资料装逼,但是却发现无从下手.好不容易找了点技术沙龙逼格高的PPT,不幸的是都是空谈理论,无法和自己公司的项目衔接起来.但是幸运的是无意中关注了肥朝.

首先我们来看看,我在三张图看清项目结构提到的中小型公司比较典型的Java项目的架构.

从中可以看出,典型的就是三层结构,

  • 接入层,逻辑层,数据存储层.

当然也可以分成四层

  • 接入层,逻辑层,原子服务层,数据存储层.

当然是可以分成五层

  • 接入层,序列化层(异步消息队列),原子服务层,数据层,数据存储层.

当然分几层都要根据自身业务,好的架构并不是一蹴而就,而是逐渐演变的过程.从标题就可以知道,本篇着重介绍逻辑层的设计(那剩下的什么时候讲?反正每周一篇,一年也就48篇.dubbo系列完结之后下一个系列由你来定,你可以自己估算一下时间).既然是设计,那么就不能纸上谈兵,必须站在巨人的肩膀上,比如孙玄老师分享的58同城架构设计就很有参考意义.我简单用思维导图做了个总结.如下:


看到这里有朋友可能就不乐意了,不要扯这些原则,老子拿起键盘就是干.


我想说的是,如果没有读万卷书,即使行了万里路,也不过是个邮差.知道了理论,下面我们直入主题,开始实战.

插播面试题

  • 谈一下你们项目架构设计(很多人在回答这个的时候都容易回答SSH或者SSM,注意,所谓是SSH这些是技术选型,不是架构的设计)

  • 既然你们项目用到了dubbo,那你讲讲你们是怎么通过dubbo实现服务降级的,降级的方式有哪些,又有什么区别?

  • dubbo监控平台能够动态改变接口的一些设置,其原理是怎样的?

  • 既然你说你看过dubbo源码,那讲一下有没有遇到过什么坑?(区分度高,也是检验是否看过源码的试金石)


    直入主题


我们从两个角度来分析,一个是为什么需要服务降级,一个是怎么做服务降级

为什么需要服务降级

引进一个新技术,必须要看这个新技术解决了什么问题.比如服务降级,他解决了什么问题?从上面的思维导图我们就知道,当网站处于高峰期时,并发量大,服务能力有限,那么我们只能暂时屏蔽边缘业务.那么具体的例子是什么?

比如在某宝某东购物,当支付完成,会向你推荐一些商品.但是在11大促中,并发量过大.我们就要保证”支付”这些核心业务的正常运行,因此像”推荐商品”这些边缘业务,我们就可以不调用,从而减少一定的并发.但是如果双11我先把”推荐商品”接口的代码屏蔽起来,等过后我再打开.这种太简单粗暴的方法肯定不是我们的理想追求,这时候我们就需要一个”服务开关”一样的东西.这个开关,就是服务降级

怎么做服务降级

空谈误国,实战兴邦,光知道思维导图上的这些设计原则还不行,我们以dubbo为例,实战一下服务降级.首先dubbo中的服务降级分成两个

  • 屏蔽(mock=force)
  • 容错(mock=fail)

这两个有什么区别呢?我们引用文档介绍

mock=force:return+null 表示消费方对该服务的方法调用都直接返回 null 值,不发起远程调用。用来屏蔽不重要服务不可用时对调用方的影响。

还可以改为 mock=fail:return+null 表示消费方对该服务的方法调用在失败后,再返回 null 值,不抛异常。用来容忍不重要服务不稳定时对调用方的影响。

那么下面分别演示一下容错的使用方法

首先我们打上断点,造成请求超时,报错如下


配置容错


报错信息立刻消除,结果如下


屏蔽就不在演示,配置方式类似,效果自己调试.

其实从文档介绍我们就能回答出两者的区别.但是老司机可能更享受的是扒光原理的快感.

源码分析

首先我假设你之前看过肥朝每周一篇dubbo源码解析,那么对MockClusterInvoker这个类就不会陌生,那么我们直接看核心代码(应群友反馈,尝试一下代码不贴图)

从no mock(正常情况),force:direct mock(屏蔽),fail-mock(容错)三种情况我们也可以看出,普通情况是直接调用,容错的情况是调用失败后,返回一个设置的值.而屏蔽就很暴力了,直接连调用都不调用,就直接返回一个之前设置的值.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
复制代码public Result invoke(Invocation invocation) throws RpcException {
Result result = null;

String value = directory.getUrl().getMethodParameter(invocation.getMethodName(), Constants.MOCK_KEY, Boolean.FALSE.toString()).trim();
if (value.length() == 0 || value.equalsIgnoreCase("false")){
//no mock
result = this.invoker.invoke(invocation);
} else if (value.startsWith("force")) {
if (logger.isWarnEnabled()) {
logger.info("force-mock: " + invocation.getMethodName() + " force-mock enabled , url : " + directory.getUrl());
}
//force:direct mock
result = doMockInvoke(invocation, null);
} else {
//fail-mock
try {
result = this.invoker.invoke(invocation);
}catch (RpcException e) {
if (e.isBiz()) {
throw e;
} else {
if (logger.isWarnEnabled()) {
logger.info("fail-mock: " + invocation.getMethodName() + " fail-mock enabled , url : " + directory.getUrl(), e);
}
result = doMockInvoke(invocation, e);
}
}
}
return result;
}

敲黑板画重点

有句话叫尽信书不如无书,dubbo中也难免存在一些bug,比如我之前在dubbo源码解析-router就提到过,这个监控平台是有bug的,如今又出现了


你会发现点删除或者点启用和禁用后,会出现多条.解决办法还是和之前一样.清除zookeeper上的节点信息(不懂的可以点回去看看router这篇,这就是我之前反复强调的,一定要系统学习,因为之前提到的我往往一笔带过)

从这个解决bug中,我们也应该有一些逆向思维.为什么这个监控平台这么神奇,能动态改变接口的一些默认设置?你清除了zookeeper节点,监控平台上的一些配置信息就消失了,很明显,他这个原理就是改变注册在zookeeper上的节点信息.从而zookeeper通知重新生成invoker(这些具体细节在zookeeper创建节点,zookeeper连接,zookeeper订阅中都详细讲了,这里不再重复)

当然除了这些坑外,dubbo在集群容错算法中的轮询就有个坑,需要调节当前时间解决(因为这个使用不多,这里暂时不细说),但是当当网的dubbox有一个坑就比较明显.如下图,当你传的参数为null时,这里就有很明显的空指针


在后面当当网也解决了这个问题


当然他这个修复的代码也可以出一个面试题.当然这道题我就不解答了

java中 || 和 | 有什么区别

把握重点

看到这里你就必须要把握一下重点.从标题你就知道,本篇是有三个关键词,分别是dubbo,逻辑层设计,服务降级

我用dubbo演示一种服务降级的方式,并不代表是只有这一种方式,你如果仔细看思维导图就明白,其实也还有很多实现方式.另外如果你觉得你们项目比较low,那么你可以设想,假如用思维导图的这些原则来设计,那么要怎么设计,有什么优缺点?然后自己尝试改造一下,这思考和行动的过程,才是你最宝贵的收获,也是我想传达的学习思想.如果你把重点当成了dubbo如何配置服务降级,那么可能再好的项目,你都只看到了增删改查.

写在最后

写到这里的时候,不知不觉已经是凌晨四点.因为996的模式下,要坚持每周一篇对我来说也是一个挑战,但是同时我也享受着这种挑战的感觉.每次下班的时候,遇到熟人都会问我怎么这么晚才下班,加班这么多,一定很多加班费吧.我说,没有加班费,这个时候都会很自然的反驳到,没有加班费那干嘛加班.同样的道理,一些朋友看到我写博客,也会问,你每周都坚持写博客有钱赚吗,我说没有,这个时候正常的逻辑也是反驳到,没有钱干嘛要做.其实有时候,博客既是写给别人看的,也是写给自己看的.最重要的是从一件事上,看到一个人做事的决心

期待下周与你相遇.鉴于本人才疏学浅,不对的地方还望斧正,也欢迎关注我的简书,名称为肥朝

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

laravel使用技巧之查询构造器Query Builder

发表于 2017-12-12

今天给大家介绍一下laravel查询构造器的一个小技巧,在官方文档示例中没有详细提到,也不是啥高端技巧,可能很多人在用了,不知道的同学可以看看。

在业务代码中经常会根据不同条件来查询,举个简单例子,我们现在要查询用户列表,按时间倒序排列,可能会有status和type作为限定条件。

一开始我是这样写的

1
2
3
4
5
6
7
8
9
复制代码    if($status && $type) {
$users = User::where('status', $status)->where('type', $type)->latest()->get();
} else if ($status) {
$users = User::where('status', $status)->latest()->get();
} else if ($type) {
$users = User::where('status', $type)->latest()->get();
} else {
$users = User::latest()->get();
}

这个代码真的很丑陋,很多公共代码,比如->latest()->get(),写了四遍,如果产品说今天我们要正序排列,那你得改四个地方。虽然借助编辑器改一下也很快,不过要知道这只是个最简单的例子。

看了下文档有个when方法进行条件判断,一堆闭包也不是很理想。我坚信肯定有更优雅的写法,于是上stackoverflow搜了一波,果然万能的歪果仁给了我答案。

改进后的写法:

1
2
3
4
5
6
7
8
9
10
11
12
复制代码    $query = User::query();
// 如果用DB: $query = DB::table('user');

if ($status) {
$query->where('status', $status);
}

if ($type) {
$query->where('type', $type);
}

$users = $query->latest()->get();

用变量保存查询构造器实例,然后在其上叠加约束条件,最后get集合。公共部分放在首尾,结构清晰,是不是高下立判啊?

而且我们还可以把$query当成参数传入方法或函数中,将公共逻辑封装在一起,方便多处调用:

1
2
3
4
5
6
7
8
9
复制代码    function foo($query) {
$query->with(['girl', 'gay'])
->latest()
->get();
}


$query = User::query();
$users = foo($query);

这种写法有一个注意事项,一旦你在$query上调用where等约束方法,就会改变此query,有时候我们需要提前clone一个query。

举例说明,比如我们同时要拿到type为1和2的users

1
2
3
4
5
6
7
8
复制代码    $query_1 = User::query();
$query_2 = clone $query_1;


$users_1 = $query_1->where('type', 1)->latest()->get();
$users_2 = $query_2->where('type', 2)->latest()->get();
// 错误 $users_2 = $query_1->where('type', 1)->latest()->get();
// 这样写得到得是type = 1 and $type = 2

laravel的文档里虽然没有写这种示例,但是提了一下:

你可以使用 DB facade 的 table 方法开始查询。这个 table 方法针对查询表返回一个查询构造器实例,允许你在查询时链式调用更多约束,并使用 get 方法获取最终结果

题外话

以前听一些老前辈说他们不要只会百度的程序员,当时感觉真装哔,不都是搜索引擎,因为我那时不用google。现在我也不愿意和只会百度的共事了,百度只是个广告搜索嘛,搜出来的都是些啥玩意。

google、stackoverflow真是个好东西,很多歪果仁知识丰富,解答专业,从计算机历史到操作系统、数据库、各种编程语言,帮我de了好多bug。在segmentfault这么打广告是不是不好,溜了!

Reference:
How to create multiple where clause query using Laravel Eloquent? - stackoverflow
Model::query - laravelAPI

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

MySQL基础入门——MySQL与R语言、Python交互

发表于 2017-12-12

MySQL作为z最为流行的关系型数据库管理平台之一,与绝大多数数据分析工具或者编程语言都有接口,今天这一篇分享如何将MySQL与R语言、Python进行连接。

R语言中与SQL管理平台通讯的接口包有很多,可以根据自己使用的数据库平台类型以及习惯,挑选合适的接口包。因为我个人笔记本使用的MySQL平台,所以本篇仅以MySQL为例分享。(如果你需要其他平台的接口导入方案,可以直接在csdn博客上搜关键字,有很多博客资料可以参考)。

我习惯使用的接口包是RMySQL,里面的核心函数主要涉及数据库连接,数据读写,数据查询三个方面,以下是三个方面的内容实例。

R与数据库的连接: library(“RMySQL”)library(“magrittr”)

数据库连接语句:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
复制代码conn <- dbConnect(
MySQL(), #数据库平台类型
dbname=”db1”, #数据库名称
username=”root”, #登录账号(MySQL初始安装时设置的账号)
password=”**“, #登录密码(MySQL初始安装时设置的密码)
host=”127.0.0.1”, #地址
port=3306 #端口号
)
summary(conn) #查看连接信息:
User: root
Host: 127.0.0.1
Dbname: db1
Connection type: 127.0.0.1 via TCP/IP
dbGetInfo(conn) #查看连接详细信息(列表形式) $host
‘127.0.0.1’
$user
‘root’
$dbname
‘db1’
$conType
‘127.0.0.1 via TCP/IP’
$serverVersion
‘5.7.17-log’
$protocolVersion
10
$threadId
11
$rsId
dbListTables(conn) #查看该数据库连接内的表信息 ‘birthdays’ ‘company’ ‘dataanalyst’ ‘foodranking’ ‘foodtypes’ ‘orderinfo’ ‘str_date’ ‘userinfo’
dbDisconnect(conn) #关闭连接(数据通讯完成之后再运行)
R语言与MySQL数据库读写:

(mydata <- iris)
dbWriteTable(
conn = conn, #连接名称
name = "mydata", #指定导入后的表名
value = iris, #指定要导入的R内存空间数据对象
row.names = FALSE #忽略行名
) #写表
dbListTables(conn)
'birthdays' 'company' 'dataanalyst' 'foodranking' 'foodtypes' 'mydata' 'orderinfo' 'str_date' 'userinfo'

mydata1 <- dbReadTable(
conn = conn, #连接名称
name = "mydata" #数据库中的表名
) #读表
head(mydata1,10)



以上读写都是一次性操作,不能在读写的同时执行条件筛选等步骤,通常我们需要使用查询方式来获取指定条件的数据并返回数据框。

1
2
3
4
5
6
7
8
复制代码result1 <- dbSendQuery(conn = conn,  
statement = "SELECT * from mydata where `Sepal.Length` between 4 and 5
and Species = 'setosa' "
#查询条件
) %>% dbFetch()
#将查询结果返回数据框
head(result1,10)
dbClearResult(result1) #清除查询(释放内存)

这一句清除的是查询,即上一句中的dbSendQuery部分(布包含后面的dbFetch,我只是为了方便一次性输出了)。

1
2
3
4
复制代码dbRemoveTable(conn,"mydata")   #删除表    
dbListTables(conn)
'birthdays' 'company' 'dataanalyst' 'foodranking' 'foodtypes' 'orderinfo' 'str_date' 'userinfo'
dbDisconnect(conn) #断开连接

Python:

Python与MySQL连接:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
复制代码from sklearn.datasets import load_iris
import pandas as pd
from sqlalchemy import create_engine
import MySQLdb
conn=MySQLdb.connect(
host="localhost", #地址
user="root", #登录名(同上)
passwd="******", #登录密码(同上)
db="db1", #要连接的数据库名称
charset="utf8" #声明数据编码
)

engine = create_engine('mysql+mysqldb://root:password@localhost:3306/db1?charset=utf8')
#使用 sqlalchemy接口连接连接

Python与MySQL数据读写操作:

Pandas库中有封装过的数据读写函数,可以直接针对连接后的数据进行数据读写,非常方便。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
复制代码iris = load_iris()
mydata = pd.DataFrame(
iris.data[:,:],
columns=['sepal_length','sepal_width','petal_length','petal_width']
)
mydata.head(10)
#将数据框直接写入MySQL
mydata.to_sql(
name = "iris",
con = engine
)
#直接读取MySQL中的表:
mydata1 = pd.read_sql_table(
table_name= "str_date",
con =engine
)
#通过查询过滤条件获取表数据:
mydata1 = pd.read_sql_query(
sql = "SELECT * from iris where sepal_length between 4 and 5 and petal_width != 0.2 ",
con =engine
)




你可以通过以上MySQLlb接口建立的连接来执行查询操作!

1
2
3
4
5
6
复制代码cursor = conn.cursor()         #获取操作游标 
sql = "SELECT * from iris where sepal_length between 4 and 5 and petal_width != 0.2"
cursor.execute(sql) # 使用execute方法执行SQL语句
cursor.fetchall() #获取查询数据
cursor.close() # 关闭游标
conn.close() # 关闭数据库连接


总觉得MySQLlb的接口使用起来过于复杂,不直观,输出数据也不友好,还好pandas支持sqlalchemy的链接,使用pandas里面的函数可以基本满足写表、读表、执行查询的需要。

以上仅仅是MySQL与R语言、Python交互的基础函数,当然还有更为复杂的增删以及插入命令,如果需要了解详细内容可以参考RMySQL、sqlalchemy库的官方文档。

在线课程请点击文末原文链接:

Hellobi Live | R语言可视化在商务场景中的应用

往期案例数据请移步本人GitHub: github.com/ljtyduyu/Da…

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

如何编写简单的linux内核模块

发表于 2017-12-12

获取Ring-0权限

尽管Linux系统为应用程序提供了强大而丰富的API,但是有的时候,这些还远远不够。当我们需要与硬件进行交互,或者需要访问系统中特权信息的时候,应用API就爱莫能助了,这时我们必须借助内核模块。

Linux内核模块是一段编译好的二进制代码,可以直接插入到Linux内核空间中,在ring 0级别运行,该级别不仅仅是x86-64处理器的最低层的运行级别,同时也是安全限制最少的一个级别。由于这里的代码完全不受限制,所以能够以令人难以置信的速度飞速运行,同时,它们还可访问系统中的任意内容。

来到内核世界

编写一个Linux内核模块并不是一件容易的事情。在修改内核的时候,您将面临数据丢失和系统损坏的风险。对于常规Linux应用程序来说,系统为它们提供了相应的安全网作为保护,但是内核代码却完全不是这样:内核代码一旦出现故障,将会锁定整个系统。

更糟糕的是,内核代码中出现的问题可能不会马上显现出来。如果内核模块加载后,系统立即被锁定的话,这还算是“最佳情况”。随着向模块添加的代码越来越多,我们将面临引入死循环和内存泄漏的风险。如果你不小心犯了这样的错误,随着机器的继续运行,这些代码占用的内存会持续增长。那么,最终会导致重要的内存结构,甚至缓冲区都被覆盖掉。

对于内核模块来说,传统应用程序的大部分开发范式都不适用。除了加载和卸载模块外,我们还需编写代码来响应系统事件,因为这里的代码并非以串行的模式运行。对于内核开发来说,我们要编写的是供应用程序使用的API,而非应用程序本身。

除此之外,我们在内核空间也无法访问各种标准库。虽然内核提供了一些常用的函数,比如printk(用作printf的替代品)和kmalloc(作用与malloc类似),但是大部分情况下,都需要我们亲自跟设备打交道。此外,在卸载模块时,我们必须亲自完成相关的清理工作,因为这里没有提供垃圾收集功能。

先决条件

在开始编写内核模块之前,我们需要确保已经准备好了得心应手的工具。最重要的是,你需要有一台Linux机器。虽然任何Linux发行版都可以满足我们的要求,但是在本文中,我使用的是Ubuntu 16.04 LTS,所以,如果你使用了其他版本的话,在安装的过程中,可能需要稍微调整一下相关的安装命令。

其次,你还需要一台单独的物理机器或虚拟机。虽然我更喜欢在虚拟机上完成这些工作,但是读者完全可以根据自己的喜好来作出决定。我不建议使用您的工作主机,因为一旦出错,就很可能会发生数据丢失的情况。同时,我们在编写内核模块的过程中,一般至少会锁定机器许多次,这个是不用怀疑的。内核出乱子的时候,最近更新的代码很可能还在向缓冲区中写入内容,所以,这就可能导致源文件损坏。如果在虚拟机上进行测试的话,就能够消除这种风险。

最后,您至少需要对C语言有一些基本的了解。由于C++运行时对于内核来说占用的空间太多了,因此,编写C代码对于内核开发来说是非常重要的。此外,为了与硬件进行交互,了解一些汇编语言方面的知识也是非常有帮助的。

安装开发环境

在Ubuntu上,我们需要运行下列命令:

1
apt-get install build-essential linux-headers-`uname -r`

上面的命令将安装必要的开发工具,以及这个示例内核模块所需的内核头文件。

对于下面的示例内核模块,我们假设读者是以普通用户身份运行的,而不是root用户,但是,要求读者拥有sudo权限。对于非root用户来说,sudo在加载内核模块时是必须的,尽管这样有些麻烦,但我们希望尽可能以root之外的身份来完成内核模块开发工作。

踏上征程

从现在开始,我们就要开始编写代码了。好了,让我们先准备好工作环境:

1
2
mkdir ~/src/lkm_example
cd ~/src/lkm_example

你可以启动自己最喜欢的编辑器(对于我来说,就是VIM),创建文件lkm_example.c,并输入以下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <linux/init.h>
#include <linux/module.h>
#include <linux/kernel.h>
MODULE_LICENSE(“GPL”);
MODULE_AUTHOR(“Robert W. Oliver II”);
MODULE_DESCRIPTION(“A simple example Linux module.”);
MODULE_VERSION(“0.01”);
static int __init lkm_example_init(void) {
 printk(KERN_INFO “Hello, World!\n”);
 return 0;
}
static void __exit lkm_example_exit(void) {
 printk(KERN_INFO “Goodbye, World!\n”);
}
module_init(lkm_example_init);
module_exit(lkm_example_exit);

现在,我们已经做好了一个最简单的内核模块,接下来,我会对一些重点内容加以详细说明:

·“includes”用于包含Linux内核开发所需的头文件。

· 根据模块的许可证的不同,MODULE_LICENSE可以设置为不同的值。要查看许可证的完整列表,请运行:

1
grep“MODULE_LICENSE”-B 27 / usr / src / linux-headers -`uname -r` / include / linux / module.h

· 我们将init(加载)和exit(卸载)函数定义为static类型,并让它返回一个int型数据。

· 注意,这里要使用printk函数,而不是printf函数。此外,printk与printf使用的参数也各不相同。例如,KERN_INFO(这是一个标志,用以声明相应的消息记录等级)在定义的时候并没有使用逗号。

· 在文件的最后部分,我们调用了module_init和module_exit函数,来告诉内核哪些是加载函数和卸载函数。这样的话,我们就能够给这些函数自由命名了。

当目前为止,我们仍然无法编译这个文件:我们还需要一个Makefile文件。有了它,这个简单的示例模块就算就绪了。请注意,make会严格区分空格和制表符,因此,在应该使用tab的地方千万不要使用空格。

1
2
3
4
5
obj-m += lkm_example.o
all:
 make -C /lib/modules/$(shell uname -r)/build M=$(PWD) modules
clean:
 make -C /lib/modules/$(shell uname -r)/build M=$(PWD) clean

如果我们运行“make”,正常情况下应该成功编译我们的模块。最后得到的文件是“lkm_example.ko”。如果在此过程中出现了错误消息的话,请检查示例源文件中的引号是否正确,并确保没有意外粘贴为UTF-8字符。

现在,可以将我们的模块插入内核空间进行测试了。为此,我们可以运行如下所示的命令:

1
sudo insmod lkm_example.ko

如果一切顺利的话,屏幕上面是不会显示任何内容的。这是因为,printk函数不会将运行结果输出到控制台,相反,它会把运行结果输出到内核日志。为了查看内核模块的运行结果,我们需要运行下列命令:

1
sudo dmesg

正常情况下,这里应该看到带有时间戳前缀的“Hello,World!”行。这意味着我们的内核模块已经加载,并成功向内核日志输出了相关的字符串。我们还可以通过下面的命令,来检查该模块是否仍然处于加载状态:

1
lsmod | grep “lkm_example”

要删除该模块,请运行下列命令:

1
sudo rmmod lkm_example

如果您再次运行dmesg,则会在日志中看到字符串“Goodbye, World!”。同时,您也可以再次使用lsmod来确认它是否已被卸载。

正如你所看到的那样,这个测试工作流程有点繁琐而乏味,为了实现自动化,我们可以在Makefile文件末尾添加下列内容:

1
2
3
4
5
test:
 sudo dmesg -C
 sudo insmod lkm_example.ko
 sudo rmmod lkm_example.ko
 dmesg

然后,运行下列命令:

1
make test

这样的话,要想测试模块并查看内核日志的输出的话,就不必专门来运行相应的命令了。

现在,我们已经打造好了一个五脏俱全,但是没有什么用处的内核模块!

打造更有趣的内核模块

接下来,让我们通过具体的例子来进一步了解内核模块的开发。虽然内核模块可以完成各种任务,但最常见的用途,恐怕就是与应用程序进行交互了。

由于应用程序无法直接查看内核空间内存的内容,因此,它们必须借助API与其进行通信。虽然从技术上来说有多种方法可以实现这一点,但最常见的方法却是创建一个设备文件。

实际上,您很可能早就跟设备文件打过交道了。比如,涉及/dev/zero、/dev/null或类似文件的命令,实际上就是在跟名为“zero”和“null”的设备进行交互,以返回相应的内容。

在我们的例子中,我们将返回“Hello,World”。虽然对于应用程序来说,这一功能没有多大的用途,但它却为我们详细展示了通过设备文件响应应用程序的具体过程。

下面是完整的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <linux/init.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/fs.h>
#include <asm/uaccess.h>
MODULE_LICENSE(“GPL”);
MODULE_AUTHOR(“Robert W. Oliver II”);
MODULE_DESCRIPTION(“A simple example Linux module.”);
MODULE_VERSION(“0.01”);
#define DEVICE_NAME “lkm_example”
#define EXAMPLE_MSG “Hello, World!n”
#define MSG_BUFFER_LEN 15
/* Prototypes for device functions */
static int device_open(struct inode *, struct file *);
static int device_release(struct inode *, struct file *);
static ssize_t device_read(struct file *, char *, size_t, loff_t *);
static ssize_t device_write(struct file *, const char *, size_t, loff_t *);
static int major_num;
static int device_open_count = 0;
static char msg_buffer[MSG_BUFFER_LEN];
static char *msg_ptr;
/* This structure points to all of the device functions */
static struct file_operations file_ops = {
 .read = device_read,
 .write = device_write,
 .open = device_open,
 .release = device_release
};
/* When a process reads from our device, this gets called. */
static ssize_t device_read(struct file *flip, char *buffer, size_t len, loff_t *offset) {
 int bytes_read = 0;
 /* If we’re at the end, loop back to the beginning */
 if (*msg_ptr == 0) {
 msg_ptr = msg_buffer;
 }
 /* Put data in the buffer */
 while (len && *msg_ptr) {
 /* Buffer is in user data, not kernel, so you can’t just reference
 * with a pointer. The function put_user handles this for us */
 put_user(*(msg_ptr++), buffer++);
 len--;
 bytes_read++;
 }
 return bytes_read;
}
/* Called when a process tries to write to our device */
static ssize_t device_write(struct file *flip, const char *buffer, size_t len, loff_t *offset) {
 /* This is a read-only device */
 printk(KERN_ALERT “This operation is not supported.\n”);
 return -EINVAL;
}
/* Called when a process opens our device */
static int device_open(struct inode *inode, struct file *file) {
 /* If device is open, return busy */
 if (device_open_count) {
 return -EBUSY;
 }
 device_open_count++;
 try_module_get(THIS_MODULE);
 return 0;
}
/* Called when a process closes our device */
static int device_release(struct inode *inode, struct file *file) {
 /* Decrement the open counter and usage count. Without this, the module would not unload. */
 device_open_count--;
 module_put(THIS_MODULE);
 return 0;
}
static int __init lkm_example_init(void) {
 /* Fill buffer with our message */
 strncpy(msg_buffer, EXAMPLE_MSG, MSG_BUFFER_LEN);
 /* Set the msg_ptr to the buffer */
 msg_ptr = msg_buffer;
 /* Try to register character device */
 major_num = register_chrdev(0, “lkm_example”, &file_ops);
 if (major_num < 0) {
 printk(KERN_ALERT “Could not register device: %d\n”, major_num);
 return major_num;
 } else {
 printk(KERN_INFO “lkm_example module loaded with device major number %d\n”, major_num);
 return 0;
 }
}
static void __exit lkm_example_exit(void) {
 /* Remember — we have to clean up after ourselves. Unregister the character device. */
 unregister_chrdev(major_num, DEVICE_NAME);
 printk(KERN_INFO “Goodbye, World!\n”);
}
/* Register module functions */
module_init(lkm_example_init);
module_exit(lkm_example_exit);

测试加强版的示例代码

现在,我们的示例代码已经不仅限于在加载和卸载过程中输出相应的消息了,所以,我们需要一个限制性较小的测试例程。接下来,让我们修改Makefile,让它只加载模块,而不进行卸载。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
obj-m += lkm_example.o
all:
  make -C /lib/modules/$(shell uname -r)/build M=$(PWD) modules
clean:
  make -C /lib/modules/$(shell uname -r)/build M=$(PWD) clean
test:
  # We put a — in front of the rmmod command to tell make to ignore
  # an error in case the module isn’t loaded.
  -sudo rmmod lkm_example
  # Clear the kernel log without echo
  sudo dmesg -C
  # Insert the module
  sudo insmod lkm_example.ko
  # Display the kernel log
  dmesg

现在,运行“make test”的时候,该示例代码将会输出设备的主编号。在本例中,这个编号是由内核自动分配的。但是,您后面会用到这个值来创建设备。

接下来,我们需要利用运行“make test”后得到的值,创建一个设备文件,以便从用户空间与内核模块进行通信。

1
sudo mknod /dev/lkm_example c MAJOR 0

(在上面的例子中,请用运行“make test”或“dmesg”时获得的值来替换MAJOR)

mknod命令中的“c”的作用是,告诉mknod我们要创建一个字符设备文件。

现在,我们就可以通过该设备来获取相应的内容了:

1
cat /dev/lkm_example

甚至可以借助“dd”命令:

1
dd if=/dev/lkm_example of=test bs=14 count=100

此外,您也可以通过应用程序来访问该设备。当然,这些应用不要求一定是编译型的程序——即使Python、Ruby和PHP这样的脚本程序,也照样可以访问这些数据。

当我们使用完这些设备时,可以将其删除,并卸载相应的内核模块:

1
2
sudo rm /dev/lkm_example
sudo rmmod lkm_example

结束语

在这篇文章中,我们向读者介绍了如何编写简单的内核模块。虽然文章中提供的示例代码非常的简单,但是展示的内核模块构建过程却是通用的,读者完全可以据此编写出功能复杂的内核模块。

但是一定要牢记,在内核模式下,整台机器都是你的地盘,你就是这里的王者。你的代码没有保护网,也没有重新来过的机会。如果您正在跟客户洽谈一个内核模块有关的项目,请务必将预期的调试时间加倍(即使不是三倍的话)。这是因为,内核代码必须尽可能完美无瑕,只有这样才能确保系统运行的完整性和可靠性。

本文翻译自:https://blog.sourcerer.io/writing-a-simple-linux-kernel-module-d9dc3762c234 ,如若转载,请注明原文地址:
www.4hou.com/system/9053…

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

用 C# 自己动手编写一个 Web 服务器,第四部分——Se

发表于 2017-12-12

Session

在 上一篇文章 中,我们实现了 Web 服务器的路由功能,并实现了控制器的基本支持。本来,我们应该高高兴兴的继续向其中添加功能,不过马上就发现一个尴尬的问题————我们还没有 Session。更具体的说,我们一直在使用的 HttpListenerContext 只提供了 Request/Response,却没有 Session 属性。这意味着我们的服务器毫无记性,只能把每次请求都当作新的用户。

出现这种情况也是情理之中的。基础类库之中的 HttpListenerXXX 系列类为我们创建 Web 应用提供了一个很好的起点,但实现 Sesssion 则是 Web 框架的事情,并不是 Web 服务器的职责。有的同学可能会问,Web 服务器和 Web 框架的区别在哪?嗯,其实这个问题也没有严格的定义,不过一般来讲,Web 服务器通常是独立于编程语言和框架的,比如 Apache/Nginx 都有支持多种语言/框架的能力;IIS 通过插件也可以运行 PHP,并且 Web 服务器通常更关心基础设施方面的问题,包括站点管理、HTTP
压缩、证书、性能和吞吐量等。而 Web 框架一般是和具体的语言或平台绑定的,希望充分利用语言本身的特性来更好的支持业务逻辑,例如 Express(Nodejs)、Django(Python)、ASP.NET MVC(.Net)等。Session 这个东西,对于后端业务是非常必要的(区分用户是绝大多数后台系统的基本要求),但对于 Web 服务器却不是绝对必需的,而且会在一定程度上影响服务器的吞吐量,所以一般会把它放到 Web Framework 的层面去实现它。

Session 要求服务器有一定的机制去记住当前请求的用户。目前绝大多数的 Session 实现都是基于 Cookie 的。在具体实现层面,又需要考虑把多少内容放在 cookie 里的问题。主流的实现会把绝大多数内容放在服务器端,客户端只记录一个用于鉴别的 key,这种实现在网络流量以及安全性方面都是极好的,缺点是会占据较多的服务器空间。也有一些实现为了减轻服务器压力以及方便客户端处理,会把部分数据放到客户端,但这样又需要考虑安全性和数据丢失的问题。我们这里不讨论方案的优劣问题,为了示例的目的,采用第一种方案——即将所有内容保存在服务端。

此外,请允许我再多说一句:Session 是一种机制,没有什么规定要求它一定是位于内存中的。许多同学似乎误解了这一点,他们似乎认为只要 Session 就一定是使用内存的。事实当然不是这样,用其他的存储机制来保存 Session 是完全合法的。之所以有这样的误会,可能是因为大多数 Session 实现默认使用内存——因为这是最简单的方式。但许多 Web 框架都提供了诸如 Session Storage 或 Session Provider 这样的扩展点,以便将 Session 保存在其他地方,比如数据库或远程
Redis/Memcached。如果要实现跨多个服务器的分布式 Session,那么内存肯定不是一个好的选择。我们在这里的实现为了简化问题也使用了内存,但请务必清楚这一点:即 Session 并非一定要保存在内存中。

代码

本文的示例代码已经全部放到 Github,每篇文章关联的代码放在独立的分支,方便读者参考。因此,要获取本文示例代码,请使用如下命令:

1
2
复制代码git clone https://github.com/shuhari/web-server-succinctly-example.git
git checkout -b 04-session origin/04-session

实现

在开头部分我们说过,HttpListenerContext 并没有提供给我们一个 Session 接口,所以我们必须在它之上再封装一层,提供 Web 框架所需的功能。

首先声明 Session 接口。对于大多数典型使用场景,Session 可以当作一个字典:

1
2
3
4
5
6
复制代码public interface ISession
{
object this[string name] { get; set; }

void Remove(string name);
}

接下来,对 HttpListenerContext 进行再次封装(为了避免和 ASP.NET MVC 混淆,这里我们称为 HttpServerContext):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
复制代码public class HttpServerContext
{
public HttpServerContext(HttpListenerContext context)
{
_innerContext = context;
}

private readonly HttpListenerContext _innerContext;

public HttpListenerRequest Request => _innerContext.Request;

public HttpListenerResponse Response => _innerContext.Response;

public IPrincipal User { get; internal set; }

public ISession Session { get; internal set; }
}

对于已有的属性,我们可以直接委托过去。Session 则是需要我们声明的。另外,我们也重新声明了 User,这是因为默认的实现是只读的,并没有设置用户的方法(后续的用户验证部分我们还会用到它)。

接下来,我们需要把所有对 HttpListenerContext 的引用替换为 HttpServerContext。这涉及了大多数代码文件,但只是简单的替换动作,相信你可以自己完成。

在 MiddlewarePipeline 中的代码也需要稍作改动:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
复制代码    internal class MiddlewarePipeline
{
internal void Execute(HttpListenerContext context)
{
var serverContext = new HttpServerContext(context);

try
{
foreach (var middleware in _middlewares)
{
var result = middleware.Execute(serverContext);
... // 下同
}
}
catch (Exception ex)
{
...
}
}
}

控制器增加几个辅助方法,方便访问 Session:

1
2
3
4
5
6
7
8
复制代码public abstract class Controller
{
public HttpServerContext HttpContext { get; internal set; }

protected ISession Session => HttpContext.Session;

protected IPrincipal User => HttpContext.User;
}

一切就绪,我们实现一个处理 Session 的中间件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
复制代码public class SessionManager : IMiddleware
{
public SessionManager()
{
_sessions = new ConcurrentDictionary<string, Session>();
}

private const string _cookieName = "__sessionid__";

private ConcurrentDictionary<string, Session> _sessions;

public MiddlewareResult Execute(HttpServerContext context)
{
var cookie = context.Request.Cookies[_cookieName];
Session session = null;
if (cookie != null)
{
_sessions.TryGetValue(cookie.Value, out session);
}
if (session == null)
{
session = new Session();
var sessionId = GenerateSessionId();
_sessions[sessionId] = session;
cookie = new Cookie(_cookieName, sessionId);
context.Response.SetCookie(cookie);
}
context.Session = session;
return MiddlewareResult.Continue;
}

private string GenerateSessionId()
{
return Guid.NewGuid().ToString();
}
}

Session 的实现原理非常简单:用 Cookie 记录一个 key,对应服务器端中的数据,如果没有的话就新建一个。如果用于生产服务器的话,Cookie 是必须加密的,并且还要其他一些保护手段。由于实现加密需要引入很多代码,这里就不去实现了。郑重声明:虽然自己实现一个 Session 从原理上来讲并不复杂,要实现真正安全、正确且健壮的 Session 并非易事,并且 Session 也是很多黑客的攻击点。但除非你自认是安全方面的高手,请勿试图手造轮子,否则很容易引入未知的缺陷。

Session 中间件已经实现,我们可以把它加入处理管线中去:

1
2
3
4
5
6
7
8
9
10
11
复制代码class Program
{
static void RegisterMiddlewares(IWebServerBuilder builder)
{
builder.Use(new HttpLog());
// builder.Use(new BlockIp("::1", "127.0.0.1"));
builder.Use(new SessionManager());

// ... 下同
}
}

最后,对控制器代码稍作修改,看看是否真的生效了:

1
2
3
4
5
6
7
8
9
10
复制代码public class HomeController : Controller
{
public ActionResult Index()
{
int counter = (Session["counter"] != null) ? (int)Session["counter"] : 0;
counter++;
Session["counter"] = counter;
return "counter=" + counter;
}
}

打开浏览器,多刷新几次,你会看到计数器确实在增长,说明 Session 生效了。

我们已经实现了 Session,让服务器不再患有记忆丧失症。不过你或许没有意识到的是,这里为 HttpListenerContext 的封装也为后续的其他功能提供了一个很好的起点。在下一篇文章中,我们将引入视图引擎(View Engine)的支持,从而让框架能够输出真正的 HTML 页面,而不是硬编码的字符串。

系列文章

  • 用 C# 自己动手编写一个 Web 服务器 (索引)
  • 用 C# 自己动手编写一个 Web 服务器,第一部分——基础
  • 用 C# 自己动手编写一个 Web 服务器,第二部分——中间件
  • 用 C# 自己动手编写一个 Web 服务器,第三部分——路由
  • 用 C# 自己动手编写一个 Web 服务器,第四部分——Session
  • 用 C# 自己动手编写一个 Web 服务器,第五部分——视图引擎
  • 用 C# 自己动手编写一个 Web 服务器,第六部分——用户验证

本文转载自: 掘金

开发者博客 – 和开发相关的 这里全都有

1…908909910…956

开发者博客

9558 日志
1953 标签
RSS
© 2025 开发者博客
本站总访问量次
由 Hexo 强力驱动
|
主题 — NexT.Muse v5.1.4
0%