现在时间是:
当前位置:首 页 >> 数据分析>> 文章列表

NLP 分类及其综合评估(11)

作者:   发布时间:2016-03-23 13:14:39   浏览次数:1456

 

分类评估流程:

1. 将包含多个分类的整个测试集打包:data_set

2. 加载全部测试语料:data_set

   在TextPreprocess.py的TextPreprocess类中增加导出方法:
  1.  
  2.         #导出训练语料集
  3.         def load_trainset(self):
  4.                 file_obj = open(self.wordbag_path+self.trainset_name,'rb')
  5.                 self.data_set = pickle.load(file_obj)
  6.                 file_obj.close()                        
  7.  
复制代码

3. 导入停用词表和训练词袋模型

4. 计算测试集的tfidf_value特征

5. 应用linear_svm算法

6. 计算分类各种参数指标

在text_mining.py中增加评估方法:
  1.  
  2. # -*- coding: utf-8 -*-
  3.  
  4. import sys  
  5. import os 
  6. import warnings
  7. import numpy as np
  8. from sklearn import metrics
  9.  
  10. warnings.filterwarnings("ignore")
  11.  
  12. # 精度测试
  13. def calculate_accurate(actual,predict):
  14.         m_precision = metrics.accuracy_score(actual,predict)
  15.         print '结果计算:'
  16.         print '精度:{0:.3f}'.format(m_precision) 
  17.  
  18. # 召回,精度,f1测试
  19. def calculate_result(actual,predict):
  20.         m_precision = metrics.precision_score(actual,predict)
  21.         m_recall = metrics.recall_score(actual,predict)
  22.         print '结果计算:'  
  23.         print '精度:{0:.3f}'.format(m_precision)  
  24.         print '召回:{0:0.3f}'.format(m_recall)  
  25.         print 'f1-score:{0:.3f}'.format(metrics.f1_score(actual,predict))  
  26.  
  27. # 综合测试报告
  28. def test_report(actual,predicted,category):
  29.         print(metrics.classification_report(actual, predicted,target_names=category))
  30.  
复制代码


7. 输出综合测试报告


主程序:
  1.  
  2. # -*- coding: utf-8 -*-
  3.  
  4. import sys  
  5. import os 
  6. import numpy as np
  7. #引入Bunch类
  8. from sklearn.datasets.base import Bunch
  9. #引入持久化类
  10. import pickle
  11. from sklearn import feature_extraction  
  12. from sklearn.feature_extraction.text import TfidfTransformer  
  13. from sklearn.feature_extraction.text import TfidfVectorizer  
  14. from TextPreprocess import TextPreprocess  # 第一个是文件名,第二个是类名
  15. #导入线性核svm算法
  16. from sklearn.svm import LinearSVC
  17.  
  18. from text_mining import calculate_result,calculate_accurate,test_report
  19.  
  20.  
  21. # 配置utf-8输出环境
  22. reload(sys)
  23. sys.setdefaultencoding('utf-8')
  24.  
  25. # 测试语料预处理
  26. testsamp = TextPreprocess()
  27. #testsamp.corpus_path = "test_corpus1_small/"    #原始语料路径
  28. #testsamp.pos_path = "test_corpus1_pos/"       #预处理后语料路径
  29. # 测试语料预处理
  30. #testsamp.preprocess()
  31.  
  32.  
  33. testsamp.segment_path = "test_corpus1_segment/"   #分词后语料路径
  34. testsamp.stopword_path = "extra_dict/hlt_stop_words.txt"  #停止词路径
  35. # 为测试语料分词
  36. #testsamp.segment()
  37.  
  38. # 实际应用中可直接导入分词后测试语料
  39. testsamp.wordbag_path = "test_corpus1_wordbag/"   #词袋模型路径
  40. testsamp.trainset_name = "test_set.dat"      #训练集文件名
  41. # testsamp.train_bag()
  42.  
  43. # 加载全部测试语料
  44. testsamp.load_trainset()
  45.  
  46. #对测试文本进行tf-idf计算
  47. #从文件导入停用词表
  48. stpwrdlst = testsamp.getStopword(testsamp.stopword_path)
  49. print len(testsamp.data_set.contents)
  50.  
  51. # 导入训练词袋模型
  52. train_set = TextPreprocess()
  53. train_set.wordbag_path = "text_corpus1_wordbag/"
  54. train_set.wordbag_name = "word_bag.data"#词袋文件名
  55. train_set.load_wordbag()
  56. print train_set.wordbag.tdm.shape
  57.  
  58. # 计算测试集的tfidf_value特征
  59. fea_test = testsamp.tfidf_value(testsamp.data_set.contents,stpwrdlst,train_set.wordbag.vocabulary)
  60. print fea_test.shape
  61.  
  62. #应用linear_svm算法 输入词袋向量和分类标签
  63. #svclf = SVC(kernel = 'linear')   # default with 'rbf'
  64. svclf = LinearSVC(penalty="l1",dual=False, tol=1e-4)
  65. # 训练分类器
  66. svclf.fit(train_set.wordbag.tdm, train_set.wordbag.label)
  67. # 预测分类结果
  68. predicted = svclf.predict(fea_test)
  69.  
  70. # 测试集与训练集详细比较
  71. #i=0
  72. #for file_name,expct_cate in zip(testsamp.data_set.label,predicted):
  73. #        print "测试语料文件名:",testsamp.data_set.filenames[i],": 实际类别:",testsamp.data_set.target_name[testsamp.data_set.label[i]],"<-->预测类别:",train_set.wordbag.target_name[expct_cate]
  74. #        i +=1
  75.  
  76. # list转np.array
  77. actual=np.array(testsamp.data_set.label)
  78.  
  79. # 计算分类各种参数指标
  80. calculate_result(actual,predicted)
  81.  
  82. #综合测试报告
  83. test_report(actual,predicted,testsamp.data_set.target_name)
  84.  
复制代码


输出结果:

  1.  
  2.  
  3. 950
  4. (951, 42590)
  5. (950, 42590)
  6. 结果计算:
  7. 精度:0.888
  8. 召回:0.888
  9. f1-score:0.885
  10.  
  11.                precision    recall  f1-score   support
  12.  
  13.    automobile       0.95      0.90      0.93        42
  14.      computer       0.88      0.96      0.92       210
  15.     education       0.88      0.86      0.87        58
  16. entertainment       0.90      0.93      0.91       107
  17.        estate       0.95      0.93      0.94        67
  18.       finance       0.79      0.54      0.64        48
  19.        health       0.80      0.91      0.85       100
  20.     personnel       0.73      0.77      0.75        43
  21.        sports       0.98      0.98      0.98       201
  22.    technology       0.85      0.64      0.73        74
  23.  
  24.   avg / total       0.89      0.89      0.89       950
  25.  
  26.  
复制代码
 
 
 






上一篇:没有了    下一篇:没有了

Copyright ©2018    易一网络科技|www.yeayee.com All Right Reserved.

技术支持:自助建站 | 领地网站建设 |短信接口 版权所有 © 2005-2018 lingw.net.粤ICP备16125321号 -5