class TreeNode: def __init__(self, value, parent=None): self.value = value self.parent = parent self.children = [] def add_child(self, child_node): child_node.parent = self self.children.append(child_node) def remove_child(self, child_node): self.children = [child for child in self.children if child is not child_node] child_node.parent = None def __repr__(self, level=0): ret = "\t" * level + repr(self.value) + "\n" for child in self.children: ret += child.__repr__(level + 1) return ret def get_ancestors(self): """返回当前节点及其所有祖先节点""" ancestors = [] node = self ancestors.append(node) while node.parent: ancestors.append(node.parent) node = node.parent return ancestors class Tree: def __init__(self, root_value): self.root = TreeNode(root_value) def add_child(self, parent_node, child_value): if not parent_node: parent_node = self.root parent_node.add_child(TreeNode(child_value)) def find(self, value): return self._find_in_node(self.root, value) def _find_in_node(self, node, value): if node.value == value: return node for child in node.children: found_node = self._find_in_node(child, value) if found_node: return found_node return None def __repr__(self): return repr(self.root)