mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-28 15:14:52 +00:00
added support for continuous and discrete attributes in the same dataset
This commit is contained in:
19
..gemspec
Normal file
19
..gemspec
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
lib = File.expand_path('../lib', __FILE__)
|
||||||
|
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
||||||
|
require './version'
|
||||||
|
|
||||||
|
Gem::Specification.new do |gem|
|
||||||
|
gem.name = "."
|
||||||
|
gem.version = .::VERSION
|
||||||
|
gem.authors = ["Chris Nelson"]
|
||||||
|
gem.email = ["chris@gaslightsoftware.com"]
|
||||||
|
gem.description = %q{TODO: Write a gem description}
|
||||||
|
gem.summary = %q{TODO: Write a gem summary}
|
||||||
|
gem.homepage = ""
|
||||||
|
|
||||||
|
gem.files = `git ls-files`.split($/)
|
||||||
|
gem.executables = gem.files.grep(%r{^bin/}).map{ |f| File.basename(f) }
|
||||||
|
gem.test_files = gem.files.grep(%r{^(test|spec|features)/})
|
||||||
|
gem.require_paths = ["lib"]
|
||||||
|
end
|
||||||
17
.gitignore
vendored
Normal file
17
.gitignore
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
*.gem
|
||||||
|
*.rbc
|
||||||
|
.bundle
|
||||||
|
.config
|
||||||
|
.yardoc
|
||||||
|
Gemfile.lock
|
||||||
|
InstalledFiles
|
||||||
|
_yardoc
|
||||||
|
coverage
|
||||||
|
doc/
|
||||||
|
lib/bundler/man
|
||||||
|
pkg
|
||||||
|
rdoc
|
||||||
|
spec/reports
|
||||||
|
test/tmp
|
||||||
|
test/version_tmp
|
||||||
|
tmp
|
||||||
4
Gemfile
Normal file
4
Gemfile
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
source 'https://rubygems.org'
|
||||||
|
|
||||||
|
# Specify your gem's dependencies in ..gemspec
|
||||||
|
gemspec
|
||||||
@@ -47,6 +47,9 @@ Gem::Specification.new do |s|
|
|||||||
"examples/simple.rb"
|
"examples/simple.rb"
|
||||||
]
|
]
|
||||||
s.add_runtime_dependency "graphr"
|
s.add_runtime_dependency "graphr"
|
||||||
|
s.add_development_dependency "rspec"
|
||||||
|
s.add_development_dependency "rspec-given"
|
||||||
|
s.add_development_dependency "pry"
|
||||||
|
|
||||||
if s.respond_to? :specification_version then
|
if s.respond_to? :specification_version then
|
||||||
current_version = Gem::Specification::CURRENT_SPECIFICATION_VERSION
|
current_version = Gem::Specification::CURRENT_SPECIFICATION_VERSION
|
||||||
|
|||||||
@@ -52,27 +52,33 @@ module DecisionTree
|
|||||||
@tree = id3_train(data2, attributes, default)
|
@tree = id3_train(data2, attributes, default)
|
||||||
end
|
end
|
||||||
|
|
||||||
def id3_train(data, attributes, default, used={})
|
def type(attribute)
|
||||||
# Choose a fitness algorithm
|
@type.is_a?(Hash) ? @type[attribute.to_sym] : @type
|
||||||
case @type
|
end
|
||||||
|
|
||||||
|
def fitness_for(attribute)
|
||||||
|
case type(attribute)
|
||||||
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
||||||
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
|
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def id3_train(data, attributes, default, used={})
|
||||||
return default if data.empty?
|
return default if data.empty?
|
||||||
|
|
||||||
# return classification if all examples have the same classification
|
# return classification if all examples have the same classification
|
||||||
return data.first.last if data.classification.uniq.size == 1
|
return data.first.last if data.classification.uniq.size == 1
|
||||||
|
|
||||||
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
|
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
|
||||||
performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
|
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
||||||
max = performance.max { |a,b| a[0] <=> b[0] }
|
max = performance.max { |a,b| a[0] <=> b[0] }
|
||||||
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
||||||
best.threshold = nil if @type == :discrete
|
best.threshold = nil if @type == :discrete
|
||||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||||
tree, l = {best => {}}, ['>=', '<']
|
tree, l = {best => {}}, ['>=', '<']
|
||||||
|
|
||||||
case @type
|
fitness = fitness_for(best.attribute)
|
||||||
|
case type(best.attribute)
|
||||||
when :continuous
|
when :continuous
|
||||||
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
|
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
|
||||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
||||||
@@ -118,7 +124,7 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
|
|
||||||
def predict(test)
|
def predict(test)
|
||||||
return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test))
|
descend(@tree, test)
|
||||||
end
|
end
|
||||||
|
|
||||||
def graph(filename)
|
def graph(filename)
|
||||||
@@ -151,20 +157,18 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
def descend_continuous(tree, test)
|
def descend(tree, test)
|
||||||
attr = tree.to_a.first
|
attr = tree.to_a.first
|
||||||
return @default if !attr
|
return @default if !attr
|
||||||
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
if type(attr.first.attribute) == :continuous
|
||||||
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||||
return descend_continuous(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||||
return descend_continuous(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||||
end
|
return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||||
|
else
|
||||||
def descend_discrete(tree, test)
|
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
||||||
attr = tree.to_a.first
|
return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
||||||
return @default if !attr
|
end
|
||||||
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
|
||||||
return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def build_tree(tree = @tree)
|
def build_tree(tree = @tree)
|
||||||
|
|||||||
64
spec/id3_spec.rb
Normal file
64
spec/id3_spec.rb
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
require 'spec_helper'
|
||||||
|
|
||||||
|
describe describe DecisionTree::ID3Tree do
|
||||||
|
|
||||||
|
describe "discrete attributes" do
|
||||||
|
Given(:labels) { ["hungry", "color"] }
|
||||||
|
Given(:data) do
|
||||||
|
[
|
||||||
|
["yes", "red", "angry"],
|
||||||
|
["no", "blue", "not angry"],
|
||||||
|
["yes", "blue", "not angry"],
|
||||||
|
["no", "red", "not angry"]
|
||||||
|
]
|
||||||
|
end
|
||||||
|
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :discrete) }
|
||||||
|
When { tree.train }
|
||||||
|
Then { tree.predict(["yes", "red"]).should == "angry" }
|
||||||
|
Then { tree.predict(["no", "red"]).should == "not angry" }
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "discrete attributes" do
|
||||||
|
Given(:labels) { ["hunger", "happiness"] }
|
||||||
|
Given(:data) do
|
||||||
|
[
|
||||||
|
[8, 7, "angry"],
|
||||||
|
[6, 7, "angry"],
|
||||||
|
[7, 9, "angry"],
|
||||||
|
[7, 1, "not angry"],
|
||||||
|
[2, 9, "not angry"],
|
||||||
|
[3, 2, "not angry"],
|
||||||
|
[2, 3, "not angry"],
|
||||||
|
[1, 4, "not angry"]
|
||||||
|
]
|
||||||
|
end
|
||||||
|
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :continuous) }
|
||||||
|
When { tree.train }
|
||||||
|
Then { tree.graph("continuous") }
|
||||||
|
Then { tree.predict([7, 7]).should == "angry" }
|
||||||
|
Then { tree.predict([2, 3]).should == "not angry" }
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "a mixture" do
|
||||||
|
Given(:labels) { ["hunger", "color"] }
|
||||||
|
Given(:data) do
|
||||||
|
[
|
||||||
|
[8, "red", "angry"],
|
||||||
|
[6, "red", "angry"],
|
||||||
|
[7, "red", "angry"],
|
||||||
|
[7, "blue", "not angry"],
|
||||||
|
[2, "red", "not angry"],
|
||||||
|
[3, "blue", "not angry"],
|
||||||
|
[2, "blue", "not angry"],
|
||||||
|
[1, "red", "not angry"]
|
||||||
|
]
|
||||||
|
end
|
||||||
|
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", color: :discrete, hunger: :continuous) }
|
||||||
|
When { tree.train }
|
||||||
|
Then { tree.graph("continuous") }
|
||||||
|
Then { tree.predict([7, "red"]).should == "angry" }
|
||||||
|
Then { tree.predict([2, "blue"]).should == "not angry" }
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
end
|
||||||
3
spec/spec_helper.rb
Normal file
3
spec/spec_helper.rb
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
require 'rspec/given'
|
||||||
|
require 'decisiontree'
|
||||||
|
require 'pry'
|
||||||
Reference in New Issue
Block a user