diff --git a/examples/continuous-id3.rb b/examples/continuous-id3.rb index 26f9ee9..c980d1a 100644 --- a/examples/continuous-id3.rb +++ b/examples/continuous-id3.rb @@ -2,15 +2,25 @@ require 'rubygems' require 'decisiontree' include DecisionTree -# ---Continuous----------------------------------------------------------------------------------------- +# ---Continuous--- # Read in the training data -training, attributes = [], nil -File.open('data/continuous-training.txt','r').each_line { |line| - data = line.strip.chomp('.').split(',') +training = [] +File.open('data/continuous-training.txt', 'r').each_line do |line| + data = line.strip.chomp('.').split(',') attributes ||= data - training.push(data.collect {|v| (v == 'healthy') || (v == 'colic') ? (v == 'healthy' ? 1 : 0) : v.to_f}) -} + training_data = data.collect do |v| + case v + when 'healthy' + 1 + when 'colic' + 0 + else + v.to_f + end + end + training.push(training_data) +end # Remove the attribute row from the training data training.shift @@ -19,15 +29,25 @@ training.shift dec_tree = ID3Tree.new(attributes, training, 1, :continuous) dec_tree.train -#---- Test the tree.... +# ---Test the tree--- # Read in the test cases -# Note: omit the attribute line (first line), we know the labels from the training data +# Note: omit the attribute line (first line), we know the labels from the training data test = [] -File.open('data/continuous-test.txt','r').each_line { |line| - data = line.strip.chomp('.').split(',') - test.push(data.collect {|v| (v == 'healthy') || (v == 'colic') ? (v == 'healthy' ? 1 : 0) : v.to_f}) -} +File.open('data/continuous-test.txt', 'r').each_line do |line| + data = line.strip.chomp('.').split(',') + test_data = data.collect do |v| + if v == 'healthy' || v == 'colic' + v == 'healthy' ? 1 : 0 + else + v.to_f + end + end + test.push(test_data) +end # Let the tree predict the output and compare it to the true specified value -test.each { |t| predict = dec_tree.predict(t); puts "Predict: #{predict} ... True: #{t.last}"} +test.each do |t| + predict = dec_tree.predict(t) + puts "Predict: #{predict} ... True: #{t.last}" +end diff --git a/examples/discrete-id3.rb b/examples/discrete-id3.rb index d4dcf6d..ef44020 100644 --- a/examples/discrete-id3.rb +++ b/examples/discrete-id3.rb @@ -1,15 +1,25 @@ require 'rubygems' require 'decisiontree' -# ---Discrete----------------------------------------------------------------------------------------- +# ---Discrete--- # Read in the training data -training, attributes = [], nil -File.open('data/discrete-training.txt','r').each_line { |line| +training = [] +File.open('data/discrete-training.txt', 'r').each_line do |line| data = line.strip.split(',') attributes ||= data - training.push(data.collect {|v| (v == 'will buy') || (v == "won't buy") ? (v == 'will buy' ? 1 : 0) : v}) -} + training_data = data.collect do |v| + case v + when 'will buy' + 1 + when "won't buy" + 0 + else + v + end + end + training.push(training_data) +end # Remove the attribute row from the training data training.shift @@ -18,17 +28,31 @@ training.shift dec_tree = DecisionTree::ID3Tree.new(attributes, training, 1, :discrete) dec_tree.train -#---- Test the tree.... +# ---Test the tree--- # Read in the test cases -# Note: omit the attribute line (first line), we know the labels from the training data +# Note: omit the attribute line (first line), we know the labels from the training data test = [] -File.open('data/discrete-test.txt','r').each_line { |line| data = line.strip.split(',') - test.push(data.collect {|v| (v == 'will buy') || (v == "won't buy") ? (v == 'will buy' ? 1 : 0) : v}) -} +File.open('data/discrete-test.txt', 'r').each_line do |line| + data = line.strip.split(',') + test_data = data.collect do |v| + case v + when 'will buy' + 1 + when "won't buy" + 0 + else + v + end + end + training.push(test_data) +end # Let the tree predict the output and compare it to the true specified value -test.each { |t| predict = dec_tree.predict(t); puts "Predict: #{predict} ... True: #{t.last}"; } +test.each do |t| + predict = dec_tree.predict(t) + puts "Predict: #{predict} ... True: #{t.last}" +end # Graph the tree, save to 'discrete.png' -dec_tree.graph("discrete") +dec_tree.graph('discrete') diff --git a/examples/simple.rb b/examples/simple.rb index 8a0982e..e023675 100755 --- a/examples/simple.rb +++ b/examples/simple.rb @@ -2,7 +2,7 @@ require 'rubygems' require 'decisiontree' - + attributes = ['Temperature'] training = [ [36.6, 'healthy'], @@ -10,19 +10,17 @@ training = [ [38, 'sick'], [36.7, 'healthy'], [40, 'sick'], - [50, 'really sick'], + [50, 'really sick'] ] - + # Instantiate the tree, and train it based on the data (set default to '1') dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous) dec_tree.train test = [37, 'sick'] - + decision = dec_tree.predict(test) -puts "Predicted: #{decision} ... True decision: #{test.last}"; - +puts "Predicted: #{decision} ... True decision: #{test.last}" + # Graph the tree, save to 'tree.png' -dec_tree.graph("tree") - - +dec_tree.graph('tree') diff --git a/lib/core_extensions/array.rb b/lib/core_extensions/array.rb new file mode 100644 index 0000000..5a70756 --- /dev/null +++ b/lib/core_extensions/array.rb @@ -0,0 +1,29 @@ +class Array + def classification + collect(&:last) + end + + # calculate information entropy + def entropy + return 0 if empty? + + info = {} + each do |i| + info[i] = !info[i] ? 1 : (info[i] + 1) + end + + result(info, length) + end + + private + + def result(info, total) + final = 0 + info.each do |_symbol, count| + next unless count > 0 + percentage = count.to_f / total + final += -percentage * Math.log(percentage) / Math.log(2.0) + end + final + end +end diff --git a/lib/core_extensions/object.rb b/lib/core_extensions/object.rb new file mode 100644 index 0000000..0b79fd9 --- /dev/null +++ b/lib/core_extensions/object.rb @@ -0,0 +1,9 @@ +class Object + def save_to_file(filename) + File.open(filename, 'w+') { |f| f << Marshal.dump(self) } + end + + def self.load_from_file(filename) + Marshal.load(File.read(filename)) + end +end diff --git a/lib/decisiontree.rb b/lib/decisiontree.rb index 5583923..3da0b47 100644 --- a/lib/decisiontree.rb +++ b/lib/decisiontree.rb @@ -1 +1,3 @@ require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb' +require 'core_extensions/object' +require 'core_extensions/array' diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index afeceb1..4e4c340 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -3,50 +3,33 @@ ### Copyright (c) 2007 Ilya Grigorik ### Modifed at 2007 by José Ignacio Fernández -class Object - def save_to_file(filename) - File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) } - end - - def self.load_from_file(filename) - Marshal.load( File.read( filename ) ) - end -end - -class Array - def classification; collect { |v| v.last }; end - - # calculate information entropy - def entropy - return 0 if empty? - - info = {} - total = 0 - each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1} - - result = 0 - info.each do |symbol, count| - result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0) - end - result - end -end - module DecisionTree Node = Struct.new(:attribute, :threshold, :gain) class ID3Tree def initialize(attributes, data, default, type) - @used, @tree, @type = {}, {}, type - @data, @attributes, @default = data, attributes, default + @used = {} + @tree = {} + @type = type + @data = data + @attributes = attributes + @default = default end - def train(data=@data, attributes=@attributes, default=@default) - attributes = attributes.map {|e| e.to_s} + def train(data = @data, attributes = @attributes, default = @default) + attributes = attributes.map(&:to_s) initialize(attributes, data, default, @type) # Remove samples with same attributes leaving most common classification - data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]} + data2 = data.inject({}) do |hash, d| + hash[d.slice(0..-2)] ||= Hash.new(0) + hash[d.slice(0..-2)][d.last] += 1 + hash + end + + data2 = data2.map do |key, val| + key + [val.sort_by { |_k, v| v }.last.first] + end @tree = id3_train(data2, attributes, default) end @@ -57,12 +40,14 @@ module DecisionTree def fitness_for(attribute) case type(attribute) - when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)} - when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)} + when :discrete + proc { |a, b, c| id3_discrete(a, b, c) } + when :continuous + proc { |a, b, c| id3_continuous(a, b, c) } end end - def id3_train(data, attributes, default, used={}) + def id3_train(data, attributes, default, _used={}) return default if data.empty? # return classification if all examples have the same classification @@ -75,7 +60,7 @@ module DecisionTree performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) } max = performance.max { |a,b| a[0] <=> b[0] } min = performance.min { |a,b| a[0] <=> b[0] } - max = performance.shuffle.first if max[0] == min[0] + max = performance.sample if max[0] == min[0] best = Node.new(attributes[performance.index(max)], max[1], max[0]) best.threshold = nil if @type == :discrete @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] @@ -83,36 +68,47 @@ module DecisionTree fitness = fitness_for(best.attribute) case type(best.attribute) - when :continuous - data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i| - tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness) - } - when :discrete - values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort - partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } } - partitions.each_with_index { |examples, i| - tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness) - } + when :continuous + partitioned_data = data.partition do |d| + d[attributes.index(best.attribute)] >= best.threshold end + partitioned_data.each_with_index do |examples, i| + tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness) + end + when :discrete + values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort + partitions = values.collect do |val| + data.select do |d| + d[attributes.index(best.attribute)] == val + end + end + partitions.each_with_index do |examples, i| + tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0), &fitness) + end + end tree end # ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds) def id3_continuous(data, attributes, attribute) - values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, [] + values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort + thresholds = [] return [-1, -1] if values.size == 1 - values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) } + values.each_index do |i| + thresholds.push((values[i] + (values[i + 1].nil? ? values[i] : values[i + 1])).to_f / 2) + end thresholds.pop #thresholds -= used[attribute] if used.has_key? attribute - gain = thresholds.collect { |threshold| + gain = thresholds.collect do |threshold| sp = data.partition { |d| d[attributes.index(attribute)] >= threshold } pos = (sp[0].size).to_f / data.size neg = (sp[1].size).to_f / data.size - [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold] - }.max { |a,b| a[0] <=> b[0] } + [data.classification.entropy - pos * sp[0].classification.entropy - neg * sp[1].classification.entropy, threshold] + end + gain = gain.max { |a, b| a[0] <=> b[0] } return [-1, -1] if gain.size == 0 gain @@ -122,7 +118,7 @@ module DecisionTree def id3_discrete(data, attributes, attribute) values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } } - remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i } + remainder = partitions.collect { |p| (p.size.to_f / data.size) * p.classification.entropy }.inject(0) { |a, e| e += a } [data.classification.entropy - remainder, attributes.index(attribute)] end @@ -131,7 +127,7 @@ module DecisionTree descend(@tree, test) end - def graph(filename, file_type = "png") + def graph(filename, file_type = 'png') require 'graphr' dgp = DotGraphPrinter.new(build_tree) dgp.write_to_file("#{filename}.#{file_type}", file_type) @@ -143,12 +139,12 @@ module DecisionTree rs end - def build_rules(tree=@tree) + def build_rules(tree = @tree) attr = tree.to_a.first cases = attr[1].to_a rules = [] - cases.each do |c,child| - if child.is_a?(Hash) then + cases.each do |c, child| + if child.is_a?(Hash) build_rules(child).each do |r| r2 = r.clone r2.premises.unshift([attr.first, c]) @@ -161,43 +157,47 @@ module DecisionTree rules end - private + private + def descend(tree, test) attr = tree.to_a.first - return @default if !attr + return @default unless attr if type(attr.first.attribute) == :continuous - return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold - return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold - return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold - return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold + return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold + return descend(attr[1]['>='], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return descend(attr[1]['<'], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold else return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash) - return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test) + return descend(attr[1][test[@attributes.index(attr[0].attribute)]], test) end end def build_tree(tree = @tree) return [] unless tree.is_a?(Hash) - return [["Always", @default]] if tree.empty? + return [['Always', @default]] if tree.empty? attr = tree.to_a.first links = attr[1].keys.collect do |key| parent_text = "#{attr[0].attribute}\n(#{attr[0].object_id})" - if attr[1][key].is_a?(Hash) then + if attr[1][key].is_a?(Hash) child = attr[1][key].to_a.first[0] child_text = "#{child.attribute}\n(#{child.object_id})" else child = attr[1][key] child_text = "#{child}\n(#{child.to_s.clone.object_id})" end - label_text = "#{key} #{type(attr[0].attribute) == :continuous ? attr[0].threshold : ""}" + label_text = "#{key} ''" + if type(attr[0].attribute) == :continuous + label_text.gsub!("''", attr[0].threshold) + end [parent_text, child_text, label_text] end attr[1].keys.each { |key| links += build_tree(attr[1][key]) } - return links + links end end @@ -206,48 +206,56 @@ module DecisionTree attr_accessor :conclusion attr_accessor :attributes - def initialize(attributes,premises=[],conclusion=nil) - @attributes, @premises, @conclusion = attributes, premises, conclusion + def initialize(attributes, premises = [], conclusion = nil) + @attributes = attributes + @premises = premises + @conclusion = conclusion end def to_s str = '' @premises.each do |p| - str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" if p.first.threshold - str += "#{p.first.attribute} = #{p.last}" if !p.first.threshold + if p.first.threshold + str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" + else + str += "#{p.first.attribute} = #{p.last}" + end str += "\n" end str += "=> #{@conclusion} (#{accuracy})" end def predict(test) - verifies = true; + verifies = true @premises.each do |p| - if p.first.threshold then # Continuous - if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) then - verifies = false; break + if p.first.threshold # Continuous + if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) + verifies = false + break end else # Discrete - if test[@attributes.index(p.first.attribute)] != p.last then - verifies = false; break + if test[@attributes.index(p.first.attribute)] != p.last + verifies = false + break end end end return @conclusion if verifies - return nil + nil end def get_accuracy(data) - correct = 0; total = 0 + correct = 0 + total = 0 data.each do |d| prediction = predict(d) correct += 1 if d.last == prediction - total += 1 if !prediction.nil? + total += 1 unless prediction.nil? end (correct.to_f + 1) / (total.to_f + 2) end - def accuracy(data=nil) + def accuracy(data = nil) data.nil? ? @accuracy : @accuracy = get_accuracy(data) end end @@ -256,14 +264,16 @@ module DecisionTree attr_accessor :rules def initialize(attributes, data, default, type) - @attributes, @default, @type = attributes, default, type - mixed_data = data.sort_by {rand} + @attributes = attributes + @default = default + @type = type + mixed_data = data.sort_by { rand } cut = (mixed_data.size.to_f * 0.67).to_i - @train_data = mixed_data.slice(0..cut-1) + @train_data = mixed_data.slice(0..cut - 1) @prune_data = mixed_data.slice(cut..-1) end - def train(train_data=@train_data, attributes=@attributes, default=@default) + def train(train_data = @train_data, attributes = @attributes, default = @default) dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type) dec_tree.train @rules = dec_tree.build_rules @@ -271,21 +281,23 @@ module DecisionTree prune end - def prune(data=@prune_data) + def prune(data = @prune_data) @rules.each do |r| (1..r.premises.size).each do acc1 = r.accuracy(data) p = r.premises.pop - if acc1 > r.get_accuracy(data) then - r.premises.push(p); break + if acc1 > r.get_accuracy(data) + r.premises.push(p) + break end end end - @rules = @rules.sort_by{|r| -r.accuracy(data)} + @rules = @rules.sort_by { |r| -r.accuracy(data) } end def to_s - str = ''; @rules.each { |rule| str += "#{rule}\n\n" } + str = '' + @rules.each { |rule| str += "#{rule}\n\n" } str end @@ -294,18 +306,21 @@ module DecisionTree prediction = r.predict(test) return prediction, r.accuracy unless prediction.nil? end - return @default, 0.0 + [@default, 0.0] end end class Bagging attr_accessor :classifiers def initialize(attributes, data, default, type) - @classifiers, @type = [], type - @data, @attributes, @default = data, attributes, default + @classifiers = [] + @type = type + @data = data + @attributes = attributes + @default = default end - def train(data=@data, attributes=@attributes, default=@default) + def train(data = @data, attributes = @attributes, default = @default) @classifiers = [] 10.times { @classifiers << Ruleset.new(attributes, data, default, @type) } @classifiers.each do |c| @@ -320,8 +335,8 @@ module DecisionTree predictions[p] += accuracy unless p.nil? end return @default, 0.0 if predictions.empty? - winner = predictions.sort_by {|k,v| -v}.first - return winner[0], winner[1].to_f / @classifiers.size.to_f + winner = predictions.sort_by { |_k, v| -v }.first + [winner[0], winner[1].to_f / @classifiers.size.to_f] end end end