mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-28 15:14:52 +00:00
Tidy code style
This commit is contained in:
@@ -17,4 +17,3 @@ module ArrayClassification
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
class Object
|
||||
def save_to_file(filename)
|
||||
File.open(filename, 'w+') { |f| f << Marshal.dump(self) }
|
||||
File.open(filename, "w+") { |f| f << Marshal.dump(self) }
|
||||
end
|
||||
|
||||
def self.load_from_file(filename)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
require 'core_extensions/object'
|
||||
require 'core_extensions/array'
|
||||
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
|
||||
require "core_extensions/object"
|
||||
require "core_extensions/array"
|
||||
require File.dirname(__FILE__) + "/decisiontree/id3_tree.rb"
|
||||
|
||||
@@ -23,14 +23,13 @@ module DecisionTree
|
||||
initialize(attributes, data, default, @type)
|
||||
|
||||
# Remove samples with same attributes leaving most common classification
|
||||
data2 = data.inject({}) do |hash, d|
|
||||
data2 = data.each_with_object({}) do |d, hash|
|
||||
hash[d.slice(0..-2)] ||= Hash.new(0)
|
||||
hash[d.slice(0..-2)][d.last] += 1
|
||||
hash
|
||||
end
|
||||
|
||||
data2 = data2.map do |key, val|
|
||||
key + [val.sort_by { |_, v| v }.last.first]
|
||||
key + [val.max_by { |_, v| v }.first]
|
||||
end
|
||||
|
||||
@tree = id3_train(data2, attributes, default)
|
||||
@@ -49,7 +48,7 @@ module DecisionTree
|
||||
end
|
||||
end
|
||||
|
||||
def id3_train(data, attributes, default, _used={})
|
||||
def id3_train(data, attributes, default, _used = {})
|
||||
return default if data.empty?
|
||||
|
||||
# return classification if all examples have the same classification
|
||||
@@ -60,13 +59,13 @@ module DecisionTree
|
||||
# 2. Pick best attribute
|
||||
# 3. If attributes all score the same, then pick a random one to avoid infinite recursion.
|
||||
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
||||
max = performance.max { |a,b| a[0] <=> b[0] }
|
||||
min = performance.min { |a,b| a[0] <=> b[0] }
|
||||
max = performance.max_by { |a| a[0] }
|
||||
min = performance.min_by { |a| a[0] }
|
||||
max = performance.sample if max[0] == min[0]
|
||||
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
||||
best.threshold = nil if @type == :discrete
|
||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||
tree, l = {best => {}}, ['>=', '<']
|
||||
tree, l = {best => {}}, [">=", "<"]
|
||||
|
||||
case type(best.attribute)
|
||||
when :continuous
|
||||
@@ -74,7 +73,11 @@ module DecisionTree
|
||||
d[attributes.index(best.attribute)] >= best.threshold
|
||||
end
|
||||
partitioned_data.each_with_index do |examples, i|
|
||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0))
|
||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, begin
|
||||
data.classification.mode
|
||||
rescue
|
||||
0
|
||||
end)
|
||||
end
|
||||
when :discrete
|
||||
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
||||
@@ -84,7 +87,11 @@ module DecisionTree
|
||||
end
|
||||
end
|
||||
partitions.each_with_index do |examples, i|
|
||||
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0))
|
||||
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], begin
|
||||
data.classification.mode
|
||||
rescue
|
||||
0
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -100,16 +107,16 @@ module DecisionTree
|
||||
thresholds.push((values[i] + (values[i + 1].nil? ? values[i] : values[i + 1])).to_f / 2)
|
||||
end
|
||||
thresholds.pop
|
||||
#thresholds -= used[attribute] if used.has_key? attribute
|
||||
# thresholds -= used[attribute] if used.has_key? attribute
|
||||
|
||||
gain = thresholds.collect do |threshold|
|
||||
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
||||
pos = (sp[0].size).to_f / data.size
|
||||
neg = (sp[1].size).to_f / data.size
|
||||
pos = sp[0].size.to_f / data.size
|
||||
neg = sp[1].size.to_f / data.size
|
||||
|
||||
[data.classification.entropy - pos * sp[0].classification.entropy - neg * sp[1].classification.entropy, threshold]
|
||||
end
|
||||
gain = gain.max { |a, b| a[0] <=> b[0] }
|
||||
gain = gain.max_by { |a| a[0] }
|
||||
|
||||
return [-1, -1] if gain.size == 0
|
||||
gain
|
||||
@@ -135,16 +142,16 @@ module DecisionTree
|
||||
descend(@tree, test)
|
||||
end
|
||||
|
||||
def graph(filename, file_type = 'png')
|
||||
require 'graphr'
|
||||
def graph(filename, file_type = "png")
|
||||
require "graphr"
|
||||
dgp = DotGraphPrinter.new(build_tree)
|
||||
dgp.size = ''
|
||||
dgp.size = ""
|
||||
dgp.node_labeler = proc { |n| n.split("\n").first }
|
||||
dgp.write_to_file("#{filename}.#{file_type}", file_type)
|
||||
rescue LoadError
|
||||
STDERR.puts "Error: Cannot generate graph."
|
||||
STDERR.puts " The 'graphr' gem doesn't seem to be installed."
|
||||
STDERR.puts " Run 'gem install graphr' or add it to your Gemfile."
|
||||
warn "Error: Cannot generate graph."
|
||||
warn " The 'graphr' gem doesn't seem to be installed."
|
||||
warn " Run 'gem install graphr' or add it to your Gemfile."
|
||||
end
|
||||
|
||||
def ruleset
|
||||
@@ -177,19 +184,19 @@ module DecisionTree
|
||||
attr = tree.to_a.first
|
||||
return @default unless attr
|
||||
if type(attr.first.attribute) == :continuous
|
||||
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||
return descend(attr[1]['>='], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||
return descend(attr[1]['<'], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||
return attr[1][">="] if !attr[1][">="].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||
return attr[1]["<"] if !attr[1]["<"].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||
return descend(attr[1][">="], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||
return descend(attr[1]["<"], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||
else
|
||||
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
||||
return descend(attr[1][test[@attributes.index(attr[0].attribute)]], test)
|
||||
return attr[1][test[@attributes.index(attr[0].attribute)]] unless attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
||||
descend(attr[1][test[@attributes.index(attr[0].attribute)]], test)
|
||||
end
|
||||
end
|
||||
|
||||
def build_tree(tree = @tree)
|
||||
return [] unless tree.is_a?(Hash)
|
||||
return [['Always', @default]] if tree.empty?
|
||||
return [["Always", @default]] if tree.empty?
|
||||
|
||||
attr = tree.to_a.first
|
||||
|
||||
@@ -203,10 +210,10 @@ module DecisionTree
|
||||
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
||||
end
|
||||
|
||||
if type(attr[0].attribute) == :continuous
|
||||
label_text = "#{key} #{attr[0].threshold}"
|
||||
label_text = if type(attr[0].attribute) == :continuous
|
||||
"#{key} #{attr[0].threshold}"
|
||||
else
|
||||
label_text = key
|
||||
key
|
||||
end
|
||||
|
||||
[parent_text, child_text, label_text]
|
||||
@@ -229,12 +236,12 @@ module DecisionTree
|
||||
end
|
||||
|
||||
def to_s
|
||||
str = ''
|
||||
str = ""
|
||||
@premises.each do |p|
|
||||
if p.first.threshold
|
||||
str += "#{p.first.attribute} #{p.last} #{p.first.threshold}"
|
||||
str += if p.first.threshold
|
||||
"#{p.first.attribute} #{p.last} #{p.first.threshold}"
|
||||
else
|
||||
str += "#{p.first.attribute} = #{p.last}"
|
||||
"#{p.first.attribute} = #{p.last}"
|
||||
end
|
||||
str += "\n"
|
||||
end
|
||||
@@ -245,15 +252,13 @@ module DecisionTree
|
||||
verifies = true
|
||||
@premises.each do |p|
|
||||
if p.first.threshold # Continuous
|
||||
if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold)
|
||||
verifies = false
|
||||
break
|
||||
end
|
||||
else # Discrete
|
||||
if test[@attributes.index(p.first.attribute)] != p.last
|
||||
if !(p.last == ">=" && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == "<" && test[@attributes.index(p.first.attribute)] < p.first.threshold)
|
||||
verifies = false
|
||||
break
|
||||
end
|
||||
elsif test[@attributes.index(p.first.attribute)] != p.last # Discrete
|
||||
verifies = false
|
||||
break
|
||||
end
|
||||
end
|
||||
return @conclusion if verifies
|
||||
@@ -312,7 +317,7 @@ module DecisionTree
|
||||
end
|
||||
|
||||
def to_s
|
||||
str = ''
|
||||
str = ""
|
||||
@rules.each { |rule| str += "#{rule}\n\n" }
|
||||
str
|
||||
end
|
||||
@@ -355,9 +360,8 @@ module DecisionTree
|
||||
predictions[p] += accuracy unless p.nil?
|
||||
end
|
||||
return @default, 0.0 if predictions.empty?
|
||||
winner = predictions.sort_by { |_k, v| -v }.first
|
||||
winner = predictions.min_by { |_k, v| -v }
|
||||
[winner[0], winner[1].to_f / @classifiers.size.to_f]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user