# -*- coding: utf-8 -*-
"""
utilities that facilitate the propagation of hierarchical information (for example endmembers from clustering)
Created on Wed Aug 31 17:03:29 2016
@author: Suhas Somnath
"""
from __future__ import division, print_function, absolute_import, unicode_literals
import numpy as np
# TODO: Test and debug node and clusterTree classes for agglomerative clustering etc
[docs]
class Node(object):
"""
Basic unit of a tree - a node. Keeps track of its value, labels, parent, children, level in the tree etc.
"""
def __init__(self, name, value=None, parent=None, dist=0, labels=None, children=[], compute_mean=False,
verbose=False):
"""
Parameters
----------
name : (Optional) unsigned int
ID of this node
value : (Optional) 1D numpy array
Response corresponding to this Node.
parent : (Optional) unsigned int or Node object
Parent for this Node.
dist : (Optional) float
Distance between the children nodes
labels : (Optional) list or 1D numpy array of unsigned integers
Positions / instances in a main dataset within this cluster
children : (Optional) list of Node objects
Children for this node
compute_mean : (Optional) Boolean
Whether or not to compute the value attribute from the provided children
"""
self.children = children
self.parent = parent
self.name = name
self.value = value
self.dist = dist
self.level = 0
# Assign this node as the parent for all its children
for child in children:
child.parent = self
# If labels were not provided (tree node), get from children
if labels is None:
temp_labels = []
for child in self.children:
if verbose:
print('Child #{} had the following labels:'.format(child.name))
print(child.labels)
temp_labels.append(np.array(child.labels))
if verbose:
print('Labels (unsorted) derived from children for node #{}:'.format(name))
print(temp_labels)
self.labels = np.hstack(temp_labels)
self.labels.sort()
else:
if verbose:
print('Labels for leaf node #{}:'.format(name))
print(labels)
self.labels = np.array(labels, dtype=np.uint32)
# Compute the level for this node along with the number of children below it
if len(self.children) > 0:
self.num_nodes = 0
for child in self.children:
self.num_nodes += child.num_nodes
self.level = max(self.level, child.level)
self.level += 1 # because this node has to be one higher level than its highest children
else:
self.num_nodes = 1
if verbose:
print('Parent node:', str(name), 'has', str(self.num_nodes), 'children')
if all([len(self.children) > 0, value is None, compute_mean]):
resp = []
for child in children:
if verbose:
print(' Child node', str(child.name), 'has', str(child.num_nodes), 'children')
# primitive method of equal bias mean: resp.append(child.value)
# weighted mean:
resp.append(child.value * child.labels.size / self.labels.size)
# self.value = np.mean(np.array(resp), axis=0)
self.value = np.sum(np.array(resp), axis=0)
def __str__(self):
return '({}) --> {},{}'.format(self.name, str(self.children[0].name), str(self.children[1].name))
[docs]
class ClusterTree(object):
"""
Creates a tree representation from the provided linkage pairing. Useful for clustering
"""
def __init__(self, linkage_pairing, labels, distances=None, centroids=None):
"""
Parameters
----------
linkage_pairing : 2D unsigned int numpy array or list
Linkage pairing that describes a tree structure. The matrix should result in a single tree apex.
labels : 1D unsigned int numpy array or list
Labels assigned to each of the positions in the main dataset. Eg. Labels from clustering
distances : (Optional) 1D numpy float array or list
Distances between clusters
centroids : (Optional) 2D numpy array
Mean responses for each of the clusters. These will be propagated up
"""
self.num_leaves = linkage_pairing.shape[0] + 1
self.linkage = linkage_pairing
self.centroids = centroids
""" this list maintains pointers to the nodes pertaining to that cluster id for quick look-ups
By default this lookup table just contains the number indices of these clusters.
They will be replaced with node objects as and when the objects are created"""
self.nodes = list()
# now the labels is a giant list of labels assigned for each of the positions.
self.labels = np.array(labels, dtype=np.uint32)
""" the labels for the leaf nodes need to be calculated manually from the provided labels
Populate the lowest level nodes / leaves first:"""
for clust_id in range(self.num_leaves):
which_pos = np.where(self.labels == clust_id)
if centroids is not None:
self.nodes.append(Node(clust_id, value=centroids[clust_id], labels=which_pos))
else:
self.nodes.append(Node(clust_id, labels=which_pos))
for row in range(linkage_pairing.shape[0]):
"""print 'working on', linkage_pairing[row]
we already have each of these children in our look-up table"""
childs = [] # this is an empty list that will hold all the children corresponding to this node
for col in range(linkage_pairing.shape[1]):
""" look at each child in this row
look up the node object corresponding to this label """
childs.append(self.nodes[int(linkage_pairing[row, col])])
# Now this row results in a new node. That is what we create here and assign the children to this node
new_node = Node(row + self.num_leaves, children=childs, compute_mean=centroids is not None)
# If distances are provided, add the distances attribute to this node.
# This is the distance between the children
if distances is not None:
new_node.dist = distances[row]
# add this node to the look-up table:
self.nodes.append(new_node)
self.tree = self.nodes[-1]
[docs]
def __str__(self):
"""
Overrides the to string representation. Prints the names of the node and its children.
Not very useful for large trees
Returns
--------
String representation of the tree structure
"""
return str(self.tree)