mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-28 07:04:53 +00:00
Speed improvements for discrete
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
### Copyright (c) 2007 Ilya Grigorik <ilya AT igvita DOT com>
|
||||
### Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
|
||||
|
||||
require 'set'
|
||||
|
||||
module DecisionTree
|
||||
Node = Struct.new(:attribute, :threshold, :gain)
|
||||
|
||||
@@ -28,7 +30,7 @@ module DecisionTree
|
||||
end
|
||||
|
||||
data2 = data2.map do |key, val|
|
||||
key + [val.sort_by { |_k, v| v }.last.first]
|
||||
key + [val.sort_by { |_, v| v }.last.first]
|
||||
end
|
||||
|
||||
@tree = id3_train(data2, attributes, default)
|
||||
@@ -41,9 +43,9 @@ module DecisionTree
|
||||
def fitness_for(attribute)
|
||||
case type(attribute)
|
||||
when :discrete
|
||||
proc { |a, b, c| id3_discrete(a, b, c) }
|
||||
proc { |*args| id3_discrete(*args) }
|
||||
when :continuous
|
||||
proc { |a, b, c| id3_continuous(a, b, c) }
|
||||
proc { |*args| id3_continuous(*args) }
|
||||
end
|
||||
end
|
||||
|
||||
@@ -66,14 +68,13 @@ module DecisionTree
|
||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||
tree, l = {best => {}}, ['>=', '<']
|
||||
|
||||
fitness = fitness_for(best.attribute)
|
||||
case type(best.attribute)
|
||||
when :continuous
|
||||
partitioned_data = data.partition do |d|
|
||||
d[attributes.index(best.attribute)] >= best.threshold
|
||||
end
|
||||
partitioned_data.each_with_index do |examples, i|
|
||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0))
|
||||
end
|
||||
when :discrete
|
||||
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
||||
@@ -83,7 +84,7 @@ module DecisionTree
|
||||
end
|
||||
end
|
||||
partitions.each_with_index do |examples, i|
|
||||
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0), &fitness)
|
||||
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0))
|
||||
end
|
||||
end
|
||||
|
||||
@@ -116,11 +117,14 @@ module DecisionTree
|
||||
|
||||
# ID3 for discrete label cases
|
||||
def id3_discrete(data, attributes, attribute)
|
||||
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
||||
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
|
||||
index = attributes.index(attribute)
|
||||
|
||||
values = Set.new
|
||||
data.each { |d| values << d[index] }
|
||||
partitions = values.to_a.sort.collect { |val| data.select { |d| d[index] == val } }
|
||||
remainder = partitions.collect { |p| (p.size.to_f / data.size) * p.classification.entropy }.inject(0) { |a, e| e += a }
|
||||
|
||||
[data.classification.entropy - remainder, attributes.index(attribute)]
|
||||
[data.classification.entropy - remainder, index]
|
||||
end
|
||||
|
||||
def predict(test)
|
||||
|
||||
Reference in New Issue
Block a user