003《Python数据分析、挖掘与可视化(第2版)》/关联规则分析.py
from itertools import chain, combinations
from openpyxl import load_workbook

def loadDataSet():
    '''加载数据,返回包含若干集合的列表'''
    # 返回的数据格式为 [{1, 3, 4}, {2, 3, 5}, {1, 2, 3, 5}, {2, 5}]
    result = []
    # xlsx文件中有3列,分别为电影名称、导演名称、演员清单
    # 同一个电影的多个主演演员使用逗号分隔
    ws = load_workbook('电影导演演员.xlsx').worksheets[0]
    for index, row in enumerate(ws.rows):
        # 跳过第一行表头
        if index==0:
            continue
        result.append(set(row[2].value.split(',')))
    return result

def createC1(dataSet):
    '''dataSet为包含集合的列表,每个集合表示一个项集
       返回包含若干元组的列表,
       每个元组为只包含一个物品的项集,所有项集不重复'''
    return sorted(map(lambda i:(i,), set(chain(*dataSet))))

def scanD(dataSet, Ck, Lk, minSupport):
    '''dataSet为包含集合的列表,每个集合表示一个项集
       ck为候选项集列表,每个元素为元组
       minSupport为最小支持度阈值
       返回Ck中支持度大于等于minSupport的那些项集'''
    # 数据集总数量
    total = len(dataSet)
    supportData = {}
    for candidate in Ck:
        # 加速,k-频繁项集的所有k-1子集都应该是频繁项集
        if Lk and (not all(map(lambda item: item in Lk,
                                 combinations(candidate,
                                               len(candidate)-1)))):
            continue
        # 遍历每个候选项集,统计该项集在所有数据集中出现的次数
        # 这里隐含了一个技巧:True在内部存储为1
        set_candidate = set(candidate)
        frequencies = sum(map(lambda item: set_candidate<=item,
                                dataSet))
        # 计算支持度
        t = frequencies/total
        # 大于等于最小支持度,保留该项集及其支持度
        if t >= minSupport:
            supportData[candidate] = t
    return supportData

def aprioriGen(Lk, k):
    '''根据k项集生成k+1项集'''
    result = []
    for index, item1 in enumerate(Lk):
        for item2 in Lk[index+1:]:
            # 只合并前k-2项相同的项集,避免生成重复项集
            # 例如,(1,3)和(2,5)不会合并,
            # (2,3)和(2,5)会合并为(2,3,5),
            # (2,3)和(3,5)不会合并,
            # (2,3)、(2,5)、(3,5)只能得到一个项集(2,3,5)
            if sorted(item1[:k-2]) == sorted(item2[:k-2]):
                result.append(tuple(set(item1)|set(item2)))
    return result

def apriori(dataSet, minSupport=0.5):
    '''根据给定数据集dataSet,
       返回所有支持度>=minSupport的频繁项集'''
    C1 = createC1(dataSet)
    supportData = scanD(dataSet, C1, None, minSupport)
    k = 2
    while True:
        # 获取满足最小支持度的k项集
        Lk = [key for key in supportData if len(key)==k-1]
        # 合并生成k+1项集
        Ck = aprioriGen(Lk, k)
        # 筛选满足最小支持度的k+1项集
        supK = scanD(dataSet, Ck, Lk, minSupport)
        # 无法再生成包含更多项的项集,算法结束
        if not supK:
            break
        supportData.update(supK)
        k = k+1
    return supportData

def findRules(supportData, minConfidence=0.5):
    '''查找满足最小置信度的关联规则'''
    # 对频繁项集按长度降序排列
    supportDataL = sorted(supportData.items(),
                            key=lambda item:len(item[0]),
                            reverse=True)
    rules = []
    for index, pre in enumerate(supportDataL):
        for aft in supportDataL[index+1:]:
            # 只查找k-1项集到k项集的关联规则
            if len(aft[0]) < len(pre[0])-1:
                break
            # 当前项集aft[0]是pre[0]的子集
            # 且aft[0]==>pre[0]的置信度大于等于最小置信度阈值
            if set(aft[0])<set(pre[0]) and\
               pre[1]/aft[1]>=minConfidence:
                rules.append([pre[0],aft[0]])
    return rules

# 加载数据
dataSet = loadDataSet()
# 获取所有支持度大于0.2的项集
supportData = apriori(dataSet, 0.2)
# 在所有频繁项集中查找并输出关系较好的演员二人组合
bestPair = [item for item in supportData if len(item)==2]
print(bestPair)

# 查找支持度大于0.6的强关联规则
for item in findRules(supportData, 0.6):
    pre, aft = map(set, item)
    print(aft, pre-aft, sep='==>')