root / neighbors / nearest.py

Revision 128:5e724235253b, 7.6 kB (checked in by Tarek Ziad?? <tarek@…>, 13 months ago)

added db level

Line 
1
2# original code, http://aima.cs.berkeley.edu/python/learning.html
3# a bit arranged
4
5class DataSet(object):
6    """A data set for a machine learning problem.  It has the following fields:
7
8    d.examples    A list of examples.  Each one is a list of attribute values.
9    d.attrs       A list of integers to index into an example, so example[attr]
10                  gives a value. Normally the same as range(len(d.examples)).
11    d.attrnames   Optional list of mnemonic names for corresponding attrs.
12    d.target      The attribute that a learning algorithm will try to predict.
13                  By default the final attribute.
14    d.inputs      The list of attrs without the target.
15    d.values      A list of lists, each sublist is the set of possible
16                  values for the corresponding attribute. If None, it
17                  is computed from the known examples by self.setproblem.
18                  If not None, an erroneous value raises ValueError.
19    d.name        Name of the data set (for output display only).
20    d.source      URL or other source where the data came from.
21
22    Normally, you call the constructor and you're done; then you just
23    access fields like d.examples and d.target and d.inputs."""
24
25    def update(self, x, **entries):
26        """Update a dict; or an object with slots; according to entries.
27        >>> update({'a': 1}, a=10, b=20)
28        {'a': 10, 'b': 20}
29        >>> update(Struct(a=1), a=10, b=20)
30        Struct(a=10, b=20)
31        """
32        if isinstance(x, dict):
33            x.update(entries)
34        else:
35            x.__dict__.update(entries)
36        return x
37
38    def __init__(self, examples=None, attrs=None, target=-1, values=None,
39                 attrnames=None, name='', source='',
40                 inputs=None, exclude=(), doc=''):
41        """Accepts any of DataSet's fields.  Examples can
42        also be a string or file from which to parse examples using parse_csv.
43        Ex: DataSet(examples='1 2 3')"""
44        self.update(self, name=name, source=source, values=values)
45        # Initialize .examples from string or list or data directory
46        if examples is None:
47            examples = []
48
49        if isinstance(examples, str):
50            self.examples = parse_csv(examples)
51
52        #elif examples is None:
53        #    self.examples = parse_csv(DataFile(name+'.csv').read())
54        else:
55            self.examples = examples
56        if len(self.examples) > 0:
57            map(self.check_example, self.examples)
58        # Attrs are the indicies of examples, unless otherwise stated.
59        if not attrs:
60            if len(self.examples) > 0:
61                attrs = range(len(self.examples[0]))
62            else:
63                attrs = range(len(attrnames))
64        self.attrs = attrs
65        # Initialize .attrnames from string, list, or by default
66        if isinstance(attrnames, str):
67            self.attrnames = attrnames.split()
68        else:
69            self.attrnames = attrnames or attrs
70        self.setproblem(target, inputs=inputs, exclude=exclude)
71
72    def setproblem(self, target, inputs=None, exclude=()):
73        """Set (or change) the target and/or inputs.
74        This way, one DataSet can be used multiple ways. inputs, if specified,
75        is a list of attributes, or specify exclude as a list of attributes
76        to not put use in inputs. Attributes can be -n .. n, or an attrname.
77        Also computes the list of possible values, if that wasn't done yet."""
78        self.target = self.attrnum(target)
79        exclude = map(self.attrnum, exclude)
80        if inputs:
81            self.inputs = removall(self.target, inputs)
82        else:
83            self.inputs = [a for a in self.attrs
84                           if a is not self.target and a not in exclude]
85        #if not self.values: # amir
86        #    self.values = map(unique, zip(*self.examples))
87
88    def add_example(self, example):
89        """Add an example to the list of examples, checking it first."""
90        self.check_example(example)
91        self.examples.append(example)
92
93    def check_example(self, example):
94        """Raise ValueError if example has any invalid values."""
95        if self.values:
96            for a in self.attrs:
97                if example[a] not in self.values[a]:
98                    raise ValueError('Bad value %s for attribute %s in %s' %
99                                     (example[a], self.attrnames[a], example))
100
101    def attrnum(self, attr):
102        "Returns the number used for attr, which can be a name, or -n .. n."
103        if attr < 0:
104            return len(self.attrs) + attr
105        elif isinstance(attr, str):
106            return self.attrnames.index(attr)
107        else:
108            return attr
109
110    def sanitize(self, example):
111       "Return a copy of example, with non-input attributes replaced by 0."
112       return [i in self.inputs and example[i] for i in range(len(example))]
113
114    def __repr__(self):
115        return '<DataSet(%s): %d examples, %d attributes>' % (
116            self.name, len(self.examples), len(self.attrs))
117
118class NearestNeighborLearner(object):
119
120    def __init__(self, k=1):
121        "k-NearestNeighbor: the k nearest neighbors vote."
122        self.k = k
123
124    def mean_boolean_error(self, predictions, targets):
125        return self.mean([(p!=t) for p, t in zip(predictions, targets)])
126
127    def predict(self, example):
128        """With k=1, find the point closest to example.
129        With k>1, find k closest, and have them vote for the best."""
130        if self.k == 1:
131            neighbor = self.argmin(self.dataset.examples,
132                              lambda e: self.distance(e, example))
133            return neighbor[self.dataset.target]
134        else:
135            ## Maintain a sorted list of (distance, example) pairs.
136            ## For very large k, a PriorityQueue would be better
137            best = []
138            for e in self.dataset.examples:
139                d = self.distance(e, example)
140                if len(best) < self.k:
141                    best.append((d, e))
142                elif d < best[-1][0]:
143                    best[-1] = (d, e)
144                    best.sort()
145
146            return self.mode([e[self.dataset.target] for (d, e) in best])
147
148    def distance(self, e1, e2):
149        return self.mean_boolean_error(e1, e2)
150
151    def train(self, dataset):
152        self.dataset = dataset
153
154    def histogram(self, values, mode=0, bin_function=None):
155        """Return a list of (value, count) pairs, summarizing the input values.
156        Sorted by increasing value, or if mode=1, by decreasing count.
157        If bin_function is given, map it over values first."""
158        if bin_function: values = map(bin_function, values)
159        bins = {}
160        for val in values:
161            bins[val] = bins.get(val, 0) + 1
162        if mode:
163            return sorted(bins.items(), key=lambda v: v[1], reverse=True)
164        else:
165            return sorted(bins.items())
166
167    def log2(self, x):
168        """Base 2 logarithm.
169        >>> log2(1024)
170        10.0
171        """
172        return math.log10(x) / math.log10(2)
173
174    def mode(self, values):
175        """Return the most common value in the list of values.
176        >>> mode([1, 2, 3, 2])
177        2
178        """
179        res = self.histogram(values, mode=1)
180        return res[0][0]
181
182    def mean(self, values):
183        """Return the arithmetic average of the values."""
184        return sum(values) / float(len(values))
185
186    def argmin(self, seq, fn):
187        """Return an element with lowest fn(seq[i]) score;
188        tie goes to first one.
189
190        >>> argmin(['one', 'to', 'three'], len)
191        'to'
192        """
193        best = seq[0]; best_score = fn(best)
194        for x in seq:
195            x_score = fn(x)
196            if x_score < best_score:
197                best, best_score = x, x_score
198        return best
199
Note: See TracBrowser for help on using the browser.