# -*- coding: utf8 -*-
# Copyright (c) 2006 Nuxeo SAS <http://nuxeo.com>
# Authors : Tarek Ziadé <tziade@nuxeo.com>
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
# 02111-1307, USA.
#
# $Id: classifier.py 47012 2006-07-10 08:57:07Z tziade $
"""Robinson-fisher method was taken from PopF
"""
import logging
import math
import operator

logger = logging.getLogger('BayesCore.classifier')

class BayesClassifier(object):

    def __init__(self, language, backend, tokenizer, **options):
        self.language = language
        self.backend = backend
        self.tokenizer = tokenizer
        self._learnt = True
        self._probs = None

        if options is None:
            self.options = {'lang': self.language}
        else:
            self.options = options
            self.options['lang'] = self.language

    def learn(self, data, category):
        """ learn data for a given category

        wich means: store words in categories
        """
        self._learnt = True
        self.backend.add_category(name=category)
        data = self.tokenizer.transform(data, self.options)
        for element in data:
            self.backend.add_word(element, self.language, category)

    def unlearn(self, data, category):
        """ ulearn data for a given category

        wich means: remove words from categories
        """
        self._learnt = True
        data = self.tokenizer.transform(data, self.options)
        for element in data:
            self.backend.del_word(element, category)

    def guess(self, data):
        """ guess a category for a given data """
        options = {'lang': self.language}
        data = self.tokenizer.transform(data, self.options)

        # XXX this will be cached
        if self._learnt:
            self._probs = self._buildWordProbabilities()
        else:
            if self._probs is None:
                self._probs = self._buildWordProbabilities()

        probabilities = self._probs
        self._learnt = False

        res = {}
        for category_name in probabilities:
            category_probs = probabilities[category_name]
            p = self.getProbs(category_probs, data)
            if len(p) != 0:
                res[category_name] = self._robinsonFisher(p, category_name)
        res = res.items()
        res.sort(lambda x,y: cmp(y[1], x[1]))
        return res

    def getProbs(self, category_probs, words):
        """ extracts the probabilities of tokens in a message
        """
        probs = [(word, category_probs[word])
                 for word in words if word in category_probs]
        probs.sort(lambda x,y: cmp(y[1],x[1]))
        return probs[:2048]

    def _robinsonFisher(self, probs, ignore):
        """ computes the probability of a message being spam (Robinson-Fisher method)
            H = C-1( -2.ln(prod(p)), 2*n )
            S = C-1( -2.ln(prod(1-p)), 2*n )
            I = (1 + H - S) / 2
            Courtesy of http://christophe.delord.free.fr/en/index.html
        """
        def _chi2P(chi, df):
            """ return P(chisq >= chi, with df degree of freedom)

            df must be even
            """
            assert df & 1 == 0
            m = chi / 2.0
            sum = term = math.exp(-m)
            for i in range(1, df/2):
                term *= m/i
                sum += term
            return min(sum, 1.0)

        n = len(probs)
        try:
            mlog = math.log(reduce(operator.mul, map(lambda p: p[1], probs), 1.0))
            H = _chi2P(-2.0 * mlog, 2*n)
        except OverflowError:
            H = 0.0
        try:
            mlog = math.log(reduce(operator.mul, map(lambda p: 1.0-p[1], probs), 1.0))
            S = _chi2P(-2.0 * mlog, 2*n)
        except OverflowError:
            S = 0.0
        return (1 + H - S) / 2

    def corpusSize(self, language=None):

        return self.backend.word_count(language=language)

    def categorySize(self, category, language=None):
        return self.backend.word_count(category=category, language=language)

    def _buildWordProbabilities(self, language=None):
        probs = {}
        corpus_size = self.corpusSize(language)
        words = list(self.backend.list_words(language, complete=True))
        for cat in self.backend.list_categories():

            probs[cat] = self._buildCategoryWordProbabilities(cat, language,
                                                              corpus_size, words)
        return probs

    def _buildCategoryWordProbabilities(self, category, language=None,
                                        corpus_size=None, words=None):
        """Merges corpora and computes probabilities

        XXX to be cached later (invalidation on word adding)
        """
        if corpus_size is None:
            corpus_size = self.corpusSize(language)
        if words is None:
            words = self.backend.list_words(language, complete=True)
        if language is None:
            language = self.language

        category_size = float(self.categorySize(category, language))
        them_count = float(max(corpus_size - category_size, 1))
        probabilities = {}

        for word in words:
            the_word = word[1][1]

            if category not in the_word.keys():
                continue

            word_count = float(word[1][2])
            if word_count == 0.0:
                continue

            cat_word_count = float(the_word[category])
            other_count = word_count - cat_word_count

            if category_size == 0:
                good_metric = 1.0
            else:
                good_metric = min(1.0, other_count/category_size)

            if them_count == 0:
                continue

            bad_metric = min(1.0, cat_word_count/them_count)

            try:
                f = bad_metric / (good_metric + bad_metric)
            except ZeroDivisionError:
                continue

            # PROBABILITY_THRESHOLD
            if abs(f-0.5) >= 0.1 :
                # GOOD_PROB, BAD_PROB
                probabilities[word[0]] = max(0.0001, min(0.9999, f))

        return probabilities

