HuffmanTree的python实现 -- 潘登同学的图论笔记
哈夫曼树
当用 n 个结点(都做叶子结点且都有各自的权值)试图构建一棵树时,如果构建的这棵树的带权路径长度最小,称这棵树为“最优二叉树”,
在构建哈弗曼树时,要使树的带权路径长度最小,只需要遵循一个原则,那就是:权重越大的结点离树根越近。在图 1 中,因为结点 a 的权值最大,所以理应直接作为根结点的孩子结点。
构建哈夫曼树的过程
- 在 n 个权值中选出两个最小的权值,对应的两个结点组成一个新的二叉树,且新二叉树的根结点的权值为左右孩子权值的和
- 在原有的 n 个权值中删除那两个最小的权值,同时将新的权值加入到 n–2 个权值的行列中,以此类推
- 重复 1 和 2 ,直到所以的结点构建成了一棵二叉树为止,这棵树就是哈夫曼树
话不多说,直接看代码
树节点实现
树节点基本上都是大同小异的
- root: 该节点是否为叶节点(不是则为None)
- value: 记录这个词
- frq: 记录这个词出现的频次(或者是某个父节点下所有frq之和)
- size: 记录这个某个父节点下的节点总数(主要用于画图)
class HuffmanTreeNode:
def __init__(self,
root=None,
value:str=None,
frq:int=0,
) -> None:
self.root=root
self.value = value
self.frq = frq
self.left = None
self.right = None
self.size = 1
def Setleft(self,left):
self.left = left
self.frq += left.Getfrq()
self.size += left.GetSize()
return self
def Setright(self,right):
self.right = right
self.frq += right.Getfrq()
self.size += right.GetSize()
return self
def Getfrq(self):
return self.frq
def Getvalue(self):
return self.value
def GetSize(self):
return self.size
def Hasright(self):
return self.right
def Hasleft(self):
return self.left
def Isroot(self):
return self.root
def __str__(self) -> str:
if self.root:
return f'root, sum of frequency:{self.frq}'
else:
return f'value: {self.value}, frequency: {self.frq}'
HuffmanTree实现
HuffmanTree主要有两个方法
- _buildHuffmanTree: 将词频字典输入,进行树的构建
- _iter_node: 在构建好的树中,获得某个词的编码(因为哈夫曼树就是用于解决编码的,但是后来有很多的作用,我就是从CBOW模型过来的)
class HuffmanTree:
def __init__(self,
num:dict) -> None:
# 对字典按照其values进行排序
self.num = sorted(num.items(),key=lambda x:x[1],reverse=False)
self.list = [] # 一个储存列表
self.coding = {} # 编码结果
self._buildHuffmanTree()
self._iter_node(self.list[0])
def _buildHuffmanTree(self):
self.list = [HuffmanTreeNode(root=False,value=i[0],frq=i[1]) for i in self.num]
while len(self.list) > 1:
# 将两个小的节点合并 小的放左边
right_node = self.list[1]
left_node = self.list[0]
# 注意pop顺序
self.list.pop(1)
self.list.pop(0)
temp_node = HuffmanTreeNode(root=True)
temp_node.Setright(right_node)
temp_node.Setleft(left_node)
# 将合并后的根节点放回list中
if len(self.list) == 1:
if temp_node.Getfrq() < self.list[0].Getfrq():
self.list.insert(0,temp_node)
else:
self.list.insert(1,temp_node)
elif len(self.list) == 0:
self.list.insert(0,temp_node)
else:
for i in range(len(self.list)-1):
if i == 0 and temp_node.Getfrq() <= self.list[i].Getfrq():
self.list.insert(i,temp_node)
continue
elif self.list[i].Getfrq() < temp_node.Getfrq() <= self.list[i+1].Getfrq():
self.list.insert(i+1,temp_node)
continue
elif i == len(self.list)-2 and temp_node.Getfrq() > self.list[i+1].Getfrq():
self.list.insert(i+2,temp_node)
continue
def getTree(self):
return self.list[0]
def _iter_node(self,node,code=''):
if node:
if not node.Isroot():
self.coding[node.Getvalue()] = code
self._iter_node(node.Hasleft(),code='0'+code)
self._iter_node(node.Hasright(),code='1'+code)
def getCode(self):
return self.coding
绘制HuffmanTree
画图函数与之前画红黑树的区别不大,改一改拿来用就行
class Draw_RBTree:
def __init__(self, tree):
self.tree = tree
def show_node(self, node, ax, height, index, font_size):
if not node:
return
x1, y1 = None, None
if node.left:
x1, y1, index = self.show_node(node.left, ax, height-1, index, font_size)
x = 100 * index - 50
y = 100 * height - 50
if x1:
plt.plot((x1, x), (y1, y), linewidth=2.0,color='b')
circle_color = 'mediumspringgreen'
text_color = 'black'
ax.add_artist(plt.Circle((x, y), 50, color=circle_color))
text = str(node.Getfrq()) if node.Isroot() else node.Getvalue() + '\n' + str(node.Getfrq())
ax.add_artist(plt.Text(x, y, text, color= text_color, fontsize=font_size, horizontalalignment="center",verticalalignment="center"))
# print(str(node.val), (height, index))
index += 1
if node.right:
x1, y1, index = self.show_node(node.right, ax, height-1, index, font_size)
plt.plot((x1, x), (y1, y), linewidth=2.0, color='b')
return x, y, index
def show_hf_tree(self, title):
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
left, right = self.get_left_length(), self.get_right_length(),
height = 2 * np.log2(self.tree.size + 1)
# print(left, right, height)
plt.ylim(0, height*100 + 50)
plt.xlim(0, 100 * self.tree.size + 100)
self.show_node(self.tree, ax, height, 1, self.get_fontsize())
plt.axis('off')
plt.title(title)
plt.show()
def get_left_length(self):
temp = self.tree
len = 1
while temp:
temp = temp.left
len += 1
return len
def get_right_length(self):
temp = self.tree
len = 1
while temp:
temp = temp.right
len += 1
return len
def get_fontsize(self):
count = self.tree.size
if count < 10:
return 30
if count < 20:
return 20
return 16
测试代码
if __name__ == '__main__':
num = {'a':10,'b':15,'c':12,'d':3,'e':4,'f':13,'g':1}
h = HuffmanTree(num)
tree = h.getTree()
d = Draw_RBTree(tree)
d.show_hf_tree('HuffmanTree')
print(h.getCode())