mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-28 15:14:52 +00:00
added support for continuous and discrete attributes in the same dataset
This commit is contained in:
@@ -15,9 +15,9 @@ class Object
|
||||
end
|
||||
end
|
||||
|
||||
class Array
|
||||
def classification; collect { |v| v.last }; end
|
||||
|
||||
class Array
|
||||
def classification; collect { |v| v.last }; end
|
||||
|
||||
# calculate information entropy
|
||||
def entropy
|
||||
return 0 if empty?
|
||||
@@ -51,28 +51,34 @@ module DecisionTree
|
||||
|
||||
@tree = id3_train(data2, attributes, default)
|
||||
end
|
||||
|
||||
def id3_train(data, attributes, default, used={})
|
||||
# Choose a fitness algorithm
|
||||
case @type
|
||||
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
||||
|
||||
def type(attribute)
|
||||
@type.is_a?(Hash) ? @type[attribute.to_sym] : @type
|
||||
end
|
||||
|
||||
def fitness_for(attribute)
|
||||
case type(attribute)
|
||||
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
||||
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
|
||||
end
|
||||
|
||||
return default if data.empty?
|
||||
end
|
||||
|
||||
def id3_train(data, attributes, default, used={})
|
||||
return default if data.empty?
|
||||
|
||||
# return classification if all examples have the same classification
|
||||
return data.first.last if data.classification.uniq.size == 1
|
||||
|
||||
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
|
||||
performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
|
||||
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
||||
max = performance.max { |a,b| a[0] <=> b[0] }
|
||||
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
||||
best.threshold = nil if @type == :discrete
|
||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||
tree, l = {best => {}}, ['>=', '<']
|
||||
|
||||
case @type
|
||||
|
||||
fitness = fitness_for(best.attribute)
|
||||
case type(best.attribute)
|
||||
when :continuous
|
||||
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
|
||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
||||
@@ -82,7 +88,7 @@ module DecisionTree
|
||||
partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
|
||||
partitions.each_with_index { |examples, i|
|
||||
tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
|
||||
}
|
||||
}
|
||||
end
|
||||
|
||||
tree
|
||||
@@ -96,32 +102,32 @@ module DecisionTree
|
||||
thresholds.pop
|
||||
#thresholds -= used[attribute] if used.has_key? attribute
|
||||
|
||||
gain = thresholds.collect { |threshold|
|
||||
gain = thresholds.collect { |threshold|
|
||||
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
||||
pos = (sp[0].size).to_f / data.size
|
||||
neg = (sp[1].size).to_f / data.size
|
||||
|
||||
|
||||
[data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
|
||||
}.max { |a,b| a[0] <=> b[0] }
|
||||
|
||||
return [-1, -1] if gain.size == 0
|
||||
gain
|
||||
end
|
||||
|
||||
|
||||
# ID3 for discrete label cases
|
||||
def id3_discrete(data, attributes, attribute)
|
||||
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
||||
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
|
||||
remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
|
||||
|
||||
|
||||
[data.classification.entropy - remainder, attributes.index(attribute)]
|
||||
end
|
||||
|
||||
def predict(test)
|
||||
return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test))
|
||||
descend(@tree, test)
|
||||
end
|
||||
|
||||
def graph(filename)
|
||||
def graph(filename)
|
||||
dgp = DotGraphPrinter.new(build_tree)
|
||||
dgp.write_to_file("#{filename}.png", "png")
|
||||
end
|
||||
@@ -151,22 +157,20 @@ module DecisionTree
|
||||
end
|
||||
|
||||
private
|
||||
def descend_continuous(tree, test)
|
||||
def descend(tree, test)
|
||||
attr = tree.to_a.first
|
||||
return @default if !attr
|
||||
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||
return descend_continuous(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||
return descend_continuous(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||
end
|
||||
|
||||
def descend_discrete(tree, test)
|
||||
attr = tree.to_a.first
|
||||
return @default if !attr
|
||||
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
||||
return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
||||
if type(attr.first.attribute) == :continuous
|
||||
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||
return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||
return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||
else
|
||||
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
||||
return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
def build_tree(tree = @tree)
|
||||
return [] unless tree.is_a?(Hash)
|
||||
return [["Always", @default]] if tree.empty?
|
||||
@@ -282,7 +286,7 @@ module DecisionTree
|
||||
|
||||
def predict(test)
|
||||
@rules.each do |r|
|
||||
prediction = r.predict(test)
|
||||
prediction = r.predict(test)
|
||||
return prediction, r.accuracy unless prediction.nil?
|
||||
end
|
||||
return @default, 0.0
|
||||
|
||||
Reference in New Issue
Block a user