53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
class TreeNode:
|
|
def __init__(self, value, parent=None):
|
|
self.value = value
|
|
self.parent = parent
|
|
self.children = []
|
|
|
|
def add_child(self, child_node):
|
|
child_node.parent = self
|
|
self.children.append(child_node)
|
|
|
|
def remove_child(self, child_node):
|
|
self.children = [child for child in self.children if child is not child_node]
|
|
child_node.parent = None
|
|
|
|
def __repr__(self, level=0):
|
|
ret = "\t" * level + repr(self.value) + "\n"
|
|
for child in self.children:
|
|
ret += child.__repr__(level + 1)
|
|
return ret
|
|
|
|
def get_ancestors(self):
|
|
"""返回当前节点及其所有祖先节点"""
|
|
ancestors = []
|
|
node = self
|
|
ancestors.append(node)
|
|
while node.parent:
|
|
ancestors.append(node.parent)
|
|
node = node.parent
|
|
return ancestors
|
|
|
|
class Tree:
|
|
def __init__(self, root_value):
|
|
self.root = TreeNode(root_value)
|
|
|
|
def add_child(self, parent_node, child_value):
|
|
if not parent_node:
|
|
parent_node = self.root
|
|
parent_node.add_child(TreeNode(child_value))
|
|
|
|
def find(self, value):
|
|
return self._find_in_node(self.root, value)
|
|
|
|
def _find_in_node(self, node, value):
|
|
if node.value == value:
|
|
return node
|
|
for child in node.children:
|
|
found_node = self._find_in_node(child, value)
|
|
if found_node:
|
|
return found_node
|
|
return None
|
|
|
|
def __repr__(self):
|
|
return repr(self.root) |