
# original code, http://aima.cs.berkeley.edu/python/learning.html
# a bit arranged

class DataSet(object):
    """A data set for a machine learning problem.  It has the following fields:

    d.examples    A list of examples.  Each one is a list of attribute values.
    d.attrs       A list of integers to index into an example, so example[attr]
                  gives a value. Normally the same as range(len(d.examples)).
    d.attrnames   Optional list of mnemonic names for corresponding attrs.
    d.target      The attribute that a learning algorithm will try to predict.
                  By default the final attribute.
    d.inputs      The list of attrs without the target.
    d.values      A list of lists, each sublist is the set of possible
                  values for the corresponding attribute. If None, it
                  is computed from the known examples by self.setproblem.
                  If not None, an erroneous value raises ValueError.
    d.name        Name of the data set (for output display only).
    d.source      URL or other source where the data came from.

    Normally, you call the constructor and you're done; then you just
    access fields like d.examples and d.target and d.inputs."""

    def update(self, x, **entries):
        """Update a dict; or an object with slots; according to entries.
        >>> update({'a': 1}, a=10, b=20)
        {'a': 10, 'b': 20}
        >>> update(Struct(a=1), a=10, b=20)
        Struct(a=10, b=20)
        """
        if isinstance(x, dict):
            x.update(entries)
        else:
            x.__dict__.update(entries)
        return x

    def __init__(self, examples=None, attrs=None, target=-1, values=None,
                 attrnames=None, name='', source='',
                 inputs=None, exclude=(), doc=''):
        """Accepts any of DataSet's fields.  Examples can
        also be a string or file from which to parse examples using parse_csv.
        Ex: DataSet(examples='1 2 3')"""
        self.update(self, name=name, source=source, values=values)
        # Initialize .examples from string or list or data directory
        if examples is None:
            examples = []

        if isinstance(examples, str):
            self.examples = parse_csv(examples)

        #elif examples is None:
        #    self.examples = parse_csv(DataFile(name+'.csv').read())
        else:
            self.examples = examples
        if len(self.examples) > 0:
            map(self.check_example, self.examples)
        # Attrs are the indicies of examples, unless otherwise stated.
        if not attrs:
            if len(self.examples) > 0:
                attrs = range(len(self.examples[0]))
            else:
                attrs = range(len(attrnames))
        self.attrs = attrs
        # Initialize .attrnames from string, list, or by default
        if isinstance(attrnames, str):
            self.attrnames = attrnames.split()
        else:
            self.attrnames = attrnames or attrs
        self.setproblem(target, inputs=inputs, exclude=exclude)

    def setproblem(self, target, inputs=None, exclude=()):
        """Set (or change) the target and/or inputs.
        This way, one DataSet can be used multiple ways. inputs, if specified,
        is a list of attributes, or specify exclude as a list of attributes
        to not put use in inputs. Attributes can be -n .. n, or an attrname.
        Also computes the list of possible values, if that wasn't done yet."""
        self.target = self.attrnum(target)
        exclude = map(self.attrnum, exclude)
        if inputs:
            self.inputs = removall(self.target, inputs)
        else:
            self.inputs = [a for a in self.attrs
                           if a is not self.target and a not in exclude]
        #if not self.values: # amir
        #    self.values = map(unique, zip(*self.examples))

    def add_example(self, example):
        """Add an example to the list of examples, checking it first."""
        self.check_example(example)
        self.examples.append(example)

    def check_example(self, example):
        """Raise ValueError if example has any invalid values."""
        if self.values:
            for a in self.attrs:
                if example[a] not in self.values[a]:
                    raise ValueError('Bad value %s for attribute %s in %s' %
                                     (example[a], self.attrnames[a], example))

    def attrnum(self, attr):
        "Returns the number used for attr, which can be a name, or -n .. n."
        if attr < 0:
            return len(self.attrs) + attr
        elif isinstance(attr, str):
            return self.attrnames.index(attr)
        else:
            return attr

    def sanitize(self, example):
       "Return a copy of example, with non-input attributes replaced by 0."
       return [i in self.inputs and example[i] for i in range(len(example))]

    def __repr__(self):
        return '<DataSet(%s): %d examples, %d attributes>' % (
            self.name, len(self.examples), len(self.attrs))

class NearestNeighborLearner(object):

    def __init__(self, k=1):
        "k-NearestNeighbor: the k nearest neighbors vote."
        self.k = k

    def mean_boolean_error(self, predictions, targets):
        return self.mean([(p!=t) for p, t in zip(predictions, targets)])

    def predict(self, example):
        """With k=1, find the point closest to example.
        With k>1, find k closest, and have them vote for the best."""
        if self.k == 1:
            neighbor = self.argmin(self.dataset.examples,
                              lambda e: self.distance(e, example))
            return neighbor[self.dataset.target]
        else:
            ## Maintain a sorted list of (distance, example) pairs.
            ## For very large k, a PriorityQueue would be better
            best = []
            for e in self.dataset.examples:
                d = self.distance(e, example)
                if len(best) < self.k:
                    best.append((d, e))
                elif d < best[-1][0]:
                    best[-1] = (d, e)
                    best.sort()

            return self.mode([e[self.dataset.target] for (d, e) in best])

    def distance(self, e1, e2):
        return self.mean_boolean_error(e1, e2)

    def train(self, dataset):
        self.dataset = dataset

    def histogram(self, values, mode=0, bin_function=None):
        """Return a list of (value, count) pairs, summarizing the input values.
        Sorted by increasing value, or if mode=1, by decreasing count.
        If bin_function is given, map it over values first."""
        if bin_function: values = map(bin_function, values)
        bins = {}
        for val in values:
            bins[val] = bins.get(val, 0) + 1
        if mode:
            return sorted(bins.items(), key=lambda v: v[1], reverse=True)
        else:
            return sorted(bins.items())

    def log2(self, x):
        """Base 2 logarithm.
        >>> log2(1024)
        10.0
        """
        return math.log10(x) / math.log10(2)

    def mode(self, values):
        """Return the most common value in the list of values.
        >>> mode([1, 2, 3, 2])
        2
        """
        res = self.histogram(values, mode=1)
        return res[0][0]

    def mean(self, values):
        """Return the arithmetic average of the values."""
        return sum(values) / float(len(values))

    def argmin(self, seq, fn):
        """Return an element with lowest fn(seq[i]) score;
        tie goes to first one.

        >>> argmin(['one', 'to', 'three'], len)
        'to'
        """
        best = seq[0]; best_score = fn(best)
        for x in seq:
            x_score = fn(x)
            if x_score < best_score:
                best, best_score = x, x_score
        return best


