1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
| from math import log import operator
def calcShannonent(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob*log(prob, 2) return shannonEnt
def createDataSet_temp(): labels = ['头发', '声音']
with open('TreeGrowth_ID3.txt', 'r', encoding='UTF-8') as f: dataSet = [[] for i in range(9)] value = ['长', '短', '粗', '细', '男', '女'] num_line = 0 for line in f: for i in line: if (i in value): dataSet[num_line].append(i) num_line += 1 del(dataSet[0]) ''' dataSet = [ ['长', '粗', '男'], ['短', '粗', '男'], ['短', '粗', '男'], ['长', '细', '女'], ['短', '细', '女'], ['短', '粗', '女'], ['长', '粗', '女'], ['长', '粗', '女'] ] ''' return dataSet, labels
def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet
def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0])-1 baseEntropy = calcShannonent(dataSet) bestInfoGain = 0 bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntropy = 0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob*calcShannonent(subDataSet) infoGain = baseEntropy - newEntropy if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature
def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0]
def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]==len(classList)): return classList[0] if len(dataSet[0])==1: return majorityCnt((classList)) bestFeat = chooseBestFeatureToSplit(dataSet) if(bestFeat == -1): return classList[0] bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree
if __name__=='__main__': dataSet, labels = createDataSet_temp() print(createTree(dataSet, labels))
|