Decision Tree: How to find the path from the root to the desired terminal node
Prepare a fitted random forest
import random
import pandas as pd
from sklearn.ensemble.forest import RandomForestRegressor
from sklearn import tree
data = pd.DataFrame({"Y":[1,5,3,4,3,4,2], "X_1":["red", "blue", "blue", "red","red","blue", "red"],
"X_2":[18.4, 7.5, 9.3, 3.7, 5.2, 3.2, 5.2]})
data = pd.get_dummies(data)
X = data.drop(["Y"], axis=1)
y = data["Y"]
rf = RandomForestRegressor(n_estimators = 10, random_state = 1234)
rf.fit(X, y)
output:
RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
oob_score=False, random_state=1234, verbose=0, warm_start=False)
Find the path to desired terminal node
import pydotplus
import re
def return_node_path_to_max_prediction(onetree, verbose=True):
"""
@input: a tree from the sklearn randomforest
@output: the node path to maxmium terminal node
[[split_node_1], [split_node_2], ...]
[splite_node_1] = [var_index, cutoff, direction]
"""
if verbose:
print("Generating Tree Graph, it may take a while...")
dot_data = tree.export_graphviz(onetree,
out_file = None,
filled = True,
rounded = True,
special_characters = True)
graph = pydotplus.graph_from_dot_data(dot_data)
graph_ = {}
for edge in graph.get_edge_list():
graph_[edge.get_source()] = edge.get_destination()
# find all terminal node
terminal_node = {}
non_decimal = re.compile(r'[^\d.]+')
for node in graph.get_node_list():
if node.get_name() not in graph_:
if node.get_name() not in ["node", "edge"]:
value = node.get_label()
value = re.sub(r'.*v', 'v', value)
terminal_node[node.get_name()] = float(non_decimal.sub('', value))
# find the path down to the terminal with maximum predition value
flag = True
destination = max(terminal_node, key=terminal_node.get)
edge_list = graph.get_edge_list()
node_list = graph.get_node_list()
split_node = []
while flag:
myedge = [edge for edge in edge_list if edge.get_destination() == destination][0]
if int(myedge.get_destination()) - int(myedge.get_source()) > 1:
direction = "Right"
else:
direction = "Left"
mynode = [node for node in node_list if node.get_name() == myedge.get_source()][0]
var_val = re.findall(r"[-+]?\d*\.\d+|\d+", mynode.get_label())[:2]
# record the growing path:
# var_val[0]: Index of variable participating in splitting
# var_val[1]: cutoff point of the splitting
# direction: If Right, means greater than var_val[1];
# If Left, means no greater than var_val[1]
split_node.append([int(var_val[0]),float(var_val[1]),direction])
if verbose:
print(myedge.get_destination() + "<-" + myedge.get_source() +
": Split at Variable X" + var_val[0] + "; The cutoff is " + var_val[1] +
"; Turn " + direction)
destination = myedge.get_source()
if destination == "0":
flag = False
return [*reversed(split_node)]
Test:
return_node_path_to_max_prediction(rf[1], verbose=True)
Outputs:
Generating Tree Graph, it may take a while...
3<-1: Split at Variable X0; The cutoff is 5.6; Turn Right
1<-0: Split at Variable X0; The cutoff is 12.95; Turn Left
From the output above, we know the path from the root to the desired terminal node is :
Root[X0(<= 12.95)] -> X0 (>=5.6) -> Terminal Node
Collect Paths in the random forest
def collect_path(rf, verbose=True):
n_tree = len(rf)
result = []
for i in range(n_tree):
if verbose:
print("Construct the %s tree graph out of %s trees" %(i+1, n_tree))
result.append(return_node_path_to_max_prediction(rf.estimators_[i], verbose=False))
return result
Test:
result = collect_path(rf)
print(result)
Outputs:
Construct the 1 tree graph out of 10 trees
Construct the 2 tree graph out of 10 trees
Construct the 3 tree graph out of 10 trees
Construct the 4 tree graph out of 10 trees
Construct the 5 tree graph out of 10 trees
Construct the 6 tree graph out of 10 trees
Construct the 7 tree graph out of 10 trees
Construct the 8 tree graph out of 10 trees
Construct the 9 tree graph out of 10 trees
Construct the 10 tree graph out of 10 trees
[[[0, 4.2, 'Left']], [[0, 12.95, 'Left'], [0, 5.6, 'Right']], [[1, 0.5, 'Right'], [0, 8.4, 'Left']], [[0, 13.85, 'Left'], [0, 8.4, 'Left'], [1, 0.5, 'Right']], [[0, 8.4, 'Left'], [0, 6.35, 'Right']], [[0, 12.95, 'Left'], [0, 5.6, 'Right']], [[2, 0.5, 'Left'], [0, 5.35, 'Right']], [[1, 0.5, 'Right'], [0, 5.35, 'Right']], [[0, 13.85, 'Left'], [1, 0.5, 'Right'], [0, 8.4, 'Left'], [0, 5.35, 'Right']], [[0, 13.85, 'Left'], [0, 6.35, 'Right'], [0, 8.4, 'Left']]]
Summarize the decison region
def summarize_region(result, features):
decision_region = {k: [[] for _ in range(2)] for k in features}
for i in range(len(result)):
for j in range(len(result[i])):
if result[i][j][2] == "Left":
decision_region[features[result[i][j][0]]][0].append(result[i][j][1])
else:
decision_region[features[result[i][j][0]]][1].append(result[i][j][1])
decision_region_ = {}
for k in features:
try:
upper_bound = min(decision_region[k][0])
except ValueError:
upper_bound = "Unknown"
try:
lower_bound = max(decision_region[k][1])
except ValueError:
lower_bound = "Unknown"
decision_region_[k] = [lower_bound, upper_bound]
value_to_remove = ['Unknown', 'Unknown']
decision_region_ = {key: value for key, value in decision_region_.items() if value != value_to_remove}
value_to_remove = [0.5, 0.5]
decision_region_ = {key: value for key, value in decision_region_.items() if value != value_to_remove}
return (decision_region_)
Test:
features = X.columns
summarize_region(result, features)
Outputs:
{'X_1_blue': [0.5, 'Unknown'], 'X_1_red': ['Unknown', 0.5], 'X_2': [6.35, 4.2]}
From the output above, we know that the decision region:
{blue} * [6.35, 4.2]
But it seems that the region [6.35, 4.2] is not reasonable due to the poorly generated data. But it may happens in some situations, which may require us to come up with new ways to ensemble these terminal nodes.