diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index c14f238..1d3658a 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -3,6 +3,8 @@ ### Copyright (c) 2007 Ilya Grigorik ### Modifed at 2007 by José Ignacio Fernández +require 'set' + module DecisionTree Node = Struct.new(:attribute, :threshold, :gain) @@ -28,7 +30,7 @@ module DecisionTree end data2 = data2.map do |key, val| - key + [val.sort_by { |_k, v| v }.last.first] + key + [val.sort_by { |_, v| v }.last.first] end @tree = id3_train(data2, attributes, default) @@ -41,9 +43,9 @@ module DecisionTree def fitness_for(attribute) case type(attribute) when :discrete - proc { |a, b, c| id3_discrete(a, b, c) } + proc { |*args| id3_discrete(*args) } when :continuous - proc { |a, b, c| id3_continuous(a, b, c) } + proc { |*args| id3_continuous(*args) } end end @@ -66,14 +68,13 @@ module DecisionTree @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] tree, l = {best => {}}, ['>=', '<'] - fitness = fitness_for(best.attribute) case type(best.attribute) when :continuous partitioned_data = data.partition do |d| d[attributes.index(best.attribute)] >= best.threshold end partitioned_data.each_with_index do |examples, i| - tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness) + tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0)) end when :discrete values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort @@ -83,7 +84,7 @@ module DecisionTree end end partitions.each_with_index do |examples, i| - tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0), &fitness) + tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0)) end end @@ -116,11 +117,14 @@ module DecisionTree # ID3 for discrete label cases def id3_discrete(data, attributes, attribute) - values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort - partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } } + index = attributes.index(attribute) + + values = Set.new + data.each { |d| values << d[index] } + partitions = values.to_a.sort.collect { |val| data.select { |d| d[index] == val } } remainder = partitions.collect { |p| (p.size.to_f / data.size) * p.classification.entropy }.inject(0) { |a, e| e += a } - [data.classification.entropy - remainder, attributes.index(attribute)] + [data.classification.entropy - remainder, index] end def predict(test)