diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index 39e2b70..43926e3 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -69,9 +69,14 @@ module DecisionTree # return classification if all examples have the same classification return data.first.last if data.classification.uniq.size == 1 - # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute) + # Choose best attribute: + # 1. enumerate all attributes + # 2. Pick best attribute + # 3. If attributes all score the same, then pick a random one to avoid infinite recursion. performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) } max = performance.max { |a,b| a[0] <=> b[0] } + min = performance.min { |a,b| a[0] <=> b[0] } + max = performance.shuffle.first if max[0] == min[0] best = Node.new(attributes[performance.index(max)], max[1], max[0]) best.threshold = nil if @type == :discrete @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] diff --git a/spec/id3_spec.rb b/spec/id3_spec.rb index 9e011ce..5eb67ba 100644 --- a/spec/id3_spec.rb +++ b/spec/id3_spec.rb @@ -74,4 +74,19 @@ describe describe DecisionTree::ID3Tree do Then { tree.predict([2, "blue"]).should == "not angry" } end + describe "infinite recursion case" do + Given(:labels) { [:a, :b, :c] } + Given(:data) do + [ + ["a1", "b0", "c0", "RED"], + ["a1", "b1", "c1", "RED"], + ["a1", "b1", "c0", "BLUE"], + ["a1", "b0", "c1", "BLUE"] + ] + end + Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "RED", :discrete) } + When { tree.train } + Then { tree.predict(["a1","b0","c0"]).should == "RED" } + end + end