diff --git a/..gemspec b/..gemspec new file mode 100644 index 0000000..67e63f2 --- /dev/null +++ b/..gemspec @@ -0,0 +1,19 @@ +# -*- encoding: utf-8 -*- +lib = File.expand_path('../lib', __FILE__) +$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib) +require './version' + +Gem::Specification.new do |gem| + gem.name = "." + gem.version = .::VERSION + gem.authors = ["Chris Nelson"] + gem.email = ["chris@gaslightsoftware.com"] + gem.description = %q{TODO: Write a gem description} + gem.summary = %q{TODO: Write a gem summary} + gem.homepage = "" + + gem.files = `git ls-files`.split($/) + gem.executables = gem.files.grep(%r{^bin/}).map{ |f| File.basename(f) } + gem.test_files = gem.files.grep(%r{^(test|spec|features)/}) + gem.require_paths = ["lib"] +end diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d87d4be --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +*.gem +*.rbc +.bundle +.config +.yardoc +Gemfile.lock +InstalledFiles +_yardoc +coverage +doc/ +lib/bundler/man +pkg +rdoc +spec/reports +test/tmp +test/version_tmp +tmp diff --git a/Gemfile b/Gemfile new file mode 100644 index 0000000..cc56bff --- /dev/null +++ b/Gemfile @@ -0,0 +1,4 @@ +source 'https://rubygems.org' + +# Specify your gem's dependencies in ..gemspec +gemspec diff --git a/decisiontree.gemspec b/decisiontree.gemspec index 168768a..6c38bf0 100644 --- a/decisiontree.gemspec +++ b/decisiontree.gemspec @@ -47,7 +47,10 @@ Gem::Specification.new do |s| "examples/simple.rb" ] s.add_runtime_dependency "graphr" - + s.add_development_dependency "rspec" + s.add_development_dependency "rspec-given" + s.add_development_dependency "pry" + if s.respond_to? :specification_version then current_version = Gem::Specification::CURRENT_SPECIFICATION_VERSION s.specification_version = 3 diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index 2656ac1..dc3558a 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -15,9 +15,9 @@ class Object end end -class Array - def classification; collect { |v| v.last }; end - +class Array + def classification; collect { |v| v.last }; end + # calculate information entropy def entropy return 0 if empty? @@ -51,28 +51,34 @@ module DecisionTree @tree = id3_train(data2, attributes, default) end - - def id3_train(data, attributes, default, used={}) - # Choose a fitness algorithm - case @type - when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)} + + def type(attribute) + @type.is_a?(Hash) ? @type[attribute.to_sym] : @type + end + + def fitness_for(attribute) + case type(attribute) + when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)} when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)} end - - return default if data.empty? + end + + def id3_train(data, attributes, default, used={}) + return default if data.empty? # return classification if all examples have the same classification return data.first.last if data.classification.uniq.size == 1 # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute) - performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) } + performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) } max = performance.max { |a,b| a[0] <=> b[0] } best = Node.new(attributes[performance.index(max)], max[1], max[0]) best.threshold = nil if @type == :discrete - @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] + @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] tree, l = {best => {}}, ['>=', '<'] - - case @type + + fitness = fitness_for(best.attribute) + case type(best.attribute) when :continuous data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i| tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness) @@ -82,7 +88,7 @@ module DecisionTree partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } } partitions.each_with_index { |examples, i| tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness) - } + } end tree @@ -96,32 +102,32 @@ module DecisionTree thresholds.pop #thresholds -= used[attribute] if used.has_key? attribute - gain = thresholds.collect { |threshold| + gain = thresholds.collect { |threshold| sp = data.partition { |d| d[attributes.index(attribute)] >= threshold } pos = (sp[0].size).to_f / data.size neg = (sp[1].size).to_f / data.size - + [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold] }.max { |a,b| a[0] <=> b[0] } return [-1, -1] if gain.size == 0 gain end - + # ID3 for discrete label cases def id3_discrete(data, attributes, attribute) values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } } remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i } - + [data.classification.entropy - remainder, attributes.index(attribute)] end def predict(test) - return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)) + descend(@tree, test) end - def graph(filename) + def graph(filename) dgp = DotGraphPrinter.new(build_tree) dgp.write_to_file("#{filename}.png", "png") end @@ -151,22 +157,20 @@ module DecisionTree end private - def descend_continuous(tree, test) + def descend(tree, test) attr = tree.to_a.first return @default if !attr - return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold - return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold - return descend_continuous(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold - return descend_continuous(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold - end - - def descend_discrete(tree, test) - attr = tree.to_a.first - return @default if !attr - return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash) - return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test) + if type(attr.first.attribute) == :continuous + return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold + return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold + else + return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash) + return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test) + end end - + def build_tree(tree = @tree) return [] unless tree.is_a?(Hash) return [["Always", @default]] if tree.empty? @@ -282,7 +286,7 @@ module DecisionTree def predict(test) @rules.each do |r| - prediction = r.predict(test) + prediction = r.predict(test) return prediction, r.accuracy unless prediction.nil? end return @default, 0.0 diff --git a/spec/id3_spec.rb b/spec/id3_spec.rb new file mode 100644 index 0000000..03b2ffa --- /dev/null +++ b/spec/id3_spec.rb @@ -0,0 +1,64 @@ +require 'spec_helper' + +describe describe DecisionTree::ID3Tree do + + describe "discrete attributes" do + Given(:labels) { ["hungry", "color"] } + Given(:data) do + [ + ["yes", "red", "angry"], + ["no", "blue", "not angry"], + ["yes", "blue", "not angry"], + ["no", "red", "not angry"] + ] + end + Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :discrete) } + When { tree.train } + Then { tree.predict(["yes", "red"]).should == "angry" } + Then { tree.predict(["no", "red"]).should == "not angry" } + end + + describe "discrete attributes" do + Given(:labels) { ["hunger", "happiness"] } + Given(:data) do + [ + [8, 7, "angry"], + [6, 7, "angry"], + [7, 9, "angry"], + [7, 1, "not angry"], + [2, 9, "not angry"], + [3, 2, "not angry"], + [2, 3, "not angry"], + [1, 4, "not angry"] + ] + end + Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :continuous) } + When { tree.train } + Then { tree.graph("continuous") } + Then { tree.predict([7, 7]).should == "angry" } + Then { tree.predict([2, 3]).should == "not angry" } + end + + describe "a mixture" do + Given(:labels) { ["hunger", "color"] } + Given(:data) do + [ + [8, "red", "angry"], + [6, "red", "angry"], + [7, "red", "angry"], + [7, "blue", "not angry"], + [2, "red", "not angry"], + [3, "blue", "not angry"], + [2, "blue", "not angry"], + [1, "red", "not angry"] + ] + end + Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", color: :discrete, hunger: :continuous) } + When { tree.train } + Then { tree.graph("continuous") } + Then { tree.predict([7, "red"]).should == "angry" } + Then { tree.predict([2, "blue"]).should == "not angry" } + end + + +end diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb new file mode 100644 index 0000000..7d706da --- /dev/null +++ b/spec/spec_helper.rb @@ -0,0 +1,3 @@ +require 'rspec/given' +require 'decisiontree' +require 'pry'