| 1 | |
|---|
| 2 | # original code, http://aima.cs.berkeley.edu/python/learning.html |
|---|
| 3 | # a bit arranged |
|---|
| 4 | |
|---|
| 5 | class DataSet(object): |
|---|
| 6 | """A data set for a machine learning problem. It has the following fields: |
|---|
| 7 | |
|---|
| 8 | d.examples A list of examples. Each one is a list of attribute values. |
|---|
| 9 | d.attrs A list of integers to index into an example, so example[attr] |
|---|
| 10 | gives a value. Normally the same as range(len(d.examples)). |
|---|
| 11 | d.attrnames Optional list of mnemonic names for corresponding attrs. |
|---|
| 12 | d.target The attribute that a learning algorithm will try to predict. |
|---|
| 13 | By default the final attribute. |
|---|
| 14 | d.inputs The list of attrs without the target. |
|---|
| 15 | d.values A list of lists, each sublist is the set of possible |
|---|
| 16 | values for the corresponding attribute. If None, it |
|---|
| 17 | is computed from the known examples by self.setproblem. |
|---|
| 18 | If not None, an erroneous value raises ValueError. |
|---|
| 19 | d.name Name of the data set (for output display only). |
|---|
| 20 | d.source URL or other source where the data came from. |
|---|
| 21 | |
|---|
| 22 | Normally, you call the constructor and you're done; then you just |
|---|
| 23 | access fields like d.examples and d.target and d.inputs.""" |
|---|
| 24 | |
|---|
| 25 | def update(self, x, **entries): |
|---|
| 26 | """Update a dict; or an object with slots; according to entries. |
|---|
| 27 | >>> update({'a': 1}, a=10, b=20) |
|---|
| 28 | {'a': 10, 'b': 20} |
|---|
| 29 | >>> update(Struct(a=1), a=10, b=20) |
|---|
| 30 | Struct(a=10, b=20) |
|---|
| 31 | """ |
|---|
| 32 | if isinstance(x, dict): |
|---|
| 33 | x.update(entries) |
|---|
| 34 | else: |
|---|
| 35 | x.__dict__.update(entries) |
|---|
| 36 | return x |
|---|
| 37 | |
|---|
| 38 | def __init__(self, examples=None, attrs=None, target=-1, values=None, |
|---|
| 39 | attrnames=None, name='', source='', |
|---|
| 40 | inputs=None, exclude=(), doc=''): |
|---|
| 41 | """Accepts any of DataSet's fields. Examples can |
|---|
| 42 | also be a string or file from which to parse examples using parse_csv. |
|---|
| 43 | Ex: DataSet(examples='1 2 3')""" |
|---|
| 44 | self.update(self, name=name, source=source, values=values) |
|---|
| 45 | # Initialize .examples from string or list or data directory |
|---|
| 46 | if examples is None: |
|---|
| 47 | examples = [] |
|---|
| 48 | |
|---|
| 49 | if isinstance(examples, str): |
|---|
| 50 | self.examples = parse_csv(examples) |
|---|
| 51 | |
|---|
| 52 | #elif examples is None: |
|---|
| 53 | # self.examples = parse_csv(DataFile(name+'.csv').read()) |
|---|
| 54 | else: |
|---|
| 55 | self.examples = examples |
|---|
| 56 | if len(self.examples) > 0: |
|---|
| 57 | map(self.check_example, self.examples) |
|---|
| 58 | # Attrs are the indicies of examples, unless otherwise stated. |
|---|
| 59 | if not attrs: |
|---|
| 60 | if len(self.examples) > 0: |
|---|
| 61 | attrs = range(len(self.examples[0])) |
|---|
| 62 | else: |
|---|
| 63 | attrs = range(len(attrnames)) |
|---|
| 64 | self.attrs = attrs |
|---|
| 65 | # Initialize .attrnames from string, list, or by default |
|---|
| 66 | if isinstance(attrnames, str): |
|---|
| 67 | self.attrnames = attrnames.split() |
|---|
| 68 | else: |
|---|
| 69 | self.attrnames = attrnames or attrs |
|---|
| 70 | self.setproblem(target, inputs=inputs, exclude=exclude) |
|---|
| 71 | |
|---|
| 72 | def setproblem(self, target, inputs=None, exclude=()): |
|---|
| 73 | """Set (or change) the target and/or inputs. |
|---|
| 74 | This way, one DataSet can be used multiple ways. inputs, if specified, |
|---|
| 75 | is a list of attributes, or specify exclude as a list of attributes |
|---|
| 76 | to not put use in inputs. Attributes can be -n .. n, or an attrname. |
|---|
| 77 | Also computes the list of possible values, if that wasn't done yet.""" |
|---|
| 78 | self.target = self.attrnum(target) |
|---|
| 79 | exclude = map(self.attrnum, exclude) |
|---|
| 80 | if inputs: |
|---|
| 81 | self.inputs = removall(self.target, inputs) |
|---|
| 82 | else: |
|---|
| 83 | self.inputs = [a for a in self.attrs |
|---|
| 84 | if a is not self.target and a not in exclude] |
|---|
| 85 | #if not self.values: # amir |
|---|
| 86 | # self.values = map(unique, zip(*self.examples)) |
|---|
| 87 | |
|---|
| 88 | def add_example(self, example): |
|---|
| 89 | """Add an example to the list of examples, checking it first.""" |
|---|
| 90 | self.check_example(example) |
|---|
| 91 | self.examples.append(example) |
|---|
| 92 | |
|---|
| 93 | def check_example(self, example): |
|---|
| 94 | """Raise ValueError if example has any invalid values.""" |
|---|
| 95 | if self.values: |
|---|
| 96 | for a in self.attrs: |
|---|
| 97 | if example[a] not in self.values[a]: |
|---|
| 98 | raise ValueError('Bad value %s for attribute %s in %s' % |
|---|
| 99 | (example[a], self.attrnames[a], example)) |
|---|
| 100 | |
|---|
| 101 | def attrnum(self, attr): |
|---|
| 102 | "Returns the number used for attr, which can be a name, or -n .. n." |
|---|
| 103 | if attr < 0: |
|---|
| 104 | return len(self.attrs) + attr |
|---|
| 105 | elif isinstance(attr, str): |
|---|
| 106 | return self.attrnames.index(attr) |
|---|
| 107 | else: |
|---|
| 108 | return attr |
|---|
| 109 | |
|---|
| 110 | def sanitize(self, example): |
|---|
| 111 | "Return a copy of example, with non-input attributes replaced by 0." |
|---|
| 112 | return [i in self.inputs and example[i] for i in range(len(example))] |
|---|
| 113 | |
|---|
| 114 | def __repr__(self): |
|---|
| 115 | return '<DataSet(%s): %d examples, %d attributes>' % ( |
|---|
| 116 | self.name, len(self.examples), len(self.attrs)) |
|---|
| 117 | |
|---|
| 118 | class NearestNeighborLearner(object): |
|---|
| 119 | |
|---|
| 120 | def __init__(self, k=1): |
|---|
| 121 | "k-NearestNeighbor: the k nearest neighbors vote." |
|---|
| 122 | self.k = k |
|---|
| 123 | |
|---|
| 124 | def mean_boolean_error(self, predictions, targets): |
|---|
| 125 | return self.mean([(p!=t) for p, t in zip(predictions, targets)]) |
|---|
| 126 | |
|---|
| 127 | def predict(self, example): |
|---|
| 128 | """With k=1, find the point closest to example. |
|---|
| 129 | With k>1, find k closest, and have them vote for the best.""" |
|---|
| 130 | if self.k == 1: |
|---|
| 131 | neighbor = self.argmin(self.dataset.examples, |
|---|
| 132 | lambda e: self.distance(e, example)) |
|---|
| 133 | return neighbor[self.dataset.target] |
|---|
| 134 | else: |
|---|
| 135 | ## Maintain a sorted list of (distance, example) pairs. |
|---|
| 136 | ## For very large k, a PriorityQueue would be better |
|---|
| 137 | best = [] |
|---|
| 138 | for e in self.dataset.examples: |
|---|
| 139 | d = self.distance(e, example) |
|---|
| 140 | if len(best) < self.k: |
|---|
| 141 | best.append((d, e)) |
|---|
| 142 | elif d < best[-1][0]: |
|---|
| 143 | best[-1] = (d, e) |
|---|
| 144 | best.sort() |
|---|
| 145 | |
|---|
| 146 | return self.mode([e[self.dataset.target] for (d, e) in best]) |
|---|
| 147 | |
|---|
| 148 | def distance(self, e1, e2): |
|---|
| 149 | return self.mean_boolean_error(e1, e2) |
|---|
| 150 | |
|---|
| 151 | def train(self, dataset): |
|---|
| 152 | self.dataset = dataset |
|---|
| 153 | |
|---|
| 154 | def histogram(self, values, mode=0, bin_function=None): |
|---|
| 155 | """Return a list of (value, count) pairs, summarizing the input values. |
|---|
| 156 | Sorted by increasing value, or if mode=1, by decreasing count. |
|---|
| 157 | If bin_function is given, map it over values first.""" |
|---|
| 158 | if bin_function: values = map(bin_function, values) |
|---|
| 159 | bins = {} |
|---|
| 160 | for val in values: |
|---|
| 161 | bins[val] = bins.get(val, 0) + 1 |
|---|
| 162 | if mode: |
|---|
| 163 | return sorted(bins.items(), key=lambda v: v[1], reverse=True) |
|---|
| 164 | else: |
|---|
| 165 | return sorted(bins.items()) |
|---|
| 166 | |
|---|
| 167 | def log2(self, x): |
|---|
| 168 | """Base 2 logarithm. |
|---|
| 169 | >>> log2(1024) |
|---|
| 170 | 10.0 |
|---|
| 171 | """ |
|---|
| 172 | return math.log10(x) / math.log10(2) |
|---|
| 173 | |
|---|
| 174 | def mode(self, values): |
|---|
| 175 | """Return the most common value in the list of values. |
|---|
| 176 | >>> mode([1, 2, 3, 2]) |
|---|
| 177 | 2 |
|---|
| 178 | """ |
|---|
| 179 | res = self.histogram(values, mode=1) |
|---|
| 180 | return res[0][0] |
|---|
| 181 | |
|---|
| 182 | def mean(self, values): |
|---|
| 183 | """Return the arithmetic average of the values.""" |
|---|
| 184 | return sum(values) / float(len(values)) |
|---|
| 185 | |
|---|
| 186 | def argmin(self, seq, fn): |
|---|
| 187 | """Return an element with lowest fn(seq[i]) score; |
|---|
| 188 | tie goes to first one. |
|---|
| 189 | |
|---|
| 190 | >>> argmin(['one', 'to', 'three'], len) |
|---|
| 191 | 'to' |
|---|
| 192 | """ |
|---|
| 193 | best = seq[0]; best_score = fn(best) |
|---|
| 194 | for x in seq: |
|---|
| 195 | x_score = fn(x) |
|---|
| 196 | if x_score < best_score: |
|---|
| 197 | best, best_score = x, x_score |
|---|
| 198 | return best |
|---|
| 199 | |
|---|