diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index 8f9a7e8..04c9191 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -42,6 +42,7 @@ module DecisionTree end def train(data=@data, attributes=@attributes, default=@default) + attributes = attributes.map {|e| e.to_s} initialize(attributes, data, default, @type) # Remove samples with same attributes leaving most common classification diff --git a/spec/id3_spec.rb b/spec/id3_spec.rb index 5eb67ba..b86b802 100644 --- a/spec/id3_spec.rb +++ b/spec/id3_spec.rb @@ -89,4 +89,20 @@ describe describe DecisionTree::ID3Tree do Then { tree.predict(["a1","b0","c0"]).should == "RED" } end + describe "numerical labels case" do + Given(:labels) { [1, 2] } + Given(:data) do + [ + [1, 1, true], + [1, 2, false], + [2, 1, false], + [2, 2, true] + ] + end + Given(:tree) { DecisionTree::ID3Tree.new labels, data, nil, :discrete } + When { tree.train } + Then { + lambda { tree.predict([1, 1]) }.should_not raise_error + } + end end