diff --git a/search/search_index.json b/search/search_index.json index 4f9e66e..5396582 100644 --- a/search/search_index.json +++ b/search/search_index.json @@ -1 +1 @@ -{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"AAAMLP-CN \u65b0\u7279\u6027 2023.10.25 \ud83d\ude0e\u5b8c\u6210\u5168\u90e8\u7ffb\u8bd1 \ud83d\udcdd\u8ba1\u5212\u5bf9kaggle\u6e38\u4e50\u56ed\u7cfb\u5217\u4f18\u79c0\u89e3\u51b3\u65b9\u6848\u4ee3\u7801\u8fdb\u884c\u89e3\u6790 2023.09.07 \u26a1\u4fee\u6b63\u90e8\u5206\u5df2\u77e5\u6587\u5b57\u9519\u8bef\u548c\u4ee3\u7801\u9519\u8bef \ud83e\udd17\u6dfb\u52a0 \u5728\u7ebf\u6587\u6863 \u7ffb\u8bd1\u8fdb\u7a0b 2023.09.12 \u6dfb\u52a0\u7ae0\u8282\uff1a \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u3001 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5 \u7b80\u4ecb Abhishek Thakur\uff0c\u5f88\u591a kaggler \u5bf9\u4ed6\u90fd\u975e\u5e38\u719f\u6089\uff0c2017 \u5e74\uff0c\u4ed6\u5728 Linkedin \u53d1\u8868\u4e86\u4e00\u7bc7\u540d\u4e3a Approaching (Almost) Any Machine Learning Problem \u7684\u6587\u7ae0\uff0c\u4ecb\u7ecd\u4ed6\u5efa\u7acb\u7684\u4e00\u4e2a\u81ea\u52a8\u7684\u673a\u5668\u5b66\u4e60\u6846\u67b6\uff0c\u51e0\u4e4e\u53ef\u4ee5\u89e3\u51b3\u4efb\u4f55\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u8fd9\u7bc7\u6587\u7ae0\u66fe\u706b\u904d Kaggle\u3002 Abhishek \u5728 Kaggle \u4e0a\u7684\u6210\u5c31\uff1a Competitions Grandmaster\uff0817 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 3\uff09 Kernels Expert \uff08Kagglers \u6392\u540d\u524d 1\uff05\uff09 Discussion Grandmaster\uff0865 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 2\uff09 \u76ee\u524d\uff0cAbhishek \u5728\u632a\u5a01 boost \u516c\u53f8\u62c5\u4efb\u9996\u5e2d\u6570\u636e\u79d1\u5b66\u5bb6\u7684\u804c\u4f4d\uff0c\u8fd9\u662f\u4e00\u5bb6\u4e13\u95e8\u4ece\u4e8b\u4f1a\u8bdd\u4eba\u5de5\u667a\u80fd\u7684\u8f6f\u4ef6\u516c\u53f8\u3002 \u672c\u6587\u5bf9 Approaching (Almost) Any Machine Learning Problem \u8fdb\u884c\u4e86 \u4e2d\u6587\u7ffb\u8bd1 \uff0c\u7531\u4e8e\u672c\u4eba\u6c34\u5e73\u6709\u9650\uff0c\u4e14\u672a\u4f7f\u7528\u673a\u5668\u7ffb\u8bd1\uff0c\u53ef\u80fd\u6709\u90e8\u5206\u8a00\u8bed\u4e0d\u901a\u987a\u6216\u672c\u571f\u5316\u7a0b\u5ea6\u4e0d\u8db3\uff0c\u4e5f\u8bf7\u5927\u5bb6\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u591a\u63d0\u4f9b\u5b9d\u8d35\u610f\u89c1\u3002\u53e6\u9644\u4e0a\u4e66\u7c4d\u539f \u9879\u76ee\u5730\u5740 \uff0c \u8f6c\u8f7d\u8bf7\u4e00\u5b9a\u6807\u660e\u51fa\u5904\uff01 \u672c\u9879\u76ee \u652f\u6301\u5728\u7ebf\u9605\u8bfb \uff0c\u65b9\u4fbf\u60a8\u968f\u65f6\u968f\u5730\u8fdb\u884c\u67e5\u9605\u3002 \u56e0\u4e3a\u6709\u51e0\u7ae0\u5185\u5bb9\u592a\u8fc7\u57fa\u7840\uff0c\u6240\u4ee5\u672a\u8fdb\u884c\u7ffb\u8bd1\uff0c\u8be6\u7ec6\u60c5\u51b5\u8bf7\u53c2\u7167\u4e66\u7c4d\u76ee\u5f55\uff1a \u51c6\u5907\u73af\u5883\uff08\u5df2\u7ffb\u8bd1\uff09 \u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u4ea4\u53c9\u68c0\u9a8c\uff08\u5df2\u7ffb\u8bd1\uff09 \u8bc4\u4f30\u6307\u6807\uff08\u5df2\u7ffb\u8bd1\uff09 - \u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u5904\u7406\u5206\u7c7b\u53d8\u91cf\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u5de5\u7a0b\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u9009\u62e9\uff08\u5df2\u7ffb\u8bd1\uff09 \u8d85\u53c2\u6570\u4f18\u5316\uff08\u5df2\u7ffb\u8bd1\uff09 \u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6211\u5c06\u4f1a\u628a\u5b8c\u6574\u7684\u7ffb\u8bd1\u7248 Markdown \u6587\u4ef6\u4e0a\u4f20\u5230 GitHub\uff0c\u4ee5\u4f9b\u5927\u5bb6\u514d\u8d39\u4e0b\u8f7d\u548c\u9605\u8bfb\u3002\u4e3a\u4e86\u6700\u4f73\u7684\u9605\u8bfb\u4f53\u9a8c\uff0c\u63a8\u8350\u4f7f\u7528 PDF \u683c\u5f0f\u6216\u662f\u5728\u7ebf\u9605\u8bfb\u8fdb\u884c\u67e5\u770b \u82e5\u60a8\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u53d1\u73b0\u4efb\u4f55\u9519\u8bef\u6216\u4e0d\u51c6\u786e\u4e4b\u5904\uff0c\u975e\u5e38\u6b22\u8fce\u901a\u8fc7\u63d0\u4ea4 Issue \u6216 Pull Request \u6765\u534f\u52a9\u6211\u8fdb\u884c\u4fee\u6b63\u3002 \u968f\u7740\u65f6\u95f4\u63a8\u79fb\uff0c\u6211\u53ef\u80fd\u4f1a \u7ee7\u7eed\u7ffb\u8bd1\u5c1a\u672a\u5b8c\u6210\u7684\u7ae0\u8282 \u3002\u5982\u679c\u60a8\u89c9\u5f97\u8fd9\u4e2a\u9879\u76ee\u5bf9\u60a8\u6709\u5e2e\u52a9\uff0c\u8bf7\u4e0d\u541d\u7ed9\u4e88 Star \u6216\u8005\u8fdb\u884c\u5173\u6ce8\u3002","title":"\u524d\u8a00"},{"location":"#aaamlp-cn","text":"","title":"AAAMLP-CN"},{"location":"#_1","text":"","title":"\u65b0\u7279\u6027"},{"location":"#20231025","text":"\ud83d\ude0e\u5b8c\u6210\u5168\u90e8\u7ffb\u8bd1 \ud83d\udcdd\u8ba1\u5212\u5bf9kaggle\u6e38\u4e50\u56ed\u7cfb\u5217\u4f18\u79c0\u89e3\u51b3\u65b9\u6848\u4ee3\u7801\u8fdb\u884c\u89e3\u6790","title":"2023.10.25"},{"location":"#20230907","text":"\u26a1\u4fee\u6b63\u90e8\u5206\u5df2\u77e5\u6587\u5b57\u9519\u8bef\u548c\u4ee3\u7801\u9519\u8bef \ud83e\udd17\u6dfb\u52a0 \u5728\u7ebf\u6587\u6863","title":"2023.09.07"},{"location":"#_2","text":"2023.09.12 \u6dfb\u52a0\u7ae0\u8282\uff1a \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u3001 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5","title":"\u7ffb\u8bd1\u8fdb\u7a0b"},{"location":"#_3","text":"Abhishek Thakur\uff0c\u5f88\u591a kaggler \u5bf9\u4ed6\u90fd\u975e\u5e38\u719f\u6089\uff0c2017 \u5e74\uff0c\u4ed6\u5728 Linkedin \u53d1\u8868\u4e86\u4e00\u7bc7\u540d\u4e3a Approaching (Almost) Any Machine Learning Problem \u7684\u6587\u7ae0\uff0c\u4ecb\u7ecd\u4ed6\u5efa\u7acb\u7684\u4e00\u4e2a\u81ea\u52a8\u7684\u673a\u5668\u5b66\u4e60\u6846\u67b6\uff0c\u51e0\u4e4e\u53ef\u4ee5\u89e3\u51b3\u4efb\u4f55\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u8fd9\u7bc7\u6587\u7ae0\u66fe\u706b\u904d Kaggle\u3002 Abhishek \u5728 Kaggle \u4e0a\u7684\u6210\u5c31\uff1a Competitions Grandmaster\uff0817 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 3\uff09 Kernels Expert \uff08Kagglers \u6392\u540d\u524d 1\uff05\uff09 Discussion Grandmaster\uff0865 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 2\uff09 \u76ee\u524d\uff0cAbhishek \u5728\u632a\u5a01 boost \u516c\u53f8\u62c5\u4efb\u9996\u5e2d\u6570\u636e\u79d1\u5b66\u5bb6\u7684\u804c\u4f4d\uff0c\u8fd9\u662f\u4e00\u5bb6\u4e13\u95e8\u4ece\u4e8b\u4f1a\u8bdd\u4eba\u5de5\u667a\u80fd\u7684\u8f6f\u4ef6\u516c\u53f8\u3002 \u672c\u6587\u5bf9 Approaching (Almost) Any Machine Learning Problem \u8fdb\u884c\u4e86 \u4e2d\u6587\u7ffb\u8bd1 \uff0c\u7531\u4e8e\u672c\u4eba\u6c34\u5e73\u6709\u9650\uff0c\u4e14\u672a\u4f7f\u7528\u673a\u5668\u7ffb\u8bd1\uff0c\u53ef\u80fd\u6709\u90e8\u5206\u8a00\u8bed\u4e0d\u901a\u987a\u6216\u672c\u571f\u5316\u7a0b\u5ea6\u4e0d\u8db3\uff0c\u4e5f\u8bf7\u5927\u5bb6\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u591a\u63d0\u4f9b\u5b9d\u8d35\u610f\u89c1\u3002\u53e6\u9644\u4e0a\u4e66\u7c4d\u539f \u9879\u76ee\u5730\u5740 \uff0c \u8f6c\u8f7d\u8bf7\u4e00\u5b9a\u6807\u660e\u51fa\u5904\uff01 \u672c\u9879\u76ee \u652f\u6301\u5728\u7ebf\u9605\u8bfb \uff0c\u65b9\u4fbf\u60a8\u968f\u65f6\u968f\u5730\u8fdb\u884c\u67e5\u9605\u3002 \u56e0\u4e3a\u6709\u51e0\u7ae0\u5185\u5bb9\u592a\u8fc7\u57fa\u7840\uff0c\u6240\u4ee5\u672a\u8fdb\u884c\u7ffb\u8bd1\uff0c\u8be6\u7ec6\u60c5\u51b5\u8bf7\u53c2\u7167\u4e66\u7c4d\u76ee\u5f55\uff1a \u51c6\u5907\u73af\u5883\uff08\u5df2\u7ffb\u8bd1\uff09 \u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u4ea4\u53c9\u68c0\u9a8c\uff08\u5df2\u7ffb\u8bd1\uff09 \u8bc4\u4f30\u6307\u6807\uff08\u5df2\u7ffb\u8bd1\uff09 - \u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u5904\u7406\u5206\u7c7b\u53d8\u91cf\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u5de5\u7a0b\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u9009\u62e9\uff08\u5df2\u7ffb\u8bd1\uff09 \u8d85\u53c2\u6570\u4f18\u5316\uff08\u5df2\u7ffb\u8bd1\uff09 \u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6211\u5c06\u4f1a\u628a\u5b8c\u6574\u7684\u7ffb\u8bd1\u7248 Markdown \u6587\u4ef6\u4e0a\u4f20\u5230 GitHub\uff0c\u4ee5\u4f9b\u5927\u5bb6\u514d\u8d39\u4e0b\u8f7d\u548c\u9605\u8bfb\u3002\u4e3a\u4e86\u6700\u4f73\u7684\u9605\u8bfb\u4f53\u9a8c\uff0c\u63a8\u8350\u4f7f\u7528 PDF \u683c\u5f0f\u6216\u662f\u5728\u7ebf\u9605\u8bfb\u8fdb\u884c\u67e5\u770b \u82e5\u60a8\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u53d1\u73b0\u4efb\u4f55\u9519\u8bef\u6216\u4e0d\u51c6\u786e\u4e4b\u5904\uff0c\u975e\u5e38\u6b22\u8fce\u901a\u8fc7\u63d0\u4ea4 Issue \u6216 Pull Request \u6765\u534f\u52a9\u6211\u8fdb\u884c\u4fee\u6b63\u3002 \u968f\u7740\u65f6\u95f4\u63a8\u79fb\uff0c\u6211\u53ef\u80fd\u4f1a \u7ee7\u7eed\u7ffb\u8bd1\u5c1a\u672a\u5b8c\u6210\u7684\u7ae0\u8282 \u3002\u5982\u679c\u60a8\u89c9\u5f97\u8fd9\u4e2a\u9879\u76ee\u5bf9\u60a8\u6709\u5e2e\u52a9\uff0c\u8bf7\u4e0d\u541d\u7ed9\u4e88 Star \u6216\u8005\u8fdb\u884c\u5173\u6ce8\u3002","title":"\u7b80\u4ecb"},{"location":"%E4%BA%A4%E5%8F%89%E6%A3%80%E9%AA%8C/","text":"\u4ea4\u53c9\u68c0\u9a8c \u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u6ca1\u6709\u5efa\u7acb\u4efb\u4f55\u6a21\u578b\u3002\u539f\u56e0\u5f88\u7b80\u5355\uff0c\u5728\u521b\u5efa\u4efb\u4f55\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4ee5\u53ca\u5982\u4f55\u6839\u636e\u6570\u636e\u96c6\u9009\u62e9\u6700\u4f73\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\u3002 \u90a3\u4e48\uff0c\u4ec0\u4e48\u662f \u4ea4\u53c9\u68c0\u9a8c \uff0c\u6211\u4eec\u4e3a\u4ec0\u4e48\u8981\u5173\u6ce8\u5b83\uff1f \u5173\u4e8e\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u53ef\u4ee5\u627e\u5230\u591a\u79cd\u5b9a\u4e49\u3002\u6211\u7684\u5b9a\u4e49\u53ea\u6709\u4e00\u53e5\u8bdd\uff1a\u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fc7\u7a0b\u4e2d\u7684\u4e00\u4e2a\u6b65\u9aa4\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u786e\u4fdd\u6a21\u578b\u51c6\u786e\u62df\u5408\u6570\u636e\uff0c\u540c\u65f6\u786e\u4fdd\u6211\u4eec\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u4f46\u8fd9\u53c8\u5f15\u51fa\u4e86\u53e6\u4e00\u4e2a\u8bcd\uff1a \u8fc7\u62df\u5408 \u3002 \u8981\u89e3\u91ca\u8fc7\u62df\u5408\uff0c\u6211\u8ba4\u4e3a\u6700\u597d\u5148\u770b\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6709\u4e00\u4e2a\u76f8\u5f53\u6709\u540d\u7684\u7ea2\u9152\u8d28\u91cf\u6570\u636e\u96c6\uff08 red wine quality dataset \uff09\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 11 \u4e2a\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u8fd9\u4e9b\u7279\u5f81\u51b3\u5b9a\u4e86\u7ea2\u9152\u7684\u8d28\u91cf\u3002 \u8fd9\u4e9b\u5c5e\u6027\u5305\u62ec\uff1a \u56fa\u5b9a\u9178\u5ea6\uff08fixed acidity\uff09 \u6325\u53d1\u6027\u9178\u5ea6\uff08volatile acidity\uff09 \u67e0\u6aac\u9178\uff08citric acid\uff09 \u6b8b\u7559\u7cd6\uff08residual sugar\uff09 \u6c2f\u5316\u7269\uff08chlorides\uff09 \u6e38\u79bb\u4e8c\u6c27\u5316\u786b\uff08free sulfur dioxide\uff09 \u4e8c\u6c27\u5316\u786b\u603b\u91cf\uff08total sulfur dioxide\uff09 \u5bc6\u5ea6\uff08density\uff09 PH \u503c\uff08pH\uff09 \u786b\u9178\u76d0\uff08sulphates\uff09 \u9152\u7cbe\uff08alcohol\uff09 \u6839\u636e\u8fd9\u4e9b\u4e0d\u540c\u7279\u5f81\uff0c\u6211\u4eec\u9700\u8981\u9884\u6d4b\u7ea2\u8461\u8404\u9152\u7684\u8d28\u91cf\uff0c\u8d28\u91cf\u503c\u4ecb\u4e8e 0 \u5230 10 \u4e4b\u95f4\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u600e\u6837\u7684\u3002 import pandas as pd df = pd . read_csv ( \"winequality-red.csv\" ) \u56fe 1:\u7ea2\u8461\u8404\u9152\u8d28\u91cf\u6570\u636e\u96c6\u7b80\u5355\u5c55\u793a \u6211\u4eec\u53ef\u4ee5\u5c06\u8fd9\u4e2a\u95ee\u9898\u89c6\u4e3a\u5206\u7c7b\u95ee\u9898\uff0c\u4e5f\u53ef\u4ee5\u89c6\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u9009\u62e9\u5206\u7c7b\u3002\u7136\u800c\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u503c\u5305\u542b 6 \u79cd\u8d28\u91cf\u503c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u6240\u6709\u8d28\u91cf\u503c\u6620\u5c04\u5230 0 \u5230 5 \u4e4b\u95f4\u3002 # \u4e00\u4e2a\u6620\u5c04\u5b57\u5178\uff0c\u7528\u4e8e\u5c06\u8d28\u91cf\u503c\u4ece 0 \u5230 5 \u8fdb\u884c\u6620\u5c04 quality_mapping = { 3 : 0 , 4 : 1 , 5 : 2 , 6 : 3 , 7 : 4 , 8 : 5 } # \u4f60\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 map \u51fd\u6570\u4ee5\u53ca\u4efb\u4f55\u5b57\u5178\uff0c # \u6765\u8f6c\u6362\u7ed9\u5b9a\u5217\u4e2d\u7684\u503c\u4e3a\u5b57\u5178\u4e2d\u7684\u503c df . loc [:, \"quality\" ] = df . quality . map ( quality_mapping ) \u5f53\u6211\u4eec\u770b\u5927\u8fd9\u4e9b\u6570\u636e\u5e76\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u6211\u4eec\u8111\u6d77\u4e2d\u4f1a\u6d6e\u73b0\u51fa\u5f88\u591a\u53ef\u4ee5\u5e94\u7528\u7684\u7b97\u6cd5\uff0c\u4e5f\u8bb8\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u4ece\u4e00\u5f00\u59cb\u5c31\u6df1\u5165\u7814\u7a76\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u5c31\u6709\u70b9\u7275\u5f3a\u4e86\u3002\u6240\u4ee5\uff0c\u8ba9\u6211\u4eec\u4ece\u7b80\u5355\u7684\u3001\u6211\u4eec\u4e5f\u80fd\u53ef\u89c6\u5316\u7684\u4e1c\u897f\u5f00\u59cb\uff1a\u51b3\u7b56\u6811\u3002 \u5728\u5f00\u59cb\u4e86\u89e3\u4ec0\u4e48\u662f\u8fc7\u62df\u5408\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 1599 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u4fdd\u7559 1000 \u4e2a\u6837\u672c\u7528\u4e8e\u8bad\u7ec3\uff0c599 \u4e2a\u6837\u672c\u4f5c\u4e3a\u4e00\u4e2a\u5355\u72ec\u7684\u96c6\u5408\u3002 \u4ee5\u4e0b\u4ee3\u7801\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u5212\u5206\uff1a # \u4f7f\u7528 frac=1 \u7684 sample \u65b9\u6cd5\u6765\u6253\u4e71 dataframe # \u7531\u4e8e\u6253\u4e71\u540e\u7d22\u5f15\u4f1a\u6539\u53d8\uff0c\u6240\u4ee5\u6211\u4eec\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u9009\u53d6\u524d 1000 \u884c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e df_train = df . head ( 1000 ) # \u9009\u53d6\u6700\u540e\u7684 599 \u884c\u4f5c\u4e3a\u6d4b\u8bd5/\u9a8c\u8bc1\u6570\u636e df_test = df . tail ( 599 ) \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u5728\u8bad\u7ec3\u96c6\u4e0a\u4f7f\u7528 scikit-learn \u8bad\u7ec3\u4e00\u4e2a\u51b3\u7b56\u6811\u6a21\u578b\u3002 # \u4ece scikit-learn \u5bfc\u5165\u9700\u8981\u7684\u6a21\u5757 from sklearn import tree from sklearn import metrics # \u521d\u59cb\u5316\u4e00\u4e2a\u51b3\u7b56\u6811\u5206\u7c7b\u5668\uff0c\u8bbe\u7f6e\u6700\u5927\u6df1\u5ea6\u4e3a 3 clf = tree . DecisionTreeClassifier ( max_depth = 3 ) # \u9009\u62e9\u4f60\u60f3\u8981\u8bad\u7ec3\u6a21\u578b\u7684\u5217 # \u8fd9\u4e9b\u5217\u4f5c\u4e3a\u6a21\u578b\u7684\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u4f7f\u7528\u4e4b\u524d\u6620\u5c04\u7684\u8d28\u91cf\u4ee5\u53ca\u63d0\u4f9b\u7684\u7279\u5f81\u6765\u8bad\u7ec3\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5c06\u51b3\u7b56\u6811\u5206\u7c7b\u5668\u7684\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u8bbe\u4e3a 3\u3002\u8be5\u6a21\u578b\u7684\u6240\u6709\u5176\u4ed6\u53c2\u6570\u5747\u4fdd\u6301\u9ed8\u8ba4\u503c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u6d4b\u8bd5\u8be5\u6a21\u578b\u7684\u51c6\u786e\u6027\uff1a # \u5728\u8bad\u7ec3\u96c6\u4e0a\u751f\u6210\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) # \u5728\u6d4b\u8bd5\u96c6\u4e0a\u751f\u6210\u9884\u6d4b test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) # \u8ba1\u7b97\u6d4b\u8bd5\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) \u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u5206\u522b\u4e3a 58.9%\u548c 54.25%\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u589e\u52a0\u5230 7\uff0c\u5e76\u91cd\u590d\u4e0a\u8ff0\u8fc7\u7a0b\u3002\u8fd9\u6837\uff0c\u8bad\u7ec3\u51c6\u786e\u7387\u4e3a 76.6%\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4e3a 57.3%\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4f7f\u7528\u51c6\u786e\u7387\uff0c\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u662f\u6700\u76f4\u63a5\u7684\u6307\u6807\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u95ee\u9898\u6765\u8bf4\uff0c\u5b83\u53ef\u80fd\u4e0d\u662f\u6700\u597d\u7684\u6307\u6807\u3002\u6211\u4eec\u53ef\u4ee5\u6839\u636e\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u4e0d\u540c\u503c\u6765\u8ba1\u7b97\u8fd9\u4e9b\u51c6\u786e\u7387\uff0c\u5e76\u7ed8\u5236\u66f2\u7ebf\u56fe\u3002 # \u6ce8\u610f\uff1a\u8fd9\u6bb5\u4ee3\u7801\u5728 Jupyter \u7b14\u8bb0\u672c\u4e2d\u7f16\u5199 # \u5bfc\u5165 scikit-learn \u7684 tree \u548c metrics from sklearn import tree from sklearn import metrics # \u5bfc\u5165 matplotlib \u548c seaborn # \u7528\u4e8e\u7ed8\u56fe import matplotlib import matplotlib.pyplot as plt import seaborn as sns # \u8bbe\u7f6e\u5168\u5c40\u6807\u7b7e\u6587\u672c\u7684\u5927\u5c0f matplotlib . rc ( 'xtick' , labelsize = 20 ) matplotlib . rc ( 'ytick' , labelsize = 20 ) # \u786e\u4fdd\u56fe\u8868\u76f4\u63a5\u5728\u7b14\u8bb0\u672c\u5185\u663e\u793a % matplotlib inline # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6\u7684\u5217\u8868 # \u6211\u4eec\u4ece 50% \u7684\u51c6\u786e\u5ea6\u5f00\u59cb train_accuracies = [ 0.5 ] test_accuracies = [ 0.5 ] # \u904d\u5386\u51e0\u4e2a\u4e0d\u540c\u7684\u6811\u6df1\u5ea6\u503c for depth in range ( 1 , 25 ): # \u521d\u59cb\u5316\u6a21\u578b clf = tree . DecisionTreeClassifier ( max_depth = depth ) # \u9009\u62e9\u7528\u4e8e\u8bad\u7ec3\u7684\u5217/\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u5728\u7ed9\u5b9a\u7279\u5f81\u4e0a\u62df\u5408\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) # \u521b\u5efa\u8bad\u7ec3\u548c\u6d4b\u8bd5\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) # \u6dfb\u52a0\u51c6\u786e\u5ea6\u5230\u5217\u8868 train_accuracies . append ( train_accuracy ) test_accuracies . append ( test_accuracy ) # \u4f7f\u7528 matplotlib \u548c seaborn \u521b\u5efa\u4e24\u4e2a\u56fe plt . figure ( figsize = ( 10 , 5 )) sns . set_style ( \"whitegrid\" ) plt . plot ( train_accuracies , label = \"train accuracy\" ) plt . plot ( test_accuracies , label = \"test accuracy\" ) plt . legend ( loc = \"upper left\" , prop = { 'size' : 15 }) plt . xticks ( range ( 0 , 26 , 5 )) plt . xlabel ( \"max_depth\" , size = 20 ) plt . ylabel ( \"accuracy\" , size = 20 ) plt . show () \u8fd9\u5c06\u751f\u6210\u5982\u56fe 2 \u6240\u793a\u7684\u66f2\u7ebf\u56fe\u3002 \u56fe 2\uff1a\u4e0d\u540c max_depth \u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u7387\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5f53\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u503c\u4e3a 14 \u65f6\uff0c\u6d4b\u8bd5\u6570\u636e\u7684\u5f97\u5206\u6700\u9ad8\u3002\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u589e\u52a0\u8fd9\u4e2a\u53c2\u6570\u7684\u503c\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4f1a\u4fdd\u6301\u4e0d\u53d8\u6216\u53d8\u5dee\uff0c\u4f46\u8bad\u7ec3\u51c6\u786e\u7387\u4f1a\u4e0d\u65ad\u63d0\u9ad8\u3002\u8fd9\u8bf4\u660e\uff0c\u968f\u7740\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u589e\u52a0\uff0c\u51b3\u7b56\u6811\u6a21\u578b\u5bf9\u8bad\u7ec3\u6570\u636e\u7684\u5b66\u4e60\u6548\u679c\u8d8a\u6765\u8d8a\u597d\uff0c\u4f46\u6d4b\u8bd5\u6570\u636e\u7684\u6027\u80fd\u5374\u4e1d\u6beb\u6ca1\u6709\u63d0\u9ad8\u3002 \u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8fc7\u62df\u5408 \u3002 \u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u5b8c\u5168\u62df\u5408\uff0c\u800c\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5374\u8868\u73b0\u4e0d\u4f73\u3002\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5f88\u597d\u5730\u5b66\u4e60\u8bad\u7ec3\u6570\u636e\uff0c\u4f46\u65e0\u6cd5\u6cdb\u5316\u5230\u672a\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u3002\u5728\u4e0a\u9762\u7684\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u975e\u5e38\u9ad8\u7684\u6a21\u578b\uff0c\u5b83\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u4f1a\u6709\u51fa\u8272\u7684\u7ed3\u679c\uff0c\u4f46\u8fd9\u79cd\u6a21\u578b\u5e76\u4e0d\u5b9e\u7528\uff0c\u56e0\u4e3a\u5b83\u5728\u771f\u5b9e\u4e16\u754c\u7684\u6837\u672c\u6216\u5b9e\u65f6\u6570\u636e\u4e0a\u4e0d\u4f1a\u63d0\u4f9b\u7c7b\u4f3c\u7684\u7ed3\u679c\u3002 \u6709\u4eba\u53ef\u80fd\u4f1a\u8bf4\uff0c\u8fd9\u79cd\u65b9\u6cd5\u5e76\u6ca1\u6709\u8fc7\u62df\u5408\uff0c\u56e0\u4e3a\u6d4b\u8bd5\u96c6\u7684\u51c6\u786e\u7387\u57fa\u672c\u4fdd\u6301\u4e0d\u53d8\u3002\u8fc7\u62df\u5408\u7684\u53e6\u4e00\u4e2a\u5b9a\u4e49\u662f\uff0c\u5f53\u6211\u4eec\u4e0d\u65ad\u63d0\u9ad8\u8bad\u7ec3\u635f\u5931\u65f6\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u5728\u589e\u52a0\u3002\u8fd9\u79cd\u60c5\u51b5\u5728\u795e\u7ecf\u7f51\u7edc\u4e2d\u975e\u5e38\u5e38\u89c1\u3002 \u6bcf\u5f53\u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u90fd\u5fc5\u987b\u5728\u8bad\u7ec3\u671f\u95f4\u76d1\u63a7\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u3002\u5982\u679c\u6211\u4eec\u6709\u4e00\u4e2a\u975e\u5e38\u5927\u7684\u7f51\u7edc\u6765\u5904\u7406\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u6570\u636e\u96c6\uff08\u5373\u6837\u672c\u6570\u975e\u5e38\u5c11\uff09\uff0c\u6211\u4eec\u5c31\u4f1a\u89c2\u5bdf\u5230\uff0c\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u8bad\u7ec3\uff0c\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u90fd\u4f1a\u51cf\u5c11\u3002\u4f46\u662f\uff0c\u5728\u67d0\u4e2a\u65f6\u523b\uff0c\u6d4b\u8bd5\u635f\u5931\u4f1a\u8fbe\u5230\u6700\u5c0f\u503c\uff0c\u4e4b\u540e\uff0c\u5373\u4f7f\u8bad\u7ec3\u635f\u5931\u8fdb\u4e00\u6b65\u51cf\u5c11\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u4f1a\u5f00\u59cb\u589e\u52a0\u3002\u6211\u4eec\u5fc5\u987b\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u6700\u5c0f\u503c\u65f6\u505c\u6b62\u8bad\u7ec3\u3002 \u8fd9\u662f\u5bf9\u8fc7\u62df\u5408\u6700\u5e38\u89c1\u7684\u89e3\u91ca \u3002 \u5965\u5361\u59c6\u5243\u5200\u7528\u7b80\u5355\u7684\u8bdd\u8bf4\uff0c\u5c31\u662f\u4e0d\u8981\u8bd5\u56fe\u628a\u53ef\u4ee5\u7528\u7b80\u5355\u5f97\u591a\u7684\u65b9\u6cd5\u89e3\u51b3\u7684\u4e8b\u60c5\u590d\u6742\u5316\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6848\u5c31\u662f\u6700\u5177\u901a\u7528\u6027\u7684\u89e3\u51b3\u65b9\u6848\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u53ea\u8981\u4f60\u7684\u6a21\u578b\u4e0d\u7b26\u5408\u5965\u5361\u59c6\u5243\u5200\u539f\u5219\uff0c\u5c31\u5f88\u53ef\u80fd\u662f\u8fc7\u62df\u5408\u3002 \u56fe 3\uff1a\u8fc7\u62df\u5408\u7684\u6700\u4e00\u822c\u5b9a\u4e49 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u56de\u5230\u4ea4\u53c9\u68c0\u9a8c\u3002 \u5728\u89e3\u91ca\u8fc7\u62df\u5408\u65f6\uff0c\u6211\u51b3\u5b9a\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u6211\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u53e6\u4e00\u90e8\u5206\u4e0a\u68c0\u67e5\u5176\u6027\u80fd\u3002\u8fd9\u4e5f\u662f\u4ea4\u53c9\u68c0\u9a8c\u7684\u4e00\u79cd\uff0c\u901a\u5e38\u88ab\u79f0\u4e3a \"\u6682\u7559\u96c6\"\uff08 hold-out set \uff09\u3002\u5f53\u6211\u4eec\u62e5\u6709\u5927\u91cf\u6570\u636e\uff0c\u800c\u6a21\u578b\u63a8\u7406\u662f\u4e00\u4e2a\u8017\u65f6\u7684\u8fc7\u7a0b\u65f6\uff0c\u6211\u4eec\u5c31\u4f1a\u4f7f\u7528\u8fd9\u79cd\uff08\u4ea4\u53c9\uff09\u9a8c\u8bc1\u3002 \u4ea4\u53c9\u68c0\u9a8c\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\uff0c\u5b83\u662f\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u6b65\u9aa4\u3002 \u9009\u62e9\u6b63\u786e\u7684\u4ea4\u53c9\u68c0\u9a8c \u53d6\u51b3\u4e8e\u6240\u5904\u7406\u7684\u6570\u636e\u96c6\uff0c\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u9002\u7528\u7684\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u53ef\u80fd\u4e0d\u9002\u7528\u4e8e\u5176\u4ed6\u6570\u636e\u96c6\u3002\u4e0d\u8fc7\uff0c\u6709\u51e0\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u6700\u4e3a\u6d41\u884c\u548c\u5e7f\u6cdb\u4f7f\u7528\u3002 \u5176\u4e2d\u5305\u62ec\uff1a k \u6298\u4ea4\u53c9\u68c0\u9a8c \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c \u5206\u7ec4 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u4ea4\u53c9\u68c0\u9a8c\u662f\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u5c42\u51e0\u4e2a\u90e8\u5206\uff0c\u6211\u4eec\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u5176\u4f59\u90e8\u5206\u4e0a\u8fdb\u884c\u6d4b\u8bd5\u3002\u8bf7\u770b\u56fe 4\u3002 \u56fe 4\uff1a\u5c06\u6570\u636e\u96c6\u62c6\u5206\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u56fe 4 \u548c\u56fe 5 \u8bf4\u660e\uff0c\u5f53\u4f60\u5f97\u5230\u4e00\u4e2a\u6570\u636e\u96c6\u6765\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u65f6\uff0c\u4f60\u4f1a\u628a\u5b83\u4eec\u5206\u6210 \u4e24\u4e2a\u4e0d\u540c\u7684\u96c6\uff1a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u3002\u5f88\u591a\u4eba\u8fd8\u4f1a\u5c06\u5176\u5206\u6210\u7b2c\u4e09\u7ec4\uff0c\u79f0\u4e4b\u4e3a\u6d4b\u8bd5\u96c6\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u5c06\u53ea\u4f7f\u7528\u4e24\u4e2a\u96c6\u3002\u5982\u4f60\u6240\u89c1\uff0c\u6211\u4eec\u5c06\u6837\u672c\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\u8fdb\u884c\u4e86\u5212\u5206\u3002\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a k \u4e2a\u4e92\u4e0d\u5173\u8054\u7684\u4e0d\u540c\u96c6\u5408\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002 \u56fe 5\uff1aK \u6298\u4ea4\u53c9\u68c0\u9a8c \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 KFold \u5c06\u4efb\u4f55\u6570\u636e\u5206\u5272\u6210 k \u4e2a\u76f8\u7b49\u7684\u90e8\u5206\u3002\u6bcf\u4e2a\u6837\u672c\u5206\u914d\u4e00\u4e2a\u4ece 0 \u5230 k-1 \u7684\u503c\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u5b58\u50a8\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a kfold \u7684\u65b0\u5217\uff0c\u5e76\u7528 -1 \u586b\u5145 df [ \"kfold\" ] = - 1 # \u63a5\u4e0b\u6765\u7684\u6b65\u9aa4\u662f\u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4ece model_selection \u6a21\u5757\u521d\u59cb\u5316 kfold \u7c7b kf = model_selection . KFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684 kfold \u5217\uff08enumerate\u7684\u4f5c\u7528\u662f\u8fd4\u56de\u4e00\u4e2a\u8fed\u4ee3\u5668\uff09 for fold , ( trn_ , val_ ) in enumerate ( kf . split ( X = df )): df . loc [ val_ , 'kfold' ] = fold # \u4fdd\u5b58\u5e26\u6709 kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u636e\u96c6\u90fd\u53ef\u4ee5\u4f7f\u7528\u6b64\u6d41\u7a0b\u3002\u4f8b\u5982\uff0c\u5f53\u6570\u636e\u56fe\u50cf\u65f6\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u56fe\u50cf ID\u3001\u56fe\u50cf\u4f4d\u7f6e\u548c\u56fe\u50cf\u6807\u7b7e\u7684 CSV\uff0c\u7136\u540e\u4f7f\u7528\u4e0a\u8ff0\u6d41\u7a0b\u3002 \u53e6\u4e00\u79cd\u91cd\u8981\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u662f \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u6570\u636e\u96c6\uff0c\u5176\u4e2d\u6b63\u6837\u672c\u5360 90%\uff0c\u8d1f\u6837\u672c\u53ea\u5360 10%\uff0c\u90a3\u4e48\u4f60\u5c31\u4e0d\u5e94\u8be5\u4f7f\u7528\u968f\u673a k \u6298\u4ea4\u53c9\u3002\u5bf9\u8fd9\u6837\u7684\u6570\u636e\u96c6\u4f7f\u7528\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u6298\u53e0\u6837\u672c\u5168\u90e8\u4e3a\u8d1f\u6837\u672c\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u66f4\u503e\u5411\u4e8e\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u4ee5\u4fdd\u6301\u6bcf\u4e2a\u6298\u4e2d\u6807\u7b7e\u7684\u6bd4\u4f8b\u4e0d\u53d8\u3002\u56e0\u6b64\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u90fd\u4f1a\u6709\u76f8\u540c\u7684 90% \u6b63\u6837\u672c\u548c 10% \u8d1f\u6837\u672c\u3002\u56e0\u6b64\uff0c\u65e0\u8bba\u60a8\u9009\u62e9\u4ec0\u4e48\u6307\u6807\u8fdb\u884c\u8bc4\u4f30\uff0c\u90fd\u4f1a\u5728\u6240\u6709\u6298\u53e0\u4e2d\u5f97\u5230\u76f8\u4f3c\u7684\u7ed3\u679c\u3002 \u4fee\u6539\u521b\u5efa k \u6298\u4ea4\u53c9\u68c0\u9a8c\u7684\u4ee3\u7801\u4ee5\u521b\u5efa\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5f88\u5bb9\u6613\u3002\u6211\u4eec\u53ea\u9700\u5c06 model_selection.KFold \u66f4\u6539\u4e3a model_selection.StratifiedKFold \uff0c\u5e76\u5728 kf.split(...) \u51fd\u6570\u4e2d\u6307\u5b9a\u8981\u5206\u5c42\u7684\u76ee\u6807\u5217\u3002\u6211\u4eec\u5047\u8bbe CSV \u6570\u636e\u96c6\u6709\u4e00\u5217\u540d\u4e3a \"target\" \uff0c\u5e76\u4e14\u662f\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u4fdd\u5b58\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6dfb\u52a0\u4e00\u4e2a\u65b0\u5217 kfold\uff0c\u5e76\u7528 -1 \u521d\u59cb\u5316 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u53d8\u91cf y = df . target . values # \u521d\u59cb\u5316 StratifiedKFold \u7c7b\uff0c\u8bbe\u7f6e\u6298\u6570\uff08folds\uff09\u4e3a 5 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u4f7f\u7528 StratifiedKFold \u5bf9\u8c61\u7684 split \u65b9\u6cd5\u6765\u83b7\u53d6\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7d22\u5f15 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u5305\u542b kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u5bf9\u4e8e\u8461\u8404\u9152\u6570\u636e\u96c6\uff0c\u6211\u4eec\u6765\u770b\u770b\u6807\u7b7e\u7684\u5206\u5e03\u60c5\u51b5\u3002 b = sns . countplot ( x = 'quality' , data = df ) b . set_xlabel ( \"quality\" , fontsize = 20 ) b . set_ylabel ( \"count\" , fontsize = 20 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u7ee7\u7eed\u4e0a\u9762\u7684\u4ee3\u7801\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u8f6c\u6362\u4e86\u76ee\u6807\u503c\u3002\u4ece\u56fe 6 \u4e2d\u6211\u4eec\u53ef\u4ee5\u770b\u51fa\uff0c\u8d28\u91cf\u504f\u5dee\u5f88\u5927\u3002\u6709\u4e9b\u7c7b\u522b\u6709\u5f88\u591a\u6837\u672c\uff0c\u6709\u4e9b\u5219\u6ca1\u6709\u90a3\u4e48\u591a\u3002\u5982\u679c\u6211\u4eec\u8fdb\u884c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u90a3\u4e48\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u76ee\u6807\u503c\u5206\u5e03\u90fd\u4e0d\u4f1a\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u56fe 6\uff1a\u8461\u8404\u9152\u6570\u636e\u96c6\u4e2d \"\u8d28\u91cf\" \u5206\u5e03\u60c5\u51b5 \u89c4\u5219\u5f88\u7b80\u5355\uff0c\u5982\u679c\u662f\u6807\u51c6\u5206\u7c7b\u95ee\u9898\uff0c\u5c31\u76f2\u76ee\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u4f46\u5982\u679c\u6570\u636e\u91cf\u5f88\u5927\uff0c\u8be5\u600e\u4e48\u529e\u5462\uff1f\u5047\u8bbe\u6211\u4eec\u6709 100 \u4e07\u4e2a\u6837\u672c\u30025 \u500d\u4ea4\u53c9\u68c0\u9a8c\u610f\u5473\u7740\u5728 800k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u5728 200k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u9a8c\u8bc1\u3002\u6839\u636e\u6211\u4eec\u9009\u62e9\u7684\u7b97\u6cd5\uff0c\u5bf9\u4e8e\u8fd9\u6837\u89c4\u6a21\u7684\u6570\u636e\u96c6\u6765\u8bf4\uff0c\u8bad\u7ec3\u751a\u81f3\u9a8c\u8bc1\u90fd\u53ef\u80fd\u975e\u5e38\u6602\u8d35\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u3002 \u521b\u5efa\u4fdd\u6301\u7ed3\u679c\u7684\u8fc7\u7a0b\u4e0e\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u76f8\u540c\u3002\u5bf9\u4e8e\u62e5\u6709 100 \u4e07\u4e2a\u6837\u672c\u7684\u6570\u636e\u96c6\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa 10 \u4e2a\u6298\u53e0\u800c\u4e0d\u662f 5 \u4e2a\uff0c\u5e76\u4fdd\u7559\u5176\u4e2d\u4e00\u4e2a\u6298\u53e0\u4f5c\u4e3a\u4fdd\u7559\u6837\u672c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u6211\u4eec\u5c06\u6709 10 \u4e07\u4e2a\u6837\u672c\u88ab\u4fdd\u7559\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u59cb\u7ec8\u5728\u8fd9\u4e2a\u6837\u672c\u96c6\u4e0a\u8ba1\u7b97\u635f\u5931\u3001\u51c6\u786e\u7387\u548c\u5176\u4ed6\u6307\u6807\uff0c\u5e76\u5728 90 \u4e07\u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002 \u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u6682\u7559\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u975e\u5e38\u5e38\u7528\u3002\u5047\u8bbe\u6211\u4eec\u8981\u89e3\u51b3\u7684\u95ee\u9898\u662f\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97 2020 \u5e74\u7684\u9500\u552e\u989d\uff0c\u800c\u6211\u4eec\u5f97\u5230\u7684\u662f 2015-2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f60\u53ef\u4ee5\u9009\u62e9 2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u4f5c\u4e3a\u4fdd\u7559\u6570\u636e\uff0c\u7136\u540e\u5728 2015 \u5e74\u81f3 2018 \u5e74\u7684\u6240\u6709\u6570\u636e\u4e0a\u8bad\u7ec3\u4f60\u7684\u6a21\u578b\u3002 \u56fe 7\uff1a\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u793a\u4f8b \u5728\u56fe 7 \u6240\u793a\u7684\u793a\u4f8b\u4e2d\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u4ece\u65f6\u95f4\u6b65\u9aa4 31 \u5230 40 \u7684\u9500\u552e\u989d\u3002\u6211\u4eec\u53ef\u4ee5\u4fdd\u7559 21 \u81f3 30 \u6b65\u7684\u6570\u636e\uff0c\u7136\u540e\u4ece 0 \u6b65\u5230 20 \u6b65\u8bad\u7ec3\u6a21\u578b\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5728\u9884\u6d4b 31 \u6b65\u81f3 40 \u6b65\u65f6\uff0c\u5e94\u5c06 21 \u6b65\u81f3 30 \u6b65\u7684\u6570\u636e\u7eb3\u5165\u6a21\u578b\uff0c\u5426\u5219\uff0c\u6a21\u578b\u7684\u6027\u80fd\u5c06\u5927\u6253\u6298\u6263\u3002 \u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u5fc5\u987b\u5904\u7406\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u800c\u521b\u5efa\u5927\u578b\u9a8c\u8bc1\u96c6\u610f\u5473\u7740\u6a21\u578b\u5b66\u4e60\u4f1a\u4e22\u5931\u5927\u91cf\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9\u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c\uff0c\u76f8\u5f53\u4e8e\u7279\u6b8a\u7684 k \u5219\u4ea4\u53c9\u68c0\u9a8c\u5176\u4e2d k=N \uff0cN \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8fd9\u610f\u5473\u7740\u5728\u6240\u6709\u7684\u8bad\u7ec3\u6298\u53e0\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u9664 1 \u4e4b\u5916\u7684\u6240\u6709\u6570\u636e\u6837\u672c\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u7684\u6298\u53e0\u6570\u4e0e\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u76f8\u540c\u3002 \u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u6a21\u578b\u7684\u901f\u5ea6\u4e0d\u591f\u5feb\uff0c\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\uff0c\u4f46\u7531\u4e8e\u8fd9\u79cd\u4ea4\u53c9\u68c0\u9a8c\u53ea\u9002\u7528\u4e8e\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u56e0\u6b64\u5e76\u4e0d\u91cd\u8981\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u95ee\u9898\u4e86\u3002\u56de\u5f52\u95ee\u9898\u7684\u597d\u5904\u5728\u4e8e\uff0c\u9664\u4e86\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e4b\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u5728\u56de\u5f52\u95ee\u9898\u4e0a\u4f7f\u7528\u4e0a\u8ff0\u6240\u6709\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6211\u4eec\u4e0d\u80fd\u76f4\u63a5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u6709\u4e00\u4e9b\u65b9\u6cd5\u53ef\u4ee5\u7a0d\u7a0d\u6539\u53d8\u95ee\u9898\uff0c\u4ece\u800c\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u9002\u7528\u4e8e\u4efb\u4f55\u56de\u5f52\u95ee\u9898\u3002\u4f46\u662f\uff0c\u5982\u679c\u53d1\u73b0\u76ee\u6807\u5206\u5e03\u4e0d\u4e00\u81f4\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u8981\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u5fc5\u987b\u5148\u5c06\u76ee\u6807\u5212\u5206\u4e3a\u82e5\u5e72\u4e2a\u5206\u5c42\uff0c\u7136\u540e\u518d\u4ee5\u5904\u7406\u5206\u7c7b\u95ee\u9898\u7684\u76f8\u540c\u65b9\u5f0f\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u9009\u62e9\u5408\u9002\u7684\u5206\u5c42\u6570\u6709\u51e0\u79cd\u9009\u62e9\u3002\u5982\u679c\u6837\u672c\u91cf\u5f88\u5927\uff08> 10k\uff0c> 100k\uff09\uff0c\u90a3\u4e48\u5c31\u4e0d\u9700\u8981\u8003\u8651\u5206\u5c42\u7684\u6570\u91cf\u3002\u53ea\u9700\u5c06\u6570\u636e\u5206\u4e3a 10 \u6216 20 \u5c42\u5373\u53ef\u3002\u5982\u679c\u6837\u672c\u6570\u4e0d\u591a\uff0c\u5219\u53ef\u4ee5\u4f7f\u7528 Sturge's Rule \u8fd9\u6837\u7684\u7b80\u5355\u89c4\u5219\u6765\u8ba1\u7b97\u9002\u5f53\u7684\u5206\u5c42\u6570\u3002 Sturge's Rule\uff1a \\[ Number of Bins = 1 + log_2(N) \\] \u5176\u4e2d \\(N\\) \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8be5\u51fd\u6570\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u5229\u7528\u65af\u7279\u683c\u6cd5\u5219\u7ed8\u5236\u6837\u672c\u4e0e\u7bb1\u6570\u5bf9\u6bd4\u56fe \u8ba9\u6211\u4eec\u5236\u4f5c\u4e00\u4e2a\u56de\u5f52\u6570\u636e\u96c6\u6837\u672c\uff0c\u5e76\u5c1d\u8bd5\u5e94\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 # stratified-kfold for regression # \u4e3a\u56de\u5f52\u95ee\u9898\u8fdb\u884c\u5206\u5c42K-\u6298\u4ea4\u53c9\u9a8c\u8bc1 # \u5bfc\u5165\u9700\u8981\u7684\u5e93 import numpy as np import pandas as pd from sklearn import datasets from sklearn import model_selection # \u521b\u5efa\u5206\u6298\uff08folds\uff09\u7684\u51fd\u6570 def create_folds ( data ): # \u521b\u5efa\u4e00\u4e2a\u65b0\u5217\u53eb\u505akfold\uff0c\u5e76\u7528-1\u6765\u586b\u5145 data [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4f7f\u7528Sturge\u89c4\u5219\u8ba1\u7b97bin\u7684\u6570\u91cf num_bins = int ( np . floor ( 1 + np . log2 ( len ( data )))) # \u4f7f\u7528pandas\u7684cut\u51fd\u6570\u8fdb\u884c\u76ee\u6807\u53d8\u91cf\uff08target\uff09\u7684\u5206\u7bb1 data . loc [:, \"bins\" ] = pd . cut ( data [ \"target\" ], bins = num_bins , labels = False ) # \u521d\u59cb\u5316StratifiedKFold\u7c7b kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684kfold\u5217 # \u6ce8\u610f\uff1a\u6211\u4eec\u4f7f\u7528\u7684\u662fbins\u800c\u4e0d\u662f\u5b9e\u9645\u7684\u76ee\u6807\u53d8\u91cf\uff08target\uff09\uff01 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = data , y = data . bins . values )): data . loc [ v_ , 'kfold' ] = f # \u5220\u9664bins\u5217 data = data . drop ( \"bins\" , axis = 1 ) # \u8fd4\u56de\u5305\u542bfolds\u7684\u6570\u636e return data # \u4e3b\u7a0b\u5e8f\u5f00\u59cb if __name__ == \"__main__\" : # \u521b\u5efa\u4e00\u4e2a\u5e26\u670915000\u4e2a\u6837\u672c\u3001100\u4e2a\u7279\u5f81\u548c1\u4e2a\u76ee\u6807\u53d8\u91cf\u7684\u6837\u672c\u6570\u636e\u96c6 X , y = datasets . make_regression ( n_samples = 15000 , n_features = 100 , n_targets = 1 ) # \u4f7f\u7528numpy\u6570\u7ec4\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u6846 df = pd . DataFrame ( X , columns = [ f \"f_ { i } \" for i in range ( X . shape [ 1 ])] ) df . loc [:, \"target\" ] = y # \u521b\u5efafolds df = create_folds ( df ) \u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u7b2c\u4e00\u6b65\uff0c\u4e5f\u662f\u6700\u57fa\u672c\u7684\u4e00\u6b65\u3002\u5982\u679c\u8981\u505a\u7279\u5f81\u5de5\u7a0b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u8981\u5efa\u7acb\u6a21\u578b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u65b9\u6848\uff0c\u5176\u4e2d\u9a8c\u8bc1\u6570\u636e\u80fd\u591f\u4ee3\u8868\u8bad\u7ec3\u6570\u636e\u548c\u771f\u5b9e\u4e16\u754c\u7684\u6570\u636e\uff0c\u90a3\u4e48\u4f60\u5c31\u80fd\u5efa\u7acb\u4e00\u4e2a\u5177\u6709\u9ad8\u5ea6\u901a\u7528\u6027\u7684\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u3002 \u672c\u7ae0\u4ecb\u7ecd\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u6570\u636e\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u6839\u636e\u4f60\u7684\u95ee\u9898\u548c\u6570\u636e\u91c7\u7528\u65b0\u7684\u4ea4\u53c9\u68c0\u9a8c\u5f62\u5f0f\u3002 \u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u95ee\u9898\uff0c\u5e0c\u671b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u4ece\u60a3\u8005\u7684\u76ae\u80a4\u56fe\u50cf\u4e2d\u68c0\u6d4b\u51fa\u76ae\u80a4\u764c\u3002\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u5668\uff0c\u8be5\u5206\u7c7b\u5668\u63a5\u6536\u8f93\u5165\u56fe\u50cf\u5e76\u9884\u6d4b\u5176\u826f\u6027\u6216\u6076\u6027\u7684\u6982\u7387\u3002 \u5728\u8fd9\u7c7b\u6570\u636e\u96c6\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u53ef\u80fd\u6709\u540c\u4e00\u60a3\u8005\u7684\u591a\u5f20\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u8981\u5728\u8fd9\u91cc\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u7cfb\u7edf\uff0c\u5fc5\u987b\u6709\u5206\u5c42\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u4e5f\u5fc5\u987b\u786e\u4fdd\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u60a3\u8005\u4e0d\u4f1a\u51fa\u73b0\u5728\u9a8c\u8bc1\u6570\u636e\u4e2d\u3002\u5e78\u8fd0\u7684\u662f\uff0cscikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u79f0\u4e3a GroupKFold \u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u3002 \u5728\u8fd9\u91cc\uff0c\u60a3\u8005\u53ef\u4ee5\u88ab\u89c6\u4e3a\u7ec4\u3002 \u4f46\u9057\u61be\u7684\u662f\uff0cscikit-learn \u65e0\u6cd5\u5c06 GroupKFold \u4e0e StratifiedKFold \u7ed3\u5408\u8d77\u6765\u3002\u6240\u4ee5\u4f60\u9700\u8981\u81ea\u5df1\u52a8\u624b\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u8bfb\u8005\u7684\u7ec3\u4e60\u3002","title":"\u4ea4\u53c9\u68c0\u9a8c"},{"location":"%E4%BA%A4%E5%8F%89%E6%A3%80%E9%AA%8C/#_1","text":"\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u6ca1\u6709\u5efa\u7acb\u4efb\u4f55\u6a21\u578b\u3002\u539f\u56e0\u5f88\u7b80\u5355\uff0c\u5728\u521b\u5efa\u4efb\u4f55\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4ee5\u53ca\u5982\u4f55\u6839\u636e\u6570\u636e\u96c6\u9009\u62e9\u6700\u4f73\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\u3002 \u90a3\u4e48\uff0c\u4ec0\u4e48\u662f \u4ea4\u53c9\u68c0\u9a8c \uff0c\u6211\u4eec\u4e3a\u4ec0\u4e48\u8981\u5173\u6ce8\u5b83\uff1f \u5173\u4e8e\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u53ef\u4ee5\u627e\u5230\u591a\u79cd\u5b9a\u4e49\u3002\u6211\u7684\u5b9a\u4e49\u53ea\u6709\u4e00\u53e5\u8bdd\uff1a\u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fc7\u7a0b\u4e2d\u7684\u4e00\u4e2a\u6b65\u9aa4\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u786e\u4fdd\u6a21\u578b\u51c6\u786e\u62df\u5408\u6570\u636e\uff0c\u540c\u65f6\u786e\u4fdd\u6211\u4eec\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u4f46\u8fd9\u53c8\u5f15\u51fa\u4e86\u53e6\u4e00\u4e2a\u8bcd\uff1a \u8fc7\u62df\u5408 \u3002 \u8981\u89e3\u91ca\u8fc7\u62df\u5408\uff0c\u6211\u8ba4\u4e3a\u6700\u597d\u5148\u770b\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6709\u4e00\u4e2a\u76f8\u5f53\u6709\u540d\u7684\u7ea2\u9152\u8d28\u91cf\u6570\u636e\u96c6\uff08 red wine quality dataset \uff09\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 11 \u4e2a\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u8fd9\u4e9b\u7279\u5f81\u51b3\u5b9a\u4e86\u7ea2\u9152\u7684\u8d28\u91cf\u3002 \u8fd9\u4e9b\u5c5e\u6027\u5305\u62ec\uff1a \u56fa\u5b9a\u9178\u5ea6\uff08fixed acidity\uff09 \u6325\u53d1\u6027\u9178\u5ea6\uff08volatile acidity\uff09 \u67e0\u6aac\u9178\uff08citric acid\uff09 \u6b8b\u7559\u7cd6\uff08residual sugar\uff09 \u6c2f\u5316\u7269\uff08chlorides\uff09 \u6e38\u79bb\u4e8c\u6c27\u5316\u786b\uff08free sulfur dioxide\uff09 \u4e8c\u6c27\u5316\u786b\u603b\u91cf\uff08total sulfur dioxide\uff09 \u5bc6\u5ea6\uff08density\uff09 PH \u503c\uff08pH\uff09 \u786b\u9178\u76d0\uff08sulphates\uff09 \u9152\u7cbe\uff08alcohol\uff09 \u6839\u636e\u8fd9\u4e9b\u4e0d\u540c\u7279\u5f81\uff0c\u6211\u4eec\u9700\u8981\u9884\u6d4b\u7ea2\u8461\u8404\u9152\u7684\u8d28\u91cf\uff0c\u8d28\u91cf\u503c\u4ecb\u4e8e 0 \u5230 10 \u4e4b\u95f4\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u600e\u6837\u7684\u3002 import pandas as pd df = pd . read_csv ( \"winequality-red.csv\" ) \u56fe 1:\u7ea2\u8461\u8404\u9152\u8d28\u91cf\u6570\u636e\u96c6\u7b80\u5355\u5c55\u793a \u6211\u4eec\u53ef\u4ee5\u5c06\u8fd9\u4e2a\u95ee\u9898\u89c6\u4e3a\u5206\u7c7b\u95ee\u9898\uff0c\u4e5f\u53ef\u4ee5\u89c6\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u9009\u62e9\u5206\u7c7b\u3002\u7136\u800c\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u503c\u5305\u542b 6 \u79cd\u8d28\u91cf\u503c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u6240\u6709\u8d28\u91cf\u503c\u6620\u5c04\u5230 0 \u5230 5 \u4e4b\u95f4\u3002 # \u4e00\u4e2a\u6620\u5c04\u5b57\u5178\uff0c\u7528\u4e8e\u5c06\u8d28\u91cf\u503c\u4ece 0 \u5230 5 \u8fdb\u884c\u6620\u5c04 quality_mapping = { 3 : 0 , 4 : 1 , 5 : 2 , 6 : 3 , 7 : 4 , 8 : 5 } # \u4f60\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 map \u51fd\u6570\u4ee5\u53ca\u4efb\u4f55\u5b57\u5178\uff0c # \u6765\u8f6c\u6362\u7ed9\u5b9a\u5217\u4e2d\u7684\u503c\u4e3a\u5b57\u5178\u4e2d\u7684\u503c df . loc [:, \"quality\" ] = df . quality . map ( quality_mapping ) \u5f53\u6211\u4eec\u770b\u5927\u8fd9\u4e9b\u6570\u636e\u5e76\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u6211\u4eec\u8111\u6d77\u4e2d\u4f1a\u6d6e\u73b0\u51fa\u5f88\u591a\u53ef\u4ee5\u5e94\u7528\u7684\u7b97\u6cd5\uff0c\u4e5f\u8bb8\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u4ece\u4e00\u5f00\u59cb\u5c31\u6df1\u5165\u7814\u7a76\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u5c31\u6709\u70b9\u7275\u5f3a\u4e86\u3002\u6240\u4ee5\uff0c\u8ba9\u6211\u4eec\u4ece\u7b80\u5355\u7684\u3001\u6211\u4eec\u4e5f\u80fd\u53ef\u89c6\u5316\u7684\u4e1c\u897f\u5f00\u59cb\uff1a\u51b3\u7b56\u6811\u3002 \u5728\u5f00\u59cb\u4e86\u89e3\u4ec0\u4e48\u662f\u8fc7\u62df\u5408\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 1599 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u4fdd\u7559 1000 \u4e2a\u6837\u672c\u7528\u4e8e\u8bad\u7ec3\uff0c599 \u4e2a\u6837\u672c\u4f5c\u4e3a\u4e00\u4e2a\u5355\u72ec\u7684\u96c6\u5408\u3002 \u4ee5\u4e0b\u4ee3\u7801\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u5212\u5206\uff1a # \u4f7f\u7528 frac=1 \u7684 sample \u65b9\u6cd5\u6765\u6253\u4e71 dataframe # \u7531\u4e8e\u6253\u4e71\u540e\u7d22\u5f15\u4f1a\u6539\u53d8\uff0c\u6240\u4ee5\u6211\u4eec\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u9009\u53d6\u524d 1000 \u884c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e df_train = df . head ( 1000 ) # \u9009\u53d6\u6700\u540e\u7684 599 \u884c\u4f5c\u4e3a\u6d4b\u8bd5/\u9a8c\u8bc1\u6570\u636e df_test = df . tail ( 599 ) \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u5728\u8bad\u7ec3\u96c6\u4e0a\u4f7f\u7528 scikit-learn \u8bad\u7ec3\u4e00\u4e2a\u51b3\u7b56\u6811\u6a21\u578b\u3002 # \u4ece scikit-learn \u5bfc\u5165\u9700\u8981\u7684\u6a21\u5757 from sklearn import tree from sklearn import metrics # \u521d\u59cb\u5316\u4e00\u4e2a\u51b3\u7b56\u6811\u5206\u7c7b\u5668\uff0c\u8bbe\u7f6e\u6700\u5927\u6df1\u5ea6\u4e3a 3 clf = tree . DecisionTreeClassifier ( max_depth = 3 ) # \u9009\u62e9\u4f60\u60f3\u8981\u8bad\u7ec3\u6a21\u578b\u7684\u5217 # \u8fd9\u4e9b\u5217\u4f5c\u4e3a\u6a21\u578b\u7684\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u4f7f\u7528\u4e4b\u524d\u6620\u5c04\u7684\u8d28\u91cf\u4ee5\u53ca\u63d0\u4f9b\u7684\u7279\u5f81\u6765\u8bad\u7ec3\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5c06\u51b3\u7b56\u6811\u5206\u7c7b\u5668\u7684\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u8bbe\u4e3a 3\u3002\u8be5\u6a21\u578b\u7684\u6240\u6709\u5176\u4ed6\u53c2\u6570\u5747\u4fdd\u6301\u9ed8\u8ba4\u503c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u6d4b\u8bd5\u8be5\u6a21\u578b\u7684\u51c6\u786e\u6027\uff1a # \u5728\u8bad\u7ec3\u96c6\u4e0a\u751f\u6210\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) # \u5728\u6d4b\u8bd5\u96c6\u4e0a\u751f\u6210\u9884\u6d4b test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) # \u8ba1\u7b97\u6d4b\u8bd5\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) \u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u5206\u522b\u4e3a 58.9%\u548c 54.25%\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u589e\u52a0\u5230 7\uff0c\u5e76\u91cd\u590d\u4e0a\u8ff0\u8fc7\u7a0b\u3002\u8fd9\u6837\uff0c\u8bad\u7ec3\u51c6\u786e\u7387\u4e3a 76.6%\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4e3a 57.3%\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4f7f\u7528\u51c6\u786e\u7387\uff0c\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u662f\u6700\u76f4\u63a5\u7684\u6307\u6807\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u95ee\u9898\u6765\u8bf4\uff0c\u5b83\u53ef\u80fd\u4e0d\u662f\u6700\u597d\u7684\u6307\u6807\u3002\u6211\u4eec\u53ef\u4ee5\u6839\u636e\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u4e0d\u540c\u503c\u6765\u8ba1\u7b97\u8fd9\u4e9b\u51c6\u786e\u7387\uff0c\u5e76\u7ed8\u5236\u66f2\u7ebf\u56fe\u3002 # \u6ce8\u610f\uff1a\u8fd9\u6bb5\u4ee3\u7801\u5728 Jupyter \u7b14\u8bb0\u672c\u4e2d\u7f16\u5199 # \u5bfc\u5165 scikit-learn \u7684 tree \u548c metrics from sklearn import tree from sklearn import metrics # \u5bfc\u5165 matplotlib \u548c seaborn # \u7528\u4e8e\u7ed8\u56fe import matplotlib import matplotlib.pyplot as plt import seaborn as sns # \u8bbe\u7f6e\u5168\u5c40\u6807\u7b7e\u6587\u672c\u7684\u5927\u5c0f matplotlib . rc ( 'xtick' , labelsize = 20 ) matplotlib . rc ( 'ytick' , labelsize = 20 ) # \u786e\u4fdd\u56fe\u8868\u76f4\u63a5\u5728\u7b14\u8bb0\u672c\u5185\u663e\u793a % matplotlib inline # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6\u7684\u5217\u8868 # \u6211\u4eec\u4ece 50% \u7684\u51c6\u786e\u5ea6\u5f00\u59cb train_accuracies = [ 0.5 ] test_accuracies = [ 0.5 ] # \u904d\u5386\u51e0\u4e2a\u4e0d\u540c\u7684\u6811\u6df1\u5ea6\u503c for depth in range ( 1 , 25 ): # \u521d\u59cb\u5316\u6a21\u578b clf = tree . DecisionTreeClassifier ( max_depth = depth ) # \u9009\u62e9\u7528\u4e8e\u8bad\u7ec3\u7684\u5217/\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u5728\u7ed9\u5b9a\u7279\u5f81\u4e0a\u62df\u5408\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) # \u521b\u5efa\u8bad\u7ec3\u548c\u6d4b\u8bd5\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) # \u6dfb\u52a0\u51c6\u786e\u5ea6\u5230\u5217\u8868 train_accuracies . append ( train_accuracy ) test_accuracies . append ( test_accuracy ) # \u4f7f\u7528 matplotlib \u548c seaborn \u521b\u5efa\u4e24\u4e2a\u56fe plt . figure ( figsize = ( 10 , 5 )) sns . set_style ( \"whitegrid\" ) plt . plot ( train_accuracies , label = \"train accuracy\" ) plt . plot ( test_accuracies , label = \"test accuracy\" ) plt . legend ( loc = \"upper left\" , prop = { 'size' : 15 }) plt . xticks ( range ( 0 , 26 , 5 )) plt . xlabel ( \"max_depth\" , size = 20 ) plt . ylabel ( \"accuracy\" , size = 20 ) plt . show () \u8fd9\u5c06\u751f\u6210\u5982\u56fe 2 \u6240\u793a\u7684\u66f2\u7ebf\u56fe\u3002 \u56fe 2\uff1a\u4e0d\u540c max_depth \u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u7387\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5f53\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u503c\u4e3a 14 \u65f6\uff0c\u6d4b\u8bd5\u6570\u636e\u7684\u5f97\u5206\u6700\u9ad8\u3002\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u589e\u52a0\u8fd9\u4e2a\u53c2\u6570\u7684\u503c\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4f1a\u4fdd\u6301\u4e0d\u53d8\u6216\u53d8\u5dee\uff0c\u4f46\u8bad\u7ec3\u51c6\u786e\u7387\u4f1a\u4e0d\u65ad\u63d0\u9ad8\u3002\u8fd9\u8bf4\u660e\uff0c\u968f\u7740\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u589e\u52a0\uff0c\u51b3\u7b56\u6811\u6a21\u578b\u5bf9\u8bad\u7ec3\u6570\u636e\u7684\u5b66\u4e60\u6548\u679c\u8d8a\u6765\u8d8a\u597d\uff0c\u4f46\u6d4b\u8bd5\u6570\u636e\u7684\u6027\u80fd\u5374\u4e1d\u6beb\u6ca1\u6709\u63d0\u9ad8\u3002 \u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8fc7\u62df\u5408 \u3002 \u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u5b8c\u5168\u62df\u5408\uff0c\u800c\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5374\u8868\u73b0\u4e0d\u4f73\u3002\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5f88\u597d\u5730\u5b66\u4e60\u8bad\u7ec3\u6570\u636e\uff0c\u4f46\u65e0\u6cd5\u6cdb\u5316\u5230\u672a\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u3002\u5728\u4e0a\u9762\u7684\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u975e\u5e38\u9ad8\u7684\u6a21\u578b\uff0c\u5b83\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u4f1a\u6709\u51fa\u8272\u7684\u7ed3\u679c\uff0c\u4f46\u8fd9\u79cd\u6a21\u578b\u5e76\u4e0d\u5b9e\u7528\uff0c\u56e0\u4e3a\u5b83\u5728\u771f\u5b9e\u4e16\u754c\u7684\u6837\u672c\u6216\u5b9e\u65f6\u6570\u636e\u4e0a\u4e0d\u4f1a\u63d0\u4f9b\u7c7b\u4f3c\u7684\u7ed3\u679c\u3002 \u6709\u4eba\u53ef\u80fd\u4f1a\u8bf4\uff0c\u8fd9\u79cd\u65b9\u6cd5\u5e76\u6ca1\u6709\u8fc7\u62df\u5408\uff0c\u56e0\u4e3a\u6d4b\u8bd5\u96c6\u7684\u51c6\u786e\u7387\u57fa\u672c\u4fdd\u6301\u4e0d\u53d8\u3002\u8fc7\u62df\u5408\u7684\u53e6\u4e00\u4e2a\u5b9a\u4e49\u662f\uff0c\u5f53\u6211\u4eec\u4e0d\u65ad\u63d0\u9ad8\u8bad\u7ec3\u635f\u5931\u65f6\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u5728\u589e\u52a0\u3002\u8fd9\u79cd\u60c5\u51b5\u5728\u795e\u7ecf\u7f51\u7edc\u4e2d\u975e\u5e38\u5e38\u89c1\u3002 \u6bcf\u5f53\u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u90fd\u5fc5\u987b\u5728\u8bad\u7ec3\u671f\u95f4\u76d1\u63a7\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u3002\u5982\u679c\u6211\u4eec\u6709\u4e00\u4e2a\u975e\u5e38\u5927\u7684\u7f51\u7edc\u6765\u5904\u7406\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u6570\u636e\u96c6\uff08\u5373\u6837\u672c\u6570\u975e\u5e38\u5c11\uff09\uff0c\u6211\u4eec\u5c31\u4f1a\u89c2\u5bdf\u5230\uff0c\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u8bad\u7ec3\uff0c\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u90fd\u4f1a\u51cf\u5c11\u3002\u4f46\u662f\uff0c\u5728\u67d0\u4e2a\u65f6\u523b\uff0c\u6d4b\u8bd5\u635f\u5931\u4f1a\u8fbe\u5230\u6700\u5c0f\u503c\uff0c\u4e4b\u540e\uff0c\u5373\u4f7f\u8bad\u7ec3\u635f\u5931\u8fdb\u4e00\u6b65\u51cf\u5c11\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u4f1a\u5f00\u59cb\u589e\u52a0\u3002\u6211\u4eec\u5fc5\u987b\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u6700\u5c0f\u503c\u65f6\u505c\u6b62\u8bad\u7ec3\u3002 \u8fd9\u662f\u5bf9\u8fc7\u62df\u5408\u6700\u5e38\u89c1\u7684\u89e3\u91ca \u3002 \u5965\u5361\u59c6\u5243\u5200\u7528\u7b80\u5355\u7684\u8bdd\u8bf4\uff0c\u5c31\u662f\u4e0d\u8981\u8bd5\u56fe\u628a\u53ef\u4ee5\u7528\u7b80\u5355\u5f97\u591a\u7684\u65b9\u6cd5\u89e3\u51b3\u7684\u4e8b\u60c5\u590d\u6742\u5316\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6848\u5c31\u662f\u6700\u5177\u901a\u7528\u6027\u7684\u89e3\u51b3\u65b9\u6848\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u53ea\u8981\u4f60\u7684\u6a21\u578b\u4e0d\u7b26\u5408\u5965\u5361\u59c6\u5243\u5200\u539f\u5219\uff0c\u5c31\u5f88\u53ef\u80fd\u662f\u8fc7\u62df\u5408\u3002 \u56fe 3\uff1a\u8fc7\u62df\u5408\u7684\u6700\u4e00\u822c\u5b9a\u4e49 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u56de\u5230\u4ea4\u53c9\u68c0\u9a8c\u3002 \u5728\u89e3\u91ca\u8fc7\u62df\u5408\u65f6\uff0c\u6211\u51b3\u5b9a\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u6211\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u53e6\u4e00\u90e8\u5206\u4e0a\u68c0\u67e5\u5176\u6027\u80fd\u3002\u8fd9\u4e5f\u662f\u4ea4\u53c9\u68c0\u9a8c\u7684\u4e00\u79cd\uff0c\u901a\u5e38\u88ab\u79f0\u4e3a \"\u6682\u7559\u96c6\"\uff08 hold-out set \uff09\u3002\u5f53\u6211\u4eec\u62e5\u6709\u5927\u91cf\u6570\u636e\uff0c\u800c\u6a21\u578b\u63a8\u7406\u662f\u4e00\u4e2a\u8017\u65f6\u7684\u8fc7\u7a0b\u65f6\uff0c\u6211\u4eec\u5c31\u4f1a\u4f7f\u7528\u8fd9\u79cd\uff08\u4ea4\u53c9\uff09\u9a8c\u8bc1\u3002 \u4ea4\u53c9\u68c0\u9a8c\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\uff0c\u5b83\u662f\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u6b65\u9aa4\u3002 \u9009\u62e9\u6b63\u786e\u7684\u4ea4\u53c9\u68c0\u9a8c \u53d6\u51b3\u4e8e\u6240\u5904\u7406\u7684\u6570\u636e\u96c6\uff0c\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u9002\u7528\u7684\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u53ef\u80fd\u4e0d\u9002\u7528\u4e8e\u5176\u4ed6\u6570\u636e\u96c6\u3002\u4e0d\u8fc7\uff0c\u6709\u51e0\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u6700\u4e3a\u6d41\u884c\u548c\u5e7f\u6cdb\u4f7f\u7528\u3002 \u5176\u4e2d\u5305\u62ec\uff1a k \u6298\u4ea4\u53c9\u68c0\u9a8c \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c \u5206\u7ec4 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u4ea4\u53c9\u68c0\u9a8c\u662f\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u5c42\u51e0\u4e2a\u90e8\u5206\uff0c\u6211\u4eec\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u5176\u4f59\u90e8\u5206\u4e0a\u8fdb\u884c\u6d4b\u8bd5\u3002\u8bf7\u770b\u56fe 4\u3002 \u56fe 4\uff1a\u5c06\u6570\u636e\u96c6\u62c6\u5206\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u56fe 4 \u548c\u56fe 5 \u8bf4\u660e\uff0c\u5f53\u4f60\u5f97\u5230\u4e00\u4e2a\u6570\u636e\u96c6\u6765\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u65f6\uff0c\u4f60\u4f1a\u628a\u5b83\u4eec\u5206\u6210 \u4e24\u4e2a\u4e0d\u540c\u7684\u96c6\uff1a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u3002\u5f88\u591a\u4eba\u8fd8\u4f1a\u5c06\u5176\u5206\u6210\u7b2c\u4e09\u7ec4\uff0c\u79f0\u4e4b\u4e3a\u6d4b\u8bd5\u96c6\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u5c06\u53ea\u4f7f\u7528\u4e24\u4e2a\u96c6\u3002\u5982\u4f60\u6240\u89c1\uff0c\u6211\u4eec\u5c06\u6837\u672c\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\u8fdb\u884c\u4e86\u5212\u5206\u3002\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a k \u4e2a\u4e92\u4e0d\u5173\u8054\u7684\u4e0d\u540c\u96c6\u5408\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002 \u56fe 5\uff1aK \u6298\u4ea4\u53c9\u68c0\u9a8c \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 KFold \u5c06\u4efb\u4f55\u6570\u636e\u5206\u5272\u6210 k \u4e2a\u76f8\u7b49\u7684\u90e8\u5206\u3002\u6bcf\u4e2a\u6837\u672c\u5206\u914d\u4e00\u4e2a\u4ece 0 \u5230 k-1 \u7684\u503c\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u5b58\u50a8\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a kfold \u7684\u65b0\u5217\uff0c\u5e76\u7528 -1 \u586b\u5145 df [ \"kfold\" ] = - 1 # \u63a5\u4e0b\u6765\u7684\u6b65\u9aa4\u662f\u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4ece model_selection \u6a21\u5757\u521d\u59cb\u5316 kfold \u7c7b kf = model_selection . KFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684 kfold \u5217\uff08enumerate\u7684\u4f5c\u7528\u662f\u8fd4\u56de\u4e00\u4e2a\u8fed\u4ee3\u5668\uff09 for fold , ( trn_ , val_ ) in enumerate ( kf . split ( X = df )): df . loc [ val_ , 'kfold' ] = fold # \u4fdd\u5b58\u5e26\u6709 kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u636e\u96c6\u90fd\u53ef\u4ee5\u4f7f\u7528\u6b64\u6d41\u7a0b\u3002\u4f8b\u5982\uff0c\u5f53\u6570\u636e\u56fe\u50cf\u65f6\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u56fe\u50cf ID\u3001\u56fe\u50cf\u4f4d\u7f6e\u548c\u56fe\u50cf\u6807\u7b7e\u7684 CSV\uff0c\u7136\u540e\u4f7f\u7528\u4e0a\u8ff0\u6d41\u7a0b\u3002 \u53e6\u4e00\u79cd\u91cd\u8981\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u662f \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u6570\u636e\u96c6\uff0c\u5176\u4e2d\u6b63\u6837\u672c\u5360 90%\uff0c\u8d1f\u6837\u672c\u53ea\u5360 10%\uff0c\u90a3\u4e48\u4f60\u5c31\u4e0d\u5e94\u8be5\u4f7f\u7528\u968f\u673a k \u6298\u4ea4\u53c9\u3002\u5bf9\u8fd9\u6837\u7684\u6570\u636e\u96c6\u4f7f\u7528\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u6298\u53e0\u6837\u672c\u5168\u90e8\u4e3a\u8d1f\u6837\u672c\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u66f4\u503e\u5411\u4e8e\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u4ee5\u4fdd\u6301\u6bcf\u4e2a\u6298\u4e2d\u6807\u7b7e\u7684\u6bd4\u4f8b\u4e0d\u53d8\u3002\u56e0\u6b64\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u90fd\u4f1a\u6709\u76f8\u540c\u7684 90% \u6b63\u6837\u672c\u548c 10% \u8d1f\u6837\u672c\u3002\u56e0\u6b64\uff0c\u65e0\u8bba\u60a8\u9009\u62e9\u4ec0\u4e48\u6307\u6807\u8fdb\u884c\u8bc4\u4f30\uff0c\u90fd\u4f1a\u5728\u6240\u6709\u6298\u53e0\u4e2d\u5f97\u5230\u76f8\u4f3c\u7684\u7ed3\u679c\u3002 \u4fee\u6539\u521b\u5efa k \u6298\u4ea4\u53c9\u68c0\u9a8c\u7684\u4ee3\u7801\u4ee5\u521b\u5efa\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5f88\u5bb9\u6613\u3002\u6211\u4eec\u53ea\u9700\u5c06 model_selection.KFold \u66f4\u6539\u4e3a model_selection.StratifiedKFold \uff0c\u5e76\u5728 kf.split(...) \u51fd\u6570\u4e2d\u6307\u5b9a\u8981\u5206\u5c42\u7684\u76ee\u6807\u5217\u3002\u6211\u4eec\u5047\u8bbe CSV \u6570\u636e\u96c6\u6709\u4e00\u5217\u540d\u4e3a \"target\" \uff0c\u5e76\u4e14\u662f\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u4fdd\u5b58\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6dfb\u52a0\u4e00\u4e2a\u65b0\u5217 kfold\uff0c\u5e76\u7528 -1 \u521d\u59cb\u5316 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u53d8\u91cf y = df . target . values # \u521d\u59cb\u5316 StratifiedKFold \u7c7b\uff0c\u8bbe\u7f6e\u6298\u6570\uff08folds\uff09\u4e3a 5 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u4f7f\u7528 StratifiedKFold \u5bf9\u8c61\u7684 split \u65b9\u6cd5\u6765\u83b7\u53d6\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7d22\u5f15 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u5305\u542b kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u5bf9\u4e8e\u8461\u8404\u9152\u6570\u636e\u96c6\uff0c\u6211\u4eec\u6765\u770b\u770b\u6807\u7b7e\u7684\u5206\u5e03\u60c5\u51b5\u3002 b = sns . countplot ( x = 'quality' , data = df ) b . set_xlabel ( \"quality\" , fontsize = 20 ) b . set_ylabel ( \"count\" , fontsize = 20 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u7ee7\u7eed\u4e0a\u9762\u7684\u4ee3\u7801\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u8f6c\u6362\u4e86\u76ee\u6807\u503c\u3002\u4ece\u56fe 6 \u4e2d\u6211\u4eec\u53ef\u4ee5\u770b\u51fa\uff0c\u8d28\u91cf\u504f\u5dee\u5f88\u5927\u3002\u6709\u4e9b\u7c7b\u522b\u6709\u5f88\u591a\u6837\u672c\uff0c\u6709\u4e9b\u5219\u6ca1\u6709\u90a3\u4e48\u591a\u3002\u5982\u679c\u6211\u4eec\u8fdb\u884c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u90a3\u4e48\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u76ee\u6807\u503c\u5206\u5e03\u90fd\u4e0d\u4f1a\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u56fe 6\uff1a\u8461\u8404\u9152\u6570\u636e\u96c6\u4e2d \"\u8d28\u91cf\" \u5206\u5e03\u60c5\u51b5 \u89c4\u5219\u5f88\u7b80\u5355\uff0c\u5982\u679c\u662f\u6807\u51c6\u5206\u7c7b\u95ee\u9898\uff0c\u5c31\u76f2\u76ee\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u4f46\u5982\u679c\u6570\u636e\u91cf\u5f88\u5927\uff0c\u8be5\u600e\u4e48\u529e\u5462\uff1f\u5047\u8bbe\u6211\u4eec\u6709 100 \u4e07\u4e2a\u6837\u672c\u30025 \u500d\u4ea4\u53c9\u68c0\u9a8c\u610f\u5473\u7740\u5728 800k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u5728 200k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u9a8c\u8bc1\u3002\u6839\u636e\u6211\u4eec\u9009\u62e9\u7684\u7b97\u6cd5\uff0c\u5bf9\u4e8e\u8fd9\u6837\u89c4\u6a21\u7684\u6570\u636e\u96c6\u6765\u8bf4\uff0c\u8bad\u7ec3\u751a\u81f3\u9a8c\u8bc1\u90fd\u53ef\u80fd\u975e\u5e38\u6602\u8d35\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u3002 \u521b\u5efa\u4fdd\u6301\u7ed3\u679c\u7684\u8fc7\u7a0b\u4e0e\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u76f8\u540c\u3002\u5bf9\u4e8e\u62e5\u6709 100 \u4e07\u4e2a\u6837\u672c\u7684\u6570\u636e\u96c6\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa 10 \u4e2a\u6298\u53e0\u800c\u4e0d\u662f 5 \u4e2a\uff0c\u5e76\u4fdd\u7559\u5176\u4e2d\u4e00\u4e2a\u6298\u53e0\u4f5c\u4e3a\u4fdd\u7559\u6837\u672c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u6211\u4eec\u5c06\u6709 10 \u4e07\u4e2a\u6837\u672c\u88ab\u4fdd\u7559\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u59cb\u7ec8\u5728\u8fd9\u4e2a\u6837\u672c\u96c6\u4e0a\u8ba1\u7b97\u635f\u5931\u3001\u51c6\u786e\u7387\u548c\u5176\u4ed6\u6307\u6807\uff0c\u5e76\u5728 90 \u4e07\u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002 \u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u6682\u7559\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u975e\u5e38\u5e38\u7528\u3002\u5047\u8bbe\u6211\u4eec\u8981\u89e3\u51b3\u7684\u95ee\u9898\u662f\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97 2020 \u5e74\u7684\u9500\u552e\u989d\uff0c\u800c\u6211\u4eec\u5f97\u5230\u7684\u662f 2015-2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f60\u53ef\u4ee5\u9009\u62e9 2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u4f5c\u4e3a\u4fdd\u7559\u6570\u636e\uff0c\u7136\u540e\u5728 2015 \u5e74\u81f3 2018 \u5e74\u7684\u6240\u6709\u6570\u636e\u4e0a\u8bad\u7ec3\u4f60\u7684\u6a21\u578b\u3002 \u56fe 7\uff1a\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u793a\u4f8b \u5728\u56fe 7 \u6240\u793a\u7684\u793a\u4f8b\u4e2d\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u4ece\u65f6\u95f4\u6b65\u9aa4 31 \u5230 40 \u7684\u9500\u552e\u989d\u3002\u6211\u4eec\u53ef\u4ee5\u4fdd\u7559 21 \u81f3 30 \u6b65\u7684\u6570\u636e\uff0c\u7136\u540e\u4ece 0 \u6b65\u5230 20 \u6b65\u8bad\u7ec3\u6a21\u578b\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5728\u9884\u6d4b 31 \u6b65\u81f3 40 \u6b65\u65f6\uff0c\u5e94\u5c06 21 \u6b65\u81f3 30 \u6b65\u7684\u6570\u636e\u7eb3\u5165\u6a21\u578b\uff0c\u5426\u5219\uff0c\u6a21\u578b\u7684\u6027\u80fd\u5c06\u5927\u6253\u6298\u6263\u3002 \u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u5fc5\u987b\u5904\u7406\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u800c\u521b\u5efa\u5927\u578b\u9a8c\u8bc1\u96c6\u610f\u5473\u7740\u6a21\u578b\u5b66\u4e60\u4f1a\u4e22\u5931\u5927\u91cf\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9\u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c\uff0c\u76f8\u5f53\u4e8e\u7279\u6b8a\u7684 k \u5219\u4ea4\u53c9\u68c0\u9a8c\u5176\u4e2d k=N \uff0cN \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8fd9\u610f\u5473\u7740\u5728\u6240\u6709\u7684\u8bad\u7ec3\u6298\u53e0\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u9664 1 \u4e4b\u5916\u7684\u6240\u6709\u6570\u636e\u6837\u672c\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u7684\u6298\u53e0\u6570\u4e0e\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u76f8\u540c\u3002 \u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u6a21\u578b\u7684\u901f\u5ea6\u4e0d\u591f\u5feb\uff0c\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\uff0c\u4f46\u7531\u4e8e\u8fd9\u79cd\u4ea4\u53c9\u68c0\u9a8c\u53ea\u9002\u7528\u4e8e\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u56e0\u6b64\u5e76\u4e0d\u91cd\u8981\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u95ee\u9898\u4e86\u3002\u56de\u5f52\u95ee\u9898\u7684\u597d\u5904\u5728\u4e8e\uff0c\u9664\u4e86\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e4b\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u5728\u56de\u5f52\u95ee\u9898\u4e0a\u4f7f\u7528\u4e0a\u8ff0\u6240\u6709\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6211\u4eec\u4e0d\u80fd\u76f4\u63a5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u6709\u4e00\u4e9b\u65b9\u6cd5\u53ef\u4ee5\u7a0d\u7a0d\u6539\u53d8\u95ee\u9898\uff0c\u4ece\u800c\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u9002\u7528\u4e8e\u4efb\u4f55\u56de\u5f52\u95ee\u9898\u3002\u4f46\u662f\uff0c\u5982\u679c\u53d1\u73b0\u76ee\u6807\u5206\u5e03\u4e0d\u4e00\u81f4\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u8981\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u5fc5\u987b\u5148\u5c06\u76ee\u6807\u5212\u5206\u4e3a\u82e5\u5e72\u4e2a\u5206\u5c42\uff0c\u7136\u540e\u518d\u4ee5\u5904\u7406\u5206\u7c7b\u95ee\u9898\u7684\u76f8\u540c\u65b9\u5f0f\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u9009\u62e9\u5408\u9002\u7684\u5206\u5c42\u6570\u6709\u51e0\u79cd\u9009\u62e9\u3002\u5982\u679c\u6837\u672c\u91cf\u5f88\u5927\uff08> 10k\uff0c> 100k\uff09\uff0c\u90a3\u4e48\u5c31\u4e0d\u9700\u8981\u8003\u8651\u5206\u5c42\u7684\u6570\u91cf\u3002\u53ea\u9700\u5c06\u6570\u636e\u5206\u4e3a 10 \u6216 20 \u5c42\u5373\u53ef\u3002\u5982\u679c\u6837\u672c\u6570\u4e0d\u591a\uff0c\u5219\u53ef\u4ee5\u4f7f\u7528 Sturge's Rule \u8fd9\u6837\u7684\u7b80\u5355\u89c4\u5219\u6765\u8ba1\u7b97\u9002\u5f53\u7684\u5206\u5c42\u6570\u3002 Sturge's Rule\uff1a \\[ Number of Bins = 1 + log_2(N) \\] \u5176\u4e2d \\(N\\) \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8be5\u51fd\u6570\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u5229\u7528\u65af\u7279\u683c\u6cd5\u5219\u7ed8\u5236\u6837\u672c\u4e0e\u7bb1\u6570\u5bf9\u6bd4\u56fe \u8ba9\u6211\u4eec\u5236\u4f5c\u4e00\u4e2a\u56de\u5f52\u6570\u636e\u96c6\u6837\u672c\uff0c\u5e76\u5c1d\u8bd5\u5e94\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 # stratified-kfold for regression # \u4e3a\u56de\u5f52\u95ee\u9898\u8fdb\u884c\u5206\u5c42K-\u6298\u4ea4\u53c9\u9a8c\u8bc1 # \u5bfc\u5165\u9700\u8981\u7684\u5e93 import numpy as np import pandas as pd from sklearn import datasets from sklearn import model_selection # \u521b\u5efa\u5206\u6298\uff08folds\uff09\u7684\u51fd\u6570 def create_folds ( data ): # \u521b\u5efa\u4e00\u4e2a\u65b0\u5217\u53eb\u505akfold\uff0c\u5e76\u7528-1\u6765\u586b\u5145 data [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4f7f\u7528Sturge\u89c4\u5219\u8ba1\u7b97bin\u7684\u6570\u91cf num_bins = int ( np . floor ( 1 + np . log2 ( len ( data )))) # \u4f7f\u7528pandas\u7684cut\u51fd\u6570\u8fdb\u884c\u76ee\u6807\u53d8\u91cf\uff08target\uff09\u7684\u5206\u7bb1 data . loc [:, \"bins\" ] = pd . cut ( data [ \"target\" ], bins = num_bins , labels = False ) # \u521d\u59cb\u5316StratifiedKFold\u7c7b kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684kfold\u5217 # \u6ce8\u610f\uff1a\u6211\u4eec\u4f7f\u7528\u7684\u662fbins\u800c\u4e0d\u662f\u5b9e\u9645\u7684\u76ee\u6807\u53d8\u91cf\uff08target\uff09\uff01 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = data , y = data . bins . values )): data . loc [ v_ , 'kfold' ] = f # \u5220\u9664bins\u5217 data = data . drop ( \"bins\" , axis = 1 ) # \u8fd4\u56de\u5305\u542bfolds\u7684\u6570\u636e return data # \u4e3b\u7a0b\u5e8f\u5f00\u59cb if __name__ == \"__main__\" : # \u521b\u5efa\u4e00\u4e2a\u5e26\u670915000\u4e2a\u6837\u672c\u3001100\u4e2a\u7279\u5f81\u548c1\u4e2a\u76ee\u6807\u53d8\u91cf\u7684\u6837\u672c\u6570\u636e\u96c6 X , y = datasets . make_regression ( n_samples = 15000 , n_features = 100 , n_targets = 1 ) # \u4f7f\u7528numpy\u6570\u7ec4\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u6846 df = pd . DataFrame ( X , columns = [ f \"f_ { i } \" for i in range ( X . shape [ 1 ])] ) df . loc [:, \"target\" ] = y # \u521b\u5efafolds df = create_folds ( df ) \u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u7b2c\u4e00\u6b65\uff0c\u4e5f\u662f\u6700\u57fa\u672c\u7684\u4e00\u6b65\u3002\u5982\u679c\u8981\u505a\u7279\u5f81\u5de5\u7a0b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u8981\u5efa\u7acb\u6a21\u578b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u65b9\u6848\uff0c\u5176\u4e2d\u9a8c\u8bc1\u6570\u636e\u80fd\u591f\u4ee3\u8868\u8bad\u7ec3\u6570\u636e\u548c\u771f\u5b9e\u4e16\u754c\u7684\u6570\u636e\uff0c\u90a3\u4e48\u4f60\u5c31\u80fd\u5efa\u7acb\u4e00\u4e2a\u5177\u6709\u9ad8\u5ea6\u901a\u7528\u6027\u7684\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u3002 \u672c\u7ae0\u4ecb\u7ecd\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u6570\u636e\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u6839\u636e\u4f60\u7684\u95ee\u9898\u548c\u6570\u636e\u91c7\u7528\u65b0\u7684\u4ea4\u53c9\u68c0\u9a8c\u5f62\u5f0f\u3002 \u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u95ee\u9898\uff0c\u5e0c\u671b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u4ece\u60a3\u8005\u7684\u76ae\u80a4\u56fe\u50cf\u4e2d\u68c0\u6d4b\u51fa\u76ae\u80a4\u764c\u3002\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u5668\uff0c\u8be5\u5206\u7c7b\u5668\u63a5\u6536\u8f93\u5165\u56fe\u50cf\u5e76\u9884\u6d4b\u5176\u826f\u6027\u6216\u6076\u6027\u7684\u6982\u7387\u3002 \u5728\u8fd9\u7c7b\u6570\u636e\u96c6\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u53ef\u80fd\u6709\u540c\u4e00\u60a3\u8005\u7684\u591a\u5f20\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u8981\u5728\u8fd9\u91cc\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u7cfb\u7edf\uff0c\u5fc5\u987b\u6709\u5206\u5c42\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u4e5f\u5fc5\u987b\u786e\u4fdd\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u60a3\u8005\u4e0d\u4f1a\u51fa\u73b0\u5728\u9a8c\u8bc1\u6570\u636e\u4e2d\u3002\u5e78\u8fd0\u7684\u662f\uff0cscikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u79f0\u4e3a GroupKFold \u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u3002 \u5728\u8fd9\u91cc\uff0c\u60a3\u8005\u53ef\u4ee5\u88ab\u89c6\u4e3a\u7ec4\u3002 \u4f46\u9057\u61be\u7684\u662f\uff0cscikit-learn \u65e0\u6cd5\u5c06 GroupKFold \u4e0e StratifiedKFold \u7ed3\u5408\u8d77\u6765\u3002\u6240\u4ee5\u4f60\u9700\u8981\u81ea\u5df1\u52a8\u624b\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u8bfb\u8005\u7684\u7ec3\u4e60\u3002","title":"\u4ea4\u53c9\u68c0\u9a8c"},{"location":"%E5%87%86%E5%A4%87%E7%8E%AF%E5%A2%83/","text":"\u51c6\u5907\u73af\u5883 \u5728\u6211\u4eec\u5f00\u59cb\u7f16\u7a0b\u4e4b\u524d\uff0c\u5728\u4f60\u7684\u673a\u5668\u4e0a\u8bbe\u7f6e\u597d\u4e00\u5207\u662f\u975e\u5e38\u91cd\u8981\u7684\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Ubuntu 18.04 \u548c Python 3.7.6\u3002\u5982\u679c\u4f60\u662f Windows \u7528\u6237\uff0c\u53ef\u4ee5\u901a\u8fc7\u591a\u79cd\u65b9\u5f0f\u5b89\u88c5 Ubuntu\u3002\u4f8b\u5982\uff0c\u5728\u865a\u62df\u673a\u4e0a\u5b89\u88c5\u7531Oracle\u516c\u53f8\u63d0\u4f9b\u7684\u514d\u8d39\u8f6f\u4ef6 Virtual Box\u3002\u4e0eWindows\u4e00\u8d77\u4f5c\u4e3a\u53cc\u542f\u52a8\u7cfb\u7edf\u3002\u6211\u66f4\u559c\u6b22\u53cc\u542f\u52a8\uff0c\u56e0\u4e3a\u5b83\u662f\u539f\u751f\u7684\u3002\u5982\u679c\u4f60\u4e0d\u662fUbuntu\u7528\u6237\uff0c\u5728\u4f7f\u7528\u672c\u4e66\u4e2d\u7684\u67d0\u4e9bbash\u811a\u672c\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u95ee\u9898\u3002\u4e3a\u4e86\u907f\u514d\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u53ef\u4ee5\u5728\u865a\u62df\u673a\u4e2d\u5b89\u88c5Ubuntu\uff0c\u6216\u8005\u5728Windows\u4e0a\u5b89\u88c5Linux shell\u3002 \u7528 Anaconda \u5728\u4efb\u4f55\u673a\u5668\u4e0a\u5b89\u88c5 Python \u90fd\u5f88\u7b80\u5355\u3002\u6211\u7279\u522b\u559c\u6b22 Miniconda \uff0c\u5b83\u662f conda \u7684\u6700\u5c0f\u5b89\u88c5\u7a0b\u5e8f\u3002\u5b83\u9002\u7528\u4e8e Linux\u3001OSX \u548c Windows\u3002\u7531\u4e8e Python 2 \u652f\u6301\u5df2\u4e8e 2019 \u5e74\u5e95\u7ed3\u675f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Python 3 \u53d1\u884c\u7248\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0cminiconda \u5e76\u4e0d\u50cf\u666e\u901a Anaconda \u9644\u5e26\u6240\u6709\u8f6f\u4ef6\u5305\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u968f\u65f6\u5b89\u88c5\u65b0\u8f6f\u4ef6\u5305\u3002\u5b89\u88c5 miniconda \u975e\u5e38\u7b80\u5355\u3002 \u9996\u5148\u8981\u505a\u7684\u662f\u5c06 Miniconda3 \u4e0b\u8f7d\u5230\u7cfb\u7edf\u4e2d\u3002 cd ~/Downloads wget https://repo.anaconda.com/miniconda/... \u5176\u4e2d wget \u547d\u4ee4\u540e\u7684 URL \u662f miniconda3 \u7f51\u9875\u7684 URL\u3002\u5bf9\u4e8e 64 \u4f4d Linux \u7cfb\u7edf\uff0c\u7f16\u5199\u672c\u4e66\u65f6\u7684 URL \u662f https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \u4e0b\u8f7d miniconda3 \u540e\uff0c\u53ef\u4ee5\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\uff1a sh Miniconda3-latest-Linux-x86_64.sh \u63a5\u4e0b\u6765\uff0c\u8bf7\u9605\u8bfb\u5e76\u6309\u7167\u5c4f\u5e55\u4e0a\u7684\u8bf4\u660e\u64cd\u4f5c\u3002\u5982\u679c\u5b89\u88c5\u6b63\u786e\uff0c\u4f60\u5e94\u8be5\u53ef\u4ee5\u901a\u8fc7\u5728\u7ec8\u7aef\u8f93\u5165 conda init \u6765\u542f\u52a8 conda \u73af\u5883\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u5728\u672c\u4e66\u4e2d\u4e00\u76f4\u4f7f\u7528\u7684 conda \u73af\u5883\u3002\u8981\u521b\u5efa conda \u73af\u5883\uff0c\u53ef\u4ee5\u8f93\u5165\uff1a conda create -n environment_name python = 3 .7.6 \u6b64\u547d\u4ee4\u5c06\u521b\u5efa\u540d\u4e3a environment_name \u7684 conda \u73af\u5883\uff0c\u53ef\u4ee5\u4f7f\u7528\uff1a conda activate environment_name \u73b0\u5728\u6211\u4eec\u7684\u73af\u5883\u5df2\u7ecf\u642d\u5efa\u5b8c\u6bd5\u3002\u662f\u65f6\u5019\u5b89\u88c5\u4e00\u4e9b\u6211\u4eec\u4f1a\u7528\u5230\u7684\u8f6f\u4ef6\u5305\u4e86\u3002\u5728 conda \u73af\u5883\u4e2d\uff0c\u5b89\u88c5\u8f6f\u4ef6\u5305\u6709\u4e24\u79cd\u4e0d\u540c\u7684\u65b9\u5f0f\u3002 \u4f60\u53ef\u4ee5\u4ece conda \u4ed3\u5e93\u6216 PyPi \u5b98\u65b9\u4ed3\u5e93\u5b89\u88c5\u8f6f\u4ef6\u5305\u3002 conda/pip install package_name \u6ce8\u610f\uff1a\u67d0\u4e9b\u8f6f\u4ef6\u5305\u53ef\u80fd\u65e0\u6cd5\u5728 conda \u8f6f\u4ef6\u4ed3\u5e93\u4e2d\u627e\u5230\u3002\u56e0\u6b64\uff0c\u5728\u672c\u4e66\u4e2d\uff0c\u4f7f\u7528 pip \u5b89\u88c5\u662f\u6700\u53ef\u53d6\u7684\u65b9\u6cd5\u3002\u6211\u5df2\u7ecf\u521b\u5efa\u4e86\u4e00\u4e2a\u7f16\u5199\u672c\u4e66\u65f6\u4f7f\u7528\u7684\u8f6f\u4ef6\u5305\u5217\u8868\uff0c\u4fdd\u5b58\u5728 environment.yml \u4e2d\u3002 \u4f60\u53ef\u4ee5\u5728\u6211\u7684 GitHub \u4ed3\u5e93\u4e2d\u7684\u989d\u5916\u8d44\u6599\u4e2d\u627e\u5230\u5b83\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u521b\u5efa\u73af\u5883\uff1a conda env create -f environment.yml \u8be5\u547d\u4ee4\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a ml \u7684\u73af\u5883\u3002\u8981\u6fc0\u6d3b\u8be5\u73af\u5883\u5e76\u5f00\u59cb\u4f7f\u7528\uff0c\u5e94\u8fd0\u884c\uff1a conda activate ml \u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u51c6\u5907\u5c31\u7eea\uff0c\u53ef\u4ee5\u8fdb\u884c\u4e00\u4e9b\u5e94\u7528\u673a\u5668\u5b66\u4e60\u7684\u5de5\u4f5c\u4e86\uff01\u5728\u4f7f\u7528\u672c\u4e66\u8fdb\u884c\u7f16\u7801\u65f6\uff0c\u8bf7\u59cb\u7ec8\u8bb0\u4f4f\u8981\u5728 \"ml \"\u73af\u5883\u4e0b\u8fdb\u884c\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5f00\u59cb\u5b66\u4e60\u771f\u6b63\u7684\u7b2c\u4e00\u7ae0\u3002","title":"\u51c6\u5907\u73af\u5883"},{"location":"%E5%87%86%E5%A4%87%E7%8E%AF%E5%A2%83/#_1","text":"\u5728\u6211\u4eec\u5f00\u59cb\u7f16\u7a0b\u4e4b\u524d\uff0c\u5728\u4f60\u7684\u673a\u5668\u4e0a\u8bbe\u7f6e\u597d\u4e00\u5207\u662f\u975e\u5e38\u91cd\u8981\u7684\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Ubuntu 18.04 \u548c Python 3.7.6\u3002\u5982\u679c\u4f60\u662f Windows \u7528\u6237\uff0c\u53ef\u4ee5\u901a\u8fc7\u591a\u79cd\u65b9\u5f0f\u5b89\u88c5 Ubuntu\u3002\u4f8b\u5982\uff0c\u5728\u865a\u62df\u673a\u4e0a\u5b89\u88c5\u7531Oracle\u516c\u53f8\u63d0\u4f9b\u7684\u514d\u8d39\u8f6f\u4ef6 Virtual Box\u3002\u4e0eWindows\u4e00\u8d77\u4f5c\u4e3a\u53cc\u542f\u52a8\u7cfb\u7edf\u3002\u6211\u66f4\u559c\u6b22\u53cc\u542f\u52a8\uff0c\u56e0\u4e3a\u5b83\u662f\u539f\u751f\u7684\u3002\u5982\u679c\u4f60\u4e0d\u662fUbuntu\u7528\u6237\uff0c\u5728\u4f7f\u7528\u672c\u4e66\u4e2d\u7684\u67d0\u4e9bbash\u811a\u672c\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u95ee\u9898\u3002\u4e3a\u4e86\u907f\u514d\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u53ef\u4ee5\u5728\u865a\u62df\u673a\u4e2d\u5b89\u88c5Ubuntu\uff0c\u6216\u8005\u5728Windows\u4e0a\u5b89\u88c5Linux shell\u3002 \u7528 Anaconda \u5728\u4efb\u4f55\u673a\u5668\u4e0a\u5b89\u88c5 Python \u90fd\u5f88\u7b80\u5355\u3002\u6211\u7279\u522b\u559c\u6b22 Miniconda \uff0c\u5b83\u662f conda \u7684\u6700\u5c0f\u5b89\u88c5\u7a0b\u5e8f\u3002\u5b83\u9002\u7528\u4e8e Linux\u3001OSX \u548c Windows\u3002\u7531\u4e8e Python 2 \u652f\u6301\u5df2\u4e8e 2019 \u5e74\u5e95\u7ed3\u675f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Python 3 \u53d1\u884c\u7248\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0cminiconda \u5e76\u4e0d\u50cf\u666e\u901a Anaconda \u9644\u5e26\u6240\u6709\u8f6f\u4ef6\u5305\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u968f\u65f6\u5b89\u88c5\u65b0\u8f6f\u4ef6\u5305\u3002\u5b89\u88c5 miniconda \u975e\u5e38\u7b80\u5355\u3002 \u9996\u5148\u8981\u505a\u7684\u662f\u5c06 Miniconda3 \u4e0b\u8f7d\u5230\u7cfb\u7edf\u4e2d\u3002 cd ~/Downloads wget https://repo.anaconda.com/miniconda/... \u5176\u4e2d wget \u547d\u4ee4\u540e\u7684 URL \u662f miniconda3 \u7f51\u9875\u7684 URL\u3002\u5bf9\u4e8e 64 \u4f4d Linux \u7cfb\u7edf\uff0c\u7f16\u5199\u672c\u4e66\u65f6\u7684 URL \u662f https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \u4e0b\u8f7d miniconda3 \u540e\uff0c\u53ef\u4ee5\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\uff1a sh Miniconda3-latest-Linux-x86_64.sh \u63a5\u4e0b\u6765\uff0c\u8bf7\u9605\u8bfb\u5e76\u6309\u7167\u5c4f\u5e55\u4e0a\u7684\u8bf4\u660e\u64cd\u4f5c\u3002\u5982\u679c\u5b89\u88c5\u6b63\u786e\uff0c\u4f60\u5e94\u8be5\u53ef\u4ee5\u901a\u8fc7\u5728\u7ec8\u7aef\u8f93\u5165 conda init \u6765\u542f\u52a8 conda \u73af\u5883\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u5728\u672c\u4e66\u4e2d\u4e00\u76f4\u4f7f\u7528\u7684 conda \u73af\u5883\u3002\u8981\u521b\u5efa conda \u73af\u5883\uff0c\u53ef\u4ee5\u8f93\u5165\uff1a conda create -n environment_name python = 3 .7.6 \u6b64\u547d\u4ee4\u5c06\u521b\u5efa\u540d\u4e3a environment_name \u7684 conda \u73af\u5883\uff0c\u53ef\u4ee5\u4f7f\u7528\uff1a conda activate environment_name \u73b0\u5728\u6211\u4eec\u7684\u73af\u5883\u5df2\u7ecf\u642d\u5efa\u5b8c\u6bd5\u3002\u662f\u65f6\u5019\u5b89\u88c5\u4e00\u4e9b\u6211\u4eec\u4f1a\u7528\u5230\u7684\u8f6f\u4ef6\u5305\u4e86\u3002\u5728 conda \u73af\u5883\u4e2d\uff0c\u5b89\u88c5\u8f6f\u4ef6\u5305\u6709\u4e24\u79cd\u4e0d\u540c\u7684\u65b9\u5f0f\u3002 \u4f60\u53ef\u4ee5\u4ece conda \u4ed3\u5e93\u6216 PyPi \u5b98\u65b9\u4ed3\u5e93\u5b89\u88c5\u8f6f\u4ef6\u5305\u3002 conda/pip install package_name \u6ce8\u610f\uff1a\u67d0\u4e9b\u8f6f\u4ef6\u5305\u53ef\u80fd\u65e0\u6cd5\u5728 conda \u8f6f\u4ef6\u4ed3\u5e93\u4e2d\u627e\u5230\u3002\u56e0\u6b64\uff0c\u5728\u672c\u4e66\u4e2d\uff0c\u4f7f\u7528 pip \u5b89\u88c5\u662f\u6700\u53ef\u53d6\u7684\u65b9\u6cd5\u3002\u6211\u5df2\u7ecf\u521b\u5efa\u4e86\u4e00\u4e2a\u7f16\u5199\u672c\u4e66\u65f6\u4f7f\u7528\u7684\u8f6f\u4ef6\u5305\u5217\u8868\uff0c\u4fdd\u5b58\u5728 environment.yml \u4e2d\u3002 \u4f60\u53ef\u4ee5\u5728\u6211\u7684 GitHub \u4ed3\u5e93\u4e2d\u7684\u989d\u5916\u8d44\u6599\u4e2d\u627e\u5230\u5b83\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u521b\u5efa\u73af\u5883\uff1a conda env create -f environment.yml \u8be5\u547d\u4ee4\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a ml \u7684\u73af\u5883\u3002\u8981\u6fc0\u6d3b\u8be5\u73af\u5883\u5e76\u5f00\u59cb\u4f7f\u7528\uff0c\u5e94\u8fd0\u884c\uff1a conda activate ml \u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u51c6\u5907\u5c31\u7eea\uff0c\u53ef\u4ee5\u8fdb\u884c\u4e00\u4e9b\u5e94\u7528\u673a\u5668\u5b66\u4e60\u7684\u5de5\u4f5c\u4e86\uff01\u5728\u4f7f\u7528\u672c\u4e66\u8fdb\u884c\u7f16\u7801\u65f6\uff0c\u8bf7\u59cb\u7ec8\u8bb0\u4f4f\u8981\u5728 \"ml \"\u73af\u5883\u4e0b\u8fdb\u884c\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5f00\u59cb\u5b66\u4e60\u771f\u6b63\u7684\u7b2c\u4e00\u7ae0\u3002","title":"\u51c6\u5907\u73af\u5883"},{"location":"%E5%8F%AF%E9%87%8D%E5%A4%8D%E4%BB%A3%E7%A0%81%E5%92%8C%E6%A8%A1%E5%9E%8B%E6%96%B9%E6%B3%95/","text":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u5230\u4e86\u53ef\u4ee5\u5c06\u6a21\u578b/\u8bad\u7ec3\u4ee3\u7801\u5206\u53d1\u7ed9\u4ed6\u4eba\u4f7f\u7528\u7684\u9636\u6bb5\u3002\u60a8\u53ef\u4ee5\u7528\u8f6f\u76d8\u5206\u53d1\u6216\u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\uff0c\u4f46\u8fd9\u5e76\u4e0d\u7406\u60f3\u3002\u662f\u8fd9\u6837\u5417\uff1f\u4e5f\u8bb8\u5f88\u591a\u5e74\u524d\uff0c\u8fd9\u662f\u7406\u60f3\u7684\u505a\u6cd5\uff0c\u4f46\u73b0\u5728\u4e0d\u662f\u4e86\u3002 \u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\u548c\u534f\u4f5c\u7684\u9996\u9009\u65b9\u5f0f\u662f\u4f7f\u7528\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u3002Git \u662f\u6700\u6d41\u884c\u7684\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u4e4b\u4e00\u3002\u90a3\u4e48\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u5b66\u4f1a\u4e86 Git\uff0c\u5e76\u6b63\u786e\u5730\u683c\u5f0f\u5316\u4e86\u4ee3\u7801\uff0c\u7f16\u5199\u4e86\u9002\u5f53\u7684\u6587\u6863\uff0c\u8fd8\u5f00\u6e90\u4e86\u4f60\u7684\u9879\u76ee\u3002\u8fd9\u5c31\u591f\u4e86\u5417\uff1f\u4e0d\uff0c\u8fd8\u4e0d\u591f\u3002\u56e0\u4e3a\u4f60\u5728\u81ea\u5df1\u7684\u7535\u8111\u4e0a\u5199\u7684\u4ee3\u7801\uff0c\u5728\u522b\u4eba\u7684\u7535\u8111\u4e0a\u53ef\u80fd\u4f1a\u56e0\u4e3a\u5404\u79cd\u539f\u56e0\u800c\u65e0\u6cd5\u8fd0\u884c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u5728\u53d1\u5e03\u4ee3\u7801\u65f6\u80fd\u590d\u5236\u81ea\u5df1\u7684\u7535\u8111\uff0c\u800c\u5176\u4ed6\u4eba\u5728\u5b89\u88c5\u60a8\u7684\u8f6f\u4ef6\u6216\u8fd0\u884c\u60a8\u7684\u4ee3\u7801\u65f6\u4e5f\u80fd\u590d\u5236\u60a8\u7684\u7535\u8111\uff0c\u90a3\u5c31\u518d\u597d\u4e0d\u8fc7\u4e86\u3002\u4e3a\u6b64\uff0c\u5982\u4eca\u6700\u6d41\u884c\u7684\u65b9\u6cd5\u662f\u4f7f\u7528 Docker \u5bb9\u5668\uff08Docker Containers\uff09\u3002\u8981\u4f7f\u7528 Docker \u5bb9\u5668\uff0c\u4f60\u9700\u8981\u5b89\u88c5 Docker\u3002 \u8ba9\u6211\u4eec\u7528\u4e0b\u9762\u7684\u547d\u4ee4\u6765\u5b89\u88c5 Docker\u3002 sudo apt install docker.io sudo systemctl start docker sudo systemctl enable docker sudo groupadd docker sudo usermod -aG docker $USER \u8fd9\u4e9b\u547d\u4ee4\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e0a\u8fd0\u884c\u3002Docker \u6700\u68d2\u7684\u5730\u65b9\u5728\u4e8e\u5b83\u53ef\u4ee5\u5b89\u88c5\u5728\u4efb\u4f55\u673a\u5668\u4e0a\uff1a Linux\u3001Windows\u3001OSX\u3002\u56e0\u6b64\uff0c\u5982\u679c\u4f60\u4e00\u76f4\u5728 Docker \u5bb9\u5668\u4e2d\u5de5\u4f5c\uff0c\u54ea\u53f0\u673a\u5668\u90fd\u6ca1\u5173\u7cfb\uff01 Docker \u5bb9\u5668\u53ef\u4ee5\u88ab\u89c6\u4e3a\u5c0f\u578b\u865a\u62df\u673a\u3002\u4f60\u53ef\u4ee5\u4e3a\u4f60\u7684\u4ee3\u7801\u521b\u5efa\u4e00\u4e2a\u5bb9\u5668\uff0c\u7136\u540e\u6bcf\u4e2a\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u548c\u8bbf\u95ee\u5b83\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u521b\u5efa\u53ef\u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u5bb9\u5668\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e00\u7ae0\u4e2d\u8bad\u7ec3\u7684 BERT \u6a21\u578b\uff0c\u5e76\u5c1d\u8bd5\u5c06\u8bad\u7ec3\u4ee3\u7801\u5bb9\u5668\u5316\u3002 \u9996\u5148\uff0c\u4f60\u9700\u8981\u4e00\u4e2a\u5305\u542b python \u9879\u76ee\u9700\u6c42\u7684\u6587\u4ef6\u3002\u9700\u6c42\u5305\u542b\u5728\u540d\u4e3a requirements.txt \u7684\u6587\u4ef6\u4e2d\u3002\u6587\u4ef6\u540d\u662f thestandard\u3002\u8be5\u6587\u4ef6\u5305\u542b\u9879\u76ee\u4e2d\u4f7f\u7528\u7684\u6240\u6709 python \u5e93\u3002\u4e5f\u5c31\u662f\u53ef\u4ee5\u901a\u8fc7 PyPI (pip) \u4e0b\u8f7d\u7684 python \u5e93\u3002\u7528\u4e8e \u8bad\u7ec3 BERT \u6a21\u578b\u4ee5\u68c0\u6d4b\u6b63/\u8d1f\u60c5\u611f\uff0c\u6211\u4eec\u4f7f\u7528\u4e86 torch\u3001transformers\u3001tqdm\u3001scikit-learn\u3001pandas \u548c numpy\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u4eec\u5199\u5165 requirements.txt \u4e2d\u3002\u4f60\u53ef\u4ee5\u53ea\u5199\u540d\u79f0\uff0c\u4e5f\u53ef\u4ee5\u5305\u62ec\u7248\u672c\u3002\u5305\u542b\u7248\u672c\u603b\u662f\u6700\u597d\u7684\uff0c\u8fd9\u4e5f\u662f\u4f60\u5e94\u8be5\u505a\u7684\u3002\u5305\u542b\u7248\u672c\u540e\uff0c\u53ef\u4ee5\u786e\u4fdd\u5176\u4ed6\u4eba\u4f7f\u7528\u7684\u7248\u672c\u4e0e\u4f60\u7684\u7248\u672c\u76f8\u540c\uff0c\u800c\u4e0d\u662f\u6700\u65b0\u7248\u672c\uff0c\u56e0\u4e3a\u6700\u65b0\u7248\u672c\u53ef\u80fd\u4f1a\u66f4\u6539\u67d0\u4e9b\u5185\u5bb9\uff0c\u5982\u679c\u662f\u8fd9\u6837\u7684\u8bdd\uff0c\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u5c31\u4e0d\u4f1a\u4e0e\u4f60\u7684\u76f8\u540c\u4e86\u3002 \u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u663e\u793a\u4e86 requirements.txt\u3002 # requirements.txt pandas == 1.0.4 scikit - learn == 0.22.1 torch == 1.5.0 transformers == 2.11.0 \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a Dockerfile \u7684 Docker \u6587\u4ef6\u3002\u6ca1\u6709\u6269\u5c55\u540d\u3002Dockerfile \u6709\u51e0\u4e2a\u5143\u7d20\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 # Dockerfile # First of all, we include where we are getting the image # from. Image can be thought of as an operating system. # You can do \"FROM ubuntu:18.04\" # this will start from a clean ubuntu 18.04 image. # All images are downloaded from dockerhub # Here are we grabbing image from nvidia's repo # they created a docker image using ubuntu 18.04 # and installed cuda 10.1 and cudnn7 in it. Thus, we don't have to # install it. Makes our life easy. FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 # this is the same apt-get command that you are used to # except the fact that, we have -y argument. Its because # when we build this container, we cannot press Y when asked for RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* # We add a new user called \"abhishek\" # this can be anything. Anything you want it # to be. Usually, we don't use our own name, # you can use \"user\" or \"ubuntu\" RUN useradd -m abhishek # make our user own its own home directory RUN chown -R abhishek:abhishek /home/abhishek/ # copy all files from this direrctory to a # directory called app inside the home of abhishek # and abhishek owns it. COPY --chown = abhishek *.* /home/abhishek/app/ # change to user abhishek USER abhishek RUN mkdir /home/abhishek/data/ # Now we install all the requirements # after moving to the app directory # PLEASE NOTE that ubuntu 18.04 image # has python 3.6.9 and not python 3.7.6 # you can also install conda python here and use that # however, to simplify it, I will be using python 3.6.9 # inside the docker container!!!! RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt # install mkl. its needed for transformers RUN pip3 install mkl # when we log into the docker container, # we will go inside this directory automatically WORKDIR /home/abhishek/app \u521b\u5efa\u597d Docker \u6587\u4ef6\u540e\uff0c\u6211\u4eec\u5c31\u9700\u8981\u6784\u5efa\u5b83\u3002\u6784\u5efa Docker \u5bb9\u5668\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u547d\u4ee4\u3002 docker build -f Dockerfile -t bert:train . \u8be5\u547d\u4ee4\u6839\u636e\u63d0\u4f9b\u7684 Dockerfile \u6784\u5efa\u4e00\u4e2a\u5bb9\u5668\u3002Docker \u5bb9\u5668\u7684\u540d\u79f0\u662f bert:train\u3002\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a \u276f docker build -f Dockerfile -t bert:train . Sending build context to Docker daemon 19.97kB Step 1/7 : FROM nvidia/cuda:10.1-cudnn7-ubuntu18.04 ---> 3b55548ae91f Step 2/7 : RUN apt-get update && apt-get install -y git curl ca- certificates python3 python3-pip sudo && rm -rf /var/lib/apt/lists/* . . . . Removing intermediate container 8f6975dd08ba ---> d1802ac9f1b4 Step 7/7 : WORKDIR /home/abhishek/app ---> Running in 257ff09502ed Removing intermediate container 257ff09502ed ---> e5f6eb4cddd7 Successfully built e5f6eb4cddd7 Successfully tagged bert:train \u8bf7\u6ce8\u610f\uff0c\u6211\u5220\u9664\u4e86\u8f93\u51fa\u4e2d\u7684\u8bb8\u591a\u884c\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u767b\u5f55\u5bb9\u5668\u3002 docker run -ti bert:train /bin/bash \u4f60\u9700\u8981\u8bb0\u4f4f\uff0c\u4e00\u65e6\u9000\u51fa shell\uff0c\u4f60\u5728 shell \u4e2d\u6240\u505a\u7684\u4e00\u5207\u90fd\u5c06\u4e22\u5931\u3002\u4f60\u8fd8\u53ef\u4ee5\u5728 Docker \u5bb9\u5668\u4e2d\u4f7f\u7528\u3002 docker run -ti bert:train python3 train.py \u8f93\u51fa\u60c5\u51b5\uff1a Traceback (most recent call last): File \"train.py\", line 2, in import config File \"/home/abhishek/app/config.py\", line 28, in do_lower_case=True File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 393, in from_pretrained return cls._from_pretrained(*inputs, **kwargs) File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 496, in _from_pretrained list(cls.vocab_files_names.values()), OSError: Model name '../input/bert_base_uncased/' was not found in tokenizers model name list (bert-base-uncased, bert-large-uncased, bert- base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base- multilingual-cased, bert-base-chinese, bert-base-german-cased, bert- large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased- whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert- base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base- finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). We assumed '../input/bert_base_uncased/' was a path, a model identifier, or url to a directory containing vocabulary files named ['vocab.txt'] but couldn't find such vocabulary files at this path or url. \u54ce\u5440\uff0c\u51fa\u9519\u4e86\uff01 \u6211\u4e3a\u4ec0\u4e48\u8981\u628a\u9519\u8bef\u5370\u5728\u4e66\u4e0a\u5462\uff1f \u56e0\u4e3a\u7406\u89e3\u8fd9\u4e2a\u9519\u8bef\u975e\u5e38\u91cd\u8981\u3002\u8fd9\u4e2a\u9519\u8bef\u8bf4\u660e\u4ee3\u7801\u65e0\u6cd5\u627e\u5230\u76ee\u5f55\".../input/bert_base_cased\"\u3002\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u5728\u6ca1\u6709 Docker \u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u76ee\u5f55\u548c\u6240\u6709\u6587\u4ef6\u90fd\u5b58\u5728\u3002\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u662f\u56e0\u4e3a Docker \u5c31\u50cf\u4e00\u4e2a\u865a\u62df\u673a\uff01\u5b83\u6709\u81ea\u5df1\u7684\u6587\u4ef6\u7cfb\u7edf\uff0c\u672c\u5730\u673a\u5668\u4e0a\u7684\u6587\u4ef6\u4e0d\u4f1a\u5171\u4eab\u7ed9 Docker \u5bb9\u5668\u3002\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u672c\u5730\u673a\u5668\u4e0a\u7684\u8def\u5f84\u5e76\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\uff0c\u4f60\u9700\u8981\u5728\u8fd0\u884c Docker \u65f6\u5c06\u5176\u6302\u8f7d\u5230 Docker \u5bb9\u5668\u4e0a\u3002\u5f53\u6211\u4eec\u67e5\u770b\u8fd9\u4e2a\u6587\u4ef6\u5939\u7684\u8def\u5f84\u65f6\uff0c\u6211\u4eec\u77e5\u9053\u5b83\u4f4d\u4e8e\u540d\u4e3a input \u7684\u6587\u4ef6\u5939\u7684\u4e0a\u4e00\u7ea7\u3002\u8ba9\u6211\u4eec\u7a0d\u5fae\u4fee\u6539\u4e00\u4e0b config.py \u6587\u4ef6\uff01 # config.py import os import transformers # fetch home directory # in our docker container, it is # /home/abhishek HOME_DIR = os . path . expanduser ( \"~\" ) # this is the maximum number of tokens in the sentence MAX_LEN = 512 # batch sizes is low because model is huge! TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 # let's train for a maximum of 10 epochs EPOCHS = 10 # define path to BERT model files # Now we assume that all the data is stored inside # /home/abhishek/data BERT_PATH = os . path . join ( HOME_DIR , \"data\" , \"bert_base_uncased\" ) # this is where you want to save the model MODEL_PATH = os . path . join ( HOME_DIR , \"data\" , \"model.bin\" ) # training file TRAINING_FILE = os . path . join ( HOME_DIR , \"data\" , \"imdb.csv\" ) TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u73b0\u5728\uff0c\u4ee3\u7801\u5047\u5b9a\u6240\u6709\u5185\u5bb9\u90fd\u5728\u4e3b\u76ee\u5f55\u4e0b\u540d\u4e3a data \u7684\u6587\u4ef6\u5939\u4e2d\u3002 \u8bf7\u6ce8\u610f\uff0c\u5982\u679c Python \u811a\u672c\u6709\u4efb\u4f55\u6539\u52a8\uff0c\u90fd\u610f\u5473\u7740\u9700\u8981\u91cd\u5efa Docker \u5bb9\u5668\uff01\u56e0\u6b64\uff0c\u6211\u4eec\u91cd\u5efa\u5bb9\u5668\uff0c\u7136\u540e\u91cd\u65b0\u8fd0\u884c Docker \u547d\u4ee4\uff0c\u4f46\u8fd9\u6b21\u8981\u6709\u6240\u6539\u53d8\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u6ca1\u6709\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u8fd9\u4e5f\u662f\u884c\u4e0d\u901a\u7684\u3002\u522b\u62c5\u5fc3\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a Docker \u5bb9\u5668\u3002\u4f60\u53ea\u9700\u8981\u505a\u4e00\u6b21\u3002\u8981\u5b89\u88c5\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e2d\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\u3002 distribution = $( . /etc/os-release ; echo $ID$VERSION_ID ) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/ $distribution /nvidia- docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u6784\u5efa\u6211\u4eec\u7684\u5bb9\u5668\uff0c\u5e76\u5f00\u59cb\u8bad\u7ec3\u8fc7\u7a0b\uff1a docker run --gpus 1 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:train python3 train.py \u5176\u4e2d\uff0c-gpus 1 \u8868\u793a\u6211\u4eec\u5728 docker \u5bb9\u5668\u4e2d\u4f7f\u7528 1 \u4e2a GPU\uff0c-v \u8868\u793a\u6302\u8f7d\u5377\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u8981\u5c06\u672c\u5730\u76ee\u5f55 /home/abhishek/workspace/approaching_almost/input/ \u6302\u8f7d\u5230 docker \u5bb9\u5668\u4e2d\u7684 /home/abhishek/data/\u3002\u8fd9\u4e00\u6b65\u8981\u82b1\u70b9\u65f6\u95f4\uff0c\u4f46\u5b8c\u6210\u540e\uff0c\u672c\u5730\u6587\u4ef6\u5939\u4e2d\u5c31\u4f1a\u6709 model.bin\u3002 \u8fd9\u6837\uff0c\u53ea\u9700\u505a\u4e00\u4e9b\u7b80\u5355\u7684\u6539\u52a8\uff0c\u4f60\u7684\u8bad\u7ec3\u4ee3\u7801\u5c31\u5df2\u7ecf \"dockerized \"\u4e86\u3002\u73b0\u5728\uff0c\u4f60\u53ef\u4ee5\u5728\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u7cfb\u7edf\u4e0a\u4f7f\u7528\u8fd9\u4e9b\u4ee3\u7801\u8fdb\u884c\u8bad\u7ec3\u3002 \u4e0b\u4e00\u90e8\u5206\u662f\u5c06\u6211\u4eec\u8bad\u7ec3\u597d\u7684\u6a21\u578b \"\u63d0\u4f9b \"\u7ed9\u6700\u7ec8\u7528\u6237\u3002\u5047\u8bbe\u60a8\u60f3\u4ece\u63a5\u6536\u5230\u7684\u63a8\u6587\u6d41\u4e2d\u63d0\u53d6\u60c5\u611f\u4fe1\u606f\u3002\u8981\u5b8c\u6210\u8fd9\u9879\u4efb\u52a1\uff0c\u60a8\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a API\uff0c\u7528\u4e8e\u8f93\u5165\u53e5\u5b50\uff0c\u7136\u540e\u8fd4\u56de\u5e26\u6709\u60c5\u611f\u6982\u7387\u7684\u8f93\u51fa\u3002\u4f7f\u7528 Python \u6784\u5efa API \u7684\u6700\u5e38\u89c1\u65b9\u6cd5\u662f\u4f7f\u7528 Flask \uff0c\u5b83\u662f\u4e00\u4e2a\u5fae\u578b\u7f51\u7edc\u670d\u52a1\u6846\u67b6\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) MODEL = None DEVICE = \"cuda\" def sentence_prediction ( sentence ): tokenizer = config . TOKENIZER max_len = config . MAX_LEN review = str ( sentence ) review = \" \" . join ( review . split ()) inputs = tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = max_len ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] padding_length = max_len - len ( ids ) ids = ids + ([ 0 ] * padding_length ) mask = mask + ([ 0 ] * padding_length ) token_type_ids = token_type_ids + ([ 0 ] * padding_length ) ids = torch . tensor ( ids , dtype = torch . long ) . unsqueeze ( 0 ) mask = torch . tensor ( mask , dtype = torch . long ) . unsqueeze ( 0 ) token_type_ids = torch . tensor ( token_type_ids , dtype = torch . long ) . unsqueeze ( 0 ) ids = ids . to ( DEVICE , dtype = torch . long ) token_type_ids = token_type_ids . to ( DEVICE , dtype = torch . long ) mask = mask . to ( DEVICE , dtype = torch . long ) outputs = MODEL ( ids = ids , mask = mask , token_type_ids = token_type_ids ) outputs = torch . sigmoid ( outputs ) . cpu () . detach () . numpy () return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): sentence = request . args . get ( \"sentence\" ) start_time = time . time () positive_prediction = sentence_prediction ( sentence ) negative_prediction = 1 - positive_prediction response = {} response [ \"response\" ] = { \"positive\" : str ( positive_prediction ), \"negative\" : str ( negative_prediction ), \"sentence\" : str ( sentence ), \"time_taken\" : str ( time . time () - start_time ), } return flask . jsonify ( response ) if __name__ == \"__main__\" : MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ) )) MODEL . to ( DEVICE ) MODEL . eval () app . run ( host = \"0.0.0.0\" ) \u7136\u540e\u8fd0\u884c \"python api.py \"\u547d\u4ee4\u542f\u52a8 API\u3002API \u5c06\u5728\u7aef\u53e3 5000 \u7684 localhost \u4e0a\u542f\u52a8\u3002cURL \u8bf7\u6c42\u53ca\u5176\u54cd\u5e94\u793a\u4f8b\u5982\u4e0b\u3002 \u276f curl $'http://192.168.86.48:5000/predict?sentence=this%20is%20the%20best%20boo k%20ever' {\"response\":{\"negative\":\"0.0032927393913269043\",\"positive\":\"0.99670726\",\" sentence\":\"this is the best book ever\",\"time_taken\":\"0.029126882553100586\"}} \u53ef\u4ee5\u770b\u5230\uff0c\u6211\u4eec\u5f97\u5230\u7684\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002 \u60a8\u8fd8\u53ef\u4ee5\u8bbf\u95ee http://127.0.0.1:5000/predict?sentence=this%20book%20is%20too%20complicated%20for%20me\u3002\u8fd9\u5c06\u518d\u6b21\u8fd4\u56de\u4e00\u4e2a JSON \u6587\u4ef6\u3002 { response : { negative : \"0.8646619468927383\" , positive : \"0.13533805\" , sentence : \"this book is too complicated for me\" , time_taken : \"0.03852701187133789\" } } \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\uff0c\u53ef\u4ee5\u7528\u6765\u4e3a\u5c11\u91cf\u7528\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u4e3a\u4ec0\u4e48\u662f\u5c11\u91cf\uff1f\u56e0\u4e3a\u8fd9\u4e2a API \u4e00\u6b21\u53ea\u670d\u52a1\u4e00\u4e2a\u8bf7\u6c42\u3002gunicorn \u662f UNIX \u4e0a\u7684 Python WSGI HTTP \u670d\u52a1\u5668\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u5b83\u7684 CPU \u6765\u5904\u7406\u591a\u4e2a\u5e76\u884c\u8bf7\u6c42\u3002Gunicorn \u53ef\u4ee5\u4e3a API \u521b\u5efa\u591a\u4e2a\u8fdb\u7a0b\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u540c\u65f6\u4e3a\u591a\u4e2a\u5ba2\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 \"pip install gunicorn \"\u5b89\u88c5 gunicorn\u3002 \u4e3a\u4e86\u5c06\u4ee3\u7801\u8f6c\u6362\u4e3a\u4e0e gunicorn \u517c\u5bb9\uff0c\u6211\u4eec\u9700\u8981\u79fb\u9664 init main\uff0c\u5e76\u5c06\u5176\u4e2d\u7684\u6240\u6709\u5185\u5bb9\u79fb\u81f3\u5168\u5c40\u8303\u56f4\u3002\u6b64\u5916\uff0c\u6211\u4eec\u73b0\u5728\u4f7f\u7528\u7684\u662f CPU \u800c\u4e0d\u662f GPU\u3002\u4fee\u6539\u540e\u7684\u4ee3\u7801\u5982\u4e0b\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) DEVICE = \"cpu\" MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ))) MODEL . to ( DEVICE ) MODEL . eval () def sentence_prediction ( sentence ): return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): return flask . jsonify ( response ) \u6211\u4eec\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002 gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u5728\u63d0\u4f9b\u7684 IP \u5730\u5740\u548c\u7aef\u53e3\u4e0a\u4f7f\u7528 4 \u4e2a Worker \u8fd0\u884c\u6211\u4eec\u7684 flask \u5e94\u7528\u7a0b\u5e8f\u3002\u7531\u4e8e\u6709 4 \u4e2a Worker\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u540c\u65f6\u5904\u7406 4 \u4e2a\u8bf7\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u73b0\u5728\u6211\u4eec\u7684\u7ec8\u7aef\u4f7f\u7528\u7684\u662f CPU\uff0c\u56e0\u6b64\u4e0d\u9700\u8981 GPU \u673a\u5668\uff0c\u53ef\u4ee5\u5728\u4efb\u4f55\u6807\u51c6\u670d\u52a1\u5668/\u865a\u62df\u673a\u4e0a\u8fd0\u884c\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u8fd8\u6709\u4e00\u4e2a\u95ee\u9898\uff1a\u6211\u4eec\u5df2\u7ecf\u5728\u672c\u5730\u673a\u5668\u4e0a\u5b8c\u6210\u4e86\u6240\u6709\u5de5\u4f5c\uff0c\u56e0\u6b64\u5fc5\u987b\u5c06\u5176\u575e\u5316\u3002\u770b\u770b\u4e0b\u9762\u8fd9\u4e2a\u672a\u6ce8\u91ca\u7684 Dockerfile\uff0c\u5b83\u53ef\u4ee5\u7528\u6765\u90e8\u7f72\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002\u8bf7\u6ce8\u610f\u7528\u4e8e\u57f9\u8bad\u7684\u65e7 Dockerfile \u548c\u8fd9\u4e2a Dockerfile \u4e4b\u95f4\u7684\u533a\u522b\u3002\u533a\u522b\u4e0d\u5927\u3002 # CPU Dockerfile FROM ubuntu:18.04 RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* RUN useradd -m abhishek RUN chown -R abhishek:abhishek /home/abhishek/ COPY --chown = abhishek *.* /home/abhishek/app/ USER abhishek RUN mkdir /home/abhishek/data/ RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt RUN pip3 install mkl WORKDIR /home/abhishek/app \u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u65b0\u7684 Docker \u5bb9\u5668\u3002 docker build -f Dockerfile -t bert:api \u5f53 Docker \u5bb9\u5668\u6784\u5efa\u5b8c\u6210\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u76f4\u63a5\u8fd0\u884c API \u4e86\u3002 docker run -p 5000 :5000 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:api /home/abhishek/.local/bin/gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5c06\u5bb9\u5668\u5185\u7684 5000 \u7aef\u53e3\u66b4\u9732\u7ed9\u5bb9\u5668\u5916\u7684 5000 \u7aef\u53e3\u3002\u5982\u679c\u4f7f\u7528 docker-compose\uff0c\u4e5f\u53ef\u4ee5\u5f88\u597d\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002Dockercompose \u662f\u4e00\u4e2a\u53ef\u4ee5\u8ba9\u4f60\u540c\u65f6\u5728\u4e0d\u540c\u6216\u76f8\u540c\u5bb9\u5668\u4e2d\u8fd0\u884c\u4e0d\u540c\u670d\u52a1\u7684\u5de5\u5177\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 \"pip install docker-compose \"\u5b89\u88c5 docker-compose\uff0c\u7136\u540e\u5728\u6784\u5efa\u5bb9\u5668\u540e\u8fd0\u884c \"docker-compose up\"\u3002\u8981\u4f7f\u7528 docker-compose\uff0c\u4f60\u9700\u8981\u4e00\u4e2a docker-compose.yml \u6587\u4ef6\u3002 # docker-compose.yml # specify a version of the compose version : '3.7' # you can add multiple services services : # specify service name. we call our service: api api : # specify image name image : bert:api # the command that you would like to run inside the container command : /home/abhishek/.local/bin/gunicorn api:app --bind 0.0.0.0:5000 --workers 4 # mount the volume volumes : - /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ # this ensures that our ports from container will be # exposed as it is network_mode : host \u73b0\u5728\uff0c\u60a8\u53ea\u9700\u4f7f\u7528\u4e0a\u8ff0\u547d\u4ee4\u5373\u53ef\u91cd\u65b0\u8fd0\u884c API\uff0c\u5176\u8fd0\u884c\u65b9\u5f0f\u4e0e\u4e4b\u524d\u76f8\u540c\u3002\u606d\u559c\u4f60\uff0c\u73b0\u5728\uff0c\u4f60\u4e5f\u5df2\u7ecf\u6210\u529f\u5730\u5c06\u9884\u6d4b API \u8fdb\u884c\u4e86 Docker \u5316\uff0c\u53ef\u4ee5\u968f\u65f6\u968f\u5730\u90e8\u7f72\u4e86\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u5b66\u4e60\u4e86 Docker\u3001\u4f7f\u7528 flask \u6784\u5efa API\u3001\u4f7f\u7528 gunicorn \u548c Docker \u670d\u52a1 API \u4ee5\u53ca docker-compose\u3002\u5173\u4e8e docker \u7684\u77e5\u8bc6\u8fdc\u4e0d\u6b62\u8fd9\u4e9b\uff0c\u4f46\u8fd9\u5e94\u8be5\u662f\u4e00\u4e2a\u5f00\u59cb\u3002\u5176\u4ed6\u5185\u5bb9\u53ef\u4ee5\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9010\u6e10\u638c\u63e1\u3002 \u6211\u4eec\u8fd8\u8df3\u8fc7\u4e86\u8bb8\u591a\u5de5\u5177\uff0c\u5982 kubernetes\u3001bean-stalk\u3001sagemaker\u3001heroku \u548c\u8bb8\u591a\u5176\u4ed6\u5de5\u5177\uff0c\u8fd9\u4e9b\u5de5\u5177\u5982\u4eca\u88ab\u4eba\u4eec\u7528\u6765\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u6a21\u578b\u3002\"\u6211\u8981\u5199\u4ec0\u4e48\uff1f\u70b9\u51fb\u4fee\u6539\u56fe X \u4e2d\u7684 docker \u5bb9\u5668\"\uff1f\u5728\u4e66\u4e2d\u63cf\u8ff0\u8fd9\u4e9b\u662f\u4e0d\u53ef\u884c\u7684\uff0c\u4e5f\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u6240\u4ee5\u6211\u5c06\u4f7f\u7528\u4e0d\u540c\u7684\u5a92\u4ecb\u6765\u8d5e\u7f8e\u672c\u4e66\u7684\u8fd9\u4e00\u90e8\u5206\u3002\u8bf7\u8bb0\u4f4f\uff0c\u4e00\u65e6\u4f60\u5bf9\u5e94\u7528\u7a0b\u5e8f\u8fdb\u884c\u4e86 Docker \u5316\uff0c\u4f7f\u7528\u8fd9\u4e9b\u6280\u672f/\u5e73\u53f0\u8fdb\u884c\u90e8\u7f72\u5c31\u53d8\u5f97\u6613\u5982\u53cd\u638c\u4e86\u3002\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u8981\u8ba9\u4f60\u7684\u4ee3\u7801\u548c\u6a21\u578b\u5bf9\u4ed6\u4eba\u53ef\u7528\uff0c\u5e76\u505a\u597d\u6587\u6863\u8bb0\u5f55\uff0c\u8fd9\u6837\u4efb\u4f55\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u4f60\u5f00\u53d1\u7684\u4e1c\u897f\uff0c\u800c\u65e0\u9700\u591a\u6b21\u8be2\u95ee\u4f60\u3002\u8fd9\u4e0d\u4ec5\u80fd\u8282\u7701\u60a8\u7684\u65f6\u95f4\uff0c\u8fd8\u80fd\u8282\u7701\u4ed6\u4eba\u7684\u65f6\u95f4\u3002\u597d\u7684\u3001\u5f00\u6e90\u7684\u3001\u53ef\u91cd\u590d\u4f7f\u7528\u7684\u4ee3\u7801\u5728\u60a8\u7684\u4f5c\u54c1\u96c6\u4e2d\u4e5f\u975e\u5e38\u91cd\u8981\u3002","title":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5"},{"location":"%E5%8F%AF%E9%87%8D%E5%A4%8D%E4%BB%A3%E7%A0%81%E5%92%8C%E6%A8%A1%E5%9E%8B%E6%96%B9%E6%B3%95/#_1","text":"\u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u5230\u4e86\u53ef\u4ee5\u5c06\u6a21\u578b/\u8bad\u7ec3\u4ee3\u7801\u5206\u53d1\u7ed9\u4ed6\u4eba\u4f7f\u7528\u7684\u9636\u6bb5\u3002\u60a8\u53ef\u4ee5\u7528\u8f6f\u76d8\u5206\u53d1\u6216\u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\uff0c\u4f46\u8fd9\u5e76\u4e0d\u7406\u60f3\u3002\u662f\u8fd9\u6837\u5417\uff1f\u4e5f\u8bb8\u5f88\u591a\u5e74\u524d\uff0c\u8fd9\u662f\u7406\u60f3\u7684\u505a\u6cd5\uff0c\u4f46\u73b0\u5728\u4e0d\u662f\u4e86\u3002 \u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\u548c\u534f\u4f5c\u7684\u9996\u9009\u65b9\u5f0f\u662f\u4f7f\u7528\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u3002Git \u662f\u6700\u6d41\u884c\u7684\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u4e4b\u4e00\u3002\u90a3\u4e48\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u5b66\u4f1a\u4e86 Git\uff0c\u5e76\u6b63\u786e\u5730\u683c\u5f0f\u5316\u4e86\u4ee3\u7801\uff0c\u7f16\u5199\u4e86\u9002\u5f53\u7684\u6587\u6863\uff0c\u8fd8\u5f00\u6e90\u4e86\u4f60\u7684\u9879\u76ee\u3002\u8fd9\u5c31\u591f\u4e86\u5417\uff1f\u4e0d\uff0c\u8fd8\u4e0d\u591f\u3002\u56e0\u4e3a\u4f60\u5728\u81ea\u5df1\u7684\u7535\u8111\u4e0a\u5199\u7684\u4ee3\u7801\uff0c\u5728\u522b\u4eba\u7684\u7535\u8111\u4e0a\u53ef\u80fd\u4f1a\u56e0\u4e3a\u5404\u79cd\u539f\u56e0\u800c\u65e0\u6cd5\u8fd0\u884c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u5728\u53d1\u5e03\u4ee3\u7801\u65f6\u80fd\u590d\u5236\u81ea\u5df1\u7684\u7535\u8111\uff0c\u800c\u5176\u4ed6\u4eba\u5728\u5b89\u88c5\u60a8\u7684\u8f6f\u4ef6\u6216\u8fd0\u884c\u60a8\u7684\u4ee3\u7801\u65f6\u4e5f\u80fd\u590d\u5236\u60a8\u7684\u7535\u8111\uff0c\u90a3\u5c31\u518d\u597d\u4e0d\u8fc7\u4e86\u3002\u4e3a\u6b64\uff0c\u5982\u4eca\u6700\u6d41\u884c\u7684\u65b9\u6cd5\u662f\u4f7f\u7528 Docker \u5bb9\u5668\uff08Docker Containers\uff09\u3002\u8981\u4f7f\u7528 Docker \u5bb9\u5668\uff0c\u4f60\u9700\u8981\u5b89\u88c5 Docker\u3002 \u8ba9\u6211\u4eec\u7528\u4e0b\u9762\u7684\u547d\u4ee4\u6765\u5b89\u88c5 Docker\u3002 sudo apt install docker.io sudo systemctl start docker sudo systemctl enable docker sudo groupadd docker sudo usermod -aG docker $USER \u8fd9\u4e9b\u547d\u4ee4\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e0a\u8fd0\u884c\u3002Docker \u6700\u68d2\u7684\u5730\u65b9\u5728\u4e8e\u5b83\u53ef\u4ee5\u5b89\u88c5\u5728\u4efb\u4f55\u673a\u5668\u4e0a\uff1a Linux\u3001Windows\u3001OSX\u3002\u56e0\u6b64\uff0c\u5982\u679c\u4f60\u4e00\u76f4\u5728 Docker \u5bb9\u5668\u4e2d\u5de5\u4f5c\uff0c\u54ea\u53f0\u673a\u5668\u90fd\u6ca1\u5173\u7cfb\uff01 Docker \u5bb9\u5668\u53ef\u4ee5\u88ab\u89c6\u4e3a\u5c0f\u578b\u865a\u62df\u673a\u3002\u4f60\u53ef\u4ee5\u4e3a\u4f60\u7684\u4ee3\u7801\u521b\u5efa\u4e00\u4e2a\u5bb9\u5668\uff0c\u7136\u540e\u6bcf\u4e2a\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u548c\u8bbf\u95ee\u5b83\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u521b\u5efa\u53ef\u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u5bb9\u5668\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e00\u7ae0\u4e2d\u8bad\u7ec3\u7684 BERT \u6a21\u578b\uff0c\u5e76\u5c1d\u8bd5\u5c06\u8bad\u7ec3\u4ee3\u7801\u5bb9\u5668\u5316\u3002 \u9996\u5148\uff0c\u4f60\u9700\u8981\u4e00\u4e2a\u5305\u542b python \u9879\u76ee\u9700\u6c42\u7684\u6587\u4ef6\u3002\u9700\u6c42\u5305\u542b\u5728\u540d\u4e3a requirements.txt \u7684\u6587\u4ef6\u4e2d\u3002\u6587\u4ef6\u540d\u662f thestandard\u3002\u8be5\u6587\u4ef6\u5305\u542b\u9879\u76ee\u4e2d\u4f7f\u7528\u7684\u6240\u6709 python \u5e93\u3002\u4e5f\u5c31\u662f\u53ef\u4ee5\u901a\u8fc7 PyPI (pip) \u4e0b\u8f7d\u7684 python \u5e93\u3002\u7528\u4e8e \u8bad\u7ec3 BERT \u6a21\u578b\u4ee5\u68c0\u6d4b\u6b63/\u8d1f\u60c5\u611f\uff0c\u6211\u4eec\u4f7f\u7528\u4e86 torch\u3001transformers\u3001tqdm\u3001scikit-learn\u3001pandas \u548c numpy\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u4eec\u5199\u5165 requirements.txt \u4e2d\u3002\u4f60\u53ef\u4ee5\u53ea\u5199\u540d\u79f0\uff0c\u4e5f\u53ef\u4ee5\u5305\u62ec\u7248\u672c\u3002\u5305\u542b\u7248\u672c\u603b\u662f\u6700\u597d\u7684\uff0c\u8fd9\u4e5f\u662f\u4f60\u5e94\u8be5\u505a\u7684\u3002\u5305\u542b\u7248\u672c\u540e\uff0c\u53ef\u4ee5\u786e\u4fdd\u5176\u4ed6\u4eba\u4f7f\u7528\u7684\u7248\u672c\u4e0e\u4f60\u7684\u7248\u672c\u76f8\u540c\uff0c\u800c\u4e0d\u662f\u6700\u65b0\u7248\u672c\uff0c\u56e0\u4e3a\u6700\u65b0\u7248\u672c\u53ef\u80fd\u4f1a\u66f4\u6539\u67d0\u4e9b\u5185\u5bb9\uff0c\u5982\u679c\u662f\u8fd9\u6837\u7684\u8bdd\uff0c\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u5c31\u4e0d\u4f1a\u4e0e\u4f60\u7684\u76f8\u540c\u4e86\u3002 \u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u663e\u793a\u4e86 requirements.txt\u3002 # requirements.txt pandas == 1.0.4 scikit - learn == 0.22.1 torch == 1.5.0 transformers == 2.11.0 \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a Dockerfile \u7684 Docker \u6587\u4ef6\u3002\u6ca1\u6709\u6269\u5c55\u540d\u3002Dockerfile \u6709\u51e0\u4e2a\u5143\u7d20\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 # Dockerfile # First of all, we include where we are getting the image # from. Image can be thought of as an operating system. # You can do \"FROM ubuntu:18.04\" # this will start from a clean ubuntu 18.04 image. # All images are downloaded from dockerhub # Here are we grabbing image from nvidia's repo # they created a docker image using ubuntu 18.04 # and installed cuda 10.1 and cudnn7 in it. Thus, we don't have to # install it. Makes our life easy. FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 # this is the same apt-get command that you are used to # except the fact that, we have -y argument. Its because # when we build this container, we cannot press Y when asked for RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* # We add a new user called \"abhishek\" # this can be anything. Anything you want it # to be. Usually, we don't use our own name, # you can use \"user\" or \"ubuntu\" RUN useradd -m abhishek # make our user own its own home directory RUN chown -R abhishek:abhishek /home/abhishek/ # copy all files from this direrctory to a # directory called app inside the home of abhishek # and abhishek owns it. COPY --chown = abhishek *.* /home/abhishek/app/ # change to user abhishek USER abhishek RUN mkdir /home/abhishek/data/ # Now we install all the requirements # after moving to the app directory # PLEASE NOTE that ubuntu 18.04 image # has python 3.6.9 and not python 3.7.6 # you can also install conda python here and use that # however, to simplify it, I will be using python 3.6.9 # inside the docker container!!!! RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt # install mkl. its needed for transformers RUN pip3 install mkl # when we log into the docker container, # we will go inside this directory automatically WORKDIR /home/abhishek/app \u521b\u5efa\u597d Docker \u6587\u4ef6\u540e\uff0c\u6211\u4eec\u5c31\u9700\u8981\u6784\u5efa\u5b83\u3002\u6784\u5efa Docker \u5bb9\u5668\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u547d\u4ee4\u3002 docker build -f Dockerfile -t bert:train . \u8be5\u547d\u4ee4\u6839\u636e\u63d0\u4f9b\u7684 Dockerfile \u6784\u5efa\u4e00\u4e2a\u5bb9\u5668\u3002Docker \u5bb9\u5668\u7684\u540d\u79f0\u662f bert:train\u3002\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a \u276f docker build -f Dockerfile -t bert:train . Sending build context to Docker daemon 19.97kB Step 1/7 : FROM nvidia/cuda:10.1-cudnn7-ubuntu18.04 ---> 3b55548ae91f Step 2/7 : RUN apt-get update && apt-get install -y git curl ca- certificates python3 python3-pip sudo && rm -rf /var/lib/apt/lists/* . . . . Removing intermediate container 8f6975dd08ba ---> d1802ac9f1b4 Step 7/7 : WORKDIR /home/abhishek/app ---> Running in 257ff09502ed Removing intermediate container 257ff09502ed ---> e5f6eb4cddd7 Successfully built e5f6eb4cddd7 Successfully tagged bert:train \u8bf7\u6ce8\u610f\uff0c\u6211\u5220\u9664\u4e86\u8f93\u51fa\u4e2d\u7684\u8bb8\u591a\u884c\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u767b\u5f55\u5bb9\u5668\u3002 docker run -ti bert:train /bin/bash \u4f60\u9700\u8981\u8bb0\u4f4f\uff0c\u4e00\u65e6\u9000\u51fa shell\uff0c\u4f60\u5728 shell \u4e2d\u6240\u505a\u7684\u4e00\u5207\u90fd\u5c06\u4e22\u5931\u3002\u4f60\u8fd8\u53ef\u4ee5\u5728 Docker \u5bb9\u5668\u4e2d\u4f7f\u7528\u3002 docker run -ti bert:train python3 train.py \u8f93\u51fa\u60c5\u51b5\uff1a Traceback (most recent call last): File \"train.py\", line 2, in import config File \"/home/abhishek/app/config.py\", line 28, in do_lower_case=True File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 393, in from_pretrained return cls._from_pretrained(*inputs, **kwargs) File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 496, in _from_pretrained list(cls.vocab_files_names.values()), OSError: Model name '../input/bert_base_uncased/' was not found in tokenizers model name list (bert-base-uncased, bert-large-uncased, bert- base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base- multilingual-cased, bert-base-chinese, bert-base-german-cased, bert- large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased- whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert- base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base- finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). We assumed '../input/bert_base_uncased/' was a path, a model identifier, or url to a directory containing vocabulary files named ['vocab.txt'] but couldn't find such vocabulary files at this path or url. \u54ce\u5440\uff0c\u51fa\u9519\u4e86\uff01 \u6211\u4e3a\u4ec0\u4e48\u8981\u628a\u9519\u8bef\u5370\u5728\u4e66\u4e0a\u5462\uff1f \u56e0\u4e3a\u7406\u89e3\u8fd9\u4e2a\u9519\u8bef\u975e\u5e38\u91cd\u8981\u3002\u8fd9\u4e2a\u9519\u8bef\u8bf4\u660e\u4ee3\u7801\u65e0\u6cd5\u627e\u5230\u76ee\u5f55\".../input/bert_base_cased\"\u3002\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u5728\u6ca1\u6709 Docker \u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u76ee\u5f55\u548c\u6240\u6709\u6587\u4ef6\u90fd\u5b58\u5728\u3002\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u662f\u56e0\u4e3a Docker \u5c31\u50cf\u4e00\u4e2a\u865a\u62df\u673a\uff01\u5b83\u6709\u81ea\u5df1\u7684\u6587\u4ef6\u7cfb\u7edf\uff0c\u672c\u5730\u673a\u5668\u4e0a\u7684\u6587\u4ef6\u4e0d\u4f1a\u5171\u4eab\u7ed9 Docker \u5bb9\u5668\u3002\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u672c\u5730\u673a\u5668\u4e0a\u7684\u8def\u5f84\u5e76\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\uff0c\u4f60\u9700\u8981\u5728\u8fd0\u884c Docker \u65f6\u5c06\u5176\u6302\u8f7d\u5230 Docker \u5bb9\u5668\u4e0a\u3002\u5f53\u6211\u4eec\u67e5\u770b\u8fd9\u4e2a\u6587\u4ef6\u5939\u7684\u8def\u5f84\u65f6\uff0c\u6211\u4eec\u77e5\u9053\u5b83\u4f4d\u4e8e\u540d\u4e3a input \u7684\u6587\u4ef6\u5939\u7684\u4e0a\u4e00\u7ea7\u3002\u8ba9\u6211\u4eec\u7a0d\u5fae\u4fee\u6539\u4e00\u4e0b config.py \u6587\u4ef6\uff01 # config.py import os import transformers # fetch home directory # in our docker container, it is # /home/abhishek HOME_DIR = os . path . expanduser ( \"~\" ) # this is the maximum number of tokens in the sentence MAX_LEN = 512 # batch sizes is low because model is huge! TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 # let's train for a maximum of 10 epochs EPOCHS = 10 # define path to BERT model files # Now we assume that all the data is stored inside # /home/abhishek/data BERT_PATH = os . path . join ( HOME_DIR , \"data\" , \"bert_base_uncased\" ) # this is where you want to save the model MODEL_PATH = os . path . join ( HOME_DIR , \"data\" , \"model.bin\" ) # training file TRAINING_FILE = os . path . join ( HOME_DIR , \"data\" , \"imdb.csv\" ) TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u73b0\u5728\uff0c\u4ee3\u7801\u5047\u5b9a\u6240\u6709\u5185\u5bb9\u90fd\u5728\u4e3b\u76ee\u5f55\u4e0b\u540d\u4e3a data \u7684\u6587\u4ef6\u5939\u4e2d\u3002 \u8bf7\u6ce8\u610f\uff0c\u5982\u679c Python \u811a\u672c\u6709\u4efb\u4f55\u6539\u52a8\uff0c\u90fd\u610f\u5473\u7740\u9700\u8981\u91cd\u5efa Docker \u5bb9\u5668\uff01\u56e0\u6b64\uff0c\u6211\u4eec\u91cd\u5efa\u5bb9\u5668\uff0c\u7136\u540e\u91cd\u65b0\u8fd0\u884c Docker \u547d\u4ee4\uff0c\u4f46\u8fd9\u6b21\u8981\u6709\u6240\u6539\u53d8\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u6ca1\u6709\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u8fd9\u4e5f\u662f\u884c\u4e0d\u901a\u7684\u3002\u522b\u62c5\u5fc3\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a Docker \u5bb9\u5668\u3002\u4f60\u53ea\u9700\u8981\u505a\u4e00\u6b21\u3002\u8981\u5b89\u88c5\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e2d\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\u3002 distribution = $( . /etc/os-release ; echo $ID$VERSION_ID ) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/ $distribution /nvidia- docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u6784\u5efa\u6211\u4eec\u7684\u5bb9\u5668\uff0c\u5e76\u5f00\u59cb\u8bad\u7ec3\u8fc7\u7a0b\uff1a docker run --gpus 1 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:train python3 train.py \u5176\u4e2d\uff0c-gpus 1 \u8868\u793a\u6211\u4eec\u5728 docker \u5bb9\u5668\u4e2d\u4f7f\u7528 1 \u4e2a GPU\uff0c-v \u8868\u793a\u6302\u8f7d\u5377\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u8981\u5c06\u672c\u5730\u76ee\u5f55 /home/abhishek/workspace/approaching_almost/input/ \u6302\u8f7d\u5230 docker \u5bb9\u5668\u4e2d\u7684 /home/abhishek/data/\u3002\u8fd9\u4e00\u6b65\u8981\u82b1\u70b9\u65f6\u95f4\uff0c\u4f46\u5b8c\u6210\u540e\uff0c\u672c\u5730\u6587\u4ef6\u5939\u4e2d\u5c31\u4f1a\u6709 model.bin\u3002 \u8fd9\u6837\uff0c\u53ea\u9700\u505a\u4e00\u4e9b\u7b80\u5355\u7684\u6539\u52a8\uff0c\u4f60\u7684\u8bad\u7ec3\u4ee3\u7801\u5c31\u5df2\u7ecf \"dockerized \"\u4e86\u3002\u73b0\u5728\uff0c\u4f60\u53ef\u4ee5\u5728\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u7cfb\u7edf\u4e0a\u4f7f\u7528\u8fd9\u4e9b\u4ee3\u7801\u8fdb\u884c\u8bad\u7ec3\u3002 \u4e0b\u4e00\u90e8\u5206\u662f\u5c06\u6211\u4eec\u8bad\u7ec3\u597d\u7684\u6a21\u578b \"\u63d0\u4f9b \"\u7ed9\u6700\u7ec8\u7528\u6237\u3002\u5047\u8bbe\u60a8\u60f3\u4ece\u63a5\u6536\u5230\u7684\u63a8\u6587\u6d41\u4e2d\u63d0\u53d6\u60c5\u611f\u4fe1\u606f\u3002\u8981\u5b8c\u6210\u8fd9\u9879\u4efb\u52a1\uff0c\u60a8\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a API\uff0c\u7528\u4e8e\u8f93\u5165\u53e5\u5b50\uff0c\u7136\u540e\u8fd4\u56de\u5e26\u6709\u60c5\u611f\u6982\u7387\u7684\u8f93\u51fa\u3002\u4f7f\u7528 Python \u6784\u5efa API \u7684\u6700\u5e38\u89c1\u65b9\u6cd5\u662f\u4f7f\u7528 Flask \uff0c\u5b83\u662f\u4e00\u4e2a\u5fae\u578b\u7f51\u7edc\u670d\u52a1\u6846\u67b6\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) MODEL = None DEVICE = \"cuda\" def sentence_prediction ( sentence ): tokenizer = config . TOKENIZER max_len = config . MAX_LEN review = str ( sentence ) review = \" \" . join ( review . split ()) inputs = tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = max_len ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] padding_length = max_len - len ( ids ) ids = ids + ([ 0 ] * padding_length ) mask = mask + ([ 0 ] * padding_length ) token_type_ids = token_type_ids + ([ 0 ] * padding_length ) ids = torch . tensor ( ids , dtype = torch . long ) . unsqueeze ( 0 ) mask = torch . tensor ( mask , dtype = torch . long ) . unsqueeze ( 0 ) token_type_ids = torch . tensor ( token_type_ids , dtype = torch . long ) . unsqueeze ( 0 ) ids = ids . to ( DEVICE , dtype = torch . long ) token_type_ids = token_type_ids . to ( DEVICE , dtype = torch . long ) mask = mask . to ( DEVICE , dtype = torch . long ) outputs = MODEL ( ids = ids , mask = mask , token_type_ids = token_type_ids ) outputs = torch . sigmoid ( outputs ) . cpu () . detach () . numpy () return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): sentence = request . args . get ( \"sentence\" ) start_time = time . time () positive_prediction = sentence_prediction ( sentence ) negative_prediction = 1 - positive_prediction response = {} response [ \"response\" ] = { \"positive\" : str ( positive_prediction ), \"negative\" : str ( negative_prediction ), \"sentence\" : str ( sentence ), \"time_taken\" : str ( time . time () - start_time ), } return flask . jsonify ( response ) if __name__ == \"__main__\" : MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ) )) MODEL . to ( DEVICE ) MODEL . eval () app . run ( host = \"0.0.0.0\" ) \u7136\u540e\u8fd0\u884c \"python api.py \"\u547d\u4ee4\u542f\u52a8 API\u3002API \u5c06\u5728\u7aef\u53e3 5000 \u7684 localhost \u4e0a\u542f\u52a8\u3002cURL \u8bf7\u6c42\u53ca\u5176\u54cd\u5e94\u793a\u4f8b\u5982\u4e0b\u3002 \u276f curl $'http://192.168.86.48:5000/predict?sentence=this%20is%20the%20best%20boo k%20ever' {\"response\":{\"negative\":\"0.0032927393913269043\",\"positive\":\"0.99670726\",\" sentence\":\"this is the best book ever\",\"time_taken\":\"0.029126882553100586\"}} \u53ef\u4ee5\u770b\u5230\uff0c\u6211\u4eec\u5f97\u5230\u7684\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002 \u60a8\u8fd8\u53ef\u4ee5\u8bbf\u95ee http://127.0.0.1:5000/predict?sentence=this%20book%20is%20too%20complicated%20for%20me\u3002\u8fd9\u5c06\u518d\u6b21\u8fd4\u56de\u4e00\u4e2a JSON \u6587\u4ef6\u3002 { response : { negative : \"0.8646619468927383\" , positive : \"0.13533805\" , sentence : \"this book is too complicated for me\" , time_taken : \"0.03852701187133789\" } } \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\uff0c\u53ef\u4ee5\u7528\u6765\u4e3a\u5c11\u91cf\u7528\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u4e3a\u4ec0\u4e48\u662f\u5c11\u91cf\uff1f\u56e0\u4e3a\u8fd9\u4e2a API \u4e00\u6b21\u53ea\u670d\u52a1\u4e00\u4e2a\u8bf7\u6c42\u3002gunicorn \u662f UNIX \u4e0a\u7684 Python WSGI HTTP \u670d\u52a1\u5668\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u5b83\u7684 CPU \u6765\u5904\u7406\u591a\u4e2a\u5e76\u884c\u8bf7\u6c42\u3002Gunicorn \u53ef\u4ee5\u4e3a API \u521b\u5efa\u591a\u4e2a\u8fdb\u7a0b\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u540c\u65f6\u4e3a\u591a\u4e2a\u5ba2\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 \"pip install gunicorn \"\u5b89\u88c5 gunicorn\u3002 \u4e3a\u4e86\u5c06\u4ee3\u7801\u8f6c\u6362\u4e3a\u4e0e gunicorn \u517c\u5bb9\uff0c\u6211\u4eec\u9700\u8981\u79fb\u9664 init main\uff0c\u5e76\u5c06\u5176\u4e2d\u7684\u6240\u6709\u5185\u5bb9\u79fb\u81f3\u5168\u5c40\u8303\u56f4\u3002\u6b64\u5916\uff0c\u6211\u4eec\u73b0\u5728\u4f7f\u7528\u7684\u662f CPU \u800c\u4e0d\u662f GPU\u3002\u4fee\u6539\u540e\u7684\u4ee3\u7801\u5982\u4e0b\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) DEVICE = \"cpu\" MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ))) MODEL . to ( DEVICE ) MODEL . eval () def sentence_prediction ( sentence ): return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): return flask . jsonify ( response ) \u6211\u4eec\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002 gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u5728\u63d0\u4f9b\u7684 IP \u5730\u5740\u548c\u7aef\u53e3\u4e0a\u4f7f\u7528 4 \u4e2a Worker \u8fd0\u884c\u6211\u4eec\u7684 flask \u5e94\u7528\u7a0b\u5e8f\u3002\u7531\u4e8e\u6709 4 \u4e2a Worker\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u540c\u65f6\u5904\u7406 4 \u4e2a\u8bf7\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u73b0\u5728\u6211\u4eec\u7684\u7ec8\u7aef\u4f7f\u7528\u7684\u662f CPU\uff0c\u56e0\u6b64\u4e0d\u9700\u8981 GPU \u673a\u5668\uff0c\u53ef\u4ee5\u5728\u4efb\u4f55\u6807\u51c6\u670d\u52a1\u5668/\u865a\u62df\u673a\u4e0a\u8fd0\u884c\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u8fd8\u6709\u4e00\u4e2a\u95ee\u9898\uff1a\u6211\u4eec\u5df2\u7ecf\u5728\u672c\u5730\u673a\u5668\u4e0a\u5b8c\u6210\u4e86\u6240\u6709\u5de5\u4f5c\uff0c\u56e0\u6b64\u5fc5\u987b\u5c06\u5176\u575e\u5316\u3002\u770b\u770b\u4e0b\u9762\u8fd9\u4e2a\u672a\u6ce8\u91ca\u7684 Dockerfile\uff0c\u5b83\u53ef\u4ee5\u7528\u6765\u90e8\u7f72\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002\u8bf7\u6ce8\u610f\u7528\u4e8e\u57f9\u8bad\u7684\u65e7 Dockerfile \u548c\u8fd9\u4e2a Dockerfile \u4e4b\u95f4\u7684\u533a\u522b\u3002\u533a\u522b\u4e0d\u5927\u3002 # CPU Dockerfile FROM ubuntu:18.04 RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* RUN useradd -m abhishek RUN chown -R abhishek:abhishek /home/abhishek/ COPY --chown = abhishek *.* /home/abhishek/app/ USER abhishek RUN mkdir /home/abhishek/data/ RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt RUN pip3 install mkl WORKDIR /home/abhishek/app \u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u65b0\u7684 Docker \u5bb9\u5668\u3002 docker build -f Dockerfile -t bert:api \u5f53 Docker \u5bb9\u5668\u6784\u5efa\u5b8c\u6210\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u76f4\u63a5\u8fd0\u884c API \u4e86\u3002 docker run -p 5000 :5000 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:api /home/abhishek/.local/bin/gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5c06\u5bb9\u5668\u5185\u7684 5000 \u7aef\u53e3\u66b4\u9732\u7ed9\u5bb9\u5668\u5916\u7684 5000 \u7aef\u53e3\u3002\u5982\u679c\u4f7f\u7528 docker-compose\uff0c\u4e5f\u53ef\u4ee5\u5f88\u597d\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002Dockercompose \u662f\u4e00\u4e2a\u53ef\u4ee5\u8ba9\u4f60\u540c\u65f6\u5728\u4e0d\u540c\u6216\u76f8\u540c\u5bb9\u5668\u4e2d\u8fd0\u884c\u4e0d\u540c\u670d\u52a1\u7684\u5de5\u5177\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 \"pip install docker-compose \"\u5b89\u88c5 docker-compose\uff0c\u7136\u540e\u5728\u6784\u5efa\u5bb9\u5668\u540e\u8fd0\u884c \"docker-compose up\"\u3002\u8981\u4f7f\u7528 docker-compose\uff0c\u4f60\u9700\u8981\u4e00\u4e2a docker-compose.yml \u6587\u4ef6\u3002 # docker-compose.yml # specify a version of the compose version : '3.7' # you can add multiple services services : # specify service name. we call our service: api api : # specify image name image : bert:api # the command that you would like to run inside the container command : /home/abhishek/.local/bin/gunicorn api:app --bind 0.0.0.0:5000 --workers 4 # mount the volume volumes : - /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ # this ensures that our ports from container will be # exposed as it is network_mode : host \u73b0\u5728\uff0c\u60a8\u53ea\u9700\u4f7f\u7528\u4e0a\u8ff0\u547d\u4ee4\u5373\u53ef\u91cd\u65b0\u8fd0\u884c API\uff0c\u5176\u8fd0\u884c\u65b9\u5f0f\u4e0e\u4e4b\u524d\u76f8\u540c\u3002\u606d\u559c\u4f60\uff0c\u73b0\u5728\uff0c\u4f60\u4e5f\u5df2\u7ecf\u6210\u529f\u5730\u5c06\u9884\u6d4b API \u8fdb\u884c\u4e86 Docker \u5316\uff0c\u53ef\u4ee5\u968f\u65f6\u968f\u5730\u90e8\u7f72\u4e86\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u5b66\u4e60\u4e86 Docker\u3001\u4f7f\u7528 flask \u6784\u5efa API\u3001\u4f7f\u7528 gunicorn \u548c Docker \u670d\u52a1 API \u4ee5\u53ca docker-compose\u3002\u5173\u4e8e docker \u7684\u77e5\u8bc6\u8fdc\u4e0d\u6b62\u8fd9\u4e9b\uff0c\u4f46\u8fd9\u5e94\u8be5\u662f\u4e00\u4e2a\u5f00\u59cb\u3002\u5176\u4ed6\u5185\u5bb9\u53ef\u4ee5\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9010\u6e10\u638c\u63e1\u3002 \u6211\u4eec\u8fd8\u8df3\u8fc7\u4e86\u8bb8\u591a\u5de5\u5177\uff0c\u5982 kubernetes\u3001bean-stalk\u3001sagemaker\u3001heroku \u548c\u8bb8\u591a\u5176\u4ed6\u5de5\u5177\uff0c\u8fd9\u4e9b\u5de5\u5177\u5982\u4eca\u88ab\u4eba\u4eec\u7528\u6765\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u6a21\u578b\u3002\"\u6211\u8981\u5199\u4ec0\u4e48\uff1f\u70b9\u51fb\u4fee\u6539\u56fe X \u4e2d\u7684 docker \u5bb9\u5668\"\uff1f\u5728\u4e66\u4e2d\u63cf\u8ff0\u8fd9\u4e9b\u662f\u4e0d\u53ef\u884c\u7684\uff0c\u4e5f\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u6240\u4ee5\u6211\u5c06\u4f7f\u7528\u4e0d\u540c\u7684\u5a92\u4ecb\u6765\u8d5e\u7f8e\u672c\u4e66\u7684\u8fd9\u4e00\u90e8\u5206\u3002\u8bf7\u8bb0\u4f4f\uff0c\u4e00\u65e6\u4f60\u5bf9\u5e94\u7528\u7a0b\u5e8f\u8fdb\u884c\u4e86 Docker \u5316\uff0c\u4f7f\u7528\u8fd9\u4e9b\u6280\u672f/\u5e73\u53f0\u8fdb\u884c\u90e8\u7f72\u5c31\u53d8\u5f97\u6613\u5982\u53cd\u638c\u4e86\u3002\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u8981\u8ba9\u4f60\u7684\u4ee3\u7801\u548c\u6a21\u578b\u5bf9\u4ed6\u4eba\u53ef\u7528\uff0c\u5e76\u505a\u597d\u6587\u6863\u8bb0\u5f55\uff0c\u8fd9\u6837\u4efb\u4f55\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u4f60\u5f00\u53d1\u7684\u4e1c\u897f\uff0c\u800c\u65e0\u9700\u591a\u6b21\u8be2\u95ee\u4f60\u3002\u8fd9\u4e0d\u4ec5\u80fd\u8282\u7701\u60a8\u7684\u65f6\u95f4\uff0c\u8fd8\u80fd\u8282\u7701\u4ed6\u4eba\u7684\u65f6\u95f4\u3002\u597d\u7684\u3001\u5f00\u6e90\u7684\u3001\u53ef\u91cd\u590d\u4f7f\u7528\u7684\u4ee3\u7801\u5728\u60a8\u7684\u4f5c\u54c1\u96c6\u4e2d\u4e5f\u975e\u5e38\u91cd\u8981\u3002","title":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5"},{"location":"%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB%E5%92%8C%E5%88%86%E5%89%B2%E6%96%B9%E6%B3%95/","text":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5 \u8bf4\u5230\u56fe\u50cf\uff0c\u8fc7\u53bb\u51e0\u5e74\u53d6\u5f97\u4e86\u5f88\u591a\u6210\u5c31\u3002\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8fdb\u6b65\u76f8\u5f53\u5feb\uff0c\u611f\u89c9\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8bb8\u591a\u95ee\u9898\u73b0\u5728\u90fd\u66f4\u5bb9\u6613\u89e3\u51b3\u4e86\u3002\u968f\u7740\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u51fa\u73b0\u548c\u8ba1\u7b97\u6210\u672c\u7684\u964d\u4f4e\uff0c\u73b0\u5728\u5728\u5bb6\u91cc\u5c31\u80fd\u8f7b\u677e\u8bad\u7ec3\u51fa\u63a5\u8fd1\u6700\u5148\u8fdb\u6c34\u5e73\u7684\u6a21\u578b\uff0c\u89e3\u51b3\u5927\u591a\u6570\u4e0e\u56fe\u50cf\u76f8\u5173\u7684\u95ee\u9898\u3002\u4f46\u662f\uff0c\u56fe\u50cf\u95ee\u9898\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u578b\u3002\u4ece\u4e24\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u7684\u6807\u51c6\u56fe\u50cf\u5206\u7c7b\uff0c\u5230\u50cf\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\u8fd9\u6837\u5177\u6709\u6311\u6218\u6027\u7684\u95ee\u9898\u3002\u6211\u4eec\u4e0d\u4f1a\u5728\u672c\u4e66\u4e2d\u8ba8\u8bba\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\uff0c\u4f46\u6211\u4eec\u663e\u7136\u4f1a\u5904\u7406\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u56fe\u50cf\u95ee\u9898\u3002 \u6211\u4eec\u53ef\u4ee5\u5bf9\u56fe\u50cf\u91c7\u7528\u54ea\u4e9b\u4e0d\u540c\u7684\u65b9\u6cd5\uff1f\u56fe\u50cf\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u6570\u5b57\u77e9\u9635\u3002\u8ba1\u7b97\u673a\u65e0\u6cd5\u50cf\u4eba\u7c7b\u4e00\u6837\u770b\u5230\u56fe\u50cf\u3002\u5b83\u53ea\u80fd\u770b\u5230\u6570\u5b57\uff0c\u8fd9\u5c31\u662f\u56fe\u50cf\u3002\u7070\u5ea6\u56fe\u50cf\u662f\u4e00\u4e2a\u4e8c\u7ef4\u77e9\u9635\uff0c\u6570\u503c\u8303\u56f4\u4ece 0 \u5230 255\u30020 \u4ee3\u8868\u9ed1\u8272\uff0c255 \u4ee3\u8868\u767d\u8272\uff0c\u4ecb\u4e8e\u4e24\u8005\u4e4b\u95f4\u7684\u662f\u5404\u79cd\u7070\u8272\u3002\u4ee5\u524d\uff0c\u5728\u6ca1\u6709\u6df1\u5ea6\u5b66\u4e60\u7684\u65f6\u5019\uff08\u6216\u8005\u8bf4\u6df1\u5ea6\u5b66\u4e60\u8fd8\u4e0d\u6d41\u884c\u7684\u65f6\u5019\uff09\uff0c\u4eba\u4eec\u4e60\u60ef\u4e8e\u67e5\u770b\u50cf\u7d20\u3002\u6bcf\u4e2a\u50cf\u7d20\u90fd\u662f\u4e00\u4e2a\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u5728 Python \u4e2d\u8f7b\u677e\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u53ea\u9700\u4f7f\u7528 OpenCV \u6216 Python-PIL \u8bfb\u53d6\u7070\u5ea6\u56fe\u50cf\uff0c\u8f6c\u6362\u4e3a numpy \u6570\u7ec4\uff0c\u7136\u540e\u5c06\u77e9\u9635\u5e73\u94fa\uff08\u6241\u5e73\u5316\uff09\u5373\u53ef\u3002\u5982\u679c\u5904\u7406\u7684\u662f RGB \u56fe\u50cf\uff0c\u5219\u9700\u8981\u4e09\u4e2a\u77e9\u9635\uff0c\u800c\u4e0d\u662f\u4e00\u4e2a\u3002\u4f46\u601d\u8def\u662f\u4e00\u6837\u7684\u3002 import numpy as np import matplotlib.pyplot as plt # \u751f\u6210\u4e00\u4e2a 256x256 \u7684\u968f\u673a\u7070\u5ea6\u56fe\u50cf\uff0c\u50cf\u7d20\u503c\u57280\u5230255\u4e4b\u95f4\u968f\u673a\u5206\u5e03 random_image = np . random . randint ( 0 , 256 , ( 256 , 256 )) # \u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u56fe\u50cf\u7a97\u53e3\uff0c\u8bbe\u7f6e\u7a97\u53e3\u5927\u5c0f\u4e3a7x7\u82f1\u5bf8 plt . figure ( figsize = ( 7 , 7 )) # \u663e\u793a\u751f\u6210\u7684\u968f\u673a\u56fe\u50cf # \u4f7f\u7528\u7070\u5ea6\u989c\u8272\u6620\u5c04 (colormap)\uff0c\u8303\u56f4\u4ece0\u5230255 plt . imshow ( random_image , cmap = 'gray' , vmin = 0 , vmax = 255 ) # \u663e\u793a\u56fe\u50cf\u7a97\u53e3 plt . show () \u4e0a\u9762\u7684\u4ee3\u7801\u4f7f\u7528 numpy \u751f\u6210\u4e00\u4e2a\u968f\u673a\u77e9\u9635\u3002\u8be5\u77e9\u9635\u7531 0 \u5230 255\uff08\u5305\u542b\uff09\u7684\u503c\u7ec4\u6210\uff0c\u5927\u5c0f\u4e3a 256x256\uff08\u4e5f\u79f0\u4e3a\u50cf\u7d20\uff09\u3002 \u56fe 1\uff1a\u4e8c\u7ef4\u56fe\u50cf\u9635\u5217\uff08\u5355\u901a\u9053\uff09\u53ca\u5176\u5c55\u5e73\u7248\u672c \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u62fc\u5199\u540e\u7684\u7248\u672c\u53ea\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a M \u7684\u5411\u91cf\uff0c\u5176\u4e2d M = N * N\uff0c\u5728\u672c\u4f8b\u4e2d\uff0c\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 256 * 256 = 65536\u3002 \u73b0\u5728\uff0c\u5982\u679c\u6211\u4eec\u7ee7\u7eed\u5bf9\u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u8fdb\u884c\u5904\u7406\uff0c\u6bcf\u4e2a\u6837\u672c\u5c31\u4f1a\u6709 65536 \u4e2a\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5feb\u901f\u5efa\u7acb \u51b3\u7b56\u6811\u6a21\u578b\u3001\u968f\u673a\u68ee\u6797\u6a21\u578b\u6216\u57fa\u4e8e SVM \u7684\u6a21\u578b \u3002\u8fd9\u4e9b\u6a21\u578b\u5c06\u57fa\u4e8e\u50cf\u7d20\u503c\uff0c\u5c1d\u8bd5\u5c06\u6b63\u6837\u672c\u4e0e\u8d1f\u6837\u672c\u533a\u5206\u5f00\u6765\uff08\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff09\u3002 \u4f60\u4eec\u4e00\u5b9a\u90fd\u542c\u8bf4\u8fc7\u732b\u4e0e\u72d7\u7684\u95ee\u9898\uff0c\u8fd9\u662f\u4e00\u4e2a\u7ecf\u5178\u7684\u95ee\u9898\u3002\u5982\u679c\u4f60\u4eec\u8fd8\u8bb0\u5f97\uff0c\u5728\u8bc4\u4f30\u6307\u6807\u4e00\u7ae0\u7684\u5f00\u5934\uff0c\u6211\u5411\u4f60\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e2a\u6c14\u80f8\u56fe\u50cf\u6570\u636e\u96c6\u3002\u90a3\u4e48\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u68c0\u6d4b\u80ba\u90e8\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\uff08\u5e76\u4e0d\uff09\u7b80\u5355\u7684\u4e8c\u5143\u5206\u7c7b\u3002 \u56fe 2\uff1a\u975e\u6c14\u80f8\u4e0e\u6c14\u80f8 X \u5149\u56fe\u50cf\u5bf9\u6bd4 \u5728\u56fe 2 \u4e2d\uff0c\u60a8\u53ef\u4ee5\u770b\u5230\u975e\u6c14\u80f8\u548c\u6c14\u80f8\u56fe\u50cf\u7684\u5bf9\u6bd4\u3002\u60a8\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\u4e86\uff0c\u5bf9\u4e8e\u4e00\u4e2a\u975e\u4e13\u4e1a\u4eba\u58eb\uff08\u6bd4\u5982\u6211\uff09\u6765\u8bf4\uff0c\u8981\u5728\u8fd9\u4e9b\u56fe\u50cf\u4e2d\u8fa8\u522b\u51fa\u54ea\u4e2a\u662f\u6c14\u80f8\u662f\u76f8\u5f53\u56f0\u96be\u7684\u3002 \u6700\u521d\u7684\u6570\u636e\u96c6\u662f\u5173\u4e8e\u68c0\u6d4b\u6c14\u80f8\u7684\u5177\u4f53\u4f4d\u7f6e\uff0c\u4f46\u6211\u4eec\u5c06\u95ee\u9898\u4fee\u6539\u4e3a\u67e5\u627e\u7ed9\u5b9a\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u522b\u62c5\u5fc3\uff0c\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u4ecb\u7ecd\u8fd9\u4e2a\u90e8\u5206\u3002\u6570\u636e\u96c6\u7531 10675 \u5f20\u72ec\u7279\u7684\u56fe\u50cf\u7ec4\u6210\uff0c\u5176\u4e2d 2379 \u5f20\u6709\u6c14\u80f8\uff08\u6ce8\u610f\uff0c\u8fd9\u4e9b\u6570\u5b57\u662f\u7ecf\u8fc7\u6570\u636e\u6e05\u7406\u540e\u5f97\u51fa\u7684\uff0c\u56e0\u6b64\u4e0e\u539f\u59cb\u6570\u636e\u96c6\u4e0d\u7b26\uff09\u3002\u6b63\u5982\u6570\u636e\u79d1\u5b66\u5bb6\u6240\u8bf4\uff1a\u8fd9\u662f\u4e00\u4e2a\u5178\u578b\u7684 \u504f\u659c\u4e8c\u5143\u5206\u7c7b\u6848\u4f8b \u3002\u56e0\u6b64\uff0c\u6211\u4eec\u9009\u62e9 AUC \u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\uff0c\u5e76\u91c7\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u9a8c\u8bc1\u65b9\u6848\u3002 \u60a8\u53ef\u4ee5\u5c06\u7279\u5f81\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c1d\u8bd5\u4e00\u4e9b\u7ecf\u5178\u65b9\u6cd5\uff08\u5982 SVM\u3001RF\uff09\u6765\u8fdb\u884c\u5206\u7c7b\uff0c\u8fd9\u5b8c\u5168\u6ca1\u95ee\u9898\uff0c\u4f46\u5374\u65e0\u6cd5\u8ba9\u60a8\u8fbe\u5230\u6700\u5148\u8fdb\u7684\u6c34\u5e73\u3002\u6b64\u5916\uff0c\u56fe\u50cf\u5927\u5c0f\u4e3a 1024x1024\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u6a21\u578b\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u7ba1\u600e\u6837\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u7531\u4e8e\u56fe\u50cf\u662f\u7070\u5ea6\u7684\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u8fdb\u884c\u4efb\u4f55\u8f6c\u6362\u3002\u6211\u4eec\u5c06\u628a\u56fe\u50cf\u5927\u5c0f\u8c03\u6574\u4e3a 256x256\uff0c\u4f7f\u5176\u66f4\u5c0f\uff0c\u5e76\u4f7f\u7528\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684 AUC \u4f5c\u4e3a\u8861\u91cf\u6307\u6807\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5b83\u7684\u8868\u73b0\u5982\u4f55\u3002 import os import numpy as np import pandas as pd from PIL import Image from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from tqdm import tqdm # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u6765\u521b\u5efa\u6570\u636e\u96c6 def create_dataset ( training_df , image_dir ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\u6765\u5b58\u50a8\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c images = [] targets = [] # \u8fed\u4ee3\u5904\u7406\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u7684\u6bcf\u4e00\u884c for index , row in tqdm ( training_df . iterrows (), total = len ( training_df ), desc = \"processing images\" ): # \u83b7\u53d6\u56fe\u50cf\u6587\u4ef6\u540d image_id = row [ \"ImageId\" ] # \u6784\u5efa\u5b8c\u6574\u7684\u56fe\u50cf\u6587\u4ef6\u8def\u5f84 image_path = os . path . join ( image_dir , image_id ) # \u6253\u5f00\u56fe\u50cf\u6587\u4ef6\u5e76\u8fdb\u884c\u5927\u5c0f\u8c03\u6574\uff08resize\uff09\u4e3a 256x256 \u50cf\u7d20\uff0c\u4f7f\u7528\u53cc\u7ebf\u6027\u63d2\u503c\uff08BILINEAR\uff09 image = Image . open ( image_path + \".png\" ) image = image . resize (( 256 , 256 ), resample = Image . BILINEAR ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 image = np . array ( image ) # \u5c06\u56fe\u50cf\u6241\u5e73\u5316\u4e3a\u4e00\u7ef4\u6570\u7ec4\uff0c\u5e76\u5c06\u5176\u6dfb\u52a0\u5230\u56fe\u50cf\u5217\u8868 image = image . ravel () images . append ( image ) # \u5c06\u76ee\u6807\u503c\uff08target\uff09\u6dfb\u52a0\u5230\u76ee\u6807\u5217\u8868 targets . append ( int ( row [ \"target\" ])) # \u5c06\u56fe\u50cf\u5217\u8868\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 images = np . array ( images ) # \u6253\u5370\u56fe\u50cf\u6570\u7ec4\u7684\u5f62\u72b6 print ( images . shape ) # \u8fd4\u56de\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c return images , targets if __name__ == \"__main__\" : # \u5b9a\u4e49CSV\u6587\u4ef6\u8def\u5f84\u548c\u56fe\u50cf\u6587\u4ef6\u76ee\u5f55\u8def\u5f84 csv_path = \"/home/abhishek/workspace/siim_png/train.csv\" image_path = \"/home/abhishek/workspace/siim_png/train_png/\" # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( csv_path ) # \u6dfb\u52a0\u4e00\u4e2a\u540d\u4e3a'kfold'\u7684\u5217\uff0c\u5e76\u521d\u59cb\u5316\u4e3a-1 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u503c\uff08target\uff09 y = df . target . values # \u4f7f\u7528\u5206\u5c42KFold\u4ea4\u53c9\u9a8c\u8bc1\u5c06\u6570\u636e\u96c6\u5206\u62105\u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u904d\u5386\u6bcf\u4e2a\u6298\uff08fold\uff09 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u904d\u5386\u6bcf\u4e2a\u6298 for fold_ in range ( 5 ): # \u83b7\u53d6\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtrain , ytrain = create_dataset ( train_df , image_path ) # \u521b\u5efa\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtest , ytest = create_dataset ( test_df , image_path ) # \u521d\u59cb\u5316\u4e00\u4e2a\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668 clf = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u5206\u7c7b\u5668 clf . fit ( xtrain , ytrain ) # \u4f7f\u7528\u5206\u7c7b\u5668\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b\uff0c\u5e76\u83b7\u53d6\u6982\u7387\u503c preds = clf . predict_proba ( xtest )[:, 1 ] # \u6253\u5370\u6298\u6570\uff08fold\uff09\u548cAUC\uff08ROC\u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff09 print ( f \"FOLD: { fold_ } \" ) print ( f \"AUC = { metrics . roc_auc_score ( ytest , preds ) } \" ) print ( \"\" ) \u5e73\u5747 AUC \u503c\u7ea6\u4e3a 0.72\u3002\u8fd9\u8fd8\u4e0d\u9519\uff0c\u4f46\u6211\u4eec\u5e0c\u671b\u80fd\u505a\u5f97\u66f4\u597d\u3002\u4f60\u53ef\u4ee5\u5c06\u8fd9\u79cd\u65b9\u6cd5\u7528\u4e8e\u56fe\u50cf\uff0c\u8fd9\u4e5f\u662f\u5b83\u5728\u4ee5\u524d\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002SVM \u5728\u56fe\u50cf\u6570\u636e\u96c6\u65b9\u9762\u76f8\u5f53\u6709\u540d\u3002\u6df1\u5ea6\u5b66\u4e60\u5df2\u88ab\u8bc1\u660e\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u6700\u5148\u8fdb\u65b9\u6cd5\uff0c\u56e0\u6b64\u6211\u4eec\u4e0b\u4e00\u6b65\u53ef\u4ee5\u8bd5\u8bd5\u5b83\u3002 \u5173\u4e8e\u6df1\u5ea6\u5b66\u4e60\u7684\u5386\u53f2\u4ee5\u53ca\u8c01\u53d1\u660e\u4e86\u4ec0\u4e48\uff0c\u6211\u5c31\u4e0d\u591a\u8bf4\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u6700\u8457\u540d\u7684\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e4b\u4e00 AlexNet\u3002 \u56fe 3\uff1aAlexNet \u67b6\u67849 \u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u4e2d\u7684\u8f93\u5165\u5927\u5c0f\u4e0d\u662f 224x224 \u800c\u662f 227x227 \u5982\u4eca\uff0c\u4f60\u53ef\u80fd\u4f1a\u8bf4\u8fd9\u53ea\u662f\u4e00\u4e2a\u57fa\u672c\u7684 \u6df1\u5ea6\u5377\u79ef\u795e\u7ecf\u7f51\u7edc \uff0c\u4f46\u5b83\u5374\u662f\u8bb8\u591a\u65b0\u578b\u6df1\u5ea6\u7f51\u7edc\uff08\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff09\u7684\u57fa\u7840\u3002\u6211\u4eec\u770b\u5230\uff0c\u56fe 3 \u4e2d\u7684\u7f51\u7edc\u662f\u4e00\u4e2a\u5177\u6709\u4e94\u4e2a\u5377\u79ef\u5c42\u3001\u4e24\u4e2a\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u3002\u6211\u4eec\u770b\u5230\u8fd8\u6709\u6700\u5927\u6c60\u5316\u3002\u8fd9\u662f\u4ec0\u4e48\u610f\u601d\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5728\u8fdb\u884c\u6df1\u5ea6\u5b66\u4e60\u65f6\u4f1a\u9047\u5230\u7684\u4e00\u4e9b\u672f\u8bed\u3002 \u56fe 4\uff1a\u56fe\u50cf\u5927\u5c0f\u4e3a 8x8\uff0c\u6ee4\u6ce2\u5668\u5927\u5c0f\u4e3a 3x3\uff0c\u6b65\u957f\u4e3a 2\u3002 \u56fe 4 \u5f15\u5165\u4e86\u4e24\u4e2a\u65b0\u672f\u8bed\uff1a\u6ee4\u6ce2\u5668\u548c\u6b65\u957f\u3002 \u6ee4\u6ce2\u5668 \u662f\u7531\u7ed9\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u7684\u4e8c\u7ef4\u77e9\u9635\uff0c\u7531\u6307\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u3002 Kaiming\u6b63\u6001\u521d\u59cb\u5316 \uff0c\u662f\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u7684\u6700\u4f73\u9009\u62e9\u3002\u8fd9\u662f\u56e0\u4e3a\u5927\u591a\u6570\u73b0\u4ee3\u7f51\u7edc\u90fd\u4f7f\u7528 ReLU \uff08\u6574\u6d41\u7ebf\u6027\u5355\u5143\uff09\u6fc0\u6d3b\u51fd\u6570\uff0c\u9700\u8981\u9002\u5f53\u7684\u521d\u59cb\u5316\u6765\u907f\u514d\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff08\u68af\u5ea6\u8d8b\u8fd1\u4e8e\u96f6\uff0c\u7f51\u7edc\u6743\u91cd\u4e0d\u53d8\uff09\u3002\u8be5\u6ee4\u6ce2\u5668\u4e0e\u56fe\u50cf\u8fdb\u884c\u5377\u79ef\u3002\u5377\u79ef\u4e0d\u8fc7\u662f\u6ee4\u6ce2\u5668\u4e0e\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u5f53\u524d\u91cd\u53e0\u50cf\u7d20\u4e4b\u95f4\u7684\u5143\u7d20\u76f8\u4e58\u7684\u603b\u548c\u3002\u60a8\u53ef\u4ee5\u5728\u4efb\u4f55\u9ad8\u4e2d\u6570\u5b66\u6559\u79d1\u4e66\u4e2d\u9605\u8bfb\u66f4\u591a\u5173\u4e8e\u5377\u79ef\u7684\u5185\u5bb9\u3002\u6211\u4eec\u4ece\u56fe\u50cf\u7684\u5de6\u4e0a\u89d2\u5f00\u59cb\u5bf9\u6ee4\u955c\u8fdb\u884c\u5377\u79ef\uff0c\u7136\u540e\u6c34\u5e73\u79fb\u52a8\u6ee4\u955c\u3002\u5982\u679c\u79fb\u52a8 1 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 1\uff1b\u5982\u679c\u79fb\u52a8 2 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 2\u3002 \u5373\u4f7f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u4f8b\u5982\u5728\u95ee\u9898\u548c\u56de\u7b54\u7cfb\u7edf\u4e2d\u9700\u8981\u4ece\u5927\u91cf\u6587\u672c\u8bed\u6599\u4e2d\u7b5b\u9009\u7b54\u6848\u65f6\uff0c\u6b65\u957f\u4e5f\u662f\u4e00\u4e2a\u975e\u5e38\u6709\u7528\u7684\u6982\u5ff5\u3002\u5f53\u6211\u4eec\u5728\u6c34\u5e73\u65b9\u5411\u4e0a\u8d70\u5230\u5c3d\u5934\u65f6\uff0c\u5c31\u4f1a\u4ee5\u540c\u6837\u7684\u6b65\u957f\u5782\u76f4\u5411\u4e0b\u79fb\u52a8\u8fc7\u6ee4\u5668\uff0c\u4ece\u5de6\u4fa7\u5f00\u59cb\u3002\u56fe 4 \u8fd8\u663e\u793a\u4e86\u8fc7\u6ee4\u5668\u79fb\u51fa\u56fe\u50cf\u7684\u60c5\u51b5\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u65e0\u6cd5\u8ba1\u7b97\u5377\u79ef\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u8df3\u8fc7\u5b83\u3002\u5982\u679c\u4e0d\u60f3\u8df3\u8fc7\uff0c\u5219\u9700\u8981\u5bf9\u56fe\u50cf\u8fdb\u884c \u586b\u5145\uff08pad\uff09 \u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5377\u79ef\u4f1a\u51cf\u5c0f\u56fe\u50cf\u7684\u5927\u5c0f\u3002\u586b\u5145\u4e5f\u662f\u4fdd\u6301\u56fe\u50cf\u5927\u5c0f\u4e0d\u53d8\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u5728\u56fe 4 \u4e2d\uff0c\u4e00\u4e2a 3x3 \u6ee4\u6ce2\u5668\u6b63\u5728\u6c34\u5e73\u548c\u5782\u76f4\u79fb\u52a8\uff0c\u6bcf\u6b21\u79fb\u52a8\u90fd\u4f1a\u5206\u522b\u8df3\u8fc7\u4e24\u5217\u548c\u4e24\u884c\uff08\u5373\u50cf\u7d20\uff09\u3002\u7531\u4e8e\u5b83\u8df3\u8fc7\u4e86\u4e24\u4e2a\u50cf\u7d20\uff0c\u6240\u4ee5\u6b65\u957f = 2\u3002\u56e0\u6b64\u56fe\u50cf\u5927\u5c0f\u4e3a [(8-3) / 2] + 1 = 3.5\u3002\u6211\u4eec\u53d6 3.5 \u7684\u4e0b\u9650\uff0c\u6240\u4ee5\u662f 3x3\u3002\u60a8\u53ef\u4ee5\u5728\u8349\u7a3f\u7eb8\u4e0a\u8fdb\u884c\u5c1d\u8bd5\u3002 \u56fe 5\uff1a\u901a\u8fc7\u586b\u5145\uff0c\u6211\u4eec\u53ef\u4ee5\u63d0\u4f9b\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\u7684\u56fe\u50cf \u6211\u4eec\u53ef\u4ee5\u4ece\u56fe 5 \u4e2d\u770b\u5230\u586b\u5145\u7684\u6548\u679c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e00\u4e2a 3x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5b83\u4ee5 1 \u7684\u6b65\u957f\u79fb\u52a8\u3002\u539f\u59cb\u56fe\u50cf\u7684\u5927\u5c0f\u4e3a 6x6\uff0c\u6211\u4eec\u6dfb\u52a0\u4e86 1 \u7684 \u586b\u5145 \u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u751f\u6210\u7684\u56fe\u50cf\u5c06\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\uff0c\u5373 6x6\u3002\u5728\u5904\u7406\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u7684\u53e6\u4e00\u4e2a\u76f8\u5173\u672f\u8bed\u662f \u81a8\u80c0\uff08dilation\uff09 \uff0c\u5982\u56fe 6 \u6240\u793a\u3002 \u56fe 6\uff1a\u81a8\u80c0\uff08dilation\uff09\u7684\u4f8b\u5b50 \u5728\u81a8\u80c0\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u5c06\u6ee4\u6ce2\u5668\u6269\u5927 N-1\uff0c\u5176\u4e2d N \u662f\u81a8\u80c0\u7387\u7684\u503c\uff0c\u6216\u7b80\u79f0\u4e3a\u81a8\u80c0\u3002\u5728\u8fd9\u79cd\u5e26\u81a8\u80c0\u7684\u5185\u6838\u4e2d\uff0c\u6bcf\u6b21\u5377\u79ef\u90fd\u4f1a\u8df3\u8fc7\u4e00\u4e9b\u50cf\u7d20\u3002\u8fd9\u5728\u5206\u5272\u4efb\u52a1\u4e2d\u5c24\u4e3a\u6709\u6548\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u4e8c\u7ef4\u5377\u79ef\u3002 \u8fd8\u6709\u4e00\u7ef4\u5377\u79ef\u548c\u66f4\u9ad8\u7ef4\u5ea6\u7684\u5377\u79ef\u3002\u5b83\u4eec\u90fd\u57fa\u4e8e\u76f8\u540c\u7684\u57fa\u672c\u6982\u5ff5\u3002 \u63a5\u4e0b\u6765\u662f \u6700\u5927\u6c60\u5316\uff08Max pooling\uff09 \u3002\u6700\u5927\u503c\u6c60\u53ea\u662f\u4e00\u4e2a\u8fd4\u56de\u6700\u5927\u503c\u7684\u6ee4\u6ce2\u5668\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u63d0\u53d6\u7684\u4e0d\u662f\u5377\u79ef\uff0c\u800c\u662f\u50cf\u7d20\u7684\u6700\u5927\u503c\u3002\u540c\u6837\uff0c \u5e73\u5747\u6c60\u5316\uff08average pooling\uff09 \u6216 \u5747\u503c\u6c60\u5316\uff08mean pooling\uff09 \u4f1a\u8fd4\u56de\u50cf\u7d20\u7684\u5e73\u5747\u503c\u3002\u5b83\u4eec\u7684\u4f7f\u7528\u65b9\u6cd5\u4e0e\u5377\u79ef\u6838\u76f8\u540c\u3002\u6c60\u5316\u6bd4\u5377\u79ef\u66f4\u5feb\uff0c\u662f\u4e00\u79cd\u5bf9\u56fe\u50cf\u8fdb\u884c\u7f29\u51cf\u91c7\u6837\u7684\u65b9\u6cd5\u3002\u6700\u5927\u6c60\u5316\u53ef\u68c0\u6d4b\u8fb9\u7f18\uff0c\u5e73\u5747\u6c60\u5316\u53ef\u5e73\u6ed1\u56fe\u50cf\u3002 \u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u548c\u6df1\u5ea6\u5b66\u4e60\u7684\u6982\u5ff5\u592a\u591a\u4e86\u3002\u6211\u6240\u8ba8\u8bba\u7684\u662f\u4e00\u4e9b\u57fa\u7840\u77e5\u8bc6\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5728 PyTorch \u4e2d\u6784\u5efa\u7b2c\u4e00\u4e2a\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u505a\u597d\u4e86\u5145\u5206\u51c6\u5907\u3002PyTorch \u63d0\u4f9b\u4e86\u4e00\u79cd\u76f4\u89c2\u800c\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u5b9e\u73b0\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e14\u4f60\u4e0d\u9700\u8981\u5173\u5fc3\u53cd\u5411\u4f20\u64ad\u3002\u6211\u4eec\u7528\u4e00\u4e2a python \u7c7b\u548c\u4e00\u4e2a\u524d\u9988\u51fd\u6570\u6765\u5b9a\u4e49\u7f51\u7edc\uff0c\u544a\u8bc9 PyTorch \u5404\u5c42\u4e4b\u95f4\u5982\u4f55\u8fde\u63a5\u3002\u5728 PyTorch \u4e2d\uff0c\u56fe\u50cf\u7b26\u53f7\u662f BS\u3001C\u3001H\u3001W\uff0c\u5176\u4e2d\uff0cBS \u662f\u6279\u5927\u5c0f\uff0cC \u662f\u901a\u9053\uff0cH \u662f\u9ad8\u5ea6\uff0cW \u662f\u5bbd\u5ea6\u3002\u8ba9\u6211\u4eec\u770b\u770b PyTorch \u662f\u5982\u4f55\u5b9e\u73b0 AlexNet \u7684\u3002 import torch import torch.nn as nn import torch.nn.functional as F class AlexNet ( nn . Module ): def __init__ ( self ): super ( AlexNet , self ) . __init__ () self . conv1 = nn . Conv2d ( in_channels = 3 , out_channels = 96 , kernel_size = 11 , stride = 4 , padding = 0 ) self . pool1 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv2 = nn . Conv2d ( in_channels = 96 , out_channels = 256 , kernel_size = 5 , stride = 1 , padding = 2 ) self . pool2 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv3 = nn . Conv2d ( in_channels = 256 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv4 = nn . Conv2d ( in_channels = 384 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv5 = nn . Conv2d ( in_channels = 384 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ) self . pool3 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . fc1 = nn . Linear ( in_features = 9216 , out_features = 4096 ) self . dropout1 = nn . Dropout ( 0.5 ) self . fc2 = nn . Linear ( in_features = 4096 , out_features = 4096 ) self . dropout2 = nn . Dropout ( 0.5 ) self . fc3 = nn . Linear ( in_features = 4096 , out_features = 1000 ) def forward ( self , image ): bs , c , h , w = image . size () x = F . relu ( self . conv1 ( image )) # size: (bs, 96, 55, 55) x = self . pool1 ( x ) # size: (bs, 96, 27, 27) x = F . relu ( self . conv2 ( x )) # size: (bs, 256, 27, 27) x = self . pool2 ( x ) # size: (bs, 256, 13, 13) x = F . relu ( self . conv3 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv4 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv5 ( x )) # size: (bs, 256, 13, 13) x = self . pool3 ( x ) # size: (bs, 256, 6, 6) x = x . view ( bs , - 1 ) # size: (bs, 9216) x = F . relu ( self . fc1 ( x )) # size: (bs, 4096) x = self . dropout1 ( x ) # size: (bs, 4096) # dropout does not change size # dropout is used for regularization # 0.3 dropout means that only 70% of the nodes # of the current layer are used for the next layer x = F . relu ( self . fc2 ( x )) # size: (bs, 4096) x = self . dropout2 ( x ) # size: (bs, 4096) x = F . relu ( self . fc3 ( x )) # size: (bs, 1000) # 1000 is number of classes in ImageNet Dataset # softmax is an activation function that converts # linear output to probabilities that add up to 1 # for each sample in the batch x = torch . softmax ( x , axis = 1 ) # size: (bs, 1000) return x \u5982\u679c\u60a8\u6709\u4e00\u5e45 3x227x227 \u7684\u56fe\u50cf\uff0c\u5e76\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11 \u7684\u5377\u79ef\u6ee4\u6ce2\u5668\uff0c\u8fd9\u610f\u5473\u7740\u60a8\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5e76\u4e0e\u4e00\u4e2a\u5927\u5c0f\u4e3a 227x227x3 \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u5377\u79ef\u3002\u8f93\u51fa\u901a\u9053\u7684\u6570\u91cf\u5c31\u662f\u5206\u522b\u5e94\u7528\u4e8e\u56fe\u50cf\u7684\u76f8\u540c\u5927\u5c0f\u7684\u4e0d\u540c\u5377\u79ef\u6ee4\u6ce2\u5668\u7684\u6570\u91cf\u3002 \u56e0\u6b64\uff0c\u5728\u7b2c\u4e00\u4e2a\u5377\u79ef\u5c42\u4e2d\uff0c\u8f93\u5165\u901a\u9053\u662f 3\uff0c\u4e5f\u5c31\u662f\u539f\u59cb\u8f93\u5165\uff0c\u5373 R\u3001G\u3001B \u4e09\u901a\u9053\u3002PyTorch \u7684 torchvision \u63d0\u4f9b\u4e86\u8bb8\u591a\u4e0e AlexNet \u7c7b\u4f3c\u7684\u4e0d\u540c\u6a21\u578b\uff0c\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0cAlexNet \u7684\u5b9e\u73b0\u4e0e torchvision \u7684\u5b9e\u73b0\u5e76\u4e0d\u76f8\u540c\u3002Torchvision \u7684 AlexNet \u5b9e\u73b0\u662f\u4ece\u53e6\u4e00\u7bc7\u8bba\u6587\u4e2d\u4fee\u6539\u800c\u6765\u7684 AlexNet\uff1a Krizhevsky, A. One weird trick for parallelizing convolutional neural networks. CoRR, abs/1404.5997, 2014. \u4f60\u53ef\u4ee5\u4e3a\u81ea\u5df1\u7684\u4efb\u52a1\u8bbe\u8ba1\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u5f88\u591a\u65f6\u5019\uff0c\u4ece\u96f6\u505a\u8d77\u662f\u4e2a\u4e0d\u9519\u7684\u4e3b\u610f\u3002\u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u7f51\u7edc\uff0c\u7528\u4e8e\u533a\u5206\u56fe\u50cf\u6709\u65e0\u6c14\u80f8\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u51c6\u5907\u4e00\u4e9b\u6587\u4ef6\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u4e00\u4e2a\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\uff0c\u5373 train.csv\uff0c\u4f46\u589e\u52a0\u4e00\u5217 kfold\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e94\u4e2a\u6587\u4ef6\u5939\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u5df2\u7ecf\u6f14\u793a\u4e86\u5982\u4f55\u9488\u5bf9\u4e0d\u540c\u7684\u6570\u636e\u96c6\u521b\u5efa\u6298\u53e0\uff0c\u56e0\u6b64\u6211\u5c06\u8df3\u8fc7\u8fd9\u4e00\u90e8\u5206\uff0c\u7559\u4f5c\u7ec3\u4e60\u3002\u5bf9\u4e8e\u57fa\u4e8e PyTorch \u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u9700\u8981\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u7684\u76ee\u7684\u662f\u8fd4\u56de\u4e00\u4e2a\u6570\u636e\u9879\u6216\u6570\u636e\u6837\u672c\u3002\u8fd9\u4e2a\u6570\u636e\u6837\u672c\u5e94\u8be5\u5305\u542b\u8bad\u7ec3\u6216\u8bc4\u4f30\u6a21\u578b\u6240\u9700\u7684\u6240\u6709\u5185\u5bb9\u3002 import torch import numpy as np from PIL import Image from PIL import ImageFile ImageFile . LOAD_TRUNCATED_IMAGES = True # \u5b9a\u4e49\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u5904\u7406\u56fe\u50cf\u5206\u7c7b\u4efb\u52a1 class ClassificationDataset : def __init__ ( self , image_paths , targets , resize = None , augmentations = None ): # \u56fe\u50cf\u6587\u4ef6\u8def\u5f84\u5217\u8868 self . image_paths = image_paths # \u76ee\u6807\u6807\u7b7e\u5217\u8868 self . targets = targets # \u56fe\u50cf\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . resize = resize # \u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . augmentations = augmentations def __len__ ( self ): # \u8fd4\u56de\u6570\u636e\u96c6\u7684\u5927\u5c0f\uff0c\u5373\u56fe\u50cf\u6570\u91cf return len ( self . image_paths ) def __getitem__ ( self , item ): # \u83b7\u53d6\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e2a\u6837\u672c image = Image . open ( self . image_paths [ item ]) image = image . convert ( \"RGB\" ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aRGB\u683c\u5f0f # \u83b7\u53d6\u8be5\u6837\u672c\u7684\u76ee\u6807\u6807\u7b7e targets = self . targets [ item ] if self . resize is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u5c06\u56fe\u50cf\u8fdb\u884c\u5c3a\u5bf8\u8c03\u6574 image = image . resize (( self . resize [ 1 ], self . resize [ 0 ]), resample = Image . BILINEAR ) image = np . array ( image ) if self . augmentations is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u5e94\u7528\u6570\u636e\u589e\u5f3a augmented = self . augmentations ( image = image ) image = augmented [ \"image\" ] # \u5c06\u56fe\u50cf\u901a\u9053\u987a\u5e8f\u8c03\u6574\u4e3a(C, H, W)\u7684\u5f62\u5f0f\uff0c\u5e76\u8f6c\u6362\u4e3afloat32\u7c7b\u578b image = np . transpose ( image , ( 2 , 0 , 1 )) . astype ( np . float32 ) # \u8fd4\u56de\u6837\u672c\uff0c\u5305\u62ec\u56fe\u50cf\u548c\u5bf9\u5e94\u7684\u76ee\u6807\u6807\u7b7e return { \"image\" : torch . tensor ( image , dtype = torch . float ), \"targets\" : torch . tensor ( targets , dtype = torch . long ), } \u73b0\u5728\u6211\u4eec\u9700\u8981 engine.py\u3002engine.py \u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u770b\u770b engine.py \u662f\u5982\u4f55\u7f16\u5199\u7684\u3002 import torch import torch.nn as nn from tqdm import tqdm # \u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u51fd\u6570 def train ( data_loader , model , optimizer , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bad\u7ec3\u6a21\u5f0f model . train () for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u548c\u76ee\u6807\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) targets = targets . to ( device , dtype = torch . float ) # \u5c06\u4f18\u5316\u5668\u4e2d\u7684\u68af\u5ea6\u5f52\u96f6 optimizer . zero_grad () # \u524d\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u6a21\u578b\u9884\u6d4b outputs = model ( inputs ) # \u4f7f\u7528\u5e26\u903b\u8f91\u65af\u8482\u51fd\u6570\u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\u8ba1\u7b97\u635f\u5931 loss = nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) # \u53cd\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u68af\u5ea6\u5e76\u66f4\u65b0\u6a21\u578b\u6743\u91cd loss . backward () optimizer . step () # \u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u7684\u51fd\u6570 def evaluate ( data_loader , model , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bc4\u4f30\u6a21\u5f0f\uff08\u4e0d\u8fdb\u884c\u68af\u5ea6\u8ba1\u7b97\uff09 model . eval () # \u521d\u59cb\u5316\u5217\u8868\u4ee5\u5b58\u50a8\u771f\u5b9e\u76ee\u6807\u548c\u6a21\u578b\u9884\u6d4b final_targets = [] final_outputs = [] with torch . no_grad (): for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) # \u83b7\u53d6\u6a21\u578b\u9884\u6d4b output = model ( inputs ) # \u5c06\u76ee\u6807\u548c\u8f93\u51fa\u8f6c\u6362\u4e3aCPU\u548cPython\u5217\u8868 targets = targets . detach () . cpu () . numpy () . tolist () output = output . detach () . cpu () . numpy () . tolist () # \u5c06\u5217\u8868\u6269\u5c55\u4ee5\u5305\u542b\u6279\u6b21\u6570\u636e final_targets . extend ( targets ) final_outputs . extend ( output ) # \u8fd4\u56de\u6700\u7ec8\u7684\u6a21\u578b\u9884\u6d4b\u548c\u771f\u5b9e\u76ee\u6807 return final_outputs , final_targets \u6709\u4e86 engine.py\uff0c\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\uff1amodel.py\u3002model.py \u5c06\u5305\u542b\u6211\u4eec\u7684\u6a21\u578b\u3002\u628a\u6a21\u578b\u4e0e\u8bad\u7ec3\u5206\u5f00\u662f\u4e2a\u597d\u4e3b\u610f\uff0c\u56e0\u4e3a\u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u8bd5\u9a8c\u4e0d\u540c\u7684\u6a21\u578b\u548c\u4e0d\u540c\u7684\u67b6\u6784\u3002\u540d\u4e3a pretrainedmodels \u7684 PyTorch \u5e93\u4e2d\u6709\u5f88\u591a\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\uff0c\u5982 AlexNet\u3001ResNet\u3001DenseNet \u7b49\u3002\u8fd9\u4e9b\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\u662f\u5728\u540d\u4e3a ImageNet \u7684\u5927\u578b\u56fe\u50cf\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u51fa\u6765\u7684\u3002\u5728 ImageNet \u4e0a\u8bad\u7ec3\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5b83\u4eec\u7684\u6743\u91cd\uff0c\u4e5f\u53ef\u4ee5\u4e0d\u4f7f\u7528\u8fd9\u4e9b\u6743\u91cd\u3002\u5982\u679c\u6211\u4eec\u4e0d\u4f7f\u7528 ImageNet \u6743\u91cd\u8fdb\u884c\u8bad\u7ec3\uff0c\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u7f51\u7edc\u5c06\u4ece\u5934\u5f00\u59cb\u5b66\u4e60\u4e00\u5207\u3002\u8fd9\u5c31\u662f model.py \u7684\u6837\u5b50\u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 4096 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 4096 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5982\u679c\u4f60\u6253\u5370\u4e86\u7f51\u7edc\uff0c\u4f1a\u5f97\u5230\u5982\u4e0b\u8f93\u51fa\uff1a AlexNet ( ( avgpool ): AdaptiveAvgPool2d ( output_size = ( 6 , 6 )) ( _features ): Sequential ( ( 0 ): Conv2d ( 3 , 64 , kernel_size = ( 11 , 11 ), stride = ( 4 , 4 ), padding = ( 2 , 2 )) ( 1 ): ReLU ( inplace = True ) ( 2 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 3 ): Conv2d ( 64 , 192 , kernel_size = ( 5 , 5 ), stride = ( 1 , 1 ), padding = ( 2 , 2 )) ( 4 ): ReLU ( inplace = True ) ( 5 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 6 ): Conv2d ( 192 , 384 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 7 ): ReLU ( inplace = True ) ( 8 ): Conv2d ( 384 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 9 ): ReLU ( inplace = True ) ( 10 ): Conv2d ( 256 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 11 ): ReLU ( inplace = True ) ( 12 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , eil_mode = False )) ( dropout0 ): Dropout ( p = 0.5 , inplace = False ) ( linear0 ): Linear ( in_features = 9216 , out_features = 4096 , bias = True ) ( relu0 ): ReLU ( inplace = True ) ( dropout1 ): Dropout ( p = 0.5 , inplace = False ) ( linear1 ): Linear ( in_features = 4096 , out_features = 4096 , bias = True ) ( relu1 ): ReLU ( inplace = True ) ( last_linear ): Sequential ( ( 0 ): BatchNorm1d ( 4096 , eps = 1e-05 , momentum = 0.1 , affine = True , rack_running_stats = True ) ( 1 ): Dropout ( p = 0.25 , inplace = False ) ( 2 ): Linear ( in_features = 4096 , out_features = 2048 , bias = True ) ( 3 ): ReLU () ( 4 ): BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 , affine = True , track_running_stats = True ) ( 5 ): Dropout ( p = 0.5 , inplace = False ) ( 6 ): Linear ( in_features = 2048 , out_features = 1 , bias = True ) ) ) \u73b0\u5728\uff0c\u4e07\u4e8b\u4ff1\u5907\uff0c\u53ef\u4ee5\u5f00\u59cb\u8bad\u7ec3\u4e86\u3002\u6211\u4eec\u5c06\u4f7f\u7528 train.py \u8bad\u7ec3\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import torch from sklearn import metrics from sklearn.model_selection import train_test_split import dataset import engine from model import get_model if __name__ == \"__main__\" : # \u5b9a\u4e49\u6570\u636e\u8def\u5f84\u3001\u8bbe\u5907\u3001\u8fed\u4ee3\u6b21\u6570 data_path = \"/home/abhishek/workspace/siim_png/\" device = \"cuda\" # \u4f7f\u7528GPU\u52a0\u901f epochs = 10 # \u4eceCSV\u6587\u4ef6\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( os . path . join ( data_path , \"train.csv\" )) images = df . ImageId . values . tolist () images = [ os . path . join ( data_path , \"train_png\" , i + \".png\" ) for i in images ] targets = df . target . values # \u83b7\u53d6\u9884\u8bad\u7ec3\u7684\u6a21\u578b model = get_model ( pretrained = True ) model . to ( device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\uff0c\u7528\u4e8e\u6570\u636e\u6807\u51c6\u5316 mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) # \u6570\u636e\u589e\u5f3a\uff0c\u5c06\u56fe\u50cf\u6807\u51c6\u5316 aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5212\u5206\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 train_images , valid_images , train_targets , valid_targets = train_test_split ( images , targets , stratify = targets , random_state = 42 ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u548c\u9a8c\u8bc1\u6570\u636e\u96c6 train_dataset = dataset . ClassificationDataset ( image_paths = train_images , targets = train_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = 16 , shuffle = True , num_workers = 4 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = dataset . ClassificationDataset ( image_paths = valid_images , targets = valid_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = 16 , shuffle = False , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u8bad\u7ec3\u5faa\u73af for epoch in range ( epochs ): # \u8bad\u7ec3\u6a21\u578b engine . train ( train_loader , model , optimizer , device = device ) # \u8bc4\u4f30\u6a21\u578b\u6027\u80fd predictions , valid_targets = engine . evaluate ( valid_loader , model , device = device ) # \u8ba1\u7b97ROC AUC\u5206\u6570\u5e76\u6253\u5370 roc_auc = metrics . roc_auc_score ( valid_targets , predictions ) print ( f \"Epoch= { epoch } , Valid ROC AUC= { roc_auc } \" ) \u8ba9\u6211\u4eec\u5728\u6ca1\u6709\u9884\u8bad\u7ec3\u6743\u91cd\u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff1a Epoch = 0 , Valid ROC AUC = 0.5737161981475328 Epoch = 1 , Valid ROC AUC = 0.5362868001588292 Epoch = 2 , Valid ROC AUC = 0.6163448214387008 Epoch = 3 , Valid ROC AUC = 0.6119219143780944 Epoch = 4 , Valid ROC AUC = 0.6229718888519726 Epoch = 5 , Valid ROC AUC = 0.5983014999635341 Epoch = 6 , Valid ROC AUC = 0.5523236874306134 Epoch = 7 , Valid ROC AUC = 0.4717721611306046 Epoch = 8 , Valid ROC AUC = 0.6473408263980617 Epoch = 9 , Valid ROC AUC = 0.6639862888260415 AUC \u7ea6\u4e3a 0.66\uff0c\u751a\u81f3\u4f4e\u4e8e\u6211\u4eec\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u4f1a\u53d1\u751f\u4ec0\u4e48\u60c5\u51b5\uff1f Epoch = 0 , Valid ROC AUC = 0.5730387429803165 Epoch = 1 , Valid ROC AUC = 0.5319813942934937 Epoch = 2 , Valid ROC AUC = 0.627111577514323 Epoch = 3 , Valid ROC AUC = 0.6819736959393209 Epoch = 4 , Valid ROC AUC = 0.5747117168950512 Epoch = 5 , Valid ROC AUC = 0.5994619255609669 Epoch = 6 , Valid ROC AUC = 0.5080889443530546 Epoch = 7 , Valid ROC AUC = 0.6323792776512727 Epoch = 8 , Valid ROC AUC = 0.6685753182661686 Epoch = 9 , Valid ROC AUC = 0.6861802387300147 \u73b0\u5728\u7684 AUC \u597d\u4e86\u5f88\u591a\u3002\u4e0d\u8fc7\uff0c\u5b83\u4ecd\u7136\u8f83\u4f4e\u3002\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u597d\u5904\u662f\u53ef\u4ee5\u8f7b\u677e\u5c1d\u8bd5\u591a\u79cd\u4e0d\u540c\u7684\u6a21\u578b\u3002\u8ba9\u6211\u4eec\u8bd5\u8bd5\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u7684 resnet18 \u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 512 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 512 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5728\u5c1d\u8bd5\u8be5\u6a21\u578b\u65f6\uff0c\u6211\u8fd8\u5c06\u56fe\u50cf\u5927\u5c0f\u6539\u4e3a 512x512\uff0c\u5e76\u6dfb\u52a0\u4e86\u4e00\u4e2a\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\uff0c\u6bcf 3 \u4e2aepochs\u540e\u5c06\u5b66\u4e60\u7387\u4e58\u4ee5 0.5\u3002 Epoch = 0 , Valid ROC AUC = 0.5988225569880796 Epoch = 1 , Valid ROC AUC = 0.730349343208836 Epoch = 2 , Valid ROC AUC = 0.5870943169939142 Epoch = 3 , Valid ROC AUC = 0.5775864444138311 Epoch = 4 , Valid ROC AUC = 0.7330502499939224 Epoch = 5 , Valid ROC AUC = 0.7500336296524395 Epoch = 6 , Valid ROC AUC = 0.7563722113724951 Epoch = 7 , Valid ROC AUC = 0.7987463837994215 Epoch = 8 , Valid ROC AUC = 0.798505708937384 Epoch = 9 , Valid ROC AUC = 0.8025477500546988 \u8fd9\u4e2a\u6a21\u578b\u4f3c\u4e4e\u8868\u73b0\u6700\u597d\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u8c03\u6574 AlexNet \u4e2d\u7684\u4e0d\u540c\u53c2\u6570\u548c\u56fe\u50cf\u5927\u5c0f\uff0c\u4ee5\u83b7\u5f97\u66f4\u597d\u7684\u5206\u6570\u3002 \u4f7f\u7528\u589e\u5f3a\u6280\u672f\u5c06\u8fdb\u4e00\u6b65\u63d0\u9ad8\u5f97\u5206\u3002\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u5f88\u96be\uff0c\u4f46\u5e76\u975e\u4e0d\u53ef\u80fd\u3002\u9009\u62e9 Adam \u4f18\u5316\u5668\u3001\u4f7f\u7528\u4f4e\u5b66\u4e60\u7387\u3001\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u9ad8\u70b9\u65f6\u964d\u4f4e\u5b66\u4e60\u7387\u3001\u5c1d\u8bd5\u4e00\u4e9b\u589e\u5f3a\u6280\u672f\u3001\u5c1d\u8bd5\u5bf9\u56fe\u50cf\u8fdb\u884c\u9884\u5904\u7406\uff08\u5982\u5728\u9700\u8981\u65f6\u8fdb\u884c\u88c1\u526a\uff0c\u8fd9\u4e5f\u53ef\u89c6\u4e3a\u9884\u5904\u7406\uff09\u3001\u6539\u53d8\u6279\u6b21\u5927\u5c0f\u7b49\u3002\u4f60\u53ef\u4ee5\u505a\u5f88\u591a\u4e8b\u60c5\u6765\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u3002 \u4e0e AlexNet \u76f8\u6bd4\uff0c ResNet \u7684\u7ed3\u6784\u8981\u590d\u6742\u5f97\u591a\u3002ResNet \u662f\u6b8b\u5dee\u795e\u7ecf\u7f51\u7edc\uff08Residual Neural Network\uff09\u7684\u7f29\u5199\uff0c\u7531 K. He\u3001X. Zhang\u3001S. Ren \u548c J. Sun \u5728 2015 \u5e74\u53d1\u8868\u7684\u8bba\u6587\u4e2d\u63d0\u51fa\u3002ResNet \u7531 \u6b8b\u5dee\u5757 \uff08residual blocks\uff09\u7ec4\u6210\uff0c\u901a\u8fc7\u8df3\u8fc7\u67d0\u4e9b\u5c42\uff0c\u4f7f\u77e5\u8bc6\u80fd\u591f\u4e0d\u65ad\u5728\u5404\u5c42\u4e2d\u8fdb\u884c\u4f20\u9012\u3002\u8fd9\u4e9b\u5c42\u4e4b\u95f4\u7684 \u8fde\u63a5\u88ab\u79f0\u4e3a \u8df3\u8dc3\u8fde\u63a5 \uff08skip-connections\uff09\uff0c\u56e0\u4e3a\u6211\u4eec\u8df3\u8fc7\u4e86\u4e00\u5c42\u6216\u591a\u5c42\u3002\u8df3\u8dc3\u8fde\u63a5\u901a\u8fc7\u5c06\u68af\u5ea6\u4f20\u64ad\u5230\u66f4\u591a\u5c42\u6765\u5e2e\u52a9\u89e3\u51b3\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u8bad\u7ec3\u975e\u5e38\u5927\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e0d\u4f1a\u635f\u5931\u6027\u80fd\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5927\u578b\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u4e48\u5f53\u8bad\u7ec3\u5230\u67d0\u4e00\u8282\u70b9\u4e0a\u65f6\u8bad\u7ec3\u635f\u5931\u53cd\u800c\u4f1a\u589e\u52a0\uff0c\u4f46\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u8df3\u8dc3\u8fde\u63a5\u6765\u907f\u514d\u3002\u901a\u8fc7\u56fe 7 \u53ef\u4ee5\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u56fe 7\uff1a\u7b80\u5355\u8fde\u63a5\u4e0e\u6b8b\u5dee\u8fde\u63a5\u7684\u6bd4\u8f83\u3002\u53c2\u89c1\u8df3\u8dc3\u8fde\u63a5\u3002\u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u7701\u7565\u4e86\u6700\u540e\u4e00\u5c42\u3002 \u6b8b\u5dee\u5757\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4f60\u4ece\u67d0\u4e00\u5c42\u83b7\u53d6\u8f93\u51fa\uff0c\u8df3\u8fc7\u4e00\u4e9b\u5c42\uff0c\u7136\u540e\u5c06\u8f93\u51fa\u6dfb\u52a0\u5230\u7f51\u7edc\u4e2d\u66f4\u8fdc\u7684\u4e00\u5c42\u3002\u865a\u7ebf\u8868\u793a\u8f93\u5165\u5f62\u72b6\u9700\u8981\u8c03\u6574\uff0c\u56e0\u4e3a\u4f7f\u7528\u4e86\u6700\u5927\u6c60\u5316\uff0c\u800c\u6700\u5927\u6c60\u5316\u7684\u4f7f\u7528\u4f1a\u6539\u53d8\u8f93\u51fa\u7684\u5927\u5c0f\u3002 ResNet \u6709\u591a\u79cd\u4e0d\u540c\u7684\u7248\u672c\uff1a \u6709 18 \u5c42\u300134 \u5c42\u300150 \u5c42\u3001101 \u5c42\u548c 152 \u5c42\uff0c\u6240\u6709\u8fd9\u4e9b\u5c42\u90fd\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8fdb\u884c\u4e86\u6743\u91cd\u9884\u8bad\u7ec3\u3002\u5982\u4eca\uff0c\u9884\u8bad\u7ec3\u6a21\u578b\uff08\u51e0\u4e4e\uff09\u9002\u7528\u4e8e\u6240\u6709\u60c5\u51b5\uff0c\u4f46\u8bf7\u786e\u4fdd\u60a8\u4ece\u8f83\u5c0f\u7684\u6a21\u578b\u5f00\u59cb\uff0c\u4f8b\u5982\uff0c\u4ece resnet-18 \u5f00\u59cb\uff0c\u800c\u4e0d\u662f resnet-50\u3002\u5176\u4ed6\u4e00\u4e9b ImageNet \u9884\u8bad\u7ec3\u6a21\u578b\u5305\u62ec\uff1a Inception DenseNet(different variations) NASNet PNASNet VGG Xception ResNeXt EfficientNet, etc. \u5927\u90e8\u5206\u9884\u8bad\u7ec3\u7684\u6700\u5148\u8fdb\u6a21\u578b\u53ef\u4ee5\u5728 GitHub \u4e0a\u7684 pytorch- pretrainedmodels \u8d44\u6e90\u5e93\u4e2d\u627e\u5230\uff1ahttps://github.com/Cadene/pretrained-models.pytorch\u3002\u8be6\u7ec6\u8ba8\u8bba\u8fd9\u4e9b\u6a21\u578b\u4e0d\u5728\u672c\u7ae0\uff08\u548c\u672c\u4e66\uff09\u8303\u56f4\u4e4b\u5185\u3002\u65e2\u7136\u6211\u4eec\u53ea\u5173\u6ce8\u5e94\u7528\uff0c\u90a3\u5c31\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6837\u7684\u9884\u8bad\u7ec3\u6a21\u578b\u5982\u4f55\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u3002 \u56fe 8\uff1aU-Net\u67b6\u6784 \u5206\u5272\uff08Segmentation\uff09\u662f\u8ba1\u7b97\u673a\u89c6\u89c9\u4e2d\u76f8\u5f53\u6d41\u884c\u7684\u4e00\u9879\u4efb\u52a1\u3002\u5728\u5206\u5272\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u8bd5\u56fe\u4ece\u80cc\u666f\u4e2d\u79fb\u9664/\u63d0\u53d6\u524d\u666f\u3002 \u524d\u666f\u548c\u80cc\u666f\u53ef\u4ee5\u6709\u4e0d\u540c\u7684\u5b9a\u4e49\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u8bf4\uff0c\u8fd9\u662f\u4e00\u9879\u50cf\u7d20\u5206\u7c7b\u4efb\u52a1\uff0c\u4f60\u7684\u5de5\u4f5c\u662f\u7ed9\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u6bcf\u4e2a\u50cf\u7d20\u5206\u914d\u4e00\u4e2a\u7c7b\u522b\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u6b63\u5728\u5904\u7406\u7684\u6c14\u80f8\u6570\u636e\u96c6\u5c31\u662f\u4e00\u9879\u5206\u5272\u4efb\u52a1\u3002\u5728\u8fd9\u9879\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u7ed9\u5b9a\u7684\u80f8\u90e8\u653e\u5c04\u56fe\u50cf\u8fdb\u884c\u6c14\u80f8\u5206\u5272\u3002\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u7684\u6700\u5e38\u7528\u6a21\u578b\u662f U-Net\u3002\u5176\u7ed3\u6784\u5982\u56fe 8 \u6240\u793a\u3002 U-Net \u5305\u62ec\u4e24\u4e2a\u90e8\u5206\uff1a\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u3002\u7f16\u7801\u5668\u4e0e\u60a8\u76ee\u524d\u6240\u89c1\u8fc7\u7684\u4efb\u4f55 U-Net \u90fd\u662f\u4e00\u6837\u7684\u3002\u89e3\u7801\u5668\u5219\u6709\u4e9b\u4e0d\u540c\u3002\u89e3\u7801\u5668\u7531\u4e0a\u5377\u79ef\u5c42\u7ec4\u6210\u3002\u5728\u4e0a\u5377\u79ef\uff08up-convolutions\uff09\uff08 \u8f6c\u7f6e\u5377\u79ef transposed convolutions\uff09\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u6ee4\u6ce2\u5668\uff0c\u5f53\u5e94\u7528\u5230\u4e00\u4e2a\u5c0f\u56fe\u50cf\u65f6\uff0c\u4f1a\u4ea7\u751f\u4e00\u4e2a\u5927\u56fe\u50cf\u3002\u5728 PyTorch \u4e2d\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528 ConvTranspose2d \u6765\u5b8c\u6210\u8fd9\u4e00\u64cd\u4f5c\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u4e0a\u5377\u79ef\u4e0e\u4e0a\u91c7\u6837\u5e76\u4e0d\u76f8\u540c\u3002\u4e0a\u91c7\u6837\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u8fc7\u7a0b\uff0c\u6211\u4eec\u5728\u56fe\u50cf\u4e0a\u5e94\u7528\u4e00\u4e2a\u51fd\u6570\u6765\u8c03\u6574\u5b83\u7684\u5927\u5c0f\u3002\u5728\u4e0a\u5377\u79ef\u4e2d\uff0c\u6211\u4eec\u8981\u5b66\u4e60\u6ee4\u6ce2\u5668\u3002\u6211\u4eec\u5c06\u7f16\u7801\u5668\u7684\u67d0\u4e9b\u90e8\u5206\u4f5c\u4e3a\u67d0\u4e9b\u89e3\u7801\u5668\u7684\u8f93\u5165\u3002\u8fd9\u5bf9 \u4e0a\u5377\u79ef\u5c42\u975e\u5e38\u91cd\u8981\u3002 \u8ba9\u6211\u4eec\u770b\u770b U-Net \u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import torch import torch.nn as nn from torch.nn import functional as F # \u5b9a\u4e49\u4e00\u4e2a\u53cc\u5377\u79ef\u5c42 def double_conv ( in_channels , out_channels ): conv = nn . Sequential ( nn . Conv2d ( in_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ), nn . Conv2d ( out_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ) ) return conv # \u5b9a\u4e49\u51fd\u6570\u7528\u4e8e\u88c1\u526a\u8f93\u5165\u5f20\u91cf def crop_tensor ( tensor , target_tensor ): target_size = target_tensor . size ()[ 2 ] tensor_size = tensor . size ()[ 2 ] delta = tensor_size - target_size delta = delta // 2 return tensor [:, :, delta : tensor_size - delta , delta : tensor_size - delta ] # \u5b9a\u4e49 U-Net \u6a21\u578b class UNet ( nn . Module ): def __init__ ( self ): super ( UNet , self ) . __init () # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . max_pool_2x2 = nn . MaxPool2d ( kernel_size = 2 , stride = 2 ) self . down_conv_1 = double_conv ( 1 , 64 ) self . down_conv_2 = double_conv ( 64 , 128 ) self . down_conv_3 = double_conv ( 128 , 256 ) self . down_conv_4 = double_conv ( 256 , 512 ) self . down_conv_5 = double_conv ( 512 , 1024 ) # \u5b9a\u4e49\u4e0a\u91c7\u6837\u5c42\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . up_trans_1 = nn . ConvTranspose2d ( in_channels = 1024 , out_channels = 512 , kernel_size = 2 , stride = 2 ) self . up_conv_1 = double_conv ( 1024 , 512 ) self . up_trans_2 = nn . ConvTranspose2d ( in_channels = 512 , out_channels = 256 , kernel_size = 2 , stride = 2 ) self . up_conv_2 = double_conv ( 512 , 256 ) self . up_trans_3 = nn . ConvTranspose2d ( in_channels = 256 , out_channels = 128 , kernel_size = 2 , stride = 2 ) self . up_conv_3 = double_conv ( 256 , 128 ) self . up_trans_4 = nn . ConvTranspose2d ( in_channels = 128 , out_channels = 64 , kernel_size = 2 , stride = 2 ) self . up_conv_4 = double_conv ( 128 , 64 ) # \u5b9a\u4e49\u8f93\u51fa\u5c42 self . out = nn . Conv2d ( in_channels = 64 , out_channels = 2 , kernel_size = 1 ) def forward ( self , image ): # \u7f16\u7801\u5668\u90e8\u5206 x1 = self . down_conv_1 ( image ) x2 = self . max_pool_2x2 ( x1 ) x3 = self . down_conv_2 ( x2 ) x4 = self . max_pool_2x2 ( x3 ) x5 = self . down_conv_3 ( x4 ) x6 = self . max_pool_2x2 ( x5 ) x7 = self . down_conv_4 ( x6 ) x8 = self . max_pool_2x2 ( x7 ) x9 = self . down_conv_5 ( x8 ) # \u89e3\u7801\u5668\u90e8\u5206 x = self . up_trans_1 ( x9 ) y = crop_tensor ( x7 , x ) x = self . up_conv_1 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_2 ( x ) y = crop_tensor ( x5 , x ) x = self . up_conv_2 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_3 ( x ) y = crop_tensor ( x3 , x ) x = self . up_conv_3 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_4 ( x ) y = crop_tensor ( x1 , x ) x = self . up_conv_4 ( torch . cat ([ x , y ], axis = 1 )) # \u8f93\u51fa\u5c42 out = self . out ( x ) return out if __name__ == \"__main__\" : image = torch . rand (( 1 , 1 , 572 , 572 )) model = UNet () print ( model ( image )) \u8bf7\u6ce8\u610f\uff0c\u6211\u4e0a\u9762\u5c55\u793a\u7684 U-Net \u5b9e\u73b0\u662f U-Net \u8bba\u6587\u7684\u539f\u59cb\u5b9e\u73b0\u3002\u4e92\u8054\u7f51\u4e0a\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9e\u73b0\u65b9\u6cd5\u3002 \u6709\u4e9b\u4eba\u559c\u6b22\u4f7f\u7528\u53cc\u7ebf\u6027\u91c7\u6837\u4ee3\u66ff\u8f6c\u7f6e\u5377\u79ef\u8fdb\u884c\u4e0a\u91c7\u6837\uff0c\u4f46\u8fd9\u5e76\u4e0d\u662f\u8bba\u6587\u7684\u771f\u6b63\u5b9e\u73b0\u3002\u4e0d\u8fc7\uff0c\u5b83\u7684\u6027\u80fd\u53ef\u80fd\u4f1a\u66f4\u597d\u3002\u5728\u4e0a\u56fe\u6240\u793a\u7684\u539f\u59cb\u5b9e\u73b0\u4e2d\uff0c\u6709\u4e00\u4e2a\u5355\u901a\u9053\u56fe\u50cf\uff0c\u8f93\u51fa\u4e2d\u6709\u4e24\u4e2a\u901a\u9053\uff1a\u4e00\u4e2a\u662f\u524d\u666f\uff0c\u4e00\u4e2a\u662f\u80cc\u666f\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u4e3a\u4efb\u610f\u6570\u91cf\u7684\u7c7b\u548c\u4efb\u610f\u6570\u91cf\u7684\u8f93\u5165\u901a\u9053\u8fdb\u884c\u5b9a\u5236\u3002\u5728\u6b64\u5b9e\u73b0\u4e2d\uff0c\u8f93\u5165\u56fe\u50cf\u7684\u5927\u5c0f\u4e0e\u8f93\u51fa\u56fe\u50cf\u7684\u5927\u5c0f\u4e0d\u540c\uff0c\u56e0\u4e3a\u6211\u4eec\u4f7f\u7528\u7684\u662f\u65e0\u586b\u5145\u5377\u79ef\uff08convolutions without padding\uff09\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0cU-Net \u7684\u7f16\u7801\u5668\u90e8\u5206\u53ea\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u5377\u79ef\u7f51\u7edc\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u4efb\u4f55\u7f51\u7edc\uff08\u5982 ResNet\uff09\u6765\u66ff\u6362\u5b83\u3002 \u8fd9\u79cd\u66ff\u6362\u4e5f\u53ef\u4ee5\u901a\u8fc7\u9884\u8bad\u7ec3\u6743\u91cd\u6765\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u57fa\u4e8e ResNet \u7684\u7f16\u7801\u5668\uff0c\u8be5\u7f16\u7801\u5668\u5df2\u5728 ImageNet \u548c\u901a\u7528\u89e3\u7801\u5668\u4e0a\u8fdb\u884c\u4e86\u9884\u8bad\u7ec3\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u591a\u79cd\u4e0d\u540c\u7684\u7f51\u7edc\u67b6\u6784\u6765\u4ee3\u66ff ResNet\u3002Pavel Yakubovskiy \u6240\u8457\u7684\u300aSegmentation Models Pytorch\u300b\u5c31\u662f\u8bb8\u591a\u6b64\u7c7b\u53d8\u4f53\u7684\u5b9e\u73b0\uff0c\u5176\u4e2d\u7f16\u7801\u5668\u53ef\u4ee5\u88ab\u9884\u8bad\u7ec3\u6a21\u578b\u6240\u53d6\u4ee3\u3002\u8ba9\u6211\u4eec\u5e94\u7528\u57fa\u4e8e ResNet \u7684 U-Net \u6765\u89e3\u51b3\u6c14\u80f8\u68c0\u6d4b\u95ee\u9898\u3002 \u5927\u591a\u6570\u7c7b\u4f3c\u7684\u95ee\u9898\u90fd\u6709\u4e24\u4e2a\u8f93\u5165\uff1a\u539f\u59cb\u56fe\u50cf\u548c\u63a9\u7801\uff08mask\uff09\u3002 \u5982\u679c\u6709\u591a\u4e2a\u5bf9\u8c61\uff0c\u5c31\u4f1a\u6709\u591a\u4e2a\u63a9\u7801\u3002 \u5728\u6211\u4eec\u7684\u6c14\u80f8\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f RLE\u3002RLE \u4ee3\u8868\u8fd0\u884c\u957f\u5ea6\u7f16\u7801\uff0c\u662f\u4e00\u79cd\u8868\u793a\u4e8c\u8fdb\u5236\u63a9\u7801\u4ee5\u8282\u7701\u7a7a\u95f4\u7684\u65b9\u6cd5\u3002\u6df1\u5165\u7814\u7a76 RLE \u8d85\u51fa\u4e86\u672c\u7ae0\u7684\u8303\u56f4\u3002\u56e0\u6b64\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u5f20\u8f93\u5165\u56fe\u50cf\u548c\u76f8\u5e94\u7684\u63a9\u7801\u3002\u8ba9\u6211\u4eec\u5148\u8bbe\u8ba1\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u8f93\u51fa\u56fe\u50cf\u548c\u63a9\u7801\u56fe\u50cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u521b\u5efa\u7684\u811a\u672c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u5206\u5272\u95ee\u9898\u3002\u8bad\u7ec3\u6570\u636e\u96c6\u662f\u4e00\u4e2a CSV \u6587\u4ef6\uff0c\u53ea\u5305\u542b\u56fe\u50cf ID\uff08\u4e5f\u662f\u6587\u4ef6\u540d\uff09\u3002 import os import glob import torch import numpy as np import pandas as pd from PIL import Image , ImageFile from tqdm import tqdm from collections import defaultdict from torchvision import transforms from albumentations import ( Compose , OneOf , RandomBrightnessContrast , RandomGamma , ShiftScaleRotate , ) # \u8bbe\u7f6ePIL\u56fe\u50cf\u52a0\u8f7d\u622a\u65ad\u7684\u5904\u7406 ImageFile . LOAD_TRUNCATED_IMAGES = True # \u521b\u5efaSIIM\u6570\u636e\u96c6\u7c7b class SIIMDataset ( torch . utils . data . Dataset ): def __init__ ( self , image_ids , transform = True , preprocessing_fn = None ): self . data = defaultdict ( dict ) self . transform = transform self . preprocessing_fn = preprocessing_fn # \u5b9a\u4e49\u6570\u636e\u589e\u5f3a self . aug = Compose ( [ ShiftScaleRotate ( shift_limit = 0.0625 , scale_limit = 0.1 , rotate_limit = 10 , p = 0.8 ), OneOf ( [ RandomGamma ( gamma_limit = ( 90 , 110 ) ), RandomBrightnessContrast ( brightness_limit = 0.1 , contrast_limit = 0.1 ), ], p = 0.5 , ), ] ) # \u6784\u5efa\u6570\u636e\u5b57\u5178\uff0c\u5176\u4e2d\u5305\u542b\u56fe\u50cf\u548c\u63a9\u7801\u7684\u8def\u5f84\u4fe1\u606f for imgid in image_ids : files = glob . glob ( os . path . join ( TRAIN_PATH , imgid , \"*.png\" )) self . data [ counter ] = { \"img_path\" : os . path . join ( TRAIN_PATH , imgid + \".png\" ), \"mask_path\" : os . path . join ( TRAIN_PATH , imgid + \"_mask.png\" ), } def __len__ ( self ): return len ( self . data ) def __getitem__ ( self , item ): img_path = self . data [ item ][ \"img_path\" ] mask_path = self . data [ item ][ \"mask_path\" ] # \u6253\u5f00\u56fe\u50cf\u5e76\u5c06\u5176\u8f6c\u6362\u4e3aRGB\u6a21\u5f0f img = Image . open ( img_path ) img = img . convert ( \"RGB\" ) img = np . array ( img ) # \u6253\u5f00\u63a9\u7801\u56fe\u50cf\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6570 mask = Image . open ( mask_path ) mask = ( mask >= 1 ) . astype ( \"float32\" ) # \u5982\u679c\u9700\u8981\u8fdb\u884c\u6570\u636e\u589e\u5f3a if self . transform is True : augmented = self . aug ( image = img , mask = mask ) img = augmented [ \"image\" ] mask = augmented [ \"mask\" ] # \u5e94\u7528\u9884\u5904\u7406\u51fd\u6570\uff08\u5982\u679c\u6709\uff09 img = self . preprocessing_fn ( img ) # \u8fd4\u56de\u56fe\u50cf\u548c\u63a9\u7801 return { \"image\" : transforms . ToTensor ()( img ), \"mask\" : transforms . ToTensor ()( mask ) . float (), } \u6709\u4e86\u6570\u636e\u96c6\u7c7b\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8bad\u7ec3\u51fd\u6570\u3002 import os import sys import torch import numpy as np import pandas as pd import segmentation_models_pytorch as smp import torch.nn as nn import torch.optim as optim from apex import amp from collections import OrderedDict from sklearn import model_selection from tqdm import tqdm from torch.optim import lr_scheduler from dataset import SIIMDataset # \u5b9a\u4e49\u8bad\u7ec3\u6570\u636e\u96c6CSV\u6587\u4ef6\u8def\u5f84 TRAINING_CSV = \"../input/train_pneumothorax.csv\" # \u5b9a\u4e49\u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u6279\u91cf\u5927\u5c0f TRAINING_BATCH_SIZE = 16 TEST_BATCH_SIZE = 4 # \u5b9a\u4e49\u8bad\u7ec3\u7684\u65f6\u671f\u6570 EPOCHS = 10 # \u6307\u5b9a\u4f7f\u7528\u7684\u7f16\u7801\u5668\u548c\u6743\u91cd ENCODER = \"resnet18\" ENCODER_WEIGHTS = \"imagenet\" # \u6307\u5b9a\u8bbe\u5907\uff08GPU\uff09 DEVICE = \"cuda\" # \u5b9a\u4e49\u8bad\u7ec3\u51fd\u6570 def train ( dataset , data_loader , model , criterion , optimizer ): model . train () num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs . to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) optimizer . zero_grad () outputs = model ( inputs ) loss = criterion ( outputs , targets ) with amp . scale_loss ( loss , optimizer ) as scaled_loss : scaled_loss . backward () optimizer . step () tk0 . close () # \u5b9a\u4e49\u8bc4\u4f30\u51fd\u6570 def evaluate ( dataset , data_loader , model ): model . eval () final_loss = 0 num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) with torch . no_grad (): for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) output = model ( inputs ) loss = criterion ( output , targets ) final_loss += loss tk0 . close () return final_loss / num_batches if __name__ == \"__main__\" : df = pd . read_csv ( TRAINING_CSV ) df_train , df_valid = model_selection . train_test_split ( df , random_state = 42 , test_size = 0.1 ) training_images = df_train . image_id . values validation_images = df_valid . image_id . values # \u521b\u5efa U-Net \u6a21\u578b model = smp . Unet ( encoder_name = ENCODER , encoder_weights = ENCODER_WEIGHTS , classes = 1 , activation = None , ) # \u83b7\u53d6\u6570\u636e\u9884\u5904\u7406\u51fd\u6570 prep_fn = smp . encoders . get_preprocessing_fn ( ENCODER , ENCODER_WEIGHTS ) # \u5c06\u6a21\u578b\u653e\u5728\u8bbe\u5907\u4e0a model . to ( DEVICE ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6 train_dataset = SIIMDataset ( training_images , transform = True , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = TRAINING_BATCH_SIZE , shuffle = True , num_workers = 12 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = SIIMDataset ( validation_images , transform = False , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = TEST_BATCH_SIZE , shuffle = True , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) # \u5b9a\u4e49\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = lr_scheduler . ReduceLROnPlateau ( optimizer , mode = \"min\" , patience = 3 , verbose = True ) # \u521d\u59cb\u5316 Apex \u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3 model , optimizer = amp . initialize ( model , optimizer , opt_level = \"O1\" , verbosity = 0 ) # \u5982\u679c\u6709\u591a\u4e2aGPU\uff0c\u5219\u4f7f\u7528 DataParallel \u8fdb\u884c\u5e76\u884c\u8bad\u7ec3 if torch . cuda . device_count () > 1 : print ( f \"Let's use { torch . cuda . device_count () } GPUs!\" ) model = nn . DataParallel ( model ) # \u8f93\u51fa\u8bad\u7ec3\u76f8\u5173\u7684\u4fe1\u606f print ( f \"Training batch size: { TRAINING_BATCH_SIZE } \" ) print ( f \"Test batch size: { TEST_BATCH_SIZE } \" ) print ( f \"Epochs: { EPOCHS } \" ) print ( f \"Image size: { IMAGE_SIZE } \" ) print ( f \"Number of training images: { len ( train_dataset ) } \" ) print ( f \"Number of validation images: { len ( valid_dataset ) } \" ) print ( f \"Encoder: { ENCODER } \" ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( EPOCHS ): print ( f \"Training Epoch: { epoch } \" ) train ( train_dataset , train_loader , model , criterion , optimizer ) print ( f \"Validation Epoch: { epoch } \" ) val_log = evaluate ( valid_dataset , valid_loader , model ) scheduler . step ( val_log [ \"loss\" ]) print ( \" \\n \" ) \u5728\u5206\u5272\u95ee\u9898\u4e2d\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u5404\u79cd\u635f\u5931\u51fd\u6570\uff0c\u4f8b\u5982\u4e8c\u5143\u4ea4\u53c9\u71b5\u3001focal\u635f\u5931\u3001dice\u635f\u5931\u7b49\u3002\u6211\u628a\u8fd9\u4e2a\u95ee\u9898\u7559\u7ed9 \u8bfb\u8005\u6839\u636e\u8bc4\u4f30\u6307\u6807\u6765\u51b3\u5b9a\u5408\u9002\u7684\u635f\u5931\u3002\u5f53\u8bad\u7ec3\u8fd9\u6837\u4e00\u4e2a\u6a21\u578b\u65f6\uff0c\u60a8\u5c06\u5efa\u7acb\u9884\u6d4b\u6c14\u80f8\u4f4d\u7f6e\u7684\u6a21\u578b\uff0c\u5982\u56fe 9 \u6240\u793a\u3002\u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u82f1\u4f1f\u8fbe apex \u8fdb\u884c\u4e86\u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4ece PyTorch 1.6.0+ \u7248\u672c\u5f00\u59cb\uff0cPyTorch \u672c\u8eab\u5c31\u63d0\u4f9b\u4e86\u8fd9\u4e00\u529f\u80fd\u3002 \u56fe 9\uff1a\u4ece\u8bad\u7ec3\u6709\u7d20\u7684\u6a21\u578b\u4e2d\u68c0\u6d4b\u5230\u6c14\u80f8\u7684\u793a\u4f8b\uff08\u53ef\u80fd\u4e0d\u662f\u6b63\u786e\u9884\u6d4b\uff09\u3002 \u6211\u5728\u4e00\u4e2a\u540d\u4e3a \"Well That's Fantastic Machine Learning (WTFML) \"\u7684 python \u8f6f\u4ef6\u5305\u4e2d\u6536\u5f55\u4e86\u4e00\u4e9b\u5e38\u7528\u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u5982\u4f55\u5e2e\u52a9\u6211\u4eec\u4e3a FGVC 202013 \u690d\u7269\u75c5\u7406\u5b66\u6311\u6218\u8d5b\u4e2d\u7684\u690d\u7269\u56fe\u50cf\u5efa\u7acb\u591a\u7c7b\u5206\u7c7b\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import argparse import torch import torchvision import torch.nn as nn import torch.nn.functional as F from sklearn import metrics from sklearn.model_selection import train_test_split from wtfml.engine import Engine from wtfml.data_loaders.image import ClassificationDataLoader # \u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\uff0c\u5b9e\u73b0\u5bc6\u96c6\u4ea4\u53c9\u71b5 class DenseCrossEntropy ( nn . Module ): def __init__ ( self ): super ( DenseCrossEntropy , self ) . __init__ () def forward ( self , logits , labels ): logits = logits . float () labels = labels . float () logprobs = F . log_softmax ( logits , dim =- 1 ) loss = - labels * logprobs loss = loss . sum ( - 1 ) return loss . mean () # \u81ea\u5b9a\u4e49\u795e\u7ecf\u7f51\u7edc\u6a21\u578b class Model ( nn . Module ): def __init__ ( self ): super () . __init () self . base_model = torchvision . models . resnet18 ( pretrained = True ) in_features = self . base_model . fc . in_features self . out = nn . Linear ( in_features , 4 ) def forward ( self , image , targets = None ): batch_size , C , H , W = image . shape x = self . base_model . conv1 ( image ) x = self . base_model . bn1 ( x ) x = self . base_model . relu ( x ) x = self . base_model . maxpool ( x ) x = self . base_model . layer1 ( x ) x = self . base_model . layer2 ( x ) x = self . base_model . layer3 ( x ) x = self . base_model . layer4 ( x ) x = F . adaptive_avg_pool2d ( x , 1 ) . reshape ( batch_size , - 1 ) x = self . out ( x ) loss = None if targets is not None : loss = DenseCrossEntropy ()( x , targets . type_as ( x )) return x , loss if __name__ == \"__main__\" : # \u547d\u4ee4\u884c\u53c2\u6570\u89e3\u6790\u5668 parser = argparse . ArgumentParser () parser . add_argument ( \"--data_path\" , type = str , ) parser . add_argument ( \"--device\" , type = str ,) parser . add_argument ( \"--epochs\" , type = int ,) args = parser . parse_args () # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( os . path . join ( args . data_path , \"train.csv\" )) images = df . image_id . values . tolist () images = [ os . path . join ( args . data_path , \"images\" , i + \".jpg\" ) for i in images ] targets = df [[ \"healthy\" , \"multiple_diseases\" , \"rust\" , \"scab\" ]] . values # \u521b\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u578b model = Model () model . to ( args . device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\u4ee5\u53ca\u6570\u636e\u589e\u5f3a mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5206\u5272\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 ( train_images , valid_images , train_targets , valid_targets ) = train_test_split ( images , targets ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = ClassificationDataLoader ( image_paths = train_images , targets = train_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = True , tpu = False ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = ClassificationDataLoader ( image_paths = valid_images , targets = valid_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = False , tpu = False ) # \u521b\u5efa\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u521b\u5efa\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = torch . optim . lr_scheduler . StepLR ( optimizer , step_size = 15 , gamma = 0.6 ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( args . epochs ): # \u8bad\u7ec3\u6a21\u578b train_loss = Engine . train ( train_loader , model , optimizer , device = args . device ) # \u8bc4\u4f30\u6a21\u578b valid_loss = Engine . evaluate ( valid_loader , model , device = args . device ) # \u6253\u5370\u635f\u5931\u4fe1\u606f print ( f \" { epoch } , Train Loss= { train_loss } Valid Loss= { valid_loss } \" ) \u6709\u4e86\u6570\u636e\u540e\uff0c\u5c31\u53ef\u4ee5\u8fd0\u884c\u811a\u672c\u4e86\uff1a python plant.py --data_path ../../plant_pathology --device cuda -- epochs 2 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .73it/s, loss = 0 .723 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .62it/s, loss = 0 .433 ] 0 , Train Loss = 0 .7228777609592261 Valid Loss = 0 .4327834551704341 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .74it/s, loss = 0 .271 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .63it/s, loss = 0 .568 ] 1 , Train Loss = 0 .2708700496790021 Valid Loss = 0 .56841839541649 \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u8ba9\u6211\u4eec\u6784\u5efa\u6a21\u578b\u53d8\u5f97\u7b80\u5355\uff0c\u4ee3\u7801\u4e5f\u6613\u4e8e\u9605\u8bfb\u548c\u7406\u89e3\u3002\u6ca1\u6709\u4efb\u4f55\u5c01\u88c5\u7684 PyTorch \u6548\u679c\u6700\u597d\u3002\u56fe\u50cf\u4e2d\u4e0d\u4ec5\u4ec5\u6709\u5206\u7c7b\uff0c\u8fd8\u6709\u5f88\u591a\u5176\u4ed6\u7684\u5185\u5bb9\uff0c\u5982\u679c\u6211\u5f00\u59cb\u5199\u6240\u6709\u7684\u5185\u5bb9\uff0c\u5c31\u5f97\u518d\u5199\u4e00\u672c\u4e66\u4e86\uff0c \u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u56fe\u50cf\u95ee\u9898\uff08\u4f5c\u8005\u5728\u5f00\u73a9\u7b11\uff09\u3002","title":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5"},{"location":"%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB%E5%92%8C%E5%88%86%E5%89%B2%E6%96%B9%E6%B3%95/#_1","text":"\u8bf4\u5230\u56fe\u50cf\uff0c\u8fc7\u53bb\u51e0\u5e74\u53d6\u5f97\u4e86\u5f88\u591a\u6210\u5c31\u3002\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8fdb\u6b65\u76f8\u5f53\u5feb\uff0c\u611f\u89c9\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8bb8\u591a\u95ee\u9898\u73b0\u5728\u90fd\u66f4\u5bb9\u6613\u89e3\u51b3\u4e86\u3002\u968f\u7740\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u51fa\u73b0\u548c\u8ba1\u7b97\u6210\u672c\u7684\u964d\u4f4e\uff0c\u73b0\u5728\u5728\u5bb6\u91cc\u5c31\u80fd\u8f7b\u677e\u8bad\u7ec3\u51fa\u63a5\u8fd1\u6700\u5148\u8fdb\u6c34\u5e73\u7684\u6a21\u578b\uff0c\u89e3\u51b3\u5927\u591a\u6570\u4e0e\u56fe\u50cf\u76f8\u5173\u7684\u95ee\u9898\u3002\u4f46\u662f\uff0c\u56fe\u50cf\u95ee\u9898\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u578b\u3002\u4ece\u4e24\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u7684\u6807\u51c6\u56fe\u50cf\u5206\u7c7b\uff0c\u5230\u50cf\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\u8fd9\u6837\u5177\u6709\u6311\u6218\u6027\u7684\u95ee\u9898\u3002\u6211\u4eec\u4e0d\u4f1a\u5728\u672c\u4e66\u4e2d\u8ba8\u8bba\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\uff0c\u4f46\u6211\u4eec\u663e\u7136\u4f1a\u5904\u7406\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u56fe\u50cf\u95ee\u9898\u3002 \u6211\u4eec\u53ef\u4ee5\u5bf9\u56fe\u50cf\u91c7\u7528\u54ea\u4e9b\u4e0d\u540c\u7684\u65b9\u6cd5\uff1f\u56fe\u50cf\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u6570\u5b57\u77e9\u9635\u3002\u8ba1\u7b97\u673a\u65e0\u6cd5\u50cf\u4eba\u7c7b\u4e00\u6837\u770b\u5230\u56fe\u50cf\u3002\u5b83\u53ea\u80fd\u770b\u5230\u6570\u5b57\uff0c\u8fd9\u5c31\u662f\u56fe\u50cf\u3002\u7070\u5ea6\u56fe\u50cf\u662f\u4e00\u4e2a\u4e8c\u7ef4\u77e9\u9635\uff0c\u6570\u503c\u8303\u56f4\u4ece 0 \u5230 255\u30020 \u4ee3\u8868\u9ed1\u8272\uff0c255 \u4ee3\u8868\u767d\u8272\uff0c\u4ecb\u4e8e\u4e24\u8005\u4e4b\u95f4\u7684\u662f\u5404\u79cd\u7070\u8272\u3002\u4ee5\u524d\uff0c\u5728\u6ca1\u6709\u6df1\u5ea6\u5b66\u4e60\u7684\u65f6\u5019\uff08\u6216\u8005\u8bf4\u6df1\u5ea6\u5b66\u4e60\u8fd8\u4e0d\u6d41\u884c\u7684\u65f6\u5019\uff09\uff0c\u4eba\u4eec\u4e60\u60ef\u4e8e\u67e5\u770b\u50cf\u7d20\u3002\u6bcf\u4e2a\u50cf\u7d20\u90fd\u662f\u4e00\u4e2a\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u5728 Python \u4e2d\u8f7b\u677e\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u53ea\u9700\u4f7f\u7528 OpenCV \u6216 Python-PIL \u8bfb\u53d6\u7070\u5ea6\u56fe\u50cf\uff0c\u8f6c\u6362\u4e3a numpy \u6570\u7ec4\uff0c\u7136\u540e\u5c06\u77e9\u9635\u5e73\u94fa\uff08\u6241\u5e73\u5316\uff09\u5373\u53ef\u3002\u5982\u679c\u5904\u7406\u7684\u662f RGB \u56fe\u50cf\uff0c\u5219\u9700\u8981\u4e09\u4e2a\u77e9\u9635\uff0c\u800c\u4e0d\u662f\u4e00\u4e2a\u3002\u4f46\u601d\u8def\u662f\u4e00\u6837\u7684\u3002 import numpy as np import matplotlib.pyplot as plt # \u751f\u6210\u4e00\u4e2a 256x256 \u7684\u968f\u673a\u7070\u5ea6\u56fe\u50cf\uff0c\u50cf\u7d20\u503c\u57280\u5230255\u4e4b\u95f4\u968f\u673a\u5206\u5e03 random_image = np . random . randint ( 0 , 256 , ( 256 , 256 )) # \u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u56fe\u50cf\u7a97\u53e3\uff0c\u8bbe\u7f6e\u7a97\u53e3\u5927\u5c0f\u4e3a7x7\u82f1\u5bf8 plt . figure ( figsize = ( 7 , 7 )) # \u663e\u793a\u751f\u6210\u7684\u968f\u673a\u56fe\u50cf # \u4f7f\u7528\u7070\u5ea6\u989c\u8272\u6620\u5c04 (colormap)\uff0c\u8303\u56f4\u4ece0\u5230255 plt . imshow ( random_image , cmap = 'gray' , vmin = 0 , vmax = 255 ) # \u663e\u793a\u56fe\u50cf\u7a97\u53e3 plt . show () \u4e0a\u9762\u7684\u4ee3\u7801\u4f7f\u7528 numpy \u751f\u6210\u4e00\u4e2a\u968f\u673a\u77e9\u9635\u3002\u8be5\u77e9\u9635\u7531 0 \u5230 255\uff08\u5305\u542b\uff09\u7684\u503c\u7ec4\u6210\uff0c\u5927\u5c0f\u4e3a 256x256\uff08\u4e5f\u79f0\u4e3a\u50cf\u7d20\uff09\u3002 \u56fe 1\uff1a\u4e8c\u7ef4\u56fe\u50cf\u9635\u5217\uff08\u5355\u901a\u9053\uff09\u53ca\u5176\u5c55\u5e73\u7248\u672c \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u62fc\u5199\u540e\u7684\u7248\u672c\u53ea\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a M \u7684\u5411\u91cf\uff0c\u5176\u4e2d M = N * N\uff0c\u5728\u672c\u4f8b\u4e2d\uff0c\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 256 * 256 = 65536\u3002 \u73b0\u5728\uff0c\u5982\u679c\u6211\u4eec\u7ee7\u7eed\u5bf9\u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u8fdb\u884c\u5904\u7406\uff0c\u6bcf\u4e2a\u6837\u672c\u5c31\u4f1a\u6709 65536 \u4e2a\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5feb\u901f\u5efa\u7acb \u51b3\u7b56\u6811\u6a21\u578b\u3001\u968f\u673a\u68ee\u6797\u6a21\u578b\u6216\u57fa\u4e8e SVM \u7684\u6a21\u578b \u3002\u8fd9\u4e9b\u6a21\u578b\u5c06\u57fa\u4e8e\u50cf\u7d20\u503c\uff0c\u5c1d\u8bd5\u5c06\u6b63\u6837\u672c\u4e0e\u8d1f\u6837\u672c\u533a\u5206\u5f00\u6765\uff08\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff09\u3002 \u4f60\u4eec\u4e00\u5b9a\u90fd\u542c\u8bf4\u8fc7\u732b\u4e0e\u72d7\u7684\u95ee\u9898\uff0c\u8fd9\u662f\u4e00\u4e2a\u7ecf\u5178\u7684\u95ee\u9898\u3002\u5982\u679c\u4f60\u4eec\u8fd8\u8bb0\u5f97\uff0c\u5728\u8bc4\u4f30\u6307\u6807\u4e00\u7ae0\u7684\u5f00\u5934\uff0c\u6211\u5411\u4f60\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e2a\u6c14\u80f8\u56fe\u50cf\u6570\u636e\u96c6\u3002\u90a3\u4e48\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u68c0\u6d4b\u80ba\u90e8\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\uff08\u5e76\u4e0d\uff09\u7b80\u5355\u7684\u4e8c\u5143\u5206\u7c7b\u3002 \u56fe 2\uff1a\u975e\u6c14\u80f8\u4e0e\u6c14\u80f8 X \u5149\u56fe\u50cf\u5bf9\u6bd4 \u5728\u56fe 2 \u4e2d\uff0c\u60a8\u53ef\u4ee5\u770b\u5230\u975e\u6c14\u80f8\u548c\u6c14\u80f8\u56fe\u50cf\u7684\u5bf9\u6bd4\u3002\u60a8\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\u4e86\uff0c\u5bf9\u4e8e\u4e00\u4e2a\u975e\u4e13\u4e1a\u4eba\u58eb\uff08\u6bd4\u5982\u6211\uff09\u6765\u8bf4\uff0c\u8981\u5728\u8fd9\u4e9b\u56fe\u50cf\u4e2d\u8fa8\u522b\u51fa\u54ea\u4e2a\u662f\u6c14\u80f8\u662f\u76f8\u5f53\u56f0\u96be\u7684\u3002 \u6700\u521d\u7684\u6570\u636e\u96c6\u662f\u5173\u4e8e\u68c0\u6d4b\u6c14\u80f8\u7684\u5177\u4f53\u4f4d\u7f6e\uff0c\u4f46\u6211\u4eec\u5c06\u95ee\u9898\u4fee\u6539\u4e3a\u67e5\u627e\u7ed9\u5b9a\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u522b\u62c5\u5fc3\uff0c\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u4ecb\u7ecd\u8fd9\u4e2a\u90e8\u5206\u3002\u6570\u636e\u96c6\u7531 10675 \u5f20\u72ec\u7279\u7684\u56fe\u50cf\u7ec4\u6210\uff0c\u5176\u4e2d 2379 \u5f20\u6709\u6c14\u80f8\uff08\u6ce8\u610f\uff0c\u8fd9\u4e9b\u6570\u5b57\u662f\u7ecf\u8fc7\u6570\u636e\u6e05\u7406\u540e\u5f97\u51fa\u7684\uff0c\u56e0\u6b64\u4e0e\u539f\u59cb\u6570\u636e\u96c6\u4e0d\u7b26\uff09\u3002\u6b63\u5982\u6570\u636e\u79d1\u5b66\u5bb6\u6240\u8bf4\uff1a\u8fd9\u662f\u4e00\u4e2a\u5178\u578b\u7684 \u504f\u659c\u4e8c\u5143\u5206\u7c7b\u6848\u4f8b \u3002\u56e0\u6b64\uff0c\u6211\u4eec\u9009\u62e9 AUC \u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\uff0c\u5e76\u91c7\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u9a8c\u8bc1\u65b9\u6848\u3002 \u60a8\u53ef\u4ee5\u5c06\u7279\u5f81\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c1d\u8bd5\u4e00\u4e9b\u7ecf\u5178\u65b9\u6cd5\uff08\u5982 SVM\u3001RF\uff09\u6765\u8fdb\u884c\u5206\u7c7b\uff0c\u8fd9\u5b8c\u5168\u6ca1\u95ee\u9898\uff0c\u4f46\u5374\u65e0\u6cd5\u8ba9\u60a8\u8fbe\u5230\u6700\u5148\u8fdb\u7684\u6c34\u5e73\u3002\u6b64\u5916\uff0c\u56fe\u50cf\u5927\u5c0f\u4e3a 1024x1024\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u6a21\u578b\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u7ba1\u600e\u6837\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u7531\u4e8e\u56fe\u50cf\u662f\u7070\u5ea6\u7684\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u8fdb\u884c\u4efb\u4f55\u8f6c\u6362\u3002\u6211\u4eec\u5c06\u628a\u56fe\u50cf\u5927\u5c0f\u8c03\u6574\u4e3a 256x256\uff0c\u4f7f\u5176\u66f4\u5c0f\uff0c\u5e76\u4f7f\u7528\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684 AUC \u4f5c\u4e3a\u8861\u91cf\u6307\u6807\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5b83\u7684\u8868\u73b0\u5982\u4f55\u3002 import os import numpy as np import pandas as pd from PIL import Image from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from tqdm import tqdm # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u6765\u521b\u5efa\u6570\u636e\u96c6 def create_dataset ( training_df , image_dir ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\u6765\u5b58\u50a8\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c images = [] targets = [] # \u8fed\u4ee3\u5904\u7406\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u7684\u6bcf\u4e00\u884c for index , row in tqdm ( training_df . iterrows (), total = len ( training_df ), desc = \"processing images\" ): # \u83b7\u53d6\u56fe\u50cf\u6587\u4ef6\u540d image_id = row [ \"ImageId\" ] # \u6784\u5efa\u5b8c\u6574\u7684\u56fe\u50cf\u6587\u4ef6\u8def\u5f84 image_path = os . path . join ( image_dir , image_id ) # \u6253\u5f00\u56fe\u50cf\u6587\u4ef6\u5e76\u8fdb\u884c\u5927\u5c0f\u8c03\u6574\uff08resize\uff09\u4e3a 256x256 \u50cf\u7d20\uff0c\u4f7f\u7528\u53cc\u7ebf\u6027\u63d2\u503c\uff08BILINEAR\uff09 image = Image . open ( image_path + \".png\" ) image = image . resize (( 256 , 256 ), resample = Image . BILINEAR ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 image = np . array ( image ) # \u5c06\u56fe\u50cf\u6241\u5e73\u5316\u4e3a\u4e00\u7ef4\u6570\u7ec4\uff0c\u5e76\u5c06\u5176\u6dfb\u52a0\u5230\u56fe\u50cf\u5217\u8868 image = image . ravel () images . append ( image ) # \u5c06\u76ee\u6807\u503c\uff08target\uff09\u6dfb\u52a0\u5230\u76ee\u6807\u5217\u8868 targets . append ( int ( row [ \"target\" ])) # \u5c06\u56fe\u50cf\u5217\u8868\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 images = np . array ( images ) # \u6253\u5370\u56fe\u50cf\u6570\u7ec4\u7684\u5f62\u72b6 print ( images . shape ) # \u8fd4\u56de\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c return images , targets if __name__ == \"__main__\" : # \u5b9a\u4e49CSV\u6587\u4ef6\u8def\u5f84\u548c\u56fe\u50cf\u6587\u4ef6\u76ee\u5f55\u8def\u5f84 csv_path = \"/home/abhishek/workspace/siim_png/train.csv\" image_path = \"/home/abhishek/workspace/siim_png/train_png/\" # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( csv_path ) # \u6dfb\u52a0\u4e00\u4e2a\u540d\u4e3a'kfold'\u7684\u5217\uff0c\u5e76\u521d\u59cb\u5316\u4e3a-1 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u503c\uff08target\uff09 y = df . target . values # \u4f7f\u7528\u5206\u5c42KFold\u4ea4\u53c9\u9a8c\u8bc1\u5c06\u6570\u636e\u96c6\u5206\u62105\u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u904d\u5386\u6bcf\u4e2a\u6298\uff08fold\uff09 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u904d\u5386\u6bcf\u4e2a\u6298 for fold_ in range ( 5 ): # \u83b7\u53d6\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtrain , ytrain = create_dataset ( train_df , image_path ) # \u521b\u5efa\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtest , ytest = create_dataset ( test_df , image_path ) # \u521d\u59cb\u5316\u4e00\u4e2a\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668 clf = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u5206\u7c7b\u5668 clf . fit ( xtrain , ytrain ) # \u4f7f\u7528\u5206\u7c7b\u5668\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b\uff0c\u5e76\u83b7\u53d6\u6982\u7387\u503c preds = clf . predict_proba ( xtest )[:, 1 ] # \u6253\u5370\u6298\u6570\uff08fold\uff09\u548cAUC\uff08ROC\u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff09 print ( f \"FOLD: { fold_ } \" ) print ( f \"AUC = { metrics . roc_auc_score ( ytest , preds ) } \" ) print ( \"\" ) \u5e73\u5747 AUC \u503c\u7ea6\u4e3a 0.72\u3002\u8fd9\u8fd8\u4e0d\u9519\uff0c\u4f46\u6211\u4eec\u5e0c\u671b\u80fd\u505a\u5f97\u66f4\u597d\u3002\u4f60\u53ef\u4ee5\u5c06\u8fd9\u79cd\u65b9\u6cd5\u7528\u4e8e\u56fe\u50cf\uff0c\u8fd9\u4e5f\u662f\u5b83\u5728\u4ee5\u524d\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002SVM \u5728\u56fe\u50cf\u6570\u636e\u96c6\u65b9\u9762\u76f8\u5f53\u6709\u540d\u3002\u6df1\u5ea6\u5b66\u4e60\u5df2\u88ab\u8bc1\u660e\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u6700\u5148\u8fdb\u65b9\u6cd5\uff0c\u56e0\u6b64\u6211\u4eec\u4e0b\u4e00\u6b65\u53ef\u4ee5\u8bd5\u8bd5\u5b83\u3002 \u5173\u4e8e\u6df1\u5ea6\u5b66\u4e60\u7684\u5386\u53f2\u4ee5\u53ca\u8c01\u53d1\u660e\u4e86\u4ec0\u4e48\uff0c\u6211\u5c31\u4e0d\u591a\u8bf4\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u6700\u8457\u540d\u7684\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e4b\u4e00 AlexNet\u3002 \u56fe 3\uff1aAlexNet \u67b6\u67849 \u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u4e2d\u7684\u8f93\u5165\u5927\u5c0f\u4e0d\u662f 224x224 \u800c\u662f 227x227 \u5982\u4eca\uff0c\u4f60\u53ef\u80fd\u4f1a\u8bf4\u8fd9\u53ea\u662f\u4e00\u4e2a\u57fa\u672c\u7684 \u6df1\u5ea6\u5377\u79ef\u795e\u7ecf\u7f51\u7edc \uff0c\u4f46\u5b83\u5374\u662f\u8bb8\u591a\u65b0\u578b\u6df1\u5ea6\u7f51\u7edc\uff08\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff09\u7684\u57fa\u7840\u3002\u6211\u4eec\u770b\u5230\uff0c\u56fe 3 \u4e2d\u7684\u7f51\u7edc\u662f\u4e00\u4e2a\u5177\u6709\u4e94\u4e2a\u5377\u79ef\u5c42\u3001\u4e24\u4e2a\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u3002\u6211\u4eec\u770b\u5230\u8fd8\u6709\u6700\u5927\u6c60\u5316\u3002\u8fd9\u662f\u4ec0\u4e48\u610f\u601d\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5728\u8fdb\u884c\u6df1\u5ea6\u5b66\u4e60\u65f6\u4f1a\u9047\u5230\u7684\u4e00\u4e9b\u672f\u8bed\u3002 \u56fe 4\uff1a\u56fe\u50cf\u5927\u5c0f\u4e3a 8x8\uff0c\u6ee4\u6ce2\u5668\u5927\u5c0f\u4e3a 3x3\uff0c\u6b65\u957f\u4e3a 2\u3002 \u56fe 4 \u5f15\u5165\u4e86\u4e24\u4e2a\u65b0\u672f\u8bed\uff1a\u6ee4\u6ce2\u5668\u548c\u6b65\u957f\u3002 \u6ee4\u6ce2\u5668 \u662f\u7531\u7ed9\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u7684\u4e8c\u7ef4\u77e9\u9635\uff0c\u7531\u6307\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u3002 Kaiming\u6b63\u6001\u521d\u59cb\u5316 \uff0c\u662f\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u7684\u6700\u4f73\u9009\u62e9\u3002\u8fd9\u662f\u56e0\u4e3a\u5927\u591a\u6570\u73b0\u4ee3\u7f51\u7edc\u90fd\u4f7f\u7528 ReLU \uff08\u6574\u6d41\u7ebf\u6027\u5355\u5143\uff09\u6fc0\u6d3b\u51fd\u6570\uff0c\u9700\u8981\u9002\u5f53\u7684\u521d\u59cb\u5316\u6765\u907f\u514d\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff08\u68af\u5ea6\u8d8b\u8fd1\u4e8e\u96f6\uff0c\u7f51\u7edc\u6743\u91cd\u4e0d\u53d8\uff09\u3002\u8be5\u6ee4\u6ce2\u5668\u4e0e\u56fe\u50cf\u8fdb\u884c\u5377\u79ef\u3002\u5377\u79ef\u4e0d\u8fc7\u662f\u6ee4\u6ce2\u5668\u4e0e\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u5f53\u524d\u91cd\u53e0\u50cf\u7d20\u4e4b\u95f4\u7684\u5143\u7d20\u76f8\u4e58\u7684\u603b\u548c\u3002\u60a8\u53ef\u4ee5\u5728\u4efb\u4f55\u9ad8\u4e2d\u6570\u5b66\u6559\u79d1\u4e66\u4e2d\u9605\u8bfb\u66f4\u591a\u5173\u4e8e\u5377\u79ef\u7684\u5185\u5bb9\u3002\u6211\u4eec\u4ece\u56fe\u50cf\u7684\u5de6\u4e0a\u89d2\u5f00\u59cb\u5bf9\u6ee4\u955c\u8fdb\u884c\u5377\u79ef\uff0c\u7136\u540e\u6c34\u5e73\u79fb\u52a8\u6ee4\u955c\u3002\u5982\u679c\u79fb\u52a8 1 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 1\uff1b\u5982\u679c\u79fb\u52a8 2 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 2\u3002 \u5373\u4f7f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u4f8b\u5982\u5728\u95ee\u9898\u548c\u56de\u7b54\u7cfb\u7edf\u4e2d\u9700\u8981\u4ece\u5927\u91cf\u6587\u672c\u8bed\u6599\u4e2d\u7b5b\u9009\u7b54\u6848\u65f6\uff0c\u6b65\u957f\u4e5f\u662f\u4e00\u4e2a\u975e\u5e38\u6709\u7528\u7684\u6982\u5ff5\u3002\u5f53\u6211\u4eec\u5728\u6c34\u5e73\u65b9\u5411\u4e0a\u8d70\u5230\u5c3d\u5934\u65f6\uff0c\u5c31\u4f1a\u4ee5\u540c\u6837\u7684\u6b65\u957f\u5782\u76f4\u5411\u4e0b\u79fb\u52a8\u8fc7\u6ee4\u5668\uff0c\u4ece\u5de6\u4fa7\u5f00\u59cb\u3002\u56fe 4 \u8fd8\u663e\u793a\u4e86\u8fc7\u6ee4\u5668\u79fb\u51fa\u56fe\u50cf\u7684\u60c5\u51b5\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u65e0\u6cd5\u8ba1\u7b97\u5377\u79ef\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u8df3\u8fc7\u5b83\u3002\u5982\u679c\u4e0d\u60f3\u8df3\u8fc7\uff0c\u5219\u9700\u8981\u5bf9\u56fe\u50cf\u8fdb\u884c \u586b\u5145\uff08pad\uff09 \u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5377\u79ef\u4f1a\u51cf\u5c0f\u56fe\u50cf\u7684\u5927\u5c0f\u3002\u586b\u5145\u4e5f\u662f\u4fdd\u6301\u56fe\u50cf\u5927\u5c0f\u4e0d\u53d8\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u5728\u56fe 4 \u4e2d\uff0c\u4e00\u4e2a 3x3 \u6ee4\u6ce2\u5668\u6b63\u5728\u6c34\u5e73\u548c\u5782\u76f4\u79fb\u52a8\uff0c\u6bcf\u6b21\u79fb\u52a8\u90fd\u4f1a\u5206\u522b\u8df3\u8fc7\u4e24\u5217\u548c\u4e24\u884c\uff08\u5373\u50cf\u7d20\uff09\u3002\u7531\u4e8e\u5b83\u8df3\u8fc7\u4e86\u4e24\u4e2a\u50cf\u7d20\uff0c\u6240\u4ee5\u6b65\u957f = 2\u3002\u56e0\u6b64\u56fe\u50cf\u5927\u5c0f\u4e3a [(8-3) / 2] + 1 = 3.5\u3002\u6211\u4eec\u53d6 3.5 \u7684\u4e0b\u9650\uff0c\u6240\u4ee5\u662f 3x3\u3002\u60a8\u53ef\u4ee5\u5728\u8349\u7a3f\u7eb8\u4e0a\u8fdb\u884c\u5c1d\u8bd5\u3002 \u56fe 5\uff1a\u901a\u8fc7\u586b\u5145\uff0c\u6211\u4eec\u53ef\u4ee5\u63d0\u4f9b\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\u7684\u56fe\u50cf \u6211\u4eec\u53ef\u4ee5\u4ece\u56fe 5 \u4e2d\u770b\u5230\u586b\u5145\u7684\u6548\u679c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e00\u4e2a 3x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5b83\u4ee5 1 \u7684\u6b65\u957f\u79fb\u52a8\u3002\u539f\u59cb\u56fe\u50cf\u7684\u5927\u5c0f\u4e3a 6x6\uff0c\u6211\u4eec\u6dfb\u52a0\u4e86 1 \u7684 \u586b\u5145 \u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u751f\u6210\u7684\u56fe\u50cf\u5c06\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\uff0c\u5373 6x6\u3002\u5728\u5904\u7406\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u7684\u53e6\u4e00\u4e2a\u76f8\u5173\u672f\u8bed\u662f \u81a8\u80c0\uff08dilation\uff09 \uff0c\u5982\u56fe 6 \u6240\u793a\u3002 \u56fe 6\uff1a\u81a8\u80c0\uff08dilation\uff09\u7684\u4f8b\u5b50 \u5728\u81a8\u80c0\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u5c06\u6ee4\u6ce2\u5668\u6269\u5927 N-1\uff0c\u5176\u4e2d N \u662f\u81a8\u80c0\u7387\u7684\u503c\uff0c\u6216\u7b80\u79f0\u4e3a\u81a8\u80c0\u3002\u5728\u8fd9\u79cd\u5e26\u81a8\u80c0\u7684\u5185\u6838\u4e2d\uff0c\u6bcf\u6b21\u5377\u79ef\u90fd\u4f1a\u8df3\u8fc7\u4e00\u4e9b\u50cf\u7d20\u3002\u8fd9\u5728\u5206\u5272\u4efb\u52a1\u4e2d\u5c24\u4e3a\u6709\u6548\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u4e8c\u7ef4\u5377\u79ef\u3002 \u8fd8\u6709\u4e00\u7ef4\u5377\u79ef\u548c\u66f4\u9ad8\u7ef4\u5ea6\u7684\u5377\u79ef\u3002\u5b83\u4eec\u90fd\u57fa\u4e8e\u76f8\u540c\u7684\u57fa\u672c\u6982\u5ff5\u3002 \u63a5\u4e0b\u6765\u662f \u6700\u5927\u6c60\u5316\uff08Max pooling\uff09 \u3002\u6700\u5927\u503c\u6c60\u53ea\u662f\u4e00\u4e2a\u8fd4\u56de\u6700\u5927\u503c\u7684\u6ee4\u6ce2\u5668\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u63d0\u53d6\u7684\u4e0d\u662f\u5377\u79ef\uff0c\u800c\u662f\u50cf\u7d20\u7684\u6700\u5927\u503c\u3002\u540c\u6837\uff0c \u5e73\u5747\u6c60\u5316\uff08average pooling\uff09 \u6216 \u5747\u503c\u6c60\u5316\uff08mean pooling\uff09 \u4f1a\u8fd4\u56de\u50cf\u7d20\u7684\u5e73\u5747\u503c\u3002\u5b83\u4eec\u7684\u4f7f\u7528\u65b9\u6cd5\u4e0e\u5377\u79ef\u6838\u76f8\u540c\u3002\u6c60\u5316\u6bd4\u5377\u79ef\u66f4\u5feb\uff0c\u662f\u4e00\u79cd\u5bf9\u56fe\u50cf\u8fdb\u884c\u7f29\u51cf\u91c7\u6837\u7684\u65b9\u6cd5\u3002\u6700\u5927\u6c60\u5316\u53ef\u68c0\u6d4b\u8fb9\u7f18\uff0c\u5e73\u5747\u6c60\u5316\u53ef\u5e73\u6ed1\u56fe\u50cf\u3002 \u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u548c\u6df1\u5ea6\u5b66\u4e60\u7684\u6982\u5ff5\u592a\u591a\u4e86\u3002\u6211\u6240\u8ba8\u8bba\u7684\u662f\u4e00\u4e9b\u57fa\u7840\u77e5\u8bc6\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5728 PyTorch \u4e2d\u6784\u5efa\u7b2c\u4e00\u4e2a\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u505a\u597d\u4e86\u5145\u5206\u51c6\u5907\u3002PyTorch \u63d0\u4f9b\u4e86\u4e00\u79cd\u76f4\u89c2\u800c\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u5b9e\u73b0\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e14\u4f60\u4e0d\u9700\u8981\u5173\u5fc3\u53cd\u5411\u4f20\u64ad\u3002\u6211\u4eec\u7528\u4e00\u4e2a python \u7c7b\u548c\u4e00\u4e2a\u524d\u9988\u51fd\u6570\u6765\u5b9a\u4e49\u7f51\u7edc\uff0c\u544a\u8bc9 PyTorch \u5404\u5c42\u4e4b\u95f4\u5982\u4f55\u8fde\u63a5\u3002\u5728 PyTorch \u4e2d\uff0c\u56fe\u50cf\u7b26\u53f7\u662f BS\u3001C\u3001H\u3001W\uff0c\u5176\u4e2d\uff0cBS \u662f\u6279\u5927\u5c0f\uff0cC \u662f\u901a\u9053\uff0cH \u662f\u9ad8\u5ea6\uff0cW \u662f\u5bbd\u5ea6\u3002\u8ba9\u6211\u4eec\u770b\u770b PyTorch \u662f\u5982\u4f55\u5b9e\u73b0 AlexNet \u7684\u3002 import torch import torch.nn as nn import torch.nn.functional as F class AlexNet ( nn . Module ): def __init__ ( self ): super ( AlexNet , self ) . __init__ () self . conv1 = nn . Conv2d ( in_channels = 3 , out_channels = 96 , kernel_size = 11 , stride = 4 , padding = 0 ) self . pool1 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv2 = nn . Conv2d ( in_channels = 96 , out_channels = 256 , kernel_size = 5 , stride = 1 , padding = 2 ) self . pool2 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv3 = nn . Conv2d ( in_channels = 256 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv4 = nn . Conv2d ( in_channels = 384 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv5 = nn . Conv2d ( in_channels = 384 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ) self . pool3 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . fc1 = nn . Linear ( in_features = 9216 , out_features = 4096 ) self . dropout1 = nn . Dropout ( 0.5 ) self . fc2 = nn . Linear ( in_features = 4096 , out_features = 4096 ) self . dropout2 = nn . Dropout ( 0.5 ) self . fc3 = nn . Linear ( in_features = 4096 , out_features = 1000 ) def forward ( self , image ): bs , c , h , w = image . size () x = F . relu ( self . conv1 ( image )) # size: (bs, 96, 55, 55) x = self . pool1 ( x ) # size: (bs, 96, 27, 27) x = F . relu ( self . conv2 ( x )) # size: (bs, 256, 27, 27) x = self . pool2 ( x ) # size: (bs, 256, 13, 13) x = F . relu ( self . conv3 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv4 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv5 ( x )) # size: (bs, 256, 13, 13) x = self . pool3 ( x ) # size: (bs, 256, 6, 6) x = x . view ( bs , - 1 ) # size: (bs, 9216) x = F . relu ( self . fc1 ( x )) # size: (bs, 4096) x = self . dropout1 ( x ) # size: (bs, 4096) # dropout does not change size # dropout is used for regularization # 0.3 dropout means that only 70% of the nodes # of the current layer are used for the next layer x = F . relu ( self . fc2 ( x )) # size: (bs, 4096) x = self . dropout2 ( x ) # size: (bs, 4096) x = F . relu ( self . fc3 ( x )) # size: (bs, 1000) # 1000 is number of classes in ImageNet Dataset # softmax is an activation function that converts # linear output to probabilities that add up to 1 # for each sample in the batch x = torch . softmax ( x , axis = 1 ) # size: (bs, 1000) return x \u5982\u679c\u60a8\u6709\u4e00\u5e45 3x227x227 \u7684\u56fe\u50cf\uff0c\u5e76\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11 \u7684\u5377\u79ef\u6ee4\u6ce2\u5668\uff0c\u8fd9\u610f\u5473\u7740\u60a8\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5e76\u4e0e\u4e00\u4e2a\u5927\u5c0f\u4e3a 227x227x3 \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u5377\u79ef\u3002\u8f93\u51fa\u901a\u9053\u7684\u6570\u91cf\u5c31\u662f\u5206\u522b\u5e94\u7528\u4e8e\u56fe\u50cf\u7684\u76f8\u540c\u5927\u5c0f\u7684\u4e0d\u540c\u5377\u79ef\u6ee4\u6ce2\u5668\u7684\u6570\u91cf\u3002 \u56e0\u6b64\uff0c\u5728\u7b2c\u4e00\u4e2a\u5377\u79ef\u5c42\u4e2d\uff0c\u8f93\u5165\u901a\u9053\u662f 3\uff0c\u4e5f\u5c31\u662f\u539f\u59cb\u8f93\u5165\uff0c\u5373 R\u3001G\u3001B \u4e09\u901a\u9053\u3002PyTorch \u7684 torchvision \u63d0\u4f9b\u4e86\u8bb8\u591a\u4e0e AlexNet \u7c7b\u4f3c\u7684\u4e0d\u540c\u6a21\u578b\uff0c\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0cAlexNet \u7684\u5b9e\u73b0\u4e0e torchvision \u7684\u5b9e\u73b0\u5e76\u4e0d\u76f8\u540c\u3002Torchvision \u7684 AlexNet \u5b9e\u73b0\u662f\u4ece\u53e6\u4e00\u7bc7\u8bba\u6587\u4e2d\u4fee\u6539\u800c\u6765\u7684 AlexNet\uff1a Krizhevsky, A. One weird trick for parallelizing convolutional neural networks. CoRR, abs/1404.5997, 2014. \u4f60\u53ef\u4ee5\u4e3a\u81ea\u5df1\u7684\u4efb\u52a1\u8bbe\u8ba1\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u5f88\u591a\u65f6\u5019\uff0c\u4ece\u96f6\u505a\u8d77\u662f\u4e2a\u4e0d\u9519\u7684\u4e3b\u610f\u3002\u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u7f51\u7edc\uff0c\u7528\u4e8e\u533a\u5206\u56fe\u50cf\u6709\u65e0\u6c14\u80f8\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u51c6\u5907\u4e00\u4e9b\u6587\u4ef6\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u4e00\u4e2a\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\uff0c\u5373 train.csv\uff0c\u4f46\u589e\u52a0\u4e00\u5217 kfold\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e94\u4e2a\u6587\u4ef6\u5939\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u5df2\u7ecf\u6f14\u793a\u4e86\u5982\u4f55\u9488\u5bf9\u4e0d\u540c\u7684\u6570\u636e\u96c6\u521b\u5efa\u6298\u53e0\uff0c\u56e0\u6b64\u6211\u5c06\u8df3\u8fc7\u8fd9\u4e00\u90e8\u5206\uff0c\u7559\u4f5c\u7ec3\u4e60\u3002\u5bf9\u4e8e\u57fa\u4e8e PyTorch \u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u9700\u8981\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u7684\u76ee\u7684\u662f\u8fd4\u56de\u4e00\u4e2a\u6570\u636e\u9879\u6216\u6570\u636e\u6837\u672c\u3002\u8fd9\u4e2a\u6570\u636e\u6837\u672c\u5e94\u8be5\u5305\u542b\u8bad\u7ec3\u6216\u8bc4\u4f30\u6a21\u578b\u6240\u9700\u7684\u6240\u6709\u5185\u5bb9\u3002 import torch import numpy as np from PIL import Image from PIL import ImageFile ImageFile . LOAD_TRUNCATED_IMAGES = True # \u5b9a\u4e49\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u5904\u7406\u56fe\u50cf\u5206\u7c7b\u4efb\u52a1 class ClassificationDataset : def __init__ ( self , image_paths , targets , resize = None , augmentations = None ): # \u56fe\u50cf\u6587\u4ef6\u8def\u5f84\u5217\u8868 self . image_paths = image_paths # \u76ee\u6807\u6807\u7b7e\u5217\u8868 self . targets = targets # \u56fe\u50cf\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . resize = resize # \u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . augmentations = augmentations def __len__ ( self ): # \u8fd4\u56de\u6570\u636e\u96c6\u7684\u5927\u5c0f\uff0c\u5373\u56fe\u50cf\u6570\u91cf return len ( self . image_paths ) def __getitem__ ( self , item ): # \u83b7\u53d6\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e2a\u6837\u672c image = Image . open ( self . image_paths [ item ]) image = image . convert ( \"RGB\" ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aRGB\u683c\u5f0f # \u83b7\u53d6\u8be5\u6837\u672c\u7684\u76ee\u6807\u6807\u7b7e targets = self . targets [ item ] if self . resize is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u5c06\u56fe\u50cf\u8fdb\u884c\u5c3a\u5bf8\u8c03\u6574 image = image . resize (( self . resize [ 1 ], self . resize [ 0 ]), resample = Image . BILINEAR ) image = np . array ( image ) if self . augmentations is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u5e94\u7528\u6570\u636e\u589e\u5f3a augmented = self . augmentations ( image = image ) image = augmented [ \"image\" ] # \u5c06\u56fe\u50cf\u901a\u9053\u987a\u5e8f\u8c03\u6574\u4e3a(C, H, W)\u7684\u5f62\u5f0f\uff0c\u5e76\u8f6c\u6362\u4e3afloat32\u7c7b\u578b image = np . transpose ( image , ( 2 , 0 , 1 )) . astype ( np . float32 ) # \u8fd4\u56de\u6837\u672c\uff0c\u5305\u62ec\u56fe\u50cf\u548c\u5bf9\u5e94\u7684\u76ee\u6807\u6807\u7b7e return { \"image\" : torch . tensor ( image , dtype = torch . float ), \"targets\" : torch . tensor ( targets , dtype = torch . long ), } \u73b0\u5728\u6211\u4eec\u9700\u8981 engine.py\u3002engine.py \u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u770b\u770b engine.py \u662f\u5982\u4f55\u7f16\u5199\u7684\u3002 import torch import torch.nn as nn from tqdm import tqdm # \u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u51fd\u6570 def train ( data_loader , model , optimizer , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bad\u7ec3\u6a21\u5f0f model . train () for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u548c\u76ee\u6807\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) targets = targets . to ( device , dtype = torch . float ) # \u5c06\u4f18\u5316\u5668\u4e2d\u7684\u68af\u5ea6\u5f52\u96f6 optimizer . zero_grad () # \u524d\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u6a21\u578b\u9884\u6d4b outputs = model ( inputs ) # \u4f7f\u7528\u5e26\u903b\u8f91\u65af\u8482\u51fd\u6570\u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\u8ba1\u7b97\u635f\u5931 loss = nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) # \u53cd\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u68af\u5ea6\u5e76\u66f4\u65b0\u6a21\u578b\u6743\u91cd loss . backward () optimizer . step () # \u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u7684\u51fd\u6570 def evaluate ( data_loader , model , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bc4\u4f30\u6a21\u5f0f\uff08\u4e0d\u8fdb\u884c\u68af\u5ea6\u8ba1\u7b97\uff09 model . eval () # \u521d\u59cb\u5316\u5217\u8868\u4ee5\u5b58\u50a8\u771f\u5b9e\u76ee\u6807\u548c\u6a21\u578b\u9884\u6d4b final_targets = [] final_outputs = [] with torch . no_grad (): for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) # \u83b7\u53d6\u6a21\u578b\u9884\u6d4b output = model ( inputs ) # \u5c06\u76ee\u6807\u548c\u8f93\u51fa\u8f6c\u6362\u4e3aCPU\u548cPython\u5217\u8868 targets = targets . detach () . cpu () . numpy () . tolist () output = output . detach () . cpu () . numpy () . tolist () # \u5c06\u5217\u8868\u6269\u5c55\u4ee5\u5305\u542b\u6279\u6b21\u6570\u636e final_targets . extend ( targets ) final_outputs . extend ( output ) # \u8fd4\u56de\u6700\u7ec8\u7684\u6a21\u578b\u9884\u6d4b\u548c\u771f\u5b9e\u76ee\u6807 return final_outputs , final_targets \u6709\u4e86 engine.py\uff0c\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\uff1amodel.py\u3002model.py \u5c06\u5305\u542b\u6211\u4eec\u7684\u6a21\u578b\u3002\u628a\u6a21\u578b\u4e0e\u8bad\u7ec3\u5206\u5f00\u662f\u4e2a\u597d\u4e3b\u610f\uff0c\u56e0\u4e3a\u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u8bd5\u9a8c\u4e0d\u540c\u7684\u6a21\u578b\u548c\u4e0d\u540c\u7684\u67b6\u6784\u3002\u540d\u4e3a pretrainedmodels \u7684 PyTorch \u5e93\u4e2d\u6709\u5f88\u591a\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\uff0c\u5982 AlexNet\u3001ResNet\u3001DenseNet \u7b49\u3002\u8fd9\u4e9b\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\u662f\u5728\u540d\u4e3a ImageNet \u7684\u5927\u578b\u56fe\u50cf\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u51fa\u6765\u7684\u3002\u5728 ImageNet \u4e0a\u8bad\u7ec3\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5b83\u4eec\u7684\u6743\u91cd\uff0c\u4e5f\u53ef\u4ee5\u4e0d\u4f7f\u7528\u8fd9\u4e9b\u6743\u91cd\u3002\u5982\u679c\u6211\u4eec\u4e0d\u4f7f\u7528 ImageNet \u6743\u91cd\u8fdb\u884c\u8bad\u7ec3\uff0c\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u7f51\u7edc\u5c06\u4ece\u5934\u5f00\u59cb\u5b66\u4e60\u4e00\u5207\u3002\u8fd9\u5c31\u662f model.py \u7684\u6837\u5b50\u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 4096 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 4096 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5982\u679c\u4f60\u6253\u5370\u4e86\u7f51\u7edc\uff0c\u4f1a\u5f97\u5230\u5982\u4e0b\u8f93\u51fa\uff1a AlexNet ( ( avgpool ): AdaptiveAvgPool2d ( output_size = ( 6 , 6 )) ( _features ): Sequential ( ( 0 ): Conv2d ( 3 , 64 , kernel_size = ( 11 , 11 ), stride = ( 4 , 4 ), padding = ( 2 , 2 )) ( 1 ): ReLU ( inplace = True ) ( 2 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 3 ): Conv2d ( 64 , 192 , kernel_size = ( 5 , 5 ), stride = ( 1 , 1 ), padding = ( 2 , 2 )) ( 4 ): ReLU ( inplace = True ) ( 5 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 6 ): Conv2d ( 192 , 384 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 7 ): ReLU ( inplace = True ) ( 8 ): Conv2d ( 384 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 9 ): ReLU ( inplace = True ) ( 10 ): Conv2d ( 256 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 11 ): ReLU ( inplace = True ) ( 12 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , eil_mode = False )) ( dropout0 ): Dropout ( p = 0.5 , inplace = False ) ( linear0 ): Linear ( in_features = 9216 , out_features = 4096 , bias = True ) ( relu0 ): ReLU ( inplace = True ) ( dropout1 ): Dropout ( p = 0.5 , inplace = False ) ( linear1 ): Linear ( in_features = 4096 , out_features = 4096 , bias = True ) ( relu1 ): ReLU ( inplace = True ) ( last_linear ): Sequential ( ( 0 ): BatchNorm1d ( 4096 , eps = 1e-05 , momentum = 0.1 , affine = True , rack_running_stats = True ) ( 1 ): Dropout ( p = 0.25 , inplace = False ) ( 2 ): Linear ( in_features = 4096 , out_features = 2048 , bias = True ) ( 3 ): ReLU () ( 4 ): BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 , affine = True , track_running_stats = True ) ( 5 ): Dropout ( p = 0.5 , inplace = False ) ( 6 ): Linear ( in_features = 2048 , out_features = 1 , bias = True ) ) ) \u73b0\u5728\uff0c\u4e07\u4e8b\u4ff1\u5907\uff0c\u53ef\u4ee5\u5f00\u59cb\u8bad\u7ec3\u4e86\u3002\u6211\u4eec\u5c06\u4f7f\u7528 train.py \u8bad\u7ec3\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import torch from sklearn import metrics from sklearn.model_selection import train_test_split import dataset import engine from model import get_model if __name__ == \"__main__\" : # \u5b9a\u4e49\u6570\u636e\u8def\u5f84\u3001\u8bbe\u5907\u3001\u8fed\u4ee3\u6b21\u6570 data_path = \"/home/abhishek/workspace/siim_png/\" device = \"cuda\" # \u4f7f\u7528GPU\u52a0\u901f epochs = 10 # \u4eceCSV\u6587\u4ef6\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( os . path . join ( data_path , \"train.csv\" )) images = df . ImageId . values . tolist () images = [ os . path . join ( data_path , \"train_png\" , i + \".png\" ) for i in images ] targets = df . target . values # \u83b7\u53d6\u9884\u8bad\u7ec3\u7684\u6a21\u578b model = get_model ( pretrained = True ) model . to ( device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\uff0c\u7528\u4e8e\u6570\u636e\u6807\u51c6\u5316 mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) # \u6570\u636e\u589e\u5f3a\uff0c\u5c06\u56fe\u50cf\u6807\u51c6\u5316 aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5212\u5206\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 train_images , valid_images , train_targets , valid_targets = train_test_split ( images , targets , stratify = targets , random_state = 42 ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u548c\u9a8c\u8bc1\u6570\u636e\u96c6 train_dataset = dataset . ClassificationDataset ( image_paths = train_images , targets = train_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = 16 , shuffle = True , num_workers = 4 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = dataset . ClassificationDataset ( image_paths = valid_images , targets = valid_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = 16 , shuffle = False , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u8bad\u7ec3\u5faa\u73af for epoch in range ( epochs ): # \u8bad\u7ec3\u6a21\u578b engine . train ( train_loader , model , optimizer , device = device ) # \u8bc4\u4f30\u6a21\u578b\u6027\u80fd predictions , valid_targets = engine . evaluate ( valid_loader , model , device = device ) # \u8ba1\u7b97ROC AUC\u5206\u6570\u5e76\u6253\u5370 roc_auc = metrics . roc_auc_score ( valid_targets , predictions ) print ( f \"Epoch= { epoch } , Valid ROC AUC= { roc_auc } \" ) \u8ba9\u6211\u4eec\u5728\u6ca1\u6709\u9884\u8bad\u7ec3\u6743\u91cd\u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff1a Epoch = 0 , Valid ROC AUC = 0.5737161981475328 Epoch = 1 , Valid ROC AUC = 0.5362868001588292 Epoch = 2 , Valid ROC AUC = 0.6163448214387008 Epoch = 3 , Valid ROC AUC = 0.6119219143780944 Epoch = 4 , Valid ROC AUC = 0.6229718888519726 Epoch = 5 , Valid ROC AUC = 0.5983014999635341 Epoch = 6 , Valid ROC AUC = 0.5523236874306134 Epoch = 7 , Valid ROC AUC = 0.4717721611306046 Epoch = 8 , Valid ROC AUC = 0.6473408263980617 Epoch = 9 , Valid ROC AUC = 0.6639862888260415 AUC \u7ea6\u4e3a 0.66\uff0c\u751a\u81f3\u4f4e\u4e8e\u6211\u4eec\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u4f1a\u53d1\u751f\u4ec0\u4e48\u60c5\u51b5\uff1f Epoch = 0 , Valid ROC AUC = 0.5730387429803165 Epoch = 1 , Valid ROC AUC = 0.5319813942934937 Epoch = 2 , Valid ROC AUC = 0.627111577514323 Epoch = 3 , Valid ROC AUC = 0.6819736959393209 Epoch = 4 , Valid ROC AUC = 0.5747117168950512 Epoch = 5 , Valid ROC AUC = 0.5994619255609669 Epoch = 6 , Valid ROC AUC = 0.5080889443530546 Epoch = 7 , Valid ROC AUC = 0.6323792776512727 Epoch = 8 , Valid ROC AUC = 0.6685753182661686 Epoch = 9 , Valid ROC AUC = 0.6861802387300147 \u73b0\u5728\u7684 AUC \u597d\u4e86\u5f88\u591a\u3002\u4e0d\u8fc7\uff0c\u5b83\u4ecd\u7136\u8f83\u4f4e\u3002\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u597d\u5904\u662f\u53ef\u4ee5\u8f7b\u677e\u5c1d\u8bd5\u591a\u79cd\u4e0d\u540c\u7684\u6a21\u578b\u3002\u8ba9\u6211\u4eec\u8bd5\u8bd5\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u7684 resnet18 \u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 512 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 512 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5728\u5c1d\u8bd5\u8be5\u6a21\u578b\u65f6\uff0c\u6211\u8fd8\u5c06\u56fe\u50cf\u5927\u5c0f\u6539\u4e3a 512x512\uff0c\u5e76\u6dfb\u52a0\u4e86\u4e00\u4e2a\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\uff0c\u6bcf 3 \u4e2aepochs\u540e\u5c06\u5b66\u4e60\u7387\u4e58\u4ee5 0.5\u3002 Epoch = 0 , Valid ROC AUC = 0.5988225569880796 Epoch = 1 , Valid ROC AUC = 0.730349343208836 Epoch = 2 , Valid ROC AUC = 0.5870943169939142 Epoch = 3 , Valid ROC AUC = 0.5775864444138311 Epoch = 4 , Valid ROC AUC = 0.7330502499939224 Epoch = 5 , Valid ROC AUC = 0.7500336296524395 Epoch = 6 , Valid ROC AUC = 0.7563722113724951 Epoch = 7 , Valid ROC AUC = 0.7987463837994215 Epoch = 8 , Valid ROC AUC = 0.798505708937384 Epoch = 9 , Valid ROC AUC = 0.8025477500546988 \u8fd9\u4e2a\u6a21\u578b\u4f3c\u4e4e\u8868\u73b0\u6700\u597d\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u8c03\u6574 AlexNet \u4e2d\u7684\u4e0d\u540c\u53c2\u6570\u548c\u56fe\u50cf\u5927\u5c0f\uff0c\u4ee5\u83b7\u5f97\u66f4\u597d\u7684\u5206\u6570\u3002 \u4f7f\u7528\u589e\u5f3a\u6280\u672f\u5c06\u8fdb\u4e00\u6b65\u63d0\u9ad8\u5f97\u5206\u3002\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u5f88\u96be\uff0c\u4f46\u5e76\u975e\u4e0d\u53ef\u80fd\u3002\u9009\u62e9 Adam \u4f18\u5316\u5668\u3001\u4f7f\u7528\u4f4e\u5b66\u4e60\u7387\u3001\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u9ad8\u70b9\u65f6\u964d\u4f4e\u5b66\u4e60\u7387\u3001\u5c1d\u8bd5\u4e00\u4e9b\u589e\u5f3a\u6280\u672f\u3001\u5c1d\u8bd5\u5bf9\u56fe\u50cf\u8fdb\u884c\u9884\u5904\u7406\uff08\u5982\u5728\u9700\u8981\u65f6\u8fdb\u884c\u88c1\u526a\uff0c\u8fd9\u4e5f\u53ef\u89c6\u4e3a\u9884\u5904\u7406\uff09\u3001\u6539\u53d8\u6279\u6b21\u5927\u5c0f\u7b49\u3002\u4f60\u53ef\u4ee5\u505a\u5f88\u591a\u4e8b\u60c5\u6765\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u3002 \u4e0e AlexNet \u76f8\u6bd4\uff0c ResNet \u7684\u7ed3\u6784\u8981\u590d\u6742\u5f97\u591a\u3002ResNet \u662f\u6b8b\u5dee\u795e\u7ecf\u7f51\u7edc\uff08Residual Neural Network\uff09\u7684\u7f29\u5199\uff0c\u7531 K. He\u3001X. Zhang\u3001S. Ren \u548c J. Sun \u5728 2015 \u5e74\u53d1\u8868\u7684\u8bba\u6587\u4e2d\u63d0\u51fa\u3002ResNet \u7531 \u6b8b\u5dee\u5757 \uff08residual blocks\uff09\u7ec4\u6210\uff0c\u901a\u8fc7\u8df3\u8fc7\u67d0\u4e9b\u5c42\uff0c\u4f7f\u77e5\u8bc6\u80fd\u591f\u4e0d\u65ad\u5728\u5404\u5c42\u4e2d\u8fdb\u884c\u4f20\u9012\u3002\u8fd9\u4e9b\u5c42\u4e4b\u95f4\u7684 \u8fde\u63a5\u88ab\u79f0\u4e3a \u8df3\u8dc3\u8fde\u63a5 \uff08skip-connections\uff09\uff0c\u56e0\u4e3a\u6211\u4eec\u8df3\u8fc7\u4e86\u4e00\u5c42\u6216\u591a\u5c42\u3002\u8df3\u8dc3\u8fde\u63a5\u901a\u8fc7\u5c06\u68af\u5ea6\u4f20\u64ad\u5230\u66f4\u591a\u5c42\u6765\u5e2e\u52a9\u89e3\u51b3\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u8bad\u7ec3\u975e\u5e38\u5927\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e0d\u4f1a\u635f\u5931\u6027\u80fd\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5927\u578b\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u4e48\u5f53\u8bad\u7ec3\u5230\u67d0\u4e00\u8282\u70b9\u4e0a\u65f6\u8bad\u7ec3\u635f\u5931\u53cd\u800c\u4f1a\u589e\u52a0\uff0c\u4f46\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u8df3\u8dc3\u8fde\u63a5\u6765\u907f\u514d\u3002\u901a\u8fc7\u56fe 7 \u53ef\u4ee5\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u56fe 7\uff1a\u7b80\u5355\u8fde\u63a5\u4e0e\u6b8b\u5dee\u8fde\u63a5\u7684\u6bd4\u8f83\u3002\u53c2\u89c1\u8df3\u8dc3\u8fde\u63a5\u3002\u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u7701\u7565\u4e86\u6700\u540e\u4e00\u5c42\u3002 \u6b8b\u5dee\u5757\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4f60\u4ece\u67d0\u4e00\u5c42\u83b7\u53d6\u8f93\u51fa\uff0c\u8df3\u8fc7\u4e00\u4e9b\u5c42\uff0c\u7136\u540e\u5c06\u8f93\u51fa\u6dfb\u52a0\u5230\u7f51\u7edc\u4e2d\u66f4\u8fdc\u7684\u4e00\u5c42\u3002\u865a\u7ebf\u8868\u793a\u8f93\u5165\u5f62\u72b6\u9700\u8981\u8c03\u6574\uff0c\u56e0\u4e3a\u4f7f\u7528\u4e86\u6700\u5927\u6c60\u5316\uff0c\u800c\u6700\u5927\u6c60\u5316\u7684\u4f7f\u7528\u4f1a\u6539\u53d8\u8f93\u51fa\u7684\u5927\u5c0f\u3002 ResNet \u6709\u591a\u79cd\u4e0d\u540c\u7684\u7248\u672c\uff1a \u6709 18 \u5c42\u300134 \u5c42\u300150 \u5c42\u3001101 \u5c42\u548c 152 \u5c42\uff0c\u6240\u6709\u8fd9\u4e9b\u5c42\u90fd\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8fdb\u884c\u4e86\u6743\u91cd\u9884\u8bad\u7ec3\u3002\u5982\u4eca\uff0c\u9884\u8bad\u7ec3\u6a21\u578b\uff08\u51e0\u4e4e\uff09\u9002\u7528\u4e8e\u6240\u6709\u60c5\u51b5\uff0c\u4f46\u8bf7\u786e\u4fdd\u60a8\u4ece\u8f83\u5c0f\u7684\u6a21\u578b\u5f00\u59cb\uff0c\u4f8b\u5982\uff0c\u4ece resnet-18 \u5f00\u59cb\uff0c\u800c\u4e0d\u662f resnet-50\u3002\u5176\u4ed6\u4e00\u4e9b ImageNet \u9884\u8bad\u7ec3\u6a21\u578b\u5305\u62ec\uff1a Inception DenseNet(different variations) NASNet PNASNet VGG Xception ResNeXt EfficientNet, etc. \u5927\u90e8\u5206\u9884\u8bad\u7ec3\u7684\u6700\u5148\u8fdb\u6a21\u578b\u53ef\u4ee5\u5728 GitHub \u4e0a\u7684 pytorch- pretrainedmodels \u8d44\u6e90\u5e93\u4e2d\u627e\u5230\uff1ahttps://github.com/Cadene/pretrained-models.pytorch\u3002\u8be6\u7ec6\u8ba8\u8bba\u8fd9\u4e9b\u6a21\u578b\u4e0d\u5728\u672c\u7ae0\uff08\u548c\u672c\u4e66\uff09\u8303\u56f4\u4e4b\u5185\u3002\u65e2\u7136\u6211\u4eec\u53ea\u5173\u6ce8\u5e94\u7528\uff0c\u90a3\u5c31\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6837\u7684\u9884\u8bad\u7ec3\u6a21\u578b\u5982\u4f55\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u3002 \u56fe 8\uff1aU-Net\u67b6\u6784 \u5206\u5272\uff08Segmentation\uff09\u662f\u8ba1\u7b97\u673a\u89c6\u89c9\u4e2d\u76f8\u5f53\u6d41\u884c\u7684\u4e00\u9879\u4efb\u52a1\u3002\u5728\u5206\u5272\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u8bd5\u56fe\u4ece\u80cc\u666f\u4e2d\u79fb\u9664/\u63d0\u53d6\u524d\u666f\u3002 \u524d\u666f\u548c\u80cc\u666f\u53ef\u4ee5\u6709\u4e0d\u540c\u7684\u5b9a\u4e49\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u8bf4\uff0c\u8fd9\u662f\u4e00\u9879\u50cf\u7d20\u5206\u7c7b\u4efb\u52a1\uff0c\u4f60\u7684\u5de5\u4f5c\u662f\u7ed9\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u6bcf\u4e2a\u50cf\u7d20\u5206\u914d\u4e00\u4e2a\u7c7b\u522b\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u6b63\u5728\u5904\u7406\u7684\u6c14\u80f8\u6570\u636e\u96c6\u5c31\u662f\u4e00\u9879\u5206\u5272\u4efb\u52a1\u3002\u5728\u8fd9\u9879\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u7ed9\u5b9a\u7684\u80f8\u90e8\u653e\u5c04\u56fe\u50cf\u8fdb\u884c\u6c14\u80f8\u5206\u5272\u3002\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u7684\u6700\u5e38\u7528\u6a21\u578b\u662f U-Net\u3002\u5176\u7ed3\u6784\u5982\u56fe 8 \u6240\u793a\u3002 U-Net \u5305\u62ec\u4e24\u4e2a\u90e8\u5206\uff1a\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u3002\u7f16\u7801\u5668\u4e0e\u60a8\u76ee\u524d\u6240\u89c1\u8fc7\u7684\u4efb\u4f55 U-Net \u90fd\u662f\u4e00\u6837\u7684\u3002\u89e3\u7801\u5668\u5219\u6709\u4e9b\u4e0d\u540c\u3002\u89e3\u7801\u5668\u7531\u4e0a\u5377\u79ef\u5c42\u7ec4\u6210\u3002\u5728\u4e0a\u5377\u79ef\uff08up-convolutions\uff09\uff08 \u8f6c\u7f6e\u5377\u79ef transposed convolutions\uff09\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u6ee4\u6ce2\u5668\uff0c\u5f53\u5e94\u7528\u5230\u4e00\u4e2a\u5c0f\u56fe\u50cf\u65f6\uff0c\u4f1a\u4ea7\u751f\u4e00\u4e2a\u5927\u56fe\u50cf\u3002\u5728 PyTorch \u4e2d\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528 ConvTranspose2d \u6765\u5b8c\u6210\u8fd9\u4e00\u64cd\u4f5c\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u4e0a\u5377\u79ef\u4e0e\u4e0a\u91c7\u6837\u5e76\u4e0d\u76f8\u540c\u3002\u4e0a\u91c7\u6837\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u8fc7\u7a0b\uff0c\u6211\u4eec\u5728\u56fe\u50cf\u4e0a\u5e94\u7528\u4e00\u4e2a\u51fd\u6570\u6765\u8c03\u6574\u5b83\u7684\u5927\u5c0f\u3002\u5728\u4e0a\u5377\u79ef\u4e2d\uff0c\u6211\u4eec\u8981\u5b66\u4e60\u6ee4\u6ce2\u5668\u3002\u6211\u4eec\u5c06\u7f16\u7801\u5668\u7684\u67d0\u4e9b\u90e8\u5206\u4f5c\u4e3a\u67d0\u4e9b\u89e3\u7801\u5668\u7684\u8f93\u5165\u3002\u8fd9\u5bf9 \u4e0a\u5377\u79ef\u5c42\u975e\u5e38\u91cd\u8981\u3002 \u8ba9\u6211\u4eec\u770b\u770b U-Net \u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import torch import torch.nn as nn from torch.nn import functional as F # \u5b9a\u4e49\u4e00\u4e2a\u53cc\u5377\u79ef\u5c42 def double_conv ( in_channels , out_channels ): conv = nn . Sequential ( nn . Conv2d ( in_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ), nn . Conv2d ( out_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ) ) return conv # \u5b9a\u4e49\u51fd\u6570\u7528\u4e8e\u88c1\u526a\u8f93\u5165\u5f20\u91cf def crop_tensor ( tensor , target_tensor ): target_size = target_tensor . size ()[ 2 ] tensor_size = tensor . size ()[ 2 ] delta = tensor_size - target_size delta = delta // 2 return tensor [:, :, delta : tensor_size - delta , delta : tensor_size - delta ] # \u5b9a\u4e49 U-Net \u6a21\u578b class UNet ( nn . Module ): def __init__ ( self ): super ( UNet , self ) . __init () # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . max_pool_2x2 = nn . MaxPool2d ( kernel_size = 2 , stride = 2 ) self . down_conv_1 = double_conv ( 1 , 64 ) self . down_conv_2 = double_conv ( 64 , 128 ) self . down_conv_3 = double_conv ( 128 , 256 ) self . down_conv_4 = double_conv ( 256 , 512 ) self . down_conv_5 = double_conv ( 512 , 1024 ) # \u5b9a\u4e49\u4e0a\u91c7\u6837\u5c42\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . up_trans_1 = nn . ConvTranspose2d ( in_channels = 1024 , out_channels = 512 , kernel_size = 2 , stride = 2 ) self . up_conv_1 = double_conv ( 1024 , 512 ) self . up_trans_2 = nn . ConvTranspose2d ( in_channels = 512 , out_channels = 256 , kernel_size = 2 , stride = 2 ) self . up_conv_2 = double_conv ( 512 , 256 ) self . up_trans_3 = nn . ConvTranspose2d ( in_channels = 256 , out_channels = 128 , kernel_size = 2 , stride = 2 ) self . up_conv_3 = double_conv ( 256 , 128 ) self . up_trans_4 = nn . ConvTranspose2d ( in_channels = 128 , out_channels = 64 , kernel_size = 2 , stride = 2 ) self . up_conv_4 = double_conv ( 128 , 64 ) # \u5b9a\u4e49\u8f93\u51fa\u5c42 self . out = nn . Conv2d ( in_channels = 64 , out_channels = 2 , kernel_size = 1 ) def forward ( self , image ): # \u7f16\u7801\u5668\u90e8\u5206 x1 = self . down_conv_1 ( image ) x2 = self . max_pool_2x2 ( x1 ) x3 = self . down_conv_2 ( x2 ) x4 = self . max_pool_2x2 ( x3 ) x5 = self . down_conv_3 ( x4 ) x6 = self . max_pool_2x2 ( x5 ) x7 = self . down_conv_4 ( x6 ) x8 = self . max_pool_2x2 ( x7 ) x9 = self . down_conv_5 ( x8 ) # \u89e3\u7801\u5668\u90e8\u5206 x = self . up_trans_1 ( x9 ) y = crop_tensor ( x7 , x ) x = self . up_conv_1 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_2 ( x ) y = crop_tensor ( x5 , x ) x = self . up_conv_2 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_3 ( x ) y = crop_tensor ( x3 , x ) x = self . up_conv_3 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_4 ( x ) y = crop_tensor ( x1 , x ) x = self . up_conv_4 ( torch . cat ([ x , y ], axis = 1 )) # \u8f93\u51fa\u5c42 out = self . out ( x ) return out if __name__ == \"__main__\" : image = torch . rand (( 1 , 1 , 572 , 572 )) model = UNet () print ( model ( image )) \u8bf7\u6ce8\u610f\uff0c\u6211\u4e0a\u9762\u5c55\u793a\u7684 U-Net \u5b9e\u73b0\u662f U-Net \u8bba\u6587\u7684\u539f\u59cb\u5b9e\u73b0\u3002\u4e92\u8054\u7f51\u4e0a\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9e\u73b0\u65b9\u6cd5\u3002 \u6709\u4e9b\u4eba\u559c\u6b22\u4f7f\u7528\u53cc\u7ebf\u6027\u91c7\u6837\u4ee3\u66ff\u8f6c\u7f6e\u5377\u79ef\u8fdb\u884c\u4e0a\u91c7\u6837\uff0c\u4f46\u8fd9\u5e76\u4e0d\u662f\u8bba\u6587\u7684\u771f\u6b63\u5b9e\u73b0\u3002\u4e0d\u8fc7\uff0c\u5b83\u7684\u6027\u80fd\u53ef\u80fd\u4f1a\u66f4\u597d\u3002\u5728\u4e0a\u56fe\u6240\u793a\u7684\u539f\u59cb\u5b9e\u73b0\u4e2d\uff0c\u6709\u4e00\u4e2a\u5355\u901a\u9053\u56fe\u50cf\uff0c\u8f93\u51fa\u4e2d\u6709\u4e24\u4e2a\u901a\u9053\uff1a\u4e00\u4e2a\u662f\u524d\u666f\uff0c\u4e00\u4e2a\u662f\u80cc\u666f\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u4e3a\u4efb\u610f\u6570\u91cf\u7684\u7c7b\u548c\u4efb\u610f\u6570\u91cf\u7684\u8f93\u5165\u901a\u9053\u8fdb\u884c\u5b9a\u5236\u3002\u5728\u6b64\u5b9e\u73b0\u4e2d\uff0c\u8f93\u5165\u56fe\u50cf\u7684\u5927\u5c0f\u4e0e\u8f93\u51fa\u56fe\u50cf\u7684\u5927\u5c0f\u4e0d\u540c\uff0c\u56e0\u4e3a\u6211\u4eec\u4f7f\u7528\u7684\u662f\u65e0\u586b\u5145\u5377\u79ef\uff08convolutions without padding\uff09\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0cU-Net \u7684\u7f16\u7801\u5668\u90e8\u5206\u53ea\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u5377\u79ef\u7f51\u7edc\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u4efb\u4f55\u7f51\u7edc\uff08\u5982 ResNet\uff09\u6765\u66ff\u6362\u5b83\u3002 \u8fd9\u79cd\u66ff\u6362\u4e5f\u53ef\u4ee5\u901a\u8fc7\u9884\u8bad\u7ec3\u6743\u91cd\u6765\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u57fa\u4e8e ResNet \u7684\u7f16\u7801\u5668\uff0c\u8be5\u7f16\u7801\u5668\u5df2\u5728 ImageNet \u548c\u901a\u7528\u89e3\u7801\u5668\u4e0a\u8fdb\u884c\u4e86\u9884\u8bad\u7ec3\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u591a\u79cd\u4e0d\u540c\u7684\u7f51\u7edc\u67b6\u6784\u6765\u4ee3\u66ff ResNet\u3002Pavel Yakubovskiy \u6240\u8457\u7684\u300aSegmentation Models Pytorch\u300b\u5c31\u662f\u8bb8\u591a\u6b64\u7c7b\u53d8\u4f53\u7684\u5b9e\u73b0\uff0c\u5176\u4e2d\u7f16\u7801\u5668\u53ef\u4ee5\u88ab\u9884\u8bad\u7ec3\u6a21\u578b\u6240\u53d6\u4ee3\u3002\u8ba9\u6211\u4eec\u5e94\u7528\u57fa\u4e8e ResNet \u7684 U-Net \u6765\u89e3\u51b3\u6c14\u80f8\u68c0\u6d4b\u95ee\u9898\u3002 \u5927\u591a\u6570\u7c7b\u4f3c\u7684\u95ee\u9898\u90fd\u6709\u4e24\u4e2a\u8f93\u5165\uff1a\u539f\u59cb\u56fe\u50cf\u548c\u63a9\u7801\uff08mask\uff09\u3002 \u5982\u679c\u6709\u591a\u4e2a\u5bf9\u8c61\uff0c\u5c31\u4f1a\u6709\u591a\u4e2a\u63a9\u7801\u3002 \u5728\u6211\u4eec\u7684\u6c14\u80f8\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f RLE\u3002RLE \u4ee3\u8868\u8fd0\u884c\u957f\u5ea6\u7f16\u7801\uff0c\u662f\u4e00\u79cd\u8868\u793a\u4e8c\u8fdb\u5236\u63a9\u7801\u4ee5\u8282\u7701\u7a7a\u95f4\u7684\u65b9\u6cd5\u3002\u6df1\u5165\u7814\u7a76 RLE \u8d85\u51fa\u4e86\u672c\u7ae0\u7684\u8303\u56f4\u3002\u56e0\u6b64\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u5f20\u8f93\u5165\u56fe\u50cf\u548c\u76f8\u5e94\u7684\u63a9\u7801\u3002\u8ba9\u6211\u4eec\u5148\u8bbe\u8ba1\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u8f93\u51fa\u56fe\u50cf\u548c\u63a9\u7801\u56fe\u50cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u521b\u5efa\u7684\u811a\u672c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u5206\u5272\u95ee\u9898\u3002\u8bad\u7ec3\u6570\u636e\u96c6\u662f\u4e00\u4e2a CSV \u6587\u4ef6\uff0c\u53ea\u5305\u542b\u56fe\u50cf ID\uff08\u4e5f\u662f\u6587\u4ef6\u540d\uff09\u3002 import os import glob import torch import numpy as np import pandas as pd from PIL import Image , ImageFile from tqdm import tqdm from collections import defaultdict from torchvision import transforms from albumentations import ( Compose , OneOf , RandomBrightnessContrast , RandomGamma , ShiftScaleRotate , ) # \u8bbe\u7f6ePIL\u56fe\u50cf\u52a0\u8f7d\u622a\u65ad\u7684\u5904\u7406 ImageFile . LOAD_TRUNCATED_IMAGES = True # \u521b\u5efaSIIM\u6570\u636e\u96c6\u7c7b class SIIMDataset ( torch . utils . data . Dataset ): def __init__ ( self , image_ids , transform = True , preprocessing_fn = None ): self . data = defaultdict ( dict ) self . transform = transform self . preprocessing_fn = preprocessing_fn # \u5b9a\u4e49\u6570\u636e\u589e\u5f3a self . aug = Compose ( [ ShiftScaleRotate ( shift_limit = 0.0625 , scale_limit = 0.1 , rotate_limit = 10 , p = 0.8 ), OneOf ( [ RandomGamma ( gamma_limit = ( 90 , 110 ) ), RandomBrightnessContrast ( brightness_limit = 0.1 , contrast_limit = 0.1 ), ], p = 0.5 , ), ] ) # \u6784\u5efa\u6570\u636e\u5b57\u5178\uff0c\u5176\u4e2d\u5305\u542b\u56fe\u50cf\u548c\u63a9\u7801\u7684\u8def\u5f84\u4fe1\u606f for imgid in image_ids : files = glob . glob ( os . path . join ( TRAIN_PATH , imgid , \"*.png\" )) self . data [ counter ] = { \"img_path\" : os . path . join ( TRAIN_PATH , imgid + \".png\" ), \"mask_path\" : os . path . join ( TRAIN_PATH , imgid + \"_mask.png\" ), } def __len__ ( self ): return len ( self . data ) def __getitem__ ( self , item ): img_path = self . data [ item ][ \"img_path\" ] mask_path = self . data [ item ][ \"mask_path\" ] # \u6253\u5f00\u56fe\u50cf\u5e76\u5c06\u5176\u8f6c\u6362\u4e3aRGB\u6a21\u5f0f img = Image . open ( img_path ) img = img . convert ( \"RGB\" ) img = np . array ( img ) # \u6253\u5f00\u63a9\u7801\u56fe\u50cf\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6570 mask = Image . open ( mask_path ) mask = ( mask >= 1 ) . astype ( \"float32\" ) # \u5982\u679c\u9700\u8981\u8fdb\u884c\u6570\u636e\u589e\u5f3a if self . transform is True : augmented = self . aug ( image = img , mask = mask ) img = augmented [ \"image\" ] mask = augmented [ \"mask\" ] # \u5e94\u7528\u9884\u5904\u7406\u51fd\u6570\uff08\u5982\u679c\u6709\uff09 img = self . preprocessing_fn ( img ) # \u8fd4\u56de\u56fe\u50cf\u548c\u63a9\u7801 return { \"image\" : transforms . ToTensor ()( img ), \"mask\" : transforms . ToTensor ()( mask ) . float (), } \u6709\u4e86\u6570\u636e\u96c6\u7c7b\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8bad\u7ec3\u51fd\u6570\u3002 import os import sys import torch import numpy as np import pandas as pd import segmentation_models_pytorch as smp import torch.nn as nn import torch.optim as optim from apex import amp from collections import OrderedDict from sklearn import model_selection from tqdm import tqdm from torch.optim import lr_scheduler from dataset import SIIMDataset # \u5b9a\u4e49\u8bad\u7ec3\u6570\u636e\u96c6CSV\u6587\u4ef6\u8def\u5f84 TRAINING_CSV = \"../input/train_pneumothorax.csv\" # \u5b9a\u4e49\u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u6279\u91cf\u5927\u5c0f TRAINING_BATCH_SIZE = 16 TEST_BATCH_SIZE = 4 # \u5b9a\u4e49\u8bad\u7ec3\u7684\u65f6\u671f\u6570 EPOCHS = 10 # \u6307\u5b9a\u4f7f\u7528\u7684\u7f16\u7801\u5668\u548c\u6743\u91cd ENCODER = \"resnet18\" ENCODER_WEIGHTS = \"imagenet\" # \u6307\u5b9a\u8bbe\u5907\uff08GPU\uff09 DEVICE = \"cuda\" # \u5b9a\u4e49\u8bad\u7ec3\u51fd\u6570 def train ( dataset , data_loader , model , criterion , optimizer ): model . train () num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs . to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) optimizer . zero_grad () outputs = model ( inputs ) loss = criterion ( outputs , targets ) with amp . scale_loss ( loss , optimizer ) as scaled_loss : scaled_loss . backward () optimizer . step () tk0 . close () # \u5b9a\u4e49\u8bc4\u4f30\u51fd\u6570 def evaluate ( dataset , data_loader , model ): model . eval () final_loss = 0 num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) with torch . no_grad (): for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) output = model ( inputs ) loss = criterion ( output , targets ) final_loss += loss tk0 . close () return final_loss / num_batches if __name__ == \"__main__\" : df = pd . read_csv ( TRAINING_CSV ) df_train , df_valid = model_selection . train_test_split ( df , random_state = 42 , test_size = 0.1 ) training_images = df_train . image_id . values validation_images = df_valid . image_id . values # \u521b\u5efa U-Net \u6a21\u578b model = smp . Unet ( encoder_name = ENCODER , encoder_weights = ENCODER_WEIGHTS , classes = 1 , activation = None , ) # \u83b7\u53d6\u6570\u636e\u9884\u5904\u7406\u51fd\u6570 prep_fn = smp . encoders . get_preprocessing_fn ( ENCODER , ENCODER_WEIGHTS ) # \u5c06\u6a21\u578b\u653e\u5728\u8bbe\u5907\u4e0a model . to ( DEVICE ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6 train_dataset = SIIMDataset ( training_images , transform = True , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = TRAINING_BATCH_SIZE , shuffle = True , num_workers = 12 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = SIIMDataset ( validation_images , transform = False , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = TEST_BATCH_SIZE , shuffle = True , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) # \u5b9a\u4e49\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = lr_scheduler . ReduceLROnPlateau ( optimizer , mode = \"min\" , patience = 3 , verbose = True ) # \u521d\u59cb\u5316 Apex \u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3 model , optimizer = amp . initialize ( model , optimizer , opt_level = \"O1\" , verbosity = 0 ) # \u5982\u679c\u6709\u591a\u4e2aGPU\uff0c\u5219\u4f7f\u7528 DataParallel \u8fdb\u884c\u5e76\u884c\u8bad\u7ec3 if torch . cuda . device_count () > 1 : print ( f \"Let's use { torch . cuda . device_count () } GPUs!\" ) model = nn . DataParallel ( model ) # \u8f93\u51fa\u8bad\u7ec3\u76f8\u5173\u7684\u4fe1\u606f print ( f \"Training batch size: { TRAINING_BATCH_SIZE } \" ) print ( f \"Test batch size: { TEST_BATCH_SIZE } \" ) print ( f \"Epochs: { EPOCHS } \" ) print ( f \"Image size: { IMAGE_SIZE } \" ) print ( f \"Number of training images: { len ( train_dataset ) } \" ) print ( f \"Number of validation images: { len ( valid_dataset ) } \" ) print ( f \"Encoder: { ENCODER } \" ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( EPOCHS ): print ( f \"Training Epoch: { epoch } \" ) train ( train_dataset , train_loader , model , criterion , optimizer ) print ( f \"Validation Epoch: { epoch } \" ) val_log = evaluate ( valid_dataset , valid_loader , model ) scheduler . step ( val_log [ \"loss\" ]) print ( \" \\n \" ) \u5728\u5206\u5272\u95ee\u9898\u4e2d\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u5404\u79cd\u635f\u5931\u51fd\u6570\uff0c\u4f8b\u5982\u4e8c\u5143\u4ea4\u53c9\u71b5\u3001focal\u635f\u5931\u3001dice\u635f\u5931\u7b49\u3002\u6211\u628a\u8fd9\u4e2a\u95ee\u9898\u7559\u7ed9 \u8bfb\u8005\u6839\u636e\u8bc4\u4f30\u6307\u6807\u6765\u51b3\u5b9a\u5408\u9002\u7684\u635f\u5931\u3002\u5f53\u8bad\u7ec3\u8fd9\u6837\u4e00\u4e2a\u6a21\u578b\u65f6\uff0c\u60a8\u5c06\u5efa\u7acb\u9884\u6d4b\u6c14\u80f8\u4f4d\u7f6e\u7684\u6a21\u578b\uff0c\u5982\u56fe 9 \u6240\u793a\u3002\u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u82f1\u4f1f\u8fbe apex \u8fdb\u884c\u4e86\u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4ece PyTorch 1.6.0+ \u7248\u672c\u5f00\u59cb\uff0cPyTorch \u672c\u8eab\u5c31\u63d0\u4f9b\u4e86\u8fd9\u4e00\u529f\u80fd\u3002 \u56fe 9\uff1a\u4ece\u8bad\u7ec3\u6709\u7d20\u7684\u6a21\u578b\u4e2d\u68c0\u6d4b\u5230\u6c14\u80f8\u7684\u793a\u4f8b\uff08\u53ef\u80fd\u4e0d\u662f\u6b63\u786e\u9884\u6d4b\uff09\u3002 \u6211\u5728\u4e00\u4e2a\u540d\u4e3a \"Well That's Fantastic Machine Learning (WTFML) \"\u7684 python \u8f6f\u4ef6\u5305\u4e2d\u6536\u5f55\u4e86\u4e00\u4e9b\u5e38\u7528\u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u5982\u4f55\u5e2e\u52a9\u6211\u4eec\u4e3a FGVC 202013 \u690d\u7269\u75c5\u7406\u5b66\u6311\u6218\u8d5b\u4e2d\u7684\u690d\u7269\u56fe\u50cf\u5efa\u7acb\u591a\u7c7b\u5206\u7c7b\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import argparse import torch import torchvision import torch.nn as nn import torch.nn.functional as F from sklearn import metrics from sklearn.model_selection import train_test_split from wtfml.engine import Engine from wtfml.data_loaders.image import ClassificationDataLoader # \u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\uff0c\u5b9e\u73b0\u5bc6\u96c6\u4ea4\u53c9\u71b5 class DenseCrossEntropy ( nn . Module ): def __init__ ( self ): super ( DenseCrossEntropy , self ) . __init__ () def forward ( self , logits , labels ): logits = logits . float () labels = labels . float () logprobs = F . log_softmax ( logits , dim =- 1 ) loss = - labels * logprobs loss = loss . sum ( - 1 ) return loss . mean () # \u81ea\u5b9a\u4e49\u795e\u7ecf\u7f51\u7edc\u6a21\u578b class Model ( nn . Module ): def __init__ ( self ): super () . __init () self . base_model = torchvision . models . resnet18 ( pretrained = True ) in_features = self . base_model . fc . in_features self . out = nn . Linear ( in_features , 4 ) def forward ( self , image , targets = None ): batch_size , C , H , W = image . shape x = self . base_model . conv1 ( image ) x = self . base_model . bn1 ( x ) x = self . base_model . relu ( x ) x = self . base_model . maxpool ( x ) x = self . base_model . layer1 ( x ) x = self . base_model . layer2 ( x ) x = self . base_model . layer3 ( x ) x = self . base_model . layer4 ( x ) x = F . adaptive_avg_pool2d ( x , 1 ) . reshape ( batch_size , - 1 ) x = self . out ( x ) loss = None if targets is not None : loss = DenseCrossEntropy ()( x , targets . type_as ( x )) return x , loss if __name__ == \"__main__\" : # \u547d\u4ee4\u884c\u53c2\u6570\u89e3\u6790\u5668 parser = argparse . ArgumentParser () parser . add_argument ( \"--data_path\" , type = str , ) parser . add_argument ( \"--device\" , type = str ,) parser . add_argument ( \"--epochs\" , type = int ,) args = parser . parse_args () # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( os . path . join ( args . data_path , \"train.csv\" )) images = df . image_id . values . tolist () images = [ os . path . join ( args . data_path , \"images\" , i + \".jpg\" ) for i in images ] targets = df [[ \"healthy\" , \"multiple_diseases\" , \"rust\" , \"scab\" ]] . values # \u521b\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u578b model = Model () model . to ( args . device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\u4ee5\u53ca\u6570\u636e\u589e\u5f3a mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5206\u5272\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 ( train_images , valid_images , train_targets , valid_targets ) = train_test_split ( images , targets ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = ClassificationDataLoader ( image_paths = train_images , targets = train_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = True , tpu = False ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = ClassificationDataLoader ( image_paths = valid_images , targets = valid_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = False , tpu = False ) # \u521b\u5efa\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u521b\u5efa\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = torch . optim . lr_scheduler . StepLR ( optimizer , step_size = 15 , gamma = 0.6 ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( args . epochs ): # \u8bad\u7ec3\u6a21\u578b train_loss = Engine . train ( train_loader , model , optimizer , device = args . device ) # \u8bc4\u4f30\u6a21\u578b valid_loss = Engine . evaluate ( valid_loader , model , device = args . device ) # \u6253\u5370\u635f\u5931\u4fe1\u606f print ( f \" { epoch } , Train Loss= { train_loss } Valid Loss= { valid_loss } \" ) \u6709\u4e86\u6570\u636e\u540e\uff0c\u5c31\u53ef\u4ee5\u8fd0\u884c\u811a\u672c\u4e86\uff1a python plant.py --data_path ../../plant_pathology --device cuda -- epochs 2 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .73it/s, loss = 0 .723 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .62it/s, loss = 0 .433 ] 0 , Train Loss = 0 .7228777609592261 Valid Loss = 0 .4327834551704341 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .74it/s, loss = 0 .271 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .63it/s, loss = 0 .568 ] 1 , Train Loss = 0 .2708700496790021 Valid Loss = 0 .56841839541649 \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u8ba9\u6211\u4eec\u6784\u5efa\u6a21\u578b\u53d8\u5f97\u7b80\u5355\uff0c\u4ee3\u7801\u4e5f\u6613\u4e8e\u9605\u8bfb\u548c\u7406\u89e3\u3002\u6ca1\u6709\u4efb\u4f55\u5c01\u88c5\u7684 PyTorch \u6548\u679c\u6700\u597d\u3002\u56fe\u50cf\u4e2d\u4e0d\u4ec5\u4ec5\u6709\u5206\u7c7b\uff0c\u8fd8\u6709\u5f88\u591a\u5176\u4ed6\u7684\u5185\u5bb9\uff0c\u5982\u679c\u6211\u5f00\u59cb\u5199\u6240\u6709\u7684\u5185\u5bb9\uff0c\u5c31\u5f97\u518d\u5199\u4e00\u672c\u4e66\u4e86\uff0c \u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u56fe\u50cf\u95ee\u9898\uff08\u4f5c\u8005\u5728\u5f00\u73a9\u7b11\uff09\u3002","title":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5"},{"location":"%E5%A4%84%E7%90%86%E5%88%86%E7%B1%BB%E5%8F%98%E9%87%8F/","text":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf \u5f88\u591a\u4eba\u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u65f6\u90fd\u4f1a\u9047\u5230\u5f88\u591a\u56f0\u96be\uff0c\u56e0\u6b64\u8fd9\u503c\u5f97\u7528\u6574\u6574\u4e00\u7ae0\u7684\u7bc7\u5e45\u6765\u8ba8\u8bba\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u5c06\u8bb2\u8ff0\u4e0d\u540c\u7c7b\u578b\u7684\u5206\u7c7b\u6570\u636e\uff0c\u4ee5\u53ca\u5982\u4f55\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u3002 \u4ec0\u4e48\u662f\u5206\u7c7b\u53d8\u91cf\uff1f \u5206\u7c7b\u53d8\u91cf/\u7279\u5f81\u662f\u6307\u4efb\u4f55\u7279\u5f81\u7c7b\u578b\uff0c\u53ef\u5206\u4e3a\u4e24\u5927\u7c7b\uff1a - \u65e0\u5e8f - \u6709\u5e8f \u65e0\u5e8f\u53d8\u91cf \u662f\u6307\u6709\u4e24\u4e2a\u6216\u4e24\u4e2a\u4ee5\u4e0a\u7c7b\u522b\u7684\u53d8\u91cf\uff0c\u8fd9\u4e9b\u7c7b\u522b\u6ca1\u6709\u4efb\u4f55\u76f8\u5173\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u5982\u679c\u5c06\u6027\u522b\u5206\u4e3a\u4e24\u7ec4\uff0c\u5373\u7537\u6027\u548c\u5973\u6027\uff0c\u5219\u53ef\u5c06\u5176\u89c6\u4e3a\u540d\u4e49\u53d8\u91cf\u3002 \u6709\u5e8f\u53d8\u91cf \u5219\u6709 \"\u7b49\u7ea7 \"\u6216\u7c7b\u522b\uff0c\u5e76\u6709\u7279\u5b9a\u7684\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u4e00\u4e2a\u987a\u5e8f\u5206\u7c7b\u53d8\u91cf\u53ef\u4ee5\u662f\u4e00\u4e2a\u5177\u6709\u4f4e\u3001\u4e2d\u3001\u9ad8\u4e09\u4e2a\u4e0d\u540c\u7b49\u7ea7\u7684\u7279\u5f81\u3002\u987a\u5e8f\u5f88\u91cd\u8981\u3002 \u5c31\u5b9a\u4e49\u800c\u8a00\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u53d8\u91cf\u5206\u4e3a \u4e8c\u5143\u53d8\u91cf \uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u7c7b\u522b\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u6709\u4e9b\u4eba\u751a\u81f3\u628a\u5206\u7c7b\u53d8\u91cf\u79f0\u4e3a \" \u5faa\u73af \"\u53d8\u91cf\u3002\u5468\u671f\u53d8\u91cf\u4ee5 \"\u5468\u671f \"\u7684\u5f62\u5f0f\u5b58\u5728\uff0c\u4f8b\u5982\u4e00\u5468\u4e2d\u7684\u5929\u6570\uff1a \u5468\u65e5\u3001\u5468\u4e00\u3001\u5468\u4e8c\u3001\u5468\u4e09\u3001\u5468\u56db\u3001\u5468\u4e94\u548c\u5468\u516d\u3002\u5468\u516d\u8fc7\u540e\uff0c\u53c8\u662f\u5468\u65e5\u3002\u8fd9\u5c31\u662f\u4e00\u4e2a\u5faa\u73af\u3002\u53e6\u4e00\u4e2a\u4f8b\u5b50\u662f\u4e00\u5929\u4e2d\u7684\u5c0f\u65f6\u6570\uff0c\u5982\u679c\u6211\u4eec\u5c06\u5b83\u4eec\u89c6\u4e3a\u7c7b\u522b\u7684\u8bdd\u3002 \u5206\u7c7b\u53d8\u91cf\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9a\u4e49\uff0c\u5f88\u591a\u4eba\u4e5f\u8c08\u5230\u8981\u6839\u636e\u5206\u7c7b\u53d8\u91cf\u7684\u7c7b\u578b\u6765\u5904\u7406\u4e0d\u540c\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u4e0d\u8fc7\uff0c\u6211\u8ba4\u4e3a\u6ca1\u6709\u5fc5\u8981\u8fd9\u6837\u505a\u3002\u6240\u6709\u6d89\u53ca\u5206\u7c7b\u53d8\u91cf\u7684\u95ee\u9898\u90fd\u53ef\u4ee5\u7528\u540c\u6837\u7684\u65b9\u6cd5\u5904\u7406\u3002 \u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e00\u4e2a\u6570\u636e\u96c6\uff08\u4e00\u5982\u65e2\u5f80\uff09\u3002\u8981\u4e86\u89e3\u5206\u7c7b\u53d8\u91cf\uff0c\u6700\u597d\u7684\u514d\u8d39\u6570\u636e\u96c6\u4e4b\u4e00\u662f Kaggle \u5206\u7c7b\u7279\u5f81\u7f16\u7801\u6311\u6218\u8d5b\u4e2d\u7684 cat-in-the-dat \u3002\u5171\u6709\u4e24\u4e2a\u6311\u6218\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6311\u6218\u7684\u6570\u636e\uff0c\u56e0\u4e3a\u5b83\u6bd4\u524d\u4e00\u4e2a\u7248\u672c\u6709\u66f4\u591a\u53d8\u91cf\uff0c\u96be\u5ea6\u4e5f\u66f4\u5927\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u3002 \u56fe 1\uff1aCat-in-the-dat-ii challenge\u90e8\u5206\u6570\u636e\u5c55\u793a \u6570\u636e\u96c6\u7531\u5404\u79cd\u5206\u7c7b\u53d8\u91cf\u7ec4\u6210\uff1a \u65e0\u5e8f \u6709\u5e8f \u5faa\u73af \u4e8c\u5143 \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6240\u6709\u5b58\u5728\u7684\u53d8\u91cf\u548c\u76ee\u6807\u53d8\u91cf\u7684\u5b50\u96c6\u3002 \u8fd9\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002 \u76ee\u6807\u53d8\u91cf\u5bf9\u4e8e\u6211\u4eec\u5b66\u4e60\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\u5e76\u4e0d\u5341\u5206\u91cd\u8981\uff0c\u4f46\u6700\u7ec8\u6211\u4eec\u5c06\u5efa\u7acb\u4e00\u4e2a\u7aef\u5230\u7aef\u6a21\u578b\uff0c\u56e0\u6b64\u8ba9\u6211\u4eec\u770b\u770b\u56fe 2 \u4e2d\u7684\u76ee\u6807\u53d8\u91cf\u5206\u5e03\u3002\u6211\u4eec\u770b\u5230\u76ee\u6807\u662f \u504f\u659c \u7684\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8fd9\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u6765\u8bf4\uff0c\u6700\u597d\u7684\u6307\u6807\u662f ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528\u7cbe\u786e\u5ea6\u548c\u53ec\u56de\u7387\uff0c\u4f46 AUC \u7ed3\u5408\u4e86\u8fd9\u4e24\u4e2a\u6307\u6807\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 AUC \u6765\u8bc4\u4f30\u6211\u4eec\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u7684\u6a21\u578b\u3002 \u56fe 2\uff1a\u6807\u7b7e\u8ba1\u6570\u3002X \u8f74\u8868\u793a\u6807\u7b7e\uff0cY \u8f74\u8868\u793a\u6807\u7b7e\u8ba1\u6570 \u603b\u4f53\u800c\u8a00\uff0c\u6709\uff1a 5\u4e2a\u4e8c\u5143\u53d8\u91cf 10\u4e2a\u65e0\u5e8f\u53d8\u91cf 6\u4e2a\u6709\u5e8f\u53d8\u91cf 2\u4e2a\u5faa\u73af\u53d8\u91cf 1\u4e2a\u76ee\u6807\u53d8\u91cf \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u96c6\u4e2d\u7684 ord_2 \u7279\u5f81\u3002\u5b83\u5305\u62ec6\u4e2a\u4e0d\u540c\u7684\u7c7b\u522b\uff1a - \u51b0\u51bb - \u6e29\u6696 - \u5bd2\u51b7 - \u8f83\u70ed - \u70ed - \u975e\u5e38\u70ed \u6211\u4eec\u5fc5\u987b\u77e5\u9053\uff0c\u8ba1\u7b97\u673a\u65e0\u6cd5\u7406\u89e3\u6587\u672c\u6570\u636e\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u662f\u521b\u5efa\u4e00\u4e2a\u5b57\u5178\uff0c\u5c06\u8fd9\u4e9b\u503c\u6620\u5c04\u4e3a\u4ece 0 \u5230 N-1 \u7684\u6570\u5b57\uff0c\u5176\u4e2d N \u662f\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7c7b\u522b\u7684\u603b\u6570\u3002 # \u6620\u5c04\u5b57\u5178 mapping = { \"Freezing\" : 0 , \"Warm\" : 1 , \"Cold\" : 2 , \"Boiling Hot\" : 3 , \"Hot\" : 4 , \"Lava Hot\" : 5 } \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8bfb\u53d6\u6570\u636e\u96c6\uff0c\u5e76\u8f7b\u677e\u5730\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002 import pandas as pd # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u53d6*ord_2*\u5217\uff0c\u5e76\u4f7f\u7528\u6620\u5c04\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57 df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. map ( mapping ) \u6620\u5c04\u524d\u7684\u6570\u503c\u8ba1\u6570\uff1a df .* ord_2 *. value_counts () Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : * ord_2 * , dtype : int64 \u6620\u5c04\u540e\u7684\u6570\u503c\u8ba1\u6570\uff1a 0.0 142726 1.0 124239 2.0 97822 3.0 84790 4.0 67508 5.0 64840 Name : * ord_2 * , dtype : int64 \u8fd9\u79cd\u5206\u7c7b\u53d8\u91cf\u7684\u7f16\u7801\u65b9\u5f0f\u88ab\u79f0\u4e3a\u6807\u7b7e\u7f16\u7801\uff08Label Encoding\uff09\u6211\u4eec\u5c06\u6bcf\u4e2a\u7c7b\u522b\u7f16\u7801\u4e3a\u4e00\u4e2a\u6570\u5b57\u6807\u7b7e\u3002 \u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 LabelEncoder \u8fdb\u884c\u7f16\u7801\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u5c06\u7f3a\u5931\u503c\u586b\u5145\u4e3a\"NONE\" df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. fillna ( \"NONE\" ) # LabelEncoder\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u8f6c\u6362\u6570\u636e df . loc [:, \"*ord_2*\" ] = lbl_enc . fit_transform ( df .* ord_2 *. values ) \u4f60\u4f1a\u770b\u5230\u6211\u4f7f\u7528\u4e86 pandas \u7684 fillna\u3002\u539f\u56e0\u662f scikit-learn \u7684 LabelEncoder \u65e0\u6cd5\u5904\u7406 NaN \u503c\uff0c\u800c ord_2 \u5217\u4e2d\u6709 NaN \u503c\u3002 \u6211\u4eec\u53ef\u4ee5\u5728\u8bb8\u591a\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u4e2d\u76f4\u63a5\u4f7f\u7528\u5b83\uff1a - \u51b3\u7b56\u6811 - \u968f\u673a\u68ee\u6797 - \u63d0\u5347\u6811 - \u6216\u4efb\u4f55\u4e00\u79cd\u63d0\u5347\u6811\u6a21\u578b - XGBoost - GBM - LightGBM \u8fd9\u79cd\u7f16\u7801\u65b9\u5f0f\u4e0d\u80fd\u7528\u4e8e\u7ebf\u6027\u6a21\u578b\u3001\u652f\u6301\u5411\u91cf\u673a\u6216\u795e\u7ecf\u7f51\u7edc\uff0c\u56e0\u4e3a\u5b83\u4eec\u5e0c\u671b\u6570\u636e\u662f\u6807\u51c6\u5316\u7684\u3002 \u5bf9\u4e8e\u8fd9\u4e9b\u7c7b\u578b\u7684\u6a21\u578b\uff0c\u6211\u4eec\u53ef\u4ee5\u5bf9\u6570\u636e\u8fdb\u884c\u4e8c\u503c\u5316\uff08binarize\uff09\u5904\u7406\u3002 Freezing --> 0 --> 0 0 0 Warm --> 1 --> 0 0 1 Cold --> 2 --> 0 1 0 Boiling Hot --> 3 --> 0 1 1 Hot --> 4 --> 1 0 0 Lava Hot --> 5 --> 1 0 1 \u8fd9\u53ea\u662f\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\uff0c\u7136\u540e\u518d\u8f6c\u6362\u4e3a\u4e8c\u503c\u5316\u8868\u793a\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u628a\u4e00\u4e2a\u7279\u5f81\u5206\u6210\u4e86\u4e09\u4e2a\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u7279\u5f81\uff08\u6216\u5217\uff09\u3002\u5982\u679c\u6211\u4eec\u6709\u66f4\u591a\u7684\u7c7b\u522b\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u5206\u6210\u66f4\u591a\u7684\u5217\u3002 \u5982\u679c\u6211\u4eec\u7528\u7a00\u758f\u683c\u5f0f\u5b58\u50a8\u5927\u91cf\u4e8c\u503c\u5316\u53d8\u91cf\uff0c\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u5b58\u50a8\u8fd9\u4e9b\u53d8\u91cf\u3002\u7a00\u758f\u683c\u5f0f\u4e0d\u8fc7\u662f\u4e00\u79cd\u5728\u5185\u5b58\u4e2d\u5b58\u50a8\u6570\u636e\u7684\u8868\u793a\u6216\u65b9\u5f0f\uff0c\u5728\u8fd9\u79cd\u683c\u5f0f\u4e2d\uff0c\u4f60\u5e76\u4e0d\u5b58\u50a8\u6240\u6709\u7684\u503c\uff0c\u800c\u53ea\u5b58\u50a8\u91cd\u8981\u7684\u503c\u3002\u5728\u4e0a\u8ff0\u4e8c\u8fdb\u5236\u53d8\u91cf\u7684\u60c5\u51b5\u4e2d\uff0c\u6700\u91cd\u8981\u7684\u5c31\u662f\u6709 1 \u7684\u5730\u65b9\u3002 \u5f88\u96be\u60f3\u8c61\u8fd9\u6837\u7684\u683c\u5f0f\uff0c\u4f46\u4e3e\u4e2a\u4f8b\u5b50\u5c31\u4f1a\u660e\u767d\u3002 \u5047\u8bbe\u4e0a\u9762\u7684\u6570\u636e\u5e27\u4e2d\u53ea\u6709\u4e00\u4e2a\u7279\u5f81\uff1a ord_2 \u3002 Index Feature 0 Warm 1 Hot 2 Lava hot \u76ee\u524d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6570\u636e\u96c6\u4e2d\u7684\u4e09\u4e2a\u6837\u672c\u3002\u8ba9\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u4e8c\u503c\u8868\u793a\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u6709\u4e09\u4e2a\u9879\u76ee\u3002 \u8fd9\u4e09\u4e2a\u9879\u76ee\u5c31\u662f\u4e09\u4e2a\u7279\u5f81\u3002 Index Feature_0 Feature_1 Feature_2 0 0 0 1 1 1 0 0 2 1 0 1 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u7279\u5f81\u5b58\u50a8\u5728\u4e00\u4e2a\u6709 3 \u884c 3 \u5217\uff083x3\uff09\u7684\u77e9\u9635\u4e2d\u3002\u77e9\u9635\u7684\u6bcf\u4e2a\u5143\u7d20\u5360\u7528 8 \u4e2a\u5b57\u8282\u3002\u56e0\u6b64\uff0c\u8fd9\u4e2a\u6570\u7ec4\u7684\u603b\u5185\u5b58\u9700\u6c42\u4e3a 8x3x3 = 72 \u5b57\u8282\u3002 \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e2a\u7b80\u5355\u7684 python \u4ee3\u7801\u6bb5\u6765\u68c0\u67e5\u8fd9\u4e00\u70b9\u3002 import numpy as np example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) print ( example . nbytes ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u6253\u5370\u51fa 72\uff0c\u5c31\u50cf\u6211\u4eec\u4e4b\u524d\u8ba1\u7b97\u7684\u90a3\u6837\u3002\u4f46\u6211\u4eec\u9700\u8981\u5b58\u50a8\u8fd9\u4e2a\u77e9\u9635\u7684\u6240\u6709\u5143\u7d20\u5417\uff1f\u5982\u524d\u6240\u8ff0\uff0c\u6211\u4eec\u53ea\u5bf9 1 \u611f\u5174\u8da3\u30020 \u5e76\u4e0d\u91cd\u8981\uff0c\u56e0\u4e3a\u4efb\u4f55\u4e0e 0 \u76f8\u4e58\u7684\u5143\u7d20\u90fd\u662f 0\uff0c\u800c 0 \u4e0e\u4efb\u4f55\u5143\u7d20\u76f8\u52a0\u6216\u76f8\u51cf\u4e5f\u6ca1\u6709\u4efb\u4f55\u533a\u522b\u3002\u53ea\u7528 1 \u8868\u793a\u77e9\u9635\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u67d0\u79cd\u5b57\u5178\u65b9\u6cd5\uff0c\u5176\u4e2d\u952e\u662f\u884c\u548c\u5217\u7684\u7d22\u5f15\uff0c\u503c\u662f 1\uff1a ( 0 , 2 ) 1 ( 1 , 0 ) 1 ( 2 , 0 ) 1 ( 2 , 2 ) 1 \u8fd9\u6837\u7684\u7b26\u53f7\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u5b58\u50a8\u56db\u4e2a\u503c\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u3002\u4f7f\u7528\u7684\u603b\u5185\u5b58\u4e3a 8x4 = 32 \u5b57\u8282\u3002\u4efb\u4f55 numpy \u6570\u7ec4\u90fd\u53ef\u4ee5\u901a\u8fc7\u7b80\u5355\u7684 python \u4ee3\u7801\u8f6c\u6362\u4e3a\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) sparse_example = sparse . csr_matrix ( example ) print ( sparse_example . data . nbytes ) \u8fd9\u5c06\u6253\u5370 32\uff0c\u6bd4\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u5c11\u4e86\u8fd9\u4e48\u591a\uff01\u7a00\u758f csr \u77e9\u9635\u7684\u603b\u5927\u5c0f\u662f\u4e09\u4e2a\u503c\u7684\u603b\u548c\u3002 print ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) \u8fd9\u5c06\u6253\u5370\u51fa 64 \u4e2a\u5143\u7d20\uff0c\u4ecd\u7136\u5c11\u4e8e\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u3002\u9057\u61be\u7684\u662f\uff0c\u6211\u4e0d\u4f1a\u8be6\u7ec6\u4ecb\u7ecd\u8fd9\u4e9b\u5143\u7d20\u3002\u4f60\u53ef\u4ee5\u5728 scipy \u6587\u6863\u4e2d\u4e86\u89e3\u66f4\u591a\u3002\u5f53\u6211\u4eec\u62e5\u6709\u66f4\u5927\u7684\u6570\u7ec4\u65f6\uff0c\u6bd4\u5982\u8bf4\u62e5\u6709\u6570\u5343\u4e2a\u6837\u672c\u548c\u6570\u4e07\u4e2a\u7279\u5f81\u7684\u6570\u7ec4\uff0c\u5927\u5c0f\u5dee\u5f02\u5c31\u4f1a\u53d8\u5f97\u975e\u5e38\u5927\u3002\u4f8b\u5982\uff0c\u6211\u4eec\u4f7f\u7528\u57fa\u4e8e\u8ba1\u6570\u7279\u5f81\u7684\u6587\u672c\u6570\u636e\u96c6\u3002 import numpy as np from scipy import sparse n_rows = 10000 n_cols = 100000 # \u751f\u6210\u7b26\u5408\u4f2f\u52aa\u5229\u5206\u5e03\u7684\u968f\u673a\u6570\u7ec4\uff0c\u7ef4\u5ea6\u4e3a[10000, 100000] example = np . random . binomial ( 1 , p = 0.05 , size = ( n_rows , n_cols )) print ( f \"Size of dense array: { example . nbytes } \" ) # \u5c06\u968f\u673a\u77e9\u9635\u8f6c\u6362\u4e3a\u6d17\u6f31\u77e9\u9635 sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u8fd9\u5c06\u6253\u5370\uff1a Size of dense array : 8000000000 Size of sparse array : 399932496 Full size of sparse array : 599938748 \u56e0\u6b64\uff0c\u5bc6\u96c6\u9635\u5217\u9700\u8981 ~8000MB \u6216\u5927\u7ea6 8GB \u5185\u5b58\u3002\u800c\u7a00\u758f\u9635\u5217\u53ea\u5360\u7528 399MB \u5185\u5b58\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5f53\u6211\u4eec\u7684\u7279\u5f81\u4e2d\u6709\u5927\u91cf\u96f6\u65f6\uff0c\u6211\u4eec\u66f4\u559c\u6b22\u7a00\u758f\u9635\u5217\u800c\u4e0d\u662f\u5bc6\u96c6\u9635\u5217\u7684\u539f\u56e0\u3002 \u8bf7\u6ce8\u610f\uff0c\u7a00\u758f\u77e9\u9635\u6709\u591a\u79cd\u4e0d\u540c\u7684\u8868\u793a\u65b9\u6cd5\u3002\u8fd9\u91cc\u6211\u53ea\u5c55\u793a\u4e86\u5176\u4e2d\u4e00\u79cd\uff08\u53ef\u80fd\u4e5f\u662f\u6700\u5e38\u7528\u7684\uff09\u65b9\u6cd5\u3002\u6df1\u5165\u63a2\u8ba8\u8fd9\u4e9b\u65b9\u6cd5\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\uff0c\u56e0\u6b64\u7559\u7ed9\u8bfb\u8005\u4e00\u4e2a\u7ec3\u4e60\u3002 \u5c3d\u7ba1\u4e8c\u503c\u5316\u7279\u5f81\u7684\u7a00\u758f\u8868\u793a\u6bd4\u5176\u5bc6\u96c6\u8868\u793a\u6240\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u4f46\u5bf9\u4e8e\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\uff0c\u8fd8\u6709\u4e00\u79cd\u8f6c\u6362\u6240\u5360\u7528\u7684\u5185\u5b58\u66f4\u5c11\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \" \u72ec\u70ed\u7f16\u7801 \"\u3002 \u72ec\u70ed\u7f16\u7801\u4e5f\u662f\u4e00\u79cd\u4e8c\u503c\u7f16\u7801\uff0c\u56e0\u4e3a\u53ea\u6709 0 \u548c 1 \u4e24\u4e2a\u503c\u3002\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5b83\u5e76\u4e0d\u662f\u4e8c\u503c\u8868\u793a\u6cd5\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0b\u9762\u7684\u4f8b\u5b50\u6765\u7406\u89e3\u5b83\u7684\u8868\u793a\u6cd5\u3002 \u5047\u8bbe\u6211\u4eec\u7528\u4e00\u4e2a\u5411\u91cf\u6765\u8868\u793a ord_2 \u53d8\u91cf\u7684\u6bcf\u4e2a\u7c7b\u522b\u3002\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e0e ord_2 \u53d8\u91cf\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u6bcf\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u90fd\u662f 6\uff0c\u5e76\u4e14\u9664\u4e86\u4e00\u4e2a\u4f4d\u7f6e\u5916\uff0c\u5176\u4ed6\u4f4d\u7f6e\u90fd\u662f 0\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u7279\u6b8a\u7684\u5411\u91cf\u8868\u3002 Freezing 0 0 0 0 0 1 Warm 0 0 0 0 1 0 Cold 0 0 0 1 0 0 Boiling Hot 0 0 1 0 0 0 Hot 0 1 0 0 0 0 Lava Hot 1 0 0 0 0 0 \u6211\u4eec\u770b\u5230\u5411\u91cf\u7684\u5927\u5c0f\u662f 1x6\uff0c\u5373\u5411\u91cf\u4e2d\u67096\u4e2a\u5143\u7d20\u3002\u8fd9\u4e2a\u6570\u5b57\u662f\u600e\u4e48\u6765\u7684\u5462\uff1f\u5982\u679c\u4f60\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u5982\u524d\u6240\u8ff0\uff0c\u67096\u4e2a\u7c7b\u522b\u3002\u5728\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u65f6\uff0c\u5411\u91cf\u7684\u5927\u5c0f\u5fc5\u987b\u4e0e\u6211\u4eec\u8981\u67e5\u770b\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u6bcf\u4e2a\u5411\u91cf\u90fd\u6709\u4e00\u4e2a 1\uff0c\u5176\u4f59\u6240\u6709\u503c\u90fd\u662f 0\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u7528\u8fd9\u4e9b\u7279\u5f81\u6765\u4ee3\u66ff\u4e4b\u524d\u7684\u4e8c\u503c\u5316\u7279\u5f81\uff0c\u770b\u770b\u80fd\u8282\u7701\u591a\u5c11\u5185\u5b58\u3002 \u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\u4ee5\u524d\u7684\u6570\u636e\uff0c\u5b83\u770b\u8d77\u6765\u5982\u4e0b\uff1a Index Feature 0 Warm 1 Hot 2 Lava hot \u6bcf\u4e2a\u6837\u672c\u67093\u4e2a\u7279\u5f81\u3002\u4f46\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u72ec\u70ed\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 6\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u67096\u4e2a\u7279\u5f81\uff0c\u800c\u4e0d\u662f3\u4e2a\u3002 Index F_0 F_1 F_2 F_3 F_4 F_5 0 0 0 0 0 1 0 1 0 1 0 0 0 0 2 1 0 1 0 0 0 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 6 \u4e2a\u7279\u5f81\uff0c\u800c\u5728\u8fd9\u4e2a 3x6 \u6570\u7ec4\u4e2d\uff0c\u53ea\u6709 3 \u4e2a1\u3002\u4f7f\u7528 numpy \u8ba1\u7b97\u5927\u5c0f\u4e0e\u4e8c\u503c\u5316\u5927\u5c0f\u8ba1\u7b97\u811a\u672c\u975e\u5e38\u76f8\u4f3c\u3002\u4f60\u9700\u8981\u6539\u53d8\u7684\u53ea\u662f\u6570\u7ec4\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6bb5\u4ee3\u7801\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 0 , 0 , 1 , 0 ], [ 0 , 1 , 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 , 0 , 0 ] ] ) print ( f \"Size of dense array: { example . nbytes } \" ) sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u6253\u5370\u5185\u5b58\u5927\u5c0f\u4e3a\uff1a Size of dense array : 144 Size of sparse array : 24 Full size of sparse array : 52 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5bc6\u96c6\u77e9\u9635\u7684\u5927\u5c0f\u8fdc\u8fdc\u5927\u4e8e\u4e8c\u503c\u5316\u77e9\u9635\u7684\u5927\u5c0f\u3002\u4e0d\u8fc7\uff0c\u7a00\u758f\u6570\u7ec4\u7684\u5927\u5c0f\u8981\u66f4\u5c0f\u3002\u8ba9\u6211\u4eec\u7528\u66f4\u5927\u7684\u6570\u7ec4\u6765\u8bd5\u8bd5\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 OneHotEncoder \u5c06\u5305\u542b 1001 \u4e2a\u7c7b\u522b\u7684\u7279\u5f81\u6570\u7ec4\u8f6c\u6362\u4e3a\u5bc6\u96c6\u77e9\u9635\u548c\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from sklearn import preprocessing # \u751f\u6210\u7b26\u5408\u5747\u5300\u5206\u5e03\u7684\u968f\u673a\u6574\u6570\uff0c\u7ef4\u5ea6\u4e3a[1000000, 10000000] example = np . random . randint ( 1000 , size = 1000000 ) # \u72ec\u70ed\u7f16\u7801\uff0c\u975e\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = False ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of dense array: { ohe_example . nbytes } \" ) # \u72ec\u70ed\u7f16\u7801\uff0c\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = True ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of sparse array: { ohe_example . data . nbytes } \" ) full_size = ( ohe_example . data . nbytes + ohe_example . indptr . nbytes + ohe_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u4e0a\u9762\u4ee3\u7801\u6253\u5370\u7684\u8f93\u51fa\uff1a Size of dense array : 8000000000 Size of sparse array : 8000000 Full size of sparse array : 16000004 \u8fd9\u91cc\u7684\u5bc6\u96c6\u9635\u5217\u5927\u5c0f\u7ea6\u4e3a 8GB\uff0c\u7a00\u758f\u9635\u5217\u4e3a 8MB\u3002\u5982\u679c\u53ef\u4ee5\u9009\u62e9\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff1f\u5728\u6211\u770b\u6765\uff0c\u9009\u62e9\u5f88\u7b80\u5355\uff0c\u4e0d\u662f\u5417\uff1f \u8fd9\u4e09\u79cd\u65b9\u6cd5\uff08\u6807\u7b7e\u7f16\u7801\u3001\u7a00\u758f\u77e9\u9635\u3001\u72ec\u70ed\u7f16\u7801\uff09\u662f\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u7684\u6700\u91cd\u8981\u65b9\u6cd5\u3002\u4e0d\u8fc7\uff0c\u4f60\u8fd8\u53ef\u4ee5\u7528\u5f88\u591a\u5176\u4ed6\u4e0d\u540c\u7684\u65b9\u6cd5\u6765\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u3002\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u6570\u503c\u53d8\u91cf\u5c31\u662f\u5176\u4e2d\u7684\u4e00\u4e2a\u4f8b\u5b50\u3002 \u5047\u8bbe\u6211\u4eec\u56de\u5230\u4e4b\u524d\u7684\u5206\u7c7b\u7279\u5f81\u6570\u636e\uff08\u539f\u59cb\u6570\u636e\u4e2d\u7684 cat-in-the-dat-ii\uff09\u3002\u5728\u6570\u636e\u4e2d\uff0c ord_2 \u7684\u503c\u4e3a\u201c\u70ed\u201c\u7684 id \u6709\u591a\u5c11\uff1f \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8ba1\u7b97\u6570\u636e\u7684\u5f62\u72b6\uff08shape\uff09\u8f7b\u677e\u8ba1\u7b97\u51fa\u8fd9\u4e2a\u503c\uff0c\u5176\u4e2d ord_2 \u5217\u7684\u503c\u4e3a Boiling Hot \u3002 In [ X ]: df [ df . ord_2 == \"Boiling Hot\" ] . shape Out [ X ]: ( 84790 , 25 ) \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 84790 \u6761\u8bb0\u5f55\u5177\u6709\u6b64\u503c\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 pandas \u4e2d\u7684 groupby \u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u8be5\u503c\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . count () Out [ X ]: ord_2 Boiling Hot 84790 Cold 97822 Freezing 142726 Hot 67508 Lava Hot 64840 Warm 124239 Name : id , dtype : int64 \u5982\u679c\u6211\u4eec\u53ea\u662f\u5c06 ord_2 \u5217\u66ff\u6362\u4e3a\u5176\u8ba1\u6570\u503c\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u5c06\u5176\u8f6c\u6362\u4e3a\u4e00\u79cd\u6570\u503c\u7279\u5f81\u4e86\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 transform \u51fd\u6570\u548c groupby \u6765\u521b\u5efa\u65b0\u5217\u6216\u66ff\u6362\u8fd9\u4e00\u5217\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . transform ( \"count\" ) Out [ X ]: 0 67508.0 1 124239.0 2 142726.0 3 64840.0 4 97822.0 ... 599995 142726.0 599996 84790.0 599997 142726.0 599998 124239.0 599999 84790.0 Name : id , Length : 600000 , dtype : float64 \u4f60\u53ef\u4ee5\u6dfb\u52a0\u6240\u6709\u7279\u5f81\u7684\u8ba1\u6570\uff0c\u4e5f\u53ef\u4ee5\u66ff\u6362\u5b83\u4eec\uff0c\u6216\u8005\u6839\u636e\u591a\u4e2a\u5217\u53ca\u5176\u8ba1\u6570\u8fdb\u884c\u5206\u7ec4\u3002\u4f8b\u5982\uff0c\u4ee5\u4e0b\u4ee3\u7801\u901a\u8fc7\u5bf9 ord_1 \u548c ord_2 \u5217\u5206\u7ec4\u8fdb\u884c\u8ba1\u6570\u3002 In [ X ]: df . groupby ( ... : [ ... : \"ord_1\" , ... : \"ord_2\" ... : ] ... : )[ \"id\" ] . count () . reset_index ( name = \"count\" ) Out [ X ]: ord_1 ord_2 count 0 Contributor Boiling Hot 15634 1 Contributor Cold 17734 2 Contributor Freezing 26082 3 Contributor Hot 12428 4 Contributor Lava Hot 11919 5 Contributor Warm 22774 6 Expert Boiling Hot 19477 7 Expert Cold 22956 8 Expert Freezing 33249 9 Expert Hot 15792 10 Expert Lava Hot 15078 11 Expert Warm 28900 12 Grandmaster Boiling Hot 13623 13 Grandmaster Cold 15464 14 Grandmaster Freezing 22818 15 Grandmaster Hot 10805 16 Grandmaster Lava Hot 10363 17 Grandmaster Warm 19899 18 Master Boiling Hot 10800 ... \u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u884c\uff0c\u4ee5\u4fbf\u5728\u4e00\u9875\u4e2d\u5bb9\u7eb3\u8fd9\u4e9b\u884c\u3002\u8fd9\u662f\u53e6\u4e00\u79cd\u53ef\u4ee5\u4f5c\u4e3a\u529f\u80fd\u6dfb\u52a0\u7684\u8ba1\u6570\u3002\u60a8\u73b0\u5728\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4f7f\u7528 id \u5217\u8fdb\u884c\u8ba1\u6570\u3002\u4e0d\u8fc7\uff0c\u4f60\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5bf9\u5217\u7684\u7ec4\u5408\u8fdb\u884c\u5206\u7ec4\uff0c\u5bf9\u5176\u4ed6\u5217\u8fdb\u884c\u8ba1\u6570\u3002 \u8fd8\u6709\u4e00\u4e2a\u5c0f\u7a8d\u95e8\uff0c\u5c31\u662f\u4ece\u8fd9\u4e9b\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u4ece\u73b0\u6709\u7684\u7279\u5f81\u4e2d\u521b\u5efa\u65b0\u7684\u5206\u7c7b\u7279\u5f81\uff0c\u800c\u4e14\u53ef\u4ee5\u6beb\u4e0d\u8d39\u529b\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot 1 Grandmaster_Warm 2 nan_Freezing 3 Novice_Lava Hot 4 Grandmaster_Cold ... 599999 Contributor_Boiling Hot Name : new_feature , Length : 600000 , dtype : object \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u7528\u4e0b\u5212\u7ebf\u5c06 ord_1 \u548c ord_2 \u5408\u5e76\uff0c\u7136\u540e\u5c06\u8fd9\u4e9b\u5217\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u7c7b\u578b\u3002\u8bf7\u6ce8\u610f\uff0cNaN \u4e5f\u4f1a\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u3002\u4e0d\u8fc7\u6ca1\u5173\u7cfb\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06 NaN \u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u6709\u4e86\u4e00\u4e2a\u7531\u8fd9\u4e24\u4e2a\u7279\u5f81\u7ec4\u5408\u800c\u6210\u7684\u65b0\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u5c06\u4e09\u5217\u4ee5\u4e0a\u6216\u56db\u5217\u751a\u81f3\u66f4\u591a\u5217\u7ec4\u5408\u5728\u4e00\u8d77\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : + \"_\" ... : + df . ord_3 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot_c 1 Grandmaster_Warm_e 2 nan_Freezing_n 3 Novice_Lava Hot_a 4 Grandmaster_Cold_h ... 599999 Contributor_Boiling Hot_b Name : new_feature , Length : 600000 , dtype : object \u90a3\u4e48\uff0c\u6211\u4eec\u5e94\u8be5\u628a\u54ea\u4e9b\u7c7b\u522b\u7ed3\u5408\u8d77\u6765\u5462\uff1f\u8fd9\u5e76\u6ca1\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u7b54\u6848\u3002\u8fd9\u53d6\u51b3\u4e8e\u60a8\u7684\u6570\u636e\u548c\u7279\u5f81\u7c7b\u578b\u3002\u4e00\u4e9b\u9886\u57df\u77e5\u8bc6\u5bf9\u4e8e\u521b\u5efa\u8fd9\u6837\u7684\u7279\u5f81\u53ef\u80fd\u5f88\u6709\u7528\u3002\u4f46\u662f\uff0c\u5982\u679c\u4f60\u4e0d\u62c5\u5fc3\u5185\u5b58\u548c CPU \u7684\u4f7f\u7528\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\uff0c\u5373\u521b\u5efa\u8bb8\u591a\u8fd9\u6837\u7684\u7ec4\u5408\uff0c\u7136\u540e\u4f7f\u7528\u4e00\u4e2a\u6a21\u578b\u6765\u51b3\u5b9a\u54ea\u4e9b\u7279\u5f81\u662f\u6709\u7528\u7684\uff0c\u5e76\u4fdd\u7559\u5b83\u4eec\u3002\u6211\u4eec\u5c06\u5728\u672c\u4e66\u7a0d\u540e\u90e8\u5206\u4ecb\u7ecd\u8fd9\u79cd\u65b9\u6cd5\u3002 \u65e0\u8bba\u4f55\u65f6\u83b7\u5f97\u5206\u7c7b\u53d8\u91cf\uff0c\u90fd\u8981\u9075\u5faa\u4ee5\u4e0b\u7b80\u5355\u6b65\u9aa4\uff1a - \u586b\u5145 NaN \u503c\uff08\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff01\uff09\u3002 - \u4f7f\u7528 scikit-learn \u7684 LabelEncoder \u6216\u6620\u5c04\u5b57\u5178\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\uff0c\u5c06\u5b83\u4eec\u8f6c\u6362\u4e3a\u6574\u6570\u3002\u5982\u679c\u6ca1\u6709\u586b\u5145 NaN \u503c\uff0c\u53ef\u80fd\u9700\u8981\u5728\u8fd9\u4e00\u6b65\u4e2d\u8fdb\u884c\u5904\u7406 - \u521b\u5efa\u72ec\u70ed\u7f16\u7801\u3002\u662f\u7684\uff0c\u4f60\u53ef\u4ee5\u8df3\u8fc7\u4e8c\u503c\u5316\uff01 - \u5efa\u6a21\uff01\u6211\u6307\u7684\u662f\u673a\u5668\u5b66\u4e60\u3002 \u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u5904\u7406 NaN \u6570\u636e\u975e\u5e38\u91cd\u8981\uff0c\u5426\u5219\u60a8\u53ef\u80fd\u4f1a\u4ece scikit-learn \u7684 LabelEncoder \u4e2d\u5f97\u5230\u81ed\u540d\u662d\u8457\u7684\u9519\u8bef\u4fe1\u606f\uff1a ValueError: y \u5305\u542b\u4ee5\u524d\u672a\u89c1\u8fc7\u7684\u6807\u7b7e\uff1a [Nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan) \u8fd9\u4ec5\u4ec5\u610f\u5473\u7740\uff0c\u5728\u8f6c\u6362\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6570\u636e\u4e2d\u51fa\u73b0\u4e86 NaN \u503c\u3002\u8fd9\u662f\u56e0\u4e3a\u4f60\u5728\u8bad\u7ec3\u65f6\u5fd8\u8bb0\u4e86\u5904\u7406\u5b83\u4eec\u3002 \u5904\u7406 NaN \u503c \u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u4e22\u5f03\u5b83\u4eec\u3002\u867d\u7136\u7b80\u5355\uff0c\u4f46\u5e76\u4e0d\u7406\u60f3\u3002NaN \u503c\u4e2d\u53ef\u80fd\u5305\u542b\u5f88\u591a\u4fe1\u606f\uff0c\u5982\u679c\u53ea\u662f\u4e22\u5f03\u8fd9\u4e9b\u503c\uff0c\u5c31\u4f1a\u4e22\u5931\u8fd9\u4e9b\u4fe1\u606f\u3002\u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u5927\u90e8\u5206\u6570\u636e\u90fd\u662f NaN \u503c\uff0c\u56e0\u6b64\u4e0d\u80fd\u4e22\u5f03 NaN \u503c\u7684\u884c/\u6837\u672c\u3002\u5904\u7406 NaN \u503c\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u5c06\u5176\u4f5c\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u662f\u5904\u7406 NaN \u503c\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002\u5982\u679c\u4f7f\u7528 pandas\uff0c\u8fd8\u53ef\u4ee5\u901a\u8fc7\u975e\u5e38\u7b80\u5355\u7684\u65b9\u5f0f\u5b9e\u73b0\u3002 \u8bf7\u770b\u6211\u4eec\u4e4b\u524d\u67e5\u770b\u8fc7\u7684\u6570\u636e\u7684 ord_2 \u5217\u3002 In [ X ]: df . ord_2 . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : ord_2 , dtype : int64 \u586b\u5165 NaN \u503c\u540e\uff0c\u5c31\u53d8\u6210\u4e86 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u54c7\uff01\u8fd9\u4e00\u5217\u4e2d\u6709 18075 \u4e2a NaN \u503c\uff0c\u800c\u6211\u4eec\u4e4b\u524d\u751a\u81f3\u90fd\u6ca1\u6709\u8003\u8651\u4f7f\u7528\u5b83\u4eec\u3002\u589e\u52a0\u4e86\u8fd9\u4e2a\u65b0\u7c7b\u522b\u540e\uff0c\u7c7b\u522b\u603b\u6570\u4ece 6 \u4e2a\u589e\u52a0\u5230\u4e86 7 \u4e2a\u3002\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u73b0\u5728\u6211\u4eec\u5728\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u4e5f\u4f1a\u8003\u8651 NaN\u3002\u76f8\u5173\u4fe1\u606f\u8d8a\u591a\uff0c\u6a21\u578b\u5c31\u8d8a\u597d\u3002 \u5047\u8bbe ord_2 \u6ca1\u6709\u4efb\u4f55 NaN \u503c\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e00\u5217\u4e2d\u7684\u6240\u6709\u7c7b\u522b\u90fd\u6709\u663e\u8457\u7684\u8ba1\u6570\u3002\u5176\u4e2d\u6ca1\u6709 \"\u7f55\u89c1 \"\u7c7b\u522b\uff0c\u5373\u53ea\u5728\u6837\u672c\u603b\u6570\u4e2d\u5360\u5f88\u5c0f\u6bd4\u4f8b\u7684\u7c7b\u522b\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5047\u8bbe\u60a8\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u4e86\u4f7f\u7528\u8fd9\u4e00\u5217\u7684\u6a21\u578b\uff0c\u5f53\u6a21\u578b\u6216\u9879\u76ee\u4e0a\u7ebf\u65f6\uff0c\u60a8\u5728 ord_2 \u5217\u4e2d\u5f97\u5230\u4e86\u4e00\u4e2a\u5728\u8bad\u7ec3\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u7ba1\u9053\u4f1a\u629b\u51fa\u4e00\u4e2a\u9519\u8bef\uff0c\u60a8\u5bf9\u6b64\u65e0\u80fd\u4e3a\u529b\u3002\u5982\u679c\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\uff0c\u90a3\u4e48\u53ef\u80fd\u662f\u751f\u4ea7\u4e2d\u7684\u7ba1\u9053\u51fa\u4e86\u95ee\u9898\u3002\u5982\u679c\u8fd9\u662f\u9884\u6599\u4e4b\u4e2d\u7684\uff0c\u90a3\u4e48\u60a8\u5c31\u5fc5\u987b\u4fee\u6539\u60a8\u7684\u6a21\u578b\u7ba1\u9053\uff0c\u5e76\u5728\u8fd9\u516d\u4e2a\u7c7b\u522b\u4e2d\u52a0\u5165\u4e00\u4e2a\u65b0\u7c7b\u522b\u3002 \u8fd9\u4e2a\u65b0\u7c7b\u522b\u88ab\u79f0\u4e3a \"\u7f55\u89c1 \"\u7c7b\u522b\u3002\u7f55\u89c1\u7c7b\u522b\u662f\u4e00\u79cd\u4e0d\u5e38\u89c1\u7684\u7c7b\u522b\uff0c\u53ef\u4ee5\u5305\u62ec\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8fd1\u90bb\u6a21\u578b\u6765 \"\u9884\u6d4b \"\u672a\u77e5\u7c7b\u522b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u4e86\u8fd9\u4e2a\u7c7b\u522b\uff0c\u5b83\u5c31\u4f1a\u6210\u4e3a\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u4e00\u4e2a\u7c7b\u522b\u3002 \u56fe 3\uff1a\u5177\u6709\u4e0d\u540c\u7279\u5f81\u4e14\u65e0\u6807\u7b7e\u7684\u6570\u636e\u96c6\u793a\u610f\u56fe\uff0c\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u53ef\u80fd\u4f1a\u5728\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u51fa\u73b0\u65b0\u503c \u5f53\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u56fe 3 \u6240\u793a\u7684\u6570\u636e\u96c6\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\uff0c\u5bf9\u9664 \"f3 \"\u4e4b\u5916\u7684\u6240\u6709\u7279\u5f81\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u6837\uff0c\u4f60\u5c06\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u4e0d\u77e5\u9053\u6216\u8bad\u7ec3\u4e2d\u6ca1\u6709 \"f3 \"\u65f6\u9884\u6d4b\u5b83\u3002\u6211\u4e0d\u6562\u8bf4\u8fd9\u6837\u7684\u6a21\u578b\u662f\u5426\u80fd\u5e26\u6765\u51fa\u8272\u7684\u6027\u80fd\uff0c\u4f46\u4e5f\u8bb8\u80fd\u5904\u7406\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u7684\u7f3a\u5931\u503c\uff0c\u5c31\u50cf\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5176\u4ed6\u4e8b\u60c5\u4e00\u6837\uff0c\u4e0d\u5c1d\u8bd5\u4e00\u4e0b\u662f\u8bf4\u4e0d\u51c6\u7684\u3002 \u5982\u679c\u4f60\u6709\u4e00\u4e2a\u56fa\u5b9a\u7684\u6d4b\u8bd5\u96c6\uff0c\u4f60\u53ef\u4ee5\u5c06\u6d4b\u8bd5\u6570\u636e\u6dfb\u52a0\u5230\u8bad\u7ec3\u4e2d\uff0c\u4ee5\u4e86\u89e3\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u7c7b\u522b\u3002\u8fd9\u4e0e\u534a\u76d1\u7763\u5b66\u4e60\u975e\u5e38\u76f8\u4f3c\uff0c\u5373\u4f7f\u7528\u65e0\u6cd5\u7528\u4e8e\u8bad\u7ec3\u7684\u6570\u636e\u6765\u6539\u8fdb\u6a21\u578b\u3002\u8fd9\u4e5f\u4f1a\u7167\u987e\u5230\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u51fa\u73b0\u6b21\u6570\u6781\u5c11\u4f46\u5728\u6d4b\u8bd5\u6570\u636e\u4e2d\u5927\u91cf\u5b58\u5728\u7684\u7a00\u6709\u503c\u3002\u4f60\u7684\u6a21\u578b\u5c06\u66f4\u52a0\u7a33\u5065\u3002 \u5f88\u591a\u4eba\u8ba4\u4e3a\u8fd9\u79cd\u60f3\u6cd5\u4f1a\u8fc7\u5ea6\u62df\u5408\u3002\u53ef\u80fd\u8fc7\u62df\u5408\uff0c\u4e5f\u53ef\u80fd\u4e0d\u8fc7\u62df\u5408\u3002\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u3002\u5982\u679c\u4f60\u5728\u8bbe\u8ba1\u4ea4\u53c9\u9a8c\u8bc1\u65f6\uff0c\u80fd\u591f\u5728\u6d4b\u8bd5\u6570\u636e\u4e0a\u8fd0\u884c\u6a21\u578b\u65f6\u590d\u5236\u9884\u6d4b\u8fc7\u7a0b\uff0c\u90a3\u4e48\u5b83\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u8fd9\u610f\u5473\u7740\u7b2c\u4e00\u6b65\u5e94\u8be5\u662f\u5206\u79bb\u6298\u53e0\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u4f60\u5e94\u8be5\u5e94\u7528\u4e0e\u6d4b\u8bd5\u6570\u636e\u76f8\u540c\u7684\u9884\u5904\u7406\u3002\u5047\u8bbe\u60a8\u60f3\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e\uff0c\u90a3\u4e48\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u60a8\u5fc5\u987b\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\uff0c\u5e76\u786e\u4fdd\u9a8c\u8bc1\u6570\u636e\u96c6\u590d\u5236\u4e86\u6d4b\u8bd5\u96c6\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u60a8\u5fc5\u987b\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u8bbe\u8ba1\u9a8c\u8bc1\u96c6\uff0c\u4f7f\u5176\u5305\u542b\u8bad\u7ec3\u96c6\u4e2d \"\u672a\u89c1 \"\u7684\u7c7b\u522b\u3002 \u56fe 4\uff1a\u5bf9\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u8fdb\u884c\u7b80\u5355\u5408\u5e76\uff0c\u4ee5\u4e86\u89e3\u6d4b\u8bd5\u96c6\u4e2d\u5b58\u5728\u4f46\u8bad\u7ec3\u96c6\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u6216\u8bad\u7ec3\u96c6\u4e2d\u7f55\u89c1\u7684\u7c7b\u522b \u53ea\u8981\u770b\u4e00\u4e0b\u56fe 4 \u548c\u4e0b\u9762\u7684\u4ee3\u7801\uff0c\u5c31\u80fd\u5f88\u5bb9\u6613\u7406\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( \"../input/cat_train.csv\" ) # \u8bfb\u53d6\u6d4b\u8bd5\u96c6 test = pd . read_csv ( \"../input/cat_test.csv\" ) # \u5c06\u6d4b\u8bd5\u96c6\"target\"\u5217\u5168\u90e8\u7f6e\u4e3a-1 test . loc [:, \"target\" ] = - 1 # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u62fc\u63a5 data = pd . concat ([ train , test ]) . reset_index ( drop = True ) # \u5c06\u9664\"id\"\u548c\"target\"\u5217\u7684\u5176\u4ed6\u7279\u5f81\u5217\u540d\u53d6\u51fa features = [ x for x in train . columns if x not in [ \"id\" , \"target\" ]] # \u904d\u5386\u7279\u5f81 for feat in features : # \u6807\u7b7e\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\",\u5e76\u5c06\u8be5\u5217\u683c\u5f0f\u53d8\u4e3astr temp_col = data [ feat ] . fillna ( \"NONE\" ) . astype ( str ) . values # \u8f6c\u6362\u6570\u503c data . loc [:, feat ] = lbl_enc . fit_transform ( temp_col ) # \u6839\u636e\"target\"\u5217\u5c06\u8bad\u7ec3\u96c6\u4e0e\u6d4b\u8bd5\u96c6\u5206\u5f00 train = data [ data . target != - 1 ] . reset_index ( drop = True ) test = data [ data . target == - 1 ] . reset_index ( drop = True ) \u5f53\u60a8\u9047\u5230\u5df2\u7ecf\u6709\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u95ee\u9898\u65f6\uff0c\u8fd9\u4e2a\u6280\u5de7\u5c31\u4f1a\u8d77\u4f5c\u7528\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u4e00\u62db\u5728\u5b9e\u65f6\u73af\u5883\u4e2d\u4e0d\u8d77\u4f5c\u7528\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u60a8\u6240\u5728\u7684\u516c\u53f8\u63d0\u4f9b\u5b9e\u65f6\u7ade\u4ef7\u89e3\u51b3\u65b9\u6848\uff08RTB\uff09\u3002RTB \u7cfb\u7edf\u4f1a\u5bf9\u5728\u7ebf\u770b\u5230\u7684\u6bcf\u4e2a\u7528\u6237\u8fdb\u884c\u7ade\u4ef7\uff0c\u4ee5\u8d2d\u4e70\u5e7f\u544a\u7a7a\u95f4\u3002\u8fd9\u79cd\u6a21\u5f0f\u53ef\u4f7f\u7528\u7684\u529f\u80fd\u53ef\u80fd\u5305\u62ec\u7f51\u7ad9\u4e2d\u6d4f\u89c8\u7684\u9875\u9762\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u7279\u5f81\u662f\u7528\u6237\u8bbf\u95ee\u7684\u6700\u540e\u4e94\u4e2a\u7c7b\u522b/\u9875\u9762\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u7f51\u7ad9\u5f15\u5165\u4e86\u65b0\u7684\u7c7b\u522b\uff0c\u6211\u4eec\u5c06\u65e0\u6cd5\u518d\u51c6\u786e\u9884\u6d4b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u5931\u6548\u3002\u8fd9\u79cd\u60c5\u51b5\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528 \"\u672a\u77e5 \"\u7c7b\u522b\u6765\u907f\u514d \u3002 \u5728\u6211\u4eec\u7684 cat-in-the-dat \u6570\u636e\u96c6\u4e2d\uff0c ord_2 \u5217\u4e2d\u5df2\u7ecf\u6709\u4e86\u672a\u77e5\u7c7b\u522b\u3002 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u6211\u4eec\u53ef\u4ee5\u5c06 \"NONE \"\u89c6\u4e3a\u672a\u77e5\u3002\u56e0\u6b64\uff0c\u5982\u679c\u5728\u5b9e\u65f6\u6d4b\u8bd5\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4ee5\u524d\u4ece\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\uff0c\u6211\u4eec\u5c31\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a \"NONE\"\u3002 \u8fd9\u4e0e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u95ee\u9898\u975e\u5e38\u76f8\u4f3c\u3002\u6211\u4eec\u603b\u662f\u57fa\u4e8e\u56fa\u5b9a\u7684\u8bcd\u6c47\u5efa\u7acb\u6a21\u578b\u3002\u589e\u52a0\u8bcd\u6c47\u91cf\u5c31\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u3002\u50cf BERT \u8fd9\u6837\u7684\u8f6c\u6362\u5668\u6a21\u578b\u662f\u5728 ~30000 \u4e2a\u5355\u8bcd\uff08\u82f1\u8bed\uff09\u7684\u57fa\u7840\u4e0a\u8bad\u7ec3\u7684\u3002\u56e0\u6b64\uff0c\u5f53\u6709\u65b0\u8bcd\u8f93\u5165\u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a UNK\uff08\u672a\u77e5\uff09\u3002 \u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u5047\u8bbe\u6d4b\u8bd5\u6570\u636e\u4e0e\u8bad\u7ec3\u6570\u636e\u5177\u6709\u76f8\u540c\u7684\u7c7b\u522b\uff0c\u4e5f\u53ef\u4ee5\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u5f15\u5165\u7f55\u89c1\u6216\u672a\u77e5\u7c7b\u522b\uff0c\u4ee5\u5904\u7406\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u65b0\u7c7b\u522b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u586b\u5165 NaN \u503c\u540e ord_4 \u5217\u7684\u503c\u8ba1\u6570\uff1a In [ X ]: df . ord_4 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 . . . K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 G 3404 V 3107 J 1950 L 1657 Name : ord_4 , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u6709\u4e9b\u6570\u503c\u53ea\u51fa\u73b0\u4e86\u51e0\u5343\u6b21\uff0c\u6709\u4e9b\u5219\u51fa\u73b0\u4e86\u8fd1 40000 \u6b21\u3002NaN \u4e5f\u7ecf\u5e38\u51fa\u73b0\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u503c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u5c06\u4e00\u4e2a\u503c\u79f0\u4e3a \" \u7f55\u89c1\uff08rare\uff09 \"\u7684\u6807\u51c6\u4e86\u3002\u6bd4\u65b9\u8bf4\uff0c\u5728\u8fd9\u4e00\u5217\u4e2d\uff0c\u7a00\u6709\u503c\u7684\u8981\u6c42\u662f\u8ba1\u6570\u5c0f\u4e8e 2000\u3002\u8fd9\u6837\u770b\u6765\uff0cJ \u548c L \u5c31\u53ef\u4ee5\u88ab\u6807\u8bb0\u4e3a\u7a00\u6709\u503c\u4e86\u3002\u4f7f\u7528 pandas\uff0c\u6839\u636e\u8ba1\u6570\u9608\u503c\u66ff\u6362\u7c7b\u522b\u975e\u5e38\u7b80\u5355\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 In [ X ]: df . ord_4 = df . ord_4 . fillna ( \"NONE\" ) In [ X ]: df . loc [ ... : df [ \"ord_4\" ] . value_counts ()[ df [ \"ord_4\" ]] . values < 2000 , ... : \"ord_4\" ... : ] = \"RARE\" In [ X ]: df . ord_4 . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 M 32504 . . . B 25212 E 21871 K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 RARE 3607 G 3404 V 3107 Name : ord_4 , dtype : int64 \u6211\u4eec\u8ba4\u4e3a\uff0c\u53ea\u8981\u67d0\u4e2a\u7c7b\u522b\u7684\u503c\u5c0f\u4e8e 2000\uff0c\u5c31\u5c06\u5176\u66ff\u6362\u4e3a\u7f55\u89c1\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u5728\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6240\u6709\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"RARE\"\uff0c\u800c\u6240\u6709\u7f3a\u5931\u503c\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"NONE\"\u3002 \u8fd9\u79cd\u65b9\u6cd5\u8fd8\u80fd\u786e\u4fdd\u5373\u4f7f\u6709\u65b0\u7684\u7c7b\u522b\uff0c\u6a21\u578b\u4e5f\u80fd\u5728\u5b9e\u9645\u73af\u5883\u4e2d\u6b63\u5e38\u5de5\u4f5c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u5177\u5907\u4e86\u5904\u7406\u4efb\u4f55\u5e26\u6709\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u6240\u9700\u7684\u4e00\u5207\u6761\u4ef6\u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u7b2c\u4e00\u4e2a\u6a21\u578b\uff0c\u5e76\u9010\u6b65\u63d0\u9ad8\u5176\u6027\u80fd\u3002 \u5728\u6784\u5efa\u4efb\u4f55\u7c7b\u578b\u7684\u6a21\u578b\u4e4b\u524d\uff0c\u4ea4\u53c9\u68c0\u9a8c\u81f3\u5173\u91cd\u8981\u3002\u6211\u4eec\u5df2\u7ecf\u770b\u5230\u4e86\u6807\u7b7e/\u76ee\u6807\u5206\u5e03\uff0c\u77e5\u9053\u8fd9\u662f\u4e00\u4e2a\u76ee\u6807\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u6765\u5206\u5272\u6570\u636e\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u6dfb\u52a0\"kfold\"\u5217\uff0c\u5e76\u7f6e\u4e3a-1 df [ \"kfold\" ] = - 1 # \u6253\u4e71\u6570\u636e\u987a\u5e8f\uff0c\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u5c06\u76ee\u6807\u5217\u53d6\u51fa y = df . target . values # \u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): # \u533a\u5206\u6298\u53e0 df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u6587\u4ef6 df . to_csv ( \"../input/cat_train_folds.csv\" , index = False ) \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u68c0\u67e5\u65b0\u7684\u6298\u53e0 csv\uff0c\u67e5\u770b\u6bcf\u4e2a\u6298\u53e0\u7684\u6837\u672c\u6570\uff1a In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) In [ X ]: df . kfold . value_counts () Out [ X ]: 4 120000 3 120000 2 120000 1 120000 0 120000 Name : kfold , dtype : int64 \u6240\u6709\u6298\u53e0\u90fd\u6709 120000 \u4e2a\u6837\u672c\u3002\u8fd9\u662f\u610f\u6599\u4e4b\u4e2d\u7684\uff0c\u56e0\u4e3a\u8bad\u7ec3\u6570\u636e\u6709 600000 \u4e2a\u6837\u672c\uff0c\u800c\u6211\u4eec\u505a\u4e865\u6b21\u6298\u53e0\u3002\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4e00\u5207\u987a\u5229\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u6bcf\u4e2a\u6298\u53e0\u7684\u76ee\u6807\u5206\u5e03\u3002 In [ X ]: df [ df . kfold == 0 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 1 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 2 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 3 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 4 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u76ee\u6807\u7684\u5206\u5e03\u90fd\u662f\u4e00\u6837\u7684\u3002\u8fd9\u6b63\u662f\u6211\u4eec\u6240\u9700\u8981\u7684\u3002\u5b83\u4e5f\u53ef\u4ee5\u662f\u76f8\u4f3c\u7684\uff0c\u5e76\u4e0d\u4e00\u5b9a\u8981\u4e00\u76f4\u76f8\u540c\u3002\u73b0\u5728\uff0c\u5f53\u6211\u4eec\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u6807\u7b7e\u5206\u5e03\u90fd\u5c06\u76f8\u540c\u3002 \u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u7684\u6700\u7b80\u5355\u7684\u6a21\u578b\u4e4b\u4e00\u662f\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u5e76\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): # \u8bfb\u53d6\u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) # \u53d6\u9664\"id\", \"target\", \"kfold\"\u5916\u7684\u5176\u4ed6\u7279\u5f81\u5217 features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] # \u904d\u5386\u7279\u5f81\u5217\u8868 for col in features : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u6d4b\u8bd5\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u903b\u8f91\u56de\u5f52 model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . target . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( auc ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6298\u53e00 run ( 0 ) \u90a3\u4e48\uff0c\u53d1\u751f\u4e86\u4ec0\u4e48\u5462\uff1f \u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u5c06\u6570\u636e\u5206\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u4e24\u90e8\u5206\uff0c\u7ed9\u5b9a\u6298\u53e0\u6570\uff0c\u5904\u7406 NaN \u503c\uff0c\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u5355\u6b21\u7f16\u7801\uff0c\u5e76\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\u3002 \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u90e8\u5206\u4ee3\u7801\u65f6\uff0c\u4f1a\u4ea7\u751f\u5982\u4e0b\u8f93\u51fa\uff1a \u276f python ohe_logres . py / home / abhishek / miniconda3 / envs / ml / lib / python3 .7 / site - packages / sklearn / linear_model / _logistic . py : 939 : ConvergenceWarning : lbfgs failed to converge ( status = 1 ): STOP : TOTAL NO . of ITERATIONS REACHED LIMIT . Increase the number of iterations ( max_iter ) or scale the data as shown in : https : // scikit - learn . org / stable / modules / preprocessing . html . Please also refer to the documentation for alternative solver options : https : // scikit - learn . org / stable / modules / linear_model . html #logistic- regression extra_warning_msg = _LOGISTIC_SOLVER_CONVERGENCE_MSG ) 0.7847865042255127 \u6709\u4e00\u4e9b\u8b66\u544a\u3002\u903b\u8f91\u56de\u5f52\u4f3c\u4e4e\u6ca1\u6709\u6536\u655b\u5230\u6700\u5927\u8fed\u4ee3\u6b21\u6570\u3002\u6211\u4eec\u6ca1\u6709\u8c03\u6574\u53c2\u6570\uff0c\u6240\u4ee5\u6ca1\u6709\u95ee\u9898\u3002\u6211\u4eec\u770b\u5230 AUC \u4e3a 0.785\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u5bf9\u4ee3\u7801\u8fdb\u884c\u7b80\u5355\u4fee\u6539\uff0c\u8fd0\u884c\u6240\u6709\u6298\u53e0\u3002 .... model = linear_model . LogisticRegression () model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u5faa\u73af\u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5e76\u6ca1\u6709\u505a\u5f88\u5927\u7684\u6539\u52a8\uff0c\u6240\u4ee5\u6211\u53ea\u663e\u793a\u4e86\u90e8\u5206\u4ee3\u7801\u884c\uff0c\u5176\u4e2d\u4e00\u4e9b\u4ee3\u7801\u884c\u6709\u6539\u52a8\u3002 \u8fd9\u5c31\u6253\u5370\u51fa\u4e86\uff1a python - W ignore ohe_logres . py Fold = 0 , AUC = 0.7847865042255127 Fold = 1 , AUC = 0.7853553605899214 Fold = 2 , AUC = 0.7879321942914885 Fold = 3 , AUC = 0.7870315929550808 Fold = 4 , AUC = 0.7864668243125608 \u8bf7\u6ce8\u610f\uff0c\u6211\u4f7f\u7528\"-W ignore \"\u5ffd\u7565\u4e86\u6240\u6709\u8b66\u544a\u3002 \u6211\u4eec\u770b\u5230\uff0cAUC \u5206\u6570\u5728\u6240\u6709\u8936\u76b1\u4e2d\u90fd\u76f8\u5f53\u7a33\u5b9a\u3002\u5e73\u5747 AUC \u4e3a 0.78631449527\u3002\u5bf9\u4e8e\u6211\u4eec\u7684\u7b2c\u4e00\u4e2a\u6a21\u578b\u6765\u8bf4\u76f8\u5f53\u4e0d\u9519\uff01 \u5f88\u591a\u4eba\u5728\u9047\u5230\u8fd9\u79cd\u95ee\u9898\u65f6\u4f1a\u9996\u5148\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6bd4\u5982\u968f\u673a\u68ee\u6797\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u5e94\u7528\u968f\u673a\u68ee\u6797\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\uff08label encoding\uff09\uff0c\u5c06\u6bcf\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81\u90fd\u8f6c\u6362\u4e3a\u6574\u6570\uff0c\u800c\u4e0d\u662f\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u72ec\u70ed\u7f16\u7801\u3002 \u8fd9\u79cd\u7f16\u7801\u4e0e\u72ec\u70ed\u7f16\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # \u968f\u673a\u68ee\u6797\u6a21\u578b model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u4f7f\u7528 scikit-learn \u4e2d\u7684\u968f\u673a\u68ee\u6797\uff0c\u5e76\u53d6\u6d88\u4e86\u72ec\u70ed\u7f16\u7801\u3002\u6211\u4eec\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u4ee3\u66ff\u72ec\u70ed\u7f16\u7801\u3002\u5f97\u5206\u5982\u4e0b \u276f python lbl_rf . py Fold = 0 , AUC = 0.7167390828113697 Fold = 1 , AUC = 0.7165459672958506 Fold = 2 , AUC = 0.7159709909587376 Fold = 3 , AUC = 0.7161589664189556 Fold = 4 , AUC = 0.7156020216155978 \u54c7 \u5de8\u5927\u7684\u5dee\u5f02\uff01 \u968f\u673a\u68ee\u6797\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8d85\u53c2\u6570\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u8868\u73b0\u8981\u6bd4\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u5dee\u5f88\u591a\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u603b\u662f\u5e94\u8be5\u5148\u4ece\u7b80\u5355\u6a21\u578b\u5f00\u59cb\u7684\u539f\u56e0\u3002\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u7c89\u4e1d\u4f1a\u4ece\u8fd9\u91cc\u5f00\u59cb\uff0c\u800c\u5ffd\u7565\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u8ba4\u4e3a\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u6a21\u578b\uff0c\u4e0d\u80fd\u5e26\u6765\u6bd4\u968f\u673a\u68ee\u6797\u66f4\u597d\u7684\u4ef7\u503c\u3002\u8fd9\u79cd\u4eba\u5c06\u4f1a\u72af\u4e0b\u5927\u9519\u3002\u5728\u6211\u4eec\u5b9e\u73b0\u968f\u673a\u68ee\u6797\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4e0e\u903b\u8f91\u56de\u5f52\u76f8\u6bd4\uff0c\u6298\u53e0\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u624d\u80fd\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e0d\u4ec5\u635f\u5931\u4e86 AUC\uff0c\u8fd8\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u6765\u5b8c\u6210\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u8fdb\u884c\u63a8\u7406\u4e5f\u5f88\u8017\u65f6\uff0c\u800c\u4e14\u5360\u7528\u7684\u7a7a\u95f4\u4e5f\u66f4\u5927\u3002 \u5982\u679c\u6211\u4eec\u613f\u610f\uff0c\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u5728\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u6570\u636e\u4e0a\u8fd0\u884c\u968f\u673a\u68ee\u6797\uff0c\u4f46\u8fd9\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u5947\u5f02\u503c\u5206\u89e3\u6765\u51cf\u5c11\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u77e9\u9635\u3002\u8fd9\u662f\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u63d0\u53d6\u4e3b\u9898\u7684\u5e38\u7528\u65b9\u6cd5\u3002 import pandas as pd from scipy import sparse from sklearn import decomposition from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" )] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) x_train = ohe . transform ( df_train [ features ]) x_valid = ohe . transform ( df_valid [ features ]) # \u5947\u5f02\u503c\u5206\u89e3 svd = decomposition . TruncatedSVD ( n_components = 120 ) full_sparse = sparse . vstack (( x_train , x_valid )) svd . fit ( full_sparse ) x_train = svd . transform ( x_train ) x_valid = svd . transform ( x_valid ) model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u5bf9\u5168\u90e8\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\uff0c\u7136\u540e\u7528\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u5728\u7a00\u758f\u77e9\u9635\u4e0a\u62df\u5408 scikit-learn \u7684 TruncatedSVD\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c06\u9ad8\u7ef4\u7a00\u758f\u77e9\u9635\u51cf\u5c11\u5230 120 \u4e2a\u7279\u5f81\uff0c\u7136\u540e\u62df\u5408\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u3002 \u4ee5\u4e0b\u662f\u8be5\u6a21\u578b\u7684\u8f93\u51fa\u7ed3\u679c\uff1a \u276f python ohe_svd_rf . py Fold = 0 , AUC = 0.7064863038754249 Fold = 1 , AUC = 0.706050102937374 Fold = 2 , AUC = 0.7086069243167242 Fold = 3 , AUC = 0.7066819080085971 Fold = 4 , AUC = 0.7058154015055585 \u6211\u4eec\u53d1\u73b0\u60c5\u51b5\u66f4\u7cdf\u3002\u770b\u6765\uff0c\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u7684\u6700\u4f73\u65b9\u6cd5\u662f\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u3002\u968f\u673a\u68ee\u6797\u4f3c\u4e4e\u8017\u65f6\u592a\u591a\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u8bd5\u8bd5 XGBoost\u3002\u5982\u679c\u4f60\u4e0d\u77e5\u9053 XGBoost\uff0c\u5b83\u662f\u6700\u6d41\u884c\u7684\u68af\u5ea6\u63d0\u5347\u7b97\u6cd5\u4e4b\u4e00\u3002\u7531\u4e8e\u5b83\u662f\u4e00\u79cd\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u6570\u636e\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 , n_estimators = 200 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6211\u5bf9 xgboost \u53c2\u6570\u505a\u4e86\u4e00\u4e9b\u4fee\u6539\u3002xgboost \u7684\u9ed8\u8ba4\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u662f 3\uff0c\u6211\u628a\u5b83\u6539\u6210\u4e86 7\uff0c\u8fd8\u628a\u4f30\u8ba1\u5668\u6570\u91cf\uff08n_estimators\uff09\u4ece 100 \u6539\u6210\u4e86 200\u3002 \u8be5\u6a21\u578b\u7684 5 \u6298\u4ea4\u53c9\u68c0\u9a8c\u5f97\u5206\u5982\u4e0b\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.7656768851999011 Fold = 1 , AUC = 0.7633006564148015 Fold = 2 , AUC = 0.7654277821434345 Fold = 3 , AUC = 0.7663609758878182 Fold = 4 , AUC = 0.764914671468069 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u4e0d\u505a\u4efb\u4f55\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u5f97\u5206\u6bd4\u666e\u901a\u968f\u673a\u68ee\u6797\u8981\u9ad8\u5f97\u591a\u3002 \u60a8\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4e00\u4e9b\u7279\u5f81\u5de5\u7a0b\uff0c\u653e\u5f03\u67d0\u4e9b\u5bf9\u6a21\u578b\u6ca1\u6709\u4efb\u4f55\u4ef7\u503c\u7684\u5217\u7b49\u3002\u4f46\u4f3c\u4e4e\u6211\u4eec\u80fd\u505a\u7684\u4e0d\u591a\uff0c\u65e0\u6cd5\u8bc1\u660e\u6a21\u578b\u7684\u6539\u8fdb\u3002\u8ba9\u6211\u4eec\u628a\u6570\u636e\u96c6\u6362\u6210\u53e6\u4e00\u4e2a\u6709\u5927\u91cf\u5206\u7c7b\u53d8\u91cf\u7684\u6570\u636e\u96c6\u3002\u53e6\u4e00\u4e2a\u6709\u540d\u7684\u6570\u636e\u96c6\u662f \u7f8e\u56fd\u6210\u4eba\u4eba\u53e3\u666e\u67e5\u6570\u636e\uff08US adult census data\uff09 \u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u5305\u542b\u4e00\u4e9b\u7279\u5f81\uff0c\u800c\u4f60\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5de5\u8d44\u7b49\u7ea7\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u3002\u56fe 5 \u663e\u793a\u4e86\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e9b\u5217\u3002 \u56fe 5\uff1a\u90e8\u5206\u6570\u636e\u96c6\u5c55\u793a \u8be5\u6570\u636e\u96c6\u6709\u4ee5\u4e0b\u51e0\u5217\uff1a - \u5e74\u9f84\uff08age\uff09 \u5de5\u4f5c\u7c7b\u522b\uff08workclass\uff09 \u5b66\u5386\uff08fnlwgt\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education.num\uff09 \u5a5a\u59fb\u72b6\u51b5\uff08marital.status\uff09 \u804c\u4e1a\uff08occupation\uff09 \u5173\u7cfb\uff08relationship\uff09 \u79cd\u65cf\uff08race\uff09 \u6027\u522b\uff08sex\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u539f\u7c4d\u56fd\uff08native.country\uff09 \u6536\u5165\uff08income\uff09 \u8fd9\u4e9b\u7279\u5f81\u5927\u591a\u4e0d\u8a00\u81ea\u660e\u3002\u90a3\u4e9b\u4e0d\u660e\u767d\u7684\uff0c\u6211\u4eec\u53ef\u4ee5\u4e0d\u8003\u8651\u3002\u8ba9\u6211\u4eec\u5148\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002 \u6211\u4eec\u770b\u5230\u6536\u5165\u5217\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u3002\u8ba9\u6211\u4eec\u5bf9\u8fd9\u4e00\u5217\u8fdb\u884c\u6570\u503c\u7edf\u8ba1\u3002 In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/adult.csv\" ) In [ X ]: df . income . value_counts () Out [ X ]: <= 50 K 24720 > 50 K 7841 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 7841 \u4e2a\u5b9e\u4f8b\u7684\u6536\u5165\u8d85\u8fc7 5 \u4e07\u7f8e\u5143\u3002\u8fd9\u5360\u6837\u672c\u603b\u6570\u7684 24%\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u4e0e\u732b\u6570\u636e\u96c6\u76f8\u540c\u7684\u8bc4\u4f30\u65b9\u6cd5\uff0c\u5373 AUC\u3002 \u5728\u5f00\u59cb\u5efa\u6a21\u4e4b\u524d\uff0c\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5c06\u53bb\u6389\u51e0\u5217\u7279\u5f81\uff0c\u5373 \u5b66\u5386\uff08fnlwgt\uff09 \u5e74\u9f84\uff08age\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u8ba9\u6211\u4eec\u8bd5\u7740\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u5668\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u7b2c\u4e00\u6b65\u603b\u662f\u8981\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002\u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u90e8\u5206\u4ee3\u7801\u3002\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u9700\u8981\u5220\u9664\u7684\u5217 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) # \u6620\u5c04 target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } # \u4f7f\u7528\u6620\u5c04\u66ff\u6362 df . loc [:, \"income\" ] = df . income . map ( target_mapping ) # \u53d6\u9664\"kfold\", \"income\"\u5217\u7684\u5176\u4ed6\u5217\u540d features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u9a8c\u8bc1\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u6784\u5efa\u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . income . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u6211\u4eec\u4f1a\u5f97\u5230 \u276f python - W ignore ohe_logres . py Fold = 0 , AUC = 0.8794809708119079 Fold = 1 , AUC = 0.8875785068274882 Fold = 2 , AUC = 0.8852609687685753 Fold = 3 , AUC = 0.8681236223251438 Fold = 4 , AUC = 0.8728581541840037 \u5bf9\u4e8e\u4e00\u4e2a\u5982\u6b64\u7b80\u5355\u7684\u6a21\u578b\u6765\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u4e0d\u9519\u7684 AUC\uff01 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5728\u4e0d\u8c03\u6574\u4efb\u4f55\u8d85\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c1d\u8bd5\u4e00\u4e0b\u6807\u7b7e\u7f16\u7801\u7684xgboost\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8ba9\u6211\u4eec\u8fd0\u884c\u4e0a\u9762\u4ee3\u7801\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8800810634234078 Fold = 1 , AUC = 0.886811884948154 Fold = 2 , AUC = 0.8854421433318472 Fold = 3 , AUC = 0.8676319549361007 Fold = 4 , AUC = 0.8714450054900602 \u8fd9\u770b\u8d77\u6765\u5df2\u7ecf\u76f8\u5f53\u4e0d\u9519\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b max_depth \u589e\u52a0\u5230 7 \u548c n_estimators \u589e\u52a0\u5230 200 \u65f6\u7684\u5f97\u5206\u3002 \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8764108944332032 Fold = 1 , AUC = 0.8840708537662638 Fold = 2 , AUC = 0.8816601162613102 Fold = 3 , AUC = 0.8662335762581732 Fold = 4 , AUC = 0.8698983461709926 \u770b\u8d77\u6765\u5e76\u6ca1\u6709\u6539\u5584\u3002 \u8fd9\u8868\u660e\uff0c\u4e00\u4e2a\u6570\u636e\u96c6\u7684\u53c2\u6570\u4e0d\u80fd\u79fb\u690d\u5230\u53e6\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6211\u4eec\u5fc5\u987b\u518d\u6b21\u5c1d\u8bd5\u8c03\u6574\u53c2\u6570\uff0c\u4f46\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u8bf4\u660e\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u4e0d\u8c03\u6574\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c06\u6570\u503c\u7279\u5f81\u7eb3\u5165 xgboost \u6a21\u578b\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u52a0\u5165\u6570\u503c\u7279\u5f81 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : if col not in num_cols : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u56e0\u6b64\uff0c\u6211\u4eec\u4fdd\u7559\u6570\u5b57\u5217\uff0c\u53ea\u662f\u4e0d\u5bf9\u5176\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6700\u7ec8\u7279\u5f81\u77e9\u9635\u5c31\u7531\u6570\u5b57\u5217\uff08\u539f\u6837\uff09\u548c\u7f16\u7801\u5206\u7c7b\u5217\u7ec4\u6210\u4e86\u3002\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\u90fd\u80fd\u8f7b\u677e\u5904\u7406\u8fd9\u79cd\u6df7\u5408\u3002 \u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff0c\u5728\u4f7f\u7528\u7ebf\u6027\u6a21\u578b\uff08\u5982\u903b\u8f91\u56de\u5f52\uff09\u65f6\u4e0d\u5bb9\u5ffd\u89c6\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u4e2a\u811a\u672c\uff01 \u276f python lbl_xgb_num . py Fold = 0 , AUC = 0.9209790185449889 Fold = 1 , AUC = 0.9247157449144706 Fold = 2 , AUC = 0.9269329887598243 Fold = 3 , AUC = 0.9119349082169275 Fold = 4 , AUC = 0.9166408030141667 \u54c7\u54e6 \u8fd9\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u5206\u6570\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c1d\u8bd5\u6dfb\u52a0\u4e00\u4e9b\u529f\u80fd\u3002\u6211\u4eec\u5c06\u63d0\u53d6\u6240\u6709\u5206\u7c7b\u5217\uff0c\u5e76\u521b\u5efa\u6240\u6709\u4e8c\u5ea6\u7ec4\u5408\u3002\u8bf7\u770b\u4e0b\u9762\u4ee3\u7801\u6bb5\u4e2d\u7684 feature_engineering \u51fd\u6570\uff0c\u4e86\u89e3\u5982\u4f55\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002 import itertools import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def feature_engineering ( df , cat_cols ): # \u751f\u6210\u4e24\u4e2a\u7279\u5f81\u7684\u7ec4\u5408 combi = list ( itertools . combinations ( cat_cols , 2 )) for c1 , c2 in combi : df . loc [:, c1 + \"_\" + c2 ] = df [ c1 ] . astype ( str ) + \"_\" + df [ c2 ] . astype ( str ) return df def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) cat_cols = [ c for c in df . columns if c not in num_cols and c not in ( \"kfold\" , \"income\" )] # \u7279\u5f81\u5de5\u7a0b df = feature_engineering ( df , cat_cols ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" )] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u8fd9\u662f\u4ece\u5206\u7c7b\u5217\u4e2d\u521b\u5efa\u7279\u5f81\u7684\u4e00\u79cd\u975e\u5e38\u5e7c\u7a1a\u7684\u65b9\u6cd5\u3002\u6211\u4eec\u5e94\u8be5\u4ed4\u7ec6\u7814\u7a76\u6570\u636e\uff0c\u770b\u770b\u54ea\u4e9b\u7ec4\u5408\u6700\u5408\u7406\u3002\u5982\u679c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5c31\u9700\u8981\u4f7f\u7528\u67d0\u79cd\u7279\u5f81\u9009\u62e9\u6765\u9009\u51fa\u6700\u4f73\u7279\u5f81\u3002\u7a0d\u540e\u6211\u4eec\u5c06\u8be6\u7ec6\u4ecb\u7ecd\u7279\u5f81\u9009\u62e9\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5206\u6570\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9211483465031423 Fold = 1 , AUC = 0.9251499446866125 Fold = 2 , AUC = 0.9262344766486692 Fold = 3 , AUC = 0.9114264068794995 Fold = 4 , AUC = 0.9177914453099201 \u770b\u6765\uff0c\u5373\u4f7f\u4e0d\u6539\u53d8\u4efb\u4f55\u8d85\u53c2\u6570\uff0c\u53ea\u589e\u52a0\u4e00\u4e9b\u7279\u5f81\uff0c\u6211\u4eec\u4e5f\u80fd\u63d0\u9ad8\u4e00\u4e9b\u6298\u53e0\u5f97\u5206\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5c06 max_depth \u589e\u52a0\u5230 7 \u662f\u5426\u6709\u5e2e\u52a9\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9286668430204137 Fold = 1 , AUC = 0.9329340656165378 Fold = 2 , AUC = 0.9319817543218744 Fold = 3 , AUC = 0.919046187194538 Fold = 4 , AUC = 0.9245692057162671 \u6211\u4eec\u518d\u6b21\u6539\u8fdb\u4e86\u6211\u4eec\u7684\u6a21\u578b\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u8fd8\u6ca1\u6709\u4f7f\u7528\u7a00\u6709\u503c\u3001\u4e8c\u503c\u5316\u3001\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u7279\u5f81\u7684\u7ec4\u5408\u4ee5\u53ca\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u3002 \u4ece\u5206\u7c7b\u7279\u5f81\u4e2d\u8fdb\u884c\u7279\u5f81\u5de5\u7a0b\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u4f7f\u7528 \u76ee\u6807\u7f16\u7801 \u3002\u4f46\u662f\uff0c\u60a8\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u8fd9\u53ef\u80fd\u4f1a\u4f7f\u60a8\u7684\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u76ee\u6807\u7f16\u7801\u662f\u4e00\u79cd\u5c06\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u6620\u5c04\u5230\u5176\u5e73\u5747\u76ee\u6807\u503c\u7684\u6280\u672f\uff0c\u4f46\u5fc5\u987b\u59cb\u7ec8\u4ee5\u4ea4\u53c9\u9a8c\u8bc1\u7684\u65b9\u5f0f\u8fdb\u884c\u3002\u8fd9\u610f\u5473\u7740\u9996\u5148\u8981\u521b\u5efa\u6298\u53e0\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u6298\u53e0\u4e3a\u6570\u636e\u7684\u4e0d\u540c\u5217\u521b\u5efa\u76ee\u6807\u7f16\u7801\u7279\u5f81\uff0c\u65b9\u6cd5\u4e0e\u5728\u6298\u53e0\u4e0a\u62df\u5408\u548c\u9884\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u521b\u5efa\u4e86 5 \u4e2a\u6298\u53e0\uff0c\u60a8\u5c31\u5fc5\u987b\u521b\u5efa 5 \u6b21\u76ee\u6807\u7f16\u7801\uff0c\u8fd9\u6837\u6700\u7ec8\uff0c\u60a8\u5c31\u53ef\u4ee5\u4e3a\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u53d8\u91cf\u521b\u5efa\u7f16\u7801\uff0c\u800c\u8fd9\u4e9b\u53d8\u91cf\u5e76\u975e\u6765\u81ea\u540c\u4e00\u4e2a\u6298\u53e0\u3002\u7136\u540e\u5728\u62df\u5408\u6a21\u578b\u65f6\uff0c\u5fc5\u987b\u518d\u6b21\u4f7f\u7528\u76f8\u540c\u7684\u6298\u53e0\u3002\u672a\u89c1\u6d4b\u8bd5\u6570\u636e\u7684\u76ee\u6807\u7f16\u7801\u53ef\u4ee5\u6765\u81ea\u5168\u90e8\u8bad\u7ec3\u6570\u636e\uff0c\u4e5f\u53ef\u4ee5\u662f\u6240\u6709 5 \u4e2a\u6298\u53e0\u7684\u5e73\u5747\u503c\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u540c\u4e00\u4e2a\u6210\u4eba\u6570\u636e\u96c6\u4e0a\u4f7f\u7528\u76ee\u6807\u7f16\u7801\uff0c\u4ee5\u4fbf\u8fdb\u884c\u6bd4\u8f83\u3002 import copy import pandas as pd from sklearn import metrics from sklearn import preprocessing import xgboost as xgb def mean_target_encoding ( data ): df = copy . deepcopy ( data ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) and f not in num_cols ] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) encoded_dfs = [] for fold in range ( 5 ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) for column in features : # \u76ee\u6807\u7f16\u7801 mapping_dict = dict ( df_train . groupby ( column )[ \"income\" ] . mean () ) df_valid . loc [:, column + \"_enc\" ] = df_valid [ column ] . map ( mapping_dict ) encoded_dfs . append ( df_valid ) encoded_df = pd . concat ( encoded_dfs , axis = 0 ) return encoded_df def run ( df , fold ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/adult_folds.csv\" ) df = mean_target_encoding ( df ) for fold_ in range ( 5 ): run ( df , fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u4e0a\u8ff0\u7247\u6bb5\u4e2d\uff0c\u6211\u5728\u8fdb\u884c\u76ee\u6807\u7f16\u7801\u65f6\u5e76\u6ca1\u6709\u5220\u9664\u5206\u7c7b\u5217\u3002\u6211\u4fdd\u7559\u4e86\u6240\u6709\u7279\u5f81\uff0c\u5e76\u5728\u6b64\u57fa\u7840\u4e0a\u6dfb\u52a0\u4e86\u76ee\u6807\u7f16\u7801\u7279\u5f81\u3002\u6b64\u5916\uff0c\u6211\u8fd8\u4f7f\u7528\u4e86\u5e73\u5747\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5e73\u5747\u503c\u3001\u4e2d\u4f4d\u6570\u3001\u6807\u51c6\u504f\u5dee\u6216\u76ee\u6807\u7684\u4efb\u4f55\u5176\u4ed6\u51fd\u6570\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u7ed3\u679c\u3002 Fold = 0 , AUC = 0.9332240662017529 Fold = 1 , AUC = 0.9363551625140347 Fold = 2 , AUC = 0.9375013544556173 Fold = 3 , AUC = 0.92237621307625 Fold = 4 , AUC = 0.9292131180445478 \u4e0d\u9519\uff01\u770b\u6765\u6211\u4eec\u53c8\u6709\u8fdb\u6b65\u4e86\u3002\u4e0d\u8fc7\uff0c\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u5b83\u592a\u5bb9\u6613\u51fa\u73b0\u8fc7\u5ea6\u62df\u5408\u3002\u5f53\u6211\u4eec\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\uff0c\u6700\u597d\u4f7f\u7528\u67d0\u79cd\u5e73\u6ed1\u65b9\u6cd5\u6216\u5728\u7f16\u7801\u503c\u4e2d\u6dfb\u52a0\u566a\u58f0\u3002 Scikit-learn \u7684\u8d21\u732e\u5e93\u4e2d\u6709\u5e26\u5e73\u6ed1\u7684\u76ee\u6807\u7f16\u7801\uff0c\u4f60\u4e5f\u53ef\u4ee5\u521b\u5efa\u81ea\u5df1\u7684\u5e73\u6ed1\u3002\u5e73\u6ed1\u4f1a\u5f15\u5165\u67d0\u79cd\u6b63\u5219\u5316\uff0c\u6709\u52a9\u4e8e\u907f\u514d\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u8fd9\u5e76\u4e0d\u96be\u3002 \u5904\u7406\u5206\u7c7b\u7279\u5f81\u662f\u4e00\u9879\u590d\u6742\u7684\u4efb\u52a1\u3002\u8bb8\u591a\u8d44\u6e90\u4e2d\u90fd\u6709\u5927\u91cf\u4fe1\u606f\u3002\u672c\u7ae0\u5e94\u8be5\u80fd\u5e2e\u52a9\u4f60\u5f00\u59cb\u89e3\u51b3\u5206\u7c7b\u53d8\u91cf\u7684\u4efb\u4f55\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u5bf9\u4e8e\u5927\u591a\u6570\u95ee\u9898\u6765\u8bf4\uff0c\u9664\u4e86\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u4e4b\u5916\uff0c\u4f60\u4e0d\u9700\u8981\u66f4\u591a\u7684\u4e1c\u897f\u3002 \u8981\u8fdb\u4e00\u6b65\u6539\u8fdb\u6a21\u578b\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u66f4\u591a\uff01 \u5728\u672c\u7ae0\u7684\u6700\u540e\uff0c\u6211\u4eec\u4e0d\u80fd\u4e0d\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e00\u79cd\u79f0\u4e3a \u5b9e\u4f53\u5d4c\u5165 \u7684\u6280\u672f\u3002\u5728\u5b9e\u4f53\u5d4c\u5165\u4e2d\uff0c\u7c7b\u522b\u7528\u5411\u91cf\u8868\u793a\u3002\u5728\u4e8c\u503c\u5316\u548c\u72ec\u70ed\u7f16\u7801\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u90fd\u662f\u7528\u5411\u91cf\u6765\u8868\u793a\u7c7b\u522b\u7684\u3002 \u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u6709\u6570\u4ee5\u4e07\u8ba1\u7684\u7c7b\u522b\u600e\u4e48\u529e\uff1f\u8fd9\u5c06\u4f1a\u4ea7\u751f\u5de8\u5927\u7684\u77e9\u9635\uff0c\u6211\u4eec\u5c06\u9700\u8981\u5f88\u957f\u65f6\u95f4\u6765\u8bad\u7ec3\u590d\u6742\u7684\u6a21\u578b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5e26\u6709\u6d6e\u70b9\u503c\u7684\u5411\u91cf\u6765\u8868\u793a\u5b83\u4eec\u3002 \u8fd9\u4e2a\u60f3\u6cd5\u975e\u5e38\u7b80\u5355\u3002\u6bcf\u4e2a\u5206\u7c7b\u7279\u5f81\u90fd\u6709\u4e00\u4e2a\u5d4c\u5165\u5c42\u3002\u56e0\u6b64\uff0c\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u73b0\u5728\u90fd\u53ef\u4ee5\u6620\u5c04\u5230\u4e00\u4e2a\u5d4c\u5165\u5c42\uff08\u5c31\u50cf\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u5c06\u5355\u8bcd\u6620\u5c04\u5230\u5d4c\u5165\u5c42\u4e00\u6837\uff09\u3002\u7136\u540e\uff0c\u6839\u636e\u5176\u7ef4\u5ea6\u91cd\u5851\u8fd9\u4e9b\u5d4c\u5165\u5c42\uff0c\u4f7f\u5176\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c06\u6240\u6709\u6241\u5e73\u5316\u7684\u8f93\u5165\u5d4c\u5165\u5c42\u8fde\u63a5\u8d77\u6765\u3002\u7136\u540e\u6dfb\u52a0\u4e00\u5806\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\uff0c\u5c31\u5927\u529f\u544a\u6210\u4e86\u3002 \u56fe 6\uff1a\u7c7b\u522b\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6216\u5d4c\u5165\u5411\u91cf \u51fa\u4e8e\u67d0\u79cd\u539f\u56e0\uff0c\u6211\u53d1\u73b0\u4f7f\u7528 TF/Keras \u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528 TF/Keras \u5b9e\u73b0\u5b83\u3002\u6b64\u5916\uff0c\u8fd9\u662f\u672c\u4e66\u4e2d\u552f\u4e00\u4e00\u4e2a\u4f7f\u7528 TF/Keras \u7684\u793a\u4f8b\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a PyTorch\uff08\u4f7f\u7528 cat-in-the-dat-ii \u6570\u636e\u96c6\uff09\u4e5f\u975e\u5e38\u5bb9\u6613 import os import gc import joblib import pandas as pd import numpy as np from sklearn import metrics , preprocessing from tensorflow.keras import layers from tensorflow.keras import optimizers from tensorflow.keras.models import Model , load_model from tensorflow.keras import callbacks from tensorflow.keras import backend as K from tensorflow.keras import utils def create_model ( data , catcols ): # \u521b\u5efa\u7a7a\u7684\u8f93\u5165\u5217\u8868\u548c\u8f93\u51fa\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6a21\u578b\u7684\u8f93\u5165\u548c\u8f93\u51fa inputs = [] outputs = [] # \u904d\u5386\u5206\u7c7b\u7279\u5f81\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81 for c in catcols : # \u8ba1\u7b97\u7279\u5f81\u4e2d\u552f\u4e00\u503c\u7684\u6570\u91cf num_unique_values = int ( data [ c ] . nunique ()) # \u8ba1\u7b97\u5d4c\u5165\u7ef4\u5ea6\uff0c\u6700\u5927\u4e0d\u8d85\u8fc750 embed_dim = int ( min ( np . ceil (( num_unique_values ) / 2 ), 50 )) # \u521b\u5efa\u6a21\u578b\u7684\u8f93\u5165\u5c42\uff0c\u6bcf\u4e2a\u7279\u5f81\u5bf9\u5e94\u4e00\u4e2a\u8f93\u5165 inp = layers . Input ( shape = ( 1 ,)) # \u521b\u5efa\u5d4c\u5165\u5c42\uff0c\u5c06\u5206\u7c7b\u7279\u5f81\u6620\u5c04\u5230\u4f4e\u7ef4\u5ea6\u7684\u8fde\u7eed\u5411\u91cf out = layers . Embedding ( num_unique_values + 1 , embed_dim , name = c )( inp ) # \u5bf9\u5d4c\u5165\u5c42\u8fdb\u884c\u7a7a\u95f4\u4e22\u5f03\uff08Dropout\uff09 out = layers . SpatialDropout1D ( 0.3 )( out ) # \u5c06\u5d4c\u5165\u5c42\u7684\u5f62\u72b6\u91cd\u65b0\u8c03\u6574\u4e3a\u4e00\u7ef4 out = layers . Reshape ( target_shape = ( embed_dim ,))( out ) # \u5c06\u8f93\u5165\u548c\u8f93\u51fa\u6dfb\u52a0\u5230\u5bf9\u5e94\u7684\u5217\u8868\u4e2d inputs . append ( inp ) outputs . append ( out ) # \u4f7f\u7528Concatenate\u5c42\u5c06\u6240\u6709\u7684\u5d4c\u5165\u5c42\u8f93\u51fa\u8fde\u63a5\u5728\u4e00\u8d77 x = layers . Concatenate ()( outputs ) # \u5bf9\u8fde\u63a5\u540e\u7684\u6570\u636e\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u53e6\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u8f93\u51fa\u5c42\uff0c\u5177\u67092\u4e2a\u795e\u7ecf\u5143\uff08\u7528\u4e8e\u4e8c\u8fdb\u5236\u5206\u7c7b\uff09\uff0c\u5e76\u4f7f\u7528softmax\u6fc0\u6d3b\u51fd\u6570 y = layers . Dense ( 2 , activation = \"softmax\" )( x ) # \u521b\u5efa\u6a21\u578b\uff0c\u5c06\u8f93\u5165\u548c\u8f93\u51fa\u4f20\u9012\u7ed9Model\u6784\u9020\u51fd\u6570 model = Model ( inputs = inputs , outputs = y ) # \u7f16\u8bd1\u6a21\u578b\uff0c\u6307\u5b9a\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668 model . compile ( loss = 'binary_crossentropy' , optimizer = 'adam' ) # \u8fd4\u56de\u521b\u5efa\u7684\u6a21\u578b return model def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for feat in features : lbl_enc = preprocessing . LabelEncoder () df . loc [:, feat ] = lbl_enc . fit_transform ( df [ feat ] . values ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) model = create_model ( df , features ) xtrain = [ df_train [ features ] . values [:, k ] for k in range ( len ( features ))] xvalid = [ df_valid [ features ] . values [:, k ] for k in range ( len ( features )) ] ytrain = df_train . target . values yvalid = df_valid . target . values ytrain_cat = utils . to_categorical ( ytrain ) yvalid_cat = utils . to_categorical ( yvalid ) model . fit ( xtrain , ytrain_cat , validation_data = ( xvalid , yvalid_cat ), verbose = 1 , batch_size = 1024 , epochs = 3 ) valid_preds = model . predict ( xvalid )[:, 1 ] print ( metrics . roc_auc_score ( yvalid , valid_preds )) K . clear_session () if __name__ == \"__main__\" : run ( 0 ) run ( 1 ) run ( 2 ) run ( 3 ) run ( 4 ) \u4f60\u4f1a\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u6548\u679c\u6700\u597d\uff0c\u800c\u4e14\u5982\u679c\u4f60\u6709 GPU\uff0c\u901f\u5ea6\u4e5f\u8d85\u5feb\uff01\u8fd9\u79cd\u65b9\u6cd5\u8fd8\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\uff0c\u800c\u4e14\u4f60\u65e0\u9700\u62c5\u5fc3\u7279\u5f81\u5de5\u7a0b\uff0c\u56e0\u4e3a\u795e\u7ecf\u7f51\u7edc\u4f1a\u81ea\u884c\u5904\u7406\u3002\u5728\u5904\u7406\u5927\u91cf\u5206\u7c7b\u7279\u5f81\u6570\u636e\u96c6\u65f6\uff0c\u8fd9\u7edd\u5bf9\u503c\u5f97\u4e00\u8bd5\u3002\u5f53\u5d4c\u5165\u5927\u5c0f\u4e0e\u552f\u4e00\u7c7b\u522b\u7684\u6570\u91cf\u76f8\u540c\u65f6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u7f16\u7801\uff08one-hot-encoding\uff09\u3002 \u672c\u7ae0\u57fa\u672c\u4e0a\u90fd\u662f\u5173\u4e8e\u7279\u5f81\u5de5\u7a0b\u7684\u3002\u8ba9\u6211\u4eec\u5728\u4e0b\u4e00\u7ae0\u4e2d\u770b\u770b\u5982\u4f55\u5728\u6570\u5b57\u7279\u5f81\u548c\u4e0d\u540c\u7c7b\u578b\u7279\u5f81\u7684\u7ec4\u5408\u65b9\u9762\u8fdb\u884c\u66f4\u591a\u7684\u7279\u5f81\u5de5\u7a0b\u3002","title":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf"},{"location":"%E5%A4%84%E7%90%86%E5%88%86%E7%B1%BB%E5%8F%98%E9%87%8F/#_1","text":"\u5f88\u591a\u4eba\u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u65f6\u90fd\u4f1a\u9047\u5230\u5f88\u591a\u56f0\u96be\uff0c\u56e0\u6b64\u8fd9\u503c\u5f97\u7528\u6574\u6574\u4e00\u7ae0\u7684\u7bc7\u5e45\u6765\u8ba8\u8bba\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u5c06\u8bb2\u8ff0\u4e0d\u540c\u7c7b\u578b\u7684\u5206\u7c7b\u6570\u636e\uff0c\u4ee5\u53ca\u5982\u4f55\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u3002 \u4ec0\u4e48\u662f\u5206\u7c7b\u53d8\u91cf\uff1f \u5206\u7c7b\u53d8\u91cf/\u7279\u5f81\u662f\u6307\u4efb\u4f55\u7279\u5f81\u7c7b\u578b\uff0c\u53ef\u5206\u4e3a\u4e24\u5927\u7c7b\uff1a - \u65e0\u5e8f - \u6709\u5e8f \u65e0\u5e8f\u53d8\u91cf \u662f\u6307\u6709\u4e24\u4e2a\u6216\u4e24\u4e2a\u4ee5\u4e0a\u7c7b\u522b\u7684\u53d8\u91cf\uff0c\u8fd9\u4e9b\u7c7b\u522b\u6ca1\u6709\u4efb\u4f55\u76f8\u5173\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u5982\u679c\u5c06\u6027\u522b\u5206\u4e3a\u4e24\u7ec4\uff0c\u5373\u7537\u6027\u548c\u5973\u6027\uff0c\u5219\u53ef\u5c06\u5176\u89c6\u4e3a\u540d\u4e49\u53d8\u91cf\u3002 \u6709\u5e8f\u53d8\u91cf \u5219\u6709 \"\u7b49\u7ea7 \"\u6216\u7c7b\u522b\uff0c\u5e76\u6709\u7279\u5b9a\u7684\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u4e00\u4e2a\u987a\u5e8f\u5206\u7c7b\u53d8\u91cf\u53ef\u4ee5\u662f\u4e00\u4e2a\u5177\u6709\u4f4e\u3001\u4e2d\u3001\u9ad8\u4e09\u4e2a\u4e0d\u540c\u7b49\u7ea7\u7684\u7279\u5f81\u3002\u987a\u5e8f\u5f88\u91cd\u8981\u3002 \u5c31\u5b9a\u4e49\u800c\u8a00\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u53d8\u91cf\u5206\u4e3a \u4e8c\u5143\u53d8\u91cf \uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u7c7b\u522b\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u6709\u4e9b\u4eba\u751a\u81f3\u628a\u5206\u7c7b\u53d8\u91cf\u79f0\u4e3a \" \u5faa\u73af \"\u53d8\u91cf\u3002\u5468\u671f\u53d8\u91cf\u4ee5 \"\u5468\u671f \"\u7684\u5f62\u5f0f\u5b58\u5728\uff0c\u4f8b\u5982\u4e00\u5468\u4e2d\u7684\u5929\u6570\uff1a \u5468\u65e5\u3001\u5468\u4e00\u3001\u5468\u4e8c\u3001\u5468\u4e09\u3001\u5468\u56db\u3001\u5468\u4e94\u548c\u5468\u516d\u3002\u5468\u516d\u8fc7\u540e\uff0c\u53c8\u662f\u5468\u65e5\u3002\u8fd9\u5c31\u662f\u4e00\u4e2a\u5faa\u73af\u3002\u53e6\u4e00\u4e2a\u4f8b\u5b50\u662f\u4e00\u5929\u4e2d\u7684\u5c0f\u65f6\u6570\uff0c\u5982\u679c\u6211\u4eec\u5c06\u5b83\u4eec\u89c6\u4e3a\u7c7b\u522b\u7684\u8bdd\u3002 \u5206\u7c7b\u53d8\u91cf\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9a\u4e49\uff0c\u5f88\u591a\u4eba\u4e5f\u8c08\u5230\u8981\u6839\u636e\u5206\u7c7b\u53d8\u91cf\u7684\u7c7b\u578b\u6765\u5904\u7406\u4e0d\u540c\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u4e0d\u8fc7\uff0c\u6211\u8ba4\u4e3a\u6ca1\u6709\u5fc5\u8981\u8fd9\u6837\u505a\u3002\u6240\u6709\u6d89\u53ca\u5206\u7c7b\u53d8\u91cf\u7684\u95ee\u9898\u90fd\u53ef\u4ee5\u7528\u540c\u6837\u7684\u65b9\u6cd5\u5904\u7406\u3002 \u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e00\u4e2a\u6570\u636e\u96c6\uff08\u4e00\u5982\u65e2\u5f80\uff09\u3002\u8981\u4e86\u89e3\u5206\u7c7b\u53d8\u91cf\uff0c\u6700\u597d\u7684\u514d\u8d39\u6570\u636e\u96c6\u4e4b\u4e00\u662f Kaggle \u5206\u7c7b\u7279\u5f81\u7f16\u7801\u6311\u6218\u8d5b\u4e2d\u7684 cat-in-the-dat \u3002\u5171\u6709\u4e24\u4e2a\u6311\u6218\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6311\u6218\u7684\u6570\u636e\uff0c\u56e0\u4e3a\u5b83\u6bd4\u524d\u4e00\u4e2a\u7248\u672c\u6709\u66f4\u591a\u53d8\u91cf\uff0c\u96be\u5ea6\u4e5f\u66f4\u5927\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u3002 \u56fe 1\uff1aCat-in-the-dat-ii challenge\u90e8\u5206\u6570\u636e\u5c55\u793a \u6570\u636e\u96c6\u7531\u5404\u79cd\u5206\u7c7b\u53d8\u91cf\u7ec4\u6210\uff1a \u65e0\u5e8f \u6709\u5e8f \u5faa\u73af \u4e8c\u5143 \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6240\u6709\u5b58\u5728\u7684\u53d8\u91cf\u548c\u76ee\u6807\u53d8\u91cf\u7684\u5b50\u96c6\u3002 \u8fd9\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002 \u76ee\u6807\u53d8\u91cf\u5bf9\u4e8e\u6211\u4eec\u5b66\u4e60\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\u5e76\u4e0d\u5341\u5206\u91cd\u8981\uff0c\u4f46\u6700\u7ec8\u6211\u4eec\u5c06\u5efa\u7acb\u4e00\u4e2a\u7aef\u5230\u7aef\u6a21\u578b\uff0c\u56e0\u6b64\u8ba9\u6211\u4eec\u770b\u770b\u56fe 2 \u4e2d\u7684\u76ee\u6807\u53d8\u91cf\u5206\u5e03\u3002\u6211\u4eec\u770b\u5230\u76ee\u6807\u662f \u504f\u659c \u7684\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8fd9\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u6765\u8bf4\uff0c\u6700\u597d\u7684\u6307\u6807\u662f ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528\u7cbe\u786e\u5ea6\u548c\u53ec\u56de\u7387\uff0c\u4f46 AUC \u7ed3\u5408\u4e86\u8fd9\u4e24\u4e2a\u6307\u6807\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 AUC \u6765\u8bc4\u4f30\u6211\u4eec\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u7684\u6a21\u578b\u3002 \u56fe 2\uff1a\u6807\u7b7e\u8ba1\u6570\u3002X \u8f74\u8868\u793a\u6807\u7b7e\uff0cY \u8f74\u8868\u793a\u6807\u7b7e\u8ba1\u6570 \u603b\u4f53\u800c\u8a00\uff0c\u6709\uff1a 5\u4e2a\u4e8c\u5143\u53d8\u91cf 10\u4e2a\u65e0\u5e8f\u53d8\u91cf 6\u4e2a\u6709\u5e8f\u53d8\u91cf 2\u4e2a\u5faa\u73af\u53d8\u91cf 1\u4e2a\u76ee\u6807\u53d8\u91cf \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u96c6\u4e2d\u7684 ord_2 \u7279\u5f81\u3002\u5b83\u5305\u62ec6\u4e2a\u4e0d\u540c\u7684\u7c7b\u522b\uff1a - \u51b0\u51bb - \u6e29\u6696 - \u5bd2\u51b7 - \u8f83\u70ed - \u70ed - \u975e\u5e38\u70ed \u6211\u4eec\u5fc5\u987b\u77e5\u9053\uff0c\u8ba1\u7b97\u673a\u65e0\u6cd5\u7406\u89e3\u6587\u672c\u6570\u636e\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u662f\u521b\u5efa\u4e00\u4e2a\u5b57\u5178\uff0c\u5c06\u8fd9\u4e9b\u503c\u6620\u5c04\u4e3a\u4ece 0 \u5230 N-1 \u7684\u6570\u5b57\uff0c\u5176\u4e2d N \u662f\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7c7b\u522b\u7684\u603b\u6570\u3002 # \u6620\u5c04\u5b57\u5178 mapping = { \"Freezing\" : 0 , \"Warm\" : 1 , \"Cold\" : 2 , \"Boiling Hot\" : 3 , \"Hot\" : 4 , \"Lava Hot\" : 5 } \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8bfb\u53d6\u6570\u636e\u96c6\uff0c\u5e76\u8f7b\u677e\u5730\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002 import pandas as pd # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u53d6*ord_2*\u5217\uff0c\u5e76\u4f7f\u7528\u6620\u5c04\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57 df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. map ( mapping ) \u6620\u5c04\u524d\u7684\u6570\u503c\u8ba1\u6570\uff1a df .* ord_2 *. value_counts () Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : * ord_2 * , dtype : int64 \u6620\u5c04\u540e\u7684\u6570\u503c\u8ba1\u6570\uff1a 0.0 142726 1.0 124239 2.0 97822 3.0 84790 4.0 67508 5.0 64840 Name : * ord_2 * , dtype : int64 \u8fd9\u79cd\u5206\u7c7b\u53d8\u91cf\u7684\u7f16\u7801\u65b9\u5f0f\u88ab\u79f0\u4e3a\u6807\u7b7e\u7f16\u7801\uff08Label Encoding\uff09\u6211\u4eec\u5c06\u6bcf\u4e2a\u7c7b\u522b\u7f16\u7801\u4e3a\u4e00\u4e2a\u6570\u5b57\u6807\u7b7e\u3002 \u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 LabelEncoder \u8fdb\u884c\u7f16\u7801\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u5c06\u7f3a\u5931\u503c\u586b\u5145\u4e3a\"NONE\" df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. fillna ( \"NONE\" ) # LabelEncoder\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u8f6c\u6362\u6570\u636e df . loc [:, \"*ord_2*\" ] = lbl_enc . fit_transform ( df .* ord_2 *. values ) \u4f60\u4f1a\u770b\u5230\u6211\u4f7f\u7528\u4e86 pandas \u7684 fillna\u3002\u539f\u56e0\u662f scikit-learn \u7684 LabelEncoder \u65e0\u6cd5\u5904\u7406 NaN \u503c\uff0c\u800c ord_2 \u5217\u4e2d\u6709 NaN \u503c\u3002 \u6211\u4eec\u53ef\u4ee5\u5728\u8bb8\u591a\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u4e2d\u76f4\u63a5\u4f7f\u7528\u5b83\uff1a - \u51b3\u7b56\u6811 - \u968f\u673a\u68ee\u6797 - \u63d0\u5347\u6811 - \u6216\u4efb\u4f55\u4e00\u79cd\u63d0\u5347\u6811\u6a21\u578b - XGBoost - GBM - LightGBM \u8fd9\u79cd\u7f16\u7801\u65b9\u5f0f\u4e0d\u80fd\u7528\u4e8e\u7ebf\u6027\u6a21\u578b\u3001\u652f\u6301\u5411\u91cf\u673a\u6216\u795e\u7ecf\u7f51\u7edc\uff0c\u56e0\u4e3a\u5b83\u4eec\u5e0c\u671b\u6570\u636e\u662f\u6807\u51c6\u5316\u7684\u3002 \u5bf9\u4e8e\u8fd9\u4e9b\u7c7b\u578b\u7684\u6a21\u578b\uff0c\u6211\u4eec\u53ef\u4ee5\u5bf9\u6570\u636e\u8fdb\u884c\u4e8c\u503c\u5316\uff08binarize\uff09\u5904\u7406\u3002 Freezing --> 0 --> 0 0 0 Warm --> 1 --> 0 0 1 Cold --> 2 --> 0 1 0 Boiling Hot --> 3 --> 0 1 1 Hot --> 4 --> 1 0 0 Lava Hot --> 5 --> 1 0 1 \u8fd9\u53ea\u662f\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\uff0c\u7136\u540e\u518d\u8f6c\u6362\u4e3a\u4e8c\u503c\u5316\u8868\u793a\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u628a\u4e00\u4e2a\u7279\u5f81\u5206\u6210\u4e86\u4e09\u4e2a\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u7279\u5f81\uff08\u6216\u5217\uff09\u3002\u5982\u679c\u6211\u4eec\u6709\u66f4\u591a\u7684\u7c7b\u522b\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u5206\u6210\u66f4\u591a\u7684\u5217\u3002 \u5982\u679c\u6211\u4eec\u7528\u7a00\u758f\u683c\u5f0f\u5b58\u50a8\u5927\u91cf\u4e8c\u503c\u5316\u53d8\u91cf\uff0c\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u5b58\u50a8\u8fd9\u4e9b\u53d8\u91cf\u3002\u7a00\u758f\u683c\u5f0f\u4e0d\u8fc7\u662f\u4e00\u79cd\u5728\u5185\u5b58\u4e2d\u5b58\u50a8\u6570\u636e\u7684\u8868\u793a\u6216\u65b9\u5f0f\uff0c\u5728\u8fd9\u79cd\u683c\u5f0f\u4e2d\uff0c\u4f60\u5e76\u4e0d\u5b58\u50a8\u6240\u6709\u7684\u503c\uff0c\u800c\u53ea\u5b58\u50a8\u91cd\u8981\u7684\u503c\u3002\u5728\u4e0a\u8ff0\u4e8c\u8fdb\u5236\u53d8\u91cf\u7684\u60c5\u51b5\u4e2d\uff0c\u6700\u91cd\u8981\u7684\u5c31\u662f\u6709 1 \u7684\u5730\u65b9\u3002 \u5f88\u96be\u60f3\u8c61\u8fd9\u6837\u7684\u683c\u5f0f\uff0c\u4f46\u4e3e\u4e2a\u4f8b\u5b50\u5c31\u4f1a\u660e\u767d\u3002 \u5047\u8bbe\u4e0a\u9762\u7684\u6570\u636e\u5e27\u4e2d\u53ea\u6709\u4e00\u4e2a\u7279\u5f81\uff1a ord_2 \u3002 Index Feature 0 Warm 1 Hot 2 Lava hot \u76ee\u524d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6570\u636e\u96c6\u4e2d\u7684\u4e09\u4e2a\u6837\u672c\u3002\u8ba9\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u4e8c\u503c\u8868\u793a\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u6709\u4e09\u4e2a\u9879\u76ee\u3002 \u8fd9\u4e09\u4e2a\u9879\u76ee\u5c31\u662f\u4e09\u4e2a\u7279\u5f81\u3002 Index Feature_0 Feature_1 Feature_2 0 0 0 1 1 1 0 0 2 1 0 1 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u7279\u5f81\u5b58\u50a8\u5728\u4e00\u4e2a\u6709 3 \u884c 3 \u5217\uff083x3\uff09\u7684\u77e9\u9635\u4e2d\u3002\u77e9\u9635\u7684\u6bcf\u4e2a\u5143\u7d20\u5360\u7528 8 \u4e2a\u5b57\u8282\u3002\u56e0\u6b64\uff0c\u8fd9\u4e2a\u6570\u7ec4\u7684\u603b\u5185\u5b58\u9700\u6c42\u4e3a 8x3x3 = 72 \u5b57\u8282\u3002 \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e2a\u7b80\u5355\u7684 python \u4ee3\u7801\u6bb5\u6765\u68c0\u67e5\u8fd9\u4e00\u70b9\u3002 import numpy as np example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) print ( example . nbytes ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u6253\u5370\u51fa 72\uff0c\u5c31\u50cf\u6211\u4eec\u4e4b\u524d\u8ba1\u7b97\u7684\u90a3\u6837\u3002\u4f46\u6211\u4eec\u9700\u8981\u5b58\u50a8\u8fd9\u4e2a\u77e9\u9635\u7684\u6240\u6709\u5143\u7d20\u5417\uff1f\u5982\u524d\u6240\u8ff0\uff0c\u6211\u4eec\u53ea\u5bf9 1 \u611f\u5174\u8da3\u30020 \u5e76\u4e0d\u91cd\u8981\uff0c\u56e0\u4e3a\u4efb\u4f55\u4e0e 0 \u76f8\u4e58\u7684\u5143\u7d20\u90fd\u662f 0\uff0c\u800c 0 \u4e0e\u4efb\u4f55\u5143\u7d20\u76f8\u52a0\u6216\u76f8\u51cf\u4e5f\u6ca1\u6709\u4efb\u4f55\u533a\u522b\u3002\u53ea\u7528 1 \u8868\u793a\u77e9\u9635\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u67d0\u79cd\u5b57\u5178\u65b9\u6cd5\uff0c\u5176\u4e2d\u952e\u662f\u884c\u548c\u5217\u7684\u7d22\u5f15\uff0c\u503c\u662f 1\uff1a ( 0 , 2 ) 1 ( 1 , 0 ) 1 ( 2 , 0 ) 1 ( 2 , 2 ) 1 \u8fd9\u6837\u7684\u7b26\u53f7\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u5b58\u50a8\u56db\u4e2a\u503c\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u3002\u4f7f\u7528\u7684\u603b\u5185\u5b58\u4e3a 8x4 = 32 \u5b57\u8282\u3002\u4efb\u4f55 numpy \u6570\u7ec4\u90fd\u53ef\u4ee5\u901a\u8fc7\u7b80\u5355\u7684 python \u4ee3\u7801\u8f6c\u6362\u4e3a\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) sparse_example = sparse . csr_matrix ( example ) print ( sparse_example . data . nbytes ) \u8fd9\u5c06\u6253\u5370 32\uff0c\u6bd4\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u5c11\u4e86\u8fd9\u4e48\u591a\uff01\u7a00\u758f csr \u77e9\u9635\u7684\u603b\u5927\u5c0f\u662f\u4e09\u4e2a\u503c\u7684\u603b\u548c\u3002 print ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) \u8fd9\u5c06\u6253\u5370\u51fa 64 \u4e2a\u5143\u7d20\uff0c\u4ecd\u7136\u5c11\u4e8e\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u3002\u9057\u61be\u7684\u662f\uff0c\u6211\u4e0d\u4f1a\u8be6\u7ec6\u4ecb\u7ecd\u8fd9\u4e9b\u5143\u7d20\u3002\u4f60\u53ef\u4ee5\u5728 scipy \u6587\u6863\u4e2d\u4e86\u89e3\u66f4\u591a\u3002\u5f53\u6211\u4eec\u62e5\u6709\u66f4\u5927\u7684\u6570\u7ec4\u65f6\uff0c\u6bd4\u5982\u8bf4\u62e5\u6709\u6570\u5343\u4e2a\u6837\u672c\u548c\u6570\u4e07\u4e2a\u7279\u5f81\u7684\u6570\u7ec4\uff0c\u5927\u5c0f\u5dee\u5f02\u5c31\u4f1a\u53d8\u5f97\u975e\u5e38\u5927\u3002\u4f8b\u5982\uff0c\u6211\u4eec\u4f7f\u7528\u57fa\u4e8e\u8ba1\u6570\u7279\u5f81\u7684\u6587\u672c\u6570\u636e\u96c6\u3002 import numpy as np from scipy import sparse n_rows = 10000 n_cols = 100000 # \u751f\u6210\u7b26\u5408\u4f2f\u52aa\u5229\u5206\u5e03\u7684\u968f\u673a\u6570\u7ec4\uff0c\u7ef4\u5ea6\u4e3a[10000, 100000] example = np . random . binomial ( 1 , p = 0.05 , size = ( n_rows , n_cols )) print ( f \"Size of dense array: { example . nbytes } \" ) # \u5c06\u968f\u673a\u77e9\u9635\u8f6c\u6362\u4e3a\u6d17\u6f31\u77e9\u9635 sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u8fd9\u5c06\u6253\u5370\uff1a Size of dense array : 8000000000 Size of sparse array : 399932496 Full size of sparse array : 599938748 \u56e0\u6b64\uff0c\u5bc6\u96c6\u9635\u5217\u9700\u8981 ~8000MB \u6216\u5927\u7ea6 8GB \u5185\u5b58\u3002\u800c\u7a00\u758f\u9635\u5217\u53ea\u5360\u7528 399MB \u5185\u5b58\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5f53\u6211\u4eec\u7684\u7279\u5f81\u4e2d\u6709\u5927\u91cf\u96f6\u65f6\uff0c\u6211\u4eec\u66f4\u559c\u6b22\u7a00\u758f\u9635\u5217\u800c\u4e0d\u662f\u5bc6\u96c6\u9635\u5217\u7684\u539f\u56e0\u3002 \u8bf7\u6ce8\u610f\uff0c\u7a00\u758f\u77e9\u9635\u6709\u591a\u79cd\u4e0d\u540c\u7684\u8868\u793a\u65b9\u6cd5\u3002\u8fd9\u91cc\u6211\u53ea\u5c55\u793a\u4e86\u5176\u4e2d\u4e00\u79cd\uff08\u53ef\u80fd\u4e5f\u662f\u6700\u5e38\u7528\u7684\uff09\u65b9\u6cd5\u3002\u6df1\u5165\u63a2\u8ba8\u8fd9\u4e9b\u65b9\u6cd5\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\uff0c\u56e0\u6b64\u7559\u7ed9\u8bfb\u8005\u4e00\u4e2a\u7ec3\u4e60\u3002 \u5c3d\u7ba1\u4e8c\u503c\u5316\u7279\u5f81\u7684\u7a00\u758f\u8868\u793a\u6bd4\u5176\u5bc6\u96c6\u8868\u793a\u6240\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u4f46\u5bf9\u4e8e\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\uff0c\u8fd8\u6709\u4e00\u79cd\u8f6c\u6362\u6240\u5360\u7528\u7684\u5185\u5b58\u66f4\u5c11\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \" \u72ec\u70ed\u7f16\u7801 \"\u3002 \u72ec\u70ed\u7f16\u7801\u4e5f\u662f\u4e00\u79cd\u4e8c\u503c\u7f16\u7801\uff0c\u56e0\u4e3a\u53ea\u6709 0 \u548c 1 \u4e24\u4e2a\u503c\u3002\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5b83\u5e76\u4e0d\u662f\u4e8c\u503c\u8868\u793a\u6cd5\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0b\u9762\u7684\u4f8b\u5b50\u6765\u7406\u89e3\u5b83\u7684\u8868\u793a\u6cd5\u3002 \u5047\u8bbe\u6211\u4eec\u7528\u4e00\u4e2a\u5411\u91cf\u6765\u8868\u793a ord_2 \u53d8\u91cf\u7684\u6bcf\u4e2a\u7c7b\u522b\u3002\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e0e ord_2 \u53d8\u91cf\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u6bcf\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u90fd\u662f 6\uff0c\u5e76\u4e14\u9664\u4e86\u4e00\u4e2a\u4f4d\u7f6e\u5916\uff0c\u5176\u4ed6\u4f4d\u7f6e\u90fd\u662f 0\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u7279\u6b8a\u7684\u5411\u91cf\u8868\u3002 Freezing 0 0 0 0 0 1 Warm 0 0 0 0 1 0 Cold 0 0 0 1 0 0 Boiling Hot 0 0 1 0 0 0 Hot 0 1 0 0 0 0 Lava Hot 1 0 0 0 0 0 \u6211\u4eec\u770b\u5230\u5411\u91cf\u7684\u5927\u5c0f\u662f 1x6\uff0c\u5373\u5411\u91cf\u4e2d\u67096\u4e2a\u5143\u7d20\u3002\u8fd9\u4e2a\u6570\u5b57\u662f\u600e\u4e48\u6765\u7684\u5462\uff1f\u5982\u679c\u4f60\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u5982\u524d\u6240\u8ff0\uff0c\u67096\u4e2a\u7c7b\u522b\u3002\u5728\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u65f6\uff0c\u5411\u91cf\u7684\u5927\u5c0f\u5fc5\u987b\u4e0e\u6211\u4eec\u8981\u67e5\u770b\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u6bcf\u4e2a\u5411\u91cf\u90fd\u6709\u4e00\u4e2a 1\uff0c\u5176\u4f59\u6240\u6709\u503c\u90fd\u662f 0\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u7528\u8fd9\u4e9b\u7279\u5f81\u6765\u4ee3\u66ff\u4e4b\u524d\u7684\u4e8c\u503c\u5316\u7279\u5f81\uff0c\u770b\u770b\u80fd\u8282\u7701\u591a\u5c11\u5185\u5b58\u3002 \u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\u4ee5\u524d\u7684\u6570\u636e\uff0c\u5b83\u770b\u8d77\u6765\u5982\u4e0b\uff1a Index Feature 0 Warm 1 Hot 2 Lava hot \u6bcf\u4e2a\u6837\u672c\u67093\u4e2a\u7279\u5f81\u3002\u4f46\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u72ec\u70ed\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 6\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u67096\u4e2a\u7279\u5f81\uff0c\u800c\u4e0d\u662f3\u4e2a\u3002 Index F_0 F_1 F_2 F_3 F_4 F_5 0 0 0 0 0 1 0 1 0 1 0 0 0 0 2 1 0 1 0 0 0 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 6 \u4e2a\u7279\u5f81\uff0c\u800c\u5728\u8fd9\u4e2a 3x6 \u6570\u7ec4\u4e2d\uff0c\u53ea\u6709 3 \u4e2a1\u3002\u4f7f\u7528 numpy \u8ba1\u7b97\u5927\u5c0f\u4e0e\u4e8c\u503c\u5316\u5927\u5c0f\u8ba1\u7b97\u811a\u672c\u975e\u5e38\u76f8\u4f3c\u3002\u4f60\u9700\u8981\u6539\u53d8\u7684\u53ea\u662f\u6570\u7ec4\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6bb5\u4ee3\u7801\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 0 , 0 , 1 , 0 ], [ 0 , 1 , 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 , 0 , 0 ] ] ) print ( f \"Size of dense array: { example . nbytes } \" ) sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u6253\u5370\u5185\u5b58\u5927\u5c0f\u4e3a\uff1a Size of dense array : 144 Size of sparse array : 24 Full size of sparse array : 52 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5bc6\u96c6\u77e9\u9635\u7684\u5927\u5c0f\u8fdc\u8fdc\u5927\u4e8e\u4e8c\u503c\u5316\u77e9\u9635\u7684\u5927\u5c0f\u3002\u4e0d\u8fc7\uff0c\u7a00\u758f\u6570\u7ec4\u7684\u5927\u5c0f\u8981\u66f4\u5c0f\u3002\u8ba9\u6211\u4eec\u7528\u66f4\u5927\u7684\u6570\u7ec4\u6765\u8bd5\u8bd5\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 OneHotEncoder \u5c06\u5305\u542b 1001 \u4e2a\u7c7b\u522b\u7684\u7279\u5f81\u6570\u7ec4\u8f6c\u6362\u4e3a\u5bc6\u96c6\u77e9\u9635\u548c\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from sklearn import preprocessing # \u751f\u6210\u7b26\u5408\u5747\u5300\u5206\u5e03\u7684\u968f\u673a\u6574\u6570\uff0c\u7ef4\u5ea6\u4e3a[1000000, 10000000] example = np . random . randint ( 1000 , size = 1000000 ) # \u72ec\u70ed\u7f16\u7801\uff0c\u975e\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = False ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of dense array: { ohe_example . nbytes } \" ) # \u72ec\u70ed\u7f16\u7801\uff0c\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = True ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of sparse array: { ohe_example . data . nbytes } \" ) full_size = ( ohe_example . data . nbytes + ohe_example . indptr . nbytes + ohe_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u4e0a\u9762\u4ee3\u7801\u6253\u5370\u7684\u8f93\u51fa\uff1a Size of dense array : 8000000000 Size of sparse array : 8000000 Full size of sparse array : 16000004 \u8fd9\u91cc\u7684\u5bc6\u96c6\u9635\u5217\u5927\u5c0f\u7ea6\u4e3a 8GB\uff0c\u7a00\u758f\u9635\u5217\u4e3a 8MB\u3002\u5982\u679c\u53ef\u4ee5\u9009\u62e9\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff1f\u5728\u6211\u770b\u6765\uff0c\u9009\u62e9\u5f88\u7b80\u5355\uff0c\u4e0d\u662f\u5417\uff1f \u8fd9\u4e09\u79cd\u65b9\u6cd5\uff08\u6807\u7b7e\u7f16\u7801\u3001\u7a00\u758f\u77e9\u9635\u3001\u72ec\u70ed\u7f16\u7801\uff09\u662f\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u7684\u6700\u91cd\u8981\u65b9\u6cd5\u3002\u4e0d\u8fc7\uff0c\u4f60\u8fd8\u53ef\u4ee5\u7528\u5f88\u591a\u5176\u4ed6\u4e0d\u540c\u7684\u65b9\u6cd5\u6765\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u3002\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u6570\u503c\u53d8\u91cf\u5c31\u662f\u5176\u4e2d\u7684\u4e00\u4e2a\u4f8b\u5b50\u3002 \u5047\u8bbe\u6211\u4eec\u56de\u5230\u4e4b\u524d\u7684\u5206\u7c7b\u7279\u5f81\u6570\u636e\uff08\u539f\u59cb\u6570\u636e\u4e2d\u7684 cat-in-the-dat-ii\uff09\u3002\u5728\u6570\u636e\u4e2d\uff0c ord_2 \u7684\u503c\u4e3a\u201c\u70ed\u201c\u7684 id \u6709\u591a\u5c11\uff1f \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8ba1\u7b97\u6570\u636e\u7684\u5f62\u72b6\uff08shape\uff09\u8f7b\u677e\u8ba1\u7b97\u51fa\u8fd9\u4e2a\u503c\uff0c\u5176\u4e2d ord_2 \u5217\u7684\u503c\u4e3a Boiling Hot \u3002 In [ X ]: df [ df . ord_2 == \"Boiling Hot\" ] . shape Out [ X ]: ( 84790 , 25 ) \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 84790 \u6761\u8bb0\u5f55\u5177\u6709\u6b64\u503c\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 pandas \u4e2d\u7684 groupby \u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u8be5\u503c\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . count () Out [ X ]: ord_2 Boiling Hot 84790 Cold 97822 Freezing 142726 Hot 67508 Lava Hot 64840 Warm 124239 Name : id , dtype : int64 \u5982\u679c\u6211\u4eec\u53ea\u662f\u5c06 ord_2 \u5217\u66ff\u6362\u4e3a\u5176\u8ba1\u6570\u503c\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u5c06\u5176\u8f6c\u6362\u4e3a\u4e00\u79cd\u6570\u503c\u7279\u5f81\u4e86\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 transform \u51fd\u6570\u548c groupby \u6765\u521b\u5efa\u65b0\u5217\u6216\u66ff\u6362\u8fd9\u4e00\u5217\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . transform ( \"count\" ) Out [ X ]: 0 67508.0 1 124239.0 2 142726.0 3 64840.0 4 97822.0 ... 599995 142726.0 599996 84790.0 599997 142726.0 599998 124239.0 599999 84790.0 Name : id , Length : 600000 , dtype : float64 \u4f60\u53ef\u4ee5\u6dfb\u52a0\u6240\u6709\u7279\u5f81\u7684\u8ba1\u6570\uff0c\u4e5f\u53ef\u4ee5\u66ff\u6362\u5b83\u4eec\uff0c\u6216\u8005\u6839\u636e\u591a\u4e2a\u5217\u53ca\u5176\u8ba1\u6570\u8fdb\u884c\u5206\u7ec4\u3002\u4f8b\u5982\uff0c\u4ee5\u4e0b\u4ee3\u7801\u901a\u8fc7\u5bf9 ord_1 \u548c ord_2 \u5217\u5206\u7ec4\u8fdb\u884c\u8ba1\u6570\u3002 In [ X ]: df . groupby ( ... : [ ... : \"ord_1\" , ... : \"ord_2\" ... : ] ... : )[ \"id\" ] . count () . reset_index ( name = \"count\" ) Out [ X ]: ord_1 ord_2 count 0 Contributor Boiling Hot 15634 1 Contributor Cold 17734 2 Contributor Freezing 26082 3 Contributor Hot 12428 4 Contributor Lava Hot 11919 5 Contributor Warm 22774 6 Expert Boiling Hot 19477 7 Expert Cold 22956 8 Expert Freezing 33249 9 Expert Hot 15792 10 Expert Lava Hot 15078 11 Expert Warm 28900 12 Grandmaster Boiling Hot 13623 13 Grandmaster Cold 15464 14 Grandmaster Freezing 22818 15 Grandmaster Hot 10805 16 Grandmaster Lava Hot 10363 17 Grandmaster Warm 19899 18 Master Boiling Hot 10800 ... \u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u884c\uff0c\u4ee5\u4fbf\u5728\u4e00\u9875\u4e2d\u5bb9\u7eb3\u8fd9\u4e9b\u884c\u3002\u8fd9\u662f\u53e6\u4e00\u79cd\u53ef\u4ee5\u4f5c\u4e3a\u529f\u80fd\u6dfb\u52a0\u7684\u8ba1\u6570\u3002\u60a8\u73b0\u5728\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4f7f\u7528 id \u5217\u8fdb\u884c\u8ba1\u6570\u3002\u4e0d\u8fc7\uff0c\u4f60\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5bf9\u5217\u7684\u7ec4\u5408\u8fdb\u884c\u5206\u7ec4\uff0c\u5bf9\u5176\u4ed6\u5217\u8fdb\u884c\u8ba1\u6570\u3002 \u8fd8\u6709\u4e00\u4e2a\u5c0f\u7a8d\u95e8\uff0c\u5c31\u662f\u4ece\u8fd9\u4e9b\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u4ece\u73b0\u6709\u7684\u7279\u5f81\u4e2d\u521b\u5efa\u65b0\u7684\u5206\u7c7b\u7279\u5f81\uff0c\u800c\u4e14\u53ef\u4ee5\u6beb\u4e0d\u8d39\u529b\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot 1 Grandmaster_Warm 2 nan_Freezing 3 Novice_Lava Hot 4 Grandmaster_Cold ... 599999 Contributor_Boiling Hot Name : new_feature , Length : 600000 , dtype : object \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u7528\u4e0b\u5212\u7ebf\u5c06 ord_1 \u548c ord_2 \u5408\u5e76\uff0c\u7136\u540e\u5c06\u8fd9\u4e9b\u5217\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u7c7b\u578b\u3002\u8bf7\u6ce8\u610f\uff0cNaN \u4e5f\u4f1a\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u3002\u4e0d\u8fc7\u6ca1\u5173\u7cfb\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06 NaN \u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u6709\u4e86\u4e00\u4e2a\u7531\u8fd9\u4e24\u4e2a\u7279\u5f81\u7ec4\u5408\u800c\u6210\u7684\u65b0\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u5c06\u4e09\u5217\u4ee5\u4e0a\u6216\u56db\u5217\u751a\u81f3\u66f4\u591a\u5217\u7ec4\u5408\u5728\u4e00\u8d77\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : + \"_\" ... : + df . ord_3 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot_c 1 Grandmaster_Warm_e 2 nan_Freezing_n 3 Novice_Lava Hot_a 4 Grandmaster_Cold_h ... 599999 Contributor_Boiling Hot_b Name : new_feature , Length : 600000 , dtype : object \u90a3\u4e48\uff0c\u6211\u4eec\u5e94\u8be5\u628a\u54ea\u4e9b\u7c7b\u522b\u7ed3\u5408\u8d77\u6765\u5462\uff1f\u8fd9\u5e76\u6ca1\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u7b54\u6848\u3002\u8fd9\u53d6\u51b3\u4e8e\u60a8\u7684\u6570\u636e\u548c\u7279\u5f81\u7c7b\u578b\u3002\u4e00\u4e9b\u9886\u57df\u77e5\u8bc6\u5bf9\u4e8e\u521b\u5efa\u8fd9\u6837\u7684\u7279\u5f81\u53ef\u80fd\u5f88\u6709\u7528\u3002\u4f46\u662f\uff0c\u5982\u679c\u4f60\u4e0d\u62c5\u5fc3\u5185\u5b58\u548c CPU \u7684\u4f7f\u7528\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\uff0c\u5373\u521b\u5efa\u8bb8\u591a\u8fd9\u6837\u7684\u7ec4\u5408\uff0c\u7136\u540e\u4f7f\u7528\u4e00\u4e2a\u6a21\u578b\u6765\u51b3\u5b9a\u54ea\u4e9b\u7279\u5f81\u662f\u6709\u7528\u7684\uff0c\u5e76\u4fdd\u7559\u5b83\u4eec\u3002\u6211\u4eec\u5c06\u5728\u672c\u4e66\u7a0d\u540e\u90e8\u5206\u4ecb\u7ecd\u8fd9\u79cd\u65b9\u6cd5\u3002 \u65e0\u8bba\u4f55\u65f6\u83b7\u5f97\u5206\u7c7b\u53d8\u91cf\uff0c\u90fd\u8981\u9075\u5faa\u4ee5\u4e0b\u7b80\u5355\u6b65\u9aa4\uff1a - \u586b\u5145 NaN \u503c\uff08\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff01\uff09\u3002 - \u4f7f\u7528 scikit-learn \u7684 LabelEncoder \u6216\u6620\u5c04\u5b57\u5178\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\uff0c\u5c06\u5b83\u4eec\u8f6c\u6362\u4e3a\u6574\u6570\u3002\u5982\u679c\u6ca1\u6709\u586b\u5145 NaN \u503c\uff0c\u53ef\u80fd\u9700\u8981\u5728\u8fd9\u4e00\u6b65\u4e2d\u8fdb\u884c\u5904\u7406 - \u521b\u5efa\u72ec\u70ed\u7f16\u7801\u3002\u662f\u7684\uff0c\u4f60\u53ef\u4ee5\u8df3\u8fc7\u4e8c\u503c\u5316\uff01 - \u5efa\u6a21\uff01\u6211\u6307\u7684\u662f\u673a\u5668\u5b66\u4e60\u3002 \u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u5904\u7406 NaN \u6570\u636e\u975e\u5e38\u91cd\u8981\uff0c\u5426\u5219\u60a8\u53ef\u80fd\u4f1a\u4ece scikit-learn \u7684 LabelEncoder \u4e2d\u5f97\u5230\u81ed\u540d\u662d\u8457\u7684\u9519\u8bef\u4fe1\u606f\uff1a ValueError: y \u5305\u542b\u4ee5\u524d\u672a\u89c1\u8fc7\u7684\u6807\u7b7e\uff1a [Nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan) \u8fd9\u4ec5\u4ec5\u610f\u5473\u7740\uff0c\u5728\u8f6c\u6362\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6570\u636e\u4e2d\u51fa\u73b0\u4e86 NaN \u503c\u3002\u8fd9\u662f\u56e0\u4e3a\u4f60\u5728\u8bad\u7ec3\u65f6\u5fd8\u8bb0\u4e86\u5904\u7406\u5b83\u4eec\u3002 \u5904\u7406 NaN \u503c \u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u4e22\u5f03\u5b83\u4eec\u3002\u867d\u7136\u7b80\u5355\uff0c\u4f46\u5e76\u4e0d\u7406\u60f3\u3002NaN \u503c\u4e2d\u53ef\u80fd\u5305\u542b\u5f88\u591a\u4fe1\u606f\uff0c\u5982\u679c\u53ea\u662f\u4e22\u5f03\u8fd9\u4e9b\u503c\uff0c\u5c31\u4f1a\u4e22\u5931\u8fd9\u4e9b\u4fe1\u606f\u3002\u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u5927\u90e8\u5206\u6570\u636e\u90fd\u662f NaN \u503c\uff0c\u56e0\u6b64\u4e0d\u80fd\u4e22\u5f03 NaN \u503c\u7684\u884c/\u6837\u672c\u3002\u5904\u7406 NaN \u503c\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u5c06\u5176\u4f5c\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u662f\u5904\u7406 NaN \u503c\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002\u5982\u679c\u4f7f\u7528 pandas\uff0c\u8fd8\u53ef\u4ee5\u901a\u8fc7\u975e\u5e38\u7b80\u5355\u7684\u65b9\u5f0f\u5b9e\u73b0\u3002 \u8bf7\u770b\u6211\u4eec\u4e4b\u524d\u67e5\u770b\u8fc7\u7684\u6570\u636e\u7684 ord_2 \u5217\u3002 In [ X ]: df . ord_2 . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : ord_2 , dtype : int64 \u586b\u5165 NaN \u503c\u540e\uff0c\u5c31\u53d8\u6210\u4e86 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u54c7\uff01\u8fd9\u4e00\u5217\u4e2d\u6709 18075 \u4e2a NaN \u503c\uff0c\u800c\u6211\u4eec\u4e4b\u524d\u751a\u81f3\u90fd\u6ca1\u6709\u8003\u8651\u4f7f\u7528\u5b83\u4eec\u3002\u589e\u52a0\u4e86\u8fd9\u4e2a\u65b0\u7c7b\u522b\u540e\uff0c\u7c7b\u522b\u603b\u6570\u4ece 6 \u4e2a\u589e\u52a0\u5230\u4e86 7 \u4e2a\u3002\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u73b0\u5728\u6211\u4eec\u5728\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u4e5f\u4f1a\u8003\u8651 NaN\u3002\u76f8\u5173\u4fe1\u606f\u8d8a\u591a\uff0c\u6a21\u578b\u5c31\u8d8a\u597d\u3002 \u5047\u8bbe ord_2 \u6ca1\u6709\u4efb\u4f55 NaN \u503c\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e00\u5217\u4e2d\u7684\u6240\u6709\u7c7b\u522b\u90fd\u6709\u663e\u8457\u7684\u8ba1\u6570\u3002\u5176\u4e2d\u6ca1\u6709 \"\u7f55\u89c1 \"\u7c7b\u522b\uff0c\u5373\u53ea\u5728\u6837\u672c\u603b\u6570\u4e2d\u5360\u5f88\u5c0f\u6bd4\u4f8b\u7684\u7c7b\u522b\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5047\u8bbe\u60a8\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u4e86\u4f7f\u7528\u8fd9\u4e00\u5217\u7684\u6a21\u578b\uff0c\u5f53\u6a21\u578b\u6216\u9879\u76ee\u4e0a\u7ebf\u65f6\uff0c\u60a8\u5728 ord_2 \u5217\u4e2d\u5f97\u5230\u4e86\u4e00\u4e2a\u5728\u8bad\u7ec3\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u7ba1\u9053\u4f1a\u629b\u51fa\u4e00\u4e2a\u9519\u8bef\uff0c\u60a8\u5bf9\u6b64\u65e0\u80fd\u4e3a\u529b\u3002\u5982\u679c\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\uff0c\u90a3\u4e48\u53ef\u80fd\u662f\u751f\u4ea7\u4e2d\u7684\u7ba1\u9053\u51fa\u4e86\u95ee\u9898\u3002\u5982\u679c\u8fd9\u662f\u9884\u6599\u4e4b\u4e2d\u7684\uff0c\u90a3\u4e48\u60a8\u5c31\u5fc5\u987b\u4fee\u6539\u60a8\u7684\u6a21\u578b\u7ba1\u9053\uff0c\u5e76\u5728\u8fd9\u516d\u4e2a\u7c7b\u522b\u4e2d\u52a0\u5165\u4e00\u4e2a\u65b0\u7c7b\u522b\u3002 \u8fd9\u4e2a\u65b0\u7c7b\u522b\u88ab\u79f0\u4e3a \"\u7f55\u89c1 \"\u7c7b\u522b\u3002\u7f55\u89c1\u7c7b\u522b\u662f\u4e00\u79cd\u4e0d\u5e38\u89c1\u7684\u7c7b\u522b\uff0c\u53ef\u4ee5\u5305\u62ec\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8fd1\u90bb\u6a21\u578b\u6765 \"\u9884\u6d4b \"\u672a\u77e5\u7c7b\u522b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u4e86\u8fd9\u4e2a\u7c7b\u522b\uff0c\u5b83\u5c31\u4f1a\u6210\u4e3a\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u4e00\u4e2a\u7c7b\u522b\u3002 \u56fe 3\uff1a\u5177\u6709\u4e0d\u540c\u7279\u5f81\u4e14\u65e0\u6807\u7b7e\u7684\u6570\u636e\u96c6\u793a\u610f\u56fe\uff0c\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u53ef\u80fd\u4f1a\u5728\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u51fa\u73b0\u65b0\u503c \u5f53\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u56fe 3 \u6240\u793a\u7684\u6570\u636e\u96c6\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\uff0c\u5bf9\u9664 \"f3 \"\u4e4b\u5916\u7684\u6240\u6709\u7279\u5f81\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u6837\uff0c\u4f60\u5c06\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u4e0d\u77e5\u9053\u6216\u8bad\u7ec3\u4e2d\u6ca1\u6709 \"f3 \"\u65f6\u9884\u6d4b\u5b83\u3002\u6211\u4e0d\u6562\u8bf4\u8fd9\u6837\u7684\u6a21\u578b\u662f\u5426\u80fd\u5e26\u6765\u51fa\u8272\u7684\u6027\u80fd\uff0c\u4f46\u4e5f\u8bb8\u80fd\u5904\u7406\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u7684\u7f3a\u5931\u503c\uff0c\u5c31\u50cf\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5176\u4ed6\u4e8b\u60c5\u4e00\u6837\uff0c\u4e0d\u5c1d\u8bd5\u4e00\u4e0b\u662f\u8bf4\u4e0d\u51c6\u7684\u3002 \u5982\u679c\u4f60\u6709\u4e00\u4e2a\u56fa\u5b9a\u7684\u6d4b\u8bd5\u96c6\uff0c\u4f60\u53ef\u4ee5\u5c06\u6d4b\u8bd5\u6570\u636e\u6dfb\u52a0\u5230\u8bad\u7ec3\u4e2d\uff0c\u4ee5\u4e86\u89e3\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u7c7b\u522b\u3002\u8fd9\u4e0e\u534a\u76d1\u7763\u5b66\u4e60\u975e\u5e38\u76f8\u4f3c\uff0c\u5373\u4f7f\u7528\u65e0\u6cd5\u7528\u4e8e\u8bad\u7ec3\u7684\u6570\u636e\u6765\u6539\u8fdb\u6a21\u578b\u3002\u8fd9\u4e5f\u4f1a\u7167\u987e\u5230\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u51fa\u73b0\u6b21\u6570\u6781\u5c11\u4f46\u5728\u6d4b\u8bd5\u6570\u636e\u4e2d\u5927\u91cf\u5b58\u5728\u7684\u7a00\u6709\u503c\u3002\u4f60\u7684\u6a21\u578b\u5c06\u66f4\u52a0\u7a33\u5065\u3002 \u5f88\u591a\u4eba\u8ba4\u4e3a\u8fd9\u79cd\u60f3\u6cd5\u4f1a\u8fc7\u5ea6\u62df\u5408\u3002\u53ef\u80fd\u8fc7\u62df\u5408\uff0c\u4e5f\u53ef\u80fd\u4e0d\u8fc7\u62df\u5408\u3002\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u3002\u5982\u679c\u4f60\u5728\u8bbe\u8ba1\u4ea4\u53c9\u9a8c\u8bc1\u65f6\uff0c\u80fd\u591f\u5728\u6d4b\u8bd5\u6570\u636e\u4e0a\u8fd0\u884c\u6a21\u578b\u65f6\u590d\u5236\u9884\u6d4b\u8fc7\u7a0b\uff0c\u90a3\u4e48\u5b83\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u8fd9\u610f\u5473\u7740\u7b2c\u4e00\u6b65\u5e94\u8be5\u662f\u5206\u79bb\u6298\u53e0\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u4f60\u5e94\u8be5\u5e94\u7528\u4e0e\u6d4b\u8bd5\u6570\u636e\u76f8\u540c\u7684\u9884\u5904\u7406\u3002\u5047\u8bbe\u60a8\u60f3\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e\uff0c\u90a3\u4e48\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u60a8\u5fc5\u987b\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\uff0c\u5e76\u786e\u4fdd\u9a8c\u8bc1\u6570\u636e\u96c6\u590d\u5236\u4e86\u6d4b\u8bd5\u96c6\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u60a8\u5fc5\u987b\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u8bbe\u8ba1\u9a8c\u8bc1\u96c6\uff0c\u4f7f\u5176\u5305\u542b\u8bad\u7ec3\u96c6\u4e2d \"\u672a\u89c1 \"\u7684\u7c7b\u522b\u3002 \u56fe 4\uff1a\u5bf9\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u8fdb\u884c\u7b80\u5355\u5408\u5e76\uff0c\u4ee5\u4e86\u89e3\u6d4b\u8bd5\u96c6\u4e2d\u5b58\u5728\u4f46\u8bad\u7ec3\u96c6\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u6216\u8bad\u7ec3\u96c6\u4e2d\u7f55\u89c1\u7684\u7c7b\u522b \u53ea\u8981\u770b\u4e00\u4e0b\u56fe 4 \u548c\u4e0b\u9762\u7684\u4ee3\u7801\uff0c\u5c31\u80fd\u5f88\u5bb9\u6613\u7406\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( \"../input/cat_train.csv\" ) # \u8bfb\u53d6\u6d4b\u8bd5\u96c6 test = pd . read_csv ( \"../input/cat_test.csv\" ) # \u5c06\u6d4b\u8bd5\u96c6\"target\"\u5217\u5168\u90e8\u7f6e\u4e3a-1 test . loc [:, \"target\" ] = - 1 # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u62fc\u63a5 data = pd . concat ([ train , test ]) . reset_index ( drop = True ) # \u5c06\u9664\"id\"\u548c\"target\"\u5217\u7684\u5176\u4ed6\u7279\u5f81\u5217\u540d\u53d6\u51fa features = [ x for x in train . columns if x not in [ \"id\" , \"target\" ]] # \u904d\u5386\u7279\u5f81 for feat in features : # \u6807\u7b7e\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\",\u5e76\u5c06\u8be5\u5217\u683c\u5f0f\u53d8\u4e3astr temp_col = data [ feat ] . fillna ( \"NONE\" ) . astype ( str ) . values # \u8f6c\u6362\u6570\u503c data . loc [:, feat ] = lbl_enc . fit_transform ( temp_col ) # \u6839\u636e\"target\"\u5217\u5c06\u8bad\u7ec3\u96c6\u4e0e\u6d4b\u8bd5\u96c6\u5206\u5f00 train = data [ data . target != - 1 ] . reset_index ( drop = True ) test = data [ data . target == - 1 ] . reset_index ( drop = True ) \u5f53\u60a8\u9047\u5230\u5df2\u7ecf\u6709\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u95ee\u9898\u65f6\uff0c\u8fd9\u4e2a\u6280\u5de7\u5c31\u4f1a\u8d77\u4f5c\u7528\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u4e00\u62db\u5728\u5b9e\u65f6\u73af\u5883\u4e2d\u4e0d\u8d77\u4f5c\u7528\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u60a8\u6240\u5728\u7684\u516c\u53f8\u63d0\u4f9b\u5b9e\u65f6\u7ade\u4ef7\u89e3\u51b3\u65b9\u6848\uff08RTB\uff09\u3002RTB \u7cfb\u7edf\u4f1a\u5bf9\u5728\u7ebf\u770b\u5230\u7684\u6bcf\u4e2a\u7528\u6237\u8fdb\u884c\u7ade\u4ef7\uff0c\u4ee5\u8d2d\u4e70\u5e7f\u544a\u7a7a\u95f4\u3002\u8fd9\u79cd\u6a21\u5f0f\u53ef\u4f7f\u7528\u7684\u529f\u80fd\u53ef\u80fd\u5305\u62ec\u7f51\u7ad9\u4e2d\u6d4f\u89c8\u7684\u9875\u9762\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u7279\u5f81\u662f\u7528\u6237\u8bbf\u95ee\u7684\u6700\u540e\u4e94\u4e2a\u7c7b\u522b/\u9875\u9762\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u7f51\u7ad9\u5f15\u5165\u4e86\u65b0\u7684\u7c7b\u522b\uff0c\u6211\u4eec\u5c06\u65e0\u6cd5\u518d\u51c6\u786e\u9884\u6d4b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u5931\u6548\u3002\u8fd9\u79cd\u60c5\u51b5\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528 \"\u672a\u77e5 \"\u7c7b\u522b\u6765\u907f\u514d \u3002 \u5728\u6211\u4eec\u7684 cat-in-the-dat \u6570\u636e\u96c6\u4e2d\uff0c ord_2 \u5217\u4e2d\u5df2\u7ecf\u6709\u4e86\u672a\u77e5\u7c7b\u522b\u3002 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u6211\u4eec\u53ef\u4ee5\u5c06 \"NONE \"\u89c6\u4e3a\u672a\u77e5\u3002\u56e0\u6b64\uff0c\u5982\u679c\u5728\u5b9e\u65f6\u6d4b\u8bd5\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4ee5\u524d\u4ece\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\uff0c\u6211\u4eec\u5c31\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a \"NONE\"\u3002 \u8fd9\u4e0e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u95ee\u9898\u975e\u5e38\u76f8\u4f3c\u3002\u6211\u4eec\u603b\u662f\u57fa\u4e8e\u56fa\u5b9a\u7684\u8bcd\u6c47\u5efa\u7acb\u6a21\u578b\u3002\u589e\u52a0\u8bcd\u6c47\u91cf\u5c31\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u3002\u50cf BERT \u8fd9\u6837\u7684\u8f6c\u6362\u5668\u6a21\u578b\u662f\u5728 ~30000 \u4e2a\u5355\u8bcd\uff08\u82f1\u8bed\uff09\u7684\u57fa\u7840\u4e0a\u8bad\u7ec3\u7684\u3002\u56e0\u6b64\uff0c\u5f53\u6709\u65b0\u8bcd\u8f93\u5165\u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a UNK\uff08\u672a\u77e5\uff09\u3002 \u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u5047\u8bbe\u6d4b\u8bd5\u6570\u636e\u4e0e\u8bad\u7ec3\u6570\u636e\u5177\u6709\u76f8\u540c\u7684\u7c7b\u522b\uff0c\u4e5f\u53ef\u4ee5\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u5f15\u5165\u7f55\u89c1\u6216\u672a\u77e5\u7c7b\u522b\uff0c\u4ee5\u5904\u7406\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u65b0\u7c7b\u522b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u586b\u5165 NaN \u503c\u540e ord_4 \u5217\u7684\u503c\u8ba1\u6570\uff1a In [ X ]: df . ord_4 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 . . . K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 G 3404 V 3107 J 1950 L 1657 Name : ord_4 , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u6709\u4e9b\u6570\u503c\u53ea\u51fa\u73b0\u4e86\u51e0\u5343\u6b21\uff0c\u6709\u4e9b\u5219\u51fa\u73b0\u4e86\u8fd1 40000 \u6b21\u3002NaN \u4e5f\u7ecf\u5e38\u51fa\u73b0\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u503c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u5c06\u4e00\u4e2a\u503c\u79f0\u4e3a \" \u7f55\u89c1\uff08rare\uff09 \"\u7684\u6807\u51c6\u4e86\u3002\u6bd4\u65b9\u8bf4\uff0c\u5728\u8fd9\u4e00\u5217\u4e2d\uff0c\u7a00\u6709\u503c\u7684\u8981\u6c42\u662f\u8ba1\u6570\u5c0f\u4e8e 2000\u3002\u8fd9\u6837\u770b\u6765\uff0cJ \u548c L \u5c31\u53ef\u4ee5\u88ab\u6807\u8bb0\u4e3a\u7a00\u6709\u503c\u4e86\u3002\u4f7f\u7528 pandas\uff0c\u6839\u636e\u8ba1\u6570\u9608\u503c\u66ff\u6362\u7c7b\u522b\u975e\u5e38\u7b80\u5355\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 In [ X ]: df . ord_4 = df . ord_4 . fillna ( \"NONE\" ) In [ X ]: df . loc [ ... : df [ \"ord_4\" ] . value_counts ()[ df [ \"ord_4\" ]] . values < 2000 , ... : \"ord_4\" ... : ] = \"RARE\" In [ X ]: df . ord_4 . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 M 32504 . . . B 25212 E 21871 K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 RARE 3607 G 3404 V 3107 Name : ord_4 , dtype : int64 \u6211\u4eec\u8ba4\u4e3a\uff0c\u53ea\u8981\u67d0\u4e2a\u7c7b\u522b\u7684\u503c\u5c0f\u4e8e 2000\uff0c\u5c31\u5c06\u5176\u66ff\u6362\u4e3a\u7f55\u89c1\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u5728\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6240\u6709\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"RARE\"\uff0c\u800c\u6240\u6709\u7f3a\u5931\u503c\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"NONE\"\u3002 \u8fd9\u79cd\u65b9\u6cd5\u8fd8\u80fd\u786e\u4fdd\u5373\u4f7f\u6709\u65b0\u7684\u7c7b\u522b\uff0c\u6a21\u578b\u4e5f\u80fd\u5728\u5b9e\u9645\u73af\u5883\u4e2d\u6b63\u5e38\u5de5\u4f5c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u5177\u5907\u4e86\u5904\u7406\u4efb\u4f55\u5e26\u6709\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u6240\u9700\u7684\u4e00\u5207\u6761\u4ef6\u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u7b2c\u4e00\u4e2a\u6a21\u578b\uff0c\u5e76\u9010\u6b65\u63d0\u9ad8\u5176\u6027\u80fd\u3002 \u5728\u6784\u5efa\u4efb\u4f55\u7c7b\u578b\u7684\u6a21\u578b\u4e4b\u524d\uff0c\u4ea4\u53c9\u68c0\u9a8c\u81f3\u5173\u91cd\u8981\u3002\u6211\u4eec\u5df2\u7ecf\u770b\u5230\u4e86\u6807\u7b7e/\u76ee\u6807\u5206\u5e03\uff0c\u77e5\u9053\u8fd9\u662f\u4e00\u4e2a\u76ee\u6807\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u6765\u5206\u5272\u6570\u636e\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u6dfb\u52a0\"kfold\"\u5217\uff0c\u5e76\u7f6e\u4e3a-1 df [ \"kfold\" ] = - 1 # \u6253\u4e71\u6570\u636e\u987a\u5e8f\uff0c\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u5c06\u76ee\u6807\u5217\u53d6\u51fa y = df . target . values # \u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): # \u533a\u5206\u6298\u53e0 df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u6587\u4ef6 df . to_csv ( \"../input/cat_train_folds.csv\" , index = False ) \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u68c0\u67e5\u65b0\u7684\u6298\u53e0 csv\uff0c\u67e5\u770b\u6bcf\u4e2a\u6298\u53e0\u7684\u6837\u672c\u6570\uff1a In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) In [ X ]: df . kfold . value_counts () Out [ X ]: 4 120000 3 120000 2 120000 1 120000 0 120000 Name : kfold , dtype : int64 \u6240\u6709\u6298\u53e0\u90fd\u6709 120000 \u4e2a\u6837\u672c\u3002\u8fd9\u662f\u610f\u6599\u4e4b\u4e2d\u7684\uff0c\u56e0\u4e3a\u8bad\u7ec3\u6570\u636e\u6709 600000 \u4e2a\u6837\u672c\uff0c\u800c\u6211\u4eec\u505a\u4e865\u6b21\u6298\u53e0\u3002\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4e00\u5207\u987a\u5229\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u6bcf\u4e2a\u6298\u53e0\u7684\u76ee\u6807\u5206\u5e03\u3002 In [ X ]: df [ df . kfold == 0 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 1 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 2 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 3 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 4 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u76ee\u6807\u7684\u5206\u5e03\u90fd\u662f\u4e00\u6837\u7684\u3002\u8fd9\u6b63\u662f\u6211\u4eec\u6240\u9700\u8981\u7684\u3002\u5b83\u4e5f\u53ef\u4ee5\u662f\u76f8\u4f3c\u7684\uff0c\u5e76\u4e0d\u4e00\u5b9a\u8981\u4e00\u76f4\u76f8\u540c\u3002\u73b0\u5728\uff0c\u5f53\u6211\u4eec\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u6807\u7b7e\u5206\u5e03\u90fd\u5c06\u76f8\u540c\u3002 \u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u7684\u6700\u7b80\u5355\u7684\u6a21\u578b\u4e4b\u4e00\u662f\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u5e76\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): # \u8bfb\u53d6\u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) # \u53d6\u9664\"id\", \"target\", \"kfold\"\u5916\u7684\u5176\u4ed6\u7279\u5f81\u5217 features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] # \u904d\u5386\u7279\u5f81\u5217\u8868 for col in features : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u6d4b\u8bd5\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u903b\u8f91\u56de\u5f52 model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . target . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( auc ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6298\u53e00 run ( 0 ) \u90a3\u4e48\uff0c\u53d1\u751f\u4e86\u4ec0\u4e48\u5462\uff1f \u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u5c06\u6570\u636e\u5206\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u4e24\u90e8\u5206\uff0c\u7ed9\u5b9a\u6298\u53e0\u6570\uff0c\u5904\u7406 NaN \u503c\uff0c\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u5355\u6b21\u7f16\u7801\uff0c\u5e76\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\u3002 \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u90e8\u5206\u4ee3\u7801\u65f6\uff0c\u4f1a\u4ea7\u751f\u5982\u4e0b\u8f93\u51fa\uff1a \u276f python ohe_logres . py / home / abhishek / miniconda3 / envs / ml / lib / python3 .7 / site - packages / sklearn / linear_model / _logistic . py : 939 : ConvergenceWarning : lbfgs failed to converge ( status = 1 ): STOP : TOTAL NO . of ITERATIONS REACHED LIMIT . Increase the number of iterations ( max_iter ) or scale the data as shown in : https : // scikit - learn . org / stable / modules / preprocessing . html . Please also refer to the documentation for alternative solver options : https : // scikit - learn . org / stable / modules / linear_model . html #logistic- regression extra_warning_msg = _LOGISTIC_SOLVER_CONVERGENCE_MSG ) 0.7847865042255127 \u6709\u4e00\u4e9b\u8b66\u544a\u3002\u903b\u8f91\u56de\u5f52\u4f3c\u4e4e\u6ca1\u6709\u6536\u655b\u5230\u6700\u5927\u8fed\u4ee3\u6b21\u6570\u3002\u6211\u4eec\u6ca1\u6709\u8c03\u6574\u53c2\u6570\uff0c\u6240\u4ee5\u6ca1\u6709\u95ee\u9898\u3002\u6211\u4eec\u770b\u5230 AUC \u4e3a 0.785\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u5bf9\u4ee3\u7801\u8fdb\u884c\u7b80\u5355\u4fee\u6539\uff0c\u8fd0\u884c\u6240\u6709\u6298\u53e0\u3002 .... model = linear_model . LogisticRegression () model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u5faa\u73af\u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5e76\u6ca1\u6709\u505a\u5f88\u5927\u7684\u6539\u52a8\uff0c\u6240\u4ee5\u6211\u53ea\u663e\u793a\u4e86\u90e8\u5206\u4ee3\u7801\u884c\uff0c\u5176\u4e2d\u4e00\u4e9b\u4ee3\u7801\u884c\u6709\u6539\u52a8\u3002 \u8fd9\u5c31\u6253\u5370\u51fa\u4e86\uff1a python - W ignore ohe_logres . py Fold = 0 , AUC = 0.7847865042255127 Fold = 1 , AUC = 0.7853553605899214 Fold = 2 , AUC = 0.7879321942914885 Fold = 3 , AUC = 0.7870315929550808 Fold = 4 , AUC = 0.7864668243125608 \u8bf7\u6ce8\u610f\uff0c\u6211\u4f7f\u7528\"-W ignore \"\u5ffd\u7565\u4e86\u6240\u6709\u8b66\u544a\u3002 \u6211\u4eec\u770b\u5230\uff0cAUC \u5206\u6570\u5728\u6240\u6709\u8936\u76b1\u4e2d\u90fd\u76f8\u5f53\u7a33\u5b9a\u3002\u5e73\u5747 AUC \u4e3a 0.78631449527\u3002\u5bf9\u4e8e\u6211\u4eec\u7684\u7b2c\u4e00\u4e2a\u6a21\u578b\u6765\u8bf4\u76f8\u5f53\u4e0d\u9519\uff01 \u5f88\u591a\u4eba\u5728\u9047\u5230\u8fd9\u79cd\u95ee\u9898\u65f6\u4f1a\u9996\u5148\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6bd4\u5982\u968f\u673a\u68ee\u6797\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u5e94\u7528\u968f\u673a\u68ee\u6797\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\uff08label encoding\uff09\uff0c\u5c06\u6bcf\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81\u90fd\u8f6c\u6362\u4e3a\u6574\u6570\uff0c\u800c\u4e0d\u662f\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u72ec\u70ed\u7f16\u7801\u3002 \u8fd9\u79cd\u7f16\u7801\u4e0e\u72ec\u70ed\u7f16\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # \u968f\u673a\u68ee\u6797\u6a21\u578b model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u4f7f\u7528 scikit-learn \u4e2d\u7684\u968f\u673a\u68ee\u6797\uff0c\u5e76\u53d6\u6d88\u4e86\u72ec\u70ed\u7f16\u7801\u3002\u6211\u4eec\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u4ee3\u66ff\u72ec\u70ed\u7f16\u7801\u3002\u5f97\u5206\u5982\u4e0b \u276f python lbl_rf . py Fold = 0 , AUC = 0.7167390828113697 Fold = 1 , AUC = 0.7165459672958506 Fold = 2 , AUC = 0.7159709909587376 Fold = 3 , AUC = 0.7161589664189556 Fold = 4 , AUC = 0.7156020216155978 \u54c7 \u5de8\u5927\u7684\u5dee\u5f02\uff01 \u968f\u673a\u68ee\u6797\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8d85\u53c2\u6570\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u8868\u73b0\u8981\u6bd4\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u5dee\u5f88\u591a\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u603b\u662f\u5e94\u8be5\u5148\u4ece\u7b80\u5355\u6a21\u578b\u5f00\u59cb\u7684\u539f\u56e0\u3002\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u7c89\u4e1d\u4f1a\u4ece\u8fd9\u91cc\u5f00\u59cb\uff0c\u800c\u5ffd\u7565\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u8ba4\u4e3a\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u6a21\u578b\uff0c\u4e0d\u80fd\u5e26\u6765\u6bd4\u968f\u673a\u68ee\u6797\u66f4\u597d\u7684\u4ef7\u503c\u3002\u8fd9\u79cd\u4eba\u5c06\u4f1a\u72af\u4e0b\u5927\u9519\u3002\u5728\u6211\u4eec\u5b9e\u73b0\u968f\u673a\u68ee\u6797\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4e0e\u903b\u8f91\u56de\u5f52\u76f8\u6bd4\uff0c\u6298\u53e0\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u624d\u80fd\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e0d\u4ec5\u635f\u5931\u4e86 AUC\uff0c\u8fd8\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u6765\u5b8c\u6210\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u8fdb\u884c\u63a8\u7406\u4e5f\u5f88\u8017\u65f6\uff0c\u800c\u4e14\u5360\u7528\u7684\u7a7a\u95f4\u4e5f\u66f4\u5927\u3002 \u5982\u679c\u6211\u4eec\u613f\u610f\uff0c\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u5728\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u6570\u636e\u4e0a\u8fd0\u884c\u968f\u673a\u68ee\u6797\uff0c\u4f46\u8fd9\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u5947\u5f02\u503c\u5206\u89e3\u6765\u51cf\u5c11\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u77e9\u9635\u3002\u8fd9\u662f\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u63d0\u53d6\u4e3b\u9898\u7684\u5e38\u7528\u65b9\u6cd5\u3002 import pandas as pd from scipy import sparse from sklearn import decomposition from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" )] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) x_train = ohe . transform ( df_train [ features ]) x_valid = ohe . transform ( df_valid [ features ]) # \u5947\u5f02\u503c\u5206\u89e3 svd = decomposition . TruncatedSVD ( n_components = 120 ) full_sparse = sparse . vstack (( x_train , x_valid )) svd . fit ( full_sparse ) x_train = svd . transform ( x_train ) x_valid = svd . transform ( x_valid ) model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u5bf9\u5168\u90e8\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\uff0c\u7136\u540e\u7528\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u5728\u7a00\u758f\u77e9\u9635\u4e0a\u62df\u5408 scikit-learn \u7684 TruncatedSVD\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c06\u9ad8\u7ef4\u7a00\u758f\u77e9\u9635\u51cf\u5c11\u5230 120 \u4e2a\u7279\u5f81\uff0c\u7136\u540e\u62df\u5408\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u3002 \u4ee5\u4e0b\u662f\u8be5\u6a21\u578b\u7684\u8f93\u51fa\u7ed3\u679c\uff1a \u276f python ohe_svd_rf . py Fold = 0 , AUC = 0.7064863038754249 Fold = 1 , AUC = 0.706050102937374 Fold = 2 , AUC = 0.7086069243167242 Fold = 3 , AUC = 0.7066819080085971 Fold = 4 , AUC = 0.7058154015055585 \u6211\u4eec\u53d1\u73b0\u60c5\u51b5\u66f4\u7cdf\u3002\u770b\u6765\uff0c\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u7684\u6700\u4f73\u65b9\u6cd5\u662f\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u3002\u968f\u673a\u68ee\u6797\u4f3c\u4e4e\u8017\u65f6\u592a\u591a\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u8bd5\u8bd5 XGBoost\u3002\u5982\u679c\u4f60\u4e0d\u77e5\u9053 XGBoost\uff0c\u5b83\u662f\u6700\u6d41\u884c\u7684\u68af\u5ea6\u63d0\u5347\u7b97\u6cd5\u4e4b\u4e00\u3002\u7531\u4e8e\u5b83\u662f\u4e00\u79cd\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u6570\u636e\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 , n_estimators = 200 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6211\u5bf9 xgboost \u53c2\u6570\u505a\u4e86\u4e00\u4e9b\u4fee\u6539\u3002xgboost \u7684\u9ed8\u8ba4\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u662f 3\uff0c\u6211\u628a\u5b83\u6539\u6210\u4e86 7\uff0c\u8fd8\u628a\u4f30\u8ba1\u5668\u6570\u91cf\uff08n_estimators\uff09\u4ece 100 \u6539\u6210\u4e86 200\u3002 \u8be5\u6a21\u578b\u7684 5 \u6298\u4ea4\u53c9\u68c0\u9a8c\u5f97\u5206\u5982\u4e0b\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.7656768851999011 Fold = 1 , AUC = 0.7633006564148015 Fold = 2 , AUC = 0.7654277821434345 Fold = 3 , AUC = 0.7663609758878182 Fold = 4 , AUC = 0.764914671468069 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u4e0d\u505a\u4efb\u4f55\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u5f97\u5206\u6bd4\u666e\u901a\u968f\u673a\u68ee\u6797\u8981\u9ad8\u5f97\u591a\u3002 \u60a8\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4e00\u4e9b\u7279\u5f81\u5de5\u7a0b\uff0c\u653e\u5f03\u67d0\u4e9b\u5bf9\u6a21\u578b\u6ca1\u6709\u4efb\u4f55\u4ef7\u503c\u7684\u5217\u7b49\u3002\u4f46\u4f3c\u4e4e\u6211\u4eec\u80fd\u505a\u7684\u4e0d\u591a\uff0c\u65e0\u6cd5\u8bc1\u660e\u6a21\u578b\u7684\u6539\u8fdb\u3002\u8ba9\u6211\u4eec\u628a\u6570\u636e\u96c6\u6362\u6210\u53e6\u4e00\u4e2a\u6709\u5927\u91cf\u5206\u7c7b\u53d8\u91cf\u7684\u6570\u636e\u96c6\u3002\u53e6\u4e00\u4e2a\u6709\u540d\u7684\u6570\u636e\u96c6\u662f \u7f8e\u56fd\u6210\u4eba\u4eba\u53e3\u666e\u67e5\u6570\u636e\uff08US adult census data\uff09 \u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u5305\u542b\u4e00\u4e9b\u7279\u5f81\uff0c\u800c\u4f60\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5de5\u8d44\u7b49\u7ea7\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u3002\u56fe 5 \u663e\u793a\u4e86\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e9b\u5217\u3002 \u56fe 5\uff1a\u90e8\u5206\u6570\u636e\u96c6\u5c55\u793a \u8be5\u6570\u636e\u96c6\u6709\u4ee5\u4e0b\u51e0\u5217\uff1a - \u5e74\u9f84\uff08age\uff09 \u5de5\u4f5c\u7c7b\u522b\uff08workclass\uff09 \u5b66\u5386\uff08fnlwgt\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education.num\uff09 \u5a5a\u59fb\u72b6\u51b5\uff08marital.status\uff09 \u804c\u4e1a\uff08occupation\uff09 \u5173\u7cfb\uff08relationship\uff09 \u79cd\u65cf\uff08race\uff09 \u6027\u522b\uff08sex\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u539f\u7c4d\u56fd\uff08native.country\uff09 \u6536\u5165\uff08income\uff09 \u8fd9\u4e9b\u7279\u5f81\u5927\u591a\u4e0d\u8a00\u81ea\u660e\u3002\u90a3\u4e9b\u4e0d\u660e\u767d\u7684\uff0c\u6211\u4eec\u53ef\u4ee5\u4e0d\u8003\u8651\u3002\u8ba9\u6211\u4eec\u5148\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002 \u6211\u4eec\u770b\u5230\u6536\u5165\u5217\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u3002\u8ba9\u6211\u4eec\u5bf9\u8fd9\u4e00\u5217\u8fdb\u884c\u6570\u503c\u7edf\u8ba1\u3002 In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/adult.csv\" ) In [ X ]: df . income . value_counts () Out [ X ]: <= 50 K 24720 > 50 K 7841 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 7841 \u4e2a\u5b9e\u4f8b\u7684\u6536\u5165\u8d85\u8fc7 5 \u4e07\u7f8e\u5143\u3002\u8fd9\u5360\u6837\u672c\u603b\u6570\u7684 24%\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u4e0e\u732b\u6570\u636e\u96c6\u76f8\u540c\u7684\u8bc4\u4f30\u65b9\u6cd5\uff0c\u5373 AUC\u3002 \u5728\u5f00\u59cb\u5efa\u6a21\u4e4b\u524d\uff0c\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5c06\u53bb\u6389\u51e0\u5217\u7279\u5f81\uff0c\u5373 \u5b66\u5386\uff08fnlwgt\uff09 \u5e74\u9f84\uff08age\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u8ba9\u6211\u4eec\u8bd5\u7740\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u5668\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u7b2c\u4e00\u6b65\u603b\u662f\u8981\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002\u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u90e8\u5206\u4ee3\u7801\u3002\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u9700\u8981\u5220\u9664\u7684\u5217 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) # \u6620\u5c04 target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } # \u4f7f\u7528\u6620\u5c04\u66ff\u6362 df . loc [:, \"income\" ] = df . income . map ( target_mapping ) # \u53d6\u9664\"kfold\", \"income\"\u5217\u7684\u5176\u4ed6\u5217\u540d features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u9a8c\u8bc1\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u6784\u5efa\u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . income . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u6211\u4eec\u4f1a\u5f97\u5230 \u276f python - W ignore ohe_logres . py Fold = 0 , AUC = 0.8794809708119079 Fold = 1 , AUC = 0.8875785068274882 Fold = 2 , AUC = 0.8852609687685753 Fold = 3 , AUC = 0.8681236223251438 Fold = 4 , AUC = 0.8728581541840037 \u5bf9\u4e8e\u4e00\u4e2a\u5982\u6b64\u7b80\u5355\u7684\u6a21\u578b\u6765\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u4e0d\u9519\u7684 AUC\uff01 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5728\u4e0d\u8c03\u6574\u4efb\u4f55\u8d85\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c1d\u8bd5\u4e00\u4e0b\u6807\u7b7e\u7f16\u7801\u7684xgboost\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8ba9\u6211\u4eec\u8fd0\u884c\u4e0a\u9762\u4ee3\u7801\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8800810634234078 Fold = 1 , AUC = 0.886811884948154 Fold = 2 , AUC = 0.8854421433318472 Fold = 3 , AUC = 0.8676319549361007 Fold = 4 , AUC = 0.8714450054900602 \u8fd9\u770b\u8d77\u6765\u5df2\u7ecf\u76f8\u5f53\u4e0d\u9519\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b max_depth \u589e\u52a0\u5230 7 \u548c n_estimators \u589e\u52a0\u5230 200 \u65f6\u7684\u5f97\u5206\u3002 \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8764108944332032 Fold = 1 , AUC = 0.8840708537662638 Fold = 2 , AUC = 0.8816601162613102 Fold = 3 , AUC = 0.8662335762581732 Fold = 4 , AUC = 0.8698983461709926 \u770b\u8d77\u6765\u5e76\u6ca1\u6709\u6539\u5584\u3002 \u8fd9\u8868\u660e\uff0c\u4e00\u4e2a\u6570\u636e\u96c6\u7684\u53c2\u6570\u4e0d\u80fd\u79fb\u690d\u5230\u53e6\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6211\u4eec\u5fc5\u987b\u518d\u6b21\u5c1d\u8bd5\u8c03\u6574\u53c2\u6570\uff0c\u4f46\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u8bf4\u660e\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u4e0d\u8c03\u6574\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c06\u6570\u503c\u7279\u5f81\u7eb3\u5165 xgboost \u6a21\u578b\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u52a0\u5165\u6570\u503c\u7279\u5f81 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : if col not in num_cols : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u56e0\u6b64\uff0c\u6211\u4eec\u4fdd\u7559\u6570\u5b57\u5217\uff0c\u53ea\u662f\u4e0d\u5bf9\u5176\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6700\u7ec8\u7279\u5f81\u77e9\u9635\u5c31\u7531\u6570\u5b57\u5217\uff08\u539f\u6837\uff09\u548c\u7f16\u7801\u5206\u7c7b\u5217\u7ec4\u6210\u4e86\u3002\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\u90fd\u80fd\u8f7b\u677e\u5904\u7406\u8fd9\u79cd\u6df7\u5408\u3002 \u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff0c\u5728\u4f7f\u7528\u7ebf\u6027\u6a21\u578b\uff08\u5982\u903b\u8f91\u56de\u5f52\uff09\u65f6\u4e0d\u5bb9\u5ffd\u89c6\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u4e2a\u811a\u672c\uff01 \u276f python lbl_xgb_num . py Fold = 0 , AUC = 0.9209790185449889 Fold = 1 , AUC = 0.9247157449144706 Fold = 2 , AUC = 0.9269329887598243 Fold = 3 , AUC = 0.9119349082169275 Fold = 4 , AUC = 0.9166408030141667 \u54c7\u54e6 \u8fd9\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u5206\u6570\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c1d\u8bd5\u6dfb\u52a0\u4e00\u4e9b\u529f\u80fd\u3002\u6211\u4eec\u5c06\u63d0\u53d6\u6240\u6709\u5206\u7c7b\u5217\uff0c\u5e76\u521b\u5efa\u6240\u6709\u4e8c\u5ea6\u7ec4\u5408\u3002\u8bf7\u770b\u4e0b\u9762\u4ee3\u7801\u6bb5\u4e2d\u7684 feature_engineering \u51fd\u6570\uff0c\u4e86\u89e3\u5982\u4f55\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002 import itertools import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def feature_engineering ( df , cat_cols ): # \u751f\u6210\u4e24\u4e2a\u7279\u5f81\u7684\u7ec4\u5408 combi = list ( itertools . combinations ( cat_cols , 2 )) for c1 , c2 in combi : df . loc [:, c1 + \"_\" + c2 ] = df [ c1 ] . astype ( str ) + \"_\" + df [ c2 ] . astype ( str ) return df def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) cat_cols = [ c for c in df . columns if c not in num_cols and c not in ( \"kfold\" , \"income\" )] # \u7279\u5f81\u5de5\u7a0b df = feature_engineering ( df , cat_cols ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" )] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u8fd9\u662f\u4ece\u5206\u7c7b\u5217\u4e2d\u521b\u5efa\u7279\u5f81\u7684\u4e00\u79cd\u975e\u5e38\u5e7c\u7a1a\u7684\u65b9\u6cd5\u3002\u6211\u4eec\u5e94\u8be5\u4ed4\u7ec6\u7814\u7a76\u6570\u636e\uff0c\u770b\u770b\u54ea\u4e9b\u7ec4\u5408\u6700\u5408\u7406\u3002\u5982\u679c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5c31\u9700\u8981\u4f7f\u7528\u67d0\u79cd\u7279\u5f81\u9009\u62e9\u6765\u9009\u51fa\u6700\u4f73\u7279\u5f81\u3002\u7a0d\u540e\u6211\u4eec\u5c06\u8be6\u7ec6\u4ecb\u7ecd\u7279\u5f81\u9009\u62e9\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5206\u6570\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9211483465031423 Fold = 1 , AUC = 0.9251499446866125 Fold = 2 , AUC = 0.9262344766486692 Fold = 3 , AUC = 0.9114264068794995 Fold = 4 , AUC = 0.9177914453099201 \u770b\u6765\uff0c\u5373\u4f7f\u4e0d\u6539\u53d8\u4efb\u4f55\u8d85\u53c2\u6570\uff0c\u53ea\u589e\u52a0\u4e00\u4e9b\u7279\u5f81\uff0c\u6211\u4eec\u4e5f\u80fd\u63d0\u9ad8\u4e00\u4e9b\u6298\u53e0\u5f97\u5206\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5c06 max_depth \u589e\u52a0\u5230 7 \u662f\u5426\u6709\u5e2e\u52a9\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9286668430204137 Fold = 1 , AUC = 0.9329340656165378 Fold = 2 , AUC = 0.9319817543218744 Fold = 3 , AUC = 0.919046187194538 Fold = 4 , AUC = 0.9245692057162671 \u6211\u4eec\u518d\u6b21\u6539\u8fdb\u4e86\u6211\u4eec\u7684\u6a21\u578b\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u8fd8\u6ca1\u6709\u4f7f\u7528\u7a00\u6709\u503c\u3001\u4e8c\u503c\u5316\u3001\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u7279\u5f81\u7684\u7ec4\u5408\u4ee5\u53ca\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u3002 \u4ece\u5206\u7c7b\u7279\u5f81\u4e2d\u8fdb\u884c\u7279\u5f81\u5de5\u7a0b\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u4f7f\u7528 \u76ee\u6807\u7f16\u7801 \u3002\u4f46\u662f\uff0c\u60a8\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u8fd9\u53ef\u80fd\u4f1a\u4f7f\u60a8\u7684\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u76ee\u6807\u7f16\u7801\u662f\u4e00\u79cd\u5c06\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u6620\u5c04\u5230\u5176\u5e73\u5747\u76ee\u6807\u503c\u7684\u6280\u672f\uff0c\u4f46\u5fc5\u987b\u59cb\u7ec8\u4ee5\u4ea4\u53c9\u9a8c\u8bc1\u7684\u65b9\u5f0f\u8fdb\u884c\u3002\u8fd9\u610f\u5473\u7740\u9996\u5148\u8981\u521b\u5efa\u6298\u53e0\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u6298\u53e0\u4e3a\u6570\u636e\u7684\u4e0d\u540c\u5217\u521b\u5efa\u76ee\u6807\u7f16\u7801\u7279\u5f81\uff0c\u65b9\u6cd5\u4e0e\u5728\u6298\u53e0\u4e0a\u62df\u5408\u548c\u9884\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u521b\u5efa\u4e86 5 \u4e2a\u6298\u53e0\uff0c\u60a8\u5c31\u5fc5\u987b\u521b\u5efa 5 \u6b21\u76ee\u6807\u7f16\u7801\uff0c\u8fd9\u6837\u6700\u7ec8\uff0c\u60a8\u5c31\u53ef\u4ee5\u4e3a\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u53d8\u91cf\u521b\u5efa\u7f16\u7801\uff0c\u800c\u8fd9\u4e9b\u53d8\u91cf\u5e76\u975e\u6765\u81ea\u540c\u4e00\u4e2a\u6298\u53e0\u3002\u7136\u540e\u5728\u62df\u5408\u6a21\u578b\u65f6\uff0c\u5fc5\u987b\u518d\u6b21\u4f7f\u7528\u76f8\u540c\u7684\u6298\u53e0\u3002\u672a\u89c1\u6d4b\u8bd5\u6570\u636e\u7684\u76ee\u6807\u7f16\u7801\u53ef\u4ee5\u6765\u81ea\u5168\u90e8\u8bad\u7ec3\u6570\u636e\uff0c\u4e5f\u53ef\u4ee5\u662f\u6240\u6709 5 \u4e2a\u6298\u53e0\u7684\u5e73\u5747\u503c\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u540c\u4e00\u4e2a\u6210\u4eba\u6570\u636e\u96c6\u4e0a\u4f7f\u7528\u76ee\u6807\u7f16\u7801\uff0c\u4ee5\u4fbf\u8fdb\u884c\u6bd4\u8f83\u3002 import copy import pandas as pd from sklearn import metrics from sklearn import preprocessing import xgboost as xgb def mean_target_encoding ( data ): df = copy . deepcopy ( data ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) and f not in num_cols ] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) encoded_dfs = [] for fold in range ( 5 ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) for column in features : # \u76ee\u6807\u7f16\u7801 mapping_dict = dict ( df_train . groupby ( column )[ \"income\" ] . mean () ) df_valid . loc [:, column + \"_enc\" ] = df_valid [ column ] . map ( mapping_dict ) encoded_dfs . append ( df_valid ) encoded_df = pd . concat ( encoded_dfs , axis = 0 ) return encoded_df def run ( df , fold ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/adult_folds.csv\" ) df = mean_target_encoding ( df ) for fold_ in range ( 5 ): run ( df , fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u4e0a\u8ff0\u7247\u6bb5\u4e2d\uff0c\u6211\u5728\u8fdb\u884c\u76ee\u6807\u7f16\u7801\u65f6\u5e76\u6ca1\u6709\u5220\u9664\u5206\u7c7b\u5217\u3002\u6211\u4fdd\u7559\u4e86\u6240\u6709\u7279\u5f81\uff0c\u5e76\u5728\u6b64\u57fa\u7840\u4e0a\u6dfb\u52a0\u4e86\u76ee\u6807\u7f16\u7801\u7279\u5f81\u3002\u6b64\u5916\uff0c\u6211\u8fd8\u4f7f\u7528\u4e86\u5e73\u5747\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5e73\u5747\u503c\u3001\u4e2d\u4f4d\u6570\u3001\u6807\u51c6\u504f\u5dee\u6216\u76ee\u6807\u7684\u4efb\u4f55\u5176\u4ed6\u51fd\u6570\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u7ed3\u679c\u3002 Fold = 0 , AUC = 0.9332240662017529 Fold = 1 , AUC = 0.9363551625140347 Fold = 2 , AUC = 0.9375013544556173 Fold = 3 , AUC = 0.92237621307625 Fold = 4 , AUC = 0.9292131180445478 \u4e0d\u9519\uff01\u770b\u6765\u6211\u4eec\u53c8\u6709\u8fdb\u6b65\u4e86\u3002\u4e0d\u8fc7\uff0c\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u5b83\u592a\u5bb9\u6613\u51fa\u73b0\u8fc7\u5ea6\u62df\u5408\u3002\u5f53\u6211\u4eec\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\uff0c\u6700\u597d\u4f7f\u7528\u67d0\u79cd\u5e73\u6ed1\u65b9\u6cd5\u6216\u5728\u7f16\u7801\u503c\u4e2d\u6dfb\u52a0\u566a\u58f0\u3002 Scikit-learn \u7684\u8d21\u732e\u5e93\u4e2d\u6709\u5e26\u5e73\u6ed1\u7684\u76ee\u6807\u7f16\u7801\uff0c\u4f60\u4e5f\u53ef\u4ee5\u521b\u5efa\u81ea\u5df1\u7684\u5e73\u6ed1\u3002\u5e73\u6ed1\u4f1a\u5f15\u5165\u67d0\u79cd\u6b63\u5219\u5316\uff0c\u6709\u52a9\u4e8e\u907f\u514d\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u8fd9\u5e76\u4e0d\u96be\u3002 \u5904\u7406\u5206\u7c7b\u7279\u5f81\u662f\u4e00\u9879\u590d\u6742\u7684\u4efb\u52a1\u3002\u8bb8\u591a\u8d44\u6e90\u4e2d\u90fd\u6709\u5927\u91cf\u4fe1\u606f\u3002\u672c\u7ae0\u5e94\u8be5\u80fd\u5e2e\u52a9\u4f60\u5f00\u59cb\u89e3\u51b3\u5206\u7c7b\u53d8\u91cf\u7684\u4efb\u4f55\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u5bf9\u4e8e\u5927\u591a\u6570\u95ee\u9898\u6765\u8bf4\uff0c\u9664\u4e86\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u4e4b\u5916\uff0c\u4f60\u4e0d\u9700\u8981\u66f4\u591a\u7684\u4e1c\u897f\u3002 \u8981\u8fdb\u4e00\u6b65\u6539\u8fdb\u6a21\u578b\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u66f4\u591a\uff01 \u5728\u672c\u7ae0\u7684\u6700\u540e\uff0c\u6211\u4eec\u4e0d\u80fd\u4e0d\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e00\u79cd\u79f0\u4e3a \u5b9e\u4f53\u5d4c\u5165 \u7684\u6280\u672f\u3002\u5728\u5b9e\u4f53\u5d4c\u5165\u4e2d\uff0c\u7c7b\u522b\u7528\u5411\u91cf\u8868\u793a\u3002\u5728\u4e8c\u503c\u5316\u548c\u72ec\u70ed\u7f16\u7801\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u90fd\u662f\u7528\u5411\u91cf\u6765\u8868\u793a\u7c7b\u522b\u7684\u3002 \u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u6709\u6570\u4ee5\u4e07\u8ba1\u7684\u7c7b\u522b\u600e\u4e48\u529e\uff1f\u8fd9\u5c06\u4f1a\u4ea7\u751f\u5de8\u5927\u7684\u77e9\u9635\uff0c\u6211\u4eec\u5c06\u9700\u8981\u5f88\u957f\u65f6\u95f4\u6765\u8bad\u7ec3\u590d\u6742\u7684\u6a21\u578b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5e26\u6709\u6d6e\u70b9\u503c\u7684\u5411\u91cf\u6765\u8868\u793a\u5b83\u4eec\u3002 \u8fd9\u4e2a\u60f3\u6cd5\u975e\u5e38\u7b80\u5355\u3002\u6bcf\u4e2a\u5206\u7c7b\u7279\u5f81\u90fd\u6709\u4e00\u4e2a\u5d4c\u5165\u5c42\u3002\u56e0\u6b64\uff0c\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u73b0\u5728\u90fd\u53ef\u4ee5\u6620\u5c04\u5230\u4e00\u4e2a\u5d4c\u5165\u5c42\uff08\u5c31\u50cf\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u5c06\u5355\u8bcd\u6620\u5c04\u5230\u5d4c\u5165\u5c42\u4e00\u6837\uff09\u3002\u7136\u540e\uff0c\u6839\u636e\u5176\u7ef4\u5ea6\u91cd\u5851\u8fd9\u4e9b\u5d4c\u5165\u5c42\uff0c\u4f7f\u5176\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c06\u6240\u6709\u6241\u5e73\u5316\u7684\u8f93\u5165\u5d4c\u5165\u5c42\u8fde\u63a5\u8d77\u6765\u3002\u7136\u540e\u6dfb\u52a0\u4e00\u5806\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\uff0c\u5c31\u5927\u529f\u544a\u6210\u4e86\u3002 \u56fe 6\uff1a\u7c7b\u522b\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6216\u5d4c\u5165\u5411\u91cf \u51fa\u4e8e\u67d0\u79cd\u539f\u56e0\uff0c\u6211\u53d1\u73b0\u4f7f\u7528 TF/Keras \u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528 TF/Keras \u5b9e\u73b0\u5b83\u3002\u6b64\u5916\uff0c\u8fd9\u662f\u672c\u4e66\u4e2d\u552f\u4e00\u4e00\u4e2a\u4f7f\u7528 TF/Keras \u7684\u793a\u4f8b\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a PyTorch\uff08\u4f7f\u7528 cat-in-the-dat-ii \u6570\u636e\u96c6\uff09\u4e5f\u975e\u5e38\u5bb9\u6613 import os import gc import joblib import pandas as pd import numpy as np from sklearn import metrics , preprocessing from tensorflow.keras import layers from tensorflow.keras import optimizers from tensorflow.keras.models import Model , load_model from tensorflow.keras import callbacks from tensorflow.keras import backend as K from tensorflow.keras import utils def create_model ( data , catcols ): # \u521b\u5efa\u7a7a\u7684\u8f93\u5165\u5217\u8868\u548c\u8f93\u51fa\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6a21\u578b\u7684\u8f93\u5165\u548c\u8f93\u51fa inputs = [] outputs = [] # \u904d\u5386\u5206\u7c7b\u7279\u5f81\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81 for c in catcols : # \u8ba1\u7b97\u7279\u5f81\u4e2d\u552f\u4e00\u503c\u7684\u6570\u91cf num_unique_values = int ( data [ c ] . nunique ()) # \u8ba1\u7b97\u5d4c\u5165\u7ef4\u5ea6\uff0c\u6700\u5927\u4e0d\u8d85\u8fc750 embed_dim = int ( min ( np . ceil (( num_unique_values ) / 2 ), 50 )) # \u521b\u5efa\u6a21\u578b\u7684\u8f93\u5165\u5c42\uff0c\u6bcf\u4e2a\u7279\u5f81\u5bf9\u5e94\u4e00\u4e2a\u8f93\u5165 inp = layers . Input ( shape = ( 1 ,)) # \u521b\u5efa\u5d4c\u5165\u5c42\uff0c\u5c06\u5206\u7c7b\u7279\u5f81\u6620\u5c04\u5230\u4f4e\u7ef4\u5ea6\u7684\u8fde\u7eed\u5411\u91cf out = layers . Embedding ( num_unique_values + 1 , embed_dim , name = c )( inp ) # \u5bf9\u5d4c\u5165\u5c42\u8fdb\u884c\u7a7a\u95f4\u4e22\u5f03\uff08Dropout\uff09 out = layers . SpatialDropout1D ( 0.3 )( out ) # \u5c06\u5d4c\u5165\u5c42\u7684\u5f62\u72b6\u91cd\u65b0\u8c03\u6574\u4e3a\u4e00\u7ef4 out = layers . Reshape ( target_shape = ( embed_dim ,))( out ) # \u5c06\u8f93\u5165\u548c\u8f93\u51fa\u6dfb\u52a0\u5230\u5bf9\u5e94\u7684\u5217\u8868\u4e2d inputs . append ( inp ) outputs . append ( out ) # \u4f7f\u7528Concatenate\u5c42\u5c06\u6240\u6709\u7684\u5d4c\u5165\u5c42\u8f93\u51fa\u8fde\u63a5\u5728\u4e00\u8d77 x = layers . Concatenate ()( outputs ) # \u5bf9\u8fde\u63a5\u540e\u7684\u6570\u636e\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u53e6\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u8f93\u51fa\u5c42\uff0c\u5177\u67092\u4e2a\u795e\u7ecf\u5143\uff08\u7528\u4e8e\u4e8c\u8fdb\u5236\u5206\u7c7b\uff09\uff0c\u5e76\u4f7f\u7528softmax\u6fc0\u6d3b\u51fd\u6570 y = layers . Dense ( 2 , activation = \"softmax\" )( x ) # \u521b\u5efa\u6a21\u578b\uff0c\u5c06\u8f93\u5165\u548c\u8f93\u51fa\u4f20\u9012\u7ed9Model\u6784\u9020\u51fd\u6570 model = Model ( inputs = inputs , outputs = y ) # \u7f16\u8bd1\u6a21\u578b\uff0c\u6307\u5b9a\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668 model . compile ( loss = 'binary_crossentropy' , optimizer = 'adam' ) # \u8fd4\u56de\u521b\u5efa\u7684\u6a21\u578b return model def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for feat in features : lbl_enc = preprocessing . LabelEncoder () df . loc [:, feat ] = lbl_enc . fit_transform ( df [ feat ] . values ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) model = create_model ( df , features ) xtrain = [ df_train [ features ] . values [:, k ] for k in range ( len ( features ))] xvalid = [ df_valid [ features ] . values [:, k ] for k in range ( len ( features )) ] ytrain = df_train . target . values yvalid = df_valid . target . values ytrain_cat = utils . to_categorical ( ytrain ) yvalid_cat = utils . to_categorical ( yvalid ) model . fit ( xtrain , ytrain_cat , validation_data = ( xvalid , yvalid_cat ), verbose = 1 , batch_size = 1024 , epochs = 3 ) valid_preds = model . predict ( xvalid )[:, 1 ] print ( metrics . roc_auc_score ( yvalid , valid_preds )) K . clear_session () if __name__ == \"__main__\" : run ( 0 ) run ( 1 ) run ( 2 ) run ( 3 ) run ( 4 ) \u4f60\u4f1a\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u6548\u679c\u6700\u597d\uff0c\u800c\u4e14\u5982\u679c\u4f60\u6709 GPU\uff0c\u901f\u5ea6\u4e5f\u8d85\u5feb\uff01\u8fd9\u79cd\u65b9\u6cd5\u8fd8\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\uff0c\u800c\u4e14\u4f60\u65e0\u9700\u62c5\u5fc3\u7279\u5f81\u5de5\u7a0b\uff0c\u56e0\u4e3a\u795e\u7ecf\u7f51\u7edc\u4f1a\u81ea\u884c\u5904\u7406\u3002\u5728\u5904\u7406\u5927\u91cf\u5206\u7c7b\u7279\u5f81\u6570\u636e\u96c6\u65f6\uff0c\u8fd9\u7edd\u5bf9\u503c\u5f97\u4e00\u8bd5\u3002\u5f53\u5d4c\u5165\u5927\u5c0f\u4e0e\u552f\u4e00\u7c7b\u522b\u7684\u6570\u91cf\u76f8\u540c\u65f6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u7f16\u7801\uff08one-hot-encoding\uff09\u3002 \u672c\u7ae0\u57fa\u672c\u4e0a\u90fd\u662f\u5173\u4e8e\u7279\u5f81\u5de5\u7a0b\u7684\u3002\u8ba9\u6211\u4eec\u5728\u4e0b\u4e00\u7ae0\u4e2d\u770b\u770b\u5982\u4f55\u5728\u6570\u5b57\u7279\u5f81\u548c\u4e0d\u540c\u7c7b\u578b\u7279\u5f81\u7684\u7ec4\u5408\u65b9\u9762\u8fdb\u884c\u66f4\u591a\u7684\u7279\u5f81\u5de5\u7a0b\u3002","title":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf"},{"location":"%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%88%96%E5%9B%9E%E5%BD%92%E6%96%B9%E6%B3%95/","text":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5 \u6587\u672c\u95ee\u9898\u662f\u6211\u7684\u6700\u7231\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u8fd9\u4e9b\u95ee\u9898\u4e5f\u88ab\u79f0\u4e3a \u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u95ee\u9898 \u3002NLP \u95ee\u9898\u4e0e\u56fe\u50cf\u95ee\u9898\u4e5f\u6709\u5f88\u5927\u4e0d\u540c\u3002\u4f60\u9700\u8981\u521b\u5efa\u4ee5\u524d\u4ece\u672a\u4e3a\u8868\u683c\u95ee\u9898\u521b\u5efa\u8fc7\u7684\u6570\u636e\u7ba1\u9053\u3002\u4f60\u9700\u8981\u4e86\u89e3\u5546\u4e1a\u6848\u4f8b\uff0c\u624d\u80fd\u5efa\u7acb\u4e00\u4e2a\u597d\u7684\u6a21\u578b\u3002\u987a\u4fbf\u8bf4\u4e00\u53e5\uff0c\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u4efb\u4f55\u4e8b\u60c5\u90fd\u662f\u5982\u6b64\u3002\u5efa\u7acb\u6a21\u578b\u4f1a\u8ba9\u4f60\u8fbe\u5230\u4e00\u5b9a\u7684\u6c34\u5e73\uff0c\u4f46\u8981\u60f3\u6539\u5584\u548c\u4fc3\u8fdb\u4f60\u6240\u5efa\u7acb\u6a21\u578b\u7684\u4e1a\u52a1\uff0c\u4f60\u5fc5\u987b\u4e86\u89e3\u5b83\u5bf9\u4e1a\u52a1\u7684\u5f71\u54cd\u3002 NLP \u95ee\u9898\u6709\u5f88\u591a\u79cd\uff0c\u5176\u4e2d\u6700\u5e38\u89c1\u7684\u662f\u5b57\u7b26\u4e32\u5206\u7c7b\u3002\u5f88\u591a\u65f6\u5019\uff0c\u6211\u4eec\u4f1a\u770b\u5230\u4eba\u4eec\u5728\u5904\u7406\u8868\u683c\u6570\u636e\u6216\u56fe\u50cf\u65f6\u8868\u73b0\u51fa\u8272\uff0c\u4f46\u5728\u5904\u7406\u6587\u672c\u65f6\uff0c\u4ed6\u4eec\u751a\u81f3\u4e0d\u77e5\u9053\u4ece\u4f55\u5165\u624b\u3002\u6587\u672c\u6570\u636e\u4e0e\u5176\u4ed6\u7c7b\u578b\u7684\u6570\u636e\u96c6\u6ca1\u6709\u4ec0\u4e48\u4e0d\u540c\u3002\u5bf9\u4e8e\u8ba1\u7b97\u673a\u6765\u8bf4\uff0c\u4e00\u5207\u90fd\u662f\u6570\u5b57\u3002 \u5047\u8bbe\u6211\u4eec\u4ece\u60c5\u611f\u5206\u7c7b\u8fd9\u4e00\u57fa\u672c\u4efb\u52a1\u5f00\u59cb\u3002\u6211\u4eec\u5c06\u5c1d\u8bd5\u5bf9\u7535\u5f71\u8bc4\u8bba\u8fdb\u884c\u60c5\u611f\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u6709\u4e00\u4e2a\u6587\u672c\uff0c\u5e76\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u60c5\u611f\u3002\u4f60\u5c06\u5982\u4f55\u5904\u7406\u8fd9\u7c7b\u95ee\u9898\uff1f\u662f\u5e94\u7528\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff1f \u4e0d\uff0c\u7edd\u5bf9\u9519\u4e86\u3002\u4f60\u8981\u4ece\u6700\u57fa\u672c\u7684\u5f00\u59cb\u3002\u8ba9\u6211\u4eec\u5148\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u4ec0\u4e48\u6837\u5b50\u7684\u3002 \u6211\u4eec\u4ece IMDB \u7535\u5f71\u8bc4\u8bba\u6570\u636e\u96c6 \u5f00\u59cb\uff0c\u8be5\u6570\u636e\u96c6\u5305\u542b 25000 \u7bc7\u6b63\u9762\u60c5\u611f\u8bc4\u8bba\u548c 25000 \u7bc7\u8d1f\u9762\u60c5\u611f\u8bc4\u8bba\u3002 \u6211\u5c06\u5728\u6b64\u8ba8\u8bba\u7684\u6982\u5ff5\u51e0\u4e4e\u9002\u7528\u4e8e\u4efb\u4f55\u6587\u672c\u5206\u7c7b\u6570\u636e\u96c6\u3002 \u8fd9\u4e2a\u6570\u636e\u96c6\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4e00\u7bc7\u8bc4\u8bba\u5bf9\u5e94\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5199\u7684\u662f\u8bc4\u8bba\u800c\u4e0d\u662f\u53e5\u5b50\u3002\u8bc4\u8bba\u5c31\u662f\u4e00\u5806\u53e5\u5b50\u3002\u6240\u4ee5\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4f60\u4e00\u5b9a\u53ea\u770b\u5230\u4e86\u5bf9\u5355\u53e5\u7684\u5206\u7c7b\uff0c\u4f46\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u591a\u4e2a\u53e5\u5b50\u8fdb\u884c\u5206\u7c7b\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u8fd9\u610f\u5473\u7740\u4e0d\u4ec5\u4e00\u4e2a\u53e5\u5b50\u4f1a\u5bf9\u60c5\u611f\u4ea7\u751f\u5f71\u54cd\uff0c\u800c\u4e14\u60c5\u611f\u5f97\u5206\u662f\u591a\u4e2a\u53e5\u5b50\u5f97\u5206\u7684\u7ec4\u5408\u3002\u6570\u636e\u7b80\u4ecb\u5982\u56fe 1 \u6240\u793a\u3002 \u5982\u4f55\u7740\u624b\u89e3\u51b3\u8fd9\u6837\u7684\u95ee\u9898\uff1f\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u624b\u5de5\u5236\u4f5c\u4e24\u4efd\u5355\u8bcd\u8868\u3002\u4e00\u4e2a\u5217\u8868\u5305\u542b\u4f60\u80fd\u60f3\u8c61\u5230\u7684\u6240\u6709\u6b63\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u597d\u3001\u68d2\u3001\u597d\u7b49\uff1b\u53e6\u4e00\u4e2a\u5217\u8868\u5305\u542b\u6240\u6709\u8d1f\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u574f\u3001\u6076\u7b49\u3002\u6211\u4eec\u5148\u4e0d\u8981\u4e3e\u4f8b\u8bf4\u660e\u574f\u8bcd\uff0c\u5426\u5219\u8fd9\u672c\u4e66\u5c31\u53ea\u80fd\u4f9b 18 \u5c81\u4ee5\u4e0a\u7684\u4eba\u9605\u8bfb\u4e86\u3002\u4e00\u65e6\u4f60\u6709\u4e86\u8fd9\u4e9b\u5217\u8868\uff0c\u4f60\u751a\u81f3\u4e0d\u9700\u8981\u4e00\u4e2a\u6a21\u578b\u6765\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u4e9b\u5217\u8868\u4e5f\u88ab\u79f0\u4e3a\u60c5\u611f\u8bcd\u5178\u3002\u4f60\u53ef\u4ee5\u7528\u4e00\u4e2a\u7b80\u5355\u7684\u8ba1\u6570\u5668\u6765\u8ba1\u7b97\u53e5\u5b50\u4e2d\u6b63\u9762\u548c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u3002\u5982\u679c\u6b63\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u6b63\u9762\u60c5\u611f\uff1b\u5982\u679c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u8d1f\u9762\u60c5\u611f\u3002\u5982\u679c\u53e5\u5b50\u4e2d\u6ca1\u6709\u8fd9\u4e9b\u8bcd\uff0c\u5219\u53ef\u4ee5\u8bf4\u8be5\u53e5\u5b50\u5177\u6709\u4e2d\u6027\u60c5\u611f\u3002\u8fd9\u662f\u6700\u53e4\u8001\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u73b0\u5728\u4ecd\u6709\u4eba\u5728\u4f7f\u7528\u3002\u5b83\u4e5f\u4e0d\u9700\u8981\u592a\u591a\u4ee3\u7801\u3002 def find_sentiment ( sentence , pos , neg ): sentence = sentence . split () sentence = set ( sentence ) num_common_pos = len ( sentence . intersection ( pos )) num_common_neg = len ( sentence . intersection ( neg )) if num_common_pos > num_common_neg : return \"positive\" if num_common_pos < num_common_neg : return \"negative\" return \"neutral\" \u4e0d\u8fc7\uff0c\u8fd9\u79cd\u65b9\u6cd5\u8003\u8651\u7684\u56e0\u7d20\u5e76\u4e0d\u591a\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u7684 split() \u4e5f\u5e76\u4e0d\u5b8c\u7f8e\u3002\u5982\u679c\u4f7f\u7528 split()\uff0c\u5c31\u4f1a\u51fa\u73b0\u8fd9\u6837\u7684\u53e5\u5b50\uff1a \"hi, how are you?\" \u7ecf\u8fc7\u5206\u5272\u540e\u53d8\u4e3a\uff1a [\"hi,\", \"how\",\"are\",\"you?\"] \u8fd9\u79cd\u65b9\u6cd5\u5e76\u4e0d\u7406\u60f3\uff0c\u56e0\u4e3a\u5355\u8bcd\u4e2d\u5305\u542b\u4e86\u9017\u53f7\u548c\u95ee\u53f7\uff0c\u5b83\u4eec\u5e76\u6ca1\u6709\u88ab\u5206\u5272\u3002\u56e0\u6b64\uff0c\u5982\u679c\u6ca1\u6709\u5728\u5206\u5272\u524d\u5bf9\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u8fdb\u884c\u9884\u5904\u7406\uff0c\u4e0d\u5efa\u8bae\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\u3002\u5c06\u5b57\u7b26\u4e32\u62c6\u5206\u4e3a\u5355\u8bcd\u5217\u8868\u79f0\u4e3a\u6807\u8bb0\u5316\u3002\u6700\u6d41\u884c\u7684\u6807\u8bb0\u5316\u65b9\u6cd5\u4e4b\u4e00\u6765\u81ea NLTK\uff08\u81ea\u7136\u8bed\u8a00\u5de5\u5177\u5305\uff09 \u3002 In [ X ]: from nltk.tokenize import word_tokenize In [ X ]: sentence = \"hi, how are you?\" In [ X ]: sentence . split () Out [ X ]: [ 'hi,' , 'how' , 'are' , 'you?' ] In [ X ]: word_tokenize ( sentence ) Out [ X ]: [ 'hi' , ',' , 'how' , 'are' , 'you' , '?' ] \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u4f7f\u7528 NLTK \u7684\u5355\u8bcd\u6807\u8bb0\u5316\u529f\u80fd\uff0c\u540c\u4e00\u4e2a\u53e5\u5b50\u7684\u62c6\u5206\u6548\u679c\u8981\u597d\u5f97\u591a\u3002\u4f7f\u7528\u5355\u8bcd\u5217\u8868\u8fdb\u884c\u5bf9\u6bd4\u7684\u6548\u679c\u4e5f\u4f1a\u66f4\u597d\uff01\u8fd9\u5c31\u662f\u6211\u4eec\u5c06\u5e94\u7528\u4e8e\u7b2c\u4e00\u4e2a\u60c5\u611f\u68c0\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u3002 \u5728\u5904\u7406 NLP \u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u60a8\u5e94\u8be5\u7ecf\u5e38\u5c1d\u8bd5\u7684\u57fa\u672c\u6a21\u578b\u4e4b\u4e00\u662f \u8bcd\u888b\u6a21\u578b\uff08bag of words\uff09 \u3002\u5728\u8bcd\u888b\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u5de8\u5927\u7684\u7a00\u758f\u77e9\u9635\uff0c\u5b58\u50a8\u8bed\u6599\u5e93\uff08\u8bed\u6599\u5e93=\u6240\u6709\u6587\u6863=\u6240\u6709\u53e5\u5b50\uff09\u4e2d\u6240\u6709\u5355\u8bcd\u7684\u8ba1\u6570\u3002\u4e3a\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 CountVectorizer\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u662f\u5982\u4f55\u5de5\u4f5c\u7684\u3002 from sklearn.feature_extraction.text import CountVectorizer corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer () ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) \u5982\u679c\u6211\u4eec\u6253\u5370 corpus_transformed\uff0c\u5c31\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e0b\u9762\u7684\u7ed3\u679c\uff1a ( 0 , 2 ) 1 ( 0 , 9 ) 1 ( 0 , 11 ) 1 ( 0 , 22 ) 1 ( 1 , 1 ) 1 ( 1 , 3 ) 1 ( 1 , 4 ) 1 ( 1 , 7 ) 1 ( 1 , 8 ) 1 ( 1 , 10 ) 1 ( 1 , 13 ) 1 ( 1 , 17 ) 1 ( 1 , 19 ) 1 ( 1 , 22 ) 2 ( 2 , 0 ) 1 ( 2 , 5 ) 1 ( 2 , 6 ) 1 ( 2 , 14 ) 1 ( 2 , 22 ) 1 ( 3 , 12 ) 1 ( 3 , 15 ) 1 ( 3 , 16 ) 1 ( 3 , 18 ) 1 ( 3 , 20 ) 1 ( 4 , 21 ) 1 \u5728\u524d\u9762\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u89c1\u8bc6\u8fc7\u8fd9\u79cd\u8868\u793a\u6cd5\u3002\u5373\u7a00\u758f\u8868\u793a\u6cd5\u3002\u56e0\u6b64\uff0c\u8bed\u6599\u5e93\u73b0\u5728\u662f\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5176\u4e2d\u7b2c\u4e00\u4e2a\u6837\u672c\u6709 4 \u4e2a\u5143\u7d20\uff0c\u7b2c\u4e8c\u4e2a\u6837\u672c\u6709 10 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\uff0c\u7b2c\u4e09\u4e2a\u6837\u672c\u6709 5 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e9b\u5143\u7d20\u90fd\u6709\u76f8\u5173\u7684\u8ba1\u6570\u3002\u6709\u4e9b\u5143\u7d20\u4f1a\u51fa\u73b0\u4e24\u6b21\uff0c\u6709\u4e9b\u5219\u53ea\u6709\u4e00\u6b21\u3002\u4f8b\u5982\uff0c\u5728\u6837\u672c 2\uff08\u7b2c 1 \u884c\uff09\u4e2d\uff0c\u6211\u4eec\u770b\u5230\u7b2c 22 \u5217\u7684\u6570\u503c\u662f 2\u3002\u8fd9\u662f\u4e3a\u4ec0\u4e48\u5462\uff1f\u7b2c 22 \u5217\u662f\u4ec0\u4e48\uff1f CountVectorizer \u7684\u5de5\u4f5c\u65b9\u5f0f\u662f\u9996\u5148\u5bf9\u53e5\u5b50\u8fdb\u884c\u6807\u8bb0\u5316\u5904\u7406\uff0c\u7136\u540e\u4e3a\u6bcf\u4e2a\u6807\u8bb0\u8d4b\u503c\u3002\u56e0\u6b64\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u7531\u4e00\u4e2a\u552f\u4e00\u7d22\u5f15\u8868\u793a\u3002\u8fd9\u4e9b\u552f\u4e00\u7d22\u5f15\u5c31\u662f\u6211\u4eec\u770b\u5230\u7684\u5217\u3002CountVectorizer \u4f1a\u5b58\u50a8\u8fd9\u4e9b\u4fe1\u606f\u3002 print ( ctv . vocabulary_ ) { 'hello' : 9 , 'how' : 11 , 'are' : 2 , 'you' : 22 , 'im' : 13 , 'getting' : 8 , 'bored' : 4 , 'at' : 3 , 'home' : 10 , 'and' : 1 , 'what' : 19 , 'do' : 7 , 'think' : 17 , 'did' : 6 , 'know' : 14 , 'about' : 0 , 'counts' : 5 , 'let' : 15 , 'see' : 16 , 'if' : 12 , 'this' : 18 , 'works' : 20 , 'yes' : 21 } \u6211\u4eec\u770b\u5230\uff0c\u7d22\u5f15 22 \u5c5e\u4e8e \"you\"\uff0c\u800c\u5728\u7b2c\u4e8c\u53e5\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u4e24\u6b21 \"you\"\u3002\u6211\u5e0c\u671b\u5927\u5bb6\u73b0\u5728\u5df2\u7ecf\u6e05\u695a\u4ec0\u4e48\u662f\u8bcd\u888b\u4e86\u3002\u4f46\u662f\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4e9b\u7279\u6b8a\u5b57\u7b26\u3002\u6709\u65f6\uff0c\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u4e5f\u5f88\u6709\u7528\u3002\u4f8b\u5982\uff0c\"? \"\u5728\u5927\u591a\u6570\u53e5\u5b50\u4e2d\u8868\u793a\u7591\u95ee\u53e5\u3002\u8ba9\u6211\u4eec\u628a scikit-learn \u7684 word_tokenize \u6574\u5408\u5230 CountVectorizer \u4e2d\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002 from sklearn.feature_extraction.text import CountVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) print ( ctv . vocabulary_ ) \u8fd9\u6837\uff0c\u6211\u4eec\u7684\u8bcd\u888b\u5c31\u53d8\u6210\u4e86\uff1a { 'hello' : 14 , ',' : 2 , 'how' : 16 , 'are' : 7 , 'you' : 27 , '?' : 4 , 'im' : 18 , 'getting' : 13 , 'bored' : 9 , 'at' : 8 , 'home' : 15 , '.' : 3 , 'and' : 6 , 'what' : 24 , 'do' : 12 , 'think' : 22 , 'did' : 11 , 'know' : 19 , 'about' : 5 , 'counts' : 10 , 'let' : 20 , \"'s\" : 1 , 'see' : 21 , 'if' : 17 , 'this' : 23 , 'works' : 25 , '!' : 0 , 'yes' : 26 } \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u5229\u7528 IMDB \u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u53e5\u5b50\u521b\u5efa\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5e76\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002\u8be5\u6570\u636e\u96c6\u4e2d\u6b63\u8d1f\u6837\u672c\u7684\u6bd4\u4f8b\u4e3a 1:1\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u5e76\u521b\u5efa\u4e00\u4e2a\u811a\u672c\u6765\u8bad\u7ec35\u4e2a\u6298\u53e0\u3002\u4f60\u4f1a\u95ee\u4f7f\u7528\u54ea\u4e2a\u6a21\u578b\uff1f\u5bf9\u4e8e\u9ad8\u7ef4\u7a00\u758f\u6570\u636e\uff0c\u54ea\u4e2a\u6a21\u578b\u6700\u5feb\uff1f\u903b\u8f91\u56de\u5f52\u3002\u6211\u4eec\u5c06\u9996\u5148\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u6765\u5904\u7406\u8fd9\u4e2a\u6570\u636e\u96c6\uff0c\u5e76\u521b\u5efa\u7b2c\u4e00\u4e2a\u57fa\u51c6\u6a21\u578b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) count_vec = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) count_vec . fit ( train_df . review ) xtrain = count_vec . transform ( train_df . review ) xtest = count_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u6bb5\u4ee3\u7801\u7684\u8fd0\u884c\u9700\u8981\u4e00\u5b9a\u7684\u65f6\u95f4\uff0c\u4f46\u53ef\u4ee5\u5f97\u5230\u4ee5\u4e0b\u8f93\u51fa\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8903 Fold : 1 Accuracy = 0.897 Fold : 2 Accuracy = 0.891 Fold : 3 Accuracy = 0.8914 Fold : 4 Accuracy = 0.8931 \u54c7\uff0c\u51c6\u786e\u7387\u5df2\u7ecf\u8fbe\u5230 89%\uff0c\u800c\u6211\u4eec\u6240\u505a\u7684\u53ea\u662f\u4f7f\u7528\u8bcd\u888b\u548c\u903b\u8f91\u56de\u5f52\uff01\u8fd9\u771f\u662f\u592a\u68d2\u4e86\uff01\u4e0d\u8fc7\uff0c\u8fd9\u4e2a\u6a21\u578b\u7684\u8bad\u7ec3\u82b1\u8d39\u4e86\u5f88\u591a\u65f6\u95f4\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u80fd\u5426\u901a\u8fc7\u4f7f\u7528\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u6765\u7f29\u77ed\u8bad\u7ec3\u65f6\u95f4\u3002\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u5728 NLP \u4efb\u52a1\u4e2d\u76f8\u5f53\u6d41\u884c\uff0c\u56e0\u4e3a\u7a00\u758f\u77e9\u9635\u975e\u5e38\u5e9e\u5927\uff0c\u800c\u6734\u7d20\u8d1d\u53f6\u65af\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\u3002\u8981\u4f7f\u7528\u8fd9\u4e2a\u6a21\u578b\uff0c\u9700\u8981\u66f4\u6539\u4e00\u4e2a\u5bfc\u5165\u548c\u6a21\u578b\u7684\u884c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e2a\u6a21\u578b\u7684\u6027\u80fd\u5982\u4f55\u3002\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 MultinomialNB\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import naive_bayes from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer model = naive_bayes . MultinomialNB () model . fit ( xtrain , train_df . sentiment ) \u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8444 Fold : 1 Accuracy = 0.8499 Fold : 2 Accuracy = 0.8422 Fold : 3 Accuracy = 0.8443 Fold : 4 Accuracy = 0.8455 \u6211\u4eec\u770b\u5230\u8fd9\u4e2a\u5206\u6570\u5f88\u4f4e\u3002\u4f46\u6734\u7d20\u8d1d\u53f6\u65af\u6a21\u578b\u7684\u901f\u5ea6\u975e\u5e38\u5feb\u3002 NLP \u4e2d\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f TF-IDF\uff0c\u5982\u4eca\u5927\u591a\u6570\u4eba\u90fd\u503e\u5411\u4e8e\u5ffd\u7565\u6216\u4e0d\u5c51\u4e8e\u4e86\u89e3\u8fd9\u79cd\u65b9\u6cd5\u3002TF \u662f\u672f\u8bed\u9891\u7387\uff0cIDF \u662f\u53cd\u5411\u6587\u6863\u9891\u7387\u3002\u4ece\u8fd9\u4e9b\u672f\u8bed\u6765\u770b\uff0c\u8fd9\u4f3c\u4e4e\u6709\u4e9b\u56f0\u96be\uff0c\u4f46\u901a\u8fc7 TF \u548c IDF \u7684\u8ba1\u7b97\u516c\u5f0f\uff0c\u4e8b\u60c5\u5c31\u4f1a\u53d8\u5f97\u5f88\u660e\u663e\u3002 $$ TF(t) = \\frac{Number\\ of\\ times\\ a\\ term\\ t\\ appears\\ in\\ a\\ document}{Total\\ number\\ of\\ terms\\ in \\ the\\ document} $$ \\[ IDF(t) = LOG\\left(\\frac{Total\\ number\\ of\\ documents}{Number\\ of\\ documents with\\ term\\ t\\ in\\ it}\\right) \\] \u672f\u8bed t \u7684 TF-IDF \u5b9a\u4e49\u4e3a\uff1a $$ TF-IDF(t) = TF(t) \\times IDF(t) $$ \u4e0e scikit-learn \u4e2d\u7684 CountVectorizer \u7c7b\u4f3c\uff0c\u6211\u4eec\u4e5f\u6709 TfidfVectorizer\u3002\u8ba9\u6211\u4eec\u8bd5\u7740\u50cf\u4f7f\u7528 CountVectorizer \u4e00\u6837\u4f7f\u7528\u5b83\u3002 from sklearn.feature_extraction.text import TfidfVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) print ( corpus_transformed ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a ( 0 , 27 ) 0.2965698850220162 ( 0 , 16 ) 0.4428321995085722 ( 0 , 14 ) 0.4428321995085722 ( 0 , 7 ) 0.4428321995085722 ( 0 , 4 ) 0.35727423026525224 ( 0 , 2 ) 0.4428321995085722 ( 1 , 27 ) 0.35299699146792735 ( 1 , 24 ) 0.2635440111190765 ( 1 , 22 ) 0.2635440111190765 ( 1 , 18 ) 0.2635440111190765 ( 1 , 15 ) 0.2635440111190765 ( 1 , 13 ) 0.2635440111190765 ( 1 , 12 ) 0.2635440111190765 ( 1 , 9 ) 0.2635440111190765 ( 1 , 8 ) 0.2635440111190765 ( 1 , 6 ) 0.2635440111190765 ( 1 , 4 ) 0.42525129752567803 ( 1 , 3 ) 0.2635440111190765 ( 2 , 27 ) 0.31752680284846835 ( 2 , 19 ) 0.4741246485558491 ( 2 , 11 ) 0.4741246485558491 ( 2 , 10 ) 0.4741246485558491 ( 2 , 5 ) 0.4741246485558491 ( 3 , 25 ) 0.38775666010579296 ( 3 , 23 ) 0.38775666010579296 ( 3 , 21 ) 0.38775666010579296 ( 3 , 20 ) 0.38775666010579296 ( 3 , 17 ) 0.38775666010579296 ( 3 , 1 ) 0.38775666010579296 ( 3 , 0 ) 0.3128396318588854 ( 4 , 26 ) 0.2959842226518677 ( 4 , 0 ) 0.9551928286692534 \u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6b21\u6211\u4eec\u5f97\u5230\u7684\u4e0d\u662f\u6574\u6570\u503c\uff0c\u800c\u662f\u6d6e\u70b9\u6570\u3002 \u7528 TfidfVectorizer \u4ee3\u66ff CountVectorizer \u4e5f\u662f\u5c0f\u83dc\u4e00\u789f\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 TfidfTransformer\u3002\u5982\u679c\u4f60\u4f7f\u7528\u7684\u662f\u8ba1\u6570\u503c\uff0c\u53ef\u4ee5\u4f7f\u7528 TfidfTransformer \u5e76\u83b7\u5f97\u4e0e TfidfVectorizer \u76f8\u540c\u7684\u6548\u679c\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfidf_vec . fit ( train_df . review ) xtrain = tfidf_vec . transform ( train_df . review ) xtest = tfidf_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u6211\u4eec\u53ef\u4ee5\u770b\u770b TF-IDF \u5728\u903b\u8f91\u56de\u5f52\u6a21\u578b\u4e0a\u7684\u8868\u73b0\u5982\u4f55\u3002 Fold : 0 Accuracy = 0.8976 Fold : 1 Accuracy = 0.8998 Fold : 2 Accuracy = 0.8948 Fold : 3 Accuracy = 0.8912 Fold : 4 Accuracy = 0.8995 \u6211\u4eec\u770b\u5230\uff0c\u8fd9\u4e9b\u5206\u6570\u90fd\u6bd4 CountVectorizer \u9ad8\u4e00\u4e9b\uff0c\u56e0\u6b64\u5b83\u6210\u4e3a\u4e86\u6211\u4eec\u60f3\u8981\u51fb\u8d25\u7684\u65b0\u57fa\u51c6\u3002 NLP \u4e2d\u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u6982\u5ff5\u662f N-gram\u3002N-grams \u662f\u6309\u987a\u5e8f\u6392\u5217\u7684\u5355\u8bcd\u7ec4\u5408\u3002N-grams \u5f88\u5bb9\u6613\u521b\u5efa\u3002\u60a8\u53ea\u9700\u6ce8\u610f\u987a\u5e8f\u5373\u53ef\u3002\u4e3a\u4e86\u8ba9\u4e8b\u60c5\u53d8\u5f97\u66f4\u7b80\u5355\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 NLTK \u7684 N-gram \u5b9e\u73b0\u3002 from nltk import ngrams from nltk.tokenize import word_tokenize N = 3 sentence = \"hi, how are you?\" tokenized_sentence = word_tokenize ( sentence ) n_grams = list ( ngrams ( tokenized_sentence , N )) print ( n_grams ) \u7531\u6b64\u5f97\u5230\uff1a [( 'hi' , ',' , 'how' ), ( ',' , 'how' , 'are' ), ( 'how' , 'are' , 'you' ), ( 'are' , 'you' , '?' )] \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u521b\u5efa 2-gram \u6216 4-gram \u7b49\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b n-gram \u5c06\u6210\u4e3a\u6211\u4eec\u8bcd\u6c47\u8868\u7684\u4e00\u90e8\u5206\uff0c\u5f53\u6211\u4eec\u8ba1\u7b97\u8ba1\u6570\u6216 tf-idf \u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u4e00\u4e2a n-gram \u89c6\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u6807\u8bb0\u3002\u56e0\u6b64\uff0c\u5728\u67d0\u79cd\u7a0b\u5ea6\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u7ed3\u5408\u4e0a\u4e0b\u6587\u3002scikit-learn \u7684 CountVectorizer \u548c TfidfVectorizer \u5b9e\u73b0\u90fd\u901a\u8fc7 ngram_range \u53c2\u6570\u63d0\u4f9b n-gram\uff0c\u8be5\u53c2\u6570\u6709\u6700\u5c0f\u548c\u6700\u5927\u9650\u5236\u3002\u9ed8\u8ba4\u60c5\u51b5\u4e0b\uff0c\u8be5\u53c2\u6570\u4e3a\uff081, 1\uff09\u3002\u5f53\u6211\u4eec\u5c06\u5176\u6539\u4e3a (1, 3) \u65f6\uff0c\u6211\u4eec\u5c06\u770b\u5230\u5355\u5b57\u5143\u3001\u53cc\u5b57\u5143\u548c\u4e09\u5b57\u5143\u3002\u4ee3\u7801\u6539\u52a8\u5f88\u5c0f\u3002 \u7531\u4e8e\u5230\u76ee\u524d\u4e3a\u6b62\u6211\u4eec\u4f7f\u7528 tf-idf \u5f97\u5230\u4e86\u6700\u597d\u7684\u7ed3\u679c\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5305\u542b n-grams \u76f4\u81f3 trigrams \u662f\u5426\u80fd\u6539\u8fdb\u6a21\u578b\u3002\u552f\u4e00\u9700\u8981\u4fee\u6539\u7684\u662f TfidfVectorizer \u7684\u521d\u59cb\u5316\u3002 tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None , ngram_range = ( 1 , 3 ) ) \u8ba9\u6211\u4eec\u770b\u770b\u662f\u5426\u4f1a\u6709\u6539\u8fdb\u3002 Fold : 0 Accuracy = 0.8931 Fold : 1 Accuracy = 0.8941 Fold : 2 Accuracy = 0.897 Fold : 3 Accuracy = 0.8922 Fold : 4 Accuracy = 0.8847 \u770b\u8d77\u6765\u8fd8\u884c\uff0c\u4f46\u6211\u4eec\u770b\u4e0d\u5230\u4efb\u4f55\u6539\u8fdb\u3002 \u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u591a\u4f7f\u7528 bigrams \u6765\u83b7\u5f97\u6539\u8fdb\u3002 \u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u4e00\u90e8\u5206\u3002\u4e5f\u8bb8\u4f60\u53ef\u4ee5\u81ea\u5df1\u8bd5\u7740\u505a\u3002 NLP \u7684\u57fa\u7840\u77e5\u8bc6\u8fd8\u6709\u5f88\u591a\u3002\u4f60\u5fc5\u987b\u77e5\u9053\u7684\u4e00\u4e2a\u672f\u8bed\u662f\u8bcd\u5e72\u63d0\u53d6\uff08strmming\uff09\u3002\u53e6\u4e00\u4e2a\u662f\u8bcd\u5f62\u8fd8\u539f\uff08lemmatization\uff09\u3002 \u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f \u53ef\u4ee5\u5c06\u4e00\u4e2a\u8bcd\u51cf\u5c11\u5230\u6700\u5c0f\u5f62\u5f0f\u3002\u5728\u8bcd\u5e72\u63d0\u53d6\u7684\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5e72\u5355\u8bcd\uff0c\u800c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5f62\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u8bcd\u5f62\u8fd8\u539f\u6bd4\u8bcd\u5e72\u63d0\u53d6\u66f4\u6fc0\u8fdb\uff0c\u800c\u8bcd\u5e72\u63d0\u53d6\u66f4\u6d41\u884c\u548c\u5e7f\u6cdb\u3002\u8bcd\u5e72\u548c\u8bcd\u5f62\u90fd\u6765\u81ea\u8bed\u8a00\u5b66\u3002\u5982\u679c\u4f60\u6253\u7b97\u4e3a\u67d0\u79cd\u8bed\u8a00\u5236\u4f5c\u8bcd\u5e72\u6216\u8bcd\u578b\uff0c\u9700\u8981\u5bf9\u8be5\u8bed\u8a00\u6709\u6df1\u5165\u7684\u4e86\u89e3\u3002\u5982\u679c\u8981\u8fc7\u591a\u5730\u4ecb\u7ecd\u8fd9\u4e9b\u77e5\u8bc6\uff0c\u5c31\u610f\u5473\u7740\u8981\u5728\u672c\u4e66\u4e2d\u589e\u52a0\u4e00\u7ae0\u3002\u4f7f\u7528 NLTK \u8f6f\u4ef6\u5305\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e24\u79cd\u65b9\u6cd5\u7684\u4e00\u4e9b\u793a\u4f8b\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u5668\u3002\u6211\u5c06\u7528\u6700\u5e38\u89c1\u7684 Snowball Stemmer \u548c WordNet Lemmatizer \u6765\u4e3e\u4f8b\u8bf4\u660e\u3002 from nltk.stem import WordNetLemmatizer from nltk.stem.snowball import SnowballStemmer lemmatizer = WordNetLemmatizer () stemmer = SnowballStemmer ( \"english\" ) words = [ \"fishing\" , \"fishes\" , \"fished\" ] for word in words : print ( f \"word= { word } \" ) print ( f \"stemmed_word= { stemmer . stem ( word ) } \" ) print ( f \"lemma= { lemmatizer . lemmatize ( word ) } \" ) print ( \"\" ) \u8fd9\u5c06\u6253\u5370\uff1a word = fishing stemmed_word = fish lemma = fishing word = fishes stemmed_word = fish lemma = fish word = fished stemmed_word = fish lemma = fished \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u662f\u622a\u7136\u4e0d\u540c\u7684\u3002\u5f53\u6211\u4eec\u8fdb\u884c\u8bcd\u5e72\u63d0\u53d6\u65f6\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u4e00\u4e2a\u8bcd\u7684\u6700\u5c0f\u5f62\u5f0f\uff0c\u5b83\u53ef\u80fd\u662f\u4e5f\u53ef\u80fd\u4e0d\u662f\u8be5\u8bcd\u6240\u5c5e\u8bed\u8a00\u8bcd\u5178\u4e2d\u7684\u4e00\u4e2a\u8bcd\u3002\u4f46\u662f\uff0c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u8fd9\u5c06\u662f\u4e00\u4e2a\u8bcd\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5c1d\u8bd5\u6dfb\u52a0\u8bcd\u5e72\u548c\u8bcd\u7d20\u5316\uff0c\u770b\u770b\u662f\u5426\u80fd\u6539\u5584\u7ed3\u679c\u3002 \u60a8\u8fd8\u5e94\u8be5\u4e86\u89e3\u7684\u4e00\u4e2a\u4e3b\u9898\u662f\u4e3b\u9898\u63d0\u53d6\u3002 \u4e3b\u9898\u63d0\u53d6 \u53ef\u4ee5\u4f7f\u7528\u975e\u8d1f\u77e9\u9635\u56e0\u5f0f\u5206\u89e3\uff08NMF\uff09\u6216\u6f5c\u5728\u8bed\u4e49\u5206\u6790\uff08LSA\uff09\u6765\u5b8c\u6210\uff0c\u540e\u8005\u4e5f\u88ab\u79f0\u4e3a\u5947\u5f02\u503c\u5206\u89e3\u6216 SVD\u3002\u8fd9\u4e9b\u5206\u89e3\u6280\u672f\u53ef\u5c06\u6570\u636e\u7b80\u5316\u4e3a\u7ed9\u5b9a\u6570\u91cf\u7684\u6210\u5206\u3002 \u60a8\u53ef\u4ee5\u5728\u4ece CountVectorizer \u6216 TfidfVectorizer \u4e2d\u83b7\u5f97\u7684\u7a00\u758f\u77e9\u9635\u4e0a\u5e94\u7528\u5176\u4e2d\u4efb\u4f55\u4e00\u79cd\u6280\u672f\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u5e94\u7528\u5230\u4e4b\u524d\u4f7f\u7528\u8fc7\u7684 TfidfVetorizer \u4e0a\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import decomposition from sklearn.feature_extraction.text import TfidfVectorizer corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus = corpus . review . values tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) svd = decomposition . TruncatedSVD ( n_components = 10 ) corpus_svd = svd . fit ( corpus_transformed ) sample_index = 0 feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) N = 5 print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ]) \u60a8\u53ef\u4ee5\u4f7f\u7528\u5faa\u73af\u6765\u8fd0\u884c\u591a\u4e2a\u6837\u672c\u3002 N = 5 for sample_index in range ( 5 ): feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ] ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a [ 'the' , ',' , '.' , 'a' , 'and' ] [ 'br' , '<' , '>' , '/' , '-' ] [ 'i' , 'movie' , '!' , 'it' , 'was' ] [ ',' , '!' , \"''\" , '``' , 'you' ] [ '!' , 'the' , '...' , \"''\" , '``' ] \u4f60\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6839\u672c\u8bf4\u4e0d\u901a\u3002\u600e\u4e48\u529e\u5462\uff1f\u8ba9\u6211\u4eec\u8bd5\u7740\u6e05\u7406\u4e00\u4e0b\uff0c\u770b\u770b\u662f\u5426\u6709\u610f\u4e49\u3002\u8981\u6e05\u7406\u4efb\u4f55\u6587\u672c\u6570\u636e\uff0c\u5c24\u5176\u662f pandas \u6570\u636e\u5e27\u4e2d\u7684\u6587\u672c\u6570\u636e\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u51fd\u6570\u3002 import re import string def clean_text ( s ): s = s . split () s = \" \" . join ( s ) s = re . sub ( f '[ { re . escape ( string . punctuation ) } ]' , '' , s ) return s \u8be5\u51fd\u6570\u4f1a\u5c06 \"hi, how are you????\" \u8fd9\u6837\u7684\u5b57\u7b26\u4e32\u8f6c\u6362\u4e3a \"hi how are you\"\u3002\u8ba9\u6211\u4eec\u628a\u8fd9\u4e2a\u51fd\u6570\u5e94\u7528\u5230\u65e7\u7684 SVD \u4ee3\u7801\u4e2d\uff0c\u770b\u770b\u5b83\u662f\u5426\u80fd\u7ed9\u63d0\u53d6\u7684\u4e3b\u9898\u5e26\u6765\u63d0\u5347\u3002\u4f7f\u7528 pandas\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 apply \u51fd\u6570\u5c06\u6e05\u7406\u4ee3\u7801 \"\u5e94\u7528 \"\u5230\u4efb\u610f\u7ed9\u5b9a\u7684\u5217\u4e2d\u3002 import pandas as pd corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus . loc [:, \"review\" ] = corpus . review . apply ( clean_text ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u5728\u4e3b SVD \u811a\u672c\u4e2d\u6dfb\u52a0\u4e86\u4e00\u884c\u4ee3\u7801\uff0c\u8fd9\u5c31\u662f\u4f7f\u7528\u51fd\u6570\u548c pandas \u5e94\u7528\u7684\u597d\u5904\u3002\u8fd9\u6b21\u751f\u6210\u7684\u4e3b\u9898\u5982\u4e0b\u3002 [ 'the' , 'a' , 'and' , 'of' , 'to' ] [ 'i' , 'movie' , 'it' , 'was' , 'this' ] [ 'the' , 'was' , 'i' , 'were' , 'of' ] [ 'her' , 'was' , 'she' , 'i' , 'he' ] [ 'br' , 'to' , 'they' , 'he' , 'show' ] \u547c\uff01\u81f3\u5c11\u8fd9\u6bd4\u6211\u4eec\u4e4b\u524d\u597d\u591a\u4e86\u3002\u4f46\u4f60\u77e5\u9053\u5417\uff1f\u4f60\u53ef\u4ee5\u901a\u8fc7\u5728\u6e05\u7406\u529f\u80fd\u4e2d\u5220\u9664\u505c\u6b62\u8bcd\uff08stopwords\uff09\u6765\u4f7f\u5b83\u53d8\u5f97\u66f4\u597d\u3002\u4ec0\u4e48\u662fstopwords\uff1f\u5b83\u4eec\u662f\u5b58\u5728\u4e8e\u6bcf\u79cd\u8bed\u8a00\u4e2d\u7684\u9ad8\u9891\u8bcd\u3002\u4f8b\u5982\uff0c\u5728\u82f1\u8bed\u4e2d\uff0c\u8fd9\u4e9b\u8bcd\u5305\u62ec \"a\"\u3001\"an\"\u3001\"the\"\u3001\"for \"\u7b49\u3002\u5220\u9664\u505c\u6b62\u8bcd\u5e76\u975e\u603b\u662f\u660e\u667a\u7684\u9009\u62e9\uff0c\u8fd9\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u4e1a\u52a1\u95ee\u9898\u3002\u50cf \"I need a new dog\"\u8fd9\u6837\u7684\u53e5\u5b50\uff0c\u53bb\u6389\u505c\u6b62\u8bcd\u540e\u4f1a\u53d8\u6210 \"need new dog\"\uff0c\u6b64\u65f6\u6211\u4eec\u4e0d\u77e5\u9053\u8c01\u9700\u8981new dog\u3002 \u5982\u679c\u6211\u4eec\u603b\u662f\u5220\u9664\u505c\u6b62\u8bcd\uff0c\u5c31\u4f1a\u4e22\u5931\u5f88\u591a\u4e0a\u4e0b\u6587\u4fe1\u606f\u3002\u4f60\u53ef\u4ee5\u5728 NLTK \u4e2d\u627e\u5230\u8bb8\u591a\u8bed\u8a00\u7684\u505c\u6b62\u8bcd\uff0c\u5982\u679c\u6ca1\u6709\uff0c\u4f60\u4e5f\u53ef\u4ee5\u5728\u81ea\u5df1\u559c\u6b22\u7684\u641c\u7d22\u5f15\u64ce\u4e0a\u5feb\u901f\u641c\u7d22\u4e00\u4e0b\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u8f6c\u5230\u5927\u591a\u6570\u4eba\u90fd\u559c\u6b22\u4f7f\u7528\u7684\u65b9\u6cd5\uff1a\u6df1\u5ea6\u5b66\u4e60\u3002\u4f46\u9996\u5148\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u8bcd\u5d4c\u5165\uff08embedings for words\uff09\u3002\u4f60\u5df2\u7ecf\u770b\u5230\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u5c06\u6807\u8bb0\u8f6c\u6362\u6210\u4e86\u6570\u5b57\u3002\u56e0\u6b64\uff0c\u5982\u679c\u67d0\u4e2a\u8bed\u6599\u5e93\u4e2d\u6709 N \u4e2a\u552f\u4e00\u7684\u8bcd\u5757\uff0c\u5b83\u4eec\u53ef\u4ee5\u7528 0 \u5230 N-1 \u4e4b\u95f4\u7684\u6574\u6570\u6765\u8868\u793a\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u7528\u5411\u91cf\u6765\u8868\u793a\u8fd9\u4e9b\u6574\u6570\u8bcd\u5757\u3002\u8fd9\u79cd\u5c06\u5355\u8bcd\u8868\u793a\u6210\u5411\u91cf\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u5355\u8bcd\u5d4c\u5165\u6216\u5355\u8bcd\u5411\u91cf\u3002\u8c37\u6b4c\u7684 Word2Vec \u662f\u5c06\u5355\u8bcd\u8f6c\u6362\u4e3a\u5411\u91cf\u7684\u6700\u53e4\u8001\u65b9\u6cd5\u4e4b\u4e00\u3002\u6b64\u5916\uff0c\u8fd8\u6709 Facebook \u7684 FastText \u548c\u65af\u5766\u798f\u5927\u5b66\u7684 GloVe\uff08\u7528\u4e8e\u5355\u8bcd\u8868\u793a\u7684\u5168\u5c40\u5411\u91cf\uff09\u3002\u8fd9\u4e9b\u65b9\u6cd5\u5f7c\u6b64\u5927\u76f8\u5f84\u5ead\u3002 \u5176\u57fa\u672c\u601d\u60f3\u662f\u5efa\u7acb\u4e00\u4e2a\u6d45\u5c42\u7f51\u7edc\uff0c\u901a\u8fc7\u91cd\u6784\u8f93\u5165\u53e5\u5b50\u6765\u5b66\u4e60\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u5468\u56f4\u7684\u6240\u6709\u5355\u8bcd\u6765\u8bad\u7ec3\u7f51\u7edc\u9884\u6d4b\u4e00\u4e2a\u7f3a\u5931\u7684\u5355\u8bcd\uff0c\u5728\u6b64\u8fc7\u7a0b\u4e2d\uff0c\u7f51\u7edc\u5c06\u5b66\u4e60\u5e76\u66f4\u65b0\u6240\u6709\u76f8\u5173\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u8fd9\u79cd\u65b9\u6cd5\u4e5f\u88ab\u79f0\u4e3a\u8fde\u7eed\u8bcd\u888b\u6216 CBoW \u6a21\u578b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e2a\u5355\u8bcd\u6765\u9884\u6d4b\u4e0a\u4e0b\u6587\u4e2d\u7684\u5355\u8bcd\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8df3\u683c\u6a21\u578b\u3002Word2Vec \u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\u5b66\u4e60\u5d4c\u5165\u3002 FastText \u53ef\u4ee5\u5b66\u4e60\u5b57\u7b26 n-gram \u7684\u5d4c\u5165\u3002\u548c\u5355\u8bcd n-gram \u4e00\u6837\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5b57\u7b26\uff0c\u5219\u79f0\u4e3a\u5b57\u7b26 n-gram\uff0c\u6700\u540e\uff0cGloVe \u901a\u8fc7\u5171\u73b0\u77e9\u9635\u6765\u5b66\u4e60\u8fd9\u4e9b\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0c\u6240\u6709\u8fd9\u4e9b\u4e0d\u540c\u7c7b\u578b\u7684\u5d4c\u5165\u6700\u7ec8\u90fd\u4f1a\u8fd4\u56de\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u8bed\u6599\u5e93\uff08\u4f8b\u5982\u82f1\u8bed\u7ef4\u57fa\u767e\u79d1\uff09\u4e2d\u7684\u5355\u8bcd\uff0c\u503c\u662f\u5927\u5c0f\u4e3a N\uff08\u901a\u5e38\u4e3a 300\uff09\u7684\u5411\u91cf\u3002 \u56fe 1\uff1a\u53ef\u89c6\u5316\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u3002 \u56fe 1 \u663e\u793a\u4e86\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u7684\u53ef\u89c6\u5316\u6548\u679c\u3002\u5047\u8bbe\u6211\u4eec\u4ee5\u67d0\u79cd\u65b9\u5f0f\u5b8c\u6210\u4e86\u8bcd\u8bed\u7684\u4e8c\u7ef4\u8868\u793a\u3002\u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u4eceBerlin\uff08\u5fb7\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u4e2d\u51cf\u53bb\u5fb7\u56fd\uff08Germany\uff09\u7684\u5411\u91cf\uff0c\u518d\u52a0\u4e0a\u6cd5\u56fd\uff08france\uff09\u7684\u5411\u91cf\uff0c\u5c31\u4f1a\u5f97\u5230\u4e00\u4e2a\u63a5\u8fd1Paris\uff08\u6cd5\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u3002\u7531\u6b64\u53ef\u89c1\uff0c\u5d4c\u5165\u5f0f\u4e5f\u80fd\u8fdb\u884c\u7c7b\u6bd4\u3002 \u8fd9\u5e76\u4e0d\u603b\u662f\u6b63\u786e\u7684\uff0c\u4f46\u8fd9\u6837\u7684\u4f8b\u5b50\u6709\u52a9\u4e8e\u7406\u89e3\u5355\u8bcd\u5d4c\u5165\u7684\u4f5c\u7528\u3002\u50cf \"\u55e8\uff0c\u4f60\u597d\u5417 \"\u8fd9\u6837\u7684\u53e5\u5b50\u53ef\u4ee5\u7528\u4e0b\u9762\u7684\u4e00\u5806\u5411\u91cf\u6765\u8868\u793a\u3002 hi \u2500> [vector (v1) of size 300] , \u2500> [vector (v2) of size 300] how \u2500> [vector (v3) of size 300] are \u2500> [vector (v4) of size 300] you \u2500> [vector (v5) of size 300] ? \u2500> [vector (v6) of size 300] \u4f7f\u7528\u8fd9\u4e9b\u4fe1\u606f\u6709\u591a\u79cd\u65b9\u6cd5\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u4e4b\u4e00\u5c31\u662f\u4f7f\u7528\u5d4c\u5165\u5411\u91cf\u3002\u5982\u4e0a\u4f8b\u6240\u793a\uff0c\u6bcf\u4e2a\u5355\u8bcd\u90fd\u6709\u4e00\u4e2a 1x300 \u7684\u5d4c\u5165\u5411\u91cf\u3002\u5229\u7528\u8fd9\u4e9b\u4fe1\u606f\uff0c\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u51fa\u6574\u4e2a\u53e5\u5b50\u7684\u5d4c\u5165\u3002\u8ba1\u7b97\u65b9\u6cd5\u6709\u591a\u79cd\u3002\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\u5982\u4e0b\u6240\u793a\u3002\u5728\u8fd9\u4e2a\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u5c06\u7ed9\u5b9a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u5411\u91cf\u63d0\u53d6\u51fa\u6765\uff0c\u7136\u540e\u4ece\u6240\u6709\u6807\u8bb0\u8bcd\u7684\u5355\u8bcd\u5411\u91cf\u4e2d\u521b\u5efa\u4e00\u4e2a\u5f52\u4e00\u5316\u7684\u5355\u8bcd\u5411\u91cf\u3002\u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u53e5\u5b50\u5411\u91cf\u3002 import numpy as np def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): words = str ( s ) . lower () words = tokenizer ( words ) words = [ w for w in words if not w in stop_words ] words = [ w for w in words if w . isalpha ()] M = [] for w in words : if w in embedding_dict : M . append ( embedding_dict [ w ]) if len ( M ) == 0 : return np . zeros ( 300 ) M = np . array ( M ) v = M . sum ( axis = 0 ) return v / np . sqrt (( v ** 2 ) . sum ()) \u6211\u4eec\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5c06\u6240\u6709\u793a\u4f8b\u8f6c\u6362\u6210\u4e00\u4e2a\u5411\u91cf\u3002\u6211\u4eec\u80fd\u5426\u4f7f\u7528 fastText \u5411\u91cf\u6765\u6539\u8fdb\u4e4b\u524d\u7684\u7ed3\u679c\uff1f\u6bcf\u7bc7\u8bc4\u8bba\u90fd\u6709 300 \u4e2a\u7279\u5f81\u3002 import io import numpy as np import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df = df . sample ( frac = 1 ) . reset_index ( drop = True ) print ( \"Loading embeddings\" ) embeddings = load_vectors ( \"../input/crawl-300d-2M.vec\" ) print ( \"Creating sentence vectors\" ) vectors = [] for review in df . review . values : vectors . append ( sentence_to_vec ( s = review , embedding_dict = embeddings , stop_words = [], tokenizer = word_tokenize ) ) vectors = np . array ( vectors ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for fold_ , ( t_ , v_ ) in enumerate ( kf . split ( X = vectors , y = y )): print ( f \"Training fold: { fold_ } \" ) xtrain = vectors [ t_ , :] ytrain = y [ t_ ] xtest = vectors [ v_ , :] ytest = y [ v_ ] model = linear_model . LogisticRegression () model . fit ( xtrain , ytrain ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( ytest , preds ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u5c06\u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Loading embeddings Creating sentence vectors Training fold : 0 Accuracy = 0.8619 Training fold : 1 Accuracy = 0.8661 Training fold : 2 Accuracy = 0.8544 Training fold : 3 Accuracy = 0.8624 Training fold : 4 Accuracy = 0.8595 Wow\uff01\u771f\u662f\u51fa\u4e4e\u610f\u6599\u3002\u6211\u4eec\u6240\u505a\u7684\u4e00\u5207\u90fd\u662f\u4e3a\u4e86\u4f7f\u7528 FastText \u5d4c\u5165\u3002\u8bd5\u7740\u628a\u5d4c\u5165\u5f0f\u6362\u6210 GloVe\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u3002 \u5f53\u6211\u4eec\u8c08\u8bba\u6587\u672c\u6570\u636e\u65f6\uff0c\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\u4e00\u4ef6\u4e8b\u3002\u6587\u672c\u6570\u636e\u4e0e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u975e\u5e38\u76f8\u4f3c\u3002\u5982\u56fe 2 \u6240\u793a\uff0c\u6211\u4eec\u8bc4\u8bba\u4e2d\u7684\u4efb\u4f55\u6837\u672c\u90fd\u662f\u5728\u4e0d\u540c\u65f6\u95f4\u6233\u4e0a\u6309\u9012\u589e\u987a\u5e8f\u6392\u5217\u7684\u6807\u8bb0\u5e8f\u5217\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u53ef\u4ee5\u8868\u793a\u4e3a\u4e00\u4e2a\u5411\u91cf/\u5d4c\u5165\u3002 \u56fe 2\uff1a\u5c06\u6807\u8bb0\u8868\u793a\u4e3a\u5d4c\u5165\uff0c\u5e76\u5c06\u5176\u89c6\u4e3a\u65f6\u95f4\u5e8f\u5217 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e7f\u6cdb\u7528\u4e8e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u7684\u6a21\u578b\uff0c\u4f8b\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\uff08LSTM\uff09\u6216\u95e8\u63a7\u9012\u5f52\u5355\u5143\uff08GRU\uff09\uff0c\u751a\u81f3\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff08CNN\uff09\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u53cc\u5411 LSTM \u6a21\u578b\u3002 \u9996\u5148\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u9879\u76ee\u3002\u4f60\u53ef\u4ee5\u968f\u610f\u7ed9\u5b83\u547d\u540d\u3002\u7136\u540e\uff0c\u6211\u4eec\u7684\u7b2c\u4e00\u6b65\u5c06\u662f\u5206\u5272\u6570\u636e\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f df . to_csv ( \"../input/imdb_folds.csv\" , index = False ) \u5c06\u6570\u636e\u96c6\u5212\u5206\u4e3a\u591a\u4e2a\u6298\u53e0\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5728 dataset.py \u4e2d\u521b\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u4f1a\u8fd4\u56de\u4e00\u4e2a\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\u6837\u672c\u3002 import torch class IMDBDataset : def __init__ ( self , reviews , targets ): self . reviews = reviews self . target = targets def __len__ ( self ): return len ( self . reviews ) def __getitem__ ( self , item ): review = self . reviews [ item , :] target = self . target [ item ] return { \"review\" : torch . tensor ( review , dtype = torch . long ), \"target\" : torch . tensor ( target , dtype = torch . float ) } \u5b8c\u6210\u6570\u636e\u96c6\u5206\u7c7b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa lstm.py\uff0c\u5176\u4e2d\u5305\u542b\u6211\u4eec\u7684 LSTM \u6a21\u578b import torch import torch.nn as nn class LSTM ( nn . Module ): def __init__ ( self , embedding_matrix ): super ( LSTM , self ) . __init__ () num_words = embedding_matrix . shape [ 0 ] embed_dim = embedding_matrix . shape [ 1 ] self . embedding = nn . Embedding ( num_embeddings = num_words , embedding_dim = embed_dim ) self . embedding . weight = nn . Parameter ( torch . tensor ( embedding_matrix , dtype = torch . float32 ) ) self . embedding . weight . requires_grad = False self . lstm = nn . LSTM ( embed_dim , 128 , bidirectional = True , batch_first = True , ) self . out = nn . Linear ( 512 , 1 ) def forward ( self , x ): x = self . embedding ( x ) x , _ = self . lstm ( x ) avg_pool = torch . mean ( x , 1 ) max_pool , _ = torch . max ( x , 1 ) out = torch . cat (( avg_pool , max_pool ), 1 ) out = self . out ( out ) return out \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa engine.py\uff0c\u5176\u4e2d\u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u51fd\u6570\u3002 import torch import torch.nn as nn def train ( data_loader , model , optimizer , device ): model . train () for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () predictions = model ( reviews ) loss = nn . BCEWithLogitsLoss ()( predictions , targets . view ( - 1 , 1 ) ) loss . backward () optimizer . step () def evaluate ( data_loader , model , device ): final_predictions = [] final_targets = [] model . eval () with torch . no_grad (): for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) predictions = model ( reviews ) predictions = predictions . cpu () . numpy () . tolist () targets = data [ \"target\" ] . cpu () . numpy () . tolist () final_predictions . extend ( predictions ) final_targets . extend ( targets ) return final_predictions , final_targets \u8fd9\u4e9b\u51fd\u6570\u5c06\u5728 train.py \u4e2d\u4e3a\u6211\u4eec\u63d0\u4f9b\u5e2e\u52a9\uff0c\u8be5\u51fd\u6570\u7528\u4e8e\u8bad\u7ec3\u591a\u4e2a\u6298\u53e0\u3002 import io import torch import numpy as np import pandas as pd import tensorflow as tf from sklearn import metrics import config import dataset import engine import lstm def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def create_embedding_matrix ( word_index , embedding_dict ): embedding_matrix = np . zeros (( len ( word_index ) + 1 , 300 )) for word , i in word_index . items (): if word in embedding_dict : embedding_matrix [ i ] = embedding_dict [ word ] return embedding_matrix def run ( df , fold ): train_df = df [ df . kfold != fold ] . reset_index ( drop = True ) valid_df = df [ df . kfold == fold ] . reset_index ( drop = True ) print ( \"Fitting tokenizer\" ) tokenizer = tf . keras . preprocessing . text . Tokenizer () tokenizer . fit_on_texts ( df . review . values . tolist ()) xtrain = tokenizer . texts_to_sequences ( train_df . review . values ) xtest = tokenizer . texts_to_sequences ( valid_df . review . values ) xtrain = tf . keras . preprocessing . sequence . pad_sequences ( xtrain , maxlen = config . MAX_LEN ) xtest = tf . keras . preprocessing . sequence . pad_sequences ( xtest , maxlen = config . MAX_LEN ) train_dataset = dataset . IMDBDataset ( reviews = xtrain , targets = train_df . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 2 ) valid_dataset = dataset . IMDBDataset ( reviews = xtest , targets = valid_df . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) print ( \"Loading embeddings\" ) embedding_dict = load_vectors ( \"../input/crawl-300d-2M.vec\" ) embedding_matrix = create_embedding_matrix ( tokenizer . word_index , embedding_dict ) device = torch . device ( \"cuda\" ) model = lstm . LSTM ( embedding_matrix ) model . to ( device ) optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) print ( \"Training Model\" ) best_accuracy = 0 early_stopping_counter = 0 for epoch in range ( config . EPOCHS ): engine . train ( train_data_loader , model , optimizer , device ) outputs , targets = engine . evaluate ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"FOLD: { fold } , Epoch: { epoch } , Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : best_accuracy = accuracy else : early_stopping_counter += 1 if early_stopping_counter > 2 : break if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb_folds.csv\" ) run ( df , fold = 0 ) run ( df , fold = 1 ) run ( df , fold = 2 ) run ( df , fold = 3 ) run ( df , fold = 4 ) \u6700\u540e\u662f config.py\u3002 MAX_LEN = 128 TRAIN_BATCH_SIZE = 16 VALID_BATCH_SIZE = 8 EPOCHS = 10 \u8ba9\u6211\u4eec\u770b\u770b\u8f93\u51fa\uff1a FOLD : 0 , Epoch : 3 , Accuracy Score = 0.9015 FOLD : 1 , Epoch : 4 , Accuracy Score = 0.9007 FOLD : 2 , Epoch : 3 , Accuracy Score = 0.8924 FOLD : 3 , Epoch : 2 , Accuracy Score = 0.9 FOLD : 4 , Epoch : 1 , Accuracy Score = 0.878 \u8fd9\u662f\u8fc4\u4eca\u4e3a\u6b62\u6211\u4eec\u83b7\u5f97\u7684\u6700\u597d\u6210\u7ee9\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u53ea\u663e\u793a\u4e86\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7cbe\u5ea6\u6700\u9ad8\u7684Epoch\u3002 \u4f60\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u9884\u5148\u8bad\u7ec3\u7684\u5d4c\u5165\u548c\u7b80\u5355\u7684\u53cc\u5411 LSTM\u3002 \u5982\u679c\u4f60\u60f3\u6539\u53d8\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u53ea\u6539\u53d8 lstm.py \u4e2d\u7684\u6a21\u578b\u5e76\u4fdd\u6301\u4e00\u5207\u4e0d\u53d8\u3002 \u8fd9\u79cd\u4ee3\u7801\u53ea\u9700\u8981\u5f88\u5c11\u7684\u5b9e\u9a8c\u6539\u52a8\uff0c\u5e76\u4e14\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \u4f8b\u5982\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5b66\u4e60\u5d4c\u5165\u800c\u4e0d\u662f\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6\u4e00\u4e9b\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u7ec4\u5408\u591a\u4e2a\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528GRU\uff0c\u60a8\u53ef\u4ee5\u5728\u5d4c\u5165\u540e\u4f7f\u7528\u7a7a\u95f4dropout\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0GRU LSTM \u5c42\u4e4b\u540e\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0\u4e24\u4e2a LSTM \u5c42\uff0c\u60a8\u53ef\u4ee5\u8fdb\u884c LSTM-GRU-LSTM \u914d\u7f6e\uff0c\u60a8\u53ef\u4ee5\u7528\u5377\u79ef\u5c42\u66ff\u6362 LSTM \u7b49\uff0c\u800c\u65e0\u9700\u5bf9\u4ee3\u7801\u8fdb\u884c\u592a\u591a\u66f4\u6539\u3002 \u6211\u63d0\u5230\u7684\u5927\u90e8\u5206\u5185\u5bb9\u53ea\u9700\u8981\u66f4\u6539\u6a21\u578b\u7c7b\u3002 \u5f53\u60a8\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\u65f6\uff0c\u5c1d\u8bd5\u67e5\u770b\u6709\u591a\u5c11\u5355\u8bcd\u65e0\u6cd5\u627e\u5230\u5d4c\u5165\u4ee5\u53ca\u539f\u56e0\u3002 \u9884\u8bad\u7ec3\u5d4c\u5165\u7684\u5355\u8bcd\u8d8a\u591a\uff0c\u7ed3\u679c\u5c31\u8d8a\u597d\u3002 \u6211\u5411\u60a8\u5c55\u793a\u4ee5\u4e0b\u672a\u6ce8\u91ca\u7684 (!) \u51fd\u6570\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5b83\u4e3a\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u8bad\u7ec3\u5d4c\u5165\u521b\u5efa\u5d4c\u5165\u77e9\u9635\uff0c\u5176\u683c\u5f0f\u4e0e glove \u6216 fastText \u76f8\u540c\uff08\u53ef\u80fd\u9700\u8981\u8fdb\u884c\u4e00\u4e9b\u66f4\u6539\uff09\u3002 def load_embeddings ( word_index , embedding_file , vector_length = 300 ): max_features = len ( word_index ) + 1 words_to_find = list ( word_index . keys ()) more_words_to_find = [] for wtf in words_to_find : more_words_to_find . append ( wtf ) more_words_to_find . append ( str ( wtf ) . capitalize ()) more_words_to_find = set ( more_words_to_find ) def get_coefs ( word , * arr ): return word , np . asarray ( arr , dtype = 'float32' ) embeddings_index = dict ( get_coefs ( * o . strip () . split ( \" \" )) for o in open ( embedding_file ) if o . split ( \" \" )[ 0 ] in more_words_to_find and len ( o ) > 100 ) embedding_matrix = np . zeros (( max_features , vector_length )) for word , i in word_index . items (): if i >= max_features : continue embedding_vector = embeddings_index . get ( word ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . capitalize () ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . upper () ) if ( embedding_vector is not None and len ( embedding_vector ) == vector_length ): embedding_matrix [ i ] = embedding_vector return embedding_matrix \u9605\u8bfb\u5e76\u8fd0\u884c\u4e0a\u9762\u7684\u51fd\u6570\uff0c\u770b\u770b\u53d1\u751f\u4e86\u4ec0\u4e48\u3002 \u8be5\u51fd\u6570\u8fd8\u53ef\u4ee5\u4fee\u6539\u4e3a\u4f7f\u7528\u8bcd\u5e72\u8bcd\u6216\u8bcd\u5f62\u8fd8\u539f\u8bcd\u3002 \u6700\u540e\uff0c\u60a8\u5e0c\u671b\u8bad\u7ec3\u8bed\u6599\u5e93\u4e2d\u7684\u672a\u77e5\u5355\u8bcd\u6570\u91cf\u6700\u5c11\u3002 \u53e6\u4e00\u4e2a\u6280\u5de7\u662f\u5b66\u4e60\u5d4c\u5165\u5c42\uff0c\u5373\u4f7f\u5176\u53ef\u8bad\u7ec3\uff0c\u7136\u540e\u8bad\u7ec3\u7f51\u7edc\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5206\u7c7b\u95ee\u9898\u6784\u5efa\u4e86\u5f88\u591a\u6a21\u578b\u3002 \u7136\u800c\uff0c\u73b0\u5728\u662f\u5e03\u5076\u65f6\u4ee3\uff0c\u8d8a\u6765\u8d8a\u591a\u7684\u4eba\u8f6c\u5411\u57fa\u4e8e\u53d8\u5f62\u91d1\u521a\u7684\u6a21\u578b\u3002 \u57fa\u4e8e Transformer \u7684\u7f51\u7edc\u80fd\u591f\u5904\u7406\u672c\u8d28\u4e0a\u957f\u671f\u7684\u4f9d\u8d56\u5173\u7cfb\u3002 LSTM \u4ec5\u5f53\u5b83\u770b\u5230\u524d\u4e00\u4e2a\u5355\u8bcd\u65f6\u624d\u67e5\u770b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u3002 \u53d8\u538b\u5668\u7684\u60c5\u51b5\u5e76\u975e\u5982\u6b64\u3002 \u5b83\u53ef\u4ee5\u540c\u65f6\u67e5\u770b\u6574\u4e2a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u3002 \u56e0\u6b64\uff0c\u53e6\u4e00\u4e2a\u4f18\u70b9\u662f\u5b83\u53ef\u4ee5\u8f7b\u677e\u5e76\u884c\u5316\u5e76\u66f4\u6709\u6548\u5730\u4f7f\u7528 GPU\u3002 Transformers \u662f\u4e00\u4e2a\u975e\u5e38\u5e7f\u6cdb\u7684\u8bdd\u9898\uff0c\u6709\u592a\u591a\u7684\u6a21\u578b\uff1a BERT\u3001RoBERTa\u3001XLNet\u3001XLM-RoBERTa\u3001T5 \u7b49\u3002\u6211\u5c06\u5411\u60a8\u5c55\u793a\u4e00\u79cd\u53ef\u7528\u4e8e\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\uff08T5 \u9664\u5916\uff09\u8fdb\u884c\u5206\u7c7b\u7684\u901a\u7528\u65b9\u6cd5 \u6211\u4eec\u4e00\u76f4\u5728\u8ba8\u8bba\u7684\u95ee\u9898\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u4e9b\u53d8\u538b\u5668\u9700\u8981\u8bad\u7ec3\u5b83\u4eec\u6240\u9700\u7684\u8ba1\u7b97\u80fd\u529b\u3002 \u56e0\u6b64\uff0c\u5982\u679c\u60a8\u6ca1\u6709\u9ad8\u7aef\u7cfb\u7edf\uff0c\u4e0e\u57fa\u4e8e LSTM \u6216 TF-IDF \u7684\u6a21\u578b\u76f8\u6bd4\uff0c\u8bad\u7ec3\u6a21\u578b\u53ef\u80fd\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u3002 \u6211\u4eec\u8981\u505a\u7684\u7b2c\u4e00\u4ef6\u4e8b\u662f\u521b\u5efa\u4e00\u4e2a\u914d\u7f6e\u6587\u4ef6\u3002 import transformers MAX_LEN = 512 TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 EPOCHS = 10 BERT_PATH = \"../input/bert_base_uncased/\" MODEL_PATH = \"model.bin\" TRAINING_FILE = \"../input/imdb.csv\" TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u8fd9\u91cc\u7684\u914d\u7f6e\u6587\u4ef6\u662f\u6211\u4eec\u5b9a\u4e49\u5206\u8bcd\u5668\u548c\u5176\u4ed6\u6211\u4eec\u60f3\u8981\u7ecf\u5e38\u66f4\u6539\u7684\u53c2\u6570\u7684\u552f\u4e00\u5730\u65b9 \u2014\u2014 \u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u505a\u5f88\u591a\u5b9e\u9a8c\u800c\u4e0d\u9700\u8981\u8fdb\u884c\u5927\u91cf\u66f4\u6539\u3002 \u4e0b\u4e00\u6b65\u662f\u6784\u5efa\u6570\u636e\u96c6\u7c7b\u3002 import config import torch class BERTDataset : def __init__ ( self , review , target ): self . review = review self . target = target self . tokenizer = config . TOKENIZER self . max_len = config . MAX_LEN def __len__ ( self ): return len ( self . review ) def __getitem__ ( self , item ): review = str ( self . review [ item ]) review = \" \" . join ( review . split ()) inputs = self . tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = self . max_len , pad_to_max_length = True , ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] return { \"ids\" : torch . tensor ( ids , dtype = torch . long ), \"mask\" : torch . tensor ( mask , dtype = torch . long ), \"token_type_ids\" : torch . tensor ( token_type_ids , dtype = torch . long ), \"targets\" : torch . tensor ( self . target [ item ], dtype = torch . float ) } \u73b0\u5728\u6211\u4eec\u6765\u5230\u4e86\u8be5\u9879\u76ee\u7684\u6838\u5fc3\uff0c\u5373\u6a21\u578b\u3002 import config import transformers import torch.nn as nn class BERTBaseUncased ( nn . Module ): def __init__ ( self ): super ( BERTBaseUncased , self ) . __init__ () self . bert = transformers . BertModel . from_pretrained ( config . BERT_PATH ) self . bert_drop = nn . Dropout ( 0.3 ) self . out = nn . Linear ( 768 , 1 ) def forward ( self , ids , mask , token_type_ids ): hidden state _ , o2 = self . bert ( ids , attention_mask = mask , token_type_ids = token_type_ids ) bo = self . bert_drop ( o2 ) output = self . out ( bo ) return output \u8be5\u6a21\u578b\u8fd4\u56de\u5355\u4e2a\u8f93\u51fa\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e26\u6709 logits \u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\uff0c\u5b83\u9996\u5148\u5e94\u7528 sigmoid\uff0c\u7136\u540e\u8ba1\u7b97\u635f\u5931\u3002 \u8fd9\u662f\u5728engine.py \u4e2d\u5b8c\u6210\u7684\u3002 import torch import torch.nn as nn def loss_fn ( outputs , targets ): return nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) def train_fn ( data_loader , model , optimizer , device , scheduler ): model . train () for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) loss = loss_fn ( outputs , targets ) loss . backward () optimizer . step () scheduler . step () def eval_fn ( data_loader , model , device ): model . eval () fin_targets = [] fin_outputs = [] with torch . no_grad (): for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) targets = targets . cpu () . detach () fin_targets . extend ( targets . numpy () . tolist ()) outputs = torch . sigmoid ( outputs ) . cpu () . detach () fin_outputs . extend ( outputs . numpy () . tolist ()) return fin_outputs , fin_targets \u6700\u540e\uff0c\u6211\u4eec\u51c6\u5907\u597d\u8bad\u7ec3\u4e86\u3002 \u6211\u4eec\u6765\u770b\u770b\u8bad\u7ec3\u811a\u672c\u5427\uff01 import config import dataset import engine import torch import pandas as pd import torch.nn as nn import numpy as np from model import BERTBaseUncased from sklearn import model_selection from sklearn import metrics from transformers import AdamW from transformers import get_linear_schedule_with_warmup def train (): dfx = pd . read_csv ( config . TRAINING_FILE ) . fillna ( \"none\" ) dfx . sentiment = dfx . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df_train , df_valid = model_selection . train_test_split ( dfx , test_size = 0.1 , random_state = 42 , stratify = dfx . sentiment . values ) df_train = df_train . reset_index ( drop = True ) df_valid = df_valid . reset_index ( drop = True ) train_dataset = dataset . BERTDataset ( review = df_train . review . values , target = df_train . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 4 ) valid_dataset = dataset . BERTDataset ( review = df_valid . review . values , target = df_valid . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) device = torch . device ( \"cuda\" ) model = BERTBaseUncased () model . to ( device ) param_optimizer = list ( model . named_parameters ()) no_decay = [ \"bias\" , \"LayerNorm.bias\" , \"LayerNorm.weight\" ] optimizer_parameters = [ { \"params\" : [ p for n , p in param_optimizer if not any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.001 , } \uff0c { \"params\" : [ p for n , p in param_optimizer if any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.0 , }] num_train_steps = int ( len ( df_train ) / config . TRAIN_BATCH_SIZE * config . EPOCHS ) optimizer = AdamW ( optimizer_parameters , lr = 3e-5 ) scheduler = get_linear_schedule_with_warmup ( optimizer , num_warmup_steps = 0 , num_training_steps = num_train_steps ) model = nn . DataParallel ( model ) best_accuracy = 0 for epoch in range ( config . EPOCHS ): engine . train_fn ( train_data_loader , model , optimizer , device , scheduler ) outputs , targets = engine . eval_fn ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : torch . save ( model . state_dict (), config . MODEL_PATH ) best_accuracy = accuracy if __name__ == \"__main__\" : train () \u4e4d\u4e00\u770b\u53ef\u80fd\u770b\u8d77\u6765\u5f88\u591a\uff0c\u4f46\u4e00\u65e6\u60a8\u4e86\u89e3\u4e86\u5404\u4e2a\u7ec4\u4ef6\uff0c\u5c31\u4e0d\u518d\u90a3\u4e48\u7b80\u5355\u4e86\u3002 \u60a8\u53ea\u9700\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u5373\u53ef\u8f7b\u677e\u5c06\u5176\u66f4\u6539\u4e3a\u60a8\u60f3\u8981\u4f7f\u7528\u7684\u4efb\u4f55\u5176\u4ed6\u53d8\u538b\u5668\u6a21\u578b\u3002 \u8be5\u6a21\u578b\u7684\u51c6\u786e\u7387\u4e3a 93%\uff01 \u54c7\uff01 \u8fd9\u6bd4\u4efb\u4f55\u5176\u4ed6\u6a21\u578b\u90fd\u8981\u597d\u5f97\u591a\u3002 \u4f46\u662f\u8fd9\u503c\u5f97\u5417\uff1f \u6211\u4eec\u4f7f\u7528 LSTM \u80fd\u591f\u5b9e\u73b0 90% \u7684\u76ee\u6807\uff0c\u800c\u4e14\u5b83\u4eec\u66f4\u7b80\u5355\u3001\u66f4\u5bb9\u6613\u8bad\u7ec3\u5e76\u4e14\u63a8\u7406\u901f\u5ea6\u66f4\u5feb\u3002 \u901a\u8fc7\u4f7f\u7528\u4e0d\u540c\u7684\u6570\u636e\u5904\u7406\u6216\u8c03\u6574\u5c42\u3001\u8282\u70b9\u3001dropout\u3001\u5b66\u4e60\u7387\u3001\u66f4\u6539\u4f18\u5316\u5668\u7b49\u53c2\u6570\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6a21\u578b\u6539\u8fdb\u4e00\u4e2a\u767e\u5206\u70b9\u3002\u7136\u540e\u6211\u4eec\u5c06\u4ece BERT \u4e2d\u83b7\u5f97\u7ea6 2% \u7684\u6536\u76ca\u3002 \u53e6\u4e00\u65b9\u9762\uff0cBERT \u7684\u8bad\u7ec3\u65f6\u95f4\u8981\u957f\u5f97\u591a\uff0c\u53c2\u6570\u5f88\u591a\uff0c\u800c\u4e14\u63a8\u7406\u901f\u5ea6\u4e5f\u5f88\u6162\u3002 \u6700\u540e\uff0c\u60a8\u5e94\u8be5\u5ba1\u89c6\u81ea\u5df1\u7684\u4e1a\u52a1\u5e76\u505a\u51fa\u660e\u667a\u7684\u9009\u62e9\u3002 \u4e0d\u8981\u4ec5\u4ec5\u56e0\u4e3a BERT\u201c\u9177\u201d\u800c\u9009\u62e9\u5b83\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6211\u4eec\u5728\u8fd9\u91cc\u8ba8\u8bba\u7684\u552f\u4e00\u4efb\u52a1\u662f\u5206\u7c7b\uff0c\u4f46\u5c06\u5176\u66f4\u6539\u4e3a\u56de\u5f52\u3001\u591a\u6807\u7b7e\u6216\u591a\u7c7b\u53ea\u9700\u8981\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u3002 \u4f8b\u5982\uff0c\u591a\u7c7b\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u7684\u540c\u4e00\u95ee\u9898\u5c06\u6709\u591a\u4e2a\u8f93\u51fa\u548c\u4ea4\u53c9\u71b5\u635f\u5931\u3002 \u5176\u4ed6\u4e00\u5207\u90fd\u5e94\u8be5\u4fdd\u6301\u4e0d\u53d8\u3002 \u81ea\u7136\u8bed\u8a00\u5904\u7406\u975e\u5e38\u5e9e\u5927\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u5176\u4e2d\u7684\u4e00\u5c0f\u90e8\u5206\u3002 \u663e\u7136\uff0c\u8fd9\u662f\u4e00\u4e2a\u5f88\u5927\u7684\u6bd4\u4f8b\uff0c\u56e0\u4e3a\u5927\u591a\u6570\u5de5\u4e1a\u6a21\u578b\u90fd\u662f\u5206\u7c7b\u6216\u56de\u5f52\u6a21\u578b\u3002 \u5982\u679c\u6211\u5f00\u59cb\u8be6\u7ec6\u5199\u6240\u6709\u5185\u5bb9\uff0c\u6211\u6700\u7ec8\u53ef\u80fd\u4f1a\u5199\u51e0\u767e\u9875\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u51b3\u5b9a\u5c06\u6240\u6709\u5185\u5bb9\u5305\u542b\u5728\u4e00\u672c\u5355\u72ec\u7684\u4e66\u4e2d\uff1a\u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55 NLP \u95ee\u9898\uff01","title":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5"},{"location":"%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%88%96%E5%9B%9E%E5%BD%92%E6%96%B9%E6%B3%95/#_1","text":"\u6587\u672c\u95ee\u9898\u662f\u6211\u7684\u6700\u7231\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u8fd9\u4e9b\u95ee\u9898\u4e5f\u88ab\u79f0\u4e3a \u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u95ee\u9898 \u3002NLP \u95ee\u9898\u4e0e\u56fe\u50cf\u95ee\u9898\u4e5f\u6709\u5f88\u5927\u4e0d\u540c\u3002\u4f60\u9700\u8981\u521b\u5efa\u4ee5\u524d\u4ece\u672a\u4e3a\u8868\u683c\u95ee\u9898\u521b\u5efa\u8fc7\u7684\u6570\u636e\u7ba1\u9053\u3002\u4f60\u9700\u8981\u4e86\u89e3\u5546\u4e1a\u6848\u4f8b\uff0c\u624d\u80fd\u5efa\u7acb\u4e00\u4e2a\u597d\u7684\u6a21\u578b\u3002\u987a\u4fbf\u8bf4\u4e00\u53e5\uff0c\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u4efb\u4f55\u4e8b\u60c5\u90fd\u662f\u5982\u6b64\u3002\u5efa\u7acb\u6a21\u578b\u4f1a\u8ba9\u4f60\u8fbe\u5230\u4e00\u5b9a\u7684\u6c34\u5e73\uff0c\u4f46\u8981\u60f3\u6539\u5584\u548c\u4fc3\u8fdb\u4f60\u6240\u5efa\u7acb\u6a21\u578b\u7684\u4e1a\u52a1\uff0c\u4f60\u5fc5\u987b\u4e86\u89e3\u5b83\u5bf9\u4e1a\u52a1\u7684\u5f71\u54cd\u3002 NLP \u95ee\u9898\u6709\u5f88\u591a\u79cd\uff0c\u5176\u4e2d\u6700\u5e38\u89c1\u7684\u662f\u5b57\u7b26\u4e32\u5206\u7c7b\u3002\u5f88\u591a\u65f6\u5019\uff0c\u6211\u4eec\u4f1a\u770b\u5230\u4eba\u4eec\u5728\u5904\u7406\u8868\u683c\u6570\u636e\u6216\u56fe\u50cf\u65f6\u8868\u73b0\u51fa\u8272\uff0c\u4f46\u5728\u5904\u7406\u6587\u672c\u65f6\uff0c\u4ed6\u4eec\u751a\u81f3\u4e0d\u77e5\u9053\u4ece\u4f55\u5165\u624b\u3002\u6587\u672c\u6570\u636e\u4e0e\u5176\u4ed6\u7c7b\u578b\u7684\u6570\u636e\u96c6\u6ca1\u6709\u4ec0\u4e48\u4e0d\u540c\u3002\u5bf9\u4e8e\u8ba1\u7b97\u673a\u6765\u8bf4\uff0c\u4e00\u5207\u90fd\u662f\u6570\u5b57\u3002 \u5047\u8bbe\u6211\u4eec\u4ece\u60c5\u611f\u5206\u7c7b\u8fd9\u4e00\u57fa\u672c\u4efb\u52a1\u5f00\u59cb\u3002\u6211\u4eec\u5c06\u5c1d\u8bd5\u5bf9\u7535\u5f71\u8bc4\u8bba\u8fdb\u884c\u60c5\u611f\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u6709\u4e00\u4e2a\u6587\u672c\uff0c\u5e76\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u60c5\u611f\u3002\u4f60\u5c06\u5982\u4f55\u5904\u7406\u8fd9\u7c7b\u95ee\u9898\uff1f\u662f\u5e94\u7528\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff1f \u4e0d\uff0c\u7edd\u5bf9\u9519\u4e86\u3002\u4f60\u8981\u4ece\u6700\u57fa\u672c\u7684\u5f00\u59cb\u3002\u8ba9\u6211\u4eec\u5148\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u4ec0\u4e48\u6837\u5b50\u7684\u3002 \u6211\u4eec\u4ece IMDB \u7535\u5f71\u8bc4\u8bba\u6570\u636e\u96c6 \u5f00\u59cb\uff0c\u8be5\u6570\u636e\u96c6\u5305\u542b 25000 \u7bc7\u6b63\u9762\u60c5\u611f\u8bc4\u8bba\u548c 25000 \u7bc7\u8d1f\u9762\u60c5\u611f\u8bc4\u8bba\u3002 \u6211\u5c06\u5728\u6b64\u8ba8\u8bba\u7684\u6982\u5ff5\u51e0\u4e4e\u9002\u7528\u4e8e\u4efb\u4f55\u6587\u672c\u5206\u7c7b\u6570\u636e\u96c6\u3002 \u8fd9\u4e2a\u6570\u636e\u96c6\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4e00\u7bc7\u8bc4\u8bba\u5bf9\u5e94\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5199\u7684\u662f\u8bc4\u8bba\u800c\u4e0d\u662f\u53e5\u5b50\u3002\u8bc4\u8bba\u5c31\u662f\u4e00\u5806\u53e5\u5b50\u3002\u6240\u4ee5\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4f60\u4e00\u5b9a\u53ea\u770b\u5230\u4e86\u5bf9\u5355\u53e5\u7684\u5206\u7c7b\uff0c\u4f46\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u591a\u4e2a\u53e5\u5b50\u8fdb\u884c\u5206\u7c7b\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u8fd9\u610f\u5473\u7740\u4e0d\u4ec5\u4e00\u4e2a\u53e5\u5b50\u4f1a\u5bf9\u60c5\u611f\u4ea7\u751f\u5f71\u54cd\uff0c\u800c\u4e14\u60c5\u611f\u5f97\u5206\u662f\u591a\u4e2a\u53e5\u5b50\u5f97\u5206\u7684\u7ec4\u5408\u3002\u6570\u636e\u7b80\u4ecb\u5982\u56fe 1 \u6240\u793a\u3002 \u5982\u4f55\u7740\u624b\u89e3\u51b3\u8fd9\u6837\u7684\u95ee\u9898\uff1f\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u624b\u5de5\u5236\u4f5c\u4e24\u4efd\u5355\u8bcd\u8868\u3002\u4e00\u4e2a\u5217\u8868\u5305\u542b\u4f60\u80fd\u60f3\u8c61\u5230\u7684\u6240\u6709\u6b63\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u597d\u3001\u68d2\u3001\u597d\u7b49\uff1b\u53e6\u4e00\u4e2a\u5217\u8868\u5305\u542b\u6240\u6709\u8d1f\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u574f\u3001\u6076\u7b49\u3002\u6211\u4eec\u5148\u4e0d\u8981\u4e3e\u4f8b\u8bf4\u660e\u574f\u8bcd\uff0c\u5426\u5219\u8fd9\u672c\u4e66\u5c31\u53ea\u80fd\u4f9b 18 \u5c81\u4ee5\u4e0a\u7684\u4eba\u9605\u8bfb\u4e86\u3002\u4e00\u65e6\u4f60\u6709\u4e86\u8fd9\u4e9b\u5217\u8868\uff0c\u4f60\u751a\u81f3\u4e0d\u9700\u8981\u4e00\u4e2a\u6a21\u578b\u6765\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u4e9b\u5217\u8868\u4e5f\u88ab\u79f0\u4e3a\u60c5\u611f\u8bcd\u5178\u3002\u4f60\u53ef\u4ee5\u7528\u4e00\u4e2a\u7b80\u5355\u7684\u8ba1\u6570\u5668\u6765\u8ba1\u7b97\u53e5\u5b50\u4e2d\u6b63\u9762\u548c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u3002\u5982\u679c\u6b63\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u6b63\u9762\u60c5\u611f\uff1b\u5982\u679c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u8d1f\u9762\u60c5\u611f\u3002\u5982\u679c\u53e5\u5b50\u4e2d\u6ca1\u6709\u8fd9\u4e9b\u8bcd\uff0c\u5219\u53ef\u4ee5\u8bf4\u8be5\u53e5\u5b50\u5177\u6709\u4e2d\u6027\u60c5\u611f\u3002\u8fd9\u662f\u6700\u53e4\u8001\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u73b0\u5728\u4ecd\u6709\u4eba\u5728\u4f7f\u7528\u3002\u5b83\u4e5f\u4e0d\u9700\u8981\u592a\u591a\u4ee3\u7801\u3002 def find_sentiment ( sentence , pos , neg ): sentence = sentence . split () sentence = set ( sentence ) num_common_pos = len ( sentence . intersection ( pos )) num_common_neg = len ( sentence . intersection ( neg )) if num_common_pos > num_common_neg : return \"positive\" if num_common_pos < num_common_neg : return \"negative\" return \"neutral\" \u4e0d\u8fc7\uff0c\u8fd9\u79cd\u65b9\u6cd5\u8003\u8651\u7684\u56e0\u7d20\u5e76\u4e0d\u591a\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u7684 split() \u4e5f\u5e76\u4e0d\u5b8c\u7f8e\u3002\u5982\u679c\u4f7f\u7528 split()\uff0c\u5c31\u4f1a\u51fa\u73b0\u8fd9\u6837\u7684\u53e5\u5b50\uff1a \"hi, how are you?\" \u7ecf\u8fc7\u5206\u5272\u540e\u53d8\u4e3a\uff1a [\"hi,\", \"how\",\"are\",\"you?\"] \u8fd9\u79cd\u65b9\u6cd5\u5e76\u4e0d\u7406\u60f3\uff0c\u56e0\u4e3a\u5355\u8bcd\u4e2d\u5305\u542b\u4e86\u9017\u53f7\u548c\u95ee\u53f7\uff0c\u5b83\u4eec\u5e76\u6ca1\u6709\u88ab\u5206\u5272\u3002\u56e0\u6b64\uff0c\u5982\u679c\u6ca1\u6709\u5728\u5206\u5272\u524d\u5bf9\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u8fdb\u884c\u9884\u5904\u7406\uff0c\u4e0d\u5efa\u8bae\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\u3002\u5c06\u5b57\u7b26\u4e32\u62c6\u5206\u4e3a\u5355\u8bcd\u5217\u8868\u79f0\u4e3a\u6807\u8bb0\u5316\u3002\u6700\u6d41\u884c\u7684\u6807\u8bb0\u5316\u65b9\u6cd5\u4e4b\u4e00\u6765\u81ea NLTK\uff08\u81ea\u7136\u8bed\u8a00\u5de5\u5177\u5305\uff09 \u3002 In [ X ]: from nltk.tokenize import word_tokenize In [ X ]: sentence = \"hi, how are you?\" In [ X ]: sentence . split () Out [ X ]: [ 'hi,' , 'how' , 'are' , 'you?' ] In [ X ]: word_tokenize ( sentence ) Out [ X ]: [ 'hi' , ',' , 'how' , 'are' , 'you' , '?' ] \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u4f7f\u7528 NLTK \u7684\u5355\u8bcd\u6807\u8bb0\u5316\u529f\u80fd\uff0c\u540c\u4e00\u4e2a\u53e5\u5b50\u7684\u62c6\u5206\u6548\u679c\u8981\u597d\u5f97\u591a\u3002\u4f7f\u7528\u5355\u8bcd\u5217\u8868\u8fdb\u884c\u5bf9\u6bd4\u7684\u6548\u679c\u4e5f\u4f1a\u66f4\u597d\uff01\u8fd9\u5c31\u662f\u6211\u4eec\u5c06\u5e94\u7528\u4e8e\u7b2c\u4e00\u4e2a\u60c5\u611f\u68c0\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u3002 \u5728\u5904\u7406 NLP \u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u60a8\u5e94\u8be5\u7ecf\u5e38\u5c1d\u8bd5\u7684\u57fa\u672c\u6a21\u578b\u4e4b\u4e00\u662f \u8bcd\u888b\u6a21\u578b\uff08bag of words\uff09 \u3002\u5728\u8bcd\u888b\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u5de8\u5927\u7684\u7a00\u758f\u77e9\u9635\uff0c\u5b58\u50a8\u8bed\u6599\u5e93\uff08\u8bed\u6599\u5e93=\u6240\u6709\u6587\u6863=\u6240\u6709\u53e5\u5b50\uff09\u4e2d\u6240\u6709\u5355\u8bcd\u7684\u8ba1\u6570\u3002\u4e3a\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 CountVectorizer\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u662f\u5982\u4f55\u5de5\u4f5c\u7684\u3002 from sklearn.feature_extraction.text import CountVectorizer corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer () ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) \u5982\u679c\u6211\u4eec\u6253\u5370 corpus_transformed\uff0c\u5c31\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e0b\u9762\u7684\u7ed3\u679c\uff1a ( 0 , 2 ) 1 ( 0 , 9 ) 1 ( 0 , 11 ) 1 ( 0 , 22 ) 1 ( 1 , 1 ) 1 ( 1 , 3 ) 1 ( 1 , 4 ) 1 ( 1 , 7 ) 1 ( 1 , 8 ) 1 ( 1 , 10 ) 1 ( 1 , 13 ) 1 ( 1 , 17 ) 1 ( 1 , 19 ) 1 ( 1 , 22 ) 2 ( 2 , 0 ) 1 ( 2 , 5 ) 1 ( 2 , 6 ) 1 ( 2 , 14 ) 1 ( 2 , 22 ) 1 ( 3 , 12 ) 1 ( 3 , 15 ) 1 ( 3 , 16 ) 1 ( 3 , 18 ) 1 ( 3 , 20 ) 1 ( 4 , 21 ) 1 \u5728\u524d\u9762\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u89c1\u8bc6\u8fc7\u8fd9\u79cd\u8868\u793a\u6cd5\u3002\u5373\u7a00\u758f\u8868\u793a\u6cd5\u3002\u56e0\u6b64\uff0c\u8bed\u6599\u5e93\u73b0\u5728\u662f\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5176\u4e2d\u7b2c\u4e00\u4e2a\u6837\u672c\u6709 4 \u4e2a\u5143\u7d20\uff0c\u7b2c\u4e8c\u4e2a\u6837\u672c\u6709 10 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\uff0c\u7b2c\u4e09\u4e2a\u6837\u672c\u6709 5 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e9b\u5143\u7d20\u90fd\u6709\u76f8\u5173\u7684\u8ba1\u6570\u3002\u6709\u4e9b\u5143\u7d20\u4f1a\u51fa\u73b0\u4e24\u6b21\uff0c\u6709\u4e9b\u5219\u53ea\u6709\u4e00\u6b21\u3002\u4f8b\u5982\uff0c\u5728\u6837\u672c 2\uff08\u7b2c 1 \u884c\uff09\u4e2d\uff0c\u6211\u4eec\u770b\u5230\u7b2c 22 \u5217\u7684\u6570\u503c\u662f 2\u3002\u8fd9\u662f\u4e3a\u4ec0\u4e48\u5462\uff1f\u7b2c 22 \u5217\u662f\u4ec0\u4e48\uff1f CountVectorizer \u7684\u5de5\u4f5c\u65b9\u5f0f\u662f\u9996\u5148\u5bf9\u53e5\u5b50\u8fdb\u884c\u6807\u8bb0\u5316\u5904\u7406\uff0c\u7136\u540e\u4e3a\u6bcf\u4e2a\u6807\u8bb0\u8d4b\u503c\u3002\u56e0\u6b64\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u7531\u4e00\u4e2a\u552f\u4e00\u7d22\u5f15\u8868\u793a\u3002\u8fd9\u4e9b\u552f\u4e00\u7d22\u5f15\u5c31\u662f\u6211\u4eec\u770b\u5230\u7684\u5217\u3002CountVectorizer \u4f1a\u5b58\u50a8\u8fd9\u4e9b\u4fe1\u606f\u3002 print ( ctv . vocabulary_ ) { 'hello' : 9 , 'how' : 11 , 'are' : 2 , 'you' : 22 , 'im' : 13 , 'getting' : 8 , 'bored' : 4 , 'at' : 3 , 'home' : 10 , 'and' : 1 , 'what' : 19 , 'do' : 7 , 'think' : 17 , 'did' : 6 , 'know' : 14 , 'about' : 0 , 'counts' : 5 , 'let' : 15 , 'see' : 16 , 'if' : 12 , 'this' : 18 , 'works' : 20 , 'yes' : 21 } \u6211\u4eec\u770b\u5230\uff0c\u7d22\u5f15 22 \u5c5e\u4e8e \"you\"\uff0c\u800c\u5728\u7b2c\u4e8c\u53e5\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u4e24\u6b21 \"you\"\u3002\u6211\u5e0c\u671b\u5927\u5bb6\u73b0\u5728\u5df2\u7ecf\u6e05\u695a\u4ec0\u4e48\u662f\u8bcd\u888b\u4e86\u3002\u4f46\u662f\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4e9b\u7279\u6b8a\u5b57\u7b26\u3002\u6709\u65f6\uff0c\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u4e5f\u5f88\u6709\u7528\u3002\u4f8b\u5982\uff0c\"? \"\u5728\u5927\u591a\u6570\u53e5\u5b50\u4e2d\u8868\u793a\u7591\u95ee\u53e5\u3002\u8ba9\u6211\u4eec\u628a scikit-learn \u7684 word_tokenize \u6574\u5408\u5230 CountVectorizer \u4e2d\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002 from sklearn.feature_extraction.text import CountVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) print ( ctv . vocabulary_ ) \u8fd9\u6837\uff0c\u6211\u4eec\u7684\u8bcd\u888b\u5c31\u53d8\u6210\u4e86\uff1a { 'hello' : 14 , ',' : 2 , 'how' : 16 , 'are' : 7 , 'you' : 27 , '?' : 4 , 'im' : 18 , 'getting' : 13 , 'bored' : 9 , 'at' : 8 , 'home' : 15 , '.' : 3 , 'and' : 6 , 'what' : 24 , 'do' : 12 , 'think' : 22 , 'did' : 11 , 'know' : 19 , 'about' : 5 , 'counts' : 10 , 'let' : 20 , \"'s\" : 1 , 'see' : 21 , 'if' : 17 , 'this' : 23 , 'works' : 25 , '!' : 0 , 'yes' : 26 } \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u5229\u7528 IMDB \u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u53e5\u5b50\u521b\u5efa\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5e76\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002\u8be5\u6570\u636e\u96c6\u4e2d\u6b63\u8d1f\u6837\u672c\u7684\u6bd4\u4f8b\u4e3a 1:1\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u5e76\u521b\u5efa\u4e00\u4e2a\u811a\u672c\u6765\u8bad\u7ec35\u4e2a\u6298\u53e0\u3002\u4f60\u4f1a\u95ee\u4f7f\u7528\u54ea\u4e2a\u6a21\u578b\uff1f\u5bf9\u4e8e\u9ad8\u7ef4\u7a00\u758f\u6570\u636e\uff0c\u54ea\u4e2a\u6a21\u578b\u6700\u5feb\uff1f\u903b\u8f91\u56de\u5f52\u3002\u6211\u4eec\u5c06\u9996\u5148\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u6765\u5904\u7406\u8fd9\u4e2a\u6570\u636e\u96c6\uff0c\u5e76\u521b\u5efa\u7b2c\u4e00\u4e2a\u57fa\u51c6\u6a21\u578b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) count_vec = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) count_vec . fit ( train_df . review ) xtrain = count_vec . transform ( train_df . review ) xtest = count_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u6bb5\u4ee3\u7801\u7684\u8fd0\u884c\u9700\u8981\u4e00\u5b9a\u7684\u65f6\u95f4\uff0c\u4f46\u53ef\u4ee5\u5f97\u5230\u4ee5\u4e0b\u8f93\u51fa\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8903 Fold : 1 Accuracy = 0.897 Fold : 2 Accuracy = 0.891 Fold : 3 Accuracy = 0.8914 Fold : 4 Accuracy = 0.8931 \u54c7\uff0c\u51c6\u786e\u7387\u5df2\u7ecf\u8fbe\u5230 89%\uff0c\u800c\u6211\u4eec\u6240\u505a\u7684\u53ea\u662f\u4f7f\u7528\u8bcd\u888b\u548c\u903b\u8f91\u56de\u5f52\uff01\u8fd9\u771f\u662f\u592a\u68d2\u4e86\uff01\u4e0d\u8fc7\uff0c\u8fd9\u4e2a\u6a21\u578b\u7684\u8bad\u7ec3\u82b1\u8d39\u4e86\u5f88\u591a\u65f6\u95f4\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u80fd\u5426\u901a\u8fc7\u4f7f\u7528\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u6765\u7f29\u77ed\u8bad\u7ec3\u65f6\u95f4\u3002\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u5728 NLP \u4efb\u52a1\u4e2d\u76f8\u5f53\u6d41\u884c\uff0c\u56e0\u4e3a\u7a00\u758f\u77e9\u9635\u975e\u5e38\u5e9e\u5927\uff0c\u800c\u6734\u7d20\u8d1d\u53f6\u65af\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\u3002\u8981\u4f7f\u7528\u8fd9\u4e2a\u6a21\u578b\uff0c\u9700\u8981\u66f4\u6539\u4e00\u4e2a\u5bfc\u5165\u548c\u6a21\u578b\u7684\u884c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e2a\u6a21\u578b\u7684\u6027\u80fd\u5982\u4f55\u3002\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 MultinomialNB\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import naive_bayes from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer model = naive_bayes . MultinomialNB () model . fit ( xtrain , train_df . sentiment ) \u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8444 Fold : 1 Accuracy = 0.8499 Fold : 2 Accuracy = 0.8422 Fold : 3 Accuracy = 0.8443 Fold : 4 Accuracy = 0.8455 \u6211\u4eec\u770b\u5230\u8fd9\u4e2a\u5206\u6570\u5f88\u4f4e\u3002\u4f46\u6734\u7d20\u8d1d\u53f6\u65af\u6a21\u578b\u7684\u901f\u5ea6\u975e\u5e38\u5feb\u3002 NLP \u4e2d\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f TF-IDF\uff0c\u5982\u4eca\u5927\u591a\u6570\u4eba\u90fd\u503e\u5411\u4e8e\u5ffd\u7565\u6216\u4e0d\u5c51\u4e8e\u4e86\u89e3\u8fd9\u79cd\u65b9\u6cd5\u3002TF \u662f\u672f\u8bed\u9891\u7387\uff0cIDF \u662f\u53cd\u5411\u6587\u6863\u9891\u7387\u3002\u4ece\u8fd9\u4e9b\u672f\u8bed\u6765\u770b\uff0c\u8fd9\u4f3c\u4e4e\u6709\u4e9b\u56f0\u96be\uff0c\u4f46\u901a\u8fc7 TF \u548c IDF \u7684\u8ba1\u7b97\u516c\u5f0f\uff0c\u4e8b\u60c5\u5c31\u4f1a\u53d8\u5f97\u5f88\u660e\u663e\u3002 $$ TF(t) = \\frac{Number\\ of\\ times\\ a\\ term\\ t\\ appears\\ in\\ a\\ document}{Total\\ number\\ of\\ terms\\ in \\ the\\ document} $$ \\[ IDF(t) = LOG\\left(\\frac{Total\\ number\\ of\\ documents}{Number\\ of\\ documents with\\ term\\ t\\ in\\ it}\\right) \\] \u672f\u8bed t \u7684 TF-IDF \u5b9a\u4e49\u4e3a\uff1a $$ TF-IDF(t) = TF(t) \\times IDF(t) $$ \u4e0e scikit-learn \u4e2d\u7684 CountVectorizer \u7c7b\u4f3c\uff0c\u6211\u4eec\u4e5f\u6709 TfidfVectorizer\u3002\u8ba9\u6211\u4eec\u8bd5\u7740\u50cf\u4f7f\u7528 CountVectorizer \u4e00\u6837\u4f7f\u7528\u5b83\u3002 from sklearn.feature_extraction.text import TfidfVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) print ( corpus_transformed ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a ( 0 , 27 ) 0.2965698850220162 ( 0 , 16 ) 0.4428321995085722 ( 0 , 14 ) 0.4428321995085722 ( 0 , 7 ) 0.4428321995085722 ( 0 , 4 ) 0.35727423026525224 ( 0 , 2 ) 0.4428321995085722 ( 1 , 27 ) 0.35299699146792735 ( 1 , 24 ) 0.2635440111190765 ( 1 , 22 ) 0.2635440111190765 ( 1 , 18 ) 0.2635440111190765 ( 1 , 15 ) 0.2635440111190765 ( 1 , 13 ) 0.2635440111190765 ( 1 , 12 ) 0.2635440111190765 ( 1 , 9 ) 0.2635440111190765 ( 1 , 8 ) 0.2635440111190765 ( 1 , 6 ) 0.2635440111190765 ( 1 , 4 ) 0.42525129752567803 ( 1 , 3 ) 0.2635440111190765 ( 2 , 27 ) 0.31752680284846835 ( 2 , 19 ) 0.4741246485558491 ( 2 , 11 ) 0.4741246485558491 ( 2 , 10 ) 0.4741246485558491 ( 2 , 5 ) 0.4741246485558491 ( 3 , 25 ) 0.38775666010579296 ( 3 , 23 ) 0.38775666010579296 ( 3 , 21 ) 0.38775666010579296 ( 3 , 20 ) 0.38775666010579296 ( 3 , 17 ) 0.38775666010579296 ( 3 , 1 ) 0.38775666010579296 ( 3 , 0 ) 0.3128396318588854 ( 4 , 26 ) 0.2959842226518677 ( 4 , 0 ) 0.9551928286692534 \u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6b21\u6211\u4eec\u5f97\u5230\u7684\u4e0d\u662f\u6574\u6570\u503c\uff0c\u800c\u662f\u6d6e\u70b9\u6570\u3002 \u7528 TfidfVectorizer \u4ee3\u66ff CountVectorizer \u4e5f\u662f\u5c0f\u83dc\u4e00\u789f\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 TfidfTransformer\u3002\u5982\u679c\u4f60\u4f7f\u7528\u7684\u662f\u8ba1\u6570\u503c\uff0c\u53ef\u4ee5\u4f7f\u7528 TfidfTransformer \u5e76\u83b7\u5f97\u4e0e TfidfVectorizer \u76f8\u540c\u7684\u6548\u679c\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfidf_vec . fit ( train_df . review ) xtrain = tfidf_vec . transform ( train_df . review ) xtest = tfidf_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u6211\u4eec\u53ef\u4ee5\u770b\u770b TF-IDF \u5728\u903b\u8f91\u56de\u5f52\u6a21\u578b\u4e0a\u7684\u8868\u73b0\u5982\u4f55\u3002 Fold : 0 Accuracy = 0.8976 Fold : 1 Accuracy = 0.8998 Fold : 2 Accuracy = 0.8948 Fold : 3 Accuracy = 0.8912 Fold : 4 Accuracy = 0.8995 \u6211\u4eec\u770b\u5230\uff0c\u8fd9\u4e9b\u5206\u6570\u90fd\u6bd4 CountVectorizer \u9ad8\u4e00\u4e9b\uff0c\u56e0\u6b64\u5b83\u6210\u4e3a\u4e86\u6211\u4eec\u60f3\u8981\u51fb\u8d25\u7684\u65b0\u57fa\u51c6\u3002 NLP \u4e2d\u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u6982\u5ff5\u662f N-gram\u3002N-grams \u662f\u6309\u987a\u5e8f\u6392\u5217\u7684\u5355\u8bcd\u7ec4\u5408\u3002N-grams \u5f88\u5bb9\u6613\u521b\u5efa\u3002\u60a8\u53ea\u9700\u6ce8\u610f\u987a\u5e8f\u5373\u53ef\u3002\u4e3a\u4e86\u8ba9\u4e8b\u60c5\u53d8\u5f97\u66f4\u7b80\u5355\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 NLTK \u7684 N-gram \u5b9e\u73b0\u3002 from nltk import ngrams from nltk.tokenize import word_tokenize N = 3 sentence = \"hi, how are you?\" tokenized_sentence = word_tokenize ( sentence ) n_grams = list ( ngrams ( tokenized_sentence , N )) print ( n_grams ) \u7531\u6b64\u5f97\u5230\uff1a [( 'hi' , ',' , 'how' ), ( ',' , 'how' , 'are' ), ( 'how' , 'are' , 'you' ), ( 'are' , 'you' , '?' )] \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u521b\u5efa 2-gram \u6216 4-gram \u7b49\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b n-gram \u5c06\u6210\u4e3a\u6211\u4eec\u8bcd\u6c47\u8868\u7684\u4e00\u90e8\u5206\uff0c\u5f53\u6211\u4eec\u8ba1\u7b97\u8ba1\u6570\u6216 tf-idf \u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u4e00\u4e2a n-gram \u89c6\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u6807\u8bb0\u3002\u56e0\u6b64\uff0c\u5728\u67d0\u79cd\u7a0b\u5ea6\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u7ed3\u5408\u4e0a\u4e0b\u6587\u3002scikit-learn \u7684 CountVectorizer \u548c TfidfVectorizer \u5b9e\u73b0\u90fd\u901a\u8fc7 ngram_range \u53c2\u6570\u63d0\u4f9b n-gram\uff0c\u8be5\u53c2\u6570\u6709\u6700\u5c0f\u548c\u6700\u5927\u9650\u5236\u3002\u9ed8\u8ba4\u60c5\u51b5\u4e0b\uff0c\u8be5\u53c2\u6570\u4e3a\uff081, 1\uff09\u3002\u5f53\u6211\u4eec\u5c06\u5176\u6539\u4e3a (1, 3) \u65f6\uff0c\u6211\u4eec\u5c06\u770b\u5230\u5355\u5b57\u5143\u3001\u53cc\u5b57\u5143\u548c\u4e09\u5b57\u5143\u3002\u4ee3\u7801\u6539\u52a8\u5f88\u5c0f\u3002 \u7531\u4e8e\u5230\u76ee\u524d\u4e3a\u6b62\u6211\u4eec\u4f7f\u7528 tf-idf \u5f97\u5230\u4e86\u6700\u597d\u7684\u7ed3\u679c\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5305\u542b n-grams \u76f4\u81f3 trigrams \u662f\u5426\u80fd\u6539\u8fdb\u6a21\u578b\u3002\u552f\u4e00\u9700\u8981\u4fee\u6539\u7684\u662f TfidfVectorizer \u7684\u521d\u59cb\u5316\u3002 tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None , ngram_range = ( 1 , 3 ) ) \u8ba9\u6211\u4eec\u770b\u770b\u662f\u5426\u4f1a\u6709\u6539\u8fdb\u3002 Fold : 0 Accuracy = 0.8931 Fold : 1 Accuracy = 0.8941 Fold : 2 Accuracy = 0.897 Fold : 3 Accuracy = 0.8922 Fold : 4 Accuracy = 0.8847 \u770b\u8d77\u6765\u8fd8\u884c\uff0c\u4f46\u6211\u4eec\u770b\u4e0d\u5230\u4efb\u4f55\u6539\u8fdb\u3002 \u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u591a\u4f7f\u7528 bigrams \u6765\u83b7\u5f97\u6539\u8fdb\u3002 \u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u4e00\u90e8\u5206\u3002\u4e5f\u8bb8\u4f60\u53ef\u4ee5\u81ea\u5df1\u8bd5\u7740\u505a\u3002 NLP \u7684\u57fa\u7840\u77e5\u8bc6\u8fd8\u6709\u5f88\u591a\u3002\u4f60\u5fc5\u987b\u77e5\u9053\u7684\u4e00\u4e2a\u672f\u8bed\u662f\u8bcd\u5e72\u63d0\u53d6\uff08strmming\uff09\u3002\u53e6\u4e00\u4e2a\u662f\u8bcd\u5f62\u8fd8\u539f\uff08lemmatization\uff09\u3002 \u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f \u53ef\u4ee5\u5c06\u4e00\u4e2a\u8bcd\u51cf\u5c11\u5230\u6700\u5c0f\u5f62\u5f0f\u3002\u5728\u8bcd\u5e72\u63d0\u53d6\u7684\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5e72\u5355\u8bcd\uff0c\u800c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5f62\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u8bcd\u5f62\u8fd8\u539f\u6bd4\u8bcd\u5e72\u63d0\u53d6\u66f4\u6fc0\u8fdb\uff0c\u800c\u8bcd\u5e72\u63d0\u53d6\u66f4\u6d41\u884c\u548c\u5e7f\u6cdb\u3002\u8bcd\u5e72\u548c\u8bcd\u5f62\u90fd\u6765\u81ea\u8bed\u8a00\u5b66\u3002\u5982\u679c\u4f60\u6253\u7b97\u4e3a\u67d0\u79cd\u8bed\u8a00\u5236\u4f5c\u8bcd\u5e72\u6216\u8bcd\u578b\uff0c\u9700\u8981\u5bf9\u8be5\u8bed\u8a00\u6709\u6df1\u5165\u7684\u4e86\u89e3\u3002\u5982\u679c\u8981\u8fc7\u591a\u5730\u4ecb\u7ecd\u8fd9\u4e9b\u77e5\u8bc6\uff0c\u5c31\u610f\u5473\u7740\u8981\u5728\u672c\u4e66\u4e2d\u589e\u52a0\u4e00\u7ae0\u3002\u4f7f\u7528 NLTK \u8f6f\u4ef6\u5305\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e24\u79cd\u65b9\u6cd5\u7684\u4e00\u4e9b\u793a\u4f8b\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u5668\u3002\u6211\u5c06\u7528\u6700\u5e38\u89c1\u7684 Snowball Stemmer \u548c WordNet Lemmatizer \u6765\u4e3e\u4f8b\u8bf4\u660e\u3002 from nltk.stem import WordNetLemmatizer from nltk.stem.snowball import SnowballStemmer lemmatizer = WordNetLemmatizer () stemmer = SnowballStemmer ( \"english\" ) words = [ \"fishing\" , \"fishes\" , \"fished\" ] for word in words : print ( f \"word= { word } \" ) print ( f \"stemmed_word= { stemmer . stem ( word ) } \" ) print ( f \"lemma= { lemmatizer . lemmatize ( word ) } \" ) print ( \"\" ) \u8fd9\u5c06\u6253\u5370\uff1a word = fishing stemmed_word = fish lemma = fishing word = fishes stemmed_word = fish lemma = fish word = fished stemmed_word = fish lemma = fished \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u662f\u622a\u7136\u4e0d\u540c\u7684\u3002\u5f53\u6211\u4eec\u8fdb\u884c\u8bcd\u5e72\u63d0\u53d6\u65f6\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u4e00\u4e2a\u8bcd\u7684\u6700\u5c0f\u5f62\u5f0f\uff0c\u5b83\u53ef\u80fd\u662f\u4e5f\u53ef\u80fd\u4e0d\u662f\u8be5\u8bcd\u6240\u5c5e\u8bed\u8a00\u8bcd\u5178\u4e2d\u7684\u4e00\u4e2a\u8bcd\u3002\u4f46\u662f\uff0c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u8fd9\u5c06\u662f\u4e00\u4e2a\u8bcd\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5c1d\u8bd5\u6dfb\u52a0\u8bcd\u5e72\u548c\u8bcd\u7d20\u5316\uff0c\u770b\u770b\u662f\u5426\u80fd\u6539\u5584\u7ed3\u679c\u3002 \u60a8\u8fd8\u5e94\u8be5\u4e86\u89e3\u7684\u4e00\u4e2a\u4e3b\u9898\u662f\u4e3b\u9898\u63d0\u53d6\u3002 \u4e3b\u9898\u63d0\u53d6 \u53ef\u4ee5\u4f7f\u7528\u975e\u8d1f\u77e9\u9635\u56e0\u5f0f\u5206\u89e3\uff08NMF\uff09\u6216\u6f5c\u5728\u8bed\u4e49\u5206\u6790\uff08LSA\uff09\u6765\u5b8c\u6210\uff0c\u540e\u8005\u4e5f\u88ab\u79f0\u4e3a\u5947\u5f02\u503c\u5206\u89e3\u6216 SVD\u3002\u8fd9\u4e9b\u5206\u89e3\u6280\u672f\u53ef\u5c06\u6570\u636e\u7b80\u5316\u4e3a\u7ed9\u5b9a\u6570\u91cf\u7684\u6210\u5206\u3002 \u60a8\u53ef\u4ee5\u5728\u4ece CountVectorizer \u6216 TfidfVectorizer \u4e2d\u83b7\u5f97\u7684\u7a00\u758f\u77e9\u9635\u4e0a\u5e94\u7528\u5176\u4e2d\u4efb\u4f55\u4e00\u79cd\u6280\u672f\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u5e94\u7528\u5230\u4e4b\u524d\u4f7f\u7528\u8fc7\u7684 TfidfVetorizer \u4e0a\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import decomposition from sklearn.feature_extraction.text import TfidfVectorizer corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus = corpus . review . values tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) svd = decomposition . TruncatedSVD ( n_components = 10 ) corpus_svd = svd . fit ( corpus_transformed ) sample_index = 0 feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) N = 5 print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ]) \u60a8\u53ef\u4ee5\u4f7f\u7528\u5faa\u73af\u6765\u8fd0\u884c\u591a\u4e2a\u6837\u672c\u3002 N = 5 for sample_index in range ( 5 ): feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ] ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a [ 'the' , ',' , '.' , 'a' , 'and' ] [ 'br' , '<' , '>' , '/' , '-' ] [ 'i' , 'movie' , '!' , 'it' , 'was' ] [ ',' , '!' , \"''\" , '``' , 'you' ] [ '!' , 'the' , '...' , \"''\" , '``' ] \u4f60\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6839\u672c\u8bf4\u4e0d\u901a\u3002\u600e\u4e48\u529e\u5462\uff1f\u8ba9\u6211\u4eec\u8bd5\u7740\u6e05\u7406\u4e00\u4e0b\uff0c\u770b\u770b\u662f\u5426\u6709\u610f\u4e49\u3002\u8981\u6e05\u7406\u4efb\u4f55\u6587\u672c\u6570\u636e\uff0c\u5c24\u5176\u662f pandas \u6570\u636e\u5e27\u4e2d\u7684\u6587\u672c\u6570\u636e\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u51fd\u6570\u3002 import re import string def clean_text ( s ): s = s . split () s = \" \" . join ( s ) s = re . sub ( f '[ { re . escape ( string . punctuation ) } ]' , '' , s ) return s \u8be5\u51fd\u6570\u4f1a\u5c06 \"hi, how are you????\" \u8fd9\u6837\u7684\u5b57\u7b26\u4e32\u8f6c\u6362\u4e3a \"hi how are you\"\u3002\u8ba9\u6211\u4eec\u628a\u8fd9\u4e2a\u51fd\u6570\u5e94\u7528\u5230\u65e7\u7684 SVD \u4ee3\u7801\u4e2d\uff0c\u770b\u770b\u5b83\u662f\u5426\u80fd\u7ed9\u63d0\u53d6\u7684\u4e3b\u9898\u5e26\u6765\u63d0\u5347\u3002\u4f7f\u7528 pandas\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 apply \u51fd\u6570\u5c06\u6e05\u7406\u4ee3\u7801 \"\u5e94\u7528 \"\u5230\u4efb\u610f\u7ed9\u5b9a\u7684\u5217\u4e2d\u3002 import pandas as pd corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus . loc [:, \"review\" ] = corpus . review . apply ( clean_text ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u5728\u4e3b SVD \u811a\u672c\u4e2d\u6dfb\u52a0\u4e86\u4e00\u884c\u4ee3\u7801\uff0c\u8fd9\u5c31\u662f\u4f7f\u7528\u51fd\u6570\u548c pandas \u5e94\u7528\u7684\u597d\u5904\u3002\u8fd9\u6b21\u751f\u6210\u7684\u4e3b\u9898\u5982\u4e0b\u3002 [ 'the' , 'a' , 'and' , 'of' , 'to' ] [ 'i' , 'movie' , 'it' , 'was' , 'this' ] [ 'the' , 'was' , 'i' , 'were' , 'of' ] [ 'her' , 'was' , 'she' , 'i' , 'he' ] [ 'br' , 'to' , 'they' , 'he' , 'show' ] \u547c\uff01\u81f3\u5c11\u8fd9\u6bd4\u6211\u4eec\u4e4b\u524d\u597d\u591a\u4e86\u3002\u4f46\u4f60\u77e5\u9053\u5417\uff1f\u4f60\u53ef\u4ee5\u901a\u8fc7\u5728\u6e05\u7406\u529f\u80fd\u4e2d\u5220\u9664\u505c\u6b62\u8bcd\uff08stopwords\uff09\u6765\u4f7f\u5b83\u53d8\u5f97\u66f4\u597d\u3002\u4ec0\u4e48\u662fstopwords\uff1f\u5b83\u4eec\u662f\u5b58\u5728\u4e8e\u6bcf\u79cd\u8bed\u8a00\u4e2d\u7684\u9ad8\u9891\u8bcd\u3002\u4f8b\u5982\uff0c\u5728\u82f1\u8bed\u4e2d\uff0c\u8fd9\u4e9b\u8bcd\u5305\u62ec \"a\"\u3001\"an\"\u3001\"the\"\u3001\"for \"\u7b49\u3002\u5220\u9664\u505c\u6b62\u8bcd\u5e76\u975e\u603b\u662f\u660e\u667a\u7684\u9009\u62e9\uff0c\u8fd9\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u4e1a\u52a1\u95ee\u9898\u3002\u50cf \"I need a new dog\"\u8fd9\u6837\u7684\u53e5\u5b50\uff0c\u53bb\u6389\u505c\u6b62\u8bcd\u540e\u4f1a\u53d8\u6210 \"need new dog\"\uff0c\u6b64\u65f6\u6211\u4eec\u4e0d\u77e5\u9053\u8c01\u9700\u8981new dog\u3002 \u5982\u679c\u6211\u4eec\u603b\u662f\u5220\u9664\u505c\u6b62\u8bcd\uff0c\u5c31\u4f1a\u4e22\u5931\u5f88\u591a\u4e0a\u4e0b\u6587\u4fe1\u606f\u3002\u4f60\u53ef\u4ee5\u5728 NLTK \u4e2d\u627e\u5230\u8bb8\u591a\u8bed\u8a00\u7684\u505c\u6b62\u8bcd\uff0c\u5982\u679c\u6ca1\u6709\uff0c\u4f60\u4e5f\u53ef\u4ee5\u5728\u81ea\u5df1\u559c\u6b22\u7684\u641c\u7d22\u5f15\u64ce\u4e0a\u5feb\u901f\u641c\u7d22\u4e00\u4e0b\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u8f6c\u5230\u5927\u591a\u6570\u4eba\u90fd\u559c\u6b22\u4f7f\u7528\u7684\u65b9\u6cd5\uff1a\u6df1\u5ea6\u5b66\u4e60\u3002\u4f46\u9996\u5148\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u8bcd\u5d4c\u5165\uff08embedings for words\uff09\u3002\u4f60\u5df2\u7ecf\u770b\u5230\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u5c06\u6807\u8bb0\u8f6c\u6362\u6210\u4e86\u6570\u5b57\u3002\u56e0\u6b64\uff0c\u5982\u679c\u67d0\u4e2a\u8bed\u6599\u5e93\u4e2d\u6709 N \u4e2a\u552f\u4e00\u7684\u8bcd\u5757\uff0c\u5b83\u4eec\u53ef\u4ee5\u7528 0 \u5230 N-1 \u4e4b\u95f4\u7684\u6574\u6570\u6765\u8868\u793a\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u7528\u5411\u91cf\u6765\u8868\u793a\u8fd9\u4e9b\u6574\u6570\u8bcd\u5757\u3002\u8fd9\u79cd\u5c06\u5355\u8bcd\u8868\u793a\u6210\u5411\u91cf\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u5355\u8bcd\u5d4c\u5165\u6216\u5355\u8bcd\u5411\u91cf\u3002\u8c37\u6b4c\u7684 Word2Vec \u662f\u5c06\u5355\u8bcd\u8f6c\u6362\u4e3a\u5411\u91cf\u7684\u6700\u53e4\u8001\u65b9\u6cd5\u4e4b\u4e00\u3002\u6b64\u5916\uff0c\u8fd8\u6709 Facebook \u7684 FastText \u548c\u65af\u5766\u798f\u5927\u5b66\u7684 GloVe\uff08\u7528\u4e8e\u5355\u8bcd\u8868\u793a\u7684\u5168\u5c40\u5411\u91cf\uff09\u3002\u8fd9\u4e9b\u65b9\u6cd5\u5f7c\u6b64\u5927\u76f8\u5f84\u5ead\u3002 \u5176\u57fa\u672c\u601d\u60f3\u662f\u5efa\u7acb\u4e00\u4e2a\u6d45\u5c42\u7f51\u7edc\uff0c\u901a\u8fc7\u91cd\u6784\u8f93\u5165\u53e5\u5b50\u6765\u5b66\u4e60\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u5468\u56f4\u7684\u6240\u6709\u5355\u8bcd\u6765\u8bad\u7ec3\u7f51\u7edc\u9884\u6d4b\u4e00\u4e2a\u7f3a\u5931\u7684\u5355\u8bcd\uff0c\u5728\u6b64\u8fc7\u7a0b\u4e2d\uff0c\u7f51\u7edc\u5c06\u5b66\u4e60\u5e76\u66f4\u65b0\u6240\u6709\u76f8\u5173\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u8fd9\u79cd\u65b9\u6cd5\u4e5f\u88ab\u79f0\u4e3a\u8fde\u7eed\u8bcd\u888b\u6216 CBoW \u6a21\u578b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e2a\u5355\u8bcd\u6765\u9884\u6d4b\u4e0a\u4e0b\u6587\u4e2d\u7684\u5355\u8bcd\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8df3\u683c\u6a21\u578b\u3002Word2Vec \u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\u5b66\u4e60\u5d4c\u5165\u3002 FastText \u53ef\u4ee5\u5b66\u4e60\u5b57\u7b26 n-gram \u7684\u5d4c\u5165\u3002\u548c\u5355\u8bcd n-gram \u4e00\u6837\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5b57\u7b26\uff0c\u5219\u79f0\u4e3a\u5b57\u7b26 n-gram\uff0c\u6700\u540e\uff0cGloVe \u901a\u8fc7\u5171\u73b0\u77e9\u9635\u6765\u5b66\u4e60\u8fd9\u4e9b\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0c\u6240\u6709\u8fd9\u4e9b\u4e0d\u540c\u7c7b\u578b\u7684\u5d4c\u5165\u6700\u7ec8\u90fd\u4f1a\u8fd4\u56de\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u8bed\u6599\u5e93\uff08\u4f8b\u5982\u82f1\u8bed\u7ef4\u57fa\u767e\u79d1\uff09\u4e2d\u7684\u5355\u8bcd\uff0c\u503c\u662f\u5927\u5c0f\u4e3a N\uff08\u901a\u5e38\u4e3a 300\uff09\u7684\u5411\u91cf\u3002 \u56fe 1\uff1a\u53ef\u89c6\u5316\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u3002 \u56fe 1 \u663e\u793a\u4e86\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u7684\u53ef\u89c6\u5316\u6548\u679c\u3002\u5047\u8bbe\u6211\u4eec\u4ee5\u67d0\u79cd\u65b9\u5f0f\u5b8c\u6210\u4e86\u8bcd\u8bed\u7684\u4e8c\u7ef4\u8868\u793a\u3002\u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u4eceBerlin\uff08\u5fb7\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u4e2d\u51cf\u53bb\u5fb7\u56fd\uff08Germany\uff09\u7684\u5411\u91cf\uff0c\u518d\u52a0\u4e0a\u6cd5\u56fd\uff08france\uff09\u7684\u5411\u91cf\uff0c\u5c31\u4f1a\u5f97\u5230\u4e00\u4e2a\u63a5\u8fd1Paris\uff08\u6cd5\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u3002\u7531\u6b64\u53ef\u89c1\uff0c\u5d4c\u5165\u5f0f\u4e5f\u80fd\u8fdb\u884c\u7c7b\u6bd4\u3002 \u8fd9\u5e76\u4e0d\u603b\u662f\u6b63\u786e\u7684\uff0c\u4f46\u8fd9\u6837\u7684\u4f8b\u5b50\u6709\u52a9\u4e8e\u7406\u89e3\u5355\u8bcd\u5d4c\u5165\u7684\u4f5c\u7528\u3002\u50cf \"\u55e8\uff0c\u4f60\u597d\u5417 \"\u8fd9\u6837\u7684\u53e5\u5b50\u53ef\u4ee5\u7528\u4e0b\u9762\u7684\u4e00\u5806\u5411\u91cf\u6765\u8868\u793a\u3002 hi \u2500> [vector (v1) of size 300] , \u2500> [vector (v2) of size 300] how \u2500> [vector (v3) of size 300] are \u2500> [vector (v4) of size 300] you \u2500> [vector (v5) of size 300] ? \u2500> [vector (v6) of size 300] \u4f7f\u7528\u8fd9\u4e9b\u4fe1\u606f\u6709\u591a\u79cd\u65b9\u6cd5\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u4e4b\u4e00\u5c31\u662f\u4f7f\u7528\u5d4c\u5165\u5411\u91cf\u3002\u5982\u4e0a\u4f8b\u6240\u793a\uff0c\u6bcf\u4e2a\u5355\u8bcd\u90fd\u6709\u4e00\u4e2a 1x300 \u7684\u5d4c\u5165\u5411\u91cf\u3002\u5229\u7528\u8fd9\u4e9b\u4fe1\u606f\uff0c\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u51fa\u6574\u4e2a\u53e5\u5b50\u7684\u5d4c\u5165\u3002\u8ba1\u7b97\u65b9\u6cd5\u6709\u591a\u79cd\u3002\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\u5982\u4e0b\u6240\u793a\u3002\u5728\u8fd9\u4e2a\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u5c06\u7ed9\u5b9a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u5411\u91cf\u63d0\u53d6\u51fa\u6765\uff0c\u7136\u540e\u4ece\u6240\u6709\u6807\u8bb0\u8bcd\u7684\u5355\u8bcd\u5411\u91cf\u4e2d\u521b\u5efa\u4e00\u4e2a\u5f52\u4e00\u5316\u7684\u5355\u8bcd\u5411\u91cf\u3002\u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u53e5\u5b50\u5411\u91cf\u3002 import numpy as np def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): words = str ( s ) . lower () words = tokenizer ( words ) words = [ w for w in words if not w in stop_words ] words = [ w for w in words if w . isalpha ()] M = [] for w in words : if w in embedding_dict : M . append ( embedding_dict [ w ]) if len ( M ) == 0 : return np . zeros ( 300 ) M = np . array ( M ) v = M . sum ( axis = 0 ) return v / np . sqrt (( v ** 2 ) . sum ()) \u6211\u4eec\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5c06\u6240\u6709\u793a\u4f8b\u8f6c\u6362\u6210\u4e00\u4e2a\u5411\u91cf\u3002\u6211\u4eec\u80fd\u5426\u4f7f\u7528 fastText \u5411\u91cf\u6765\u6539\u8fdb\u4e4b\u524d\u7684\u7ed3\u679c\uff1f\u6bcf\u7bc7\u8bc4\u8bba\u90fd\u6709 300 \u4e2a\u7279\u5f81\u3002 import io import numpy as np import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df = df . sample ( frac = 1 ) . reset_index ( drop = True ) print ( \"Loading embeddings\" ) embeddings = load_vectors ( \"../input/crawl-300d-2M.vec\" ) print ( \"Creating sentence vectors\" ) vectors = [] for review in df . review . values : vectors . append ( sentence_to_vec ( s = review , embedding_dict = embeddings , stop_words = [], tokenizer = word_tokenize ) ) vectors = np . array ( vectors ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for fold_ , ( t_ , v_ ) in enumerate ( kf . split ( X = vectors , y = y )): print ( f \"Training fold: { fold_ } \" ) xtrain = vectors [ t_ , :] ytrain = y [ t_ ] xtest = vectors [ v_ , :] ytest = y [ v_ ] model = linear_model . LogisticRegression () model . fit ( xtrain , ytrain ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( ytest , preds ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u5c06\u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Loading embeddings Creating sentence vectors Training fold : 0 Accuracy = 0.8619 Training fold : 1 Accuracy = 0.8661 Training fold : 2 Accuracy = 0.8544 Training fold : 3 Accuracy = 0.8624 Training fold : 4 Accuracy = 0.8595 Wow\uff01\u771f\u662f\u51fa\u4e4e\u610f\u6599\u3002\u6211\u4eec\u6240\u505a\u7684\u4e00\u5207\u90fd\u662f\u4e3a\u4e86\u4f7f\u7528 FastText \u5d4c\u5165\u3002\u8bd5\u7740\u628a\u5d4c\u5165\u5f0f\u6362\u6210 GloVe\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u3002 \u5f53\u6211\u4eec\u8c08\u8bba\u6587\u672c\u6570\u636e\u65f6\uff0c\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\u4e00\u4ef6\u4e8b\u3002\u6587\u672c\u6570\u636e\u4e0e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u975e\u5e38\u76f8\u4f3c\u3002\u5982\u56fe 2 \u6240\u793a\uff0c\u6211\u4eec\u8bc4\u8bba\u4e2d\u7684\u4efb\u4f55\u6837\u672c\u90fd\u662f\u5728\u4e0d\u540c\u65f6\u95f4\u6233\u4e0a\u6309\u9012\u589e\u987a\u5e8f\u6392\u5217\u7684\u6807\u8bb0\u5e8f\u5217\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u53ef\u4ee5\u8868\u793a\u4e3a\u4e00\u4e2a\u5411\u91cf/\u5d4c\u5165\u3002 \u56fe 2\uff1a\u5c06\u6807\u8bb0\u8868\u793a\u4e3a\u5d4c\u5165\uff0c\u5e76\u5c06\u5176\u89c6\u4e3a\u65f6\u95f4\u5e8f\u5217 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e7f\u6cdb\u7528\u4e8e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u7684\u6a21\u578b\uff0c\u4f8b\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\uff08LSTM\uff09\u6216\u95e8\u63a7\u9012\u5f52\u5355\u5143\uff08GRU\uff09\uff0c\u751a\u81f3\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff08CNN\uff09\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u53cc\u5411 LSTM \u6a21\u578b\u3002 \u9996\u5148\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u9879\u76ee\u3002\u4f60\u53ef\u4ee5\u968f\u610f\u7ed9\u5b83\u547d\u540d\u3002\u7136\u540e\uff0c\u6211\u4eec\u7684\u7b2c\u4e00\u6b65\u5c06\u662f\u5206\u5272\u6570\u636e\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f df . to_csv ( \"../input/imdb_folds.csv\" , index = False ) \u5c06\u6570\u636e\u96c6\u5212\u5206\u4e3a\u591a\u4e2a\u6298\u53e0\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5728 dataset.py \u4e2d\u521b\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u4f1a\u8fd4\u56de\u4e00\u4e2a\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\u6837\u672c\u3002 import torch class IMDBDataset : def __init__ ( self , reviews , targets ): self . reviews = reviews self . target = targets def __len__ ( self ): return len ( self . reviews ) def __getitem__ ( self , item ): review = self . reviews [ item , :] target = self . target [ item ] return { \"review\" : torch . tensor ( review , dtype = torch . long ), \"target\" : torch . tensor ( target , dtype = torch . float ) } \u5b8c\u6210\u6570\u636e\u96c6\u5206\u7c7b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa lstm.py\uff0c\u5176\u4e2d\u5305\u542b\u6211\u4eec\u7684 LSTM \u6a21\u578b import torch import torch.nn as nn class LSTM ( nn . Module ): def __init__ ( self , embedding_matrix ): super ( LSTM , self ) . __init__ () num_words = embedding_matrix . shape [ 0 ] embed_dim = embedding_matrix . shape [ 1 ] self . embedding = nn . Embedding ( num_embeddings = num_words , embedding_dim = embed_dim ) self . embedding . weight = nn . Parameter ( torch . tensor ( embedding_matrix , dtype = torch . float32 ) ) self . embedding . weight . requires_grad = False self . lstm = nn . LSTM ( embed_dim , 128 , bidirectional = True , batch_first = True , ) self . out = nn . Linear ( 512 , 1 ) def forward ( self , x ): x = self . embedding ( x ) x , _ = self . lstm ( x ) avg_pool = torch . mean ( x , 1 ) max_pool , _ = torch . max ( x , 1 ) out = torch . cat (( avg_pool , max_pool ), 1 ) out = self . out ( out ) return out \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa engine.py\uff0c\u5176\u4e2d\u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u51fd\u6570\u3002 import torch import torch.nn as nn def train ( data_loader , model , optimizer , device ): model . train () for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () predictions = model ( reviews ) loss = nn . BCEWithLogitsLoss ()( predictions , targets . view ( - 1 , 1 ) ) loss . backward () optimizer . step () def evaluate ( data_loader , model , device ): final_predictions = [] final_targets = [] model . eval () with torch . no_grad (): for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) predictions = model ( reviews ) predictions = predictions . cpu () . numpy () . tolist () targets = data [ \"target\" ] . cpu () . numpy () . tolist () final_predictions . extend ( predictions ) final_targets . extend ( targets ) return final_predictions , final_targets \u8fd9\u4e9b\u51fd\u6570\u5c06\u5728 train.py \u4e2d\u4e3a\u6211\u4eec\u63d0\u4f9b\u5e2e\u52a9\uff0c\u8be5\u51fd\u6570\u7528\u4e8e\u8bad\u7ec3\u591a\u4e2a\u6298\u53e0\u3002 import io import torch import numpy as np import pandas as pd import tensorflow as tf from sklearn import metrics import config import dataset import engine import lstm def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def create_embedding_matrix ( word_index , embedding_dict ): embedding_matrix = np . zeros (( len ( word_index ) + 1 , 300 )) for word , i in word_index . items (): if word in embedding_dict : embedding_matrix [ i ] = embedding_dict [ word ] return embedding_matrix def run ( df , fold ): train_df = df [ df . kfold != fold ] . reset_index ( drop = True ) valid_df = df [ df . kfold == fold ] . reset_index ( drop = True ) print ( \"Fitting tokenizer\" ) tokenizer = tf . keras . preprocessing . text . Tokenizer () tokenizer . fit_on_texts ( df . review . values . tolist ()) xtrain = tokenizer . texts_to_sequences ( train_df . review . values ) xtest = tokenizer . texts_to_sequences ( valid_df . review . values ) xtrain = tf . keras . preprocessing . sequence . pad_sequences ( xtrain , maxlen = config . MAX_LEN ) xtest = tf . keras . preprocessing . sequence . pad_sequences ( xtest , maxlen = config . MAX_LEN ) train_dataset = dataset . IMDBDataset ( reviews = xtrain , targets = train_df . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 2 ) valid_dataset = dataset . IMDBDataset ( reviews = xtest , targets = valid_df . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) print ( \"Loading embeddings\" ) embedding_dict = load_vectors ( \"../input/crawl-300d-2M.vec\" ) embedding_matrix = create_embedding_matrix ( tokenizer . word_index , embedding_dict ) device = torch . device ( \"cuda\" ) model = lstm . LSTM ( embedding_matrix ) model . to ( device ) optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) print ( \"Training Model\" ) best_accuracy = 0 early_stopping_counter = 0 for epoch in range ( config . EPOCHS ): engine . train ( train_data_loader , model , optimizer , device ) outputs , targets = engine . evaluate ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"FOLD: { fold } , Epoch: { epoch } , Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : best_accuracy = accuracy else : early_stopping_counter += 1 if early_stopping_counter > 2 : break if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb_folds.csv\" ) run ( df , fold = 0 ) run ( df , fold = 1 ) run ( df , fold = 2 ) run ( df , fold = 3 ) run ( df , fold = 4 ) \u6700\u540e\u662f config.py\u3002 MAX_LEN = 128 TRAIN_BATCH_SIZE = 16 VALID_BATCH_SIZE = 8 EPOCHS = 10 \u8ba9\u6211\u4eec\u770b\u770b\u8f93\u51fa\uff1a FOLD : 0 , Epoch : 3 , Accuracy Score = 0.9015 FOLD : 1 , Epoch : 4 , Accuracy Score = 0.9007 FOLD : 2 , Epoch : 3 , Accuracy Score = 0.8924 FOLD : 3 , Epoch : 2 , Accuracy Score = 0.9 FOLD : 4 , Epoch : 1 , Accuracy Score = 0.878 \u8fd9\u662f\u8fc4\u4eca\u4e3a\u6b62\u6211\u4eec\u83b7\u5f97\u7684\u6700\u597d\u6210\u7ee9\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u53ea\u663e\u793a\u4e86\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7cbe\u5ea6\u6700\u9ad8\u7684Epoch\u3002 \u4f60\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u9884\u5148\u8bad\u7ec3\u7684\u5d4c\u5165\u548c\u7b80\u5355\u7684\u53cc\u5411 LSTM\u3002 \u5982\u679c\u4f60\u60f3\u6539\u53d8\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u53ea\u6539\u53d8 lstm.py \u4e2d\u7684\u6a21\u578b\u5e76\u4fdd\u6301\u4e00\u5207\u4e0d\u53d8\u3002 \u8fd9\u79cd\u4ee3\u7801\u53ea\u9700\u8981\u5f88\u5c11\u7684\u5b9e\u9a8c\u6539\u52a8\uff0c\u5e76\u4e14\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \u4f8b\u5982\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5b66\u4e60\u5d4c\u5165\u800c\u4e0d\u662f\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6\u4e00\u4e9b\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u7ec4\u5408\u591a\u4e2a\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528GRU\uff0c\u60a8\u53ef\u4ee5\u5728\u5d4c\u5165\u540e\u4f7f\u7528\u7a7a\u95f4dropout\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0GRU LSTM \u5c42\u4e4b\u540e\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0\u4e24\u4e2a LSTM \u5c42\uff0c\u60a8\u53ef\u4ee5\u8fdb\u884c LSTM-GRU-LSTM \u914d\u7f6e\uff0c\u60a8\u53ef\u4ee5\u7528\u5377\u79ef\u5c42\u66ff\u6362 LSTM \u7b49\uff0c\u800c\u65e0\u9700\u5bf9\u4ee3\u7801\u8fdb\u884c\u592a\u591a\u66f4\u6539\u3002 \u6211\u63d0\u5230\u7684\u5927\u90e8\u5206\u5185\u5bb9\u53ea\u9700\u8981\u66f4\u6539\u6a21\u578b\u7c7b\u3002 \u5f53\u60a8\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\u65f6\uff0c\u5c1d\u8bd5\u67e5\u770b\u6709\u591a\u5c11\u5355\u8bcd\u65e0\u6cd5\u627e\u5230\u5d4c\u5165\u4ee5\u53ca\u539f\u56e0\u3002 \u9884\u8bad\u7ec3\u5d4c\u5165\u7684\u5355\u8bcd\u8d8a\u591a\uff0c\u7ed3\u679c\u5c31\u8d8a\u597d\u3002 \u6211\u5411\u60a8\u5c55\u793a\u4ee5\u4e0b\u672a\u6ce8\u91ca\u7684 (!) \u51fd\u6570\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5b83\u4e3a\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u8bad\u7ec3\u5d4c\u5165\u521b\u5efa\u5d4c\u5165\u77e9\u9635\uff0c\u5176\u683c\u5f0f\u4e0e glove \u6216 fastText \u76f8\u540c\uff08\u53ef\u80fd\u9700\u8981\u8fdb\u884c\u4e00\u4e9b\u66f4\u6539\uff09\u3002 def load_embeddings ( word_index , embedding_file , vector_length = 300 ): max_features = len ( word_index ) + 1 words_to_find = list ( word_index . keys ()) more_words_to_find = [] for wtf in words_to_find : more_words_to_find . append ( wtf ) more_words_to_find . append ( str ( wtf ) . capitalize ()) more_words_to_find = set ( more_words_to_find ) def get_coefs ( word , * arr ): return word , np . asarray ( arr , dtype = 'float32' ) embeddings_index = dict ( get_coefs ( * o . strip () . split ( \" \" )) for o in open ( embedding_file ) if o . split ( \" \" )[ 0 ] in more_words_to_find and len ( o ) > 100 ) embedding_matrix = np . zeros (( max_features , vector_length )) for word , i in word_index . items (): if i >= max_features : continue embedding_vector = embeddings_index . get ( word ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . capitalize () ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . upper () ) if ( embedding_vector is not None and len ( embedding_vector ) == vector_length ): embedding_matrix [ i ] = embedding_vector return embedding_matrix \u9605\u8bfb\u5e76\u8fd0\u884c\u4e0a\u9762\u7684\u51fd\u6570\uff0c\u770b\u770b\u53d1\u751f\u4e86\u4ec0\u4e48\u3002 \u8be5\u51fd\u6570\u8fd8\u53ef\u4ee5\u4fee\u6539\u4e3a\u4f7f\u7528\u8bcd\u5e72\u8bcd\u6216\u8bcd\u5f62\u8fd8\u539f\u8bcd\u3002 \u6700\u540e\uff0c\u60a8\u5e0c\u671b\u8bad\u7ec3\u8bed\u6599\u5e93\u4e2d\u7684\u672a\u77e5\u5355\u8bcd\u6570\u91cf\u6700\u5c11\u3002 \u53e6\u4e00\u4e2a\u6280\u5de7\u662f\u5b66\u4e60\u5d4c\u5165\u5c42\uff0c\u5373\u4f7f\u5176\u53ef\u8bad\u7ec3\uff0c\u7136\u540e\u8bad\u7ec3\u7f51\u7edc\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5206\u7c7b\u95ee\u9898\u6784\u5efa\u4e86\u5f88\u591a\u6a21\u578b\u3002 \u7136\u800c\uff0c\u73b0\u5728\u662f\u5e03\u5076\u65f6\u4ee3\uff0c\u8d8a\u6765\u8d8a\u591a\u7684\u4eba\u8f6c\u5411\u57fa\u4e8e\u53d8\u5f62\u91d1\u521a\u7684\u6a21\u578b\u3002 \u57fa\u4e8e Transformer \u7684\u7f51\u7edc\u80fd\u591f\u5904\u7406\u672c\u8d28\u4e0a\u957f\u671f\u7684\u4f9d\u8d56\u5173\u7cfb\u3002 LSTM \u4ec5\u5f53\u5b83\u770b\u5230\u524d\u4e00\u4e2a\u5355\u8bcd\u65f6\u624d\u67e5\u770b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u3002 \u53d8\u538b\u5668\u7684\u60c5\u51b5\u5e76\u975e\u5982\u6b64\u3002 \u5b83\u53ef\u4ee5\u540c\u65f6\u67e5\u770b\u6574\u4e2a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u3002 \u56e0\u6b64\uff0c\u53e6\u4e00\u4e2a\u4f18\u70b9\u662f\u5b83\u53ef\u4ee5\u8f7b\u677e\u5e76\u884c\u5316\u5e76\u66f4\u6709\u6548\u5730\u4f7f\u7528 GPU\u3002 Transformers \u662f\u4e00\u4e2a\u975e\u5e38\u5e7f\u6cdb\u7684\u8bdd\u9898\uff0c\u6709\u592a\u591a\u7684\u6a21\u578b\uff1a BERT\u3001RoBERTa\u3001XLNet\u3001XLM-RoBERTa\u3001T5 \u7b49\u3002\u6211\u5c06\u5411\u60a8\u5c55\u793a\u4e00\u79cd\u53ef\u7528\u4e8e\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\uff08T5 \u9664\u5916\uff09\u8fdb\u884c\u5206\u7c7b\u7684\u901a\u7528\u65b9\u6cd5 \u6211\u4eec\u4e00\u76f4\u5728\u8ba8\u8bba\u7684\u95ee\u9898\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u4e9b\u53d8\u538b\u5668\u9700\u8981\u8bad\u7ec3\u5b83\u4eec\u6240\u9700\u7684\u8ba1\u7b97\u80fd\u529b\u3002 \u56e0\u6b64\uff0c\u5982\u679c\u60a8\u6ca1\u6709\u9ad8\u7aef\u7cfb\u7edf\uff0c\u4e0e\u57fa\u4e8e LSTM \u6216 TF-IDF \u7684\u6a21\u578b\u76f8\u6bd4\uff0c\u8bad\u7ec3\u6a21\u578b\u53ef\u80fd\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u3002 \u6211\u4eec\u8981\u505a\u7684\u7b2c\u4e00\u4ef6\u4e8b\u662f\u521b\u5efa\u4e00\u4e2a\u914d\u7f6e\u6587\u4ef6\u3002 import transformers MAX_LEN = 512 TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 EPOCHS = 10 BERT_PATH = \"../input/bert_base_uncased/\" MODEL_PATH = \"model.bin\" TRAINING_FILE = \"../input/imdb.csv\" TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u8fd9\u91cc\u7684\u914d\u7f6e\u6587\u4ef6\u662f\u6211\u4eec\u5b9a\u4e49\u5206\u8bcd\u5668\u548c\u5176\u4ed6\u6211\u4eec\u60f3\u8981\u7ecf\u5e38\u66f4\u6539\u7684\u53c2\u6570\u7684\u552f\u4e00\u5730\u65b9 \u2014\u2014 \u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u505a\u5f88\u591a\u5b9e\u9a8c\u800c\u4e0d\u9700\u8981\u8fdb\u884c\u5927\u91cf\u66f4\u6539\u3002 \u4e0b\u4e00\u6b65\u662f\u6784\u5efa\u6570\u636e\u96c6\u7c7b\u3002 import config import torch class BERTDataset : def __init__ ( self , review , target ): self . review = review self . target = target self . tokenizer = config . TOKENIZER self . max_len = config . MAX_LEN def __len__ ( self ): return len ( self . review ) def __getitem__ ( self , item ): review = str ( self . review [ item ]) review = \" \" . join ( review . split ()) inputs = self . tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = self . max_len , pad_to_max_length = True , ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] return { \"ids\" : torch . tensor ( ids , dtype = torch . long ), \"mask\" : torch . tensor ( mask , dtype = torch . long ), \"token_type_ids\" : torch . tensor ( token_type_ids , dtype = torch . long ), \"targets\" : torch . tensor ( self . target [ item ], dtype = torch . float ) } \u73b0\u5728\u6211\u4eec\u6765\u5230\u4e86\u8be5\u9879\u76ee\u7684\u6838\u5fc3\uff0c\u5373\u6a21\u578b\u3002 import config import transformers import torch.nn as nn class BERTBaseUncased ( nn . Module ): def __init__ ( self ): super ( BERTBaseUncased , self ) . __init__ () self . bert = transformers . BertModel . from_pretrained ( config . BERT_PATH ) self . bert_drop = nn . Dropout ( 0.3 ) self . out = nn . Linear ( 768 , 1 ) def forward ( self , ids , mask , token_type_ids ): hidden state _ , o2 = self . bert ( ids , attention_mask = mask , token_type_ids = token_type_ids ) bo = self . bert_drop ( o2 ) output = self . out ( bo ) return output \u8be5\u6a21\u578b\u8fd4\u56de\u5355\u4e2a\u8f93\u51fa\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e26\u6709 logits \u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\uff0c\u5b83\u9996\u5148\u5e94\u7528 sigmoid\uff0c\u7136\u540e\u8ba1\u7b97\u635f\u5931\u3002 \u8fd9\u662f\u5728engine.py \u4e2d\u5b8c\u6210\u7684\u3002 import torch import torch.nn as nn def loss_fn ( outputs , targets ): return nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) def train_fn ( data_loader , model , optimizer , device , scheduler ): model . train () for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) loss = loss_fn ( outputs , targets ) loss . backward () optimizer . step () scheduler . step () def eval_fn ( data_loader , model , device ): model . eval () fin_targets = [] fin_outputs = [] with torch . no_grad (): for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) targets = targets . cpu () . detach () fin_targets . extend ( targets . numpy () . tolist ()) outputs = torch . sigmoid ( outputs ) . cpu () . detach () fin_outputs . extend ( outputs . numpy () . tolist ()) return fin_outputs , fin_targets \u6700\u540e\uff0c\u6211\u4eec\u51c6\u5907\u597d\u8bad\u7ec3\u4e86\u3002 \u6211\u4eec\u6765\u770b\u770b\u8bad\u7ec3\u811a\u672c\u5427\uff01 import config import dataset import engine import torch import pandas as pd import torch.nn as nn import numpy as np from model import BERTBaseUncased from sklearn import model_selection from sklearn import metrics from transformers import AdamW from transformers import get_linear_schedule_with_warmup def train (): dfx = pd . read_csv ( config . TRAINING_FILE ) . fillna ( \"none\" ) dfx . sentiment = dfx . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df_train , df_valid = model_selection . train_test_split ( dfx , test_size = 0.1 , random_state = 42 , stratify = dfx . sentiment . values ) df_train = df_train . reset_index ( drop = True ) df_valid = df_valid . reset_index ( drop = True ) train_dataset = dataset . BERTDataset ( review = df_train . review . values , target = df_train . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 4 ) valid_dataset = dataset . BERTDataset ( review = df_valid . review . values , target = df_valid . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) device = torch . device ( \"cuda\" ) model = BERTBaseUncased () model . to ( device ) param_optimizer = list ( model . named_parameters ()) no_decay = [ \"bias\" , \"LayerNorm.bias\" , \"LayerNorm.weight\" ] optimizer_parameters = [ { \"params\" : [ p for n , p in param_optimizer if not any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.001 , } \uff0c { \"params\" : [ p for n , p in param_optimizer if any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.0 , }] num_train_steps = int ( len ( df_train ) / config . TRAIN_BATCH_SIZE * config . EPOCHS ) optimizer = AdamW ( optimizer_parameters , lr = 3e-5 ) scheduler = get_linear_schedule_with_warmup ( optimizer , num_warmup_steps = 0 , num_training_steps = num_train_steps ) model = nn . DataParallel ( model ) best_accuracy = 0 for epoch in range ( config . EPOCHS ): engine . train_fn ( train_data_loader , model , optimizer , device , scheduler ) outputs , targets = engine . eval_fn ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : torch . save ( model . state_dict (), config . MODEL_PATH ) best_accuracy = accuracy if __name__ == \"__main__\" : train () \u4e4d\u4e00\u770b\u53ef\u80fd\u770b\u8d77\u6765\u5f88\u591a\uff0c\u4f46\u4e00\u65e6\u60a8\u4e86\u89e3\u4e86\u5404\u4e2a\u7ec4\u4ef6\uff0c\u5c31\u4e0d\u518d\u90a3\u4e48\u7b80\u5355\u4e86\u3002 \u60a8\u53ea\u9700\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u5373\u53ef\u8f7b\u677e\u5c06\u5176\u66f4\u6539\u4e3a\u60a8\u60f3\u8981\u4f7f\u7528\u7684\u4efb\u4f55\u5176\u4ed6\u53d8\u538b\u5668\u6a21\u578b\u3002 \u8be5\u6a21\u578b\u7684\u51c6\u786e\u7387\u4e3a 93%\uff01 \u54c7\uff01 \u8fd9\u6bd4\u4efb\u4f55\u5176\u4ed6\u6a21\u578b\u90fd\u8981\u597d\u5f97\u591a\u3002 \u4f46\u662f\u8fd9\u503c\u5f97\u5417\uff1f \u6211\u4eec\u4f7f\u7528 LSTM \u80fd\u591f\u5b9e\u73b0 90% \u7684\u76ee\u6807\uff0c\u800c\u4e14\u5b83\u4eec\u66f4\u7b80\u5355\u3001\u66f4\u5bb9\u6613\u8bad\u7ec3\u5e76\u4e14\u63a8\u7406\u901f\u5ea6\u66f4\u5feb\u3002 \u901a\u8fc7\u4f7f\u7528\u4e0d\u540c\u7684\u6570\u636e\u5904\u7406\u6216\u8c03\u6574\u5c42\u3001\u8282\u70b9\u3001dropout\u3001\u5b66\u4e60\u7387\u3001\u66f4\u6539\u4f18\u5316\u5668\u7b49\u53c2\u6570\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6a21\u578b\u6539\u8fdb\u4e00\u4e2a\u767e\u5206\u70b9\u3002\u7136\u540e\u6211\u4eec\u5c06\u4ece BERT \u4e2d\u83b7\u5f97\u7ea6 2% \u7684\u6536\u76ca\u3002 \u53e6\u4e00\u65b9\u9762\uff0cBERT \u7684\u8bad\u7ec3\u65f6\u95f4\u8981\u957f\u5f97\u591a\uff0c\u53c2\u6570\u5f88\u591a\uff0c\u800c\u4e14\u63a8\u7406\u901f\u5ea6\u4e5f\u5f88\u6162\u3002 \u6700\u540e\uff0c\u60a8\u5e94\u8be5\u5ba1\u89c6\u81ea\u5df1\u7684\u4e1a\u52a1\u5e76\u505a\u51fa\u660e\u667a\u7684\u9009\u62e9\u3002 \u4e0d\u8981\u4ec5\u4ec5\u56e0\u4e3a BERT\u201c\u9177\u201d\u800c\u9009\u62e9\u5b83\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6211\u4eec\u5728\u8fd9\u91cc\u8ba8\u8bba\u7684\u552f\u4e00\u4efb\u52a1\u662f\u5206\u7c7b\uff0c\u4f46\u5c06\u5176\u66f4\u6539\u4e3a\u56de\u5f52\u3001\u591a\u6807\u7b7e\u6216\u591a\u7c7b\u53ea\u9700\u8981\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u3002 \u4f8b\u5982\uff0c\u591a\u7c7b\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u7684\u540c\u4e00\u95ee\u9898\u5c06\u6709\u591a\u4e2a\u8f93\u51fa\u548c\u4ea4\u53c9\u71b5\u635f\u5931\u3002 \u5176\u4ed6\u4e00\u5207\u90fd\u5e94\u8be5\u4fdd\u6301\u4e0d\u53d8\u3002 \u81ea\u7136\u8bed\u8a00\u5904\u7406\u975e\u5e38\u5e9e\u5927\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u5176\u4e2d\u7684\u4e00\u5c0f\u90e8\u5206\u3002 \u663e\u7136\uff0c\u8fd9\u662f\u4e00\u4e2a\u5f88\u5927\u7684\u6bd4\u4f8b\uff0c\u56e0\u4e3a\u5927\u591a\u6570\u5de5\u4e1a\u6a21\u578b\u90fd\u662f\u5206\u7c7b\u6216\u56de\u5f52\u6a21\u578b\u3002 \u5982\u679c\u6211\u5f00\u59cb\u8be6\u7ec6\u5199\u6240\u6709\u5185\u5bb9\uff0c\u6211\u6700\u7ec8\u53ef\u80fd\u4f1a\u5199\u51e0\u767e\u9875\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u51b3\u5b9a\u5c06\u6240\u6709\u5185\u5bb9\u5305\u542b\u5728\u4e00\u672c\u5355\u72ec\u7684\u4e66\u4e2d\uff1a\u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55 NLP \u95ee\u9898\uff01","title":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5"},{"location":"%E6%97%A0%E7%9B%91%E7%9D%A3%E5%92%8C%E6%9C%89%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0/","text":"\u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60 \u5728\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u65f6\uff0c\u901a\u5e38\u6709\u4e24\u7c7b\u6570\u636e\uff08\u548c\u673a\u5668\u5b66\u4e60\u6a21\u578b\uff09\uff1a \u76d1\u7763\u6570\u636e\uff1a\u603b\u662f\u6709\u4e00\u4e2a\u6216\u591a\u4e2a\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807 \u65e0\u76d1\u7763\u6570\u636e\uff1a\u6ca1\u6709\u4efb\u4f55\u76ee\u6807\u53d8\u91cf\u3002 \u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\u3002\u6211\u4eec\u9700\u8981\u9884\u6d4b\u4e00\u4e2a\u503c\u7684\u95ee\u9898\u88ab\u79f0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u3002\u4f8b\u5982\uff0c\u5982\u679c\u95ee\u9898\u662f\u6839\u636e\u5386\u53f2\u623f\u4ef7\u9884\u6d4b\u623f\u4ef7\uff0c\u90a3\u4e48\u533b\u9662\u3001\u5b66\u6821\u6216\u8d85\u5e02\u7684\u5b58\u5728\uff0c\u4e0e\u6700\u8fd1\u516c\u5171\u4ea4\u901a\u7684\u8ddd\u79bb\u7b49\u7279\u5f81\u5c31\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u540c\u6837\uff0c\u5f53\u6211\u4eec\u5f97\u5230\u732b\u548c\u72d7\u7684\u56fe\u50cf\u65f6\uff0c\u6211\u4eec\u4e8b\u5148\u77e5\u9053\u54ea\u4e9b\u662f\u732b\uff0c\u54ea\u4e9b\u662f\u72d7\uff0c\u5982\u679c\u4efb\u52a1\u662f\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u6240\u63d0\u4f9b\u7684\u56fe\u50cf\u662f\u732b\u8fd8\u662f\u72d7\uff0c\u90a3\u4e48\u8fd9\u4e2a\u95ee\u9898\u5c31\u88ab\u8ba4\u4e3a\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002 \u56fe 1\uff1a\u6709\u76d1\u7763\u5b66\u4e60\u6570\u636e \u5982\u56fe 1 \u6240\u793a\uff0c\u6570\u636e\u7684\u6bcf\u4e00\u884c\u90fd\u4e0e\u4e00\u4e2a\u76ee\u6807\u6216\u6807\u7b7e\u76f8\u5173\u8054\u3002\u5217\u662f\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u884c\u4ee3\u8868\u4e0d\u540c\u7684\u6570\u636e\u70b9\uff0c\u901a\u5e38\u79f0\u4e3a\u6837\u672c\u3002\u793a\u4f8b\u4e2d\u7684\u5341\u4e2a\u6837\u672c\u6709\u5341\u4e2a\u7279\u5f81\u548c\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\uff0c\u76ee\u6807\u53d8\u91cf\u53ef\u4ee5\u662f\u6570\u5b57\u6216\u7c7b\u522b\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5206\u7c7b\u53d8\u91cf\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5b9e\u6570\uff0c\u95ee\u9898\u5c31\u88ab\u5b9a\u4e49\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6709\u76d1\u7763\u95ee\u9898\u53ef\u5206\u4e3a\u4e24\u4e2a\u5b50\u7c7b\uff1a \u5206\u7c7b\uff1a\u9884\u6d4b\u7c7b\u522b\uff0c\u5982\u732b\u6216\u72d7 \u56de\u5f52\uff1a\u9884\u6d4b\u503c\uff0c\u5982\u623f\u4ef7 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6709\u65f6\u6211\u4eec\u53ef\u80fd\u4f1a\u5728\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u4f7f\u7528\u56de\u5f52\uff0c\u8fd9\u53d6\u51b3\u4e8e\u7528\u4e8e\u8bc4\u4f30\u7684\u6307\u6807\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u7a0d\u540e\u4f1a\u8ba8\u8bba\u8fd9\u4e2a\u95ee\u9898\u3002 \u53e6\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u95ee\u9898\u662f\u65e0\u76d1\u7763\u7c7b\u578b\u3002 \u65e0\u76d1\u7763 \u6570\u636e\u96c6\u6ca1\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\uff0c\u4e00\u822c\u6765\u8bf4\uff0c\u4e0e\u6709\u76d1\u7763\u95ee\u9898\u76f8\u6bd4\uff0c\u5904\u7406\u65e0\u76d1\u7763\u6570\u636e\u96c6\u66f4\u5177\u6311\u6218\u6027\u3002 \u5047\u8bbe\u4f60\u5728\u4e00\u5bb6\u5904\u7406\u4fe1\u7528\u5361\u4ea4\u6613\u7684\u91d1\u878d\u516c\u53f8\u5de5\u4f5c\u3002\u6bcf\u79d2\u949f\u90fd\u6709\u5927\u91cf\u6570\u636e\u6d8c\u5165\u3002\u552f\u4e00\u7684\u95ee\u9898\u662f\uff0c\u5f88\u96be\u627e\u5230\u4e00\u4e2a\u4eba\u6765\u5c06\u6bcf\u7b14\u4ea4\u6613\u6807\u8bb0\u4e3a\u6709\u6548\u4ea4\u6613\u3001\u771f\u5b9e\u4ea4\u6613\u6216\u6b3a\u8bc8\u4ea4\u6613\u3002\u5f53\u6211\u4eec\u6ca1\u6709\u4efb\u4f55\u5173\u4e8e\u4ea4\u6613\u662f\u6b3a\u8bc8\u8fd8\u662f\u771f\u5b9e\u7684\u4fe1\u606f\u65f6\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u65e0\u76d1\u7763\u95ee\u9898\u3002\u8981\u89e3\u51b3\u8fd9\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5fc5\u987b\u8003\u8651\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a\u591a\u5c11\u4e2a \u805a\u7c7b \u3002\u805a\u7c7b\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd8\u6709\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5e94\u7528\u4e8e\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5bf9\u4e8e\u6b3a\u8bc8\u68c0\u6d4b\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\u6570\u636e\u53ef\u4ee5\u5206\u4e3a\u4e24\u7c7b\uff08\u6b3a\u8bc8\u6216\u771f\u5b9e\uff09\u3002 \u5f53\u6211\u4eec\u77e5\u9053\u805a\u7c7b\u7684\u6570\u91cf\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u805a\u7c7b\u7b97\u6cd5\u6765\u89e3\u51b3\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5728\u56fe 2 \u4e2d\uff0c\u5047\u8bbe\u6570\u636e\u5206\u4e3a\u4e24\u7c7b\uff0c\u6df1\u8272\u4ee3\u8868\u6b3a\u8bc8\uff0c\u6d45\u8272\u4ee3\u8868\u771f\u5b9e\u4ea4\u6613\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528\u805a\u7c7b\u65b9\u6cd5\u4e4b\u524d\uff0c\u6211\u4eec\u5e76\u4e0d\u77e5\u9053\u8fd9\u4e9b\u7c7b\u522b\u3002\u5e94\u7528\u805a\u7c7b\u7b97\u6cd5\u540e\uff0c\u6211\u4eec\u5e94\u8be5\u80fd\u591f\u533a\u5206\u8fd9\u4e24\u4e2a\u5047\u5b9a\u76ee\u6807\u3002 \u4e3a\u4e86\u7406\u89e3\u65e0\u76d1\u7763\u95ee\u9898\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u8bb8\u591a\u5206\u89e3\u6280\u672f\uff0c\u5982 \u4e3b\u6210\u5206\u5206\u6790\uff08PCA\uff09\u3001t-\u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09 \u7b49\u3002 \u6709\u76d1\u7763\u7684\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\uff0c\u56e0\u4e3a\u5b83\u4eec\u5f88\u5bb9\u6613\u8bc4\u4f30\u3002\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u8bc4\u4f30\u6280\u672f\u3002\u7136\u800c\uff0c\u5bf9\u65e0\u76d1\u7763\u7b97\u6cd5\u7684\u7ed3\u679c\u8fdb\u884c\u8bc4\u4f30\u5177\u6709\u6311\u6218\u6027\uff0c\u9700\u8981\u5927\u91cf\u7684\u4eba\u4e3a\u5e72\u9884\u6216\u542f\u53d1\u5f0f\u65b9\u6cd5\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3b\u8981\u5173\u6ce8\u6709\u76d1\u7763\u6570\u636e\u548c\u6a21\u578b\uff0c\u4f46\u8fd9\u5e76\u4e0d\u610f\u5473\u7740\u6211\u4eec\u4f1a\u5ffd\u7565\u65e0\u76d1\u7763\u6570\u636e\u95ee\u9898\u3002 \u56fe 2\uff1a\u65e0\u76d1\u7763\u5b66\u4e60\u6570\u636e\u96c6 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u5f53\u4eba\u4eec\u5f00\u59cb\u5b66\u4e60\u6570\u636e\u79d1\u5b66\u6216\u673a\u5668\u5b66\u4e60\u65f6\uff0c\u90fd\u4f1a\u4ece\u975e\u5e38\u8457\u540d\u7684\u6570\u636e\u96c6\u5f00\u59cb\uff0c\u4f8b\u5982\u6cf0\u5766\u5c3c\u514b\u6570\u636e\u96c6\u6216\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\uff0c\u8fd9\u4e9b\u90fd\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u5728\u6cf0\u5766\u5c3c\u514b\u53f7\u6570\u636e\u96c6\u4e2d\uff0c\u4f60\u5fc5\u987b\u6839\u636e\u8239\u7968\u7b49\u7ea7\u3001\u6027\u522b\u3001\u5e74\u9f84\u7b49\u56e0\u7d20\u9884\u6d4b\u6cf0\u5766\u5c3c\u514b\u53f7\u4e0a\u4e58\u5ba2\u7684\u5b58\u6d3b\u7387\u3002\u540c\u6837\uff0c\u5728\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\u4e2d\uff0c\u60a8\u5fc5\u987b\u6839\u636e\u843c\u7247\u5bbd\u5ea6\u3001\u82b1\u74e3\u957f\u5ea6\u3001\u843c\u7247\u957f\u5ea6\u548c\u82b1\u74e3\u5bbd\u5ea6\u7b49\u56e0\u7d20\u9884\u6d4b\u82b1\u7684\u79cd\u7c7b\u3002 \u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u80fd\u5305\u62ec\u7528\u4e8e\u5ba2\u6237\u7ec6\u5206\u7684\u6570\u636e\u96c6\u3002 \u4f8b\u5982\uff0c\u60a8\u62e5\u6709\u8bbf\u95ee\u60a8\u7684\u7535\u5b50\u5546\u52a1\u7f51\u7ad9\u7684\u5ba2\u6237\u6570\u636e\uff0c\u6216\u8005\u8bbf\u95ee\u5546\u5e97\u6216\u5546\u573a\u7684\u5ba2\u6237\u6570\u636e\uff0c\u800c\u60a8\u5e0c\u671b\u5c06\u5b83\u4eec\u7ec6\u5206\u6216\u805a\u7c7b\u4e3a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u65e0\u76d1\u7763\u6570\u636e\u96c6\u7684\u53e6\u4e00\u4e2a\u4f8b\u5b50\u53ef\u80fd\u5305\u62ec\u4fe1\u7528\u5361\u6b3a\u8bc8\u68c0\u6d4b\u6216\u5bf9\u51e0\u5f20\u56fe\u7247\u8fdb\u884c\u805a\u7c7b\u7b49\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd8\u53ef\u4ee5\u5c06\u6709\u76d1\u7763\u6570\u636e\u96c6\u8f6c\u6362\u4e3a\u65e0\u76d1\u7763\u6570\u636e\u96c6\uff0c\u4ee5\u67e5\u770b\u5b83\u4eec\u5728\u7ed8\u5236\u65f6\u7684\u6548\u679c\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u56fe 3 \u4e2d\u7684\u6570\u636e\u96c6\u3002\u56fe 3 \u663e\u793a\u7684\u662f MNIST \u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u6d41\u884c\u7684\u624b\u5199\u6570\u5b57\u6570\u636e\u96c6\uff0c\u5b83\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\uff0c\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u4f60\u4f1a\u5f97\u5230\u6570\u5b57\u56fe\u50cf\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u6b63\u786e\u6807\u7b7e\u3002\u4f60\u5fc5\u987b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u53ea\u63d0\u4f9b\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\u8bc6\u522b\u51fa\u54ea\u4e2a\u6570\u5b57\u662f\u5b83\u3002 \u56fe 3\uff1aMNIST\u6570\u636e\u96c6 \u5982\u679c\u6211\u4eec\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c t \u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09\u5206\u89e3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u53ea\u9700\u5728\u56fe\u50cf\u50cf\u7d20\u4e0a\u964d\u7ef4\u81f3 2 \u4e2a\u7ef4\u5ea6\uff0c\u5c31\u80fd\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4e0a\u5206\u79bb\u56fe\u50cf\u3002\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1aMNIST \u6570\u636e\u96c6\u7684 t-SNE \u53ef\u89c6\u5316\u3002\u4f7f\u7528\u4e86 3000 \u5e45\u56fe\u50cf\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002\u9996\u5148\u662f\u5bfc\u5165\u6240\u6709\u9700\u8981\u7684\u5e93\u3002 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from sklearn import datasets from sklearn import manifold % matplotlib inline \u6211\u4eec\u4f7f\u7528 matplotlib \u548c seaborn \u8fdb\u884c\u7ed8\u56fe\uff0c\u4f7f\u7528 numpy \u5904\u7406\u6570\u503c\u6570\u7ec4\uff0c\u4f7f\u7528 pandas \u4ece\u6570\u503c\u6570\u7ec4\u521b\u5efa\u6570\u636e\u5e27\uff0c\u4f7f\u7528 scikit-learn (sklearn) \u83b7\u53d6\u6570\u636e\u5e76\u6267\u884c t-SNE\u3002 \u5bfc\u5165\u540e\uff0c\u6211\u4eec\u9700\u8981\u4e0b\u8f7d\u6570\u636e\u5e76\u5355\u72ec\u8bfb\u53d6\uff0c\u6216\u8005\u4f7f\u7528 sklearn \u7684\u5185\u7f6e\u51fd\u6570\u6765\u63d0\u4f9b MNIST \u6570\u636e\u96c6\u3002 data = datasets . fetch_openml ( 'mnist_784' , version = 1 , return_X_y = True ) pixel_values , targets = data targets = targets . astype ( int ) \u5728\u8fd9\u90e8\u5206\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 sklearn \u6570\u636e\u96c6\u83b7\u53d6\u4e86\u6570\u636e\uff0c\u5e76\u83b7\u5f97\u4e86\u4e00\u4e2a\u50cf\u7d20\u503c\u6570\u7ec4\u548c\u53e6\u4e00\u4e2a\u76ee\u6807\u6570\u7ec4\u3002\u7531\u4e8e\u76ee\u6807\u662f\u5b57\u7b26\u4e32\u7c7b\u578b\uff0c\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u3002 pixel_values \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a 70000x784 \u7684\u4e8c\u7ef4\u6570\u7ec4\u3002 \u5171\u6709 70000 \u5f20\u4e0d\u540c\u7684\u56fe\u50cf\uff0c\u6bcf\u5f20\u56fe\u50cf\u5927\u5c0f\u4e3a 28x28 \u50cf\u7d20\u3002\u5e73\u94fa 28x28 \u540e\u5f97\u5230 784 \u4e2a\u6570\u636e\u70b9\u3002 \u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u91cd\u5851\u4e3a\u539f\u6765\u7684\u5f62\u72b6\uff0c\u7136\u540e\u4f7f\u7528 matplotlib \u7ed8\u5236\u6210\u56fe\u8868\uff0c\u4ece\u800c\u5c06\u5176\u53ef\u89c6\u5316\u3002 single_image = pixel_values [ 1 , :] . reshape ( 28 , 28 ) plt . imshow ( single_image , cmap = 'gray' ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u7ed8\u5236\u5982\u4e0b\u56fe\u50cf\uff1a \u56fe 5\uff1a\u7ed8\u5236MNIST\u6570\u636e\u96c6\u5355\u5f20\u56fe\u7247 \u6700\u91cd\u8981\u7684\u4e00\u6b65\u662f\u5728\u6211\u4eec\u83b7\u53d6\u6570\u636e\u4e4b\u540e\u3002 tsne = manifold . TSNE ( n_components = 2 , random_state = 42 ) transformed_data = tsne . fit_transform ( pixel_values [: 3000 , :]) \u8fd9\u4e00\u6b65\u521b\u5efa\u4e86\u6570\u636e\u7684 t-SNE \u53d8\u6362\u3002\u6211\u4eec\u53ea\u4f7f\u7528 2 \u4e2a\u7ef4\u5ea6\uff0c\u56e0\u4e3a\u5728\u4e8c\u7ef4\u73af\u5883\u4e2d\u53ef\u4ee5\u5f88\u597d\u5730\u5c06\u5b83\u4eec\u53ef\u89c6\u5316\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u8f6c\u6362\u540e\u7684\u6570\u636e\u662f\u4e00\u4e2a 3000x2 \u5f62\u72b6\u7684\u6570\u7ec4\uff083000 \u884c 2 \u5217\uff09\u3002\u5728\u6570\u7ec4\u4e0a\u8c03\u7528 pd.DataFrame \u53ef\u4ee5\u5c06\u8fd9\u6837\u7684\u6570\u636e\u8f6c\u6362\u4e3a pandas \u6570\u636e\u5e27\u3002 tsne_df = pd . DataFrame ( np . column_stack (( transformed_data , targets [: 3000 ])), columns = [ \"x\" , \"y\" , \"targets\" ]) tsne_df . loc [:, \"targets\" ] = tsne_df . targets . astype ( int ) \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a numpy \u6570\u7ec4\u521b\u5efa\u4e00\u4e2a pandas \u6570\u636e\u5e27\u3002x \u548c y \u662f t-SNE \u5206\u89e3\u7684\u4e24\u4e2a\u7ef4\u5ea6\uff0ctarget \u662f\u5b9e\u9645\u6570\u5b57\u3002\u8fd9\u6837\u6211\u4eec\u5c31\u5f97\u5230\u4e86\u5982\u56fe 6 \u6240\u793a\u7684\u6570\u636e\u5e27\u3002 \u56fe 6\uff1at-SNE\u540e\u6570\u636e\u524d10\u884c \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 seaborn \u548c matplotlib \u7ed8\u5236\u5b83\u3002 grid = sns . FacetGrid ( tsne_df , hue = \"targets\" , size = 8 ) grid . map ( plt . scatter , \"x\" , \"y\" ) . add_legend () \u8fd9\u662f\u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u89c6\u5316\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5728\u540c\u4e00\u6570\u636e\u96c6\u4e0a\u8fdb\u884c k-means \u805a\u7c7b\uff0c\u770b\u770b\u5b83\u5728\u65e0\u76d1\u7763\u73af\u5883\u4e0b\u7684\u8868\u73b0\u5982\u4f55\u3002\u4e00\u4e2a\u7ecf\u5e38\u51fa\u73b0\u7684\u95ee\u9898\u662f\uff0c\u5982\u4f55\u5728 k-means \u805a\u7c7b\u4e2d\u627e\u5230\u6700\u4f73\u7684\u7c07\u6570\u3002\u8fd9\u4e2a\u95ee\u9898\u6ca1\u6709\u6b63\u786e\u7b54\u6848\u3002\u4f60\u5fc5\u987b\u901a\u8fc7\u4ea4\u53c9\u9a8c\u8bc1\u6765\u627e\u5230\u6700\u4f73\u7c07\u6570\u3002\u672c\u4e66\u7a0d\u540e\u5c06\u8ba8\u8bba\u4ea4\u53c9\u9a8c\u8bc1\u3002\u8bf7\u6ce8\u610f\uff0c\u4e0a\u8ff0\u4ee3\u7801\u662f\u5728 jupyter \u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\u7684\u3002 \u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 jupyter \u505a\u4e00\u4e9b\u7b80\u5355\u7684\u4e8b\u60c5\uff0c\u6bd4\u5982\u4e0a\u9762\u7684\u4f8b\u5b50\u548c \u7ed8\u56fe\u3002\u5bf9\u4e8e\u672c\u4e66\u4e2d\u7684\u5927\u90e8\u5206\u5185\u5bb9\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 python \u811a\u672c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6 IDE \u56e0\u4e3a\u7ed3\u679c\u90fd\u662f\u4e00\u6837\u7684\u3002 MNIST \u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u628a\u5b83\u8f6c\u6362\u6210\u4e00\u4e2a\u65e0\u76d1\u7763\u7684\u95ee\u9898\uff0c\u53ea\u662f\u4e3a\u4e86\u68c0\u67e5\u5b83\u662f\u5426\u80fd\u5e26\u6765\u4efb\u4f55\u597d\u7684\u7ed3\u679c\u3002\u5982\u679c\u6211\u4eec\u4f7f\u7528\u5206\u7c7b\u7b97\u6cd5\uff0c\u6548\u679c\u4f1a\u66f4\u597d\u3002\u8ba9\u6211\u4eec\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u4e00\u63a2\u7a76\u7adf\u3002","title":"\u6709\u76d1\u7763\u548c\u65e0\u76d1\u7763\u5b66\u4e60"},{"location":"%E6%97%A0%E7%9B%91%E7%9D%A3%E5%92%8C%E6%9C%89%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0/#_1","text":"\u5728\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u65f6\uff0c\u901a\u5e38\u6709\u4e24\u7c7b\u6570\u636e\uff08\u548c\u673a\u5668\u5b66\u4e60\u6a21\u578b\uff09\uff1a \u76d1\u7763\u6570\u636e\uff1a\u603b\u662f\u6709\u4e00\u4e2a\u6216\u591a\u4e2a\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807 \u65e0\u76d1\u7763\u6570\u636e\uff1a\u6ca1\u6709\u4efb\u4f55\u76ee\u6807\u53d8\u91cf\u3002 \u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\u3002\u6211\u4eec\u9700\u8981\u9884\u6d4b\u4e00\u4e2a\u503c\u7684\u95ee\u9898\u88ab\u79f0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u3002\u4f8b\u5982\uff0c\u5982\u679c\u95ee\u9898\u662f\u6839\u636e\u5386\u53f2\u623f\u4ef7\u9884\u6d4b\u623f\u4ef7\uff0c\u90a3\u4e48\u533b\u9662\u3001\u5b66\u6821\u6216\u8d85\u5e02\u7684\u5b58\u5728\uff0c\u4e0e\u6700\u8fd1\u516c\u5171\u4ea4\u901a\u7684\u8ddd\u79bb\u7b49\u7279\u5f81\u5c31\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u540c\u6837\uff0c\u5f53\u6211\u4eec\u5f97\u5230\u732b\u548c\u72d7\u7684\u56fe\u50cf\u65f6\uff0c\u6211\u4eec\u4e8b\u5148\u77e5\u9053\u54ea\u4e9b\u662f\u732b\uff0c\u54ea\u4e9b\u662f\u72d7\uff0c\u5982\u679c\u4efb\u52a1\u662f\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u6240\u63d0\u4f9b\u7684\u56fe\u50cf\u662f\u732b\u8fd8\u662f\u72d7\uff0c\u90a3\u4e48\u8fd9\u4e2a\u95ee\u9898\u5c31\u88ab\u8ba4\u4e3a\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002 \u56fe 1\uff1a\u6709\u76d1\u7763\u5b66\u4e60\u6570\u636e \u5982\u56fe 1 \u6240\u793a\uff0c\u6570\u636e\u7684\u6bcf\u4e00\u884c\u90fd\u4e0e\u4e00\u4e2a\u76ee\u6807\u6216\u6807\u7b7e\u76f8\u5173\u8054\u3002\u5217\u662f\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u884c\u4ee3\u8868\u4e0d\u540c\u7684\u6570\u636e\u70b9\uff0c\u901a\u5e38\u79f0\u4e3a\u6837\u672c\u3002\u793a\u4f8b\u4e2d\u7684\u5341\u4e2a\u6837\u672c\u6709\u5341\u4e2a\u7279\u5f81\u548c\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\uff0c\u76ee\u6807\u53d8\u91cf\u53ef\u4ee5\u662f\u6570\u5b57\u6216\u7c7b\u522b\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5206\u7c7b\u53d8\u91cf\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5b9e\u6570\uff0c\u95ee\u9898\u5c31\u88ab\u5b9a\u4e49\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6709\u76d1\u7763\u95ee\u9898\u53ef\u5206\u4e3a\u4e24\u4e2a\u5b50\u7c7b\uff1a \u5206\u7c7b\uff1a\u9884\u6d4b\u7c7b\u522b\uff0c\u5982\u732b\u6216\u72d7 \u56de\u5f52\uff1a\u9884\u6d4b\u503c\uff0c\u5982\u623f\u4ef7 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6709\u65f6\u6211\u4eec\u53ef\u80fd\u4f1a\u5728\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u4f7f\u7528\u56de\u5f52\uff0c\u8fd9\u53d6\u51b3\u4e8e\u7528\u4e8e\u8bc4\u4f30\u7684\u6307\u6807\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u7a0d\u540e\u4f1a\u8ba8\u8bba\u8fd9\u4e2a\u95ee\u9898\u3002 \u53e6\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u95ee\u9898\u662f\u65e0\u76d1\u7763\u7c7b\u578b\u3002 \u65e0\u76d1\u7763 \u6570\u636e\u96c6\u6ca1\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\uff0c\u4e00\u822c\u6765\u8bf4\uff0c\u4e0e\u6709\u76d1\u7763\u95ee\u9898\u76f8\u6bd4\uff0c\u5904\u7406\u65e0\u76d1\u7763\u6570\u636e\u96c6\u66f4\u5177\u6311\u6218\u6027\u3002 \u5047\u8bbe\u4f60\u5728\u4e00\u5bb6\u5904\u7406\u4fe1\u7528\u5361\u4ea4\u6613\u7684\u91d1\u878d\u516c\u53f8\u5de5\u4f5c\u3002\u6bcf\u79d2\u949f\u90fd\u6709\u5927\u91cf\u6570\u636e\u6d8c\u5165\u3002\u552f\u4e00\u7684\u95ee\u9898\u662f\uff0c\u5f88\u96be\u627e\u5230\u4e00\u4e2a\u4eba\u6765\u5c06\u6bcf\u7b14\u4ea4\u6613\u6807\u8bb0\u4e3a\u6709\u6548\u4ea4\u6613\u3001\u771f\u5b9e\u4ea4\u6613\u6216\u6b3a\u8bc8\u4ea4\u6613\u3002\u5f53\u6211\u4eec\u6ca1\u6709\u4efb\u4f55\u5173\u4e8e\u4ea4\u6613\u662f\u6b3a\u8bc8\u8fd8\u662f\u771f\u5b9e\u7684\u4fe1\u606f\u65f6\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u65e0\u76d1\u7763\u95ee\u9898\u3002\u8981\u89e3\u51b3\u8fd9\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5fc5\u987b\u8003\u8651\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a\u591a\u5c11\u4e2a \u805a\u7c7b \u3002\u805a\u7c7b\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd8\u6709\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5e94\u7528\u4e8e\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5bf9\u4e8e\u6b3a\u8bc8\u68c0\u6d4b\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\u6570\u636e\u53ef\u4ee5\u5206\u4e3a\u4e24\u7c7b\uff08\u6b3a\u8bc8\u6216\u771f\u5b9e\uff09\u3002 \u5f53\u6211\u4eec\u77e5\u9053\u805a\u7c7b\u7684\u6570\u91cf\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u805a\u7c7b\u7b97\u6cd5\u6765\u89e3\u51b3\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5728\u56fe 2 \u4e2d\uff0c\u5047\u8bbe\u6570\u636e\u5206\u4e3a\u4e24\u7c7b\uff0c\u6df1\u8272\u4ee3\u8868\u6b3a\u8bc8\uff0c\u6d45\u8272\u4ee3\u8868\u771f\u5b9e\u4ea4\u6613\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528\u805a\u7c7b\u65b9\u6cd5\u4e4b\u524d\uff0c\u6211\u4eec\u5e76\u4e0d\u77e5\u9053\u8fd9\u4e9b\u7c7b\u522b\u3002\u5e94\u7528\u805a\u7c7b\u7b97\u6cd5\u540e\uff0c\u6211\u4eec\u5e94\u8be5\u80fd\u591f\u533a\u5206\u8fd9\u4e24\u4e2a\u5047\u5b9a\u76ee\u6807\u3002 \u4e3a\u4e86\u7406\u89e3\u65e0\u76d1\u7763\u95ee\u9898\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u8bb8\u591a\u5206\u89e3\u6280\u672f\uff0c\u5982 \u4e3b\u6210\u5206\u5206\u6790\uff08PCA\uff09\u3001t-\u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09 \u7b49\u3002 \u6709\u76d1\u7763\u7684\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\uff0c\u56e0\u4e3a\u5b83\u4eec\u5f88\u5bb9\u6613\u8bc4\u4f30\u3002\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u8bc4\u4f30\u6280\u672f\u3002\u7136\u800c\uff0c\u5bf9\u65e0\u76d1\u7763\u7b97\u6cd5\u7684\u7ed3\u679c\u8fdb\u884c\u8bc4\u4f30\u5177\u6709\u6311\u6218\u6027\uff0c\u9700\u8981\u5927\u91cf\u7684\u4eba\u4e3a\u5e72\u9884\u6216\u542f\u53d1\u5f0f\u65b9\u6cd5\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3b\u8981\u5173\u6ce8\u6709\u76d1\u7763\u6570\u636e\u548c\u6a21\u578b\uff0c\u4f46\u8fd9\u5e76\u4e0d\u610f\u5473\u7740\u6211\u4eec\u4f1a\u5ffd\u7565\u65e0\u76d1\u7763\u6570\u636e\u95ee\u9898\u3002 \u56fe 2\uff1a\u65e0\u76d1\u7763\u5b66\u4e60\u6570\u636e\u96c6 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u5f53\u4eba\u4eec\u5f00\u59cb\u5b66\u4e60\u6570\u636e\u79d1\u5b66\u6216\u673a\u5668\u5b66\u4e60\u65f6\uff0c\u90fd\u4f1a\u4ece\u975e\u5e38\u8457\u540d\u7684\u6570\u636e\u96c6\u5f00\u59cb\uff0c\u4f8b\u5982\u6cf0\u5766\u5c3c\u514b\u6570\u636e\u96c6\u6216\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\uff0c\u8fd9\u4e9b\u90fd\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u5728\u6cf0\u5766\u5c3c\u514b\u53f7\u6570\u636e\u96c6\u4e2d\uff0c\u4f60\u5fc5\u987b\u6839\u636e\u8239\u7968\u7b49\u7ea7\u3001\u6027\u522b\u3001\u5e74\u9f84\u7b49\u56e0\u7d20\u9884\u6d4b\u6cf0\u5766\u5c3c\u514b\u53f7\u4e0a\u4e58\u5ba2\u7684\u5b58\u6d3b\u7387\u3002\u540c\u6837\uff0c\u5728\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\u4e2d\uff0c\u60a8\u5fc5\u987b\u6839\u636e\u843c\u7247\u5bbd\u5ea6\u3001\u82b1\u74e3\u957f\u5ea6\u3001\u843c\u7247\u957f\u5ea6\u548c\u82b1\u74e3\u5bbd\u5ea6\u7b49\u56e0\u7d20\u9884\u6d4b\u82b1\u7684\u79cd\u7c7b\u3002 \u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u80fd\u5305\u62ec\u7528\u4e8e\u5ba2\u6237\u7ec6\u5206\u7684\u6570\u636e\u96c6\u3002 \u4f8b\u5982\uff0c\u60a8\u62e5\u6709\u8bbf\u95ee\u60a8\u7684\u7535\u5b50\u5546\u52a1\u7f51\u7ad9\u7684\u5ba2\u6237\u6570\u636e\uff0c\u6216\u8005\u8bbf\u95ee\u5546\u5e97\u6216\u5546\u573a\u7684\u5ba2\u6237\u6570\u636e\uff0c\u800c\u60a8\u5e0c\u671b\u5c06\u5b83\u4eec\u7ec6\u5206\u6216\u805a\u7c7b\u4e3a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u65e0\u76d1\u7763\u6570\u636e\u96c6\u7684\u53e6\u4e00\u4e2a\u4f8b\u5b50\u53ef\u80fd\u5305\u62ec\u4fe1\u7528\u5361\u6b3a\u8bc8\u68c0\u6d4b\u6216\u5bf9\u51e0\u5f20\u56fe\u7247\u8fdb\u884c\u805a\u7c7b\u7b49\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd8\u53ef\u4ee5\u5c06\u6709\u76d1\u7763\u6570\u636e\u96c6\u8f6c\u6362\u4e3a\u65e0\u76d1\u7763\u6570\u636e\u96c6\uff0c\u4ee5\u67e5\u770b\u5b83\u4eec\u5728\u7ed8\u5236\u65f6\u7684\u6548\u679c\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u56fe 3 \u4e2d\u7684\u6570\u636e\u96c6\u3002\u56fe 3 \u663e\u793a\u7684\u662f MNIST \u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u6d41\u884c\u7684\u624b\u5199\u6570\u5b57\u6570\u636e\u96c6\uff0c\u5b83\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\uff0c\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u4f60\u4f1a\u5f97\u5230\u6570\u5b57\u56fe\u50cf\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u6b63\u786e\u6807\u7b7e\u3002\u4f60\u5fc5\u987b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u53ea\u63d0\u4f9b\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\u8bc6\u522b\u51fa\u54ea\u4e2a\u6570\u5b57\u662f\u5b83\u3002 \u56fe 3\uff1aMNIST\u6570\u636e\u96c6 \u5982\u679c\u6211\u4eec\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c t \u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09\u5206\u89e3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u53ea\u9700\u5728\u56fe\u50cf\u50cf\u7d20\u4e0a\u964d\u7ef4\u81f3 2 \u4e2a\u7ef4\u5ea6\uff0c\u5c31\u80fd\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4e0a\u5206\u79bb\u56fe\u50cf\u3002\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1aMNIST \u6570\u636e\u96c6\u7684 t-SNE \u53ef\u89c6\u5316\u3002\u4f7f\u7528\u4e86 3000 \u5e45\u56fe\u50cf\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002\u9996\u5148\u662f\u5bfc\u5165\u6240\u6709\u9700\u8981\u7684\u5e93\u3002 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from sklearn import datasets from sklearn import manifold % matplotlib inline \u6211\u4eec\u4f7f\u7528 matplotlib \u548c seaborn \u8fdb\u884c\u7ed8\u56fe\uff0c\u4f7f\u7528 numpy \u5904\u7406\u6570\u503c\u6570\u7ec4\uff0c\u4f7f\u7528 pandas \u4ece\u6570\u503c\u6570\u7ec4\u521b\u5efa\u6570\u636e\u5e27\uff0c\u4f7f\u7528 scikit-learn (sklearn) \u83b7\u53d6\u6570\u636e\u5e76\u6267\u884c t-SNE\u3002 \u5bfc\u5165\u540e\uff0c\u6211\u4eec\u9700\u8981\u4e0b\u8f7d\u6570\u636e\u5e76\u5355\u72ec\u8bfb\u53d6\uff0c\u6216\u8005\u4f7f\u7528 sklearn \u7684\u5185\u7f6e\u51fd\u6570\u6765\u63d0\u4f9b MNIST \u6570\u636e\u96c6\u3002 data = datasets . fetch_openml ( 'mnist_784' , version = 1 , return_X_y = True ) pixel_values , targets = data targets = targets . astype ( int ) \u5728\u8fd9\u90e8\u5206\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 sklearn \u6570\u636e\u96c6\u83b7\u53d6\u4e86\u6570\u636e\uff0c\u5e76\u83b7\u5f97\u4e86\u4e00\u4e2a\u50cf\u7d20\u503c\u6570\u7ec4\u548c\u53e6\u4e00\u4e2a\u76ee\u6807\u6570\u7ec4\u3002\u7531\u4e8e\u76ee\u6807\u662f\u5b57\u7b26\u4e32\u7c7b\u578b\uff0c\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u3002 pixel_values \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a 70000x784 \u7684\u4e8c\u7ef4\u6570\u7ec4\u3002 \u5171\u6709 70000 \u5f20\u4e0d\u540c\u7684\u56fe\u50cf\uff0c\u6bcf\u5f20\u56fe\u50cf\u5927\u5c0f\u4e3a 28x28 \u50cf\u7d20\u3002\u5e73\u94fa 28x28 \u540e\u5f97\u5230 784 \u4e2a\u6570\u636e\u70b9\u3002 \u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u91cd\u5851\u4e3a\u539f\u6765\u7684\u5f62\u72b6\uff0c\u7136\u540e\u4f7f\u7528 matplotlib \u7ed8\u5236\u6210\u56fe\u8868\uff0c\u4ece\u800c\u5c06\u5176\u53ef\u89c6\u5316\u3002 single_image = pixel_values [ 1 , :] . reshape ( 28 , 28 ) plt . imshow ( single_image , cmap = 'gray' ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u7ed8\u5236\u5982\u4e0b\u56fe\u50cf\uff1a \u56fe 5\uff1a\u7ed8\u5236MNIST\u6570\u636e\u96c6\u5355\u5f20\u56fe\u7247 \u6700\u91cd\u8981\u7684\u4e00\u6b65\u662f\u5728\u6211\u4eec\u83b7\u53d6\u6570\u636e\u4e4b\u540e\u3002 tsne = manifold . TSNE ( n_components = 2 , random_state = 42 ) transformed_data = tsne . fit_transform ( pixel_values [: 3000 , :]) \u8fd9\u4e00\u6b65\u521b\u5efa\u4e86\u6570\u636e\u7684 t-SNE \u53d8\u6362\u3002\u6211\u4eec\u53ea\u4f7f\u7528 2 \u4e2a\u7ef4\u5ea6\uff0c\u56e0\u4e3a\u5728\u4e8c\u7ef4\u73af\u5883\u4e2d\u53ef\u4ee5\u5f88\u597d\u5730\u5c06\u5b83\u4eec\u53ef\u89c6\u5316\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u8f6c\u6362\u540e\u7684\u6570\u636e\u662f\u4e00\u4e2a 3000x2 \u5f62\u72b6\u7684\u6570\u7ec4\uff083000 \u884c 2 \u5217\uff09\u3002\u5728\u6570\u7ec4\u4e0a\u8c03\u7528 pd.DataFrame \u53ef\u4ee5\u5c06\u8fd9\u6837\u7684\u6570\u636e\u8f6c\u6362\u4e3a pandas \u6570\u636e\u5e27\u3002 tsne_df = pd . DataFrame ( np . column_stack (( transformed_data , targets [: 3000 ])), columns = [ \"x\" , \"y\" , \"targets\" ]) tsne_df . loc [:, \"targets\" ] = tsne_df . targets . astype ( int ) \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a numpy \u6570\u7ec4\u521b\u5efa\u4e00\u4e2a pandas \u6570\u636e\u5e27\u3002x \u548c y \u662f t-SNE \u5206\u89e3\u7684\u4e24\u4e2a\u7ef4\u5ea6\uff0ctarget \u662f\u5b9e\u9645\u6570\u5b57\u3002\u8fd9\u6837\u6211\u4eec\u5c31\u5f97\u5230\u4e86\u5982\u56fe 6 \u6240\u793a\u7684\u6570\u636e\u5e27\u3002 \u56fe 6\uff1at-SNE\u540e\u6570\u636e\u524d10\u884c \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 seaborn \u548c matplotlib \u7ed8\u5236\u5b83\u3002 grid = sns . FacetGrid ( tsne_df , hue = \"targets\" , size = 8 ) grid . map ( plt . scatter , \"x\" , \"y\" ) . add_legend () \u8fd9\u662f\u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u89c6\u5316\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5728\u540c\u4e00\u6570\u636e\u96c6\u4e0a\u8fdb\u884c k-means \u805a\u7c7b\uff0c\u770b\u770b\u5b83\u5728\u65e0\u76d1\u7763\u73af\u5883\u4e0b\u7684\u8868\u73b0\u5982\u4f55\u3002\u4e00\u4e2a\u7ecf\u5e38\u51fa\u73b0\u7684\u95ee\u9898\u662f\uff0c\u5982\u4f55\u5728 k-means \u805a\u7c7b\u4e2d\u627e\u5230\u6700\u4f73\u7684\u7c07\u6570\u3002\u8fd9\u4e2a\u95ee\u9898\u6ca1\u6709\u6b63\u786e\u7b54\u6848\u3002\u4f60\u5fc5\u987b\u901a\u8fc7\u4ea4\u53c9\u9a8c\u8bc1\u6765\u627e\u5230\u6700\u4f73\u7c07\u6570\u3002\u672c\u4e66\u7a0d\u540e\u5c06\u8ba8\u8bba\u4ea4\u53c9\u9a8c\u8bc1\u3002\u8bf7\u6ce8\u610f\uff0c\u4e0a\u8ff0\u4ee3\u7801\u662f\u5728 jupyter \u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\u7684\u3002 \u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 jupyter \u505a\u4e00\u4e9b\u7b80\u5355\u7684\u4e8b\u60c5\uff0c\u6bd4\u5982\u4e0a\u9762\u7684\u4f8b\u5b50\u548c \u7ed8\u56fe\u3002\u5bf9\u4e8e\u672c\u4e66\u4e2d\u7684\u5927\u90e8\u5206\u5185\u5bb9\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 python \u811a\u672c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6 IDE \u56e0\u4e3a\u7ed3\u679c\u90fd\u662f\u4e00\u6837\u7684\u3002 MNIST \u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u628a\u5b83\u8f6c\u6362\u6210\u4e00\u4e2a\u65e0\u76d1\u7763\u7684\u95ee\u9898\uff0c\u53ea\u662f\u4e3a\u4e86\u68c0\u67e5\u5b83\u662f\u5426\u80fd\u5e26\u6765\u4efb\u4f55\u597d\u7684\u7ed3\u679c\u3002\u5982\u679c\u6211\u4eec\u4f7f\u7528\u5206\u7c7b\u7b97\u6cd5\uff0c\u6548\u679c\u4f1a\u66f4\u597d\u3002\u8ba9\u6211\u4eec\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u4e00\u63a2\u7a76\u7adf\u3002","title":"\u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60"},{"location":"%E7%89%B9%E5%BE%81%E5%B7%A5%E7%A8%8B/","text":"\u7279\u5f81\u5de5\u7a0b \u7279\u5f81\u5de5\u7a0b\u662f\u6784\u5efa\u826f\u597d\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u90e8\u5206\u4e4b\u4e00\u3002\u5982\u679c\u6211\u4eec\u62e5\u6709\u6709\u7528\u7684\u7279\u5f81\uff0c\u6a21\u578b\u5c31\u4f1a\u8868\u73b0\u5f97\u66f4\u597d\u3002\u5728\u8bb8\u591a\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u907f\u514d\u4f7f\u7528\u5927\u578b\u590d\u6742\u6a21\u578b\uff0c\u800c\u4f7f\u7528\u5177\u6709\u5173\u952e\u5de5\u7a0b\u7279\u5f81\u7684\u7b80\u5355\u6a21\u578b\u3002\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\uff0c\u53ea\u6709\u5f53\u4f60\u5bf9\u95ee\u9898\u7684\u9886\u57df\u6709\u4e00\u5b9a\u7684\u4e86\u89e3\uff0c\u5e76\u4e14\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u76f8\u5173\u6570\u636e\u65f6\uff0c\u624d\u80fd\u4ee5\u6700\u4f73\u65b9\u5f0f\u5b8c\u6210\u7279\u5f81\u5de5\u7a0b\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e9b\u901a\u7528\u6280\u672f\uff0c\u4ece\u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u7279\u5f81\u3002\u7279\u5f81\u5de5\u7a0b\u4e0d\u4ec5\u4ec5\u662f\u4ece\u6570\u636e\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\uff0c\u8fd8\u5305\u62ec\u4e0d\u540c\u7c7b\u578b\u7684\u5f52\u4e00\u5316\u548c\u8f6c\u6362\u3002 \u5728\u6709\u5173\u5206\u7c7b\u7279\u5f81\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u4e86\u89e3\u4e86\u7ed3\u5408\u4e0d\u540c\u5206\u7c7b\u53d8\u91cf\u7684\u65b9\u6cd5\u3001\u5982\u4f55\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u8ba1\u6570\u3001\u6807\u7b7e\u7f16\u7801\u548c\u4f7f\u7528\u5d4c\u5165\u3002\u8fd9\u4e9b\u51e0\u4e4e\u90fd\u662f\u5229\u7528\u5206\u7c7b\u53d8\u91cf\u8bbe\u8ba1\u7279\u5f81\u7684\u65b9\u6cd5\u3002\u56e0\u6b64\uff0c\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u7684\u91cd\u70b9\u5c06\u4ec5\u9650\u4e8e\u6570\u503c\u53d8\u91cf\u4ee5\u53ca\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u7684\u7ec4\u5408\u3002 \u8ba9\u6211\u4eec\u4ece\u6700\u7b80\u5355\u4f46\u5e94\u7528\u6700\u5e7f\u6cdb\u7684\u7279\u5f81\u5de5\u7a0b\u6280\u672f\u5f00\u59cb\u3002\u5047\u8bbe\u4f60\u6b63\u5728\u5904\u7406\u65e5\u671f\u548c\u65f6\u95f4\u6570\u636e\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709\u4e00\u4e2a\u5e26\u6709\u65e5\u671f\u7c7b\u578b\u5217\u7684 pandas \u6570\u636e\u5e27\u3002\u5229\u7528\u8fd9\u4e00\u5217\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4ee5\u4e0b\u7279\u5f81\uff1a \u5e74 \u5e74\u4e2d\u7684\u5468 \u6708 \u661f\u671f \u5468\u672b \u5c0f\u65f6 \u8fd8\u6709\u66f4\u591a \u800c\u4f7f\u7528pandas\u5c31\u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u6dfb\u52a0'year'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5e74\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'year' ] = df [ 'datetime_column' ] . dt . year # \u6dfb\u52a0'weekofyear'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5468\u6570\u63d0\u53d6\u51fa\u6765 df . loc [:, 'weekofyear' ] = df [ 'datetime_column' ] . dt . weekofyear # \u6dfb\u52a0'month'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u6708\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'month' ] = df [ 'datetime_column' ] . dt . month # \u6dfb\u52a0'dayofweek'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u661f\u671f\u51e0\u63d0\u53d6\u51fa\u6765 df . loc [:, 'dayofweek' ] = df [ 'datetime_column' ] . dt . dayofweek # \u6dfb\u52a0'weekend'\u5217\uff0c\u5224\u65ad\u5f53\u5929\u662f\u5426\u4e3a\u5468\u672b df . loc [:, 'weekend' ] = ( df . datetime_column . dt . weekday >= 5 ) . astype ( int ) # \u6dfb\u52a0 'hour' \u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5c0f\u65f6\u63d0\u53d6\u51fa\u6765 df . loc [:, 'hour' ] = df [ 'datetime_column' ] . dt . hour \u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u65e5\u671f\u65f6\u95f4\u5217\u521b\u5efa\u4e00\u7cfb\u5217\u65b0\u5217\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u53ef\u4ee5\u521b\u5efa\u7684\u4e00\u4e9b\u793a\u4f8b\u529f\u80fd\u3002 import pandas as pd # \u521b\u5efa\u65e5\u671f\u65f6\u95f4\u5e8f\u5217\uff0c\u5305\u542b\u4e86\u4ece '2020-01-06' \u5230 '2020-01-10' \u7684\u65e5\u671f\u65f6\u95f4\u70b9\uff0c\u65f6\u95f4\u95f4\u9694\u4e3a10\u5c0f\u65f6 s = pd . date_range ( '2020-01-06' , '2020-01-10' , freq = '10H' ) . to_series () # \u63d0\u53d6\u5bf9\u5e94\u65f6\u95f4\u7279\u5f81 features = { \"dayofweek\" : s . dt . dayofweek . values , \"dayofyear\" : s . dt . dayofyear . values , \"hour\" : s . dt . hour . values , \"is_leap_year\" : s . dt . is_leap_year . values , \"quarter\" : s . dt . quarter . values , \"weekofyear\" : s . dt . weekofyear . values } \u8fd9\u5c06\u4ece\u7ed9\u5b9a\u7cfb\u5217\u4e2d\u751f\u6210\u4e00\u4e2a\u7279\u5f81\u5b57\u5178\u3002\u60a8\u53ef\u4ee5\u5c06\u6b64\u5e94\u7528\u4e8e pandas \u6570\u636e\u4e2d\u7684\u4efb\u4f55\u65e5\u671f\u65f6\u95f4\u5217\u3002\u8fd9\u4e9b\u662f pandas \u63d0\u4f9b\u7684\u4f17\u591a\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u4e2d\u7684\u4e00\u90e8\u5206\u3002\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u975e\u5e38\u91cd\u8981\uff0c\u4f8b\u5982\uff0c\u5728\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97\u7684\u9500\u552e\u989d\u65f6\uff0c\u5982\u679c\u60f3\u5728\u805a\u5408\u7279\u5f81\u4e0a\u4f7f\u7528 xgboost \u7b49\u6a21\u578b\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u5c31\u975e\u5e38\u91cd\u8981\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u4e0b\u6240\u793a\u7684\u6570\u636e\uff1a \u56fe 1\uff1a\u5305\u542b\u5206\u7c7b\u548c\u65e5\u671f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u6709\u4e00\u4e2a\u65e5\u671f\u5217\uff0c\u4ece\u4e2d\u53ef\u4ee5\u8f7b\u677e\u63d0\u53d6\u5e74\u3001\u6708\u3001\u5b63\u5ea6\u7b49\u7279\u5f81\u3002\u7136\u540e\uff0c\u6211\u4eec\u6709\u4e00\u4e2a customer_id \u5217\uff0c\u8be5\u5217\u6709\u591a\u4e2a\u6761\u76ee\uff0c\u56e0\u6b64\u4e00\u4e2a\u5ba2\u6237\u4f1a\u88ab\u770b\u5230\u5f88\u591a\u6b21\uff08\u622a\u56fe\u4e2d\u770b\u4e0d\u5230\uff09\u3002\u6bcf\u4e2a\u65e5\u671f\u548c\u5ba2\u6237 ID \u90fd\u6709\u4e09\u4e2a\u5206\u7c7b\u7279\u5f81\u548c\u4e00\u4e2a\u6570\u5b57\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u4ece\u4e2d\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff1a - \u5ba2\u6237\u6700\u6d3b\u8dc3\u7684\u6708\u4efd\u662f\u51e0\u6708 - \u67d0\u4e2a\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u7684\u8ba1\u6570\u662f\u591a\u5c11 - \u67d0\u5e74\u67d0\u6708\u67d0\u5468\u67d0\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u6570\u91cf\u662f\u591a\u5c11\uff1f - \u67d0\u4e2a\u5ba2\u6237\u7684 num1 \u5e73\u5747\u503c\u662f\u591a\u5c11\uff1f - \u7b49\u7b49\u3002 \u4f7f\u7528 pandas \u4e2d\u7684\u805a\u5408\uff0c\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u521b\u5efa\u7c7b\u4f3c\u7684\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u5b9e\u73b0\u3002 def generate_features ( df ): df . loc [:, 'year' ] = df [ 'date' ] . dt . year df . loc [:, 'weekofyear' ] = df [ 'date' ] . dt . weekofyear df . loc [:, 'month' ] = df [ 'date' ] . dt . month df . loc [:, 'dayofweek' ] = df [ 'date' ] . dt . dayofweek df . loc [:, 'weekend' ] = ( df [ 'date' ] . dt . weekday >= 5 ) . astype ( int ) aggs = {} # \u5bf9 'month' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'month' ] = [ 'nunique' , 'mean' ] # \u5bf9 'weekofyear' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'weekofyear' ] = [ 'nunique' , 'mean' ] # \u5bf9 'num1' \u5217\u8fdb\u884c sum\u3001max\u3001min\u3001mean \u805a\u5408 aggs [ 'num1' ] = [ 'sum' , 'max' , 'min' , 'mean' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c size \u805a\u5408 aggs [ 'customer_id' ] = [ 'size' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c nunique \u805a\u5408 aggs [ 'customer_id' ] = [ 'nunique' ] # \u5bf9\u6570\u636e\u5e94\u7528\u4e0d\u540c\u7684\u805a\u5408\u51fd\u6570 agg_df = df . groupby ( 'customer_id' ) . agg ( aggs ) # \u91cd\u7f6e\u7d22\u5f15 agg_df = agg_df . reset_index () return agg_df \u8bf7\u6ce8\u610f\uff0c\u5728\u4e0a\u8ff0\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u8df3\u8fc7\u4e86\u5206\u7c7b\u53d8\u91cf\uff0c\u4f46\u60a8\u53ef\u4ee5\u50cf\u4f7f\u7528\u5176\u4ed6\u805a\u5408\u53d8\u91cf\u4e00\u6837\u4f7f\u7528\u5b83\u4eec\u3002 \u56fe 2\uff1a\u603b\u4f53\u7279\u5f81\u548c\u5176\u4ed6\u7279\u5f81 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u56fe 2 \u4e2d\u7684\u6570\u636e\u4e0e\u5e26\u6709 customer_id \u5217\u7684\u539f\u59cb\u6570\u636e\u5e27\u8fde\u63a5\u8d77\u6765\uff0c\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5e76\u4e0d\u662f\u8981\u9884\u6d4b\u4ec0\u4e48\uff1b\u6211\u4eec\u53ea\u662f\u5728\u521b\u5efa\u901a\u7528\u7279\u5f81\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u8bd5\u56fe\u5728\u8fd9\u91cc\u9884\u6d4b\u4ec0\u4e48\uff0c\u521b\u5efa\u7279\u5f81\u4f1a\u66f4\u5bb9\u6613\u3002 \u4f8b\u5982\uff0c\u6709\u65f6\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u95ee\u9898\u65f6\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u7684\u7279\u5f81\u4e0d\u662f\u5355\u4e2a\u503c\uff0c\u800c\u662f\u4e00\u7cfb\u5217\u503c\u3002 \u4f8b\u5982\uff0c\u5ba2\u6237\u5728\u7279\u5b9a\u65f6\u95f4\u6bb5\u5185\u7684\u4ea4\u6613\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4f1a\u521b\u5efa\u4e0d\u540c\u7c7b\u578b\u7684\u7279\u5f81\uff0c\u4f8b\u5982\uff1a\u4f7f\u7528\u6570\u503c\u7279\u5f81\u65f6\uff0c\u5728\u5bf9\u5206\u7c7b\u5217\u8fdb\u884c\u5206\u7ec4\u65f6\uff0c\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e8e\u65f6\u95f4\u5206\u5e03\u503c\u5217\u8868\u7684\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u7cfb\u5217\u7edf\u8ba1\u7279\u5f81\uff0c\u4f8b\u5982 \u5e73\u5747\u503c \u6700\u5927\u503c \u6700\u5c0f\u503c \u72ec\u7279\u6027 \u504f\u659c \u5cf0\u5ea6 Kstat \u767e\u5206\u4f4d\u6570 \u5b9a\u91cf \u5cf0\u503c\u5230\u5cf0\u503c \u4ee5\u53ca\u66f4\u591a \u8fd9\u4e9b\u53ef\u4ee5\u4f7f\u7528\u7b80\u5355\u7684 numpy \u51fd\u6570\u521b\u5efa\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 import numpy as np # \u521b\u5efa\u5b57\u5178\uff0c\u7528\u4e8e\u5b58\u50a8\u4e0d\u540c\u7684\u7edf\u8ba1\u7279\u5f81 feature_dict = {} # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5e73\u5747\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'mean' \u952e\u4e0b feature_dict [ 'mean' ] = np . mean ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5927\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'max' \u952e\u4e0b feature_dict [ 'max' ] = np . max ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5c0f\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'min' \u952e\u4e0b feature_dict [ 'min' ] = np . min ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6807\u51c6\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'std' \u952e\u4e0b feature_dict [ 'std' ] = np . std ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u65b9\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'var' \u952e\u4e0b feature_dict [ 'var' ] = np . var ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5dee\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'ptp' \u952e\u4e0b feature_dict [ 'ptp' ] = np . ptp ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c10\u767e\u5206\u4f4d\u6570\uff08\u5373\u767e\u5206\u4e4b10\u5206\u4f4d\u6570\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_10' \u952e\u4e0b feature_dict [ 'percentile_10' ] = np . percentile ( x , 10 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c60\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_60' \u952e\u4e0b feature_dict [ 'percentile_60' ] = np . percentile ( x , 60 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c90\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_90' \u952e\u4e0b feature_dict [ 'percentile_90' ] = np . percentile ( x , 90 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u76845%\u5206\u4f4d\u6570\uff08\u53730.05\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_5' \u952e\u4e0b feature_dict [ 'quantile_5' ] = np . quantile ( x , 0.05 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768495%\u5206\u4f4d\u6570\uff08\u53730.95\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_95' \u952e\u4e0b feature_dict [ 'quantile_95' ] = np . quantile ( x , 0.95 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768499%\u5206\u4f4d\u6570\uff08\u53730.99\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_99' \u952e\u4e0b feature_dict [ 'quantile_99' ] = np . quantile ( x , 0.99 ) \u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff08\u6570\u503c\u5217\u8868\uff09\u53ef\u4ee5\u8f6c\u6362\u6210\u8bb8\u591a\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4e00\u4e2a\u540d\u4e3a tsfresh \u7684 python \u5e93\u975e\u5e38\u6709\u7528\u3002 from tsfresh.feature_extraction import feature_calculators as fc # \u8ba1\u7b97 x \u6570\u5217\u7684\u7edd\u5bf9\u80fd\u91cf\uff08abs_energy\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'abs_energy' \u952e\u4e0b feature_dict [ 'abs_energy' ] = fc . abs_energy ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u9ad8\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_above_mean' \u952e\u4e0b feature_dict [ 'count_above_mean' ] = fc . count_above_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u4f4e\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_below_mean' \u952e\u4e0b feature_dict [ 'count_below_mean' ] = fc . count_below_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u7edd\u5bf9\u53d8\u5316\uff08mean_abs_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_abs_change' \u952e\u4e0b feature_dict [ 'mean_abs_change' ] = fc . mean_abs_change ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u53d8\u5316\u7387\uff08mean_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_change' \u952e\u4e0b feature_dict [ 'mean_change' ] = fc . mean_change ( x ) \u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\uff1btsfresh \u63d0\u4f9b\u4e86\u6570\u767e\u79cd\u7279\u5f81\u548c\u6570\u5341\u79cd\u4e0d\u540c\u7279\u5f81\u7684\u53d8\u4f53\uff0c\u4f60\u53ef\u4ee5\u5c06\u5b83\u4eec\u7528\u4e8e\u57fa\u4e8e\u65f6\u95f4\u5e8f\u5217\uff08\u503c\u5217\u8868\uff09\u7684\u7279\u5f81\u3002\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0cx \u662f\u4e00\u4e2a\u503c\u5217\u8868\u3002\u4f46\u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\u3002\u60a8\u8fd8\u53ef\u4ee5\u4e3a\u5305\u542b\u6216\u4e0d\u5305\u542b\u5206\u7c7b\u6570\u636e\u7684\u6570\u503c\u6570\u636e\u521b\u5efa\u8bb8\u591a\u5176\u4ed6\u7279\u5f81\u3002\u751f\u6210\u8bb8\u591a\u7279\u5f81\u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u521b\u5efa\u4e00\u5806\u591a\u9879\u5f0f\u7279\u5f81\u3002\u4f8b\u5982\uff0c\u4ece\u4e24\u4e2a\u7279\u5f81 \"a \"\u548c \"b \"\u751f\u6210\u7684\u4e8c\u7ea7\u591a\u9879\u5f0f\u7279\u5f81\u5305\u62ec \"a\"\u3001\"b\"\u3001\"ab\"\u3001\"a^2 \"\u548c \"b^2\"\u3002 import numpy as np df = pd . DataFrame ( np . random . rand ( 100 , 2 ), columns = [ f \"f_ { i } \" for i in range ( 1 , 3 )]) \u5982\u56fe 3 \u6240\u793a\uff0c\u5b83\u7ed9\u51fa\u4e86\u4e00\u4e2a\u6570\u636e\u8868\u3002 \u56fe 3\uff1a\u5305\u542b\u4e24\u4e2a\u6570\u5b57\u7279\u5f81\u7684\u968f\u673a\u6570\u636e\u8868 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u7684 PolynomialFeatures \u521b\u5efa\u4e24\u6b21\u591a\u9879\u5f0f\u7279\u5f81\u3002 from sklearn import preprocessing # \u6307\u5b9a\u591a\u9879\u5f0f\u7684\u6b21\u6570\u4e3a 2\uff0c\u4e0d\u4ec5\u8003\u8651\u4ea4\u4e92\u9879\uff0c\u4e0d\u5305\u62ec\u504f\u5dee\uff08include_bias=False\uff09 pf = preprocessing . PolynomialFeatures ( degree = 2 , interaction_only = False , include_bias = False ) # \u62df\u5408\uff0c\u521b\u5efa\u591a\u9879\u5f0f\u7279\u5f81 pf . fit ( df ) # \u8f6c\u6362\u6570\u636e poly_feats = pf . transform ( df ) # \u83b7\u53d6\u751f\u6210\u7684\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf num_feats = poly_feats . shape [ 1 ] # \u4e3a\u65b0\u751f\u6210\u7684\u7279\u5f81\u547d\u540d df_transformed = pd . DataFrame ( poly_feats , columns = [ f \"f_ { i } \" for i in range ( 1 , num_feats + 1 )] ) \u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u6570\u636e\u8868\uff0c\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1a\u5e26\u6709\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e\u8868 \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e9b\u591a\u9879\u5f0f\u7279\u5f81\u3002\u5982\u679c\u521b\u5efa\u7684\u662f\u4e09\u6b21\u591a\u9879\u5f0f\u7279\u5f81\uff0c\u6700\u7ec8\u603b\u5171\u4f1a\u6709\u4e5d\u4e2a\u7279\u5f81\u3002\u7279\u5f81\u7684\u6570\u91cf\u8d8a\u591a\uff0c\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf\u4e5f\u5c31\u8d8a\u591a\uff0c\u800c\u4e14\u4f60\u8fd8\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5982\u679c\u6570\u636e\u96c6\u4e2d\u6709\u5f88\u591a\u6837\u672c\uff0c\u90a3\u4e48\u521b\u5efa\u8fd9\u7c7b\u7279\u5f81\u5c31\u9700\u8981\u82b1\u8d39\u4e00\u4e9b\u65f6\u95f4\u3002 \u56fe 5\uff1a\u6570\u5b57\u7279\u5f81\u5217\u7684\u76f4\u65b9\u56fe \u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u529f\u80fd\u662f\u5c06\u6570\u5b57\u8f6c\u6362\u4e3a\u7c7b\u522b\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \u5206\u7bb1 \u3002\u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u56fe 5\uff0c\u5b83\u663e\u793a\u4e86\u4e00\u4e2a\u968f\u673a\u6570\u5b57\u7279\u5f81\u7684\u6837\u672c\u76f4\u65b9\u56fe\u3002\u6211\u4eec\u5728\u8be5\u56fe\u4e2d\u4f7f\u7528\u4e8610\u4e2a\u5206\u7bb1\uff0c\u53ef\u4ee5\u770b\u5230\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a10\u4e2a\u90e8\u5206\u3002\u8fd9\u53ef\u4ee5\u4f7f\u7528 pandas \u7684cat\u51fd\u6570\u6765\u5b9e\u73b0\u3002 # \u521b\u5efa10\u4e2a\u5206\u7bb1 df [ \"f_bin_10\" ] = pd . cut ( df [ \"f_1\" ], bins = 10 , labels = False ) # \u521b\u5efa100\u4e2a\u5206\u7bb1 df [ \"f_bin_100\" ] = pd . cut ( df [ \"f_1\" ], bins = 100 , labels = False ) \u5982\u56fe 6 \u6240\u793a\uff0c\u8fd9\u5c06\u5728\u6570\u636e\u5e27\u4e2d\u751f\u6210\u4e24\u4e2a\u65b0\u7279\u5f81\u3002 \u56fe 6\uff1a\u6570\u503c\u7279\u5f81\u5206\u7bb1 \u5f53\u4f60\u8fdb\u884c\u5206\u7c7b\u65f6\uff0c\u53ef\u4ee5\u540c\u65f6\u4f7f\u7528\u5206\u7bb1\u548c\u539f\u59cb\u7279\u5f81\u3002\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u540e\u534a\u90e8\u5206\u5b66\u4e60\u66f4\u591a\u5173\u4e8e\u9009\u62e9\u7279\u5f81\u7684\u77e5\u8bc6\u3002\u5206\u7bb1\u8fd8\u53ef\u4ee5\u5c06\u6570\u5b57\u7279\u5f81\u89c6\u4e3a\u5206\u7c7b\u7279\u5f81\u3002 \u53e6\u4e00\u79cd\u53ef\u4ee5\u4ece\u6570\u503c\u7279\u5f81\u4e2d\u521b\u5efa\u7684\u6709\u8da3\u7279\u5f81\u7c7b\u578b\u662f\u5bf9\u6570\u53d8\u6362\u3002\u8bf7\u770b\u56fe 7 \u4e2d\u7684\u7279\u5f81 f_3\u3002 \u4e0e\u5176\u4ed6\u65b9\u5dee\u8f83\u5c0f\u7684\u7279\u5f81\u76f8\u6bd4\uff08\u5047\u8bbe\u5982\u6b64\uff09\uff0cf_3 \u662f\u4e00\u79cd\u65b9\u5dee\u975e\u5e38\u5927\u7684\u7279\u6b8a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5e0c\u671b\u964d\u4f4e\u8fd9\u4e00\u5217\u7684\u65b9\u5dee\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6570\u53d8\u6362\u6765\u5b9e\u73b0\u3002 f_3 \u5217\u7684\u503c\u8303\u56f4\u4e3a 0 \u5230 10000\uff0c\u76f4\u65b9\u56fe\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u7279\u5f81 f_3 \u7684\u76f4\u65b9\u56fe \u6211\u4eec\u53ef\u4ee5\u5bf9\u8fd9\u4e00\u5217\u5e94\u7528 log(1 + x) \u6765\u51cf\u5c11\u5176\u65b9\u5dee\u3002\u56fe 9 \u663e\u793a\u4e86\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u76f4\u65b9\u56fe\u7684\u53d8\u5316\u3002 \u56fe 9\uff1a\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u7684 f_3 \u76f4\u65b9\u56fe \u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e0d\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u548c\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u7684\u65b9\u5dee\u3002 In [ X ]: df . f_3 . var () Out [ X ]: 8077265.875858586 In [ X ]: df . f_3 . apply ( lambda x : np . log ( 1 + x )) . var () Out [ X ]: 0.6058771732119975 \u6709\u65f6\uff0c\u4e5f\u53ef\u4ee5\u7528\u6307\u6570\u6765\u4ee3\u66ff\u5bf9\u6570\u3002\u4e00\u79cd\u975e\u5e38\u6709\u8da3\u7684\u60c5\u51b5\u662f\uff0c\u60a8\u4f7f\u7528\u57fa\u4e8e\u5bf9\u6570\u7684\u8bc4\u4f30\u6307\u6807\uff0c\u4f8b\u5982 RMSLE\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u5728\u5bf9\u6570\u53d8\u6362\u7684\u76ee\u6807\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u7136\u540e\u5728\u9884\u6d4b\u65f6\u4f7f\u7528\u6307\u6570\u503c\u8f6c\u6362\u56de\u539f\u59cb\u503c\u3002\u8fd9\u5c06\u6709\u52a9\u4e8e\u9488\u5bf9\u6307\u6807\u4f18\u5316\u6a21\u578b\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u7c7b\u6570\u5b57\u7279\u5f81\u90fd\u662f\u57fa\u4e8e\u76f4\u89c9\u521b\u5efa\u7684\u3002\u6ca1\u6709\u516c\u5f0f\u53ef\u5faa\u3002\u5982\u679c\u60a8\u4ece\u4e8b\u7684\u662f\u67d0\u4e00\u884c\u4e1a\uff0c\u60a8\u5c06\u521b\u5efa\u7279\u5b9a\u884c\u4e1a\u7684\u7279\u5f81\u3002 \u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u548c\u6570\u503c\u53d8\u91cf\u65f6\uff0c\u53ef\u80fd\u4f1a\u9047\u5230\u7f3a\u5931\u503c\u3002\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e9b\u5904\u7406\u5206\u7c7b\u7279\u5f81\u4e2d\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\uff0c\u4f46\u8fd8\u6709\u66f4\u591a\u65b9\u6cd5\u53ef\u4ee5\u5904\u7406\u7f3a\u5931\u503c/NaN \u503c\u3002\u8fd9\u4e5f\u88ab\u89c6\u4e3a\u7279\u5f81\u5de5\u7a0b\u3002 \u5982\u679c\u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u9047\u5230\u7f3a\u5931\u503c\uff0c\u5c31\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\uff01\u8fd9\u6837\u505a\u867d\u7136\u7b80\u5355\uff0c\u4f46\uff08\u51e0\u4e4e\uff09\u603b\u662f\u6709\u6548\u7684\uff01 \u5728\u6570\u503c\u6570\u636e\u4e2d\u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u9009\u62e9\u4e00\u4e2a\u5728\u7279\u5b9a\u7279\u5f81\u4e2d\u6ca1\u6709\u51fa\u73b0\u7684\u503c\uff0c\u7136\u540e\u7528\u5b83\u6765\u586b\u8865\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u7279\u5f81\u4e2d\u6ca1\u6709 0\u3002\u8fd9\u662f\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\uff0c\u4f46\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\u3002\u5bf9\u4e8e\u6570\u503c\u6570\u636e\u6765\u8bf4\uff0c\u6bd4\u586b\u5145 0 \u66f4\u6709\u6548\u7684\u65b9\u6cd5\u4e4b\u4e00\u662f\u4f7f\u7528\u5e73\u5747\u503c\u8fdb\u884c\u586b\u5145\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8be5\u7279\u5f81\u6240\u6709\u503c\u7684\u4e2d\u4f4d\u6570\u6765\u586b\u5145\uff0c\u6216\u8005\u4f7f\u7528\u6700\u5e38\u89c1\u7684\u503c\u6765\u586b\u5145\u7f3a\u5931\u503c\u3002\u8fd9\u6837\u505a\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002 \u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u9ad8\u7ea7\u65b9\u6cd5\u662f\u4f7f\u7528 K \u8fd1\u90bb\u6cd5 \u3002 \u60a8\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\uff0c\u7136\u540e\u5229\u7528\u67d0\u79cd\u8ddd\u79bb\u5ea6\u91cf\uff08\u4f8b\u5982\u6b27\u6c0f\u8ddd\u79bb\uff09\u627e\u5230\u6700\u8fd1\u7684\u90bb\u5c45\u3002\u7136\u540e\u53d6\u6240\u6709\u8fd1\u90bb\u7684\u5e73\u5747\u503c\u6765\u586b\u8865\u7f3a\u5931\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 KNN \u6765\u586b\u8865\u8fd9\u6837\u7684\u7f3a\u5931\u503c\u3002 \u56fe 10\uff1a\u6709\u7f3a\u5931\u503c\u7684\u4e8c\u7ef4\u6570\u7ec4 \u8ba9\u6211\u4eec\u770b\u770b KNN \u662f\u5982\u4f55\u5904\u7406\u56fe 10 \u6240\u793a\u7684\u7f3a\u5931\u503c\u77e9\u9635\u7684\u3002 import numpy as np from sklearn import impute # \u751f\u6210\u7ef4\u5ea6\u4e3a (10, 6) \u7684\u968f\u673a\u6574\u6570\u77e9\u9635 X\uff0c\u6570\u503c\u8303\u56f4\u5728 1 \u5230 14 \u4e4b\u95f4 X = np . random . randint ( 1 , 15 , ( 10 , 6 )) # \u6570\u636e\u7c7b\u578b\u8f6c\u6362\u4e3a float X = X . astype ( float ) # \u5728\u77e9\u9635 X \u4e2d\u968f\u673a\u9009\u62e9 10 \u4e2a\u4f4d\u7f6e\uff0c\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u5143\u7d20\u8bbe\u7f6e\u4e3a NaN\uff08\u7f3a\u5931\u503c\uff09 X . ravel ()[ np . random . choice ( X . size , 10 , replace = False )] = np . nan # \u521b\u5efa\u4e00\u4e2a KNNImputer \u5bf9\u8c61 knn_imputer\uff0c\u6307\u5b9a\u90bb\u5c45\u6570\u91cf\u4e3a 2 knn_imputer = impute . KNNImputer ( n_neighbors = 2 ) # # \u4f7f\u7528 knn_imputer \u5bf9\u77e9\u9635 X \u8fdb\u884c\u62df\u5408\u548c\u8f6c\u6362\uff0c\u7528 K-\u6700\u8fd1\u90bb\u65b9\u6cd5\u586b\u8865\u7f3a\u5931\u503c knn_imputer . fit_transform ( X ) \u5982\u56fe 11 \u6240\u793a\uff0c\u5b83\u586b\u5145\u4e86\u4e0a\u8ff0\u77e9\u9635\u3002 \u56fe 11\uff1aKNN\u4f30\u7b97\u7684\u6570\u503c \u53e6\u4e00\u79cd\u5f25\u8865\u5217\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\u662f\u8bad\u7ec3\u56de\u5f52\u6a21\u578b\uff0c\u8bd5\u56fe\u6839\u636e\u5176\u4ed6\u5217\u9884\u6d4b\u67d0\u5217\u7684\u7f3a\u5931\u503c\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u4ece\u6709\u7f3a\u5931\u503c\u7684\u4e00\u5217\u5f00\u59cb\uff0c\u5c06\u8fd9\u4e00\u5217\u4f5c\u4e3a\u65e0\u7f3a\u5931\u503c\u56de\u5f52\u6a21\u578b\u7684\u76ee\u6807\u5217\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6240\u6709\u5176\u4ed6\u5217\uff0c\u5bf9\u76f8\u5173\u5217\u4e2d\u6ca1\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\uff0c\u7136\u540e\u5c1d\u8bd5\u9884\u6d4b\u4e4b\u524d\u5220\u9664\u7684\u6837\u672c\u7684\u76ee\u6807\u5217\uff08\u540c\u4e00\u5217\uff09\u3002\u8fd9\u6837\uff0c\u57fa\u4e8e\u6a21\u578b\u7684\u4f30\u7b97\u5c31\u4f1a\u66f4\u52a0\u7a33\u5065\u3002 \u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u5bf9\u4e8e\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c\u6570\u503c\u5f52\u4e00\u5316\uff0c\u56e0\u4e3a\u5b83\u4eec\u53ef\u4ee5\u81ea\u884c\u5904\u7406\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u6240\u5c55\u793a\u7684\u53ea\u662f\u521b\u5efa\u4e00\u822c\u7279\u5f81\u7684\u4e00\u4e9b\u65b9\u6cd5\u3002\u73b0\u5728\uff0c\u5047\u8bbe\u60a8\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u9884\u6d4b\u4e0d\u540c\u5546\u54c1\uff08\u6bcf\u5468\u6216\u6bcf\u6708\uff09\u5546\u5e97\u9500\u552e\u989d\u7684\u95ee\u9898\u3002\u60a8\u6709\u5546\u54c1\uff0c\u4e5f\u6709\u5546\u5e97 ID\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u6bcf\u4e2a\u5546\u5e97\u7684\u5546\u54c1\u7b49\u7279\u5f81\u3002\u73b0\u5728\uff0c\u8fd9\u662f\u4e0a\u6587\u6ca1\u6709\u8ba8\u8bba\u7684\u7279\u5f81\u4e4b\u4e00\u3002\u8fd9\u7c7b\u7279\u5f81\u4e0d\u80fd\u4e00\u6982\u800c\u8bba\uff0c\u5b8c\u5168\u6765\u81ea\u4e8e\u9886\u57df\u3001\u6570\u636e\u548c\u4e1a\u52a1\u77e5\u8bc6\u3002\u67e5\u770b\u6570\u636e\uff0c\u627e\u51fa\u9002\u5408\u7684\u7279\u5f81\uff0c\u7136\u540e\u521b\u5efa\u76f8\u5e94\u7684\u7279\u5f81\u3002\u5982\u679c\u60a8\u4f7f\u7528\u7684\u662f\u903b\u8f91\u56de\u5f52\u7b49\u7ebf\u6027\u6a21\u578b\u6216 SVM \u7b49\u6a21\u578b\uff0c\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\u5bf9\u7279\u5f81\u8fdb\u884c\u7f29\u653e\u6216\u5f52\u4e00\u5316\u5904\u7406\u3002\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65e0\u9700\u5bf9\u7279\u5f81\u8fdb\u884c\u4efb\u4f55\u5f52\u4e00\u5316\u5904\u7406\u5373\u53ef\u6b63\u5e38\u5de5\u4f5c\u3002","title":"\u7279\u5f81\u5de5\u7a0b"},{"location":"%E7%89%B9%E5%BE%81%E5%B7%A5%E7%A8%8B/#_1","text":"\u7279\u5f81\u5de5\u7a0b\u662f\u6784\u5efa\u826f\u597d\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u90e8\u5206\u4e4b\u4e00\u3002\u5982\u679c\u6211\u4eec\u62e5\u6709\u6709\u7528\u7684\u7279\u5f81\uff0c\u6a21\u578b\u5c31\u4f1a\u8868\u73b0\u5f97\u66f4\u597d\u3002\u5728\u8bb8\u591a\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u907f\u514d\u4f7f\u7528\u5927\u578b\u590d\u6742\u6a21\u578b\uff0c\u800c\u4f7f\u7528\u5177\u6709\u5173\u952e\u5de5\u7a0b\u7279\u5f81\u7684\u7b80\u5355\u6a21\u578b\u3002\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\uff0c\u53ea\u6709\u5f53\u4f60\u5bf9\u95ee\u9898\u7684\u9886\u57df\u6709\u4e00\u5b9a\u7684\u4e86\u89e3\uff0c\u5e76\u4e14\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u76f8\u5173\u6570\u636e\u65f6\uff0c\u624d\u80fd\u4ee5\u6700\u4f73\u65b9\u5f0f\u5b8c\u6210\u7279\u5f81\u5de5\u7a0b\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e9b\u901a\u7528\u6280\u672f\uff0c\u4ece\u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u7279\u5f81\u3002\u7279\u5f81\u5de5\u7a0b\u4e0d\u4ec5\u4ec5\u662f\u4ece\u6570\u636e\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\uff0c\u8fd8\u5305\u62ec\u4e0d\u540c\u7c7b\u578b\u7684\u5f52\u4e00\u5316\u548c\u8f6c\u6362\u3002 \u5728\u6709\u5173\u5206\u7c7b\u7279\u5f81\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u4e86\u89e3\u4e86\u7ed3\u5408\u4e0d\u540c\u5206\u7c7b\u53d8\u91cf\u7684\u65b9\u6cd5\u3001\u5982\u4f55\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u8ba1\u6570\u3001\u6807\u7b7e\u7f16\u7801\u548c\u4f7f\u7528\u5d4c\u5165\u3002\u8fd9\u4e9b\u51e0\u4e4e\u90fd\u662f\u5229\u7528\u5206\u7c7b\u53d8\u91cf\u8bbe\u8ba1\u7279\u5f81\u7684\u65b9\u6cd5\u3002\u56e0\u6b64\uff0c\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u7684\u91cd\u70b9\u5c06\u4ec5\u9650\u4e8e\u6570\u503c\u53d8\u91cf\u4ee5\u53ca\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u7684\u7ec4\u5408\u3002 \u8ba9\u6211\u4eec\u4ece\u6700\u7b80\u5355\u4f46\u5e94\u7528\u6700\u5e7f\u6cdb\u7684\u7279\u5f81\u5de5\u7a0b\u6280\u672f\u5f00\u59cb\u3002\u5047\u8bbe\u4f60\u6b63\u5728\u5904\u7406\u65e5\u671f\u548c\u65f6\u95f4\u6570\u636e\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709\u4e00\u4e2a\u5e26\u6709\u65e5\u671f\u7c7b\u578b\u5217\u7684 pandas \u6570\u636e\u5e27\u3002\u5229\u7528\u8fd9\u4e00\u5217\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4ee5\u4e0b\u7279\u5f81\uff1a \u5e74 \u5e74\u4e2d\u7684\u5468 \u6708 \u661f\u671f \u5468\u672b \u5c0f\u65f6 \u8fd8\u6709\u66f4\u591a \u800c\u4f7f\u7528pandas\u5c31\u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u6dfb\u52a0'year'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5e74\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'year' ] = df [ 'datetime_column' ] . dt . year # \u6dfb\u52a0'weekofyear'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5468\u6570\u63d0\u53d6\u51fa\u6765 df . loc [:, 'weekofyear' ] = df [ 'datetime_column' ] . dt . weekofyear # \u6dfb\u52a0'month'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u6708\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'month' ] = df [ 'datetime_column' ] . dt . month # \u6dfb\u52a0'dayofweek'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u661f\u671f\u51e0\u63d0\u53d6\u51fa\u6765 df . loc [:, 'dayofweek' ] = df [ 'datetime_column' ] . dt . dayofweek # \u6dfb\u52a0'weekend'\u5217\uff0c\u5224\u65ad\u5f53\u5929\u662f\u5426\u4e3a\u5468\u672b df . loc [:, 'weekend' ] = ( df . datetime_column . dt . weekday >= 5 ) . astype ( int ) # \u6dfb\u52a0 'hour' \u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5c0f\u65f6\u63d0\u53d6\u51fa\u6765 df . loc [:, 'hour' ] = df [ 'datetime_column' ] . dt . hour \u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u65e5\u671f\u65f6\u95f4\u5217\u521b\u5efa\u4e00\u7cfb\u5217\u65b0\u5217\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u53ef\u4ee5\u521b\u5efa\u7684\u4e00\u4e9b\u793a\u4f8b\u529f\u80fd\u3002 import pandas as pd # \u521b\u5efa\u65e5\u671f\u65f6\u95f4\u5e8f\u5217\uff0c\u5305\u542b\u4e86\u4ece '2020-01-06' \u5230 '2020-01-10' \u7684\u65e5\u671f\u65f6\u95f4\u70b9\uff0c\u65f6\u95f4\u95f4\u9694\u4e3a10\u5c0f\u65f6 s = pd . date_range ( '2020-01-06' , '2020-01-10' , freq = '10H' ) . to_series () # \u63d0\u53d6\u5bf9\u5e94\u65f6\u95f4\u7279\u5f81 features = { \"dayofweek\" : s . dt . dayofweek . values , \"dayofyear\" : s . dt . dayofyear . values , \"hour\" : s . dt . hour . values , \"is_leap_year\" : s . dt . is_leap_year . values , \"quarter\" : s . dt . quarter . values , \"weekofyear\" : s . dt . weekofyear . values } \u8fd9\u5c06\u4ece\u7ed9\u5b9a\u7cfb\u5217\u4e2d\u751f\u6210\u4e00\u4e2a\u7279\u5f81\u5b57\u5178\u3002\u60a8\u53ef\u4ee5\u5c06\u6b64\u5e94\u7528\u4e8e pandas \u6570\u636e\u4e2d\u7684\u4efb\u4f55\u65e5\u671f\u65f6\u95f4\u5217\u3002\u8fd9\u4e9b\u662f pandas \u63d0\u4f9b\u7684\u4f17\u591a\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u4e2d\u7684\u4e00\u90e8\u5206\u3002\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u975e\u5e38\u91cd\u8981\uff0c\u4f8b\u5982\uff0c\u5728\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97\u7684\u9500\u552e\u989d\u65f6\uff0c\u5982\u679c\u60f3\u5728\u805a\u5408\u7279\u5f81\u4e0a\u4f7f\u7528 xgboost \u7b49\u6a21\u578b\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u5c31\u975e\u5e38\u91cd\u8981\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u4e0b\u6240\u793a\u7684\u6570\u636e\uff1a \u56fe 1\uff1a\u5305\u542b\u5206\u7c7b\u548c\u65e5\u671f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u6709\u4e00\u4e2a\u65e5\u671f\u5217\uff0c\u4ece\u4e2d\u53ef\u4ee5\u8f7b\u677e\u63d0\u53d6\u5e74\u3001\u6708\u3001\u5b63\u5ea6\u7b49\u7279\u5f81\u3002\u7136\u540e\uff0c\u6211\u4eec\u6709\u4e00\u4e2a customer_id \u5217\uff0c\u8be5\u5217\u6709\u591a\u4e2a\u6761\u76ee\uff0c\u56e0\u6b64\u4e00\u4e2a\u5ba2\u6237\u4f1a\u88ab\u770b\u5230\u5f88\u591a\u6b21\uff08\u622a\u56fe\u4e2d\u770b\u4e0d\u5230\uff09\u3002\u6bcf\u4e2a\u65e5\u671f\u548c\u5ba2\u6237 ID \u90fd\u6709\u4e09\u4e2a\u5206\u7c7b\u7279\u5f81\u548c\u4e00\u4e2a\u6570\u5b57\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u4ece\u4e2d\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff1a - \u5ba2\u6237\u6700\u6d3b\u8dc3\u7684\u6708\u4efd\u662f\u51e0\u6708 - \u67d0\u4e2a\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u7684\u8ba1\u6570\u662f\u591a\u5c11 - \u67d0\u5e74\u67d0\u6708\u67d0\u5468\u67d0\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u6570\u91cf\u662f\u591a\u5c11\uff1f - \u67d0\u4e2a\u5ba2\u6237\u7684 num1 \u5e73\u5747\u503c\u662f\u591a\u5c11\uff1f - \u7b49\u7b49\u3002 \u4f7f\u7528 pandas \u4e2d\u7684\u805a\u5408\uff0c\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u521b\u5efa\u7c7b\u4f3c\u7684\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u5b9e\u73b0\u3002 def generate_features ( df ): df . loc [:, 'year' ] = df [ 'date' ] . dt . year df . loc [:, 'weekofyear' ] = df [ 'date' ] . dt . weekofyear df . loc [:, 'month' ] = df [ 'date' ] . dt . month df . loc [:, 'dayofweek' ] = df [ 'date' ] . dt . dayofweek df . loc [:, 'weekend' ] = ( df [ 'date' ] . dt . weekday >= 5 ) . astype ( int ) aggs = {} # \u5bf9 'month' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'month' ] = [ 'nunique' , 'mean' ] # \u5bf9 'weekofyear' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'weekofyear' ] = [ 'nunique' , 'mean' ] # \u5bf9 'num1' \u5217\u8fdb\u884c sum\u3001max\u3001min\u3001mean \u805a\u5408 aggs [ 'num1' ] = [ 'sum' , 'max' , 'min' , 'mean' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c size \u805a\u5408 aggs [ 'customer_id' ] = [ 'size' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c nunique \u805a\u5408 aggs [ 'customer_id' ] = [ 'nunique' ] # \u5bf9\u6570\u636e\u5e94\u7528\u4e0d\u540c\u7684\u805a\u5408\u51fd\u6570 agg_df = df . groupby ( 'customer_id' ) . agg ( aggs ) # \u91cd\u7f6e\u7d22\u5f15 agg_df = agg_df . reset_index () return agg_df \u8bf7\u6ce8\u610f\uff0c\u5728\u4e0a\u8ff0\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u8df3\u8fc7\u4e86\u5206\u7c7b\u53d8\u91cf\uff0c\u4f46\u60a8\u53ef\u4ee5\u50cf\u4f7f\u7528\u5176\u4ed6\u805a\u5408\u53d8\u91cf\u4e00\u6837\u4f7f\u7528\u5b83\u4eec\u3002 \u56fe 2\uff1a\u603b\u4f53\u7279\u5f81\u548c\u5176\u4ed6\u7279\u5f81 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u56fe 2 \u4e2d\u7684\u6570\u636e\u4e0e\u5e26\u6709 customer_id \u5217\u7684\u539f\u59cb\u6570\u636e\u5e27\u8fde\u63a5\u8d77\u6765\uff0c\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5e76\u4e0d\u662f\u8981\u9884\u6d4b\u4ec0\u4e48\uff1b\u6211\u4eec\u53ea\u662f\u5728\u521b\u5efa\u901a\u7528\u7279\u5f81\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u8bd5\u56fe\u5728\u8fd9\u91cc\u9884\u6d4b\u4ec0\u4e48\uff0c\u521b\u5efa\u7279\u5f81\u4f1a\u66f4\u5bb9\u6613\u3002 \u4f8b\u5982\uff0c\u6709\u65f6\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u95ee\u9898\u65f6\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u7684\u7279\u5f81\u4e0d\u662f\u5355\u4e2a\u503c\uff0c\u800c\u662f\u4e00\u7cfb\u5217\u503c\u3002 \u4f8b\u5982\uff0c\u5ba2\u6237\u5728\u7279\u5b9a\u65f6\u95f4\u6bb5\u5185\u7684\u4ea4\u6613\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4f1a\u521b\u5efa\u4e0d\u540c\u7c7b\u578b\u7684\u7279\u5f81\uff0c\u4f8b\u5982\uff1a\u4f7f\u7528\u6570\u503c\u7279\u5f81\u65f6\uff0c\u5728\u5bf9\u5206\u7c7b\u5217\u8fdb\u884c\u5206\u7ec4\u65f6\uff0c\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e8e\u65f6\u95f4\u5206\u5e03\u503c\u5217\u8868\u7684\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u7cfb\u5217\u7edf\u8ba1\u7279\u5f81\uff0c\u4f8b\u5982 \u5e73\u5747\u503c \u6700\u5927\u503c \u6700\u5c0f\u503c \u72ec\u7279\u6027 \u504f\u659c \u5cf0\u5ea6 Kstat \u767e\u5206\u4f4d\u6570 \u5b9a\u91cf \u5cf0\u503c\u5230\u5cf0\u503c \u4ee5\u53ca\u66f4\u591a \u8fd9\u4e9b\u53ef\u4ee5\u4f7f\u7528\u7b80\u5355\u7684 numpy \u51fd\u6570\u521b\u5efa\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 import numpy as np # \u521b\u5efa\u5b57\u5178\uff0c\u7528\u4e8e\u5b58\u50a8\u4e0d\u540c\u7684\u7edf\u8ba1\u7279\u5f81 feature_dict = {} # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5e73\u5747\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'mean' \u952e\u4e0b feature_dict [ 'mean' ] = np . mean ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5927\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'max' \u952e\u4e0b feature_dict [ 'max' ] = np . max ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5c0f\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'min' \u952e\u4e0b feature_dict [ 'min' ] = np . min ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6807\u51c6\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'std' \u952e\u4e0b feature_dict [ 'std' ] = np . std ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u65b9\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'var' \u952e\u4e0b feature_dict [ 'var' ] = np . var ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5dee\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'ptp' \u952e\u4e0b feature_dict [ 'ptp' ] = np . ptp ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c10\u767e\u5206\u4f4d\u6570\uff08\u5373\u767e\u5206\u4e4b10\u5206\u4f4d\u6570\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_10' \u952e\u4e0b feature_dict [ 'percentile_10' ] = np . percentile ( x , 10 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c60\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_60' \u952e\u4e0b feature_dict [ 'percentile_60' ] = np . percentile ( x , 60 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c90\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_90' \u952e\u4e0b feature_dict [ 'percentile_90' ] = np . percentile ( x , 90 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u76845%\u5206\u4f4d\u6570\uff08\u53730.05\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_5' \u952e\u4e0b feature_dict [ 'quantile_5' ] = np . quantile ( x , 0.05 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768495%\u5206\u4f4d\u6570\uff08\u53730.95\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_95' \u952e\u4e0b feature_dict [ 'quantile_95' ] = np . quantile ( x , 0.95 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768499%\u5206\u4f4d\u6570\uff08\u53730.99\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_99' \u952e\u4e0b feature_dict [ 'quantile_99' ] = np . quantile ( x , 0.99 ) \u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff08\u6570\u503c\u5217\u8868\uff09\u53ef\u4ee5\u8f6c\u6362\u6210\u8bb8\u591a\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4e00\u4e2a\u540d\u4e3a tsfresh \u7684 python \u5e93\u975e\u5e38\u6709\u7528\u3002 from tsfresh.feature_extraction import feature_calculators as fc # \u8ba1\u7b97 x \u6570\u5217\u7684\u7edd\u5bf9\u80fd\u91cf\uff08abs_energy\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'abs_energy' \u952e\u4e0b feature_dict [ 'abs_energy' ] = fc . abs_energy ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u9ad8\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_above_mean' \u952e\u4e0b feature_dict [ 'count_above_mean' ] = fc . count_above_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u4f4e\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_below_mean' \u952e\u4e0b feature_dict [ 'count_below_mean' ] = fc . count_below_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u7edd\u5bf9\u53d8\u5316\uff08mean_abs_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_abs_change' \u952e\u4e0b feature_dict [ 'mean_abs_change' ] = fc . mean_abs_change ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u53d8\u5316\u7387\uff08mean_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_change' \u952e\u4e0b feature_dict [ 'mean_change' ] = fc . mean_change ( x ) \u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\uff1btsfresh \u63d0\u4f9b\u4e86\u6570\u767e\u79cd\u7279\u5f81\u548c\u6570\u5341\u79cd\u4e0d\u540c\u7279\u5f81\u7684\u53d8\u4f53\uff0c\u4f60\u53ef\u4ee5\u5c06\u5b83\u4eec\u7528\u4e8e\u57fa\u4e8e\u65f6\u95f4\u5e8f\u5217\uff08\u503c\u5217\u8868\uff09\u7684\u7279\u5f81\u3002\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0cx \u662f\u4e00\u4e2a\u503c\u5217\u8868\u3002\u4f46\u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\u3002\u60a8\u8fd8\u53ef\u4ee5\u4e3a\u5305\u542b\u6216\u4e0d\u5305\u542b\u5206\u7c7b\u6570\u636e\u7684\u6570\u503c\u6570\u636e\u521b\u5efa\u8bb8\u591a\u5176\u4ed6\u7279\u5f81\u3002\u751f\u6210\u8bb8\u591a\u7279\u5f81\u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u521b\u5efa\u4e00\u5806\u591a\u9879\u5f0f\u7279\u5f81\u3002\u4f8b\u5982\uff0c\u4ece\u4e24\u4e2a\u7279\u5f81 \"a \"\u548c \"b \"\u751f\u6210\u7684\u4e8c\u7ea7\u591a\u9879\u5f0f\u7279\u5f81\u5305\u62ec \"a\"\u3001\"b\"\u3001\"ab\"\u3001\"a^2 \"\u548c \"b^2\"\u3002 import numpy as np df = pd . DataFrame ( np . random . rand ( 100 , 2 ), columns = [ f \"f_ { i } \" for i in range ( 1 , 3 )]) \u5982\u56fe 3 \u6240\u793a\uff0c\u5b83\u7ed9\u51fa\u4e86\u4e00\u4e2a\u6570\u636e\u8868\u3002 \u56fe 3\uff1a\u5305\u542b\u4e24\u4e2a\u6570\u5b57\u7279\u5f81\u7684\u968f\u673a\u6570\u636e\u8868 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u7684 PolynomialFeatures \u521b\u5efa\u4e24\u6b21\u591a\u9879\u5f0f\u7279\u5f81\u3002 from sklearn import preprocessing # \u6307\u5b9a\u591a\u9879\u5f0f\u7684\u6b21\u6570\u4e3a 2\uff0c\u4e0d\u4ec5\u8003\u8651\u4ea4\u4e92\u9879\uff0c\u4e0d\u5305\u62ec\u504f\u5dee\uff08include_bias=False\uff09 pf = preprocessing . PolynomialFeatures ( degree = 2 , interaction_only = False , include_bias = False ) # \u62df\u5408\uff0c\u521b\u5efa\u591a\u9879\u5f0f\u7279\u5f81 pf . fit ( df ) # \u8f6c\u6362\u6570\u636e poly_feats = pf . transform ( df ) # \u83b7\u53d6\u751f\u6210\u7684\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf num_feats = poly_feats . shape [ 1 ] # \u4e3a\u65b0\u751f\u6210\u7684\u7279\u5f81\u547d\u540d df_transformed = pd . DataFrame ( poly_feats , columns = [ f \"f_ { i } \" for i in range ( 1 , num_feats + 1 )] ) \u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u6570\u636e\u8868\uff0c\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1a\u5e26\u6709\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e\u8868 \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e9b\u591a\u9879\u5f0f\u7279\u5f81\u3002\u5982\u679c\u521b\u5efa\u7684\u662f\u4e09\u6b21\u591a\u9879\u5f0f\u7279\u5f81\uff0c\u6700\u7ec8\u603b\u5171\u4f1a\u6709\u4e5d\u4e2a\u7279\u5f81\u3002\u7279\u5f81\u7684\u6570\u91cf\u8d8a\u591a\uff0c\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf\u4e5f\u5c31\u8d8a\u591a\uff0c\u800c\u4e14\u4f60\u8fd8\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5982\u679c\u6570\u636e\u96c6\u4e2d\u6709\u5f88\u591a\u6837\u672c\uff0c\u90a3\u4e48\u521b\u5efa\u8fd9\u7c7b\u7279\u5f81\u5c31\u9700\u8981\u82b1\u8d39\u4e00\u4e9b\u65f6\u95f4\u3002 \u56fe 5\uff1a\u6570\u5b57\u7279\u5f81\u5217\u7684\u76f4\u65b9\u56fe \u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u529f\u80fd\u662f\u5c06\u6570\u5b57\u8f6c\u6362\u4e3a\u7c7b\u522b\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \u5206\u7bb1 \u3002\u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u56fe 5\uff0c\u5b83\u663e\u793a\u4e86\u4e00\u4e2a\u968f\u673a\u6570\u5b57\u7279\u5f81\u7684\u6837\u672c\u76f4\u65b9\u56fe\u3002\u6211\u4eec\u5728\u8be5\u56fe\u4e2d\u4f7f\u7528\u4e8610\u4e2a\u5206\u7bb1\uff0c\u53ef\u4ee5\u770b\u5230\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a10\u4e2a\u90e8\u5206\u3002\u8fd9\u53ef\u4ee5\u4f7f\u7528 pandas \u7684cat\u51fd\u6570\u6765\u5b9e\u73b0\u3002 # \u521b\u5efa10\u4e2a\u5206\u7bb1 df [ \"f_bin_10\" ] = pd . cut ( df [ \"f_1\" ], bins = 10 , labels = False ) # \u521b\u5efa100\u4e2a\u5206\u7bb1 df [ \"f_bin_100\" ] = pd . cut ( df [ \"f_1\" ], bins = 100 , labels = False ) \u5982\u56fe 6 \u6240\u793a\uff0c\u8fd9\u5c06\u5728\u6570\u636e\u5e27\u4e2d\u751f\u6210\u4e24\u4e2a\u65b0\u7279\u5f81\u3002 \u56fe 6\uff1a\u6570\u503c\u7279\u5f81\u5206\u7bb1 \u5f53\u4f60\u8fdb\u884c\u5206\u7c7b\u65f6\uff0c\u53ef\u4ee5\u540c\u65f6\u4f7f\u7528\u5206\u7bb1\u548c\u539f\u59cb\u7279\u5f81\u3002\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u540e\u534a\u90e8\u5206\u5b66\u4e60\u66f4\u591a\u5173\u4e8e\u9009\u62e9\u7279\u5f81\u7684\u77e5\u8bc6\u3002\u5206\u7bb1\u8fd8\u53ef\u4ee5\u5c06\u6570\u5b57\u7279\u5f81\u89c6\u4e3a\u5206\u7c7b\u7279\u5f81\u3002 \u53e6\u4e00\u79cd\u53ef\u4ee5\u4ece\u6570\u503c\u7279\u5f81\u4e2d\u521b\u5efa\u7684\u6709\u8da3\u7279\u5f81\u7c7b\u578b\u662f\u5bf9\u6570\u53d8\u6362\u3002\u8bf7\u770b\u56fe 7 \u4e2d\u7684\u7279\u5f81 f_3\u3002 \u4e0e\u5176\u4ed6\u65b9\u5dee\u8f83\u5c0f\u7684\u7279\u5f81\u76f8\u6bd4\uff08\u5047\u8bbe\u5982\u6b64\uff09\uff0cf_3 \u662f\u4e00\u79cd\u65b9\u5dee\u975e\u5e38\u5927\u7684\u7279\u6b8a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5e0c\u671b\u964d\u4f4e\u8fd9\u4e00\u5217\u7684\u65b9\u5dee\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6570\u53d8\u6362\u6765\u5b9e\u73b0\u3002 f_3 \u5217\u7684\u503c\u8303\u56f4\u4e3a 0 \u5230 10000\uff0c\u76f4\u65b9\u56fe\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u7279\u5f81 f_3 \u7684\u76f4\u65b9\u56fe \u6211\u4eec\u53ef\u4ee5\u5bf9\u8fd9\u4e00\u5217\u5e94\u7528 log(1 + x) \u6765\u51cf\u5c11\u5176\u65b9\u5dee\u3002\u56fe 9 \u663e\u793a\u4e86\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u76f4\u65b9\u56fe\u7684\u53d8\u5316\u3002 \u56fe 9\uff1a\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u7684 f_3 \u76f4\u65b9\u56fe \u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e0d\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u548c\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u7684\u65b9\u5dee\u3002 In [ X ]: df . f_3 . var () Out [ X ]: 8077265.875858586 In [ X ]: df . f_3 . apply ( lambda x : np . log ( 1 + x )) . var () Out [ X ]: 0.6058771732119975 \u6709\u65f6\uff0c\u4e5f\u53ef\u4ee5\u7528\u6307\u6570\u6765\u4ee3\u66ff\u5bf9\u6570\u3002\u4e00\u79cd\u975e\u5e38\u6709\u8da3\u7684\u60c5\u51b5\u662f\uff0c\u60a8\u4f7f\u7528\u57fa\u4e8e\u5bf9\u6570\u7684\u8bc4\u4f30\u6307\u6807\uff0c\u4f8b\u5982 RMSLE\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u5728\u5bf9\u6570\u53d8\u6362\u7684\u76ee\u6807\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u7136\u540e\u5728\u9884\u6d4b\u65f6\u4f7f\u7528\u6307\u6570\u503c\u8f6c\u6362\u56de\u539f\u59cb\u503c\u3002\u8fd9\u5c06\u6709\u52a9\u4e8e\u9488\u5bf9\u6307\u6807\u4f18\u5316\u6a21\u578b\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u7c7b\u6570\u5b57\u7279\u5f81\u90fd\u662f\u57fa\u4e8e\u76f4\u89c9\u521b\u5efa\u7684\u3002\u6ca1\u6709\u516c\u5f0f\u53ef\u5faa\u3002\u5982\u679c\u60a8\u4ece\u4e8b\u7684\u662f\u67d0\u4e00\u884c\u4e1a\uff0c\u60a8\u5c06\u521b\u5efa\u7279\u5b9a\u884c\u4e1a\u7684\u7279\u5f81\u3002 \u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u548c\u6570\u503c\u53d8\u91cf\u65f6\uff0c\u53ef\u80fd\u4f1a\u9047\u5230\u7f3a\u5931\u503c\u3002\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e9b\u5904\u7406\u5206\u7c7b\u7279\u5f81\u4e2d\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\uff0c\u4f46\u8fd8\u6709\u66f4\u591a\u65b9\u6cd5\u53ef\u4ee5\u5904\u7406\u7f3a\u5931\u503c/NaN \u503c\u3002\u8fd9\u4e5f\u88ab\u89c6\u4e3a\u7279\u5f81\u5de5\u7a0b\u3002 \u5982\u679c\u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u9047\u5230\u7f3a\u5931\u503c\uff0c\u5c31\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\uff01\u8fd9\u6837\u505a\u867d\u7136\u7b80\u5355\uff0c\u4f46\uff08\u51e0\u4e4e\uff09\u603b\u662f\u6709\u6548\u7684\uff01 \u5728\u6570\u503c\u6570\u636e\u4e2d\u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u9009\u62e9\u4e00\u4e2a\u5728\u7279\u5b9a\u7279\u5f81\u4e2d\u6ca1\u6709\u51fa\u73b0\u7684\u503c\uff0c\u7136\u540e\u7528\u5b83\u6765\u586b\u8865\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u7279\u5f81\u4e2d\u6ca1\u6709 0\u3002\u8fd9\u662f\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\uff0c\u4f46\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\u3002\u5bf9\u4e8e\u6570\u503c\u6570\u636e\u6765\u8bf4\uff0c\u6bd4\u586b\u5145 0 \u66f4\u6709\u6548\u7684\u65b9\u6cd5\u4e4b\u4e00\u662f\u4f7f\u7528\u5e73\u5747\u503c\u8fdb\u884c\u586b\u5145\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8be5\u7279\u5f81\u6240\u6709\u503c\u7684\u4e2d\u4f4d\u6570\u6765\u586b\u5145\uff0c\u6216\u8005\u4f7f\u7528\u6700\u5e38\u89c1\u7684\u503c\u6765\u586b\u5145\u7f3a\u5931\u503c\u3002\u8fd9\u6837\u505a\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002 \u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u9ad8\u7ea7\u65b9\u6cd5\u662f\u4f7f\u7528 K \u8fd1\u90bb\u6cd5 \u3002 \u60a8\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\uff0c\u7136\u540e\u5229\u7528\u67d0\u79cd\u8ddd\u79bb\u5ea6\u91cf\uff08\u4f8b\u5982\u6b27\u6c0f\u8ddd\u79bb\uff09\u627e\u5230\u6700\u8fd1\u7684\u90bb\u5c45\u3002\u7136\u540e\u53d6\u6240\u6709\u8fd1\u90bb\u7684\u5e73\u5747\u503c\u6765\u586b\u8865\u7f3a\u5931\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 KNN \u6765\u586b\u8865\u8fd9\u6837\u7684\u7f3a\u5931\u503c\u3002 \u56fe 10\uff1a\u6709\u7f3a\u5931\u503c\u7684\u4e8c\u7ef4\u6570\u7ec4 \u8ba9\u6211\u4eec\u770b\u770b KNN \u662f\u5982\u4f55\u5904\u7406\u56fe 10 \u6240\u793a\u7684\u7f3a\u5931\u503c\u77e9\u9635\u7684\u3002 import numpy as np from sklearn import impute # \u751f\u6210\u7ef4\u5ea6\u4e3a (10, 6) \u7684\u968f\u673a\u6574\u6570\u77e9\u9635 X\uff0c\u6570\u503c\u8303\u56f4\u5728 1 \u5230 14 \u4e4b\u95f4 X = np . random . randint ( 1 , 15 , ( 10 , 6 )) # \u6570\u636e\u7c7b\u578b\u8f6c\u6362\u4e3a float X = X . astype ( float ) # \u5728\u77e9\u9635 X \u4e2d\u968f\u673a\u9009\u62e9 10 \u4e2a\u4f4d\u7f6e\uff0c\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u5143\u7d20\u8bbe\u7f6e\u4e3a NaN\uff08\u7f3a\u5931\u503c\uff09 X . ravel ()[ np . random . choice ( X . size , 10 , replace = False )] = np . nan # \u521b\u5efa\u4e00\u4e2a KNNImputer \u5bf9\u8c61 knn_imputer\uff0c\u6307\u5b9a\u90bb\u5c45\u6570\u91cf\u4e3a 2 knn_imputer = impute . KNNImputer ( n_neighbors = 2 ) # # \u4f7f\u7528 knn_imputer \u5bf9\u77e9\u9635 X \u8fdb\u884c\u62df\u5408\u548c\u8f6c\u6362\uff0c\u7528 K-\u6700\u8fd1\u90bb\u65b9\u6cd5\u586b\u8865\u7f3a\u5931\u503c knn_imputer . fit_transform ( X ) \u5982\u56fe 11 \u6240\u793a\uff0c\u5b83\u586b\u5145\u4e86\u4e0a\u8ff0\u77e9\u9635\u3002 \u56fe 11\uff1aKNN\u4f30\u7b97\u7684\u6570\u503c \u53e6\u4e00\u79cd\u5f25\u8865\u5217\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\u662f\u8bad\u7ec3\u56de\u5f52\u6a21\u578b\uff0c\u8bd5\u56fe\u6839\u636e\u5176\u4ed6\u5217\u9884\u6d4b\u67d0\u5217\u7684\u7f3a\u5931\u503c\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u4ece\u6709\u7f3a\u5931\u503c\u7684\u4e00\u5217\u5f00\u59cb\uff0c\u5c06\u8fd9\u4e00\u5217\u4f5c\u4e3a\u65e0\u7f3a\u5931\u503c\u56de\u5f52\u6a21\u578b\u7684\u76ee\u6807\u5217\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6240\u6709\u5176\u4ed6\u5217\uff0c\u5bf9\u76f8\u5173\u5217\u4e2d\u6ca1\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\uff0c\u7136\u540e\u5c1d\u8bd5\u9884\u6d4b\u4e4b\u524d\u5220\u9664\u7684\u6837\u672c\u7684\u76ee\u6807\u5217\uff08\u540c\u4e00\u5217\uff09\u3002\u8fd9\u6837\uff0c\u57fa\u4e8e\u6a21\u578b\u7684\u4f30\u7b97\u5c31\u4f1a\u66f4\u52a0\u7a33\u5065\u3002 \u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u5bf9\u4e8e\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c\u6570\u503c\u5f52\u4e00\u5316\uff0c\u56e0\u4e3a\u5b83\u4eec\u53ef\u4ee5\u81ea\u884c\u5904\u7406\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u6240\u5c55\u793a\u7684\u53ea\u662f\u521b\u5efa\u4e00\u822c\u7279\u5f81\u7684\u4e00\u4e9b\u65b9\u6cd5\u3002\u73b0\u5728\uff0c\u5047\u8bbe\u60a8\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u9884\u6d4b\u4e0d\u540c\u5546\u54c1\uff08\u6bcf\u5468\u6216\u6bcf\u6708\uff09\u5546\u5e97\u9500\u552e\u989d\u7684\u95ee\u9898\u3002\u60a8\u6709\u5546\u54c1\uff0c\u4e5f\u6709\u5546\u5e97 ID\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u6bcf\u4e2a\u5546\u5e97\u7684\u5546\u54c1\u7b49\u7279\u5f81\u3002\u73b0\u5728\uff0c\u8fd9\u662f\u4e0a\u6587\u6ca1\u6709\u8ba8\u8bba\u7684\u7279\u5f81\u4e4b\u4e00\u3002\u8fd9\u7c7b\u7279\u5f81\u4e0d\u80fd\u4e00\u6982\u800c\u8bba\uff0c\u5b8c\u5168\u6765\u81ea\u4e8e\u9886\u57df\u3001\u6570\u636e\u548c\u4e1a\u52a1\u77e5\u8bc6\u3002\u67e5\u770b\u6570\u636e\uff0c\u627e\u51fa\u9002\u5408\u7684\u7279\u5f81\uff0c\u7136\u540e\u521b\u5efa\u76f8\u5e94\u7684\u7279\u5f81\u3002\u5982\u679c\u60a8\u4f7f\u7528\u7684\u662f\u903b\u8f91\u56de\u5f52\u7b49\u7ebf\u6027\u6a21\u578b\u6216 SVM \u7b49\u6a21\u578b\uff0c\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\u5bf9\u7279\u5f81\u8fdb\u884c\u7f29\u653e\u6216\u5f52\u4e00\u5316\u5904\u7406\u3002\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65e0\u9700\u5bf9\u7279\u5f81\u8fdb\u884c\u4efb\u4f55\u5f52\u4e00\u5316\u5904\u7406\u5373\u53ef\u6b63\u5e38\u5de5\u4f5c\u3002","title":"\u7279\u5f81\u5de5\u7a0b"},{"location":"%E7%89%B9%E5%BE%81%E9%80%89%E6%8B%A9/","text":"\u7279\u5f81\u9009\u62e9 \u5f53\u4f60\u521b\u5efa\u4e86\u6210\u5343\u4e0a\u4e07\u4e2a\u7279\u5f81\u540e\uff0c\u5c31\u8be5\u4ece\u4e2d\u6311\u9009\u51fa\u51e0\u4e2a\u4e86\u3002\u4f46\u662f\uff0c\u6211\u4eec\u7edd\u4e0d\u5e94\u8be5\u521b\u5efa\u6210\u767e\u4e0a\u5343\u4e2a\u65e0\u7528\u7684\u7279\u5f81\u3002\u7279\u5f81\u8fc7\u591a\u4f1a\u5e26\u6765\u4e00\u4e2a\u4f17\u6240\u5468\u77e5\u7684\u95ee\u9898\uff0c\u5373 \"\u7ef4\u5ea6\u8bc5\u5492\"\u3002\u5982\u679c\u4f60\u6709\u5f88\u591a\u7279\u5f81\uff0c\u4f60\u4e5f\u5fc5\u987b\u6709\u5f88\u591a\u8bad\u7ec3\u6837\u672c\u6765\u6355\u6349\u6240\u6709\u7279\u5f81\u3002\u4ec0\u4e48\u662f \"\u5927\u91cf \"\u5e76\u6ca1\u6709\u6b63\u786e\u7684\u5b9a\u4e49\uff0c\u8fd9\u9700\u8981\u60a8\u901a\u8fc7\u6b63\u786e\u9a8c\u8bc1\u60a8\u7684\u6a21\u578b\u548c\u68c0\u67e5\u8bad\u7ec3\u6a21\u578b\u6240\u9700\u7684\u65f6\u95f4\u6765\u786e\u5b9a\u3002 \u9009\u62e9\u7279\u5f81\u7684\u6700\u7b80\u5355\u65b9\u6cd5\u662f \u5220\u9664\u65b9\u5dee\u975e\u5e38\u5c0f\u7684\u7279\u5f81 \u3002\u5982\u679c\u7279\u5f81\u7684\u65b9\u5dee\u975e\u5e38\u5c0f\uff08\u5373\u975e\u5e38\u63a5\u8fd1\u4e8e 0\uff09\uff0c\u5b83\u4eec\u5c31\u63a5\u8fd1\u4e8e\u5e38\u91cf\uff0c\u56e0\u6b64\u6839\u672c\u4e0d\u4f1a\u7ed9\u4efb\u4f55\u6a21\u578b\u589e\u52a0\u4efb\u4f55\u4ef7\u503c\u3002\u6700\u597d\u7684\u529e\u6cd5\u5c31\u662f\u53bb\u6389\u5b83\u4eec\uff0c\u4ece\u800c\u964d\u4f4e\u590d\u6742\u5ea6\u3002\u8bf7\u6ce8\u610f\uff0c\u65b9\u5dee\u4e5f\u53d6\u51b3\u4e8e\u6570\u636e\u7684\u7f29\u653e\u3002 Scikit-learn \u7684 VarianceThreshold \u5b9e\u73b0\u4e86\u8fd9\u4e00\u70b9\u3002 from sklearn.feature_selection import VarianceThreshold data = ... # \u521b\u5efa VarianceThreshold \u5bf9\u8c61 var_thresh\uff0c\u6307\u5b9a\u65b9\u5dee\u9608\u503c\u4e3a 0.1 var_thresh = VarianceThreshold ( threshold = 0.1 ) # \u4f7f\u7528 var_thresh \u5bf9\u6570\u636e data \u8fdb\u884c\u62df\u5408\u548c\u53d8\u6362\uff0c\u5c06\u65b9\u5dee\u4f4e\u4e8e\u9608\u503c\u7684\u7279\u5f81\u79fb\u9664 transformed_data = var_thresh . fit_transform ( data ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u5220\u9664\u76f8\u5173\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u3002\u8981\u8ba1\u7b97\u4e0d\u540c\u6570\u5b57\u7279\u5f81\u4e4b\u95f4\u7684\u76f8\u5173\u6027\uff0c\u53ef\u4ee5\u4f7f\u7528\u76ae\u5c14\u900a\u76f8\u5173\u6027\u3002 import pandas as pd from sklearn.datasets import fetch_california_housing # \u52a0\u8f7d\u6570\u636e data = fetch_california_housing () # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u77e9\u9635 X X = data [ \"data\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u7684\u5217\u540d col_names = data [ \"feature_names\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf y y = data [ \"target\" ] df = pd . DataFrame ( X , columns = col_names ) # \u6dfb\u52a0 MedInc_Sqrt \u5217\uff0c\u662f MedInc \u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u5e73\u65b9\u6839\u8fd0\u7b97\u7684\u7ed3\u679c df . loc [:, \"MedInc_Sqrt\" ] = df . MedInc . apply ( np . sqrt ) # \u8ba1\u7b97\u76ae\u5c14\u900a\u76f8\u5173\u6027\u77e9\u9635 df . corr () \u5f97\u51fa\u76f8\u5173\u77e9\u9635\uff0c\u5982\u56fe 1 \u6240\u793a\u3002 \u56fe 1\uff1a\u76ae\u5c14\u900a\u76f8\u5173\u77e9\u9635\u6837\u672c \u6211\u4eec\u770b\u5230\uff0cMedInc_Sqrt \u4e0e MedInc \u7684\u76f8\u5173\u6027\u975e\u5e38\u9ad8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5220\u9664\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u4e00\u4e9b \u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5 \u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u53ea\u4e0d\u8fc7\u662f\u9488\u5bf9\u7ed9\u5b9a\u76ee\u6807\u5bf9\u6bcf\u4e2a\u7279\u5f81\u8fdb\u884c\u8bc4\u5206\u3002 \u4e92\u4fe1\u606f \u3001 \u65b9\u5dee\u5206\u6790 F \u68c0\u9a8c\u548c chi2 \u662f\u4e00\u4e9b\u6700\u5e38\u7528\u7684\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u3002\u5728 scikit- learn \u4e2d\uff0c\u6709\u4e24\u79cd\u65b9\u6cd5\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e9b\u65b9\u6cd5\u3002 - SelectKBest\uff1a\u4fdd\u7559\u5f97\u5206\u6700\u9ad8\u7684 k \u4e2a\u7279\u5f81 - SelectPercentile\uff1a\u4fdd\u7559\u7528\u6237\u6307\u5b9a\u767e\u5206\u6bd4\u5185\u7684\u9876\u7ea7\u7279\u5f81\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u53ea\u6709\u975e\u8d1f\u6570\u636e\u624d\u80fd\u4f7f\u7528 chi2\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u5f53\u6211\u4eec\u6709\u4e00\u4e9b\u5355\u8bcd\u6216\u57fa\u4e8e tf-idf \u7684\u7279\u5f81\u65f6\uff0c\u8fd9\u662f\u4e00\u79cd\u7279\u522b\u6709\u7528\u7684\u7279\u5f81\u9009\u62e9\u6280\u672f\u3002\u6700\u597d\u4e3a\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u521b\u5efa\u4e00\u4e2a\u5305\u88c5\u5668\uff0c\u51e0\u4e4e\u53ef\u4ee5\u7528\u4e8e\u4efb\u4f55\u65b0\u95ee\u9898\u3002 from sklearn.feature_selection import chi2 from sklearn.feature_selection import f_classif from sklearn.feature_selection import f_regression from sklearn.feature_selection import mutual_info_classif from sklearn.feature_selection import mutual_info_regression from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectPercentile class UnivariateFeatureSelction : def __init__ ( self , n_features , problem_type , scoring ): # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u5206\u7c7b\u95ee\u9898 if problem_type == \"classification\" : # \u521b\u5efa\u5b57\u5178 valid_scoring \uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_classif\" : f_classif , \"chi2\" : chi2 , \"mutual_info_classif\" : mutual_info_classif } # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u56de\u5f52\u95ee\u9898 else : # \u521b\u5efa\u5b57\u5178 valid_scoring\uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_regression\" : f_regression , \"mutual_info_regression\" : mutual_info_regression } # \u68c0\u67e5\u7279\u5f81\u91cd\u8981\u6027\u65b9\u5f0f\u662f\u5426\u5728\u5b57\u5178\u4e2d if scoring not in valid_scoring : raise Exception ( \"Invalid scoring function\" ) # \u68c0\u67e5 n_features \u7684\u7c7b\u578b\uff0c\u5982\u679c\u662f\u6574\u6570\uff0c\u5219\u4f7f\u7528 SelectKBest \u8fdb\u884c\u7279\u5f81\u9009\u62e9 if isinstance ( n_features , int ): self . selection = SelectKBest ( valid_scoring [ scoring ], k = n_features ) # \u5982\u679c n_features \u662f\u6d6e\u70b9\u6570\uff0c\u5219\u4f7f\u7528 SelectPercentile \u8fdb\u884c\u7279\u5f81\u9009\u62e9 elif isinstance ( n_features , float ): self . selection = SelectPercentile ( valid_scoring [ scoring ], percentile = int ( n_features * 100 ) ) # \u5982\u679c n_features \u7c7b\u578b\u65e0\u6548\uff0c\u5f15\u53d1\u5f02\u5e38 else : raise Exception ( \"Invalid type of feature\" ) # \u5b9a\u4e49 fit \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 def fit ( self , X , y ): return self . selection . fit ( X , y ) # \u5b9a\u4e49 transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def transform ( self , X ): return self . selection . transform ( X ) # \u5b9a\u4e49 fit_transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668\u5e76\u540c\u65f6\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def fit_transform ( self , X , y ): return self . selection . fit_transform ( X , y ) \u4f7f\u7528\u8be5\u7c7b\u975e\u5e38\u7b80\u5355\u3002 # \u5b9e\u4f8b\u5316\u7279\u5f81\u9009\u62e9\u5668\uff0c\u4fdd\u7559\u524d10%\u7684\u7279\u5f81\uff0c\u56de\u5f52\u95ee\u9898\uff0c\u4f7f\u7528f_regression\u8861\u91cf\u7279\u5f81\u91cd\u8981\u6027 ufs = UnivariateFeatureSelction ( n_features = 0.1 , problem_type = \"regression\" , scoring = \"f_regression\" ) # \u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 ufs . fit ( X , y ) # \u7279\u5f81\u8f6c\u6362 X_transformed = ufs . transform ( X ) \u8fd9\u6837\u5c31\u80fd\u6ee1\u8db3\u5927\u90e8\u5206\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u7684\u9700\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u521b\u5efa\u8f83\u5c11\u800c\u91cd\u8981\u7684\u7279\u5f81\u901a\u5e38\u6bd4\u521b\u5efa\u6570\u4ee5\u767e\u8ba1\u7684\u7279\u5f81\u8981\u597d\u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u4e0d\u4e00\u5b9a\u603b\u662f\u8868\u73b0\u826f\u597d\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u4eba\u4eec\u66f4\u559c\u6b22\u4f7f\u7528\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 \u4f7f\u7528\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u7684\u6700\u7b80\u5355\u5f62\u5f0f\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u3002\u5728\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u4e2d\uff0c\u7b2c\u4e00\u6b65\u662f\u9009\u62e9\u4e00\u4e2a\u6a21\u578b\u3002\u7b2c\u4e8c\u6b65\u662f\u9009\u62e9\u635f\u5931/\u8bc4\u5206\u51fd\u6570\u3002\u7b2c\u4e09\u6b65\u4e5f\u662f\u6700\u540e\u4e00\u6b65\u662f\u53cd\u590d\u8bc4\u4f30\u6bcf\u4e2a\u7279\u5f81\uff0c\u5982\u679c\u80fd\u63d0\u9ad8\u635f\u5931/\u8bc4\u5206\uff0c\u5c31\u5c06\u5176\u6dfb\u52a0\u5230 \"\u597d \"\u7279\u5f81\u5217\u8868\u4e2d\u3002\u6ca1\u6709\u6bd4\u8fd9\u66f4\u7b80\u5355\u7684\u4e86\u3002\u4f46\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u8fd9\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u662f\u6709\u539f\u56e0\u7684\u3002\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u8fc7\u7a0b\u5728\u6bcf\u6b21\u8bc4\u4f30\u7279\u5f81\u65f6\u90fd\u4f1a\u9002\u5408\u7ed9\u5b9a\u7684\u6a21\u578b\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u8ba1\u7b97\u6210\u672c\u975e\u5e38\u9ad8\u3002\u5b8c\u6210\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u4e5f\u9700\u8981\u5927\u91cf\u65f6\u95f4\u3002\u5982\u679c\u4e0d\u6b63\u786e\u4f7f\u7528\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\uff0c\u751a\u81f3\u4f1a\u5bfc\u81f4\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn.datasets import make_classification class GreedyFeatureSelection : # \u5b9a\u4e49\u8bc4\u4f30\u5206\u6570\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd def evaluate_score ( self , X , y ): # \u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u8bad\u7ec3\u6a21\u578b model . fit ( X , y ) # \u9884\u6d4b\u6982\u7387\u503c predictions = model . predict_proba ( X )[:, 1 ] # \u8ba1\u7b97 AUC \u5206\u6570 auc = metrics . roc_auc_score ( y , predictions ) return auc # \u7279\u5f81\u9009\u62e9\u51fd\u6570 def _feature_selection ( self , X , y ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6700\u4f73\u7279\u5f81\u548c\u6700\u4f73\u5206\u6570 good_features = [] best_scores = [] # \u83b7\u53d6\u7279\u5f81\u6570\u91cf num_features = X . shape [ 1 ] # \u5f00\u59cb\u7279\u5f81\u9009\u62e9\u7684\u5faa\u73af while True : this_feature = None best_score = 0 # \u904d\u5386\u6bcf\u4e2a\u7279\u5f81 for feature in range ( num_features ): if feature in good_features : continue selected_features = good_features + [ feature ] xtrain = X [:, selected_features ] score = self . evaluate_score ( xtrain , y ) # \u5982\u679c\u5f53\u524d\u7279\u5f81\u7684\u5f97\u5206\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u5f97\u5206\uff0c\u5219\u66f4\u65b0 if score > best_score : this_feature = feature best_score = score # \u82e5\u627e\u5230\u4e86\u65b0\u7684\u6700\u4f73\u7279\u5f81 if this_feature != None : # \u7279\u5f81\u6dfb\u52a0\u5230 good_features \u5217\u8868 good_features . append ( this_feature ) # \u5f97\u5206\u6dfb\u52a0\u5230 best_scores \u5217\u8868 best_scores . append ( best_score ) # \u5982\u679c best_scores \u5217\u8868\u957f\u5ea6\u5927\u4e8e2\uff0c\u5e76\u4e14\u6700\u540e\u4e24\u4e2a\u5f97\u5206\u76f8\u6bd4\u8f83\u5dee\uff0c\u5219\u7ed3\u675f\u5faa\u73af if len ( best_scores ) > 2 : if best_scores [ - 1 ] < best_scores [ - 2 ]: break # \u8fd4\u56de\u6700\u4f73\u7279\u5f81\u7684\u5f97\u5206\u5217\u8868\u548c\u6700\u4f73\u7279\u5f81\u5217\u8868 return best_scores [: - 1 ], good_features [: - 1 ] # \u5b9a\u4e49\u7c7b\u7684\u8c03\u7528\u65b9\u6cd5\uff0c\u7528\u4e8e\u6267\u884c\u7279\u5f81\u9009\u62e9 def __call__ ( self , X , y ): scores , features = self . _feature_selection ( X , y ) return X [:, features ], scores if __name__ == \"__main__\" : # \u751f\u6210\u4e00\u4e2a\u793a\u4f8b\u7684\u5206\u7c7b\u6570\u636e\u96c6 X \u548c\u6807\u7b7e y X , y = make_classification ( n_samples = 1000 , n_features = 100 ) # \u5b9e\u4f8b\u5316 GreedyFeatureSelection \u7c7b\uff0c\u5e76\u4f7f\u7528 __call__ \u65b9\u6cd5\u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed , scores = GreedyFeatureSelection ()( X , y ) \u8fd9\u79cd\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u4f1a\u8fd4\u56de\u5206\u6570\u548c\u7279\u5f81\u7d22\u5f15\u5217\u8868\u3002\u56fe 2 \u663e\u793a\u4e86\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u589e\u52a0\u4e00\u4e2a\u65b0\u7279\u5f81\u540e\uff0c\u5206\u6570\u662f\u5982\u4f55\u63d0\u9ad8\u7684\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u67d0\u4e00\u70b9\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u65e0\u6cd5\u63d0\u9ad8\u5206\u6570\u4e86\uff0c\u8fd9\u5c31\u662f\u6211\u4eec\u505c\u6b62\u7684\u5730\u65b9\u3002 \u53e6\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u9012\u5f52\u7279\u5f81\u6d88\u9664\u6cd5\uff08RFE\uff09\u3002\u5728\u524d\u4e00\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a\u7279\u5f81\u5f00\u59cb\uff0c\u7136\u540e\u4e0d\u65ad\u6dfb\u52a0\u65b0\u7684\u7279\u5f81\uff0c\u4f46\u5728 RFE \u4e2d\uff0c\u6211\u4eec\u4ece\u6240\u6709\u7279\u5f81\u5f00\u59cb\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u4e0d\u65ad\u53bb\u9664\u4e00\u4e2a\u5bf9\u7ed9\u5b9a\u6a21\u578b\u63d0\u4f9b\u6700\u5c0f\u503c\u7684\u7279\u5f81\u3002\u4f46\u6211\u4eec\u5982\u4f55\u77e5\u9053\u54ea\u4e2a\u7279\u5f81\u7684\u4ef7\u503c\u6700\u5c0f\u5462\uff1f\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7ebf\u6027\u652f\u6301\u5411\u91cf\u673a\uff08SVM\uff09\u6216\u903b\u8f91\u56de\u5f52\u7b49\u6a21\u578b\uff0c\u6211\u4eec\u4f1a\u4e3a\u6bcf\u4e2a\u7279\u5f81\u5f97\u5230\u4e00\u4e2a\u7cfb\u6570\uff0c\u8be5\u7cfb\u6570\u51b3\u5b9a\u4e86\u7279\u5f81\u7684\u91cd\u8981\u6027\u3002\u800c\u5bf9\u4e8e\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u7279\u5f81\u91cd\u8981\u6027\uff0c\u800c\u4e0d\u662f\u7cfb\u6570\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u53ef\u4ee5\u5254\u9664\u6700\u4e0d\u91cd\u8981\u7684\u7279\u5f81\uff0c\u76f4\u5230\u8fbe\u5230\u6240\u9700\u7684\u7279\u5f81\u6570\u91cf\u4e3a\u6b62\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u51b3\u5b9a\u8981\u4fdd\u7559\u591a\u5c11\u7279\u5f81\u3002 \u56fe 2\uff1a\u589e\u52a0\u65b0\u7279\u5f81\u540e\uff0c\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7684 AUC \u5206\u6570\u5982\u4f55\u53d8\u5316 \u5f53\u6211\u4eec\u8fdb\u884c\u9012\u5f52\u7279\u5f81\u5254\u9664\u65f6\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u4f1a\u5254\u9664\u7279\u5f81\u91cd\u8981\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u6216\u7cfb\u6570\u63a5\u8fd1 0 \u7684\u7279\u5f81\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5f53\u4f60\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u8fd9\u6837\u7684\u6a21\u578b\u8fdb\u884c\u4e8c\u5143\u5206\u7c7b\u65f6\uff0c\u5982\u679c\u7279\u5f81\u5bf9\u6b63\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u6b63\uff0c\u800c\u5982\u679c\u7279\u5f81\u5bf9\u8d1f\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u8d1f\u3002\u4fee\u6539\u6211\u4eec\u7684\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7c7b\uff0c\u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u9012\u5f52\u7279\u5f81\u6d88\u9664\u7c7b\u975e\u5e38\u5bb9\u6613\uff0c\u4f46 scikit-learn \u4e5f\u63d0\u4f9b\u4e86 RFE\u3002\u4e0b\u9762\u7684\u793a\u4f8b\u5c55\u793a\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u7528\u6cd5\u3002 import pandas as pd from sklearn.feature_selection import RFE from sklearn.linear_model import LinearRegression from sklearn.datasets import fetch_california_housing data = fetch_california_housing () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] model = LinearRegression () # \u521b\u5efa RFE\uff08\u9012\u5f52\u7279\u5f81\u6d88\u9664\uff09\uff0c\u6307\u5b9a\u6a21\u578b\u4e3a\u7ebf\u6027\u56de\u5f52\u6a21\u578b\uff0c\u8981\u9009\u62e9\u7684\u7279\u5f81\u6570\u91cf\u4e3a 3 rfe = RFE ( estimator = model , n_features_to_select = 3 ) # \u8bad\u7ec3\u6a21\u578b rfe . fit ( X , y ) # \u4f7f\u7528 RFE \u9009\u62e9\u7684\u7279\u5f81\u8fdb\u884c\u6570\u636e\u8f6c\u6362 X_transformed = rfe . transform ( X ) \u6211\u4eec\u770b\u5230\u4e86\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u7684\u4e24\u79cd\u4e0d\u540c\u7684\u8d2a\u5a6a\u65b9\u6cd5\u3002\u4f46\u4e5f\u53ef\u4ee5\u6839\u636e\u6570\u636e\u62df\u5408\u6a21\u578b\uff0c\u7136\u540e\u901a\u8fc7\u7279\u5f81\u7cfb\u6570\u6216\u7279\u5f81\u7684\u91cd\u8981\u6027\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u5982\u679c\u4f7f\u7528\u7cfb\u6570\uff0c\u5219\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u9608\u503c\uff0c\u5982\u679c\u7cfb\u6570\u9ad8\u4e8e\u8be5\u9608\u503c\uff0c\u5219\u53ef\u4ee5\u4fdd\u7559\u8be5\u7279\u5f81\uff0c\u5426\u5219\u5c06\u5176\u5254\u9664\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4ece\u968f\u673a\u68ee\u6797\u8fd9\u6837\u7684\u6a21\u578b\u4e2d\u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u5b9e\u4f8b\u5316\u968f\u673a\u68ee\u6797\u6a21\u578b model = RandomForestRegressor () # \u62df\u5408\u6a21\u578b model . fit ( X , y ) \u968f\u673a\u68ee\u6797\uff08\u6216\u4efb\u4f55\u6a21\u578b\uff09\u7684\u7279\u5f81\u91cd\u8981\u6027\u53ef\u6309\u5982\u4e0b\u65b9\u5f0f\u7ed8\u5236\u3002 # \u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027 importances = model . feature_importances_ # \u964d\u5e8f\u6392\u5217 idxs = np . argsort ( importances ) # \u8bbe\u5b9a\u6807\u9898 plt . title ( 'Feature Importances' ) # \u521b\u5efa\u76f4\u65b9\u56fe plt . barh ( range ( len ( idxs )), importances [ idxs ], align = 'center' ) # y\u8f74\u6807\u7b7e plt . yticks ( range ( len ( idxs )), [ col_names [ i ] for i in idxs ]) # x\u8f74\u6807\u7b7e plt . xlabel ( 'Random Forest Feature Importance' ) plt . show () \u7ed3\u679c\u5982\u56fe 3 \u6240\u793a\u3002 \u56fe 3\uff1a\u7279\u5f81\u91cd\u8981\u6027\u56fe \u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u6700\u4f73\u7279\u5f81\u5e76\u4e0d\u662f\u4ec0\u4e48\u65b0\u9c9c\u4e8b\u3002\u60a8\u53ef\u4ee5\u4ece\u4e00\u4e2a\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u53e6\u4e00\u4e2a\u6a21\u578b\u8fdb\u884c\u8bad\u7ec3\u3002\u4f8b\u5982\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u7cfb\u6570\u6765\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u968f\u673a\u68ee\u6797\uff08Random Forest\uff09\u5bf9\u6240\u9009\u7279\u5f81\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 SelectFromModel \u7c7b\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u76f4\u63a5\u4ece\u7ed9\u5b9a\u7684\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u6839\u636e\u9700\u8981\u6307\u5b9a\u7cfb\u6570\u6216\u7279\u5f81\u91cd\u8981\u6027\u7684\u9608\u503c\uff0c\u4ee5\u53ca\u8981\u9009\u62e9\u7684\u7279\u5f81\u7684\u6700\u5927\u6570\u91cf\u3002 \u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\uff0c\u6211\u4eec\u4f7f\u7528 SelectFromModel \u4e2d\u7684\u9ed8\u8ba4\u53c2\u6570\u6765\u9009\u62e9\u7279\u5f81\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor from sklearn.feature_selection import SelectFromModel data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u521b\u5efa\u968f\u673a\u68ee\u6797\u6a21\u578b\u56de\u5f52\u6a21\u578b model = RandomForestRegressor () # \u521b\u5efa SelectFromModel \u5bf9\u8c61 sfm\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u6a21\u578b\u4f5c\u4e3a\u4f30\u7b97\u5668 sfm = SelectFromModel ( estimator = model ) # \u4f7f\u7528 sfm \u5bf9\u7279\u5f81\u77e9\u9635 X \u548c\u76ee\u6807\u53d8\u91cf y \u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed = sfm . fit_transform ( X , y ) # \u83b7\u53d6\u7ecf\u8fc7\u7279\u5f81\u9009\u62e9\u540e\u7684\u7279\u5f81\u63a9\u7801\uff08True \u8868\u793a\u7279\u5f81\u88ab\u9009\u62e9\uff0cFalse \u8868\u793a\u7279\u5f81\u672a\u88ab\u9009\u62e9\uff09 support = sfm . get_support () # \u6253\u5370\u88ab\u9009\u62e9\u7684\u7279\u5f81\u5217\u540d print ([ x for x , y in zip ( col_names , support ) if y == True ]) \u4e0a\u9762\u7a0b\u5e8f\u6253\u5370\u7ed3\u679c\uff1a ['bmi'\uff0c's5']\u3002\u6211\u4eec\u518d\u770b\u56fe 3\uff0c\u5c31\u4f1a\u53d1\u73b0\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u4e2a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u76f4\u63a5\u4ece\u968f\u673a\u68ee\u6797\u63d0\u4f9b\u7684\u7279\u5f81\u91cd\u8981\u6027\u4e2d\u8fdb\u884c\u9009\u62e9\u3002\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4ef6\u4e8b\uff0c\u90a3\u5c31\u662f\u4f7f\u7528 L1\uff08Lasso\uff09\u60e9\u7f5a\u6a21\u578b \u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u5f53\u6211\u4eec\u4f7f\u7528 L1 \u60e9\u7f5a\u8fdb\u884c\u6b63\u5219\u5316\u65f6\uff0c\u5927\u90e8\u5206\u7cfb\u6570\u90fd\u5c06\u4e3a 0\uff08\u6216\u63a5\u8fd1 0\uff09\uff0c\u56e0\u6b64\u6211\u4eec\u8981\u9009\u62e9\u7cfb\u6570\u4e0d\u4e3a 0 \u7684\u7279\u5f81\u3002\u53ea\u9700\u5c06\u6a21\u578b\u9009\u62e9\u7247\u6bb5\u4e2d\u7684\u968f\u673a\u68ee\u6797\u66ff\u6362\u4e3a\u652f\u6301 L1 \u60e9\u7f5a\u7684\u6a21\u578b\uff08\u5982 lasso \u56de\u5f52\uff09\u5373\u53ef\u3002\u6240\u6709\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u90fd\u63d0\u4f9b\u7279\u5f81\u91cd\u8981\u6027\uff0c\u56e0\u6b64\u672c\u7ae0\u4e2d\u5c55\u793a\u7684\u6240\u6709\u57fa\u4e8e\u6a21\u578b\u7684\u7247\u6bb5\u90fd\u53ef\u7528\u4e8e XGBoost\u3001LightGBM \u6216 CatBoost\u3002\u7279\u5f81\u91cd\u8981\u6027\u51fd\u6570\u7684\u540d\u79f0\u53ef\u80fd\u4e0d\u540c\uff0c\u4ea7\u751f\u7ed3\u679c\u7684\u683c\u5f0f\u4e5f\u53ef\u80fd\u4e0d\u540c\uff0c\u4f46\u7528\u6cd5\u662f\u4e00\u6837\u7684\u3002\u6700\u540e\uff0c\u5728\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u65f6\u5fc5\u987b\u5c0f\u5fc3\u8c28\u614e\u3002\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u9009\u62e9\u7279\u5f81\uff0c\u5e76\u5728\u9a8c\u8bc1\u6570\u636e\u4e0a\u9a8c\u8bc1\u6a21\u578b\uff0c\u4ee5\u4fbf\u5728\u4e0d\u8fc7\u5ea6\u62df\u5408\u6a21\u578b\u7684\u60c5\u51b5\u4e0b\u6b63\u786e\u9009\u62e9\u7279\u5f81\u3002","title":"\u7279\u5f81\u9009\u62e9"},{"location":"%E7%89%B9%E5%BE%81%E9%80%89%E6%8B%A9/#_1","text":"\u5f53\u4f60\u521b\u5efa\u4e86\u6210\u5343\u4e0a\u4e07\u4e2a\u7279\u5f81\u540e\uff0c\u5c31\u8be5\u4ece\u4e2d\u6311\u9009\u51fa\u51e0\u4e2a\u4e86\u3002\u4f46\u662f\uff0c\u6211\u4eec\u7edd\u4e0d\u5e94\u8be5\u521b\u5efa\u6210\u767e\u4e0a\u5343\u4e2a\u65e0\u7528\u7684\u7279\u5f81\u3002\u7279\u5f81\u8fc7\u591a\u4f1a\u5e26\u6765\u4e00\u4e2a\u4f17\u6240\u5468\u77e5\u7684\u95ee\u9898\uff0c\u5373 \"\u7ef4\u5ea6\u8bc5\u5492\"\u3002\u5982\u679c\u4f60\u6709\u5f88\u591a\u7279\u5f81\uff0c\u4f60\u4e5f\u5fc5\u987b\u6709\u5f88\u591a\u8bad\u7ec3\u6837\u672c\u6765\u6355\u6349\u6240\u6709\u7279\u5f81\u3002\u4ec0\u4e48\u662f \"\u5927\u91cf \"\u5e76\u6ca1\u6709\u6b63\u786e\u7684\u5b9a\u4e49\uff0c\u8fd9\u9700\u8981\u60a8\u901a\u8fc7\u6b63\u786e\u9a8c\u8bc1\u60a8\u7684\u6a21\u578b\u548c\u68c0\u67e5\u8bad\u7ec3\u6a21\u578b\u6240\u9700\u7684\u65f6\u95f4\u6765\u786e\u5b9a\u3002 \u9009\u62e9\u7279\u5f81\u7684\u6700\u7b80\u5355\u65b9\u6cd5\u662f \u5220\u9664\u65b9\u5dee\u975e\u5e38\u5c0f\u7684\u7279\u5f81 \u3002\u5982\u679c\u7279\u5f81\u7684\u65b9\u5dee\u975e\u5e38\u5c0f\uff08\u5373\u975e\u5e38\u63a5\u8fd1\u4e8e 0\uff09\uff0c\u5b83\u4eec\u5c31\u63a5\u8fd1\u4e8e\u5e38\u91cf\uff0c\u56e0\u6b64\u6839\u672c\u4e0d\u4f1a\u7ed9\u4efb\u4f55\u6a21\u578b\u589e\u52a0\u4efb\u4f55\u4ef7\u503c\u3002\u6700\u597d\u7684\u529e\u6cd5\u5c31\u662f\u53bb\u6389\u5b83\u4eec\uff0c\u4ece\u800c\u964d\u4f4e\u590d\u6742\u5ea6\u3002\u8bf7\u6ce8\u610f\uff0c\u65b9\u5dee\u4e5f\u53d6\u51b3\u4e8e\u6570\u636e\u7684\u7f29\u653e\u3002 Scikit-learn \u7684 VarianceThreshold \u5b9e\u73b0\u4e86\u8fd9\u4e00\u70b9\u3002 from sklearn.feature_selection import VarianceThreshold data = ... # \u521b\u5efa VarianceThreshold \u5bf9\u8c61 var_thresh\uff0c\u6307\u5b9a\u65b9\u5dee\u9608\u503c\u4e3a 0.1 var_thresh = VarianceThreshold ( threshold = 0.1 ) # \u4f7f\u7528 var_thresh \u5bf9\u6570\u636e data \u8fdb\u884c\u62df\u5408\u548c\u53d8\u6362\uff0c\u5c06\u65b9\u5dee\u4f4e\u4e8e\u9608\u503c\u7684\u7279\u5f81\u79fb\u9664 transformed_data = var_thresh . fit_transform ( data ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u5220\u9664\u76f8\u5173\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u3002\u8981\u8ba1\u7b97\u4e0d\u540c\u6570\u5b57\u7279\u5f81\u4e4b\u95f4\u7684\u76f8\u5173\u6027\uff0c\u53ef\u4ee5\u4f7f\u7528\u76ae\u5c14\u900a\u76f8\u5173\u6027\u3002 import pandas as pd from sklearn.datasets import fetch_california_housing # \u52a0\u8f7d\u6570\u636e data = fetch_california_housing () # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u77e9\u9635 X X = data [ \"data\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u7684\u5217\u540d col_names = data [ \"feature_names\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf y y = data [ \"target\" ] df = pd . DataFrame ( X , columns = col_names ) # \u6dfb\u52a0 MedInc_Sqrt \u5217\uff0c\u662f MedInc \u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u5e73\u65b9\u6839\u8fd0\u7b97\u7684\u7ed3\u679c df . loc [:, \"MedInc_Sqrt\" ] = df . MedInc . apply ( np . sqrt ) # \u8ba1\u7b97\u76ae\u5c14\u900a\u76f8\u5173\u6027\u77e9\u9635 df . corr () \u5f97\u51fa\u76f8\u5173\u77e9\u9635\uff0c\u5982\u56fe 1 \u6240\u793a\u3002 \u56fe 1\uff1a\u76ae\u5c14\u900a\u76f8\u5173\u77e9\u9635\u6837\u672c \u6211\u4eec\u770b\u5230\uff0cMedInc_Sqrt \u4e0e MedInc \u7684\u76f8\u5173\u6027\u975e\u5e38\u9ad8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5220\u9664\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u4e00\u4e9b \u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5 \u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u53ea\u4e0d\u8fc7\u662f\u9488\u5bf9\u7ed9\u5b9a\u76ee\u6807\u5bf9\u6bcf\u4e2a\u7279\u5f81\u8fdb\u884c\u8bc4\u5206\u3002 \u4e92\u4fe1\u606f \u3001 \u65b9\u5dee\u5206\u6790 F \u68c0\u9a8c\u548c chi2 \u662f\u4e00\u4e9b\u6700\u5e38\u7528\u7684\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u3002\u5728 scikit- learn \u4e2d\uff0c\u6709\u4e24\u79cd\u65b9\u6cd5\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e9b\u65b9\u6cd5\u3002 - SelectKBest\uff1a\u4fdd\u7559\u5f97\u5206\u6700\u9ad8\u7684 k \u4e2a\u7279\u5f81 - SelectPercentile\uff1a\u4fdd\u7559\u7528\u6237\u6307\u5b9a\u767e\u5206\u6bd4\u5185\u7684\u9876\u7ea7\u7279\u5f81\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u53ea\u6709\u975e\u8d1f\u6570\u636e\u624d\u80fd\u4f7f\u7528 chi2\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u5f53\u6211\u4eec\u6709\u4e00\u4e9b\u5355\u8bcd\u6216\u57fa\u4e8e tf-idf \u7684\u7279\u5f81\u65f6\uff0c\u8fd9\u662f\u4e00\u79cd\u7279\u522b\u6709\u7528\u7684\u7279\u5f81\u9009\u62e9\u6280\u672f\u3002\u6700\u597d\u4e3a\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u521b\u5efa\u4e00\u4e2a\u5305\u88c5\u5668\uff0c\u51e0\u4e4e\u53ef\u4ee5\u7528\u4e8e\u4efb\u4f55\u65b0\u95ee\u9898\u3002 from sklearn.feature_selection import chi2 from sklearn.feature_selection import f_classif from sklearn.feature_selection import f_regression from sklearn.feature_selection import mutual_info_classif from sklearn.feature_selection import mutual_info_regression from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectPercentile class UnivariateFeatureSelction : def __init__ ( self , n_features , problem_type , scoring ): # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u5206\u7c7b\u95ee\u9898 if problem_type == \"classification\" : # \u521b\u5efa\u5b57\u5178 valid_scoring \uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_classif\" : f_classif , \"chi2\" : chi2 , \"mutual_info_classif\" : mutual_info_classif } # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u56de\u5f52\u95ee\u9898 else : # \u521b\u5efa\u5b57\u5178 valid_scoring\uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_regression\" : f_regression , \"mutual_info_regression\" : mutual_info_regression } # \u68c0\u67e5\u7279\u5f81\u91cd\u8981\u6027\u65b9\u5f0f\u662f\u5426\u5728\u5b57\u5178\u4e2d if scoring not in valid_scoring : raise Exception ( \"Invalid scoring function\" ) # \u68c0\u67e5 n_features \u7684\u7c7b\u578b\uff0c\u5982\u679c\u662f\u6574\u6570\uff0c\u5219\u4f7f\u7528 SelectKBest \u8fdb\u884c\u7279\u5f81\u9009\u62e9 if isinstance ( n_features , int ): self . selection = SelectKBest ( valid_scoring [ scoring ], k = n_features ) # \u5982\u679c n_features \u662f\u6d6e\u70b9\u6570\uff0c\u5219\u4f7f\u7528 SelectPercentile \u8fdb\u884c\u7279\u5f81\u9009\u62e9 elif isinstance ( n_features , float ): self . selection = SelectPercentile ( valid_scoring [ scoring ], percentile = int ( n_features * 100 ) ) # \u5982\u679c n_features \u7c7b\u578b\u65e0\u6548\uff0c\u5f15\u53d1\u5f02\u5e38 else : raise Exception ( \"Invalid type of feature\" ) # \u5b9a\u4e49 fit \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 def fit ( self , X , y ): return self . selection . fit ( X , y ) # \u5b9a\u4e49 transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def transform ( self , X ): return self . selection . transform ( X ) # \u5b9a\u4e49 fit_transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668\u5e76\u540c\u65f6\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def fit_transform ( self , X , y ): return self . selection . fit_transform ( X , y ) \u4f7f\u7528\u8be5\u7c7b\u975e\u5e38\u7b80\u5355\u3002 # \u5b9e\u4f8b\u5316\u7279\u5f81\u9009\u62e9\u5668\uff0c\u4fdd\u7559\u524d10%\u7684\u7279\u5f81\uff0c\u56de\u5f52\u95ee\u9898\uff0c\u4f7f\u7528f_regression\u8861\u91cf\u7279\u5f81\u91cd\u8981\u6027 ufs = UnivariateFeatureSelction ( n_features = 0.1 , problem_type = \"regression\" , scoring = \"f_regression\" ) # \u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 ufs . fit ( X , y ) # \u7279\u5f81\u8f6c\u6362 X_transformed = ufs . transform ( X ) \u8fd9\u6837\u5c31\u80fd\u6ee1\u8db3\u5927\u90e8\u5206\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u7684\u9700\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u521b\u5efa\u8f83\u5c11\u800c\u91cd\u8981\u7684\u7279\u5f81\u901a\u5e38\u6bd4\u521b\u5efa\u6570\u4ee5\u767e\u8ba1\u7684\u7279\u5f81\u8981\u597d\u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u4e0d\u4e00\u5b9a\u603b\u662f\u8868\u73b0\u826f\u597d\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u4eba\u4eec\u66f4\u559c\u6b22\u4f7f\u7528\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 \u4f7f\u7528\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u7684\u6700\u7b80\u5355\u5f62\u5f0f\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u3002\u5728\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u4e2d\uff0c\u7b2c\u4e00\u6b65\u662f\u9009\u62e9\u4e00\u4e2a\u6a21\u578b\u3002\u7b2c\u4e8c\u6b65\u662f\u9009\u62e9\u635f\u5931/\u8bc4\u5206\u51fd\u6570\u3002\u7b2c\u4e09\u6b65\u4e5f\u662f\u6700\u540e\u4e00\u6b65\u662f\u53cd\u590d\u8bc4\u4f30\u6bcf\u4e2a\u7279\u5f81\uff0c\u5982\u679c\u80fd\u63d0\u9ad8\u635f\u5931/\u8bc4\u5206\uff0c\u5c31\u5c06\u5176\u6dfb\u52a0\u5230 \"\u597d \"\u7279\u5f81\u5217\u8868\u4e2d\u3002\u6ca1\u6709\u6bd4\u8fd9\u66f4\u7b80\u5355\u7684\u4e86\u3002\u4f46\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u8fd9\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u662f\u6709\u539f\u56e0\u7684\u3002\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u8fc7\u7a0b\u5728\u6bcf\u6b21\u8bc4\u4f30\u7279\u5f81\u65f6\u90fd\u4f1a\u9002\u5408\u7ed9\u5b9a\u7684\u6a21\u578b\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u8ba1\u7b97\u6210\u672c\u975e\u5e38\u9ad8\u3002\u5b8c\u6210\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u4e5f\u9700\u8981\u5927\u91cf\u65f6\u95f4\u3002\u5982\u679c\u4e0d\u6b63\u786e\u4f7f\u7528\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\uff0c\u751a\u81f3\u4f1a\u5bfc\u81f4\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn.datasets import make_classification class GreedyFeatureSelection : # \u5b9a\u4e49\u8bc4\u4f30\u5206\u6570\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd def evaluate_score ( self , X , y ): # \u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u8bad\u7ec3\u6a21\u578b model . fit ( X , y ) # \u9884\u6d4b\u6982\u7387\u503c predictions = model . predict_proba ( X )[:, 1 ] # \u8ba1\u7b97 AUC \u5206\u6570 auc = metrics . roc_auc_score ( y , predictions ) return auc # \u7279\u5f81\u9009\u62e9\u51fd\u6570 def _feature_selection ( self , X , y ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6700\u4f73\u7279\u5f81\u548c\u6700\u4f73\u5206\u6570 good_features = [] best_scores = [] # \u83b7\u53d6\u7279\u5f81\u6570\u91cf num_features = X . shape [ 1 ] # \u5f00\u59cb\u7279\u5f81\u9009\u62e9\u7684\u5faa\u73af while True : this_feature = None best_score = 0 # \u904d\u5386\u6bcf\u4e2a\u7279\u5f81 for feature in range ( num_features ): if feature in good_features : continue selected_features = good_features + [ feature ] xtrain = X [:, selected_features ] score = self . evaluate_score ( xtrain , y ) # \u5982\u679c\u5f53\u524d\u7279\u5f81\u7684\u5f97\u5206\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u5f97\u5206\uff0c\u5219\u66f4\u65b0 if score > best_score : this_feature = feature best_score = score # \u82e5\u627e\u5230\u4e86\u65b0\u7684\u6700\u4f73\u7279\u5f81 if this_feature != None : # \u7279\u5f81\u6dfb\u52a0\u5230 good_features \u5217\u8868 good_features . append ( this_feature ) # \u5f97\u5206\u6dfb\u52a0\u5230 best_scores \u5217\u8868 best_scores . append ( best_score ) # \u5982\u679c best_scores \u5217\u8868\u957f\u5ea6\u5927\u4e8e2\uff0c\u5e76\u4e14\u6700\u540e\u4e24\u4e2a\u5f97\u5206\u76f8\u6bd4\u8f83\u5dee\uff0c\u5219\u7ed3\u675f\u5faa\u73af if len ( best_scores ) > 2 : if best_scores [ - 1 ] < best_scores [ - 2 ]: break # \u8fd4\u56de\u6700\u4f73\u7279\u5f81\u7684\u5f97\u5206\u5217\u8868\u548c\u6700\u4f73\u7279\u5f81\u5217\u8868 return best_scores [: - 1 ], good_features [: - 1 ] # \u5b9a\u4e49\u7c7b\u7684\u8c03\u7528\u65b9\u6cd5\uff0c\u7528\u4e8e\u6267\u884c\u7279\u5f81\u9009\u62e9 def __call__ ( self , X , y ): scores , features = self . _feature_selection ( X , y ) return X [:, features ], scores if __name__ == \"__main__\" : # \u751f\u6210\u4e00\u4e2a\u793a\u4f8b\u7684\u5206\u7c7b\u6570\u636e\u96c6 X \u548c\u6807\u7b7e y X , y = make_classification ( n_samples = 1000 , n_features = 100 ) # \u5b9e\u4f8b\u5316 GreedyFeatureSelection \u7c7b\uff0c\u5e76\u4f7f\u7528 __call__ \u65b9\u6cd5\u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed , scores = GreedyFeatureSelection ()( X , y ) \u8fd9\u79cd\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u4f1a\u8fd4\u56de\u5206\u6570\u548c\u7279\u5f81\u7d22\u5f15\u5217\u8868\u3002\u56fe 2 \u663e\u793a\u4e86\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u589e\u52a0\u4e00\u4e2a\u65b0\u7279\u5f81\u540e\uff0c\u5206\u6570\u662f\u5982\u4f55\u63d0\u9ad8\u7684\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u67d0\u4e00\u70b9\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u65e0\u6cd5\u63d0\u9ad8\u5206\u6570\u4e86\uff0c\u8fd9\u5c31\u662f\u6211\u4eec\u505c\u6b62\u7684\u5730\u65b9\u3002 \u53e6\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u9012\u5f52\u7279\u5f81\u6d88\u9664\u6cd5\uff08RFE\uff09\u3002\u5728\u524d\u4e00\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a\u7279\u5f81\u5f00\u59cb\uff0c\u7136\u540e\u4e0d\u65ad\u6dfb\u52a0\u65b0\u7684\u7279\u5f81\uff0c\u4f46\u5728 RFE \u4e2d\uff0c\u6211\u4eec\u4ece\u6240\u6709\u7279\u5f81\u5f00\u59cb\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u4e0d\u65ad\u53bb\u9664\u4e00\u4e2a\u5bf9\u7ed9\u5b9a\u6a21\u578b\u63d0\u4f9b\u6700\u5c0f\u503c\u7684\u7279\u5f81\u3002\u4f46\u6211\u4eec\u5982\u4f55\u77e5\u9053\u54ea\u4e2a\u7279\u5f81\u7684\u4ef7\u503c\u6700\u5c0f\u5462\uff1f\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7ebf\u6027\u652f\u6301\u5411\u91cf\u673a\uff08SVM\uff09\u6216\u903b\u8f91\u56de\u5f52\u7b49\u6a21\u578b\uff0c\u6211\u4eec\u4f1a\u4e3a\u6bcf\u4e2a\u7279\u5f81\u5f97\u5230\u4e00\u4e2a\u7cfb\u6570\uff0c\u8be5\u7cfb\u6570\u51b3\u5b9a\u4e86\u7279\u5f81\u7684\u91cd\u8981\u6027\u3002\u800c\u5bf9\u4e8e\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u7279\u5f81\u91cd\u8981\u6027\uff0c\u800c\u4e0d\u662f\u7cfb\u6570\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u53ef\u4ee5\u5254\u9664\u6700\u4e0d\u91cd\u8981\u7684\u7279\u5f81\uff0c\u76f4\u5230\u8fbe\u5230\u6240\u9700\u7684\u7279\u5f81\u6570\u91cf\u4e3a\u6b62\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u51b3\u5b9a\u8981\u4fdd\u7559\u591a\u5c11\u7279\u5f81\u3002 \u56fe 2\uff1a\u589e\u52a0\u65b0\u7279\u5f81\u540e\uff0c\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7684 AUC \u5206\u6570\u5982\u4f55\u53d8\u5316 \u5f53\u6211\u4eec\u8fdb\u884c\u9012\u5f52\u7279\u5f81\u5254\u9664\u65f6\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u4f1a\u5254\u9664\u7279\u5f81\u91cd\u8981\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u6216\u7cfb\u6570\u63a5\u8fd1 0 \u7684\u7279\u5f81\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5f53\u4f60\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u8fd9\u6837\u7684\u6a21\u578b\u8fdb\u884c\u4e8c\u5143\u5206\u7c7b\u65f6\uff0c\u5982\u679c\u7279\u5f81\u5bf9\u6b63\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u6b63\uff0c\u800c\u5982\u679c\u7279\u5f81\u5bf9\u8d1f\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u8d1f\u3002\u4fee\u6539\u6211\u4eec\u7684\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7c7b\uff0c\u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u9012\u5f52\u7279\u5f81\u6d88\u9664\u7c7b\u975e\u5e38\u5bb9\u6613\uff0c\u4f46 scikit-learn \u4e5f\u63d0\u4f9b\u4e86 RFE\u3002\u4e0b\u9762\u7684\u793a\u4f8b\u5c55\u793a\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u7528\u6cd5\u3002 import pandas as pd from sklearn.feature_selection import RFE from sklearn.linear_model import LinearRegression from sklearn.datasets import fetch_california_housing data = fetch_california_housing () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] model = LinearRegression () # \u521b\u5efa RFE\uff08\u9012\u5f52\u7279\u5f81\u6d88\u9664\uff09\uff0c\u6307\u5b9a\u6a21\u578b\u4e3a\u7ebf\u6027\u56de\u5f52\u6a21\u578b\uff0c\u8981\u9009\u62e9\u7684\u7279\u5f81\u6570\u91cf\u4e3a 3 rfe = RFE ( estimator = model , n_features_to_select = 3 ) # \u8bad\u7ec3\u6a21\u578b rfe . fit ( X , y ) # \u4f7f\u7528 RFE \u9009\u62e9\u7684\u7279\u5f81\u8fdb\u884c\u6570\u636e\u8f6c\u6362 X_transformed = rfe . transform ( X ) \u6211\u4eec\u770b\u5230\u4e86\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u7684\u4e24\u79cd\u4e0d\u540c\u7684\u8d2a\u5a6a\u65b9\u6cd5\u3002\u4f46\u4e5f\u53ef\u4ee5\u6839\u636e\u6570\u636e\u62df\u5408\u6a21\u578b\uff0c\u7136\u540e\u901a\u8fc7\u7279\u5f81\u7cfb\u6570\u6216\u7279\u5f81\u7684\u91cd\u8981\u6027\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u5982\u679c\u4f7f\u7528\u7cfb\u6570\uff0c\u5219\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u9608\u503c\uff0c\u5982\u679c\u7cfb\u6570\u9ad8\u4e8e\u8be5\u9608\u503c\uff0c\u5219\u53ef\u4ee5\u4fdd\u7559\u8be5\u7279\u5f81\uff0c\u5426\u5219\u5c06\u5176\u5254\u9664\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4ece\u968f\u673a\u68ee\u6797\u8fd9\u6837\u7684\u6a21\u578b\u4e2d\u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u5b9e\u4f8b\u5316\u968f\u673a\u68ee\u6797\u6a21\u578b model = RandomForestRegressor () # \u62df\u5408\u6a21\u578b model . fit ( X , y ) \u968f\u673a\u68ee\u6797\uff08\u6216\u4efb\u4f55\u6a21\u578b\uff09\u7684\u7279\u5f81\u91cd\u8981\u6027\u53ef\u6309\u5982\u4e0b\u65b9\u5f0f\u7ed8\u5236\u3002 # \u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027 importances = model . feature_importances_ # \u964d\u5e8f\u6392\u5217 idxs = np . argsort ( importances ) # \u8bbe\u5b9a\u6807\u9898 plt . title ( 'Feature Importances' ) # \u521b\u5efa\u76f4\u65b9\u56fe plt . barh ( range ( len ( idxs )), importances [ idxs ], align = 'center' ) # y\u8f74\u6807\u7b7e plt . yticks ( range ( len ( idxs )), [ col_names [ i ] for i in idxs ]) # x\u8f74\u6807\u7b7e plt . xlabel ( 'Random Forest Feature Importance' ) plt . show () \u7ed3\u679c\u5982\u56fe 3 \u6240\u793a\u3002 \u56fe 3\uff1a\u7279\u5f81\u91cd\u8981\u6027\u56fe \u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u6700\u4f73\u7279\u5f81\u5e76\u4e0d\u662f\u4ec0\u4e48\u65b0\u9c9c\u4e8b\u3002\u60a8\u53ef\u4ee5\u4ece\u4e00\u4e2a\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u53e6\u4e00\u4e2a\u6a21\u578b\u8fdb\u884c\u8bad\u7ec3\u3002\u4f8b\u5982\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u7cfb\u6570\u6765\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u968f\u673a\u68ee\u6797\uff08Random Forest\uff09\u5bf9\u6240\u9009\u7279\u5f81\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 SelectFromModel \u7c7b\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u76f4\u63a5\u4ece\u7ed9\u5b9a\u7684\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u6839\u636e\u9700\u8981\u6307\u5b9a\u7cfb\u6570\u6216\u7279\u5f81\u91cd\u8981\u6027\u7684\u9608\u503c\uff0c\u4ee5\u53ca\u8981\u9009\u62e9\u7684\u7279\u5f81\u7684\u6700\u5927\u6570\u91cf\u3002 \u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\uff0c\u6211\u4eec\u4f7f\u7528 SelectFromModel \u4e2d\u7684\u9ed8\u8ba4\u53c2\u6570\u6765\u9009\u62e9\u7279\u5f81\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor from sklearn.feature_selection import SelectFromModel data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u521b\u5efa\u968f\u673a\u68ee\u6797\u6a21\u578b\u56de\u5f52\u6a21\u578b model = RandomForestRegressor () # \u521b\u5efa SelectFromModel \u5bf9\u8c61 sfm\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u6a21\u578b\u4f5c\u4e3a\u4f30\u7b97\u5668 sfm = SelectFromModel ( estimator = model ) # \u4f7f\u7528 sfm \u5bf9\u7279\u5f81\u77e9\u9635 X \u548c\u76ee\u6807\u53d8\u91cf y \u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed = sfm . fit_transform ( X , y ) # \u83b7\u53d6\u7ecf\u8fc7\u7279\u5f81\u9009\u62e9\u540e\u7684\u7279\u5f81\u63a9\u7801\uff08True \u8868\u793a\u7279\u5f81\u88ab\u9009\u62e9\uff0cFalse \u8868\u793a\u7279\u5f81\u672a\u88ab\u9009\u62e9\uff09 support = sfm . get_support () # \u6253\u5370\u88ab\u9009\u62e9\u7684\u7279\u5f81\u5217\u540d print ([ x for x , y in zip ( col_names , support ) if y == True ]) \u4e0a\u9762\u7a0b\u5e8f\u6253\u5370\u7ed3\u679c\uff1a ['bmi'\uff0c's5']\u3002\u6211\u4eec\u518d\u770b\u56fe 3\uff0c\u5c31\u4f1a\u53d1\u73b0\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u4e2a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u76f4\u63a5\u4ece\u968f\u673a\u68ee\u6797\u63d0\u4f9b\u7684\u7279\u5f81\u91cd\u8981\u6027\u4e2d\u8fdb\u884c\u9009\u62e9\u3002\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4ef6\u4e8b\uff0c\u90a3\u5c31\u662f\u4f7f\u7528 L1\uff08Lasso\uff09\u60e9\u7f5a\u6a21\u578b \u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u5f53\u6211\u4eec\u4f7f\u7528 L1 \u60e9\u7f5a\u8fdb\u884c\u6b63\u5219\u5316\u65f6\uff0c\u5927\u90e8\u5206\u7cfb\u6570\u90fd\u5c06\u4e3a 0\uff08\u6216\u63a5\u8fd1 0\uff09\uff0c\u56e0\u6b64\u6211\u4eec\u8981\u9009\u62e9\u7cfb\u6570\u4e0d\u4e3a 0 \u7684\u7279\u5f81\u3002\u53ea\u9700\u5c06\u6a21\u578b\u9009\u62e9\u7247\u6bb5\u4e2d\u7684\u968f\u673a\u68ee\u6797\u66ff\u6362\u4e3a\u652f\u6301 L1 \u60e9\u7f5a\u7684\u6a21\u578b\uff08\u5982 lasso \u56de\u5f52\uff09\u5373\u53ef\u3002\u6240\u6709\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u90fd\u63d0\u4f9b\u7279\u5f81\u91cd\u8981\u6027\uff0c\u56e0\u6b64\u672c\u7ae0\u4e2d\u5c55\u793a\u7684\u6240\u6709\u57fa\u4e8e\u6a21\u578b\u7684\u7247\u6bb5\u90fd\u53ef\u7528\u4e8e XGBoost\u3001LightGBM \u6216 CatBoost\u3002\u7279\u5f81\u91cd\u8981\u6027\u51fd\u6570\u7684\u540d\u79f0\u53ef\u80fd\u4e0d\u540c\uff0c\u4ea7\u751f\u7ed3\u679c\u7684\u683c\u5f0f\u4e5f\u53ef\u80fd\u4e0d\u540c\uff0c\u4f46\u7528\u6cd5\u662f\u4e00\u6837\u7684\u3002\u6700\u540e\uff0c\u5728\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u65f6\u5fc5\u987b\u5c0f\u5fc3\u8c28\u614e\u3002\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u9009\u62e9\u7279\u5f81\uff0c\u5e76\u5728\u9a8c\u8bc1\u6570\u636e\u4e0a\u9a8c\u8bc1\u6a21\u578b\uff0c\u4ee5\u4fbf\u5728\u4e0d\u8fc7\u5ea6\u62df\u5408\u6a21\u578b\u7684\u60c5\u51b5\u4e0b\u6b63\u786e\u9009\u62e9\u7279\u5f81\u3002","title":"\u7279\u5f81\u9009\u62e9"},{"location":"%E7%BB%84%E5%90%88%E5%92%8C%E5%A0%86%E5%8F%A0%E6%96%B9%E6%B3%95/","text":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u542c\u5230\u4e0a\u9762\u4e24\u4e2a\u8bcd\uff0c\u6211\u4eec\u9996\u5148\u60f3\u5230\u7684\u5c31\u662f\u5728\u7ebf\uff08online\uff09/\u79bb\u7ebf\uff08offline\uff09\u673a\u5668\u5b66\u4e60\u7ade\u8d5b\u3002\u51e0\u5e74\u524d\u662f\u8fd9\u6837\uff0c\u4f46\u73b0\u5728\u968f\u7740\u8ba1\u7b97\u80fd\u529b\u7684\u8fdb\u6b65\u548c\u865a\u62df\u5b9e\u4f8b\u7684\u5ec9\u4ef7\uff0c\u4eba\u4eec\u751a\u81f3\u5f00\u59cb\u5728\u884c\u4e1a\u4e2d\u4f7f\u7528\u7ec4\u5408\u6a21\u578b\uff08ensemble models\uff09\u3002\u4f8b\u5982\uff0c\u90e8\u7f72\u591a\u4e2a\u795e\u7ecf\u7f51\u7edc\u5e76\u5b9e\u65f6\u4e3a\u5b83\u4eec\u63d0\u4f9b\u670d\u52a1\u975e\u5e38\u5bb9\u6613\uff0c\u54cd\u5e94\u65f6\u95f4\u5c0f\u4e8e 500 \u6beb\u79d2\u3002\u6709\u65f6\uff0c\u4e00\u4e2a\u5e9e\u5927\u7684\u795e\u7ecf\u7f51\u7edc\u6216\u5927\u578b\u6a21\u578b\u4e5f\u53ef\u4ee5\u88ab\u5176\u4ed6\u51e0\u4e2a\u6a21\u578b\u53d6\u4ee3\uff0c\u8fd9\u4e9b\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u6027\u80fd\u4e0e\u5927\u578b\u6a21\u578b\u76f8\u4f3c\uff0c\u901f\u5ea6\u5374\u5feb\u4e00\u500d\u3002\u5982\u679c\u662f\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff08\u4e9b\uff09\u6a21\u578b\u5462\uff1f\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u4e8e\u9009\u62e9\u591a\u4e2a\u5c0f\u673a\u578b\uff0c\u5b83\u4eec\u901f\u5ea6\u66f4\u5feb\uff0c\u6027\u80fd\u4e0e\u5927\u673a\u578b\u548c\u6162\u673a\u578b\u76f8\u540c\u3002\u8bf7\u8bb0\u4f4f\uff0c\u8f83\u5c0f\u7684\u578b\u53f7\u4e5f\u66f4\u5bb9\u6613\u548c\u66f4\u5feb\u5730\u8fdb\u884c\u8c03\u6574\u3002 \u7ec4\u5408\uff08ensembling\uff09\u4e0d\u8fc7\u662f\u4e0d\u540c\u6a21\u578b\u7684\u7ec4\u5408\u3002\u6a21\u578b\u53ef\u4ee5\u901a\u8fc7\u9884\u6d4b/\u6982\u7387\u8fdb\u884c\u7ec4\u5408\u3002\u7ec4\u5408\u6a21\u578b\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6c42\u5e73\u5747\u503c\u3002 $$ Ensemble Probabilities = (M1_proba + M2_proba + ... + Mn_Proba)/n $$ \u8fd9\u662f\u6700\u7b80\u5355\u4e5f\u662f\u6700\u6709\u6548\u7684\u7ec4\u5408\u6a21\u578b\u7684\u65b9\u6cd5\u3002\u5728\u7b80\u5355\u5e73\u5747\u6cd5\u4e2d\uff0c\u6240\u6709\u6a21\u578b\u7684\u6743\u91cd\u90fd\u662f\u76f8\u7b49\u7684\u3002\u65e0\u8bba\u91c7\u7528\u54ea\u79cd\u7ec4\u5408\u65b9\u6cd5\uff0c\u60a8\u90fd\u5e94\u8be5\u7262\u8bb0\u4e00\u70b9\uff0c\u90a3\u5c31\u662f\u60a8\u5e94\u8be5\u59cb\u7ec8\u5c06\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b/\u6982\u7387\u7ec4\u5408\u5728\u4e00\u8d77\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u7ec4\u5408\u76f8\u5173\u6027\u4e0d\u9ad8\u7684\u6a21\u578b\u6bd4\u7ec4\u5408\u76f8\u5173\u6027\u5f88\u9ad8\u7684\u6a21\u578b\u6548\u679c\u66f4\u597d\u3002 \u5982\u679c\u6ca1\u6709\u6982\u7387\uff0c\u4e5f\u53ef\u4ee5\u7ec4\u5408\u9884\u6d4b\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6295\u7968\u3002\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u8fdb\u884c\u591a\u7c7b\u5206\u7c7b\uff0c\u6709\u4e09\u4e2a\u7c7b\u522b\uff1a 0\u30011 \u548c 2\u3002 [0, 0, 1] : \u6700\u9ad8\u7968\u6570\uff1a 0 [0, 1, 2] : \u6700\u9ad8\u7968\u7ea7\uff1a \u65e0\uff08\u968f\u673a\u9009\u62e9\u4e00\u4e2a\uff09 [2, 2, 2] : \u6700\u9ad8\u7968\u6570\uff1a 2 \u4ee5\u4e0b\u7b80\u5355\u51fd\u6570\u53ef\u4ee5\u5b8c\u6210\u8fd9\u4e9b\u7b80\u5355\u64cd\u4f5c\u3002 import numpy as np def mean_predictions ( probas ): # \u8ba1\u7b97\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u6bcf\u884c\u5e73\u5747\u503c return np . mean ( probas , axis = 1 ) def max_voting ( preds ): # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u67e5\u627e\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u7684\u7d22\u5f15 idxs = np . argmax ( preds , axis = 1 ) # \u6839\u636e\u7d22\u5f15\u53d6\u51fa\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u5143\u7d20 return np . take_along_axis ( preds , idxs [:, None ], axis = 1 ) \u8bf7\u6ce8\u610f\uff0cprobas \u7684\u6bcf\u4e00\u5217\u90fd\u53ea\u6709\u4e00\u4e2a\u6982\u7387\uff08\u5373\u4e8c\u5143\u5206\u7c7b\uff0c\u901a\u5e38\u4e3a\u7c7b\u522b 1\uff09\u3002\u56e0\u6b64\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u4e00\u4e2a\u65b0\u6a21\u578b\u3002\u540c\u6837\uff0c\u5bf9\u4e8e preds\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u6765\u81ea\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b\u503c\u3002\u8fd9\u4e24\u4e2a\u51fd\u6570\u90fd\u5047\u8bbe\u4e86\u4e00\u4e2a 2 \u7ef4 numpy \u6570\u7ec4\u3002\u60a8\u53ef\u4ee5\u6839\u636e\u81ea\u5df1\u7684\u9700\u6c42\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u6709\u4e00\u4e2a 2 \u7ef4\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u6bcf\u4e2a\u6a21\u578b\u7684\u6982\u7387\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u51fd\u6570\u4f1a\u6709\u4e00\u4e9b\u53d8\u5316\u3002 \u53e6\u4e00\u79cd\u7ec4\u5408\u591a\u4e2a\u6a21\u578b\u7684\u65b9\u6cd5\u662f\u901a\u8fc7\u5b83\u4eec\u7684 \u6982\u7387\u6392\u5e8f \u3002\u5f53\u76f8\u5173\u6307\u6807\u662f\u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u65f6\uff0c\u8fd9\u79cd\u7ec4\u5408\u65b9\u5f0f\u975e\u5e38\u6709\u6548\uff0c\u56e0\u4e3a AUC \u5c31\u662f\u5bf9\u6837\u672c\u8fdb\u884c\u6392\u5e8f\u3002 def rank_mean ( probas ): # \u521b\u5efa\u7a7a\u5217\u8868ranked\u5b58\u50a8\u6bcf\u4e2a\u7c7b\u522b\u6982\u7387\u503c\u6392\u540d ranked = [] # \u904d\u5386\u6982\u7387\u503c\u6bcf\u4e00\u5217\uff08\u6bcf\u4e2a\u7c7b\u522b\u7684\u6982\u7387\u503c\uff09 for i in range ( probas . shape [ 1 ]): # \u5f53\u524d\u5217\u6982\u7387\u503c\u6392\u540d\uff0crank_data\u662f\u6392\u540d\u7ed3\u679c rank_data = stats . rankdata ( probas [:, i ]) # \u5c06\u5f53\u524d\u5217\u6392\u540d\u7ed3\u679c\u6dfb\u52a0\u5230ranked\u5217\u8868\u4e2d ranked . append ( rank_data ) # \u5c06ranked\u5217\u8868\u4e2d\u6392\u540d\u7ed3\u679c\u6309\u5217\u5806\u53e0\uff0c\u5f62\u6210\u4e8c\u7ef4\u6570\u7ec4 ranked = np . column_stack ( ranked ) # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u8ba1\u7b97\u6837\u672c\u6392\u540d\u5e73\u5747\u503c return np . mean ( ranked , axis = 1 ) \u8bf7\u6ce8\u610f\uff0c\u5728 scipy \u7684 rankdata \u4e2d\uff0c\u7b49\u7ea7\u4ece 1 \u5f00\u59cb\u3002 \u4e3a\u4ec0\u4e48\u8fd9\u7c7b\u96c6\u5408\u6709\u6548\uff1f\u8ba9\u6211\u4eec\u770b\u770b\u56fe 1\u3002 \u56fe 1\uff1a\u4e09\u4eba\u731c\u5927\u8c61\u7684\u8eab\u9ad8 \u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u6709\u4e09\u4e2a\u4eba\u5728\u731c\u5927\u8c61\u7684\u9ad8\u5ea6\uff0c\u90a3\u4e48\u539f\u59cb\u9ad8\u5ea6\u5c06\u975e\u5e38\u63a5\u8fd1\u4e09\u4e2a\u4eba\u731c\u6d4b\u7684\u5e73\u5747\u503c\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u4eba\u90fd\u80fd\u731c\u5230\u975e\u5e38\u63a5\u8fd1\u5927\u8c61\u539f\u6765\u7684\u9ad8\u5ea6\u3002\u63a5\u8fd1\u4f30\u8ba1\u503c\u610f\u5473\u7740\u8bef\u5dee\uff0c\u4f46\u5982\u679c\u6211\u4eec\u5c06\u4e09\u4e2a\u9884\u6d4b\u503c\u5e73\u5747\uff0c\u5c31\u80fd\u5c06\u8bef\u5dee\u964d\u5230\u6700\u4f4e\u3002\u8fd9\u5c31\u662f\u591a\u4e2a\u6a21\u578b\u5e73\u5747\u7684\u4e3b\u8981\u601d\u60f3\u3002 $$ Final\\ Probabilities = w_1 \\times M1_proba + w_2 \\times M2_proba + \\cdots + w_n \\times Mn_proba $$ \u5176\u4e2d \\((w_1 + w_2 + w_3 + \\cdots + w_n)=1.0\\) \u4f8b\u5982\uff0c\u5982\u679c\u4f60\u6709\u4e00\u4e2a AUC \u975e\u5e38\u9ad8\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u548c\u4e00\u4e2a AUC \u7a0d\u4f4e\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0c\u968f\u673a\u68ee\u6797\u6a21\u578b\u5360 70%\uff0c\u903b\u8f91\u56de\u5f52\u6a21\u578b\u5360 30%\u3002\u90a3\u4e48\uff0c\u6211\u662f\u5982\u4f55\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u7684\u5462\uff1f\u8ba9\u6211\u4eec\u518d\u6dfb\u52a0\u4e00\u4e2a\u6a21\u578b\uff0c\u5047\u8bbe\u73b0\u5728\u6211\u4eec\u4e5f\u6709\u4e00\u4e2a xgboost \u6a21\u578b\uff0c\u5b83\u7684 AUC \u6bd4\u968f\u673a\u68ee\u6797\u9ad8\u3002\u73b0\u5728\uff0c\u6211\u5c06\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0cxgboost\uff1a\u968f\u673a\u68ee\u6797\uff1a\u903b\u8f91\u56de\u5f52\u7684\u6bd4\u4f8b\u4e3a 3:2:1\u3002\u5f88\u7b80\u5355\u5427\uff1f\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u6613\u5982\u53cd\u638c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u662f\u5982\u4f55\u505a\u5230\u7684\u3002 \u5047\u5b9a\u6211\u4eec\u6709\u4e09\u53ea\u7334\u5b50\uff0c\u4e09\u53ea\u65cb\u94ae\u7684\u6570\u503c\u5728 0 \u548c 1 \u4e4b\u95f4\u3002\u8fd9\u4e9b\u7334\u5b50\u8f6c\u52a8\u65cb\u94ae\uff0c\u6211\u4eec\u8ba1\u7b97\u5b83\u4eec\u6bcf\u8f6c\u5230\u4e00\u4e2a\u6570\u503c\u65f6\u7684 AUC \u5206\u6570\u3002\u6700\u7ec8\uff0c\u7334\u5b50\u4eec\u4f1a\u627e\u5230\u4e00\u4e2a\u80fd\u7ed9\u51fa\u6700\u4f73 AUC \u7684\u7ec4\u5408\u3002\u6ca1\u9519\uff0c\u8fd9\u5c31\u662f\u968f\u673a\u641c\u7d22\uff01\u5728\u8fdb\u884c\u8fd9\u7c7b\u641c\u7d22\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\u4e24\u4e2a\u6700\u91cd\u8981\u7684\u7ec4\u5408\u89c4\u5219\u3002 \u7ec4\u5408\u7684\u7b2c\u4e00\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u7ec4\u5408\u7684\u7b2c\u4e8c\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u662f\u7684\u3002\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u6761\u89c4\u5219\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u6298\u53e0\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u5047\u8bbe\u6211\u4eec\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\uff1a\u6298\u53e0 1 \u548c\u6298\u53e0 2\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u6837\u505a\u53ea\u662f\u4e3a\u4e86\u7b80\u5316\u89e3\u91ca\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\uff0c\u60a8\u5e94\u8be5\u521b\u5efa\u66f4\u591a\u7684\u6298\u53e0\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u968f\u673a\u68ee\u6797\u6a21\u578b\u3001\u903b\u8f91\u56de\u5f52\u6a21\u578b\u548c xgboost \u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 2 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u4e4b\u540e\uff0c\u6211\u4eec\u5728\u6298\u53e0 2 \u4e0a\u4ece\u5934\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 1 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u4e3a\u6240\u6709\u8bad\u7ec3\u6570\u636e\u521b\u5efa\u4e86\u9884\u6d4b\u7ed3\u679c\u3002\u73b0\u5728\uff0c\u4e3a\u4e86\u5408\u5e76\u8fd9\u4e9b\u6a21\u578b\uff0c\u6211\u4eec\u5c06\u6298\u53e0 1 \u548c\u6298\u53e0 1 \u7684\u6240\u6709\u9884\u6d4b\u6570\u636e\u5408\u5e76\u5728\u4e00\u8d77\uff0c\u7136\u540e\u521b\u5efa\u4e00\u4e2a\u4f18\u5316\u51fd\u6570\uff0c\u8bd5\u56fe\u627e\u5230\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4fbf\u9488\u5bf9\u6298\u53e0 2 \u7684\u76ee\u6807\u6700\u5c0f\u5316\u8bef\u5dee\u6216\u6700\u5927\u5316 AUC\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u662f\u7528\u4e09\u4e2a\u6a21\u578b\u7684\u9884\u6d4b\u6982\u7387\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u4e00\u4e2a\u4f18\u5316\u6a21\u578b\uff0c\u7136\u540e\u5728\u6298\u53e0 2 \u4e0a\u5bf9\u5176\u8fdb\u884c\u8bc4\u4f30\u3002\u8ba9\u6211\u4eec\u5148\u6765\u770b\u770b\u6211\u4eec\u53ef\u4ee5\u7528\u6765\u627e\u5230\u591a\u4e2a\u6a21\u578b\u7684\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4f18\u5316 AUC\uff08\u6216\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u6d4b\u6307\u6807\u7ec4\u5408\uff09\u7684\u7c7b\u3002 import numpy as np from functools import partial from scipy.optimize import fmin from sklearn import metrics class OptimizeAUC : def __init__ ( self ): # \u521d\u59cb\u5316\u7cfb\u6570 self . coef_ = 0 def _auc ( self , coef , X , y ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u7cfb\u6570 x_coef = X * coef # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8ba1\u7b97AUC\u5206\u6570 auc_score = metrics . roc_auc_score ( y , predictions ) # \u8fd4\u56de\u8d1fAUC\u4ee5\u4fbf\u6700\u5c0f\u5316 return - 1.0 * auc_score def fit ( self , X , y ): # \u521b\u5efa\u5e26\u6709\u90e8\u5206\u53c2\u6570\u7684\u76ee\u6807\u51fd\u6570 loss_partial = partial ( self . _auc , X = X , y = y ) # \u521d\u59cb\u5316\u7cfb\u6570 initial_coef = np . random . dirichlet ( np . ones ( X . shape [ 1 ]), size = 1 ) # \u4f7f\u7528fmin\u51fd\u6570\u4f18\u5316AUC\u76ee\u6807\u51fd\u6570\uff0c\u627e\u5230\u6700\u4f18\u7cfb\u6570 self . coef_ = fmin ( loss_partial , initial_coef , disp = True ) def predict ( self , X ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u8bad\u7ec3\u597d\u7684\u7cfb\u6570 x_coef = X * self . coef_ # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8fd4\u56de\u9884\u6d4b\u7ed3\u679c return predictions \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528\u5b83\uff0c\u5e76\u5c06\u5176\u4e0e\u7b80\u5355\u5e73\u5747\u6cd5\u8fdb\u884c\u6bd4\u8f83\u3002 import xgboost as xgb from sklearn.datasets import make_classification from sklearn import ensemble from sklearn import linear_model from sklearn import metrics from sklearn import model_selection # \u751f\u6210\u4e00\u4e2a\u5206\u7c7b\u6570\u636e\u96c6 X , y = make_classification ( n_samples = 10000 , n_features = 25 ) # \u5212\u5206\u6570\u636e\u96c6\u4e3a\u4e24\u4e2a\u4ea4\u53c9\u9a8c\u8bc1\u6298\u53e0 xfold1 , xfold2 , yfold1 , yfold2 = model_selection . train_test_split ( X , y , test_size = 0.5 , stratify = y ) # \u521d\u59cb\u5316\u4e09\u4e2a\u4e0d\u540c\u7684\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold1 , yfold1 ) rf . fit ( xfold1 , yfold1 ) xgbc . fit ( xfold1 , yfold1 ) # \u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold2 )[:, 1 ] pred_rf = rf . predict_proba ( xfold2 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold2 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold2_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold2 = [] for i in range ( fold2_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold2 , fold2_preds [:, i ]) aucs_fold2 . append ( auc ) print ( f \"Fold-2: LR AUC = { aucs_fold2 [ 0 ] } \" ) print ( f \"Fold-2: RF AUC = { aucs_fold2 [ 1 ] } \" ) print ( f \"Fold-2: XGB AUC = { aucs_fold2 [ 2 ] } \" ) print ( f \"Fold-2: Average Pred AUC = { aucs_fold2 [ 3 ] } \" ) # \u91cd\u65b0\u521d\u59cb\u5316\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold2 , yfold2 ) rf . fit ( xfold2 , yfold2 ) xgbc . fit ( xfold2 , yfold2 ) # \u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold1 )[:, 1 ] pred_rf = rf . predict_proba ( xfold1 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold1 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold1_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold1 = [] for i in range ( fold1_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold1 , fold1_preds [:, i ]) aucs_fold1 . append ( auc ) print ( f \"Fold-1: LR AUC = { aucs_fold1 [ 0 ] } \" ) print ( f \"Fold-1: RF AUC = { aucs_fold1 [ 1 ] } \" ) print ( f \"Fold-1: XGB AUC = { aucs_fold1 [ 2 ] } \" ) print ( f \"Fold-1: Average prediction AUC = { aucs_fold1 [ 3 ] } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765\u8bad\u7ec3\u4f18\u5316\u5668 opt . fit ( fold1_preds [:, : - 1 ], yfold1 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold2 = opt . predict ( fold2_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold2 , opt_preds_fold2 ) print ( f \"Optimized AUC, Fold 2 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765 opt . fit ( fold2_preds [:, : - 1 ], yfold2 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold1 = opt . predict ( fold1_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold1 , opt_preds_fold1 ) print ( f \"Optimized AUC, Fold 1 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) \u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u8f93\u51fa\uff1a \u276f python auc_opt . py Fold - 2 : LR AUC = 0.9145446769443348 Fold - 2 : RF AUC = 0.9269918948683287 Fold - 2 : XGB AUC = 0.9302436595508696 Fold - 2 : Average Pred AUC = 0.927701495890154 Fold - 1 : LR AUC = 0.9050872233256017 Fold - 1 : RF AUC = 0.9179382818311258 Fold - 1 : XGB AUC = 0.9195837242005629 Fold - 1 : Average prediction AUC = 0.9189669233123695 Optimization terminated successfully . Current function value : - 0.920643 Iterations : 50 Function evaluations : 109 Optimized AUC , Fold 2 = 0.9305386199756128 Coefficients = [ - 0.00188194 0.19328336 0.35891836 ] Optimization terminated successfully . Current function value : - 0.931232 Iterations : 56 Function evaluations : 113 Optimized AUC , Fold 1 = 0.9192523637234037 Coefficients = [ - 0.15655124 0.22393151 0.58711366 ] \u6211\u4eec\u770b\u5230\uff0c\u5e73\u5747\u503c\u66f4\u597d\uff0c\u4f46\u4f7f\u7528\u4f18\u5316\u5668\u627e\u5230\u9608\u503c\u66f4\u597d\uff01\u6709\u65f6\uff0c\u5e73\u5747\u503c\u662f\u6700\u597d\u7684\u9009\u62e9\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u7cfb\u6570\u52a0\u8d77\u6765\u5e76\u6ca1\u6709\u8fbe\u5230 1.0\uff0c\u4f46\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u6211\u4eec\u8981\u5904\u7406\u7684\u662f AUC\uff0c\u800c AUC \u53ea\u5173\u5fc3\u7b49\u7ea7\u3002 \u5373\u4f7f\u968f\u673a\u68ee\u6797\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u968f\u673a\u68ee\u6797\u53ea\u662f\u8bb8\u591a\u7b80\u5355\u51b3\u7b56\u6811\u7684\u7ec4\u5408\u3002\u968f\u673a\u68ee\u6797\u5c5e\u4e8e\u96c6\u5408\u6a21\u578b\u7684\u4e00\u79cd\uff0c\u4e5f\u5c31\u662f\u4fd7\u79f0\u7684 \"bagging\" \u3002\u5728\u888b\u96c6\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u5c0f\u6570\u636e\u5b50\u96c6\u5e76\u8bad\u7ec3\u591a\u4e2a\u7b80\u5355\u6a21\u578b\u3002\u6700\u7ec8\u7ed3\u679c\u7531\u6240\u6709\u8fd9\u4e9b\u5c0f\u6a21\u578b\u7684\u9884\u6d4b\u7ed3\u679c\uff08\u5982\u5e73\u5747\u503c\uff09\u7ec4\u5408\u800c\u6210\u3002 \u6211\u4eec\u4f7f\u7528\u7684 xgboost \u6a21\u578b\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u6240\u6709\u68af\u5ea6\u63d0\u5347\u6a21\u578b\u90fd\u662f\u96c6\u5408\u6a21\u578b\uff0c\u7edf\u79f0\u4e3a \u63d0\u5347\u6a21\u578b\uff08boosting models\uff09 \u3002\u63d0\u5347\u6a21\u578b\u7684\u5de5\u4f5c\u539f\u7406\u4e0e\u88c5\u888b\u6a21\u578b\u7c7b\u4f3c\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u63d0\u5347\u6a21\u578b\u4e2d\u7684\u8fde\u7eed\u6a21\u578b\u662f\u6839\u636e\u8bef\u5dee\u6b8b\u5dee\u8bad\u7ec3\u7684\uff0c\u5e76\u503e\u5411\u4e8e\u6700\u5c0f\u5316\u524d\u9762\u6a21\u578b\u7684\u8bef\u5dee\u3002\u8fd9\u6837\uff0c\u63d0\u5347\u6a21\u578b\u5c31\u80fd\u5b8c\u7f8e\u5730\u5b66\u4e60\u6570\u636e\uff0c\u56e0\u6b64\u5bb9\u6613\u51fa\u73b0\u8fc7\u62df\u5408\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u770b\u5230\u7684\u4ee3\u7801\u7247\u6bb5\u53ea\u8003\u8651\u4e86\u4e00\u5217\u3002\u4f46\u60c5\u51b5\u5e76\u975e\u603b\u662f\u5982\u6b64\uff0c\u5f88\u591a\u65f6\u5019\u60a8\u9700\u8981\u5904\u7406\u591a\u5217\u9884\u6d4b\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u4f1a\u9047\u5230\u4ece\u591a\u4e2a\u7c7b\u522b\u4e2d\u9884\u6d4b\u4e00\u4e2a\u7c7b\u522b\u7684\u95ee\u9898\uff0c\u5373\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5bf9\u4e8e\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u9009\u62e9\u6295\u7968\u65b9\u6cd5\u3002\u4f46\u6295\u7968\u6cd5\u5e76\u4e0d\u603b\u662f\u6700\u4f73\u65b9\u6cd5\u3002\u5982\u679c\u8981\u7ec4\u5408\u6982\u7387\uff0c\u5c31\u4f1a\u6709\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4\uff0c\u800c\u4e0d\u662f\u50cf\u6211\u4eec\u4e4b\u524d\u4f18\u5316 AUC \u65f6\u7684\u5411\u91cf\u3002\u5982\u679c\u6709\u591a\u4e2a\u7c7b\u522b\uff0c\u53ef\u4ee5\u5c1d\u8bd5\u4f18\u5316\u5bf9\u6570\u635f\u5931\uff08\u6216\u5176\u4ed6\u4e0e\u4e1a\u52a1\u76f8\u5173\u7684\u6307\u6807\uff09\u3002 \u8981\u8fdb\u884c\u7ec4\u5408\uff0c\u53ef\u4ee5\u5728\u62df\u5408\u51fd\u6570 (X) \u4e2d\u4f7f\u7528 numpy \u6570\u7ec4\u5217\u8868\u800c\u4e0d\u662f numpy \u6570\u7ec4\uff0c\u968f\u540e\u8fd8\u9700\u8981\u66f4\u6539\u4f18\u5316\u5668\u548c\u9884\u6d4b\u51fd\u6570\u3002\u6211\u5c31\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u5427\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8fdb\u5165\u4e0b\u4e00\u4e2a\u6709\u8da3\u7684\u8bdd\u9898\uff0c\u8fd9\u4e2a\u8bdd\u9898\u76f8\u5f53\u6d41\u884c\uff0c\u88ab\u79f0\u4e3a \u5806\u53e0 \u3002\u56fe 2 \u5c55\u793a\u4e86\u5982\u4f55\u5806\u53e0\u6a21\u578b\u3002 \u56fe2 : Stacking \u5806\u53e0\u4e0d\u50cf\u5236\u9020\u706b\u7bad\u3002\u5b83\u7b80\u5355\u660e\u4e86\u3002\u5982\u679c\u60a8\u8fdb\u884c\u4e86\u6b63\u786e\u7684\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5e76\u5728\u6574\u4e2a\u5efa\u6a21\u8fc7\u7a0b\u4e2d\u4fdd\u6301\u6298\u53e0\u4e0d\u53d8\uff0c\u90a3\u4e48\u5c31\u4e0d\u4f1a\u51fa\u73b0\u4efb\u4f55\u8fc7\u5ea6\u8d34\u5408\u7684\u60c5\u51b5\u3002 \u8ba9\u6211\u7528\u7b80\u5355\u7684\u8981\u70b9\u5411\u4f60\u63cf\u8ff0\u4e00\u4e0b\u8fd9\u4e2a\u60f3\u6cd5\u3002 - \u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u82e5\u5e72\u6298\u53e0\u3002 - \u8bad\u7ec3\u4e00\u5806\u6a21\u578b\uff1a M1\u3001M2.....Mn\u3002 - \u521b\u5efa\u5b8c\u6574\u7684\u8bad\u7ec3\u9884\u6d4b\uff08\u4f7f\u7528\u975e\u6298\u53e0\u8bad\u7ec3\uff09\uff0c\u5e76\u4f7f\u7528\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\u9884\u6d4b\u3002 - \u76f4\u5230\u8fd9\u91cc\u662f\u7b2c 1 \u5c42 (L1)\u3002 - \u5c06\u8fd9\u4e9b\u6a21\u578b\u7684\u6298\u53e0\u9884\u6d4b\u4f5c\u4e3a\u53e6\u4e00\u4e2a\u6a21\u578b\u7684\u7279\u5f81\u3002\u8fd9\u5c31\u662f\u4e8c\u7ea7\u6a21\u578b\uff08L2\uff09\u3002 - \u4f7f\u7528\u4e0e\u4e4b\u524d\u76f8\u540c\u7684\u6298\u53e0\u6765\u8bad\u7ec3\u8fd9\u4e2a L2 \u6a21\u578b\u3002 - \u73b0\u5728\uff0c\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u521b\u5efa OOF\uff08\u6298\u53e0\u5916\uff09\u9884\u6d4b\u3002 - \u73b0\u5728\u60a8\u5c31\u6709\u4e86\u8bad\u7ec3\u6570\u636e\u7684 L2 \u9884\u6d4b\u548c\u6700\u7ec8\u6d4b\u8bd5\u96c6\u9884\u6d4b\u3002 \u60a8\u53ef\u4ee5\u4e0d\u65ad\u91cd\u590d L1 \u90e8\u5206\uff0c\u4e5f\u53ef\u4ee5\u521b\u5efa\u4efb\u610f\u591a\u7684\u5c42\u6b21\u3002 \u6709\u65f6\uff0c\u4f60\u8fd8\u4f1a\u9047\u5230\u4e00\u4e2a\u53eb\u6df7\u5408\u7684\u672f\u8bed blending \u3002\u5982\u679c\u4f60\u9047\u5230\u4e86\uff0c\u4e0d\u7528\u592a\u62c5\u5fc3\u3002\u5b83\u53ea\u4e0d\u8fc7\u662f\u7528\u4e00\u4e2a\u4fdd\u7559\u7ec4\u6765\u5806\u53e0\uff0c\u800c\u4e0d\u662f\u591a\u91cd\u6298\u53e0\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u6211\u5728\u672c\u7ae0\u4e2d\u6240\u63cf\u8ff0\u7684\u5185\u5bb9\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u7c7b\u578b\u7684\u95ee\u9898\uff1a\u5206\u7c7b\u3001\u56de\u5f52\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u7b49\u3002","title":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5"},{"location":"%E7%BB%84%E5%90%88%E5%92%8C%E5%A0%86%E5%8F%A0%E6%96%B9%E6%B3%95/#_1","text":"\u542c\u5230\u4e0a\u9762\u4e24\u4e2a\u8bcd\uff0c\u6211\u4eec\u9996\u5148\u60f3\u5230\u7684\u5c31\u662f\u5728\u7ebf\uff08online\uff09/\u79bb\u7ebf\uff08offline\uff09\u673a\u5668\u5b66\u4e60\u7ade\u8d5b\u3002\u51e0\u5e74\u524d\u662f\u8fd9\u6837\uff0c\u4f46\u73b0\u5728\u968f\u7740\u8ba1\u7b97\u80fd\u529b\u7684\u8fdb\u6b65\u548c\u865a\u62df\u5b9e\u4f8b\u7684\u5ec9\u4ef7\uff0c\u4eba\u4eec\u751a\u81f3\u5f00\u59cb\u5728\u884c\u4e1a\u4e2d\u4f7f\u7528\u7ec4\u5408\u6a21\u578b\uff08ensemble models\uff09\u3002\u4f8b\u5982\uff0c\u90e8\u7f72\u591a\u4e2a\u795e\u7ecf\u7f51\u7edc\u5e76\u5b9e\u65f6\u4e3a\u5b83\u4eec\u63d0\u4f9b\u670d\u52a1\u975e\u5e38\u5bb9\u6613\uff0c\u54cd\u5e94\u65f6\u95f4\u5c0f\u4e8e 500 \u6beb\u79d2\u3002\u6709\u65f6\uff0c\u4e00\u4e2a\u5e9e\u5927\u7684\u795e\u7ecf\u7f51\u7edc\u6216\u5927\u578b\u6a21\u578b\u4e5f\u53ef\u4ee5\u88ab\u5176\u4ed6\u51e0\u4e2a\u6a21\u578b\u53d6\u4ee3\uff0c\u8fd9\u4e9b\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u6027\u80fd\u4e0e\u5927\u578b\u6a21\u578b\u76f8\u4f3c\uff0c\u901f\u5ea6\u5374\u5feb\u4e00\u500d\u3002\u5982\u679c\u662f\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff08\u4e9b\uff09\u6a21\u578b\u5462\uff1f\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u4e8e\u9009\u62e9\u591a\u4e2a\u5c0f\u673a\u578b\uff0c\u5b83\u4eec\u901f\u5ea6\u66f4\u5feb\uff0c\u6027\u80fd\u4e0e\u5927\u673a\u578b\u548c\u6162\u673a\u578b\u76f8\u540c\u3002\u8bf7\u8bb0\u4f4f\uff0c\u8f83\u5c0f\u7684\u578b\u53f7\u4e5f\u66f4\u5bb9\u6613\u548c\u66f4\u5feb\u5730\u8fdb\u884c\u8c03\u6574\u3002 \u7ec4\u5408\uff08ensembling\uff09\u4e0d\u8fc7\u662f\u4e0d\u540c\u6a21\u578b\u7684\u7ec4\u5408\u3002\u6a21\u578b\u53ef\u4ee5\u901a\u8fc7\u9884\u6d4b/\u6982\u7387\u8fdb\u884c\u7ec4\u5408\u3002\u7ec4\u5408\u6a21\u578b\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6c42\u5e73\u5747\u503c\u3002 $$ Ensemble Probabilities = (M1_proba + M2_proba + ... + Mn_Proba)/n $$ \u8fd9\u662f\u6700\u7b80\u5355\u4e5f\u662f\u6700\u6709\u6548\u7684\u7ec4\u5408\u6a21\u578b\u7684\u65b9\u6cd5\u3002\u5728\u7b80\u5355\u5e73\u5747\u6cd5\u4e2d\uff0c\u6240\u6709\u6a21\u578b\u7684\u6743\u91cd\u90fd\u662f\u76f8\u7b49\u7684\u3002\u65e0\u8bba\u91c7\u7528\u54ea\u79cd\u7ec4\u5408\u65b9\u6cd5\uff0c\u60a8\u90fd\u5e94\u8be5\u7262\u8bb0\u4e00\u70b9\uff0c\u90a3\u5c31\u662f\u60a8\u5e94\u8be5\u59cb\u7ec8\u5c06\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b/\u6982\u7387\u7ec4\u5408\u5728\u4e00\u8d77\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u7ec4\u5408\u76f8\u5173\u6027\u4e0d\u9ad8\u7684\u6a21\u578b\u6bd4\u7ec4\u5408\u76f8\u5173\u6027\u5f88\u9ad8\u7684\u6a21\u578b\u6548\u679c\u66f4\u597d\u3002 \u5982\u679c\u6ca1\u6709\u6982\u7387\uff0c\u4e5f\u53ef\u4ee5\u7ec4\u5408\u9884\u6d4b\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6295\u7968\u3002\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u8fdb\u884c\u591a\u7c7b\u5206\u7c7b\uff0c\u6709\u4e09\u4e2a\u7c7b\u522b\uff1a 0\u30011 \u548c 2\u3002 [0, 0, 1] : \u6700\u9ad8\u7968\u6570\uff1a 0 [0, 1, 2] : \u6700\u9ad8\u7968\u7ea7\uff1a \u65e0\uff08\u968f\u673a\u9009\u62e9\u4e00\u4e2a\uff09 [2, 2, 2] : \u6700\u9ad8\u7968\u6570\uff1a 2 \u4ee5\u4e0b\u7b80\u5355\u51fd\u6570\u53ef\u4ee5\u5b8c\u6210\u8fd9\u4e9b\u7b80\u5355\u64cd\u4f5c\u3002 import numpy as np def mean_predictions ( probas ): # \u8ba1\u7b97\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u6bcf\u884c\u5e73\u5747\u503c return np . mean ( probas , axis = 1 ) def max_voting ( preds ): # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u67e5\u627e\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u7684\u7d22\u5f15 idxs = np . argmax ( preds , axis = 1 ) # \u6839\u636e\u7d22\u5f15\u53d6\u51fa\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u5143\u7d20 return np . take_along_axis ( preds , idxs [:, None ], axis = 1 ) \u8bf7\u6ce8\u610f\uff0cprobas \u7684\u6bcf\u4e00\u5217\u90fd\u53ea\u6709\u4e00\u4e2a\u6982\u7387\uff08\u5373\u4e8c\u5143\u5206\u7c7b\uff0c\u901a\u5e38\u4e3a\u7c7b\u522b 1\uff09\u3002\u56e0\u6b64\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u4e00\u4e2a\u65b0\u6a21\u578b\u3002\u540c\u6837\uff0c\u5bf9\u4e8e preds\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u6765\u81ea\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b\u503c\u3002\u8fd9\u4e24\u4e2a\u51fd\u6570\u90fd\u5047\u8bbe\u4e86\u4e00\u4e2a 2 \u7ef4 numpy \u6570\u7ec4\u3002\u60a8\u53ef\u4ee5\u6839\u636e\u81ea\u5df1\u7684\u9700\u6c42\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u6709\u4e00\u4e2a 2 \u7ef4\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u6bcf\u4e2a\u6a21\u578b\u7684\u6982\u7387\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u51fd\u6570\u4f1a\u6709\u4e00\u4e9b\u53d8\u5316\u3002 \u53e6\u4e00\u79cd\u7ec4\u5408\u591a\u4e2a\u6a21\u578b\u7684\u65b9\u6cd5\u662f\u901a\u8fc7\u5b83\u4eec\u7684 \u6982\u7387\u6392\u5e8f \u3002\u5f53\u76f8\u5173\u6307\u6807\u662f\u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u65f6\uff0c\u8fd9\u79cd\u7ec4\u5408\u65b9\u5f0f\u975e\u5e38\u6709\u6548\uff0c\u56e0\u4e3a AUC \u5c31\u662f\u5bf9\u6837\u672c\u8fdb\u884c\u6392\u5e8f\u3002 def rank_mean ( probas ): # \u521b\u5efa\u7a7a\u5217\u8868ranked\u5b58\u50a8\u6bcf\u4e2a\u7c7b\u522b\u6982\u7387\u503c\u6392\u540d ranked = [] # \u904d\u5386\u6982\u7387\u503c\u6bcf\u4e00\u5217\uff08\u6bcf\u4e2a\u7c7b\u522b\u7684\u6982\u7387\u503c\uff09 for i in range ( probas . shape [ 1 ]): # \u5f53\u524d\u5217\u6982\u7387\u503c\u6392\u540d\uff0crank_data\u662f\u6392\u540d\u7ed3\u679c rank_data = stats . rankdata ( probas [:, i ]) # \u5c06\u5f53\u524d\u5217\u6392\u540d\u7ed3\u679c\u6dfb\u52a0\u5230ranked\u5217\u8868\u4e2d ranked . append ( rank_data ) # \u5c06ranked\u5217\u8868\u4e2d\u6392\u540d\u7ed3\u679c\u6309\u5217\u5806\u53e0\uff0c\u5f62\u6210\u4e8c\u7ef4\u6570\u7ec4 ranked = np . column_stack ( ranked ) # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u8ba1\u7b97\u6837\u672c\u6392\u540d\u5e73\u5747\u503c return np . mean ( ranked , axis = 1 ) \u8bf7\u6ce8\u610f\uff0c\u5728 scipy \u7684 rankdata \u4e2d\uff0c\u7b49\u7ea7\u4ece 1 \u5f00\u59cb\u3002 \u4e3a\u4ec0\u4e48\u8fd9\u7c7b\u96c6\u5408\u6709\u6548\uff1f\u8ba9\u6211\u4eec\u770b\u770b\u56fe 1\u3002 \u56fe 1\uff1a\u4e09\u4eba\u731c\u5927\u8c61\u7684\u8eab\u9ad8 \u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u6709\u4e09\u4e2a\u4eba\u5728\u731c\u5927\u8c61\u7684\u9ad8\u5ea6\uff0c\u90a3\u4e48\u539f\u59cb\u9ad8\u5ea6\u5c06\u975e\u5e38\u63a5\u8fd1\u4e09\u4e2a\u4eba\u731c\u6d4b\u7684\u5e73\u5747\u503c\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u4eba\u90fd\u80fd\u731c\u5230\u975e\u5e38\u63a5\u8fd1\u5927\u8c61\u539f\u6765\u7684\u9ad8\u5ea6\u3002\u63a5\u8fd1\u4f30\u8ba1\u503c\u610f\u5473\u7740\u8bef\u5dee\uff0c\u4f46\u5982\u679c\u6211\u4eec\u5c06\u4e09\u4e2a\u9884\u6d4b\u503c\u5e73\u5747\uff0c\u5c31\u80fd\u5c06\u8bef\u5dee\u964d\u5230\u6700\u4f4e\u3002\u8fd9\u5c31\u662f\u591a\u4e2a\u6a21\u578b\u5e73\u5747\u7684\u4e3b\u8981\u601d\u60f3\u3002 $$ Final\\ Probabilities = w_1 \\times M1_proba + w_2 \\times M2_proba + \\cdots + w_n \\times Mn_proba $$ \u5176\u4e2d \\((w_1 + w_2 + w_3 + \\cdots + w_n)=1.0\\) \u4f8b\u5982\uff0c\u5982\u679c\u4f60\u6709\u4e00\u4e2a AUC \u975e\u5e38\u9ad8\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u548c\u4e00\u4e2a AUC \u7a0d\u4f4e\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0c\u968f\u673a\u68ee\u6797\u6a21\u578b\u5360 70%\uff0c\u903b\u8f91\u56de\u5f52\u6a21\u578b\u5360 30%\u3002\u90a3\u4e48\uff0c\u6211\u662f\u5982\u4f55\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u7684\u5462\uff1f\u8ba9\u6211\u4eec\u518d\u6dfb\u52a0\u4e00\u4e2a\u6a21\u578b\uff0c\u5047\u8bbe\u73b0\u5728\u6211\u4eec\u4e5f\u6709\u4e00\u4e2a xgboost \u6a21\u578b\uff0c\u5b83\u7684 AUC \u6bd4\u968f\u673a\u68ee\u6797\u9ad8\u3002\u73b0\u5728\uff0c\u6211\u5c06\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0cxgboost\uff1a\u968f\u673a\u68ee\u6797\uff1a\u903b\u8f91\u56de\u5f52\u7684\u6bd4\u4f8b\u4e3a 3:2:1\u3002\u5f88\u7b80\u5355\u5427\uff1f\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u6613\u5982\u53cd\u638c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u662f\u5982\u4f55\u505a\u5230\u7684\u3002 \u5047\u5b9a\u6211\u4eec\u6709\u4e09\u53ea\u7334\u5b50\uff0c\u4e09\u53ea\u65cb\u94ae\u7684\u6570\u503c\u5728 0 \u548c 1 \u4e4b\u95f4\u3002\u8fd9\u4e9b\u7334\u5b50\u8f6c\u52a8\u65cb\u94ae\uff0c\u6211\u4eec\u8ba1\u7b97\u5b83\u4eec\u6bcf\u8f6c\u5230\u4e00\u4e2a\u6570\u503c\u65f6\u7684 AUC \u5206\u6570\u3002\u6700\u7ec8\uff0c\u7334\u5b50\u4eec\u4f1a\u627e\u5230\u4e00\u4e2a\u80fd\u7ed9\u51fa\u6700\u4f73 AUC \u7684\u7ec4\u5408\u3002\u6ca1\u9519\uff0c\u8fd9\u5c31\u662f\u968f\u673a\u641c\u7d22\uff01\u5728\u8fdb\u884c\u8fd9\u7c7b\u641c\u7d22\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\u4e24\u4e2a\u6700\u91cd\u8981\u7684\u7ec4\u5408\u89c4\u5219\u3002 \u7ec4\u5408\u7684\u7b2c\u4e00\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u7ec4\u5408\u7684\u7b2c\u4e8c\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u662f\u7684\u3002\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u6761\u89c4\u5219\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u6298\u53e0\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u5047\u8bbe\u6211\u4eec\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\uff1a\u6298\u53e0 1 \u548c\u6298\u53e0 2\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u6837\u505a\u53ea\u662f\u4e3a\u4e86\u7b80\u5316\u89e3\u91ca\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\uff0c\u60a8\u5e94\u8be5\u521b\u5efa\u66f4\u591a\u7684\u6298\u53e0\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u968f\u673a\u68ee\u6797\u6a21\u578b\u3001\u903b\u8f91\u56de\u5f52\u6a21\u578b\u548c xgboost \u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 2 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u4e4b\u540e\uff0c\u6211\u4eec\u5728\u6298\u53e0 2 \u4e0a\u4ece\u5934\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 1 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u4e3a\u6240\u6709\u8bad\u7ec3\u6570\u636e\u521b\u5efa\u4e86\u9884\u6d4b\u7ed3\u679c\u3002\u73b0\u5728\uff0c\u4e3a\u4e86\u5408\u5e76\u8fd9\u4e9b\u6a21\u578b\uff0c\u6211\u4eec\u5c06\u6298\u53e0 1 \u548c\u6298\u53e0 1 \u7684\u6240\u6709\u9884\u6d4b\u6570\u636e\u5408\u5e76\u5728\u4e00\u8d77\uff0c\u7136\u540e\u521b\u5efa\u4e00\u4e2a\u4f18\u5316\u51fd\u6570\uff0c\u8bd5\u56fe\u627e\u5230\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4fbf\u9488\u5bf9\u6298\u53e0 2 \u7684\u76ee\u6807\u6700\u5c0f\u5316\u8bef\u5dee\u6216\u6700\u5927\u5316 AUC\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u662f\u7528\u4e09\u4e2a\u6a21\u578b\u7684\u9884\u6d4b\u6982\u7387\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u4e00\u4e2a\u4f18\u5316\u6a21\u578b\uff0c\u7136\u540e\u5728\u6298\u53e0 2 \u4e0a\u5bf9\u5176\u8fdb\u884c\u8bc4\u4f30\u3002\u8ba9\u6211\u4eec\u5148\u6765\u770b\u770b\u6211\u4eec\u53ef\u4ee5\u7528\u6765\u627e\u5230\u591a\u4e2a\u6a21\u578b\u7684\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4f18\u5316 AUC\uff08\u6216\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u6d4b\u6307\u6807\u7ec4\u5408\uff09\u7684\u7c7b\u3002 import numpy as np from functools import partial from scipy.optimize import fmin from sklearn import metrics class OptimizeAUC : def __init__ ( self ): # \u521d\u59cb\u5316\u7cfb\u6570 self . coef_ = 0 def _auc ( self , coef , X , y ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u7cfb\u6570 x_coef = X * coef # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8ba1\u7b97AUC\u5206\u6570 auc_score = metrics . roc_auc_score ( y , predictions ) # \u8fd4\u56de\u8d1fAUC\u4ee5\u4fbf\u6700\u5c0f\u5316 return - 1.0 * auc_score def fit ( self , X , y ): # \u521b\u5efa\u5e26\u6709\u90e8\u5206\u53c2\u6570\u7684\u76ee\u6807\u51fd\u6570 loss_partial = partial ( self . _auc , X = X , y = y ) # \u521d\u59cb\u5316\u7cfb\u6570 initial_coef = np . random . dirichlet ( np . ones ( X . shape [ 1 ]), size = 1 ) # \u4f7f\u7528fmin\u51fd\u6570\u4f18\u5316AUC\u76ee\u6807\u51fd\u6570\uff0c\u627e\u5230\u6700\u4f18\u7cfb\u6570 self . coef_ = fmin ( loss_partial , initial_coef , disp = True ) def predict ( self , X ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u8bad\u7ec3\u597d\u7684\u7cfb\u6570 x_coef = X * self . coef_ # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8fd4\u56de\u9884\u6d4b\u7ed3\u679c return predictions \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528\u5b83\uff0c\u5e76\u5c06\u5176\u4e0e\u7b80\u5355\u5e73\u5747\u6cd5\u8fdb\u884c\u6bd4\u8f83\u3002 import xgboost as xgb from sklearn.datasets import make_classification from sklearn import ensemble from sklearn import linear_model from sklearn import metrics from sklearn import model_selection # \u751f\u6210\u4e00\u4e2a\u5206\u7c7b\u6570\u636e\u96c6 X , y = make_classification ( n_samples = 10000 , n_features = 25 ) # \u5212\u5206\u6570\u636e\u96c6\u4e3a\u4e24\u4e2a\u4ea4\u53c9\u9a8c\u8bc1\u6298\u53e0 xfold1 , xfold2 , yfold1 , yfold2 = model_selection . train_test_split ( X , y , test_size = 0.5 , stratify = y ) # \u521d\u59cb\u5316\u4e09\u4e2a\u4e0d\u540c\u7684\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold1 , yfold1 ) rf . fit ( xfold1 , yfold1 ) xgbc . fit ( xfold1 , yfold1 ) # \u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold2 )[:, 1 ] pred_rf = rf . predict_proba ( xfold2 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold2 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold2_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold2 = [] for i in range ( fold2_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold2 , fold2_preds [:, i ]) aucs_fold2 . append ( auc ) print ( f \"Fold-2: LR AUC = { aucs_fold2 [ 0 ] } \" ) print ( f \"Fold-2: RF AUC = { aucs_fold2 [ 1 ] } \" ) print ( f \"Fold-2: XGB AUC = { aucs_fold2 [ 2 ] } \" ) print ( f \"Fold-2: Average Pred AUC = { aucs_fold2 [ 3 ] } \" ) # \u91cd\u65b0\u521d\u59cb\u5316\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold2 , yfold2 ) rf . fit ( xfold2 , yfold2 ) xgbc . fit ( xfold2 , yfold2 ) # \u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold1 )[:, 1 ] pred_rf = rf . predict_proba ( xfold1 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold1 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold1_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold1 = [] for i in range ( fold1_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold1 , fold1_preds [:, i ]) aucs_fold1 . append ( auc ) print ( f \"Fold-1: LR AUC = { aucs_fold1 [ 0 ] } \" ) print ( f \"Fold-1: RF AUC = { aucs_fold1 [ 1 ] } \" ) print ( f \"Fold-1: XGB AUC = { aucs_fold1 [ 2 ] } \" ) print ( f \"Fold-1: Average prediction AUC = { aucs_fold1 [ 3 ] } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765\u8bad\u7ec3\u4f18\u5316\u5668 opt . fit ( fold1_preds [:, : - 1 ], yfold1 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold2 = opt . predict ( fold2_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold2 , opt_preds_fold2 ) print ( f \"Optimized AUC, Fold 2 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765 opt . fit ( fold2_preds [:, : - 1 ], yfold2 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold1 = opt . predict ( fold1_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold1 , opt_preds_fold1 ) print ( f \"Optimized AUC, Fold 1 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) \u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u8f93\u51fa\uff1a \u276f python auc_opt . py Fold - 2 : LR AUC = 0.9145446769443348 Fold - 2 : RF AUC = 0.9269918948683287 Fold - 2 : XGB AUC = 0.9302436595508696 Fold - 2 : Average Pred AUC = 0.927701495890154 Fold - 1 : LR AUC = 0.9050872233256017 Fold - 1 : RF AUC = 0.9179382818311258 Fold - 1 : XGB AUC = 0.9195837242005629 Fold - 1 : Average prediction AUC = 0.9189669233123695 Optimization terminated successfully . Current function value : - 0.920643 Iterations : 50 Function evaluations : 109 Optimized AUC , Fold 2 = 0.9305386199756128 Coefficients = [ - 0.00188194 0.19328336 0.35891836 ] Optimization terminated successfully . Current function value : - 0.931232 Iterations : 56 Function evaluations : 113 Optimized AUC , Fold 1 = 0.9192523637234037 Coefficients = [ - 0.15655124 0.22393151 0.58711366 ] \u6211\u4eec\u770b\u5230\uff0c\u5e73\u5747\u503c\u66f4\u597d\uff0c\u4f46\u4f7f\u7528\u4f18\u5316\u5668\u627e\u5230\u9608\u503c\u66f4\u597d\uff01\u6709\u65f6\uff0c\u5e73\u5747\u503c\u662f\u6700\u597d\u7684\u9009\u62e9\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u7cfb\u6570\u52a0\u8d77\u6765\u5e76\u6ca1\u6709\u8fbe\u5230 1.0\uff0c\u4f46\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u6211\u4eec\u8981\u5904\u7406\u7684\u662f AUC\uff0c\u800c AUC \u53ea\u5173\u5fc3\u7b49\u7ea7\u3002 \u5373\u4f7f\u968f\u673a\u68ee\u6797\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u968f\u673a\u68ee\u6797\u53ea\u662f\u8bb8\u591a\u7b80\u5355\u51b3\u7b56\u6811\u7684\u7ec4\u5408\u3002\u968f\u673a\u68ee\u6797\u5c5e\u4e8e\u96c6\u5408\u6a21\u578b\u7684\u4e00\u79cd\uff0c\u4e5f\u5c31\u662f\u4fd7\u79f0\u7684 \"bagging\" \u3002\u5728\u888b\u96c6\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u5c0f\u6570\u636e\u5b50\u96c6\u5e76\u8bad\u7ec3\u591a\u4e2a\u7b80\u5355\u6a21\u578b\u3002\u6700\u7ec8\u7ed3\u679c\u7531\u6240\u6709\u8fd9\u4e9b\u5c0f\u6a21\u578b\u7684\u9884\u6d4b\u7ed3\u679c\uff08\u5982\u5e73\u5747\u503c\uff09\u7ec4\u5408\u800c\u6210\u3002 \u6211\u4eec\u4f7f\u7528\u7684 xgboost \u6a21\u578b\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u6240\u6709\u68af\u5ea6\u63d0\u5347\u6a21\u578b\u90fd\u662f\u96c6\u5408\u6a21\u578b\uff0c\u7edf\u79f0\u4e3a \u63d0\u5347\u6a21\u578b\uff08boosting models\uff09 \u3002\u63d0\u5347\u6a21\u578b\u7684\u5de5\u4f5c\u539f\u7406\u4e0e\u88c5\u888b\u6a21\u578b\u7c7b\u4f3c\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u63d0\u5347\u6a21\u578b\u4e2d\u7684\u8fde\u7eed\u6a21\u578b\u662f\u6839\u636e\u8bef\u5dee\u6b8b\u5dee\u8bad\u7ec3\u7684\uff0c\u5e76\u503e\u5411\u4e8e\u6700\u5c0f\u5316\u524d\u9762\u6a21\u578b\u7684\u8bef\u5dee\u3002\u8fd9\u6837\uff0c\u63d0\u5347\u6a21\u578b\u5c31\u80fd\u5b8c\u7f8e\u5730\u5b66\u4e60\u6570\u636e\uff0c\u56e0\u6b64\u5bb9\u6613\u51fa\u73b0\u8fc7\u62df\u5408\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u770b\u5230\u7684\u4ee3\u7801\u7247\u6bb5\u53ea\u8003\u8651\u4e86\u4e00\u5217\u3002\u4f46\u60c5\u51b5\u5e76\u975e\u603b\u662f\u5982\u6b64\uff0c\u5f88\u591a\u65f6\u5019\u60a8\u9700\u8981\u5904\u7406\u591a\u5217\u9884\u6d4b\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u4f1a\u9047\u5230\u4ece\u591a\u4e2a\u7c7b\u522b\u4e2d\u9884\u6d4b\u4e00\u4e2a\u7c7b\u522b\u7684\u95ee\u9898\uff0c\u5373\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5bf9\u4e8e\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u9009\u62e9\u6295\u7968\u65b9\u6cd5\u3002\u4f46\u6295\u7968\u6cd5\u5e76\u4e0d\u603b\u662f\u6700\u4f73\u65b9\u6cd5\u3002\u5982\u679c\u8981\u7ec4\u5408\u6982\u7387\uff0c\u5c31\u4f1a\u6709\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4\uff0c\u800c\u4e0d\u662f\u50cf\u6211\u4eec\u4e4b\u524d\u4f18\u5316 AUC \u65f6\u7684\u5411\u91cf\u3002\u5982\u679c\u6709\u591a\u4e2a\u7c7b\u522b\uff0c\u53ef\u4ee5\u5c1d\u8bd5\u4f18\u5316\u5bf9\u6570\u635f\u5931\uff08\u6216\u5176\u4ed6\u4e0e\u4e1a\u52a1\u76f8\u5173\u7684\u6307\u6807\uff09\u3002 \u8981\u8fdb\u884c\u7ec4\u5408\uff0c\u53ef\u4ee5\u5728\u62df\u5408\u51fd\u6570 (X) \u4e2d\u4f7f\u7528 numpy \u6570\u7ec4\u5217\u8868\u800c\u4e0d\u662f numpy \u6570\u7ec4\uff0c\u968f\u540e\u8fd8\u9700\u8981\u66f4\u6539\u4f18\u5316\u5668\u548c\u9884\u6d4b\u51fd\u6570\u3002\u6211\u5c31\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u5427\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8fdb\u5165\u4e0b\u4e00\u4e2a\u6709\u8da3\u7684\u8bdd\u9898\uff0c\u8fd9\u4e2a\u8bdd\u9898\u76f8\u5f53\u6d41\u884c\uff0c\u88ab\u79f0\u4e3a \u5806\u53e0 \u3002\u56fe 2 \u5c55\u793a\u4e86\u5982\u4f55\u5806\u53e0\u6a21\u578b\u3002 \u56fe2 : Stacking \u5806\u53e0\u4e0d\u50cf\u5236\u9020\u706b\u7bad\u3002\u5b83\u7b80\u5355\u660e\u4e86\u3002\u5982\u679c\u60a8\u8fdb\u884c\u4e86\u6b63\u786e\u7684\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5e76\u5728\u6574\u4e2a\u5efa\u6a21\u8fc7\u7a0b\u4e2d\u4fdd\u6301\u6298\u53e0\u4e0d\u53d8\uff0c\u90a3\u4e48\u5c31\u4e0d\u4f1a\u51fa\u73b0\u4efb\u4f55\u8fc7\u5ea6\u8d34\u5408\u7684\u60c5\u51b5\u3002 \u8ba9\u6211\u7528\u7b80\u5355\u7684\u8981\u70b9\u5411\u4f60\u63cf\u8ff0\u4e00\u4e0b\u8fd9\u4e2a\u60f3\u6cd5\u3002 - \u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u82e5\u5e72\u6298\u53e0\u3002 - \u8bad\u7ec3\u4e00\u5806\u6a21\u578b\uff1a M1\u3001M2.....Mn\u3002 - \u521b\u5efa\u5b8c\u6574\u7684\u8bad\u7ec3\u9884\u6d4b\uff08\u4f7f\u7528\u975e\u6298\u53e0\u8bad\u7ec3\uff09\uff0c\u5e76\u4f7f\u7528\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\u9884\u6d4b\u3002 - \u76f4\u5230\u8fd9\u91cc\u662f\u7b2c 1 \u5c42 (L1)\u3002 - \u5c06\u8fd9\u4e9b\u6a21\u578b\u7684\u6298\u53e0\u9884\u6d4b\u4f5c\u4e3a\u53e6\u4e00\u4e2a\u6a21\u578b\u7684\u7279\u5f81\u3002\u8fd9\u5c31\u662f\u4e8c\u7ea7\u6a21\u578b\uff08L2\uff09\u3002 - \u4f7f\u7528\u4e0e\u4e4b\u524d\u76f8\u540c\u7684\u6298\u53e0\u6765\u8bad\u7ec3\u8fd9\u4e2a L2 \u6a21\u578b\u3002 - \u73b0\u5728\uff0c\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u521b\u5efa OOF\uff08\u6298\u53e0\u5916\uff09\u9884\u6d4b\u3002 - \u73b0\u5728\u60a8\u5c31\u6709\u4e86\u8bad\u7ec3\u6570\u636e\u7684 L2 \u9884\u6d4b\u548c\u6700\u7ec8\u6d4b\u8bd5\u96c6\u9884\u6d4b\u3002 \u60a8\u53ef\u4ee5\u4e0d\u65ad\u91cd\u590d L1 \u90e8\u5206\uff0c\u4e5f\u53ef\u4ee5\u521b\u5efa\u4efb\u610f\u591a\u7684\u5c42\u6b21\u3002 \u6709\u65f6\uff0c\u4f60\u8fd8\u4f1a\u9047\u5230\u4e00\u4e2a\u53eb\u6df7\u5408\u7684\u672f\u8bed blending \u3002\u5982\u679c\u4f60\u9047\u5230\u4e86\uff0c\u4e0d\u7528\u592a\u62c5\u5fc3\u3002\u5b83\u53ea\u4e0d\u8fc7\u662f\u7528\u4e00\u4e2a\u4fdd\u7559\u7ec4\u6765\u5806\u53e0\uff0c\u800c\u4e0d\u662f\u591a\u91cd\u6298\u53e0\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u6211\u5728\u672c\u7ae0\u4e2d\u6240\u63cf\u8ff0\u7684\u5185\u5bb9\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u7c7b\u578b\u7684\u95ee\u9898\uff1a\u5206\u7c7b\u3001\u56de\u5f52\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u7b49\u3002","title":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5"},{"location":"%E7%BB%84%E7%BB%87%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%A1%B9%E7%9B%AE/","text":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee \u7ec8\u4e8e\uff0c\u6211\u4eec\u53ef\u4ee5\u5f00\u59cb\u6784\u5efa\u7b2c\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e86\u3002 \u662f\u8fd9\u6837\u5417\uff1f \u5728\u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u6ce8\u610f\u51e0\u4ef6\u4e8b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u6211\u4eec\u5c06\u5728\u96c6\u6210\u5f00\u53d1\u73af\u5883/\u6587\u672c\u7f16\u8f91\u5668\u4e2d\u5de5\u4f5c\uff0c\u800c\u4e0d\u662f\u5728 jupyter notebook\u4e2d\u3002\u4f60\u4e5f\u53ef\u4ee5\u5728 jupyter notebook\u4e2d\u5de5\u4f5c\uff0c\u8fd9\u5b8c\u5168\u53d6\u51b3\u4e8e\u4f60\u3002\u4e0d\u8fc7\uff0c\u6211\u5c06\u53ea\u4f7f\u7528 jupyter notebook\u6765\u63a2\u7d22\u6570\u636e\u3001\u7ed8\u5236\u56fe\u8868\u548c\u56fe\u5f62\u3002\u6211\u4eec\u5c06\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u6784\u5efa\u5206\u7c7b\u6846\u67b6\uff0c\u5373\u63d2\u5373\u7528\u3002\u60a8\u65e0\u9700\u5bf9\u4ee3\u7801\u505a\u592a\u591a\u6539\u52a8\u5c31\u80fd\u8bad\u7ec3\u6a21\u578b\uff0c\u800c\u4e14\u5f53\u60a8\u6539\u8fdb\u6a21\u578b\u65f6\uff0c\u8fd8\u80fd\u4f7f\u7528 git \u5bf9\u5176\u8fdb\u884c\u8ddf\u8e2a\u3002 \u6211\u4eec\u9996\u5148\u6765\u770b\u770b\u6587\u4ef6\u7684\u7ed3\u6784\u3002\u5bf9\u4e8e\u4f60\u6b63\u5728\u505a\u7684\u4efb\u4f55\u9879\u76ee\uff0c\u90fd\u8981\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\u5939\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u5c06\u9879\u76ee\u547d\u540d\u4e3a \"project\"\u3002 \u9879\u76ee\u6587\u4ef6\u5939\u5185\u90e8\u5e94\u8be5\u5982\u4e0b\u6240\u793a\u3002 input train.csv test.csv src create_folds.py train.py inference.py models.py config.py model_dispatcher.py models model_rf.bin model_et.bin notebooks exploration.ipynb check_data.ipynb README.md LICENSE \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e9b\u6587\u4ef6\u5939\u548c\u6587\u4ef6\u7684\u5185\u5bb9\u3002 input/ \uff1a\u8be5\u6587\u4ef6\u5939\u5305\u542b\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7684\u6240\u6709\u8f93\u5165\u6587\u4ef6\u548c\u6570\u636e\u3002\u5982\u679c\u60a8\u6b63\u5728\u5f00\u53d1 NLP \u9879\u76ee\uff0c\u60a8\u53ef\u4ee5\u5c06embeddings\u653e\u5728\u8fd9\u91cc\u3002\u5982\u679c\u662f\u56fe\u50cf\u9879\u76ee\uff0c\u6240\u6709\u56fe\u50cf\u90fd\u653e\u5728\u8be5\u6587\u4ef6\u5939\u4e0b\u7684\u5b50\u6587\u4ef6\u5939\u4e2d\u3002 src/ \uff1a\u6211\u4eec\u5c06\u5728\u8fd9\u91cc\u4fdd\u5b58\u4e0e\u9879\u76ee\u76f8\u5173\u7684\u6240\u6709 python \u811a\u672c\u3002\u5982\u679c\u6211\u8bf4\u7684\u662f\u4e00\u4e2a python \u811a\u672c\uff0c\u5373\u4efb\u4f55 *.py \u6587\u4ef6\uff0c\u5b83\u90fd\u5b58\u50a8\u5728 src \u6587\u4ef6\u5939\u4e2d\u3002 models/ \uff1a\u8be5\u6587\u4ef6\u5939\u4fdd\u5b58\u6240\u6709\u8bad\u7ec3\u8fc7\u7684\u6a21\u578b\u3002 notebook/ \uff1a\u6240\u6709 jupyter notebook\uff08\u5373\u4efb\u4f55 *.ipynb \u6587\u4ef6\uff09\u90fd\u5b58\u50a8\u5728\u7b14\u8bb0\u672c \u6587\u4ef6\u5939\u4e2d\u3002 README.md \uff1a\u8fd9\u662f\u4e00\u4e2a\u6807\u8bb0\u7b26\u6587\u4ef6\uff0c\u60a8\u53ef\u4ee5\u5728\u5176\u4e2d\u63cf\u8ff0\u60a8\u7684\u9879\u76ee\uff0c\u5e76\u5199\u660e\u5982\u4f55\u8bad\u7ec3\u6a21\u578b\u6216\u5728\u751f\u4ea7\u73af\u5883\u4e2d\u4f7f\u7528\u3002 LICENSE \uff1a\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u6587\u4ef6\uff0c\u5305\u542b\u9879\u76ee\u7684\u8bb8\u53ef\u8bc1\uff0c\u5982 MIT\u3001Apache \u7b49\u3002\u5173\u4e8e\u8bb8\u53ef\u8bc1\u7684\u8be6\u7ec6\u4ecb\u7ecd\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 \u5047\u8bbe\u4f60\u6b63\u5728\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u5bf9 MNIST \u6570\u636e\u96c6\uff08\u51e0\u4e4e\u6bcf\u672c\u673a\u5668\u5b66\u4e60\u4e66\u7c4d\u90fd\u4f1a\u7528\u5230\u7684\u6570\u636e\u96c6\uff09\u8fdb\u884c\u5206\u7c7b\u3002\u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\uff0c\u6211\u4eec\u5728\u4ea4\u53c9\u68c0\u9a8c\u4e00\u7ae0\u4e2d\u4e5f\u63d0\u5230\u8fc7 MNIST \u6570\u636e\u96c6\u3002\u6240\u4ee5\uff0c\u6211\u5c31\u4e0d\u89e3\u91ca\u8fd9\u4e2a\u6570\u636e\u96c6\u662f\u4ec0\u4e48\u6837\u5b50\u4e86\u3002\u7f51\u4e0a\u6709\u8bb8\u591a\u4e0d\u540c\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6\uff0c\u4f46\u6211\u4eec\u5c06\u4f7f\u7528 CSV \u683c\u5f0f\u7684\u6570\u636e\u96c6\u3002 \u5728\u8fd9\u79cd\u683c\u5f0f\u7684\u6570\u636e\u96c6\u4e2d\uff0cCSV \u7684\u6bcf\u4e00\u884c\u90fd\u5305\u542b\u56fe\u50cf\u7684\u6807\u7b7e\u548c 784 \u4e2a\u50cf\u7d20\u503c\uff0c\u50cf\u7d20\u503c\u8303\u56f4\u4ece 0 \u5230 255\u3002\u6570\u636e\u96c6\u5305\u542b 60000 \u5f20\u8fd9\u79cd\u683c\u5f0f\u7684\u56fe\u50cf\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u8f7b\u677e\u8bfb\u53d6\u8fd9\u79cd\u6570\u636e\u683c\u5f0f\u3002 \u8bf7\u6ce8\u610f\uff0c\u5c3d\u7ba1\u56fe 1 \u663e\u793a\u6240\u6709\u50cf\u7d20\u503c\u5747\u4e3a\u96f6\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002 \u56fe 1\uff1aCSV\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u6807\u7b7e\u5217\u7684\u8ba1\u6570\u3002 \u56fe 2\uff1aMNIST \u6570\u636e\u96c6\u4e2d\u7684\u6807\u7b7e\u8ba1\u6570 \u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c\u66f4\u591a\u7684\u63a2\u7d22\u3002\u6211\u4eec\u5df2\u7ecf\u77e5\u9053\u4e86\u6211\u4eec\u6240\u62e5\u6709\u7684\u6570\u636e\uff0c\u6ca1\u6709\u5fc5\u8981\u518d\u5bf9\u4e0d\u540c\u7684\u50cf\u7d20\u503c\u8fdb\u884c\u7ed8\u56fe\u3002\u4ece\u56fe 2 \u4e2d\u53ef\u4ee5\u6e05\u695a\u5730\u770b\u51fa\uff0c\u6807\u7b7e\u7684\u5206\u5e03\u76f8\u5f53\u5747\u5300\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387/F1 \u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u8fd9\u5c31\u662f\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u7684\u7b2c\u4e00\u6b65\uff1a\u786e\u5b9a\u8861\u91cf\u6807\u51c6\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7f16\u5199\u4e00\u4e9b\u4ee3\u7801\u4e86\u3002\u6211\u4eec\u9700\u8981\u521b\u5efa src/ \u6587\u4ef6\u5939\u548c\u4e00\u4e9b python \u811a\u672c\u3002 \u8bf7\u6ce8\u610f\uff0c\u8bad\u7ec3 CSV \u6587\u4ef6\u4f4d\u4e8e input/ \u6587\u4ef6\u5939\u4e2d\uff0c\u540d\u4e3a mnist_train.csv \u3002 \u5bf9\u4e8e\u8fd9\u6837\u4e00\u4e2a\u9879\u76ee\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5e94\u8be5\u662f\u4ec0\u4e48\u6837\u7684\u5462\uff1f \u9996\u5148\u8981\u521b\u5efa\u7684\u811a\u672c\u662f create_folds.py \u3002 \u8fd9\u5c06\u5728 input/ \u6587\u4ef6\u5939\u4e2d\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a mnist_train_folds.csv \u7684\u65b0\u6587\u4ef6\uff0c\u4e0e mnist_train.csv \u76f8\u540c\u3002\u552f\u4e00\u4e0d\u540c\u7684\u662f\uff0c\u8fd9\u4e2a CSV \u6587\u4ef6\u7ecf\u8fc7\u4e86\u968f\u673a\u6392\u5e8f\uff0c\u5e76\u65b0\u589e\u4e86\u4e00\u5217\u540d\u4e3a kfold \u7684\u5185\u5bb9\u3002 \u4e00\u65e6\u6211\u4eec\u51b3\u5b9a\u4e86\u8981\u4f7f\u7528\u54ea\u79cd\u8bc4\u4f30\u6307\u6807\u5e76\u521b\u5efa\u4e86\u6298\u53e0\uff0c\u5c31\u53ef\u4ee5\u5f00\u59cb\u521b\u5efa\u57fa\u672c\u6a21\u578b\u4e86\u3002\u8fd9\u53ef\u4ee5\u5728 train.py \u4e2d\u5b8c\u6210\u3002 import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/mnist_train_folds.csv\" ) # \u9009\u53d6df\u4e2dkfold\u5217\u4e0d\u7b49\u4e8efold df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u9009\u53d6df\u4e2dkfold\u5217\u7b49\u4e8efold df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u8bad\u7ec3\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_train = df_train . drop ( \"label\" , axis = 1 ) . values # \u8bad\u7ec3\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_train = df_train . label . values # \u9a8c\u8bc1\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values # \u9a8c\u8bc1\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_valid = df_valid . label . values # \u5b9e\u4f8b\u5316\u51b3\u7b56\u6811\u6a21\u578b clf = tree . DecisionTreeClassifier () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b clf . fit ( x_train , y_train ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u8f93\u5165\u5f97\u5230\u9884\u6d4b\u7ed3\u679c preds = clf . predict ( x_valid ) # \u8ba1\u7b97\u9a8c\u8bc1\u96c6\u51c6\u786e\u7387 accuracy = metrics . accuracy_score ( y_valid , preds ) # \u6253\u5370fold\u4fe1\u606f\u548c\u51c6\u786e\u7387 print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) # \u4fdd\u5b58\u6a21\u578b joblib . dump ( clf , f \"../models/dt_ { fold } .bin\" ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u60a8\u53ef\u4ee5\u5728\u63a7\u5236\u53f0\u8c03\u7528 python train.py \u8fd0\u884c\u8be5\u811a\u672c\u3002 \u276f python train . py Fold = 0 , Accuracy = 0.8680833333333333 Fold = 1 , Accuracy = 0.8685 Fold = 2 , Accuracy = 0.8674166666666666 Fold = 3 , Accuracy = 0.8703333333333333 Fold = 4 , Accuracy = 0.8699166666666667 \u67e5\u770b\u8bad\u7ec3\u811a\u672c\u65f6\uff0c\u60a8\u4f1a\u53d1\u73b0\u8fd8\u6709\u4e00\u4e9b\u5185\u5bb9\u662f\u786c\u7f16\u7801\u7684\uff0c\u4f8b\u5982\u6298\u53e0\u6570\u3001\u8bad\u7ec3\u6587\u4ef6\u548c\u8f93\u51fa\u6587\u4ef6\u5939\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u6240\u6709\u8fd9\u4e9b\u4fe1\u606f\u7684\u914d\u7f6e\u6587\u4ef6\uff1a config.py \u3002 TRAINING_FILE = \"../input/mnist_train_folds.csv\" MODEL_OUTPUT = \"../models/\" \u6211\u4eec\u8fd8\u5bf9\u8bad\u7ec3\u811a\u672c\u8fdb\u884c\u4e86\u4e00\u4e9b\u4fee\u6539\u3002\u8bad\u7ec3\u6587\u4ef6\u73b0\u5728\u4f7f\u7528\u914d\u7f6e\u6587\u4ef6\u3002\u8fd9\u6837\uff0c\u66f4\u6539\u6570\u636e\u6216\u6a21\u578b\u8f93\u51fa\u5c31\u66f4\u5bb9\u6613\u4e86\u3002 import os import config import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u4f7f\u7528config\u4e2d\u7684\u8def\u5f84\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values clf = tree . DecisionTreeClassifier () clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" ) ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5e76\u6ca1\u6709\u5c55\u793a\u8fd9\u4e2a\u57f9\u8bad\u811a\u672c\u4e0e\u4e4b\u524d\u811a\u672c\u7684\u533a\u522b\u3002\u8bf7\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e24\u4e2a\u811a\u672c\uff0c\u81ea\u5df1\u627e\u51fa\u4e0d\u540c\u4e4b\u5904\u3002\u533a\u522b\u5e76\u4e0d\u591a\u3002 \u4e0e\u8bad\u7ec3\u811a\u672c\u76f8\u5173\u7684\u8fd8\u6709\u4e00\u70b9\u53ef\u4ee5\u6539\u8fdb\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u4e3a\u6bcf\u4e2a\u6298\u53e0\u591a\u6b21\u8c03\u7528\u8fd0\u884c\u51fd\u6570\u3002\u6709\u65f6\uff0c\u5728\u540c\u4e00\u4e2a\u811a\u672c\u4e2d\u8fd0\u884c\u591a\u4e2a\u6298\u53e0\u5e76\u4e0d\u53ef\u53d6\uff0c\u56e0\u4e3a\u5185\u5b58\u6d88\u8017\u53ef\u80fd\u4f1a\u4e0d\u65ad\u589e\u52a0\uff0c\u7a0b\u5e8f\u53ef\u80fd\u4f1a\u5d29\u6e83\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5411\u8bad\u7ec3\u811a\u672c\u4f20\u9012\u53c2\u6570\u3002\u6211\u559c\u6b22\u4f7f\u7528 argparse\u3002 import argparse if __name__ == \"__main__\" : # \u5b9e\u4f8b\u5316\u53c2\u6570\u73af\u5883 parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # \u8bfb\u53d6\u53c2\u6570 args = parser . parse_args () run ( fold = args . fold ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u8fd0\u884c python \u811a\u672c\uff0c\u4f46\u4ec5\u9650\u4e8e\u7ed9\u5b9a\u7684\u6298\u53e0\u3002 \u276f python train . py -- fold 0 Fold = 0 , Accuracy = 0.8656666666666667 \u4ed4\u7ec6\u89c2\u5bdf\uff0c\u6211\u4eec\u7684\u7b2c 0 \u6298\u5f97\u5206\u4e0e\u4e4b\u524d\u6709\u4e9b\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6a21\u578b\u4e2d\u5b58\u5728\u968f\u673a\u6027\u3002\u6211\u4eec\u5c06\u5728\u540e\u9762\u7684\u7ae0\u8282\u4e2d\u8ba8\u8bba\u5982\u4f55\u5904\u7406\u968f\u673a\u6027\u3002 \u73b0\u5728\uff0c\u5982\u679c\u4f60\u613f\u610f\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a shell \u811a\u672c \uff0c\u9488\u5bf9\u4e0d\u540c\u7684\u6298\u53e0\u4f7f\u7528\u4e0d\u540c\u7684\u547d\u4ee4\uff0c\u7136\u540e\u4e00\u8d77\u8fd0\u884c\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 python train . py -- fold 0 python train . py -- fold 1 python train . py -- fold 2 python train . py -- fold 3 python train . py -- fold 4 \u60a8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u5b83\u3002 \u276f sh run . sh Fold = 0 , Accuracy = 0.8675 Fold = 1 , Accuracy = 0.8693333333333333 Fold = 2 , Accuracy = 0.8683333333333333 Fold = 3 , Accuracy = 0.8704166666666666 Fold = 4 , Accuracy = 0.8685 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u53d6\u5f97\u4e86\u4e00\u4e9b\u8fdb\u5c55\uff0c\u4f46\u5982\u679c\u6211\u4eec\u770b\u4e00\u4e0b\u6211\u4eec\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u6211\u4eec\u4ecd\u7136\u53d7\u5230\u4e00\u4e9b\u4e1c\u897f\u7684\u9650\u5236\uff0c\u4f8b\u5982\u6a21\u578b\u3002\u6a21\u578b\u662f\u786c\u7f16\u7801\u5728\u8bad\u7ec3\u811a\u672c\u4e2d\u7684\uff0c\u53ea\u6709\u4fee\u6539\u811a\u672c\u624d\u80fd\u6539\u53d8\u5b83\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u65b0\u7684 python \u811a\u672c\uff0c\u540d\u4e3a model_dispatcher.py \u3002model_dispatcher.py\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5c06\u8c03\u5ea6\u6211\u4eec\u7684\u6a21\u578b\u5230\u8bad\u7ec3\u811a\u672c\u4e2d\u3002 from sklearn import tree models = { # \u4ee5gini\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), # \u4ee5entropy\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), } model_dispatcher.py \u4ece scikit-learn \u4e2d\u5bfc\u5165\u4e86 tree\uff0c\u5e76\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u6a21\u578b\u7684\u540d\u79f0\uff0c\u503c\u662f\u6a21\u578b\u672c\u8eab\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5b9a\u4e49\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u51b3\u7b56\u6811\uff0c\u4e00\u79cd\u4f7f\u7528\u57fa\u5c3c\u6807\u51c6\uff0c\u53e6\u4e00\u79cd\u4f7f\u7528\u71b5\u6807\u51c6\u3002\u8981\u4f7f\u7528 py\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u8bad\u7ec3\u811a\u672c\u505a\u4e00\u4e9b\u4fee\u6539\u3002 import argparse import os import joblib import pandas as pd from sklearn import metrics import config import model_dispatcher def run ( fold , model ): df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values # \u6839\u636emodel\u53c2\u6570\u9009\u62e9\u6a21\u578b clf = model_dispatcher . models [ model ] clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" )) if __name__ == \"__main__\" : parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # model\u53c2\u6570 parser . add_argument ( \"--model\" , type = str ) args = parser . parse_args () run ( fold = args . fold , model = args . model ) train.py \u6709\u51e0\u5904\u91cd\u5927\u6539\u52a8\uff1a - \u5bfc\u5165 model_dispatcher - \u4e3a ArgumentParser \u6dfb\u52a0 --model \u53c2\u6570 - \u4e3a run() \u51fd\u6570\u6dfb\u52a0model\u53c2\u6570 - \u4f7f\u7528\u8c03\u5ea6\u7a0b\u5e8f\u83b7\u53d6\u6307\u5b9a\u540d\u79f0\u7684\u6a21\u578b \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u811a\u672c\uff1a \u276f python train . py -- fold 0 -- model decision_tree_gini Fold = 0 , Accuracy = 0.8665833333333334 \u6216\u6267\u884c\u4ee5\u4e0b\u547d\u4ee4 \u276f python train . py -- fold 0 -- model decision_tree_entropy Fold = 0 , Accuracy = 0.8705833333333334 \u73b0\u5728\uff0c\u5982\u679c\u8981\u6dfb\u52a0\u65b0\u6a21\u578b\uff0c\u53ea\u9700\u4fee\u6539 model_dispatcher.py \u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u6dfb\u52a0\u968f\u673a\u68ee\u6797\uff0c\u770b\u770b\u51c6\u786e\u7387\u4f1a\u6709\u4ec0\u4e48\u53d8\u5316\u3002 from sklearn import ensemble from sklearn import tree models = { \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), # \u968f\u673a\u68ee\u6797\u6a21\u578b \"rf\" : ensemble . RandomForestClassifier (), } \u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u3002 \u276f python train . py -- fold 0 -- model rf Fold = 0 , Accuracy = 0.9670833333333333 \u54c7\uff0c\u4e00\u4e2a\u7b80\u5355\u7684\u6539\u52a8\u5c31\u80fd\u8ba9\u5206\u6570\u6709\u5982\u6b64\u5927\u7684\u63d0\u5347\uff01\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528 run.sh \u811a\u672c\u8fd0\u884c 5 \u4e2a\u6298\u53e0\uff01 python train . py -- fold 0 -- model rf python train . py -- fold 1 -- model rf python train . py -- fold 2 -- model rf python train . py -- fold 3 -- model rf python train . py -- fold 4 -- model rf \u5f97\u5206\u60c5\u51b5\u5982\u4e0b \u276f sh run . sh Fold = 0 , Accuracy = 0.9674166666666667 Fold = 1 , Accuracy = 0.9698333333333333 Fold = 2 , Accuracy = 0.96575 Fold = 3 , Accuracy = 0.9684166666666667 Fold = 4 , Accuracy = 0.9666666666666667 MNIST \u51e0\u4e4e\u662f\u6bcf\u672c\u4e66\u548c\u6bcf\u7bc7\u535a\u5ba2\u90fd\u4f1a\u8ba8\u8bba\u7684\u95ee\u9898\u3002\u4f46\u6211\u8bd5\u56fe\u5c06\u8fd9\u4e2a\u95ee\u9898\u8f6c\u6362\u5f97\u66f4\u6709\u8da3\uff0c\u5e76\u5411\u4f60\u5c55\u793a\u5982\u4f55\u4e3a\u4f60\u6b63\u5728\u505a\u7684\u6216\u8ba1\u5212\u5728\u4e0d\u4e45\u7684\u5c06\u6765\u505a\u7684\u51e0\u4e4e\u6240\u6709\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7f16\u5199\u4e00\u4e2a\u57fa\u672c\u6846\u67b6\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4ee5\u6539\u8fdb\u8fd9\u4e2a MNIST \u6a21\u578b\u548c\u8fd9\u4e2a\u6846\u67b6\uff0c\u6211\u4eec\u5c06\u5728\u4ee5\u540e\u7684\u7ae0\u8282\u4e2d\u770b\u5230\u3002 \u6211\u4f7f\u7528\u4e86\u4e00\u4e9b\u811a\u672c\uff0c\u5982 model_dispatcher.py \u548c config.py \uff0c\u5e76\u5c06\u5b83\u4eec\u5bfc\u5165\u5230\u6211\u7684\u8bad\u7ec3\u811a\u672c\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u6ca1\u6709\u5bfc\u5165 \uff0c\u4f60\u4e5f\u4e0d\u5e94\u8be5\u5bfc\u5165\u3002\u5982\u679c\u6211\u5bfc\u5165\u4e86 \uff0c\u4f60\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u77e5\u9053\u6a21\u578b\u5b57\u5178\u662f\u4ece\u54ea\u91cc\u6765\u7684\u3002\u7f16\u5199\u4f18\u79c0\u3001\u6613\u61c2\u7684\u4ee3\u7801\u662f\u4e00\u4e2a\u4eba\u5fc5\u987b\u5177\u5907\u7684\u57fa\u672c\u7d20\u8d28\uff0c\u4f46\u8bb8\u591a\u6570\u636e\u79d1\u5b66\u5bb6\u5374\u5ffd\u89c6\u4e86\u8fd9\u4e00\u70b9\u3002\u5982\u679c\u4f60\u6240\u505a\u7684\u9879\u76ee\u80fd\u8ba9\u5176\u4ed6\u4eba\u7406\u89e3\u5e76\u4f7f\u7528\uff0c\u800c\u65e0\u9700\u54a8\u8be2\u4f60\u7684\u610f\u89c1\uff0c\u90a3\u4e48\u4f60\u5c31\u8282\u7701\u4e86\u4ed6\u4eec\u7684\u65f6\u95f4\u548c\u81ea\u5df1\u7684\u65f6\u95f4\uff0c\u53ef\u4ee5\u5c06\u8fd9\u4e9b\u65f6\u95f4\u6295\u5165\u5230\u6539\u8fdb\u4f60\u7684\u9879\u76ee\u6216\u5f00\u53d1\u65b0\u9879\u76ee\u4e2d\u53bb\u3002","title":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee"},{"location":"%E7%BB%84%E7%BB%87%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%A1%B9%E7%9B%AE/#_1","text":"\u7ec8\u4e8e\uff0c\u6211\u4eec\u53ef\u4ee5\u5f00\u59cb\u6784\u5efa\u7b2c\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e86\u3002 \u662f\u8fd9\u6837\u5417\uff1f \u5728\u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u6ce8\u610f\u51e0\u4ef6\u4e8b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u6211\u4eec\u5c06\u5728\u96c6\u6210\u5f00\u53d1\u73af\u5883/\u6587\u672c\u7f16\u8f91\u5668\u4e2d\u5de5\u4f5c\uff0c\u800c\u4e0d\u662f\u5728 jupyter notebook\u4e2d\u3002\u4f60\u4e5f\u53ef\u4ee5\u5728 jupyter notebook\u4e2d\u5de5\u4f5c\uff0c\u8fd9\u5b8c\u5168\u53d6\u51b3\u4e8e\u4f60\u3002\u4e0d\u8fc7\uff0c\u6211\u5c06\u53ea\u4f7f\u7528 jupyter notebook\u6765\u63a2\u7d22\u6570\u636e\u3001\u7ed8\u5236\u56fe\u8868\u548c\u56fe\u5f62\u3002\u6211\u4eec\u5c06\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u6784\u5efa\u5206\u7c7b\u6846\u67b6\uff0c\u5373\u63d2\u5373\u7528\u3002\u60a8\u65e0\u9700\u5bf9\u4ee3\u7801\u505a\u592a\u591a\u6539\u52a8\u5c31\u80fd\u8bad\u7ec3\u6a21\u578b\uff0c\u800c\u4e14\u5f53\u60a8\u6539\u8fdb\u6a21\u578b\u65f6\uff0c\u8fd8\u80fd\u4f7f\u7528 git \u5bf9\u5176\u8fdb\u884c\u8ddf\u8e2a\u3002 \u6211\u4eec\u9996\u5148\u6765\u770b\u770b\u6587\u4ef6\u7684\u7ed3\u6784\u3002\u5bf9\u4e8e\u4f60\u6b63\u5728\u505a\u7684\u4efb\u4f55\u9879\u76ee\uff0c\u90fd\u8981\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\u5939\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u5c06\u9879\u76ee\u547d\u540d\u4e3a \"project\"\u3002 \u9879\u76ee\u6587\u4ef6\u5939\u5185\u90e8\u5e94\u8be5\u5982\u4e0b\u6240\u793a\u3002 input train.csv test.csv src create_folds.py train.py inference.py models.py config.py model_dispatcher.py models model_rf.bin model_et.bin notebooks exploration.ipynb check_data.ipynb README.md LICENSE \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e9b\u6587\u4ef6\u5939\u548c\u6587\u4ef6\u7684\u5185\u5bb9\u3002 input/ \uff1a\u8be5\u6587\u4ef6\u5939\u5305\u542b\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7684\u6240\u6709\u8f93\u5165\u6587\u4ef6\u548c\u6570\u636e\u3002\u5982\u679c\u60a8\u6b63\u5728\u5f00\u53d1 NLP \u9879\u76ee\uff0c\u60a8\u53ef\u4ee5\u5c06embeddings\u653e\u5728\u8fd9\u91cc\u3002\u5982\u679c\u662f\u56fe\u50cf\u9879\u76ee\uff0c\u6240\u6709\u56fe\u50cf\u90fd\u653e\u5728\u8be5\u6587\u4ef6\u5939\u4e0b\u7684\u5b50\u6587\u4ef6\u5939\u4e2d\u3002 src/ \uff1a\u6211\u4eec\u5c06\u5728\u8fd9\u91cc\u4fdd\u5b58\u4e0e\u9879\u76ee\u76f8\u5173\u7684\u6240\u6709 python \u811a\u672c\u3002\u5982\u679c\u6211\u8bf4\u7684\u662f\u4e00\u4e2a python \u811a\u672c\uff0c\u5373\u4efb\u4f55 *.py \u6587\u4ef6\uff0c\u5b83\u90fd\u5b58\u50a8\u5728 src \u6587\u4ef6\u5939\u4e2d\u3002 models/ \uff1a\u8be5\u6587\u4ef6\u5939\u4fdd\u5b58\u6240\u6709\u8bad\u7ec3\u8fc7\u7684\u6a21\u578b\u3002 notebook/ \uff1a\u6240\u6709 jupyter notebook\uff08\u5373\u4efb\u4f55 *.ipynb \u6587\u4ef6\uff09\u90fd\u5b58\u50a8\u5728\u7b14\u8bb0\u672c \u6587\u4ef6\u5939\u4e2d\u3002 README.md \uff1a\u8fd9\u662f\u4e00\u4e2a\u6807\u8bb0\u7b26\u6587\u4ef6\uff0c\u60a8\u53ef\u4ee5\u5728\u5176\u4e2d\u63cf\u8ff0\u60a8\u7684\u9879\u76ee\uff0c\u5e76\u5199\u660e\u5982\u4f55\u8bad\u7ec3\u6a21\u578b\u6216\u5728\u751f\u4ea7\u73af\u5883\u4e2d\u4f7f\u7528\u3002 LICENSE \uff1a\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u6587\u4ef6\uff0c\u5305\u542b\u9879\u76ee\u7684\u8bb8\u53ef\u8bc1\uff0c\u5982 MIT\u3001Apache \u7b49\u3002\u5173\u4e8e\u8bb8\u53ef\u8bc1\u7684\u8be6\u7ec6\u4ecb\u7ecd\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 \u5047\u8bbe\u4f60\u6b63\u5728\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u5bf9 MNIST \u6570\u636e\u96c6\uff08\u51e0\u4e4e\u6bcf\u672c\u673a\u5668\u5b66\u4e60\u4e66\u7c4d\u90fd\u4f1a\u7528\u5230\u7684\u6570\u636e\u96c6\uff09\u8fdb\u884c\u5206\u7c7b\u3002\u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\uff0c\u6211\u4eec\u5728\u4ea4\u53c9\u68c0\u9a8c\u4e00\u7ae0\u4e2d\u4e5f\u63d0\u5230\u8fc7 MNIST \u6570\u636e\u96c6\u3002\u6240\u4ee5\uff0c\u6211\u5c31\u4e0d\u89e3\u91ca\u8fd9\u4e2a\u6570\u636e\u96c6\u662f\u4ec0\u4e48\u6837\u5b50\u4e86\u3002\u7f51\u4e0a\u6709\u8bb8\u591a\u4e0d\u540c\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6\uff0c\u4f46\u6211\u4eec\u5c06\u4f7f\u7528 CSV \u683c\u5f0f\u7684\u6570\u636e\u96c6\u3002 \u5728\u8fd9\u79cd\u683c\u5f0f\u7684\u6570\u636e\u96c6\u4e2d\uff0cCSV \u7684\u6bcf\u4e00\u884c\u90fd\u5305\u542b\u56fe\u50cf\u7684\u6807\u7b7e\u548c 784 \u4e2a\u50cf\u7d20\u503c\uff0c\u50cf\u7d20\u503c\u8303\u56f4\u4ece 0 \u5230 255\u3002\u6570\u636e\u96c6\u5305\u542b 60000 \u5f20\u8fd9\u79cd\u683c\u5f0f\u7684\u56fe\u50cf\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u8f7b\u677e\u8bfb\u53d6\u8fd9\u79cd\u6570\u636e\u683c\u5f0f\u3002 \u8bf7\u6ce8\u610f\uff0c\u5c3d\u7ba1\u56fe 1 \u663e\u793a\u6240\u6709\u50cf\u7d20\u503c\u5747\u4e3a\u96f6\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002 \u56fe 1\uff1aCSV\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u6807\u7b7e\u5217\u7684\u8ba1\u6570\u3002 \u56fe 2\uff1aMNIST \u6570\u636e\u96c6\u4e2d\u7684\u6807\u7b7e\u8ba1\u6570 \u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c\u66f4\u591a\u7684\u63a2\u7d22\u3002\u6211\u4eec\u5df2\u7ecf\u77e5\u9053\u4e86\u6211\u4eec\u6240\u62e5\u6709\u7684\u6570\u636e\uff0c\u6ca1\u6709\u5fc5\u8981\u518d\u5bf9\u4e0d\u540c\u7684\u50cf\u7d20\u503c\u8fdb\u884c\u7ed8\u56fe\u3002\u4ece\u56fe 2 \u4e2d\u53ef\u4ee5\u6e05\u695a\u5730\u770b\u51fa\uff0c\u6807\u7b7e\u7684\u5206\u5e03\u76f8\u5f53\u5747\u5300\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387/F1 \u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u8fd9\u5c31\u662f\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u7684\u7b2c\u4e00\u6b65\uff1a\u786e\u5b9a\u8861\u91cf\u6807\u51c6\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7f16\u5199\u4e00\u4e9b\u4ee3\u7801\u4e86\u3002\u6211\u4eec\u9700\u8981\u521b\u5efa src/ \u6587\u4ef6\u5939\u548c\u4e00\u4e9b python \u811a\u672c\u3002 \u8bf7\u6ce8\u610f\uff0c\u8bad\u7ec3 CSV \u6587\u4ef6\u4f4d\u4e8e input/ \u6587\u4ef6\u5939\u4e2d\uff0c\u540d\u4e3a mnist_train.csv \u3002 \u5bf9\u4e8e\u8fd9\u6837\u4e00\u4e2a\u9879\u76ee\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5e94\u8be5\u662f\u4ec0\u4e48\u6837\u7684\u5462\uff1f \u9996\u5148\u8981\u521b\u5efa\u7684\u811a\u672c\u662f create_folds.py \u3002 \u8fd9\u5c06\u5728 input/ \u6587\u4ef6\u5939\u4e2d\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a mnist_train_folds.csv \u7684\u65b0\u6587\u4ef6\uff0c\u4e0e mnist_train.csv \u76f8\u540c\u3002\u552f\u4e00\u4e0d\u540c\u7684\u662f\uff0c\u8fd9\u4e2a CSV \u6587\u4ef6\u7ecf\u8fc7\u4e86\u968f\u673a\u6392\u5e8f\uff0c\u5e76\u65b0\u589e\u4e86\u4e00\u5217\u540d\u4e3a kfold \u7684\u5185\u5bb9\u3002 \u4e00\u65e6\u6211\u4eec\u51b3\u5b9a\u4e86\u8981\u4f7f\u7528\u54ea\u79cd\u8bc4\u4f30\u6307\u6807\u5e76\u521b\u5efa\u4e86\u6298\u53e0\uff0c\u5c31\u53ef\u4ee5\u5f00\u59cb\u521b\u5efa\u57fa\u672c\u6a21\u578b\u4e86\u3002\u8fd9\u53ef\u4ee5\u5728 train.py \u4e2d\u5b8c\u6210\u3002 import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/mnist_train_folds.csv\" ) # \u9009\u53d6df\u4e2dkfold\u5217\u4e0d\u7b49\u4e8efold df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u9009\u53d6df\u4e2dkfold\u5217\u7b49\u4e8efold df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u8bad\u7ec3\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_train = df_train . drop ( \"label\" , axis = 1 ) . values # \u8bad\u7ec3\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_train = df_train . label . values # \u9a8c\u8bc1\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values # \u9a8c\u8bc1\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_valid = df_valid . label . values # \u5b9e\u4f8b\u5316\u51b3\u7b56\u6811\u6a21\u578b clf = tree . DecisionTreeClassifier () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b clf . fit ( x_train , y_train ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u8f93\u5165\u5f97\u5230\u9884\u6d4b\u7ed3\u679c preds = clf . predict ( x_valid ) # \u8ba1\u7b97\u9a8c\u8bc1\u96c6\u51c6\u786e\u7387 accuracy = metrics . accuracy_score ( y_valid , preds ) # \u6253\u5370fold\u4fe1\u606f\u548c\u51c6\u786e\u7387 print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) # \u4fdd\u5b58\u6a21\u578b joblib . dump ( clf , f \"../models/dt_ { fold } .bin\" ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u60a8\u53ef\u4ee5\u5728\u63a7\u5236\u53f0\u8c03\u7528 python train.py \u8fd0\u884c\u8be5\u811a\u672c\u3002 \u276f python train . py Fold = 0 , Accuracy = 0.8680833333333333 Fold = 1 , Accuracy = 0.8685 Fold = 2 , Accuracy = 0.8674166666666666 Fold = 3 , Accuracy = 0.8703333333333333 Fold = 4 , Accuracy = 0.8699166666666667 \u67e5\u770b\u8bad\u7ec3\u811a\u672c\u65f6\uff0c\u60a8\u4f1a\u53d1\u73b0\u8fd8\u6709\u4e00\u4e9b\u5185\u5bb9\u662f\u786c\u7f16\u7801\u7684\uff0c\u4f8b\u5982\u6298\u53e0\u6570\u3001\u8bad\u7ec3\u6587\u4ef6\u548c\u8f93\u51fa\u6587\u4ef6\u5939\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u6240\u6709\u8fd9\u4e9b\u4fe1\u606f\u7684\u914d\u7f6e\u6587\u4ef6\uff1a config.py \u3002 TRAINING_FILE = \"../input/mnist_train_folds.csv\" MODEL_OUTPUT = \"../models/\" \u6211\u4eec\u8fd8\u5bf9\u8bad\u7ec3\u811a\u672c\u8fdb\u884c\u4e86\u4e00\u4e9b\u4fee\u6539\u3002\u8bad\u7ec3\u6587\u4ef6\u73b0\u5728\u4f7f\u7528\u914d\u7f6e\u6587\u4ef6\u3002\u8fd9\u6837\uff0c\u66f4\u6539\u6570\u636e\u6216\u6a21\u578b\u8f93\u51fa\u5c31\u66f4\u5bb9\u6613\u4e86\u3002 import os import config import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u4f7f\u7528config\u4e2d\u7684\u8def\u5f84\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values clf = tree . DecisionTreeClassifier () clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" ) ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5e76\u6ca1\u6709\u5c55\u793a\u8fd9\u4e2a\u57f9\u8bad\u811a\u672c\u4e0e\u4e4b\u524d\u811a\u672c\u7684\u533a\u522b\u3002\u8bf7\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e24\u4e2a\u811a\u672c\uff0c\u81ea\u5df1\u627e\u51fa\u4e0d\u540c\u4e4b\u5904\u3002\u533a\u522b\u5e76\u4e0d\u591a\u3002 \u4e0e\u8bad\u7ec3\u811a\u672c\u76f8\u5173\u7684\u8fd8\u6709\u4e00\u70b9\u53ef\u4ee5\u6539\u8fdb\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u4e3a\u6bcf\u4e2a\u6298\u53e0\u591a\u6b21\u8c03\u7528\u8fd0\u884c\u51fd\u6570\u3002\u6709\u65f6\uff0c\u5728\u540c\u4e00\u4e2a\u811a\u672c\u4e2d\u8fd0\u884c\u591a\u4e2a\u6298\u53e0\u5e76\u4e0d\u53ef\u53d6\uff0c\u56e0\u4e3a\u5185\u5b58\u6d88\u8017\u53ef\u80fd\u4f1a\u4e0d\u65ad\u589e\u52a0\uff0c\u7a0b\u5e8f\u53ef\u80fd\u4f1a\u5d29\u6e83\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5411\u8bad\u7ec3\u811a\u672c\u4f20\u9012\u53c2\u6570\u3002\u6211\u559c\u6b22\u4f7f\u7528 argparse\u3002 import argparse if __name__ == \"__main__\" : # \u5b9e\u4f8b\u5316\u53c2\u6570\u73af\u5883 parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # \u8bfb\u53d6\u53c2\u6570 args = parser . parse_args () run ( fold = args . fold ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u8fd0\u884c python \u811a\u672c\uff0c\u4f46\u4ec5\u9650\u4e8e\u7ed9\u5b9a\u7684\u6298\u53e0\u3002 \u276f python train . py -- fold 0 Fold = 0 , Accuracy = 0.8656666666666667 \u4ed4\u7ec6\u89c2\u5bdf\uff0c\u6211\u4eec\u7684\u7b2c 0 \u6298\u5f97\u5206\u4e0e\u4e4b\u524d\u6709\u4e9b\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6a21\u578b\u4e2d\u5b58\u5728\u968f\u673a\u6027\u3002\u6211\u4eec\u5c06\u5728\u540e\u9762\u7684\u7ae0\u8282\u4e2d\u8ba8\u8bba\u5982\u4f55\u5904\u7406\u968f\u673a\u6027\u3002 \u73b0\u5728\uff0c\u5982\u679c\u4f60\u613f\u610f\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a shell \u811a\u672c \uff0c\u9488\u5bf9\u4e0d\u540c\u7684\u6298\u53e0\u4f7f\u7528\u4e0d\u540c\u7684\u547d\u4ee4\uff0c\u7136\u540e\u4e00\u8d77\u8fd0\u884c\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 python train . py -- fold 0 python train . py -- fold 1 python train . py -- fold 2 python train . py -- fold 3 python train . py -- fold 4 \u60a8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u5b83\u3002 \u276f sh run . sh Fold = 0 , Accuracy = 0.8675 Fold = 1 , Accuracy = 0.8693333333333333 Fold = 2 , Accuracy = 0.8683333333333333 Fold = 3 , Accuracy = 0.8704166666666666 Fold = 4 , Accuracy = 0.8685 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u53d6\u5f97\u4e86\u4e00\u4e9b\u8fdb\u5c55\uff0c\u4f46\u5982\u679c\u6211\u4eec\u770b\u4e00\u4e0b\u6211\u4eec\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u6211\u4eec\u4ecd\u7136\u53d7\u5230\u4e00\u4e9b\u4e1c\u897f\u7684\u9650\u5236\uff0c\u4f8b\u5982\u6a21\u578b\u3002\u6a21\u578b\u662f\u786c\u7f16\u7801\u5728\u8bad\u7ec3\u811a\u672c\u4e2d\u7684\uff0c\u53ea\u6709\u4fee\u6539\u811a\u672c\u624d\u80fd\u6539\u53d8\u5b83\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u65b0\u7684 python \u811a\u672c\uff0c\u540d\u4e3a model_dispatcher.py \u3002model_dispatcher.py\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5c06\u8c03\u5ea6\u6211\u4eec\u7684\u6a21\u578b\u5230\u8bad\u7ec3\u811a\u672c\u4e2d\u3002 from sklearn import tree models = { # \u4ee5gini\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), # \u4ee5entropy\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), } model_dispatcher.py \u4ece scikit-learn \u4e2d\u5bfc\u5165\u4e86 tree\uff0c\u5e76\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u6a21\u578b\u7684\u540d\u79f0\uff0c\u503c\u662f\u6a21\u578b\u672c\u8eab\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5b9a\u4e49\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u51b3\u7b56\u6811\uff0c\u4e00\u79cd\u4f7f\u7528\u57fa\u5c3c\u6807\u51c6\uff0c\u53e6\u4e00\u79cd\u4f7f\u7528\u71b5\u6807\u51c6\u3002\u8981\u4f7f\u7528 py\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u8bad\u7ec3\u811a\u672c\u505a\u4e00\u4e9b\u4fee\u6539\u3002 import argparse import os import joblib import pandas as pd from sklearn import metrics import config import model_dispatcher def run ( fold , model ): df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values # \u6839\u636emodel\u53c2\u6570\u9009\u62e9\u6a21\u578b clf = model_dispatcher . models [ model ] clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" )) if __name__ == \"__main__\" : parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # model\u53c2\u6570 parser . add_argument ( \"--model\" , type = str ) args = parser . parse_args () run ( fold = args . fold , model = args . model ) train.py \u6709\u51e0\u5904\u91cd\u5927\u6539\u52a8\uff1a - \u5bfc\u5165 model_dispatcher - \u4e3a ArgumentParser \u6dfb\u52a0 --model \u53c2\u6570 - \u4e3a run() \u51fd\u6570\u6dfb\u52a0model\u53c2\u6570 - \u4f7f\u7528\u8c03\u5ea6\u7a0b\u5e8f\u83b7\u53d6\u6307\u5b9a\u540d\u79f0\u7684\u6a21\u578b \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u811a\u672c\uff1a \u276f python train . py -- fold 0 -- model decision_tree_gini Fold = 0 , Accuracy = 0.8665833333333334 \u6216\u6267\u884c\u4ee5\u4e0b\u547d\u4ee4 \u276f python train . py -- fold 0 -- model decision_tree_entropy Fold = 0 , Accuracy = 0.8705833333333334 \u73b0\u5728\uff0c\u5982\u679c\u8981\u6dfb\u52a0\u65b0\u6a21\u578b\uff0c\u53ea\u9700\u4fee\u6539 model_dispatcher.py \u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u6dfb\u52a0\u968f\u673a\u68ee\u6797\uff0c\u770b\u770b\u51c6\u786e\u7387\u4f1a\u6709\u4ec0\u4e48\u53d8\u5316\u3002 from sklearn import ensemble from sklearn import tree models = { \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), # \u968f\u673a\u68ee\u6797\u6a21\u578b \"rf\" : ensemble . RandomForestClassifier (), } \u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u3002 \u276f python train . py -- fold 0 -- model rf Fold = 0 , Accuracy = 0.9670833333333333 \u54c7\uff0c\u4e00\u4e2a\u7b80\u5355\u7684\u6539\u52a8\u5c31\u80fd\u8ba9\u5206\u6570\u6709\u5982\u6b64\u5927\u7684\u63d0\u5347\uff01\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528 run.sh \u811a\u672c\u8fd0\u884c 5 \u4e2a\u6298\u53e0\uff01 python train . py -- fold 0 -- model rf python train . py -- fold 1 -- model rf python train . py -- fold 2 -- model rf python train . py -- fold 3 -- model rf python train . py -- fold 4 -- model rf \u5f97\u5206\u60c5\u51b5\u5982\u4e0b \u276f sh run . sh Fold = 0 , Accuracy = 0.9674166666666667 Fold = 1 , Accuracy = 0.9698333333333333 Fold = 2 , Accuracy = 0.96575 Fold = 3 , Accuracy = 0.9684166666666667 Fold = 4 , Accuracy = 0.9666666666666667 MNIST \u51e0\u4e4e\u662f\u6bcf\u672c\u4e66\u548c\u6bcf\u7bc7\u535a\u5ba2\u90fd\u4f1a\u8ba8\u8bba\u7684\u95ee\u9898\u3002\u4f46\u6211\u8bd5\u56fe\u5c06\u8fd9\u4e2a\u95ee\u9898\u8f6c\u6362\u5f97\u66f4\u6709\u8da3\uff0c\u5e76\u5411\u4f60\u5c55\u793a\u5982\u4f55\u4e3a\u4f60\u6b63\u5728\u505a\u7684\u6216\u8ba1\u5212\u5728\u4e0d\u4e45\u7684\u5c06\u6765\u505a\u7684\u51e0\u4e4e\u6240\u6709\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7f16\u5199\u4e00\u4e2a\u57fa\u672c\u6846\u67b6\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4ee5\u6539\u8fdb\u8fd9\u4e2a MNIST \u6a21\u578b\u548c\u8fd9\u4e2a\u6846\u67b6\uff0c\u6211\u4eec\u5c06\u5728\u4ee5\u540e\u7684\u7ae0\u8282\u4e2d\u770b\u5230\u3002 \u6211\u4f7f\u7528\u4e86\u4e00\u4e9b\u811a\u672c\uff0c\u5982 model_dispatcher.py \u548c config.py \uff0c\u5e76\u5c06\u5b83\u4eec\u5bfc\u5165\u5230\u6211\u7684\u8bad\u7ec3\u811a\u672c\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u6ca1\u6709\u5bfc\u5165 \uff0c\u4f60\u4e5f\u4e0d\u5e94\u8be5\u5bfc\u5165\u3002\u5982\u679c\u6211\u5bfc\u5165\u4e86 \uff0c\u4f60\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u77e5\u9053\u6a21\u578b\u5b57\u5178\u662f\u4ece\u54ea\u91cc\u6765\u7684\u3002\u7f16\u5199\u4f18\u79c0\u3001\u6613\u61c2\u7684\u4ee3\u7801\u662f\u4e00\u4e2a\u4eba\u5fc5\u987b\u5177\u5907\u7684\u57fa\u672c\u7d20\u8d28\uff0c\u4f46\u8bb8\u591a\u6570\u636e\u79d1\u5b66\u5bb6\u5374\u5ffd\u89c6\u4e86\u8fd9\u4e00\u70b9\u3002\u5982\u679c\u4f60\u6240\u505a\u7684\u9879\u76ee\u80fd\u8ba9\u5176\u4ed6\u4eba\u7406\u89e3\u5e76\u4f7f\u7528\uff0c\u800c\u65e0\u9700\u54a8\u8be2\u4f60\u7684\u610f\u89c1\uff0c\u90a3\u4e48\u4f60\u5c31\u8282\u7701\u4e86\u4ed6\u4eec\u7684\u65f6\u95f4\u548c\u81ea\u5df1\u7684\u65f6\u95f4\uff0c\u53ef\u4ee5\u5c06\u8fd9\u4e9b\u65f6\u95f4\u6295\u5165\u5230\u6539\u8fdb\u4f60\u7684\u9879\u76ee\u6216\u5f00\u53d1\u65b0\u9879\u76ee\u4e2d\u53bb\u3002","title":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee"},{"location":"%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87/","text":"\u8bc4\u4f30\u6307\u6807 \u8bf4\u5230\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u4f60\u4f1a\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u9047\u5230\u5f88\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u6307\u6807\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u751a\u81f3\u4f1a\u6839\u636e\u4e1a\u52a1\u95ee\u9898\u521b\u5efa\u5ea6\u91cf\u6807\u51c6\u3002\u9010\u4e00\u4ecb\u7ecd\u548c\u89e3\u91ca\u6bcf\u4e00\u79cd\u5ea6\u91cf\u7c7b\u578b\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u5ea6\u91cf\u6807\u51c6\uff0c\u4f9b\u4f60\u5728\u6700\u521d\u7684\u51e0\u4e2a\u9879\u76ee\u4e2d\u4f7f\u7528\u3002 \u5728\u672c\u4e66\u7684\u5f00\u5934\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u76d1\u7763\u5b66\u4e60\u548c\u975e\u76d1\u7763\u5b66\u4e60\u3002\u867d\u7136\u65e0\u76d1\u7763\u5b66\u4e60\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e9b\u6307\u6807\uff0c\u4f46\u6211\u4eec\u5c06\u53ea\u5173\u6ce8\u6709\u76d1\u7763\u5b66\u4e60\u3002\u8fd9\u662f\u56e0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u591a\uff0c\u800c\u4e14\u5bf9\u65e0\u76d1\u7763\u65b9\u6cd5\u7684\u8bc4\u4f30\u76f8\u5f53\u4e3b\u89c2\u3002 \u5982\u679c\u6211\u4eec\u8c08\u8bba\u5206\u7c7b\u95ee\u9898\uff0c\u6700\u5e38\u7528\u7684\u6307\u6807\u662f\uff1a \u51c6\u786e\u7387\uff08Accuracy\uff09 \u7cbe\u786e\u7387\uff08P\uff09 \u53ec\u56de\u7387\uff08R\uff09 F1 \u5206\u6570\uff08F1\uff09 AUC\uff08AUC\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u8bf4\u5230\u56de\u5f52\uff0c\u6700\u5e38\u7528\u7684\u8bc4\u4ef7\u6307\u6807\u662f \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee \uff08MAE\uff09 \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u5747\u65b9\u6839\u8bef\u5dee \uff08RMSE\uff09 \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \uff08RMSLE\uff09 \u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee \uff08MPE\uff09 \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee \uff08MAPE\uff09 R2 \u4e86\u89e3\u4e0a\u8ff0\u6307\u6807\u7684\u5de5\u4f5c\u539f\u7406\u5e76\u4e0d\u662f\u6211\u4eec\u5fc5\u987b\u4e86\u89e3\u7684\u552f\u4e00\u4e8b\u60c5\u3002\u6211\u4eec\u8fd8\u5fc5\u987b\u77e5\u9053\u4f55\u65f6\u4f7f\u7528\u54ea\u4e9b\u6307\u6807\uff0c\u800c\u8fd9\u53d6\u51b3\u4e8e\u4f60\u6709\u4ec0\u4e48\u6837\u7684\u6570\u636e\u548c\u76ee\u6807\u3002\u6211\u8ba4\u4e3a\u8fd9\u4e0e\u76ee\u6807\u6709\u5173\uff0c\u800c\u4e0e\u6570\u636e\u65e0\u5173\u3002 \u8981\u8fdb\u4e00\u6b65\u4e86\u89e3\u8fd9\u4e9b\u6307\u6807\uff0c\u8ba9\u6211\u4eec\u4ece\u4e00\u4e2a\u7b80\u5355\u7684\u95ee\u9898\u5f00\u59cb\u3002\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a \u4e8c\u5143\u5206\u7c7b \u95ee\u9898\uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u76ee\u6807\u7684\u95ee\u9898\uff0c\u5047\u8bbe\u8fd9\u662f\u4e00\u4e2a\u80f8\u90e8 X \u5149\u56fe\u50cf\u5206\u7c7b\u95ee\u9898\u3002\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6ca1\u6709\u95ee\u9898\uff0c\u800c\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6709\u80ba\u584c\u9677\uff0c\u4e5f\u5c31\u662f\u6240\u8c13\u7684\u6c14\u80f8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u5206\u7c7b\u5668\uff0c\u5728\u7ed9\u5b9a\u80f8\u90e8 X \u5149\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\uff0c\u5b83\u80fd\u68c0\u6d4b\u51fa\u56fe\u50cf\u662f\u5426\u6709\u6c14\u80f8\u3002 \u56fe 1\uff1a\u6c14\u80f8\u80ba\u90e8\u56fe\u50cf \u6211\u4eec\u8fd8\u5047\u8bbe\u6709\u76f8\u540c\u6570\u91cf\u7684\u6c14\u80f8\u548c\u975e\u6c14\u80f8\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u6bd4\u5982\u5404 100 \u5f20\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709 100 \u5f20\u9633\u6027\u6837\u672c\u548c 100 \u5f20\u9634\u6027\u6837\u672c\uff0c\u5171\u8ba1 200 \u5f20\u56fe\u50cf\u3002 \u7b2c\u4e00\u6b65\u662f\u5c06\u4e0a\u8ff0\u6570\u636e\u5206\u4e3a\u4e24\u7ec4\uff0c\u6bcf\u7ec4 100 \u5f20\u56fe\u50cf\uff0c\u5373\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u8fd9\u4e24\u4e2a\u96c6\u5408\u4e2d\uff0c\u6211\u4eec\u90fd\u6709 50 \u4e2a\u6b63\u6837\u672c\u548c 50 \u4e2a\u8d1f\u6837\u672c\u3002 \u5728\u4e8c\u5143\u5206\u7c7b\u6307\u6807\u4e2d\uff0c\u5f53\u6b63\u8d1f\u6837\u672c\u6570\u91cf\u76f8\u7b49\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f7f\u7528\u51c6\u786e\u7387\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002 \u51c6\u786e\u7387 \uff1a\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6700\u76f4\u63a5\u7684\u6307\u6807\u4e4b\u4e00\u3002\u5b83\u5b9a\u4e49\u4e86\u6a21\u578b\u7684\u51c6\u786e\u5ea6\u3002\u5bf9\u4e8e\u4e0a\u8ff0\u95ee\u9898\uff0c\u5982\u679c\u4f60\u5efa\u7acb\u7684\u6a21\u578b\u80fd\u51c6\u786e\u5206\u7c7b 90 \u5f20\u56fe\u7247\uff0c\u90a3\u4e48\u4f60\u7684\u51c6\u786e\u7387\u5c31\u662f 90% \u6216 0.90\u3002\u5982\u679c\u53ea\u6709 83 \u5e45\u56fe\u50cf\u88ab\u6b63\u786e\u5206\u7c7b\uff0c\u90a3\u4e48\u6a21\u578b\u7684\u51c6\u786e\u7387\u5c31\u662f 83% \u6216 0.83\u3002 \u8ba1\u7b97\u51c6\u786e\u7387\u7684 Python \u4ee3\u7801\u4e5f\u975e\u5e38\u7b80\u5355\u3002 def accuracy ( y_true , y_pred ): # \u4e3a\u6b63\u786e\u9884\u6d4b\u6570\u521d\u59cb\u5316\u4e00\u4e2a\u7b80\u5355\u8ba1\u6570\u5668 correct_counter = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): if yt == yp : # \u5982\u679c\u9884\u6d4b\u6807\u7b7e\u4e0e\u771f\u5b9e\u6807\u7b7e\u76f8\u540c\uff0c\u5219\u589e\u52a0\u8ba1\u6570\u5668 correct_counter += 1 # \u8fd4\u56de\u6b63\u786e\u7387\uff0c\u6b63\u786e\u6807\u7b7e\u6570/\u603b\u6807\u7b7e\u6570 return correct_counter / len ( y_true ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u8ba1\u7b97\u51c6\u786e\u7387\u3002 In [ X ]: from sklearn import metrics ... : l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] ... : metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u73b0\u5728\uff0c\u5047\u8bbe\u6211\u4eec\u628a\u6570\u636e\u96c6\u7a0d\u5fae\u6539\u52a8\u4e00\u4e0b\uff0c\u6709 180 \u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u53ea\u6709 20 \u5f20\u6709\u6c14\u80f8\u3002\u5373\u4f7f\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4e5f\u8981\u521b\u5efa\u6b63\u8d1f\uff08\u6c14\u80f8\u4e0e\u975e\u6c14\u80f8\uff09\u76ee\u6807\u6bd4\u4f8b\u76f8\u540c\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u6bcf\u4e00\u7ec4\u4e2d\uff0c\u6211\u4eec\u6709 90 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u5982\u679c\u8bf4\u9a8c\u8bc1\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u90fd\u662f\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u90a3\u4e48\u60a8\u7684\u51c6\u786e\u7387\u4f1a\u662f\u591a\u5c11\u5462\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\uff1b\u60a8\u5bf9 90% \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u6b63\u786e\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u7684\u51c6\u786e\u7387\u662f 90%\u3002 \u4f46\u8bf7\u518d\u770b\u4e00\u904d\u3002 \u4f60\u751a\u81f3\u6ca1\u6709\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5c31\u5f97\u5230\u4e86 90% \u7684\u51c6\u786e\u7387\u3002\u8fd9\u4f3c\u4e4e\u6709\u70b9\u6ca1\u7528\u3002\u5982\u679c\u6211\u4eec\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u6570\u636e\u96c6\u662f\u504f\u659c\u7684\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u6bd4\u53e6\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u591a\u5f88\u591a\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u56e0\u4e3a\u5b83\u4e0d\u80fd\u4ee3\u8868\u6570\u636e\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u80fd\u4f1a\u83b7\u5f97\u5f88\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u4f46\u60a8\u7684\u6a21\u578b\u5728\u5b9e\u9645\u6837\u672c\u4e2d\u7684\u8868\u73b0\u53ef\u80fd\u5e76\u4e0d\u7406\u60f3\uff0c\u800c\u4e14\u60a8\u4e5f\u65e0\u6cd5\u5411\u7ecf\u7406\u89e3\u91ca\u539f\u56e0\u3002 \u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b \u7cbe\u786e\u7387 \u7b49\u5176\u4ed6\u6307\u6807\u3002 \u5728\u5b66\u4e60\u7cbe\u786e\u7387\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e86\u89e3\u4e00\u4e9b\u672f\u8bed\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5047\u8bbe\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e3a\u6b63\u7c7b (1)\uff0c\u6ca1\u6709\u6c14\u80f8\u7684\u4e3a\u8d1f\u7c7b (0)\u3002 \u771f\u9633\u6027 \uff08TP\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6709\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9633\u6027\u3002 \u771f\u9634\u6027 \uff08TN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u800c\u5b9e\u9645\u76ee\u6807\u663e\u793a\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u6b63\u786e\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9633\u6027\uff1b\u5982\u679c\u60a8\u7684\u6a21\u578b\u51c6\u786e\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9634\u6027\u3002 \u5047\u9633\u6027 \uff08FP\uff09 \uff1a\u7ed9\u5b9a\u4e00\u5f20\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u975e\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9633\u6027\u3002 \u5047\u9634\u6027 \uff08FN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u975e\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\uff0c\u90a3\u4e48\u5b83\u5c31\u662f\u5047\u9633\u6027\u3002\u5982\u679c\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5219\u662f\u5047\u9634\u6027\u3002 \u8ba9\u6211\u4eec\u9010\u4e00\u770b\u770b\u8fd9\u4e9b\u5b9e\u73b0\u3002 def true_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u8ba1\u6570\u5668 tp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 1 : tp += 1 # \u8fd4\u56de\u771f\u9633\u6027\u6837\u672c\u6570 return tp def true_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9634\u6027\u6837\u672c\u8ba1\u6570\u5668 tn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 0 : tn += 1 # \u8fd4\u56de\u771f\u9634\u6027\u6837\u672c\u6570 return tn def false_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9633\u6027\u8ba1\u6570\u5668 fp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 1 : fp += 1 # \u8fd4\u56de\u5047\u9633\u6027\u6837\u672c\u6570 return fp def false_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9634\u6027\u8ba1\u6570\u5668 fn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 0 : fn += 1 # \u8fd4\u56de\u5047\u9634\u6027\u6570 return fn \u6211\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e9b\u529f\u80fd\u7684\u65b9\u6cd5\u975e\u5e38\u7b80\u5355\uff0c\u800c\u4e14\u53ea\u9002\u7528\u4e8e\u4e8c\u5143\u5206\u7c7b\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u8fd9\u4e9b\u51fd\u6570\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: true_positive ( l1 , l2 ) Out [ X ]: 2 In [ X ]: false_positive ( l1 , l2 ) Out [ X ]: 1 In [ X ]: false_negative ( l1 , l2 ) Out [ X ]: 2 In [ X ]: true_negative ( l1 , l2 ) Out [ X ]: 3 \u5982\u679c\u6211\u4eec\u5fc5\u987b\u7528\u4e0a\u8ff0\u672f\u8bed\u6765\u5b9a\u4e49\u7cbe\u786e\u7387\uff0c\u6211\u4eec\u53ef\u4ee5\u5199\u4e3a\uff1a \\[ Accuracy Score = (TP + TN)/(TP + TN + FP +FN) \\] \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5728 python \u4e2d\u4f7f\u7528 TP\u3001TN\u3001FP \u548c FN \u5feb\u901f\u5b9e\u73b0\u51c6\u786e\u5ea6\u5f97\u5206\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a accuracy_v2\u3002 def accuracy_v2 ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u51c6\u786e\u7387 accuracy_score = ( tp + tn ) / ( tp + tn + fp + fn ) return accuracy_score \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0e\u4e4b\u524d\u7684\u5b9e\u73b0\u548c scikit-learn \u7248\u672c\u8fdb\u884c\u6bd4\u8f83\uff0c\u5feb\u901f\u68c0\u67e5\u8be5\u51fd\u6570\u7684\u6b63\u786e\u6027\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: accuracy ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: accuracy_v2 ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u8bf7\u6ce8\u610f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0cmetrics.accuracy_score \u6765\u81ea scikit-learn\u3002 \u5f88\u597d\u3002\u6240\u6709\u503c\u90fd\u5339\u914d\u3002\u8fd9\u8bf4\u660e\u6211\u4eec\u5728\u5b9e\u73b0\u8fc7\u7a0b\u4e2d\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u5176\u4ed6\u91cd\u8981\u6307\u6807\u3002 \u9996\u5148\u662f\u7cbe\u786e\u7387\u3002\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u662f \\[ Precision = TP/(TP + FP) \\] \u5047\u8bbe\u6211\u4eec\u5728\u65b0\u7684\u504f\u659c\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u4e86\u4e00\u4e2a\u65b0\u6a21\u578b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 90 \u5f20\u56fe\u50cf\u4e2d\u7684 80 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u56fe\u50cf\u4e2d\u7684 8 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6210\u529f\u8bc6\u522b\u4e86 100 \u5f20\u56fe\u50cf\u4e2d\u7684 88 \u5f20\u3002\u56e0\u6b64\uff0c\u51c6\u786e\u7387\u4e3a 0.88 \u6216 88%\u3002 \u4f46\u662f\uff0c\u5728\u8fd9 100 \u5f20\u6837\u672c\u4e2d\uff0c\u6709 10 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u6c14\u80f8\uff0c2 \u5f20\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u975e\u6c14\u80f8\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 TP : 8 TN: 80 FP: 10 FN: 2 \u7cbe\u786e\u7387\u4e3a 8 / (8 + 10) = 0.444\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u5728\u8bc6\u522b\u9633\u6027\u6837\u672c\uff08\u6c14\u80f8\uff09\u65f6\u6709 44.4% \u7684\u6b63\u786e\u7387\u3002 \u73b0\u5728\uff0c\u65e2\u7136\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86 TP\u3001TN\u3001FP \u548c FN\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5728 python \u4e2d\u5b9e\u73b0\u7cbe\u786e\u7387\u4e86\u3002 def precision ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8ba9\u6211\u4eec\u8bd5\u8bd5\u8fd9\u79cd\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u65b9\u5f0f\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: precision ( l1 , l2 ) Out [ X ]: 0.6666666666666666 \u8fd9\u4f3c\u4e4e\u6ca1\u6709\u95ee\u9898\u3002 \u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u770b \u53ec\u56de\u7387 \u3002\u53ec\u56de\u7387\u7684\u5b9a\u4e49\u662f\uff1a \\[ Recall = TP/(TP + FN) \\] \u5728\u4e0a\u8ff0\u60c5\u51b5\u4e0b\uff0c\u53ec\u56de\u7387\u4e3a 8 / (8 + 2) = 0.80\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 80% \u7684\u9633\u6027\u6837\u672c\u3002 def recall ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u53ec\u56de\u7387 recall = tp / ( tp + fn ) return recall \u5c31\u6211\u4eec\u7684\u4e24\u4e2a\u5c0f\u5217\u8868\u800c\u8a00\uff0c\u53ec\u56de\u7387\u5e94\u8be5\u662f 0.5\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: recall ( l1 , l2 ) Out [ X ]: 0.5 \u8fd9\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u503c\u76f8\u7b26\uff01 \u5bf9\u4e8e\u4e00\u4e2a \"\u597d \"\u6a21\u578b\u6765\u8bf4\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u503c\u90fd\u5e94\u8be5\u5f88\u9ad8\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0c\u53ec\u56de\u503c\u76f8\u5f53\u9ad8\u3002\u4f46\u662f\uff0c\u7cbe\u786e\u7387\u5374\u5f88\u4f4e\uff01\u6211\u4eec\u7684\u6a21\u578b\u4ea7\u751f\u4e86\u5927\u91cf\u7684\u8bef\u62a5\uff0c\u4f46\u8bef\u62a5\u8f83\u5c11\u3002\u5728\u8fd9\u7c7b\u95ee\u9898\u4e2d\uff0c\u5047\u9634\u6027\u8f83\u5c11\u662f\u597d\u4e8b\uff0c\u56e0\u4e3a\u4f60\u4e0d\u60f3\u5728\u75c5\u4eba\u6709\u6c14\u80f8\u7684\u60c5\u51b5\u4e0b\u5374\u8bf4\u4ed6\u4eec\u6ca1\u6709\u6c14\u80f8\u3002\u8fd9\u6837\u505a\u4f1a\u9020\u6210\u66f4\u5927\u7684\u4f24\u5bb3\u3002\u4f46\u6211\u4eec\u4e5f\u6709\u5f88\u591a\u5047\u9633\u6027\u7ed3\u679c\uff0c\u8fd9\u4e5f\u4e0d\u662f\u597d\u4e8b\u3002 \u5927\u591a\u6570\u6a21\u578b\u90fd\u4f1a\u9884\u6d4b\u4e00\u4e2a\u6982\u7387\uff0c\u5f53\u6211\u4eec\u9884\u6d4b\u65f6\uff0c\u901a\u5e38\u4f1a\u5c06\u8fd9\u4e2a\u9608\u503c\u9009\u4e3a 0.5\u3002\u8fd9\u4e2a\u9608\u503c\u5e76\u4e0d\u603b\u662f\u7406\u60f3\u7684\uff0c\u6839\u636e\u8fd9\u4e2a\u9608\u503c\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u503c\u53ef\u80fd\u4f1a\u53d1\u751f\u5f88\u5927\u7684\u53d8\u5316\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u7684\u6bcf\u4e2a\u9608\u503c\u90fd\u80fd\u8ba1\u7b97\u51fa\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u53ef\u4ee5\u5728\u8fd9\u4e9b\u503c\u4e4b\u95f4\u7ed8\u5236\u51fa\u66f2\u7ebf\u56fe\u3002\u8fd9\u5e45\u56fe\u6216\u66f2\u7ebf\u88ab\u79f0\u4e3a \"\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\"\u3002 \u5728\u7814\u7a76\u7cbe\u786e\u7387-\u8c03\u7528\u66f2\u7ebf\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5047\u8bbe\u6709\u4e24\u4e2a\u5217\u8868\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0.02638412 , 0.11114267 , 0.31620708 , ... : 0.0490937 , 0.0191491 , 0.17554844 , ... : 0.15952202 , 0.03819563 , 0.11639273 , ... : 0.079377 , 0.08584789 , 0.39095342 , ... : 0.27259048 , 0.03447096 , 0.04644807 , ... : 0.03543574 , 0.18521942 , 0.05934905 , ... : 0.61977213 , 0.33056815 ] \u56e0\u6b64\uff0cy_true \u662f\u6211\u4eec\u7684\u76ee\u6807\u503c\uff0c\u800c y_pred \u662f\u6837\u672c\u88ab\u8d4b\u503c\u4e3a 1 \u7684\u6982\u7387\u503c\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u6211\u4eec\u8981\u770b\u7684\u662f\u9884\u6d4b\u4e2d\u7684\u6982\u7387\uff0c\u800c\u4e0d\u662f\u9884\u6d4b\u503c\uff08\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u9884\u6d4b\u503c\u7684\u8ba1\u7b97\u9608\u503c\u4e3a 0.5\uff09\u3002 precisions = [] recalls = [] thresholds = [ 0.0490937 , 0.05934905 , 0.079377 , 0.08584789 , 0.11114267 , 0.11639273 , 0.15952202 , 0.17554844 , 0.18521942 , 0.27259048 , 0.31620708 , 0.33056815 , 0.39095342 , 0.61977213 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for i in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_prediction = [ 1 if x >= i else 0 for x in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , temp_prediction ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , temp_prediction ) # \u52a0\u5165\u7cbe\u786e\u7387\u5217\u8868 precisions . append ( p ) # \u52a0\u5165\u53ec\u56de\u7387\u5217\u8868 recalls . append ( r ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7ed8\u5236\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 7 , 7 )) # x\u8f74\u4e3a\u53ec\u56de\u7387\uff0cy\u8f74\u4e3a\u7cbe\u786e\u7387 plt . plot ( recalls , precisions ) # \u6dfb\u52a0x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a15 plt . xlabel ( 'Recall' , fontsize = 15 ) # \u6dfb\u52a0y\u8f74\u6807\u7b7e\uff0c\u5b57\u6761\u5927\u5c0f\u4e3a15 plt . ylabel ( 'Precision' , fontsize = 15 ) \u56fe 2 \u663e\u793a\u4e86\u6211\u4eec\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\u5f97\u5230\u7684\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 \u56fe 2\uff1a\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u8fd9\u6761 \u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u4e0e\u60a8\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230\u7684\u66f2\u7ebf\u622a\u7136\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6211\u4eec\u53ea\u6709 20 \u4e2a\u6837\u672c\uff0c\u5176\u4e2d\u53ea\u6709 3 \u4e2a\u662f\u9633\u6027\u6837\u672c\u3002\u4f46\u8fd9\u6ca1\u4ec0\u4e48\u597d\u62c5\u5fc3\u7684\u3002\u8fd9\u8fd8\u662f\u90a3\u6761\u7cbe\u786e\u7387-\u53ec\u56de\u66f2\u7ebf\u3002 \u4f60\u4f1a\u53d1\u73b0\uff0c\u9009\u62e9\u4e00\u4e2a\u65e2\u80fd\u63d0\u4f9b\u826f\u597d\u7cbe\u786e\u7387\u53c8\u80fd\u63d0\u4f9b\u53ec\u56de\u503c\u7684\u9608\u503c\u662f\u5f88\u6709\u6311\u6218\u6027\u7684\u3002\u5982\u679c\u9608\u503c\u8fc7\u9ad8\uff0c\u771f\u9633\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u51cf\u5c11\uff0c\u800c\u5047\u9634\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u589e\u52a0\u3002\u8fd9\u4f1a\u964d\u4f4e\u53ec\u56de\u7387\uff0c\u4f46\u7cbe\u786e\u7387\u5f97\u5206\u4f1a\u5f88\u9ad8\u3002\u5982\u679c\u5c06\u9608\u503c\u964d\u5f97\u592a\u4f4e\uff0c\u5219\u8bef\u62a5\u4f1a\u5927\u91cf\u589e\u52a0\uff0c\u7cbe\u786e\u7387\u4e5f\u4f1a\u964d\u4f4e\u3002 \u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u8d8a\u63a5\u8fd1 1 \u8d8a\u597d\u3002 F1 \u5206\u6570\u662f\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7efc\u5408\u6307\u6807\u3002\u5b83\u88ab\u5b9a\u4e49\u4e3a\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7b80\u5355\u52a0\u6743\u5e73\u5747\u503c\uff08\u8c03\u548c\u5e73\u5747\u503c\uff09\u3002\u5982\u679c\u6211\u4eec\u7528 P \u8868\u793a\u7cbe\u786e\u7387\uff0c\u7528 R \u8868\u793a\u53ec\u56de\u7387\uff0c\u90a3\u4e48 F1 \u5206\u6570\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a \\[ F1 = 2PR/(P + R) \\] \u6839\u636e TP\u3001FP \u548c FN\uff0c\u7a0d\u52a0\u6570\u5b66\u8ba1\u7b97\u5c31\u80fd\u5f97\u51fa\u4ee5\u4e0b F1 \u7b49\u5f0f\uff1a \\[ F1 = 2TP/(2TP + FP + FN) \\] Python \u5b9e\u73b0\u5f88\u7b80\u5355\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86\u8fd9\u4e9b def f1 ( y_true , y_pred ): # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , y_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , y_pred ) # \u8ba1\u7b97f1\u503c score = 2 * p * r / ( p + r ) return score \u8ba9\u6211\u4eec\u770b\u770b\u5176\u7ed3\u679c\uff0c\u5e76\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: f1 ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u901a\u8fc7 scikit learn\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u76f8\u540c\u7684\u5217\u8868\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . f1_score ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u4e0e\u5176\u5355\u72ec\u770b\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u60a8\u8fd8\u53ef\u4ee5\u53ea\u770b F1 \u5206\u6570\u3002\u4e0e\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c\u51c6\u786e\u5ea6\u4e00\u6837\uff0cF1 \u5206\u6570\u7684\u8303\u56f4\u4e5f\u662f\u4ece 0 \u5230 1\uff0c\u5b8c\u7f8e\u9884\u6d4b\u6a21\u578b\u7684 F1 \u5206\u6570\u4e3a 1\u3002 \u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5e94\u8be5\u4e86\u89e3\u5176\u4ed6\u4e00\u4e9b\u5173\u952e\u672f\u8bed\u3002 \u7b2c\u4e00\u4e2a\u672f\u8bed\u662f TPR \u6216\u771f\u9633\u6027\u7387\uff08True Positive Rate\uff09\uff0c\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\u3002 \\[ TPR = TP/(TP + FN) \\] \u5c3d\u7ba1\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\uff0c\u4f46\u6211\u4eec\u5c06\u4e3a\u5b83\u521b\u5efa\u4e00\u4e2a python \u51fd\u6570\uff0c\u4ee5\u4fbf\u4eca\u540e\u4f7f\u7528\u8fd9\u4e2a\u540d\u79f0\u3002 def tpr ( y_true , y_pred ): # \u771f\u9633\u6027\u7387\uff08TPR\uff09\uff0c\u4e0e\u53ec\u56de\u7387\u8ba1\u7b97\u516c\u5f0f\u4e00\u81f4 return recall ( y_true , y_pred ) TPR \u6216\u53ec\u56de\u7387\u4e5f\u88ab\u79f0\u4e3a\u7075\u654f\u5ea6\u3002 \u800c FPR \u6216\u5047\u9633\u6027\u7387\uff08False Positive Rate\uff09\u7684\u5b9a\u4e49\u662f\uff1a \\[ FPR = FP / (TN + FP) \\] def fpr ( y_true , y_pred ): # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u8fd4\u56de\u5047\u9633\u6027\u7387\uff08FPR\uff09 return fp / ( tn + fp ) 1 - FPR \u88ab\u79f0\u4e3a\u7279\u5f02\u6027\u6216\u771f\u9634\u6027\u7387\u6216 TNR\u3002\u8fd9\u4e9b\u672f\u8bed\u5f88\u591a\uff0c\u4f46\u5176\u4e2d\u6700\u91cd\u8981\u7684\u53ea\u6709 TPR \u548c FPR\u3002\u5047\u8bbe\u6211\u4eec\u53ea\u6709 15 \u4e2a\u6837\u672c\uff0c\u5176\u76ee\u6807\u503c\u4e3a\u4e8c\u5143\uff1a Actual targets : [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1] \u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u7c7b\u4f3c\u968f\u673a\u68ee\u6797\u7684\u6a21\u578b\uff0c\u5c31\u80fd\u5f97\u5230\u6837\u672c\u5448\u9633\u6027\u7684\u6982\u7387\u3002 Predicted probabilities for 1: [0.1, 0.3, 0.2, 0.6, 0.8, 0.05, 0.9, 0.5, 0.3, 0.66, 0.3, 0.2, 0.85, 0.15, 0.99] \u5bf9\u4e8e >= 0.5 \u7684\u5178\u578b\u9608\u503c\uff0c\u6211\u4eec\u53ef\u4ee5\u8bc4\u4f30\u4e0a\u8ff0\u6240\u6709\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387/TPR\u3001F1 \u548c FPR \u503c\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u5c06\u9608\u503c\u9009\u4e3a 0.4 \u6216 0.6\uff0c\u4e5f\u53ef\u4ee5\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0 \u5230 1 \u4e4b\u95f4\u7684\u4efb\u4f55\u503c\uff0c\u5e76\u8ba1\u7b97\u4e0a\u8ff0\u6240\u6709\u6307\u6807\u3002 \u4e0d\u8fc7\uff0c\u6211\u4eec\u53ea\u8ba1\u7b97\u4e24\u4e2a\u503c\uff1a TPR \u548c FPR\u3002 # \u521d\u59cb\u5316\u771f\u9633\u6027\u7387\u5217\u8868 tpr_list = [] # \u521d\u59cb\u5316\u5047\u9633\u6027\u7387\u5217\u8868 fpr_list = [] # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u7387 temp_tpr = tpr ( y_true , temp_pred ) # \u5047\u9633\u6027\u7387 temp_fpr = fpr ( y_true , temp_pred ) # \u5c06\u771f\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 tpr_list . append ( temp_tpr ) # \u5c06\u5047\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 fpr_list . append ( temp_fpr ) \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u6bcf\u4e2a\u9608\u503c\u7684 TPR \u503c\u548c FPR \u503c\u3002 \u56fe 3\uff1a\u9608\u503c\u3001TPR \u548c FPR \u503c\u8868 \u5982\u679c\u6211\u4eec\u7ed8\u5236\u5982\u56fe 3 \u6240\u793a\u7684\u8868\u683c\uff0c\u5373\u4ee5 TPR \u4e3a Y \u8f74\uff0cFPR \u4e3a X \u8f74\uff0c\u5c31\u4f1a\u5f97\u5230\u5982\u56fe 4 \u6240\u793a\u7684\u66f2\u7ebf\u3002 \u56fe 4\uff1aROC\u66f2\u7ebf \u8fd9\u6761\u66f2\u7ebf\u4e5f\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u3002\u5982\u679c\u6211\u4eec\u8ba1\u7b97\u8fd9\u6761 ROC \u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff0c\u5c31\u662f\u5728\u8ba1\u7b97\u53e6\u4e00\u4e2a\u6307\u6807\uff0c\u5f53\u6570\u636e\u96c6\u7684\u4e8c\u5143\u76ee\u6807\u504f\u659c\u65f6\uff0c\u8fd9\u4e2a\u6307\u6807\u5c31\u4f1a\u975e\u5e38\u5e38\u7528\u3002 \u8fd9\u4e2a\u6307\u6807\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u6216\u66f2\u7ebf\u4e0b\u9762\u79ef\uff0c\u7b80\u79f0 AUC\u3002\u8ba1\u7b97 ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002\u5728\u6b64\uff0c\u6211\u4eec\u5c06\u91c7\u7528 scikit- learn \u7684\u5947\u5999\u5b9e\u73b0\u65b9\u6cd5\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: metrics . roc_auc_score ( y_true , y_pred ) Out [ X ]: 0.8300000000000001 AUC \u503c\u4ece 0 \u5230 1 \u4e0d\u7b49\u3002 AUC = 1 \u610f\u5473\u7740\u60a8\u62e5\u6709\u4e00\u4e2a\u5b8c\u7f8e\u7684\u6a21\u578b\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u610f\u5473\u7740\u4f60\u5728\u9a8c\u8bc1\u65f6\u72af\u4e86\u4e00\u4e9b\u9519\u8bef\uff0c\u5e94\u8be5\u91cd\u65b0\u5ba1\u89c6\u6570\u636e\u5904\u7406\u548c\u9a8c\u8bc1\u6d41\u7a0b\u3002\u5982\u679c\u4f60\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\uff0c\u90a3\u4e48\u606d\u559c\u4f60\uff0c\u4f60\u5df2\u7ecf\u62e5\u6709\u4e86\u9488\u5bf9\u6570\u636e\u96c6\u5efa\u7acb\u7684\u6700\u4f73\u6a21\u578b\u3002 AUC = 0 \u610f\u5473\u7740\u60a8\u7684\u6a21\u578b\u975e\u5e38\u7cdf\u7cd5\uff08\u6216\u975e\u5e38\u597d\uff01\uff09\u3002\u8bd5\u7740\u53cd\u8f6c\u9884\u6d4b\u7684\u6982\u7387\uff0c\u4f8b\u5982\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u6b63\u7c7b\u7684\u6982\u7387\u662f p\uff0c\u8bd5\u7740\u7528 1-p \u4ee3\u66ff\u5b83\u3002\u8fd9\u79cd AUC \u4e5f\u53ef\u80fd\u610f\u5473\u7740\u60a8\u7684\u9a8c\u8bc1\u6216\u6570\u636e\u5904\u7406\u5b58\u5728\u95ee\u9898\u3002 AUC = 0.5 \u610f\u5473\u7740\u4f60\u7684\u9884\u6d4b\u662f\u968f\u673a\u7684\u3002\u56e0\u6b64\uff0c\u5bf9\u4e8e\u4efb\u4f55\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u5982\u679c\u6211\u5c06\u6240\u6709\u76ee\u6807\u90fd\u9884\u6d4b\u4e3a 0.5\uff0c\u6211\u5c06\u5f97\u5230 0.5 \u7684 AUC\u3002 AUC \u503c\u4ecb\u4e8e 0 \u548c 0.5 \u4e4b\u95f4\uff0c\u610f\u5473\u7740\u4f60\u7684\u6a21\u578b\u6bd4\u968f\u673a\u6a21\u578b\u66f4\u5dee\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u662f\u56e0\u4e3a\u4f60\u98a0\u5012\u4e86\u7c7b\u522b\u3002 \u5982\u679c\u60a8\u5c1d\u8bd5\u53cd\u8f6c\u9884\u6d4b\uff0c\u60a8\u7684 AUC \u503c\u53ef\u80fd\u4f1a\u8d85\u8fc7 0.5\u3002\u63a5\u8fd1 1 \u7684 AUC \u503c\u88ab\u8ba4\u4e3a\u662f\u597d\u503c\u3002 \u4f46 AUC \u5bf9\u6211\u4eec\u7684\u6a21\u578b\u6709\u4ec0\u4e48\u5f71\u54cd\u5462\uff1f \u5047\u8bbe\u60a8\u5efa\u7acb\u4e86\u4e00\u4e2a\u4ece\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e2d\u68c0\u6d4b\u6c14\u80f8\u7684\u6a21\u578b\uff0c\u5176 AUC \u503c\u4e3a 0.85\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5982\u679c\u60a8\u4ece\u6570\u636e\u96c6\u4e2d\u968f\u673a\u9009\u62e9\u4e00\u5f20\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9633\u6027\u6837\u672c\uff09\u548c\u53e6\u4e00\u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9634\u6027\u6837\u672c\uff09\uff0c\u90a3\u4e48\u6c14\u80f8\u56fe\u50cf\u7684\u6392\u540d\u5c06\u9ad8\u4e8e\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u6982\u7387\u4e3a 0.85\u3002 \u8ba1\u7b97\u6982\u7387\u548c AUC \u540e\uff0c\u60a8\u9700\u8981\u5bf9\u6d4b\u8bd5\u96c6\u8fdb\u884c\u9884\u6d4b\u3002\u6839\u636e\u95ee\u9898\u548c\u4f7f\u7528\u60c5\u51b5\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u6982\u7387\u6216\u5b9e\u9645\u7c7b\u522b\u3002\u5982\u679c\u4f60\u60f3\u8981\u6982\u7387\uff0c\u8fd9\u5e76\u4e0d\u96be\u3002\u5982\u679c\u60a8\u60f3\u8981\u7c7b\u522b\uff0c\u5219\u9700\u8981\u9009\u62e9\u4e00\u4e2a\u9608\u503c\u3002\u5728\u4e8c\u5143\u5206\u7c7b\u7684\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u91c7\u7528\u7c7b\u4f3c\u4e0b\u9762\u7684\u65b9\u6cd5\u3002 \\[ Prediction = Probability >= Threshold \\] \u4e5f\u5c31\u662f\u8bf4\uff0c\u9884\u6d4b\u662f\u4e00\u4e2a\u53ea\u5305\u542b\u4e8c\u5143\u53d8\u91cf\u7684\u65b0\u5217\u8868\u3002\u5982\u679c\u6982\u7387\u5927\u4e8e\u6216\u7b49\u4e8e\u7ed9\u5b9a\u7684\u9608\u503c\uff0c\u5219\u9884\u6d4b\u4e2d\u7684\u4e00\u9879\u4e3a 1\uff0c\u5426\u5219\u4e3a 0\u3002 \u4f60\u731c\u600e\u4e48\u7740\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 ROC \u66f2\u7ebf\u6765\u9009\u62e9\u8fd9\u4e2a\u9608\u503c\uff01ROC \u66f2\u7ebf\u4f1a\u544a\u8bc9\u60a8\u9608\u503c\u5bf9\u5047\u9633\u6027\u7387\u548c\u771f\u9633\u6027\u7387\u7684\u5f71\u54cd\uff0c\u8fdb\u800c\u5f71\u54cd\u5047\u9633\u6027\u548c\u771f\u9633\u6027\u3002\u60a8\u5e94\u8be5\u9009\u62e9\u6700\u9002\u5408\u60a8\u7684\u95ee\u9898\u548c\u6570\u636e\u96c6\u7684\u9608\u503c\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u4e0d\u5e0c\u671b\u6709\u592a\u591a\u7684\u8bef\u62a5\uff0c\u90a3\u4e48\u9608\u503c\u5c31\u5e94\u8be5\u9ad8\u4e00\u4e9b\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e5f\u4f1a\u5e26\u6765\u66f4\u591a\u7684\u8bef\u62a5\u3002\u6ce8\u610f\u6743\u8861\u5229\u5f0a\uff0c\u9009\u62e9\u6700\u4f73\u9608\u503c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u9608\u503c\u5982\u4f55\u5f71\u54cd\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u503c\u3002 # \u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list = [] # \u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list = [] # \u771f\u5b9e\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 temp_tp = true_positive ( y_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 temp_fp = false_positive ( y_true , temp_pred ) # \u52a0\u5165\u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list . append ( temp_tp ) # \u52a0\u5165\u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list . append ( temp_fp ) \u5229\u7528\u8fd9\u4e00\u70b9\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8868\u683c\uff0c\u5982\u56fe 5 \u6240\u793a\u3002 \u56fe 5\uff1a\u4e0d\u540c\u9608\u503c\u7684 TP \u503c\u548c FP \u503c \u5982\u56fe 6 \u6240\u793a\uff0c\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0cROC \u66f2\u7ebf\u5de6\u4e0a\u89d2\u7684\u503c\u5e94\u8be5\u662f\u4e00\u4e2a\u76f8\u5f53\u4e0d\u9519\u7684\u9608\u503c\u3002 \u5bf9\u6bd4\u8868\u683c\u548c ROC \u66f2\u7ebf\uff0c\u6211\u4eec\u53ef\u4ee5\u53d1\u73b0\uff0c0.6 \u5de6\u53f3\u7684\u9608\u503c\u76f8\u5f53\u4e0d\u9519\uff0c\u65e2\u4e0d\u4f1a\u4e22\u5931\u5927\u91cf\u7684\u771f\u9633\u6027\u7ed3\u679c\uff0c\u4e5f\u4e0d\u4f1a\u51fa\u73b0\u5927\u91cf\u7684\u5047\u9633\u6027\u7ed3\u679c\u3002 \u56fe 6\uff1a\u4ece ROC \u66f2\u7ebf\u6700\u5de6\u4fa7\u7684\u9876\u70b9\u9009\u62e9\u6700\u4f73\u9608\u503c AUC \u662f\u4e1a\u5185\u5e7f\u6cdb\u5e94\u7528\u4e8e\u504f\u659c\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6307\u6807\uff0c\u4e5f\u662f\u6bcf\u4e2a\u4eba\u90fd\u5e94\u8be5\u4e86\u89e3\u7684\u6307\u6807\u3002\u4e00\u65e6\u7406\u89e3\u4e86 AUC \u80cc\u540e\u7684\u7406\u5ff5\uff08\u5982\u4e0a\u6587\u6240\u8ff0\uff09\uff0c\u4e5f\u5c31\u5f88\u5bb9\u6613\u5411\u4e1a\u754c\u53ef\u80fd\u4f1a\u8bc4\u4f30\u60a8\u7684\u6a21\u578b\u7684\u975e\u6280\u672f\u4eba\u5458\u89e3\u91ca\u5b83\u4e86\u3002 \u5b66\u4e60 AUC \u540e\uff0c\u4f60\u5e94\u8be5\u5b66\u4e60\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u6307\u6807\u662f\u5bf9\u6570\u635f\u5931\u3002\u5bf9\u4e8e\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5c06\u5bf9\u6570\u635f\u5931\u5b9a\u4e49\u4e3a\uff1a \\[ LogLoss = -1.0 \\times (target \\times log(prediction) + (1-target) \\times log(1-prediction)) \\] \u5176\u4e2d\uff0c\u76ee\u6807\u503c\u4e3a 0 \u6216 1\uff0c\u9884\u6d4b\u503c\u4e3a\u6837\u672c\u5c5e\u4e8e\u7c7b\u522b 1 \u7684\u6982\u7387\u3002 \u5bf9\u4e8e\u6570\u636e\u96c6\u4e2d\u7684\u591a\u4e2a\u6837\u672c\uff0c\u6240\u6709\u6837\u672c\u7684\u5bf9\u6570\u635f\u5931\u53ea\u662f\u6240\u6709\u5355\u4e2a\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u9700\u8981\u8bb0\u4f4f\u7684\u4e00\u70b9\u662f\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u4e0d\u6b63\u786e\u6216\u504f\u5dee\u8f83\u5927\u7684\u9884\u6d4b\u8fdb\u884c\u76f8\u5f53\u9ad8\u7684\u60e9\u7f5a\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u975e\u5e38\u786e\u5b9a\u548c\u975e\u5e38\u9519\u8bef\u7684\u9884\u6d4b\u8fdb\u884c\u60e9\u7f5a\u3002 import numpy as np def log_loss ( y_true , y_proba ): # \u6781\u5c0f\u503c\uff0c\u9632\u6b620\u505a\u5206\u6bcd epsilon = 1e-15 # \u5bf9\u6570\u635f\u5931\u5217\u8868 loss = [] # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_proba ): # \u9650\u5236yp\u8303\u56f4\uff0c\u6700\u5c0f\u4e3aepsilon\uff0c\u6700\u5927\u4e3a1-epsilon yp = np . clip ( yp , epsilon , 1 - epsilon ) # \u8ba1\u7b97\u5bf9\u6570\u635f\u5931 temp_loss = - 1.0 * ( yt * np . log ( yp ) + ( 1 - yt ) * np . log ( 1 - yp )) # \u52a0\u5165\u5bf9\u6570\u635f\u5931\u5217\u8868 loss . append ( temp_loss ) return np . mean ( loss ) \u8ba9\u6211\u4eec\u6d4b\u8bd5\u4e00\u4e0b\u51fd\u6570\u6267\u884c\u60c5\u51b5\uff1a In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_proba = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u6211\u4eec\u53ef\u4ee5\u5c06\u5176\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u5b9e\u73b0\u662f\u6b63\u786e\u7684\u3002 \u5bf9\u6570\u635f\u5931\u7684\u5b9e\u73b0\u5f88\u5bb9\u6613\u3002\u89e3\u91ca\u8d77\u6765\u4f3c\u4e4e\u6709\u70b9\u56f0\u96be\u3002\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5bf9\u6570\u635f\u5931\u7684\u60e9\u7f5a\u8981\u6bd4\u5176\u4ed6\u6307\u6807\u5927\u5f97\u591a\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u6709 51% \u7684\u628a\u63e1\u8ba4\u4e3a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\uff0c\u90a3\u4e48\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.51) + (1 - 1) \\times log(1 - 0.51))=0.67 \\] \u5982\u679c\u4f60\u5bf9\u5c5e\u4e8e 0 \u7c7b\u7684\u6837\u672c\u6709 49% \u7684\u628a\u63e1\uff0c\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.49) + (1 - 1) \\times log(1 - 0.49))=0.67 \\] \u56e0\u6b64\uff0c\u5373\u4f7f\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0.5 \u7684\u622a\u65ad\u503c\u5e76\u5f97\u5230\u5b8c\u7f8e\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u6211\u4eec\u4ecd\u7136\u4f1a\u6709\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002\u56e0\u6b64\uff0c\u5728\u5904\u7406\u5bf9\u6570\u635f\u5931\u65f6\uff0c\u4f60\u9700\u8981\u975e\u5e38\u5c0f\u5fc3\uff1b\u4efb\u4f55\u4e0d\u786e\u5b9a\u7684\u9884\u6d4b\u90fd\u4f1a\u4ea7\u751f\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002 \u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u5927\u591a\u6570\u6307\u6807\u90fd\u53ef\u4ee5\u8f6c\u6362\u6210\u591a\u7c7b\u7248\u672c\u3002\u8fd9\u4e2a\u60f3\u6cd5\u5f88\u7b80\u5355\u3002\u4ee5\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u4e3a\u4f8b\u3002\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u6bcf\u4e00\u7c7b\u7684\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u6709\u4e09\u79cd\u4e0d\u540c\u7684\u8ba1\u7b97\u65b9\u6cd5\uff0c\u6709\u65f6\u53ef\u80fd\u4f1a\u4ee4\u4eba\u56f0\u60d1\u3002\u5047\u8bbe\u6211\u4eec\u9996\u5148\u5bf9\u7cbe\u786e\u7387\u611f\u5174\u8da3\u3002\u6211\u4eec\u77e5\u9053\uff0c\u7cbe\u786e\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u3002 \u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Macro averaged precision\uff09\uff1a\u5206\u522b\u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u7cbe\u786e\u7387\u7136\u540e\u6c42\u5e73\u5747\u503c \u5fae\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Micro averaged precision\uff09\uff1a\u6309\u7c7b\u8ba1\u7b97\u771f\u9633\u6027\u548c\u5047\u9633\u6027\uff0c\u7136\u540e\u7528\u5176\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387\u3002\u7136\u540e\u4ee5\u6b64\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387 \u52a0\u6743\u7cbe\u786e\u7387 \uff08Weighted precision\uff09\uff1a\u4e0e\u5b8f\u89c2\u7cbe\u786e\u7387\u76f8\u540c\uff0c\u4f46\u8fd9\u91cc\u662f\u52a0\u6743\u5e73\u5747\u7cbe\u786e\u7387 \u53d6\u51b3\u4e8e\u6bcf\u4e2a\u7c7b\u522b\u4e2d\u7684\u9879\u76ee\u6570 \u8fd9\u770b\u4f3c\u590d\u6742\uff0c\u4f46\u5728 python \u5b9e\u73b0\u4e2d\u5f88\u5bb9\u6613\u7406\u89e3\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import numpy as np def macro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u5982\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u8ba1\u7b97\u7cbe\u786e\u5ea6 temp_precision = tp / ( tp + fp ) # \u5404\u7c7b\u7cbe\u786e\u7387\u76f8\u52a0 precision += temp_precision # \u8ba1\u7b97\u5e73\u5747\u503c precision /= num_classes return precision \u4f60\u4f1a\u53d1\u73b0\u8fd9\u5e76\u4e0d\u96be\u3002\u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5fae\u5e73\u5747\u7cbe\u786e\u7387\u5206\u6570\u3002 import numpy as np def micro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u6570 tp = 0 # \u521d\u59cb\u5316\u5047\u9633\u6027\u6837\u672c\u6570 fp = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 tp += true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 fp += false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8fd9\u4e5f\u4e0d\u96be\u3002\u90a3\u4ec0\u4e48\u96be\uff1f\u4ec0\u4e48\u90fd\u4e0d\u96be\u3002\u673a\u5668\u5b66\u4e60\u5f88\u7b80\u5355\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u52a0\u6743\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u3002 from collections import Counter import numpy as np def weighted_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 temp_precision = tp / ( tp + fp ) # \u6839\u636e\u8be5\u79cd\u7c7b\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_precision = class_counts [ class_ ] * temp_precision # \u52a0\u6743\u7cbe\u786e\u7387\u6c42\u548c precision += weighted_precision # \u8ba1\u7b97\u5e73\u5747\u7cbe\u786e\u7387 overall_precision = precision / len ( y_true ) return overall_precision \u5c06\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff0c\u4ee5\u4e86\u89e3\u5b9e\u73b0\u662f\u5426\u6b63\u786e\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: macro_precision ( y_true , y_pred ) Out [ X ]: 0.3611111111111111 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"macro\" ) Out [ X ]: 0.3611111111111111 In [ X ]: micro_precision ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"micro\" ) Out [ X ]: 0.4444444444444444 In [ X ]: weighted_precision ( y_true , y_pred ) Out [ X ]: 0.39814814814814814 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.39814814814814814 \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6b63\u786e\u5730\u5b9e\u73b0\u4e86\u4e00\u5207\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u91cc\u5c55\u793a\u7684\u5b9e\u73b0\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\uff0c\u4f46\u5374\u662f\u6700\u5bb9\u6613\u7406\u89e3\u7684\u3002 \u540c\u6837\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5b9e\u73b0 \u591a\u7c7b\u522b\u7684\u53ec\u56de\u7387\u6307\u6807 \u3002\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u3001\u5047\u9633\u6027\u548c\u5047\u9634\u6027\uff0c\u800c F1 \u5219\u53d6\u51b3\u4e8e\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u53ec\u56de\u7387\u7684\u5b9e\u73b0\u65b9\u6cd5\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\uff0c\u8fd9\u91cc\u5b9e\u73b0\u7684\u662f\u591a\u7c7b F1 \u7684\u4e00\u4e2a\u7248\u672c\uff0c\u5373\u52a0\u6743\u5e73\u5747\u503c\u3002 from collections import Counter import numpy as np def weighted_f1 ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316F1\u503c f1 = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( temp_true , temp_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( temp_true , temp_pred ) # \u82e5\u7cbe\u786e\u7387+\u53ec\u56de\u7387\u4e0d\u4e3a0\uff0c\u5219\u4f7f\u7528\u516c\u5f0f\u8ba1\u7b97F1\u503c if p + r != 0 : temp_f1 = 2 * p * r / ( p + r ) # \u5426\u5219\u76f4\u63a5\u4e3a0 else : temp_f1 = 0 # \u6839\u636e\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_f1 = class_counts [ class_ ] * temp_f1 # \u52a0\u6743F1\u503c\u76f8\u52a0 f1 += weighted_f1 # \u8ba1\u7b97\u52a0\u6743\u5e73\u5747F1\u503c overall_f1 = f1 / len ( y_true ) return overall_f1 \u8bf7\u6ce8\u610f\uff0c\u4e0a\u9762\u6709\u51e0\u884c\u4ee3\u7801\u662f\u65b0\u5199\u7684\u3002\u56e0\u6b64\uff0c\u4f60\u5e94\u8be5\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e9b\u4ee3\u7801\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: weighted_f1 ( y_true , y_pred ) Out [ X ]: 0.41269841269841273 In [ X ]: metrics . f1_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.41269841269841273 \u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u591a\u7c7b\u95ee\u9898\u5b9e\u73b0\u4e86\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002\u540c\u6837\uff0c\u60a8\u4e5f\u53ef\u4ee5\u5c06 AUC \u548c\u5bf9\u6570\u635f\u5931\u8f6c\u6362\u4e3a\u591a\u7c7b\u683c\u5f0f\u3002\u8fd9\u79cd\u8f6c\u6362\u683c\u5f0f\u88ab\u79f0\u4e3a one-vs-all \u3002\u8fd9\u91cc\u6211\u4e0d\u6253\u7b97\u5b9e\u73b0\u5b83\u4eec\uff0c\u56e0\u4e3a\u5b9e\u73b0\u65b9\u6cd5\u4e0e\u6211\u4eec\u5df2\u7ecf\u8ba8\u8bba\u8fc7\u7684\u5f88\u76f8\u4f3c\u3002 \u5728\u4e8c\u5143\u6216\u591a\u7c7b\u5206\u7c7b\u4e2d\uff0c\u770b\u4e00\u4e0b \u6df7\u6dc6\u77e9\u9635 \u4e5f\u5f88\u6d41\u884c\u3002\u4e0d\u8981\u56f0\u60d1\uff0c\u8fd9\u5f88\u7b80\u5355\u3002\u6df7\u6dc6\u77e9\u9635\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u5305\u542b TP\u3001FP\u3001TN \u548c FN \u7684\u8868\u683c\u3002\u4f7f\u7528\u6df7\u6dc6\u77e9\u9635\uff0c\u60a8\u53ef\u4ee5\u5feb\u901f\u67e5\u770b\u6709\u591a\u5c11\u6837\u672c\u88ab\u9519\u8bef\u5206\u7c7b\uff0c\u6709\u591a\u5c11\u6837\u672c\u88ab\u6b63\u786e\u5206\u7c7b\u3002\u4e5f\u8bb8\u6709\u4eba\u4f1a\u8bf4\uff0c\u6df7\u6dc6\u77e9\u9635\u5e94\u8be5\u5728\u672c\u7ae0\u5f88\u65e9\u5c31\u8bb2\u5230\uff0c\u4f46\u6211\u6ca1\u6709\u8fd9\u4e48\u505a\u3002\u5982\u679c\u4e86\u89e3\u4e86 TP\u3001FP\u3001TN\u3001FN\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c AUC\uff0c\u5c31\u5f88\u5bb9\u6613\u7406\u89e3\u548c\u89e3\u91ca\u6df7\u6dc6\u77e9\u9635\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u56fe 7 \u4e2d\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6df7\u6dc6\u77e9\u9635\u7531 TP\u3001FP\u3001FN \u548c TN \u7ec4\u6210\u3002\u6211\u4eec\u53ea\u9700\u8981\u8fd9\u4e9b\u503c\u6765\u8ba1\u7b97\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u3001F1 \u5206\u6570\u548c AUC\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u4e5f\u559c\u6b22\u628a FP \u79f0\u4e3a \u7b2c\u4e00\u7c7b\u9519\u8bef \uff0c\u628a FN \u79f0\u4e3a \u7b2c\u4e8c\u7c7b\u9519\u8bef \u3002 \u56fe 7\uff1a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6df7\u6dc6\u77e9\u9635 \u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u4e8c\u5143\u6df7\u6dc6\u77e9\u9635\u6269\u5c55\u4e3a\u591a\u7c7b\u6df7\u6dc6\u77e9\u9635\u3002\u5b83\u4f1a\u662f\u4ec0\u4e48\u6837\u5b50\u5462\uff1f\u5982\u679c\u6211\u4eec\u6709 N \u4e2a\u7c7b\u522b\uff0c\u5b83\u5c06\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a NxN \u7684\u77e9\u9635\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b\uff0c\u6211\u4eec\u90fd\u8981\u8ba1\u7b97\u76f8\u5173\u7c7b\u522b\u548c\u5176\u4ed6\u7c7b\u522b\u7684\u6837\u672c\u603b\u6570\u3002\u4e3e\u4e2a\u4f8b\u5b50\u53ef\u4ee5\u8ba9\u6211\u4eec\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4ee5\u4e0b\u771f\u5b9e\u6807\u7b7e\uff1a \\[ [0, 1, 2, 0, 1, 2, 0, 2, 2] \\] \u6211\u4eec\u7684\u9884\u6d4b\u6807\u7b7e\u662f\uff1a \\[ [0, 2, 1, 0, 2, 1, 0, 0, 2] \\] \u90a3\u4e48\uff0c\u6211\u4eec\u7684\u6df7\u6dc6\u77e9\u9635\u5c06\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u591a\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635 \u56fe 8 \u8bf4\u660e\u4e86\u4ec0\u4e48\uff1f \u8ba9\u6211\u4eec\u6765\u770b\u770b 0 \u7c7b\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e 0 \u7c7b\u3002\u7136\u800c\uff0c\u5728\u9884\u6d4b\u4e2d\uff0c\u6211\u4eec\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 0 \u7c7b\uff0c1 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\u3002\u7406\u60f3\u60c5\u51b5\u4e0b\uff0c\u5bf9\u4e8e\u771f\u5b9e\u6807\u7b7e\u4e2d\u7684\u7c7b\u522b 0\uff0c\u9884\u6d4b\u6807\u7b7e 1 \u548c 2 \u5e94\u8be5\u6ca1\u6709\u4efb\u4f55\u6837\u672c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u7c7b\u522b 2\u3002\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 4\uff0c\u800c\u5728\u9884\u6d4b\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 3\u3002 \u4e00\u4e2a\u5b8c\u7f8e\u7684\u6df7\u6dc6\u77e9\u9635\u53ea\u80fd\u4ece\u5de6\u5230\u53f3\u659c\u5411\u586b\u5145\u3002 \u6df7\u6dc6\u77e9\u9635 \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u8ba1\u7b97\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u4e0d\u540c\u6307\u6807\u3002Scikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u76f4\u63a5\u7684\u65b9\u6cd5\u6765\u751f\u6210\u6df7\u6dc6\u77e9\u9635\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5728\u56fe 8 \u4e2d\u663e\u793a\u7684\u6df7\u6dc6\u77e9\u9635\u662f scikit-learn \u6df7\u6dc6\u77e9\u9635\u7684\u8f6c\u7f6e\uff0c\u539f\u59cb\u7248\u672c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u7ed8\u5236\u3002 import matplotlib.pyplot as plt import seaborn as sns from sklearn import metrics # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] # \u9884\u6d4b\u6837\u672c\u6807\u7b7e y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] # \u8ba1\u7b97\u6df7\u6dc6\u77e9\u9635 cm = metrics . confusion_matrix ( y_true , y_pred ) # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 10 , 10 )) # \u521b\u5efa\u65b9\u683c cmap = sns . cubehelix_palette ( 50 , hue = 0.05 , rot = 0 , light = 0.9 , dark = 0 , as_cmap = True ) # \u89c4\u5b9a\u5b57\u4f53\u5927\u5c0f sns . set ( font_scale = 2.5 ) # \u7ed8\u5236\u70ed\u56fe sns . heatmap ( cm , annot = True , cmap = cmap , cbar = False ) # y\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . ylabel ( 'Actual Labels' , fontsize = 20 ) # x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . xlabel ( 'Predicted Labels' , fontsize = 20 ) \u56e0\u6b64\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u89e3\u51b3\u4e86\u4e8c\u5143\u5206\u7c7b\u548c\u591a\u7c7b\u5206\u7c7b\u7684\u5ea6\u91cf\u95ee\u9898\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u8ba8\u8bba\u53e6\u4e00\u79cd\u7c7b\u578b\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u5373\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u5728\u591a\u6807\u7b7e\u5206\u7c7b\u4e2d\uff0c\u6bcf\u4e2a\u6837\u672c\u90fd\u53ef\u80fd\u4e0e\u4e00\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u76f8\u5173\u8054\u3002\u8fd9\u7c7b\u95ee\u9898\u7684\u4e00\u4e2a\u7b80\u5355\u4f8b\u5b50\u5c31\u662f\u8981\u6c42\u4f60\u9884\u6d4b\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53\u3002 \u56fe 9 \u663e\u793a\u4e86\u4e00\u4e2a\u8457\u540d\u6570\u636e\u96c6\u7684\u56fe\u50cf\u793a\u4f8b\u3002\u8bf7\u6ce8\u610f\uff0c\u8be5\u6570\u636e\u96c6\u7684\u76ee\u6807\u6709\u6240\u4e0d\u540c\uff0c\u4f46\u6211\u4eec\u6682\u4e14\u4e0d\u53bb\u8ba8\u8bba\u5b83\u3002\u6211\u4eec\u5047\u8bbe\u5176\u76ee\u7684\u53ea\u662f\u9884\u6d4b\u56fe\u50cf\u4e2d\u662f\u5426\u5b58\u5728\u67d0\u4e2a\u7269\u4f53\u3002\u5728\u56fe 9 \u4e2d\uff0c\u6211\u4eec\u6709\u6905\u5b50\u3001\u82b1\u76c6\u3001\u7a97\u6237\uff0c\u4f46\u6ca1\u6709\u5176\u4ed6\u7269\u4f53\uff0c\u5982\u7535\u8111\u3001\u5e8a\u3001\u7535\u89c6\u7b49\u3002\u56e0\u6b64\uff0c\u4e00\u5e45\u56fe\u50cf\u53ef\u80fd\u6709\u591a\u4e2a\u76f8\u5173\u76ee\u6807\u3002\u8fd9\u7c7b\u95ee\u9898\u5c31\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u95ee\u9898\u3002 \u56fe 9\uff1a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53 \u8fd9\u7c7b\u5206\u7c7b\u95ee\u9898\u7684\u8861\u91cf\u6807\u51c6\u6709\u4e9b\u4e0d\u540c\u3002\u4e00\u4e9b\u5408\u9002\u7684 \u6700\u5e38\u89c1\u7684\u6307\u6807\u6709\uff1a k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u786e\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 \u8ba9\u6211\u4eec\u4ece k \u7cbe\u786e\u7387\u6216\u8005 P@k \u6211\u4eec\u4e0d\u80fd\u5c06\u8fd9\u4e00\u7cbe\u786e\u7387\u4e0e\u524d\u9762\u8ba8\u8bba\u7684\u7cbe\u786e\u7387\u6df7\u6dc6\u3002\u5982\u679c\u60a8\u6709\u4e00\u4e2a\u7ed9\u5b9a\u6837\u672c\u7684\u539f\u59cb\u7c7b\u522b\u5217\u8868\u548c\u540c\u4e00\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u7c7b\u522b\u5217\u8868\uff0c\u90a3\u4e48\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u5c31\u662f\u9884\u6d4b\u5217\u8868\u4e2d\u4ec5\u8003\u8651\u524d k \u4e2a\u9884\u6d4b\u7ed3\u679c\u7684\u547d\u4e2d\u6570\u9664\u4ee5 k\u3002 \u5982\u679c\u60a8\u5bf9\u6b64\u611f\u5230\u56f0\u60d1\uff0c\u4f7f\u7528 python \u4ee3\u7801\u540e\u5c31\u4f1a\u660e\u767d\u3002 def pk ( y_true , y_pred , k ): # \u5982\u679ck\u4e3a0 if k == 0 : # \u8fd4\u56de0 return 0 # \u53d6\u9884\u6d4b\u6807\u7b7e\u524dk\u4e2a y_pred = y_pred [: k ] # \u5c06\u9884\u6d4b\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 pred_set = set ( y_pred ) # \u5c06\u771f\u5b9e\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 true_set = set ( y_true ) # \u9884\u6d4b\u6807\u7b7e\u96c6\u5408\u4e0e\u771f\u5b9e\u6807\u7b7e\u96c6\u5408\u4ea4\u96c6 common_values = pred_set . intersection ( true_set ) # \u8ba1\u7b97\u7cbe\u786e\u7387 return len ( common_values ) / len ( y_pred [: k ]) \u6709\u4e86\u4ee3\u7801\uff0c\u4e00\u5207\u90fd\u53d8\u5f97\u66f4\u5bb9\u6613\u7406\u89e3\u4e86\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e86 k \u5e73\u5747\u7cbe\u786e\u7387\u6216 AP@k \u3002AP@k \u662f\u901a\u8fc7 P@k \u8ba1\u7b97\u5f97\u51fa\u7684\u3002\u4f8b\u5982\uff0c\u5982\u679c\u8981\u8ba1\u7b97 AP@3\uff0c\u6211\u4eec\u8981\u5148\u8ba1\u7b97 P@1\u3001P@2 \u548c P@3\uff0c\u7136\u540e\u5c06\u603b\u548c\u9664\u4ee5 3\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u7684\u5b9e\u73b0\u3002 def apk ( y_true , y_pred , k ): # \u521d\u59cb\u5316P@k\u5217\u8868 pk_values = [] # \u904d\u53861~k for i in range ( 1 , k + 1 ): # \u5c06P@k\u52a0\u5165\u5217\u8868 pk_values . append ( pk ( y_true , y_pred , i )) # \u82e5\u957f\u5ea6\u4e3a0 if len ( pk_values ) == 0 : # \u8fd4\u56de0 return 0 # \u5426\u5219\u8ba1\u7b97AP@K return sum ( pk_values ) / len ( pk_values ) \u8fd9\u4e24\u4e2a\u51fd\u6570\u53ef\u4ee5\u7528\u6765\u8ba1\u7b97\u4e24\u4e2a\u7ed9\u5b9a\u5217\u8868\u7684 k \u5e73\u5747\u7cbe\u786e\u7387 (AP@k)\uff1b\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u8ba1\u7b97\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: for i in range ( len ( y_true )): ... : for j in range ( 1 , 4 ): ... : print ( ... : f \"\"\" ...: y_true= { y_true [ i ] } , ...: y_pred= { y_pred [ i ] } , ...: AP@ { j } = { apk ( y_true [ i ], y_pred [ i ], k = j ) } ...: \"\"\" ... : ) ... : y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 1 = 0.0 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 2 = 0.25 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 3 = 0.38888888888888884 \u8bf7\u6ce8\u610f\uff0c\u6211\u7701\u7565\u4e86\u8f93\u51fa\u7ed3\u679c\u4e2d\u7684\u8bb8\u591a\u6570\u503c\uff0c\u4f46\u4f60\u4f1a\u660e\u767d\u5176\u4e2d\u7684\u610f\u601d\u3002\u8fd9\u5c31\u662f\u6211\u4eec\u5982\u4f55\u8ba1\u7b97 AP@k \u7684\u65b9\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u7684 AP@k\u3002\u5728\u673a\u5668\u5b66\u4e60\u4e2d\uff0c\u6211\u4eec\u5bf9\u6240\u6709\u6837\u672c\u90fd\u611f\u5174\u8da3\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u6709 \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387 k \u6216 MAP@k \u3002MAP@k \u53ea\u662f AP@k \u7684\u5e73\u5747\u503c\uff0c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b python \u4ee3\u7801\u8f7b\u677e\u8ba1\u7b97\u3002 def mapk ( y_true , y_pred , k ): # \u521d\u59cb\u5316AP@k\u5217\u8868 apk_values = [] # \u904d\u53860~\uff08\u771f\u5b9e\u6807\u7b7e\u6570-1\uff09 for i in range ( len ( y_true )): # \u5c06AP@K\u52a0\u5165\u5217\u8868 apk_values . append ( apk ( y_true [ i ], y_pred [ i ], k = k ) ) # \u8ba1\u7b97\u5e73\u5747AP@k return sum ( apk_values ) / len ( apk_values ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u9488\u5bf9\u76f8\u540c\u7684\u5217\u8868\u8ba1\u7b97 k=1\u30012\u30013 \u548c 4 \u65f6\u7684 MAP@k\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: mapk ( y_true , y_pred , k = 1 ) Out [ X ]: 0.3333333333333333 In [ X ]: mapk ( y_true , y_pred , k = 2 ) Out [ X ]: 0.375 In [ X ]: mapk ( y_true , y_pred , k = 3 ) Out [ X ]: 0.3611111111111111 In [ X ]: mapk ( y_true , y_pred , k = 4 ) Out [ X ]: 0.34722222222222215 P@k\u3001AP@k \u548c MAP@k \u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u5176\u4e2d 1 \u4e3a\u6700\u4f73\u3002 \u8bf7\u6ce8\u610f\uff0c\u6709\u65f6\u60a8\u53ef\u80fd\u4f1a\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230 P@k \u548c AP@k \u7684\u4e0d\u540c\u5b9e\u73b0\u65b9\u5f0f\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5176\u4e2d\u4e00\u79cd\u5b9e\u73b0\u65b9\u5f0f\u3002 import numpy as np def apk ( actual , predicted , k = 10 ): # \u82e5\u9884\u6d4b\u6807\u7b7e\u957f\u5ea6\u5927\u4e8ek if len ( predicted ) > k : # \u53d6\u524dk\u4e2a\u6807\u7b7e predicted = predicted [: k ] score = 0.0 num_hits = 0.0 for i , p in enumerate ( predicted ): if p in actual and p not in predicted [: i ]: num_hits += 1.0 score += num_hits / ( i + 1.0 ) if not actual : return 0.0 return score / min ( len ( actual ), k ) \u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u662f AP@k \u7684\u53e6\u4e00\u4e2a\u7248\u672c\uff0c\u5176\u4e2d\u987a\u5e8f\u5f88\u91cd\u8981\uff0c\u6211\u4eec\u8981\u6743\u8861\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u7684\u7ed3\u679c\u4e0e\u6211\u7684\u4ecb\u7ecd\u7565\u6709\u4e0d\u540c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6765\u770b\u770b \u591a\u6807\u7b7e\u5206\u7c7b\u7684\u5bf9\u6570\u635f\u5931 \u3002\u8fd9\u5f88\u5bb9\u6613\u3002\u60a8\u53ef\u4ee5\u5c06\u76ee\u6807\u8f6c\u6362\u4e3a\u4e8c\u5143\u5206\u7c7b\uff0c\u7136\u540e\u5bf9\u6bcf\u4e00\u5217\u4f7f\u7528\u5bf9\u6570\u635f\u5931\u3002\u6700\u540e\uff0c\u4f60\u53ef\u4ee5\u6c42\u51fa\u6bcf\u5217\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u8fd9\u4e5f\u88ab\u79f0\u4e3a\u5e73\u5747\u5217\u5bf9\u6570\u635f\u5931\u3002\u5f53\u7136\uff0c\u8fd8\u6709\u5176\u4ed6\u65b9\u6cd5\u53ef\u4ee5\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff0c\u4f60\u5e94\u8be5\u5728\u9047\u5230\u65f6\u52a0\u4ee5\u63a2\u7d22\u3002 \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u8bf4\u5df2\u7ecf\u638c\u63e1\u4e86\u6240\u6709\u4e8c\u5143\u5206\u7c7b\u3001\u591a\u7c7b\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u6307\u6807\uff0c\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u6307\u6807\u3002 \u56de\u5f52\u4e2d\u6700\u5e38\u89c1\u7684\u6307\u6807\u662f \u8bef\u5dee\uff08Error\uff09 \u3002\u8bef\u5dee\u5f88\u7b80\u5355\uff0c\u4e5f\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \\[ Error = True\\ Value - Predicted\\ Value \\] \u7edd\u5bf9\u8bef\u5dee\uff08Absolute error\uff09 \u53ea\u662f\u4e0a\u8ff0\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\u3002 \\[ Absolute\\ Error = Abs(True\\ Value - Predicted\\ Value) \\] \u63a5\u4e0b\u6765\u6211\u4eec\u8ba8\u8bba \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff08MAE\uff09 \u3002\u5b83\u53ea\u662f\u6240\u6709\u7edd\u5bf9\u8bef\u5dee\u7684\u5e73\u5747\u503c\u3002 import numpy as np def mean_absolute_error ( y_true , y_pred ): #\u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u7edd\u5bf9\u8bef\u5dee error += np . abs ( yt - yp ) # \u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee return error / len ( y_true ) \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5e73\u65b9\u8bef\u5dee\u548c \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u3002 \\[ Squared\\ Error = (True Value - Predicted\\ Value)^2 \\] \u5747\u65b9\u8bef\u5dee\uff08MSE\uff09\u7684\u8ba1\u7b97\u65b9\u5f0f\u5982\u4e0b def mean_squared_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u8bef\u5dee\u5e73\u65b9\u548c error += ( yt - yp ) ** 2 # \u8ba1\u7b97\u5747\u65b9\u8bef\u5dee return error / len ( y_true ) MSE \u548c RMSE\uff08\u5747\u65b9\u6839\u8bef\u5dee\uff09 \u662f\u8bc4\u4f30\u56de\u5f52\u6a21\u578b\u6700\u5e38\u7528\u7684\u6307\u6807\u3002 \\[ RMSE = SQRT(MSE) \\] \u540c\u4e00\u7c7b\u8bef\u5dee\u7684\u53e6\u4e00\u79cd\u7c7b\u578b\u662f \u5e73\u65b9\u5bf9\u6570\u8bef\u5dee \u3002\u6709\u4eba\u79f0\u5176\u4e3a SLE \uff0c\u5f53\u6211\u4eec\u53d6\u6240\u6709\u6837\u672c\u4e2d\u8fd9\u4e00\u8bef\u5dee\u7684\u5e73\u5747\u503c\u65f6\uff0c\u5b83\u88ab\u79f0\u4e3a MSLE\uff08\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee\uff09\uff0c\u5b9e\u73b0\u65b9\u6cd5\u5982\u4e0b\u3002 import numpy as np def mean_squared_log_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee error += ( np . log ( 1 + yt ) - np . log ( 1 + yp )) ** 2 # \u8ba1\u7b97\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee return error / len ( y_true ) \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \u53ea\u662f\u5176\u5e73\u65b9\u6839\u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a RMSLE \u3002 \u7136\u540e\u662f\u767e\u5206\u6bd4\u8bef\u5dee\uff1a \\[ Percentage\\ Error = (( True\\ Value \u2013 Predicted\\ Value ) / True\\ Value ) \\times 100 \\] \u540c\u6837\u53ef\u4ee5\u8f6c\u6362\u4e3a\u6240\u6709\u6837\u672c\u7684\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee\u3002 def mean_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u767e\u5206\u6bd4\u8bef\u5dee error += ( yt - yp ) / yt # \u8fd4\u56de\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u7edd\u5bf9\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\uff08\u4e5f\u662f\u66f4\u5e38\u89c1\u7684\u7248\u672c\uff09\u88ab\u79f0\u4e3a \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee\u6216 MAPE \u3002 import numpy as np def mean_abs_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee error += np . abs ( yt - yp ) / yt #\u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u56de\u5f52\u7684\u6700\u5927\u4f18\u70b9\u662f\uff0c\u53ea\u6709\u51e0\u4e2a\u6700\u5e38\u7528\u7684\u6307\u6807\uff0c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u6240\u6709\u56de\u5f52\u95ee\u9898\u3002\u4e0e\u5206\u7c7b\u6307\u6807\u76f8\u6bd4\uff0c\u56de\u5f52\u6307\u6807\u66f4\u5bb9\u6613\u7406\u89e3\u3002 \u8ba9\u6211\u4eec\u6765\u8c08\u8c08\u53e6\u4e00\u4e2a\u56de\u5f52\u6307\u6807 \\(R^2\\) \uff08R \u65b9\uff09\uff0c\u4e5f\u79f0\u4e3a \u5224\u5b9a\u7cfb\u6570 \u3002 \u7b80\u5355\u5730\u8bf4\uff0cR \u65b9\u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u3002R \u65b9\u63a5\u8fd1 1.0 \u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u76f8\u5f53\u597d\uff0c\u800c\u63a5\u8fd1 0 \u5219\u8868\u793a\u6a21\u578b\u4e0d\u662f\u90a3\u4e48\u597d\u3002\u5f53\u6a21\u578b\u53ea\u662f\u505a\u51fa\u8352\u8c2c\u7684\u9884\u6d4b\u65f6\uff0cR \u65b9\u4e5f\u53ef\u80fd\u662f\u8d1f\u503c\u3002 R \u65b9\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\u6240\u793a\uff0c\u4f46 Python \u7684\u5b9e\u73b0\u603b\u662f\u80fd\u8ba9\u4e00\u5207\u66f4\u52a0\u6e05\u6670\u3002 \\[ R^2 = \\frac{\\sum^{N}_{i=1}(y_{t_i}-y_{p_i})^2}{\\sum^{N}_{i=1}(y_{t_i} - y_{t_{mean}})} \\] import numpy as np def r2 ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u5747\u771f\u5b9e\u503c mean_true_value = np . mean ( y_true ) # \u521d\u59cb\u5316\u5e73\u65b9\u8bef\u5dee numerator = 0 denominator = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): numerator += ( yt - yp ) ** 2 denominator += ( yt - mean_true_value ) ** 2 ratio = numerator / denominator # \u8ba1\u7b97R\u65b9 return 1 \u2013 ratio \u8fd8\u6709\u66f4\u591a\u7684\u8bc4\u4ef7\u6307\u6807\uff0c\u8fd9\u4e2a\u6e05\u5355\u6c38\u8fdc\u4e5f\u5217\u4e0d\u5b8c\u3002\u6211\u53ef\u4ee5\u5199\u4e00\u672c\u4e66\uff0c\u53ea\u4ecb\u7ecd\u4e0d\u540c\u7684\u8bc4\u4ef7\u6307\u6807\u3002\u4e5f\u8bb8\u6211\u4f1a\u7684\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b\u8bc4\u4f30\u6307\u6807\u51e0\u4e4e\u53ef\u4ee5\u6ee1\u8db3\u4f60\u60f3\u5c1d\u8bd5\u89e3\u51b3\u7684\u6240\u6709\u95ee\u9898\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ee5\u6700\u76f4\u63a5\u7684\u65b9\u5f0f\u5b9e\u73b0\u4e86\u8fd9\u4e9b\u6307\u6807\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u4eec\u4e0d\u591f\u9ad8\u6548\u3002\u4f60\u53ef\u4ee5\u901a\u8fc7\u6b63\u786e\u4f7f\u7528 numpy \u4ee5\u975e\u5e38\u9ad8\u6548\u7684\u65b9\u5f0f\u5b9e\u73b0\u5176\u4e2d\u5927\u90e8\u5206\u6307\u6807\u3002\u4f8b\u5982\uff0c\u770b\u770b\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\u7684\u5b9e\u73b0\uff0c\u4e0d\u9700\u8981\u4efb\u4f55\u5faa\u73af\u3002 import numpy as np def mae_np ( y_true , y_pred ): return np . mean ( np . abs ( y_true - y_pred )) \u6211\u672c\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5b9e\u73b0\u6240\u6709\u6307\u6807\uff0c\u4f46\u4e3a\u4e86\u5b66\u4e60\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b\u5e95\u5c42\u5b9e\u73b0\u3002\u4e00\u65e6\u4f60\u5b66\u4f1a\u4e86\u7eaf python \u7684\u5e95\u5c42\u5b9e\u73b0\uff0c\u5e76\u4e14\u4e0d\u4f7f\u7528\u5927\u91cf numpy\uff0c\u4f60\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u5176\u8f6c\u6362\u4e3a numpy\uff0c\u5e76\u4f7f\u5176\u53d8\u5f97\u66f4\u5feb\u3002 \u7136\u540e\u662f\u4e00\u4e9b\u9ad8\u7ea7\u5ea6\u91cf\u3002 \u5176\u4e2d\u4e00\u4e2a\u5e94\u7528\u76f8\u5f53\u5e7f\u6cdb\u7684\u6307\u6807\u662f \u4e8c\u6b21\u52a0\u6743\u5361\u5e15 \uff0c\u4e5f\u79f0\u4e3a QWK \u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a\u79d1\u6069\u5361\u5e15\u3002 QWK \u8861\u91cf\u4e24\u4e2a \"\u8bc4\u5206 \"\u4e4b\u95f4\u7684 \"\u4e00\u81f4\u6027\"\u3002\u8bc4\u5206\u53ef\u4ee5\u662f 0 \u5230 N \u4e4b\u95f4\u7684\u4efb\u4f55\u5b9e\u6570\uff0c\u9884\u6d4b\u4e5f\u5728\u540c\u4e00\u8303\u56f4\u5185\u3002\u4e00\u81f4\u6027\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u8fd9\u4e9b\u8bc4\u7ea7\u4e4b\u95f4\u7684\u63a5\u8fd1\u7a0b\u5ea6\u3002\u56e0\u6b64\uff0c\u5b83\u9002\u7528\u4e8e\u6709 N \u4e2a\u4e0d\u540c\u7c7b\u522b\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u4e00\u81f4\u5ea6\u9ad8\uff0c\u5206\u6570\u5c31\u66f4\u63a5\u8fd1 1.0\u3002Cohen's kappa \u5728 scikit-learn \u4e2d\u6709\u5f88\u597d\u7684\u5b9e\u73b0\uff0c\u5173\u4e8e\u8be5\u6307\u6807\u7684\u8be6\u7ec6\u8ba8\u8bba\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 ] In [ X ]: y_pred = [ 2 , 1 , 3 , 1 , 2 , 3 , 3 , 1 , 2 ] In [ X ]: metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) Out [ X ]: 0.33333333333333337 In [ X ]: metrics . accuracy_score ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 \u60a8\u53ef\u4ee5\u770b\u5230\uff0c\u5c3d\u7ba1\u51c6\u786e\u5ea6\u5f88\u9ad8\uff0c\u4f46 QWK \u5374\u5f88\u4f4e\u3002QWK \u5927\u4e8e 0.85 \u5373\u4e3a\u975e\u5e38\u597d\uff01 \u4e00\u4e2a\u91cd\u8981\u7684\u6307\u6807\u662f \u9a6c\u4fee\u76f8\u5173\u7cfb\u6570\uff08MCC\uff09 \u30021 \u4ee3\u8868\u5b8c\u7f8e\u9884\u6d4b\uff0c-1 \u4ee3\u8868\u4e0d\u5b8c\u7f8e\u9884\u6d4b\uff0c0 \u4ee3\u8868\u968f\u673a\u9884\u6d4b\u3002MCC \u7684\u8ba1\u7b97\u516c\u5f0f\u975e\u5e38\u7b80\u5355\u3002 \\[ MCC = \\frac{TP \\times TN - FP \\times FN}{\\sqrt{(TP + FP) \\times (FN + TN) \\times (FP + TN) \\times (TP + FN)}} \\] \u6211\u4eec\u770b\u5230\uff0cMCC \u8003\u8651\u4e86 TP\u3001FP\u3001TN \u548c FN\uff0c\u56e0\u6b64\u53ef\u7528\u4e8e\u5904\u7406\u7c7b\u504f\u659c\u7684\u95ee\u9898\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u7684\u65b9\u6cd5\u5728 python \u4e2d\u5feb\u901f\u5b9e\u73b0\u5b83\u3002 def mcc ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) numerator = ( tp * tn ) - ( fp * fn ) denominator = ( ( tp + fp ) * ( fn + tn ) * ( fp + tn ) * ( tp + fn ) ) denominator = denominator ** 0.5 return numerator / denominator \u8fd9\u4e9b\u6307\u6807\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002 \u9700\u8981\u6ce8\u610f\u7684\u4e00\u70b9\u662f\uff0c\u5728\u8bc4\u4f30\u975e\u76d1\u7763\u65b9\u6cd5\uff08\u4f8b\u5982\u67d0\u79cd\u805a\u7c7b\uff09\u65f6\uff0c\u6700\u597d\u521b\u5efa\u6216\u624b\u52a8\u6807\u8bb0\u6d4b\u8bd5\u96c6\uff0c\u5e76\u5c06\u5176\u4e0e\u5efa\u6a21\u90e8\u5206\u7684\u6240\u6709\u5185\u5bb9\u5206\u5f00\u3002\u5b8c\u6210\u805a\u7c7b\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u4efb\u4f55\u4e00\u79cd\u76d1\u7763\u5b66\u4e60\u6307\u6807\u6765\u8bc4\u4f30\u6d4b\u8bd5\u96c6\u7684\u6027\u80fd\u4e86\u3002 \u4e00\u65e6\u6211\u4eec\u4e86\u89e3\u4e86\u7279\u5b9a\u95ee\u9898\u5e94\u8be5\u4f7f\u7528\u4ec0\u4e48\u6307\u6807\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f00\u59cb\u66f4\u6df1\u5165\u5730\u7814\u7a76\u6211\u4eec\u7684\u6a21\u578b\uff0c\u4ee5\u6c42\u6539\u8fdb\u3002","title":"\u8bc4\u4f30\u6307\u6807"},{"location":"%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87/#_1","text":"\u8bf4\u5230\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u4f60\u4f1a\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u9047\u5230\u5f88\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u6307\u6807\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u751a\u81f3\u4f1a\u6839\u636e\u4e1a\u52a1\u95ee\u9898\u521b\u5efa\u5ea6\u91cf\u6807\u51c6\u3002\u9010\u4e00\u4ecb\u7ecd\u548c\u89e3\u91ca\u6bcf\u4e00\u79cd\u5ea6\u91cf\u7c7b\u578b\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u5ea6\u91cf\u6807\u51c6\uff0c\u4f9b\u4f60\u5728\u6700\u521d\u7684\u51e0\u4e2a\u9879\u76ee\u4e2d\u4f7f\u7528\u3002 \u5728\u672c\u4e66\u7684\u5f00\u5934\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u76d1\u7763\u5b66\u4e60\u548c\u975e\u76d1\u7763\u5b66\u4e60\u3002\u867d\u7136\u65e0\u76d1\u7763\u5b66\u4e60\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e9b\u6307\u6807\uff0c\u4f46\u6211\u4eec\u5c06\u53ea\u5173\u6ce8\u6709\u76d1\u7763\u5b66\u4e60\u3002\u8fd9\u662f\u56e0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u591a\uff0c\u800c\u4e14\u5bf9\u65e0\u76d1\u7763\u65b9\u6cd5\u7684\u8bc4\u4f30\u76f8\u5f53\u4e3b\u89c2\u3002 \u5982\u679c\u6211\u4eec\u8c08\u8bba\u5206\u7c7b\u95ee\u9898\uff0c\u6700\u5e38\u7528\u7684\u6307\u6807\u662f\uff1a \u51c6\u786e\u7387\uff08Accuracy\uff09 \u7cbe\u786e\u7387\uff08P\uff09 \u53ec\u56de\u7387\uff08R\uff09 F1 \u5206\u6570\uff08F1\uff09 AUC\uff08AUC\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u8bf4\u5230\u56de\u5f52\uff0c\u6700\u5e38\u7528\u7684\u8bc4\u4ef7\u6307\u6807\u662f \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee \uff08MAE\uff09 \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u5747\u65b9\u6839\u8bef\u5dee \uff08RMSE\uff09 \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \uff08RMSLE\uff09 \u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee \uff08MPE\uff09 \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee \uff08MAPE\uff09 R2 \u4e86\u89e3\u4e0a\u8ff0\u6307\u6807\u7684\u5de5\u4f5c\u539f\u7406\u5e76\u4e0d\u662f\u6211\u4eec\u5fc5\u987b\u4e86\u89e3\u7684\u552f\u4e00\u4e8b\u60c5\u3002\u6211\u4eec\u8fd8\u5fc5\u987b\u77e5\u9053\u4f55\u65f6\u4f7f\u7528\u54ea\u4e9b\u6307\u6807\uff0c\u800c\u8fd9\u53d6\u51b3\u4e8e\u4f60\u6709\u4ec0\u4e48\u6837\u7684\u6570\u636e\u548c\u76ee\u6807\u3002\u6211\u8ba4\u4e3a\u8fd9\u4e0e\u76ee\u6807\u6709\u5173\uff0c\u800c\u4e0e\u6570\u636e\u65e0\u5173\u3002 \u8981\u8fdb\u4e00\u6b65\u4e86\u89e3\u8fd9\u4e9b\u6307\u6807\uff0c\u8ba9\u6211\u4eec\u4ece\u4e00\u4e2a\u7b80\u5355\u7684\u95ee\u9898\u5f00\u59cb\u3002\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a \u4e8c\u5143\u5206\u7c7b \u95ee\u9898\uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u76ee\u6807\u7684\u95ee\u9898\uff0c\u5047\u8bbe\u8fd9\u662f\u4e00\u4e2a\u80f8\u90e8 X \u5149\u56fe\u50cf\u5206\u7c7b\u95ee\u9898\u3002\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6ca1\u6709\u95ee\u9898\uff0c\u800c\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6709\u80ba\u584c\u9677\uff0c\u4e5f\u5c31\u662f\u6240\u8c13\u7684\u6c14\u80f8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u5206\u7c7b\u5668\uff0c\u5728\u7ed9\u5b9a\u80f8\u90e8 X \u5149\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\uff0c\u5b83\u80fd\u68c0\u6d4b\u51fa\u56fe\u50cf\u662f\u5426\u6709\u6c14\u80f8\u3002 \u56fe 1\uff1a\u6c14\u80f8\u80ba\u90e8\u56fe\u50cf \u6211\u4eec\u8fd8\u5047\u8bbe\u6709\u76f8\u540c\u6570\u91cf\u7684\u6c14\u80f8\u548c\u975e\u6c14\u80f8\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u6bd4\u5982\u5404 100 \u5f20\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709 100 \u5f20\u9633\u6027\u6837\u672c\u548c 100 \u5f20\u9634\u6027\u6837\u672c\uff0c\u5171\u8ba1 200 \u5f20\u56fe\u50cf\u3002 \u7b2c\u4e00\u6b65\u662f\u5c06\u4e0a\u8ff0\u6570\u636e\u5206\u4e3a\u4e24\u7ec4\uff0c\u6bcf\u7ec4 100 \u5f20\u56fe\u50cf\uff0c\u5373\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u8fd9\u4e24\u4e2a\u96c6\u5408\u4e2d\uff0c\u6211\u4eec\u90fd\u6709 50 \u4e2a\u6b63\u6837\u672c\u548c 50 \u4e2a\u8d1f\u6837\u672c\u3002 \u5728\u4e8c\u5143\u5206\u7c7b\u6307\u6807\u4e2d\uff0c\u5f53\u6b63\u8d1f\u6837\u672c\u6570\u91cf\u76f8\u7b49\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f7f\u7528\u51c6\u786e\u7387\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002 \u51c6\u786e\u7387 \uff1a\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6700\u76f4\u63a5\u7684\u6307\u6807\u4e4b\u4e00\u3002\u5b83\u5b9a\u4e49\u4e86\u6a21\u578b\u7684\u51c6\u786e\u5ea6\u3002\u5bf9\u4e8e\u4e0a\u8ff0\u95ee\u9898\uff0c\u5982\u679c\u4f60\u5efa\u7acb\u7684\u6a21\u578b\u80fd\u51c6\u786e\u5206\u7c7b 90 \u5f20\u56fe\u7247\uff0c\u90a3\u4e48\u4f60\u7684\u51c6\u786e\u7387\u5c31\u662f 90% \u6216 0.90\u3002\u5982\u679c\u53ea\u6709 83 \u5e45\u56fe\u50cf\u88ab\u6b63\u786e\u5206\u7c7b\uff0c\u90a3\u4e48\u6a21\u578b\u7684\u51c6\u786e\u7387\u5c31\u662f 83% \u6216 0.83\u3002 \u8ba1\u7b97\u51c6\u786e\u7387\u7684 Python \u4ee3\u7801\u4e5f\u975e\u5e38\u7b80\u5355\u3002 def accuracy ( y_true , y_pred ): # \u4e3a\u6b63\u786e\u9884\u6d4b\u6570\u521d\u59cb\u5316\u4e00\u4e2a\u7b80\u5355\u8ba1\u6570\u5668 correct_counter = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): if yt == yp : # \u5982\u679c\u9884\u6d4b\u6807\u7b7e\u4e0e\u771f\u5b9e\u6807\u7b7e\u76f8\u540c\uff0c\u5219\u589e\u52a0\u8ba1\u6570\u5668 correct_counter += 1 # \u8fd4\u56de\u6b63\u786e\u7387\uff0c\u6b63\u786e\u6807\u7b7e\u6570/\u603b\u6807\u7b7e\u6570 return correct_counter / len ( y_true ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u8ba1\u7b97\u51c6\u786e\u7387\u3002 In [ X ]: from sklearn import metrics ... : l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] ... : metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u73b0\u5728\uff0c\u5047\u8bbe\u6211\u4eec\u628a\u6570\u636e\u96c6\u7a0d\u5fae\u6539\u52a8\u4e00\u4e0b\uff0c\u6709 180 \u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u53ea\u6709 20 \u5f20\u6709\u6c14\u80f8\u3002\u5373\u4f7f\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4e5f\u8981\u521b\u5efa\u6b63\u8d1f\uff08\u6c14\u80f8\u4e0e\u975e\u6c14\u80f8\uff09\u76ee\u6807\u6bd4\u4f8b\u76f8\u540c\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u6bcf\u4e00\u7ec4\u4e2d\uff0c\u6211\u4eec\u6709 90 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u5982\u679c\u8bf4\u9a8c\u8bc1\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u90fd\u662f\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u90a3\u4e48\u60a8\u7684\u51c6\u786e\u7387\u4f1a\u662f\u591a\u5c11\u5462\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\uff1b\u60a8\u5bf9 90% \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u6b63\u786e\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u7684\u51c6\u786e\u7387\u662f 90%\u3002 \u4f46\u8bf7\u518d\u770b\u4e00\u904d\u3002 \u4f60\u751a\u81f3\u6ca1\u6709\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5c31\u5f97\u5230\u4e86 90% \u7684\u51c6\u786e\u7387\u3002\u8fd9\u4f3c\u4e4e\u6709\u70b9\u6ca1\u7528\u3002\u5982\u679c\u6211\u4eec\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u6570\u636e\u96c6\u662f\u504f\u659c\u7684\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u6bd4\u53e6\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u591a\u5f88\u591a\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u56e0\u4e3a\u5b83\u4e0d\u80fd\u4ee3\u8868\u6570\u636e\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u80fd\u4f1a\u83b7\u5f97\u5f88\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u4f46\u60a8\u7684\u6a21\u578b\u5728\u5b9e\u9645\u6837\u672c\u4e2d\u7684\u8868\u73b0\u53ef\u80fd\u5e76\u4e0d\u7406\u60f3\uff0c\u800c\u4e14\u60a8\u4e5f\u65e0\u6cd5\u5411\u7ecf\u7406\u89e3\u91ca\u539f\u56e0\u3002 \u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b \u7cbe\u786e\u7387 \u7b49\u5176\u4ed6\u6307\u6807\u3002 \u5728\u5b66\u4e60\u7cbe\u786e\u7387\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e86\u89e3\u4e00\u4e9b\u672f\u8bed\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5047\u8bbe\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e3a\u6b63\u7c7b (1)\uff0c\u6ca1\u6709\u6c14\u80f8\u7684\u4e3a\u8d1f\u7c7b (0)\u3002 \u771f\u9633\u6027 \uff08TP\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6709\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9633\u6027\u3002 \u771f\u9634\u6027 \uff08TN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u800c\u5b9e\u9645\u76ee\u6807\u663e\u793a\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u6b63\u786e\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9633\u6027\uff1b\u5982\u679c\u60a8\u7684\u6a21\u578b\u51c6\u786e\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9634\u6027\u3002 \u5047\u9633\u6027 \uff08FP\uff09 \uff1a\u7ed9\u5b9a\u4e00\u5f20\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u975e\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9633\u6027\u3002 \u5047\u9634\u6027 \uff08FN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u975e\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\uff0c\u90a3\u4e48\u5b83\u5c31\u662f\u5047\u9633\u6027\u3002\u5982\u679c\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5219\u662f\u5047\u9634\u6027\u3002 \u8ba9\u6211\u4eec\u9010\u4e00\u770b\u770b\u8fd9\u4e9b\u5b9e\u73b0\u3002 def true_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u8ba1\u6570\u5668 tp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 1 : tp += 1 # \u8fd4\u56de\u771f\u9633\u6027\u6837\u672c\u6570 return tp def true_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9634\u6027\u6837\u672c\u8ba1\u6570\u5668 tn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 0 : tn += 1 # \u8fd4\u56de\u771f\u9634\u6027\u6837\u672c\u6570 return tn def false_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9633\u6027\u8ba1\u6570\u5668 fp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 1 : fp += 1 # \u8fd4\u56de\u5047\u9633\u6027\u6837\u672c\u6570 return fp def false_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9634\u6027\u8ba1\u6570\u5668 fn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 0 : fn += 1 # \u8fd4\u56de\u5047\u9634\u6027\u6570 return fn \u6211\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e9b\u529f\u80fd\u7684\u65b9\u6cd5\u975e\u5e38\u7b80\u5355\uff0c\u800c\u4e14\u53ea\u9002\u7528\u4e8e\u4e8c\u5143\u5206\u7c7b\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u8fd9\u4e9b\u51fd\u6570\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: true_positive ( l1 , l2 ) Out [ X ]: 2 In [ X ]: false_positive ( l1 , l2 ) Out [ X ]: 1 In [ X ]: false_negative ( l1 , l2 ) Out [ X ]: 2 In [ X ]: true_negative ( l1 , l2 ) Out [ X ]: 3 \u5982\u679c\u6211\u4eec\u5fc5\u987b\u7528\u4e0a\u8ff0\u672f\u8bed\u6765\u5b9a\u4e49\u7cbe\u786e\u7387\uff0c\u6211\u4eec\u53ef\u4ee5\u5199\u4e3a\uff1a \\[ Accuracy Score = (TP + TN)/(TP + TN + FP +FN) \\] \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5728 python \u4e2d\u4f7f\u7528 TP\u3001TN\u3001FP \u548c FN \u5feb\u901f\u5b9e\u73b0\u51c6\u786e\u5ea6\u5f97\u5206\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a accuracy_v2\u3002 def accuracy_v2 ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u51c6\u786e\u7387 accuracy_score = ( tp + tn ) / ( tp + tn + fp + fn ) return accuracy_score \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0e\u4e4b\u524d\u7684\u5b9e\u73b0\u548c scikit-learn \u7248\u672c\u8fdb\u884c\u6bd4\u8f83\uff0c\u5feb\u901f\u68c0\u67e5\u8be5\u51fd\u6570\u7684\u6b63\u786e\u6027\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: accuracy ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: accuracy_v2 ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u8bf7\u6ce8\u610f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0cmetrics.accuracy_score \u6765\u81ea scikit-learn\u3002 \u5f88\u597d\u3002\u6240\u6709\u503c\u90fd\u5339\u914d\u3002\u8fd9\u8bf4\u660e\u6211\u4eec\u5728\u5b9e\u73b0\u8fc7\u7a0b\u4e2d\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u5176\u4ed6\u91cd\u8981\u6307\u6807\u3002 \u9996\u5148\u662f\u7cbe\u786e\u7387\u3002\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u662f \\[ Precision = TP/(TP + FP) \\] \u5047\u8bbe\u6211\u4eec\u5728\u65b0\u7684\u504f\u659c\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u4e86\u4e00\u4e2a\u65b0\u6a21\u578b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 90 \u5f20\u56fe\u50cf\u4e2d\u7684 80 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u56fe\u50cf\u4e2d\u7684 8 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6210\u529f\u8bc6\u522b\u4e86 100 \u5f20\u56fe\u50cf\u4e2d\u7684 88 \u5f20\u3002\u56e0\u6b64\uff0c\u51c6\u786e\u7387\u4e3a 0.88 \u6216 88%\u3002 \u4f46\u662f\uff0c\u5728\u8fd9 100 \u5f20\u6837\u672c\u4e2d\uff0c\u6709 10 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u6c14\u80f8\uff0c2 \u5f20\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u975e\u6c14\u80f8\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 TP : 8 TN: 80 FP: 10 FN: 2 \u7cbe\u786e\u7387\u4e3a 8 / (8 + 10) = 0.444\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u5728\u8bc6\u522b\u9633\u6027\u6837\u672c\uff08\u6c14\u80f8\uff09\u65f6\u6709 44.4% \u7684\u6b63\u786e\u7387\u3002 \u73b0\u5728\uff0c\u65e2\u7136\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86 TP\u3001TN\u3001FP \u548c FN\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5728 python \u4e2d\u5b9e\u73b0\u7cbe\u786e\u7387\u4e86\u3002 def precision ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8ba9\u6211\u4eec\u8bd5\u8bd5\u8fd9\u79cd\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u65b9\u5f0f\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: precision ( l1 , l2 ) Out [ X ]: 0.6666666666666666 \u8fd9\u4f3c\u4e4e\u6ca1\u6709\u95ee\u9898\u3002 \u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u770b \u53ec\u56de\u7387 \u3002\u53ec\u56de\u7387\u7684\u5b9a\u4e49\u662f\uff1a \\[ Recall = TP/(TP + FN) \\] \u5728\u4e0a\u8ff0\u60c5\u51b5\u4e0b\uff0c\u53ec\u56de\u7387\u4e3a 8 / (8 + 2) = 0.80\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 80% \u7684\u9633\u6027\u6837\u672c\u3002 def recall ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u53ec\u56de\u7387 recall = tp / ( tp + fn ) return recall \u5c31\u6211\u4eec\u7684\u4e24\u4e2a\u5c0f\u5217\u8868\u800c\u8a00\uff0c\u53ec\u56de\u7387\u5e94\u8be5\u662f 0.5\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: recall ( l1 , l2 ) Out [ X ]: 0.5 \u8fd9\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u503c\u76f8\u7b26\uff01 \u5bf9\u4e8e\u4e00\u4e2a \"\u597d \"\u6a21\u578b\u6765\u8bf4\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u503c\u90fd\u5e94\u8be5\u5f88\u9ad8\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0c\u53ec\u56de\u503c\u76f8\u5f53\u9ad8\u3002\u4f46\u662f\uff0c\u7cbe\u786e\u7387\u5374\u5f88\u4f4e\uff01\u6211\u4eec\u7684\u6a21\u578b\u4ea7\u751f\u4e86\u5927\u91cf\u7684\u8bef\u62a5\uff0c\u4f46\u8bef\u62a5\u8f83\u5c11\u3002\u5728\u8fd9\u7c7b\u95ee\u9898\u4e2d\uff0c\u5047\u9634\u6027\u8f83\u5c11\u662f\u597d\u4e8b\uff0c\u56e0\u4e3a\u4f60\u4e0d\u60f3\u5728\u75c5\u4eba\u6709\u6c14\u80f8\u7684\u60c5\u51b5\u4e0b\u5374\u8bf4\u4ed6\u4eec\u6ca1\u6709\u6c14\u80f8\u3002\u8fd9\u6837\u505a\u4f1a\u9020\u6210\u66f4\u5927\u7684\u4f24\u5bb3\u3002\u4f46\u6211\u4eec\u4e5f\u6709\u5f88\u591a\u5047\u9633\u6027\u7ed3\u679c\uff0c\u8fd9\u4e5f\u4e0d\u662f\u597d\u4e8b\u3002 \u5927\u591a\u6570\u6a21\u578b\u90fd\u4f1a\u9884\u6d4b\u4e00\u4e2a\u6982\u7387\uff0c\u5f53\u6211\u4eec\u9884\u6d4b\u65f6\uff0c\u901a\u5e38\u4f1a\u5c06\u8fd9\u4e2a\u9608\u503c\u9009\u4e3a 0.5\u3002\u8fd9\u4e2a\u9608\u503c\u5e76\u4e0d\u603b\u662f\u7406\u60f3\u7684\uff0c\u6839\u636e\u8fd9\u4e2a\u9608\u503c\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u503c\u53ef\u80fd\u4f1a\u53d1\u751f\u5f88\u5927\u7684\u53d8\u5316\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u7684\u6bcf\u4e2a\u9608\u503c\u90fd\u80fd\u8ba1\u7b97\u51fa\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u53ef\u4ee5\u5728\u8fd9\u4e9b\u503c\u4e4b\u95f4\u7ed8\u5236\u51fa\u66f2\u7ebf\u56fe\u3002\u8fd9\u5e45\u56fe\u6216\u66f2\u7ebf\u88ab\u79f0\u4e3a \"\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\"\u3002 \u5728\u7814\u7a76\u7cbe\u786e\u7387-\u8c03\u7528\u66f2\u7ebf\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5047\u8bbe\u6709\u4e24\u4e2a\u5217\u8868\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0.02638412 , 0.11114267 , 0.31620708 , ... : 0.0490937 , 0.0191491 , 0.17554844 , ... : 0.15952202 , 0.03819563 , 0.11639273 , ... : 0.079377 , 0.08584789 , 0.39095342 , ... : 0.27259048 , 0.03447096 , 0.04644807 , ... : 0.03543574 , 0.18521942 , 0.05934905 , ... : 0.61977213 , 0.33056815 ] \u56e0\u6b64\uff0cy_true \u662f\u6211\u4eec\u7684\u76ee\u6807\u503c\uff0c\u800c y_pred \u662f\u6837\u672c\u88ab\u8d4b\u503c\u4e3a 1 \u7684\u6982\u7387\u503c\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u6211\u4eec\u8981\u770b\u7684\u662f\u9884\u6d4b\u4e2d\u7684\u6982\u7387\uff0c\u800c\u4e0d\u662f\u9884\u6d4b\u503c\uff08\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u9884\u6d4b\u503c\u7684\u8ba1\u7b97\u9608\u503c\u4e3a 0.5\uff09\u3002 precisions = [] recalls = [] thresholds = [ 0.0490937 , 0.05934905 , 0.079377 , 0.08584789 , 0.11114267 , 0.11639273 , 0.15952202 , 0.17554844 , 0.18521942 , 0.27259048 , 0.31620708 , 0.33056815 , 0.39095342 , 0.61977213 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for i in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_prediction = [ 1 if x >= i else 0 for x in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , temp_prediction ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , temp_prediction ) # \u52a0\u5165\u7cbe\u786e\u7387\u5217\u8868 precisions . append ( p ) # \u52a0\u5165\u53ec\u56de\u7387\u5217\u8868 recalls . append ( r ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7ed8\u5236\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 7 , 7 )) # x\u8f74\u4e3a\u53ec\u56de\u7387\uff0cy\u8f74\u4e3a\u7cbe\u786e\u7387 plt . plot ( recalls , precisions ) # \u6dfb\u52a0x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a15 plt . xlabel ( 'Recall' , fontsize = 15 ) # \u6dfb\u52a0y\u8f74\u6807\u7b7e\uff0c\u5b57\u6761\u5927\u5c0f\u4e3a15 plt . ylabel ( 'Precision' , fontsize = 15 ) \u56fe 2 \u663e\u793a\u4e86\u6211\u4eec\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\u5f97\u5230\u7684\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 \u56fe 2\uff1a\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u8fd9\u6761 \u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u4e0e\u60a8\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230\u7684\u66f2\u7ebf\u622a\u7136\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6211\u4eec\u53ea\u6709 20 \u4e2a\u6837\u672c\uff0c\u5176\u4e2d\u53ea\u6709 3 \u4e2a\u662f\u9633\u6027\u6837\u672c\u3002\u4f46\u8fd9\u6ca1\u4ec0\u4e48\u597d\u62c5\u5fc3\u7684\u3002\u8fd9\u8fd8\u662f\u90a3\u6761\u7cbe\u786e\u7387-\u53ec\u56de\u66f2\u7ebf\u3002 \u4f60\u4f1a\u53d1\u73b0\uff0c\u9009\u62e9\u4e00\u4e2a\u65e2\u80fd\u63d0\u4f9b\u826f\u597d\u7cbe\u786e\u7387\u53c8\u80fd\u63d0\u4f9b\u53ec\u56de\u503c\u7684\u9608\u503c\u662f\u5f88\u6709\u6311\u6218\u6027\u7684\u3002\u5982\u679c\u9608\u503c\u8fc7\u9ad8\uff0c\u771f\u9633\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u51cf\u5c11\uff0c\u800c\u5047\u9634\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u589e\u52a0\u3002\u8fd9\u4f1a\u964d\u4f4e\u53ec\u56de\u7387\uff0c\u4f46\u7cbe\u786e\u7387\u5f97\u5206\u4f1a\u5f88\u9ad8\u3002\u5982\u679c\u5c06\u9608\u503c\u964d\u5f97\u592a\u4f4e\uff0c\u5219\u8bef\u62a5\u4f1a\u5927\u91cf\u589e\u52a0\uff0c\u7cbe\u786e\u7387\u4e5f\u4f1a\u964d\u4f4e\u3002 \u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u8d8a\u63a5\u8fd1 1 \u8d8a\u597d\u3002 F1 \u5206\u6570\u662f\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7efc\u5408\u6307\u6807\u3002\u5b83\u88ab\u5b9a\u4e49\u4e3a\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7b80\u5355\u52a0\u6743\u5e73\u5747\u503c\uff08\u8c03\u548c\u5e73\u5747\u503c\uff09\u3002\u5982\u679c\u6211\u4eec\u7528 P \u8868\u793a\u7cbe\u786e\u7387\uff0c\u7528 R \u8868\u793a\u53ec\u56de\u7387\uff0c\u90a3\u4e48 F1 \u5206\u6570\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a \\[ F1 = 2PR/(P + R) \\] \u6839\u636e TP\u3001FP \u548c FN\uff0c\u7a0d\u52a0\u6570\u5b66\u8ba1\u7b97\u5c31\u80fd\u5f97\u51fa\u4ee5\u4e0b F1 \u7b49\u5f0f\uff1a \\[ F1 = 2TP/(2TP + FP + FN) \\] Python \u5b9e\u73b0\u5f88\u7b80\u5355\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86\u8fd9\u4e9b def f1 ( y_true , y_pred ): # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , y_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , y_pred ) # \u8ba1\u7b97f1\u503c score = 2 * p * r / ( p + r ) return score \u8ba9\u6211\u4eec\u770b\u770b\u5176\u7ed3\u679c\uff0c\u5e76\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: f1 ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u901a\u8fc7 scikit learn\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u76f8\u540c\u7684\u5217\u8868\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . f1_score ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u4e0e\u5176\u5355\u72ec\u770b\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u60a8\u8fd8\u53ef\u4ee5\u53ea\u770b F1 \u5206\u6570\u3002\u4e0e\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c\u51c6\u786e\u5ea6\u4e00\u6837\uff0cF1 \u5206\u6570\u7684\u8303\u56f4\u4e5f\u662f\u4ece 0 \u5230 1\uff0c\u5b8c\u7f8e\u9884\u6d4b\u6a21\u578b\u7684 F1 \u5206\u6570\u4e3a 1\u3002 \u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5e94\u8be5\u4e86\u89e3\u5176\u4ed6\u4e00\u4e9b\u5173\u952e\u672f\u8bed\u3002 \u7b2c\u4e00\u4e2a\u672f\u8bed\u662f TPR \u6216\u771f\u9633\u6027\u7387\uff08True Positive Rate\uff09\uff0c\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\u3002 \\[ TPR = TP/(TP + FN) \\] \u5c3d\u7ba1\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\uff0c\u4f46\u6211\u4eec\u5c06\u4e3a\u5b83\u521b\u5efa\u4e00\u4e2a python \u51fd\u6570\uff0c\u4ee5\u4fbf\u4eca\u540e\u4f7f\u7528\u8fd9\u4e2a\u540d\u79f0\u3002 def tpr ( y_true , y_pred ): # \u771f\u9633\u6027\u7387\uff08TPR\uff09\uff0c\u4e0e\u53ec\u56de\u7387\u8ba1\u7b97\u516c\u5f0f\u4e00\u81f4 return recall ( y_true , y_pred ) TPR \u6216\u53ec\u56de\u7387\u4e5f\u88ab\u79f0\u4e3a\u7075\u654f\u5ea6\u3002 \u800c FPR \u6216\u5047\u9633\u6027\u7387\uff08False Positive Rate\uff09\u7684\u5b9a\u4e49\u662f\uff1a \\[ FPR = FP / (TN + FP) \\] def fpr ( y_true , y_pred ): # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u8fd4\u56de\u5047\u9633\u6027\u7387\uff08FPR\uff09 return fp / ( tn + fp ) 1 - FPR \u88ab\u79f0\u4e3a\u7279\u5f02\u6027\u6216\u771f\u9634\u6027\u7387\u6216 TNR\u3002\u8fd9\u4e9b\u672f\u8bed\u5f88\u591a\uff0c\u4f46\u5176\u4e2d\u6700\u91cd\u8981\u7684\u53ea\u6709 TPR \u548c FPR\u3002\u5047\u8bbe\u6211\u4eec\u53ea\u6709 15 \u4e2a\u6837\u672c\uff0c\u5176\u76ee\u6807\u503c\u4e3a\u4e8c\u5143\uff1a Actual targets : [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1] \u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u7c7b\u4f3c\u968f\u673a\u68ee\u6797\u7684\u6a21\u578b\uff0c\u5c31\u80fd\u5f97\u5230\u6837\u672c\u5448\u9633\u6027\u7684\u6982\u7387\u3002 Predicted probabilities for 1: [0.1, 0.3, 0.2, 0.6, 0.8, 0.05, 0.9, 0.5, 0.3, 0.66, 0.3, 0.2, 0.85, 0.15, 0.99] \u5bf9\u4e8e >= 0.5 \u7684\u5178\u578b\u9608\u503c\uff0c\u6211\u4eec\u53ef\u4ee5\u8bc4\u4f30\u4e0a\u8ff0\u6240\u6709\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387/TPR\u3001F1 \u548c FPR \u503c\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u5c06\u9608\u503c\u9009\u4e3a 0.4 \u6216 0.6\uff0c\u4e5f\u53ef\u4ee5\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0 \u5230 1 \u4e4b\u95f4\u7684\u4efb\u4f55\u503c\uff0c\u5e76\u8ba1\u7b97\u4e0a\u8ff0\u6240\u6709\u6307\u6807\u3002 \u4e0d\u8fc7\uff0c\u6211\u4eec\u53ea\u8ba1\u7b97\u4e24\u4e2a\u503c\uff1a TPR \u548c FPR\u3002 # \u521d\u59cb\u5316\u771f\u9633\u6027\u7387\u5217\u8868 tpr_list = [] # \u521d\u59cb\u5316\u5047\u9633\u6027\u7387\u5217\u8868 fpr_list = [] # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u7387 temp_tpr = tpr ( y_true , temp_pred ) # \u5047\u9633\u6027\u7387 temp_fpr = fpr ( y_true , temp_pred ) # \u5c06\u771f\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 tpr_list . append ( temp_tpr ) # \u5c06\u5047\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 fpr_list . append ( temp_fpr ) \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u6bcf\u4e2a\u9608\u503c\u7684 TPR \u503c\u548c FPR \u503c\u3002 \u56fe 3\uff1a\u9608\u503c\u3001TPR \u548c FPR \u503c\u8868 \u5982\u679c\u6211\u4eec\u7ed8\u5236\u5982\u56fe 3 \u6240\u793a\u7684\u8868\u683c\uff0c\u5373\u4ee5 TPR \u4e3a Y \u8f74\uff0cFPR \u4e3a X \u8f74\uff0c\u5c31\u4f1a\u5f97\u5230\u5982\u56fe 4 \u6240\u793a\u7684\u66f2\u7ebf\u3002 \u56fe 4\uff1aROC\u66f2\u7ebf \u8fd9\u6761\u66f2\u7ebf\u4e5f\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u3002\u5982\u679c\u6211\u4eec\u8ba1\u7b97\u8fd9\u6761 ROC \u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff0c\u5c31\u662f\u5728\u8ba1\u7b97\u53e6\u4e00\u4e2a\u6307\u6807\uff0c\u5f53\u6570\u636e\u96c6\u7684\u4e8c\u5143\u76ee\u6807\u504f\u659c\u65f6\uff0c\u8fd9\u4e2a\u6307\u6807\u5c31\u4f1a\u975e\u5e38\u5e38\u7528\u3002 \u8fd9\u4e2a\u6307\u6807\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u6216\u66f2\u7ebf\u4e0b\u9762\u79ef\uff0c\u7b80\u79f0 AUC\u3002\u8ba1\u7b97 ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002\u5728\u6b64\uff0c\u6211\u4eec\u5c06\u91c7\u7528 scikit- learn \u7684\u5947\u5999\u5b9e\u73b0\u65b9\u6cd5\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: metrics . roc_auc_score ( y_true , y_pred ) Out [ X ]: 0.8300000000000001 AUC \u503c\u4ece 0 \u5230 1 \u4e0d\u7b49\u3002 AUC = 1 \u610f\u5473\u7740\u60a8\u62e5\u6709\u4e00\u4e2a\u5b8c\u7f8e\u7684\u6a21\u578b\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u610f\u5473\u7740\u4f60\u5728\u9a8c\u8bc1\u65f6\u72af\u4e86\u4e00\u4e9b\u9519\u8bef\uff0c\u5e94\u8be5\u91cd\u65b0\u5ba1\u89c6\u6570\u636e\u5904\u7406\u548c\u9a8c\u8bc1\u6d41\u7a0b\u3002\u5982\u679c\u4f60\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\uff0c\u90a3\u4e48\u606d\u559c\u4f60\uff0c\u4f60\u5df2\u7ecf\u62e5\u6709\u4e86\u9488\u5bf9\u6570\u636e\u96c6\u5efa\u7acb\u7684\u6700\u4f73\u6a21\u578b\u3002 AUC = 0 \u610f\u5473\u7740\u60a8\u7684\u6a21\u578b\u975e\u5e38\u7cdf\u7cd5\uff08\u6216\u975e\u5e38\u597d\uff01\uff09\u3002\u8bd5\u7740\u53cd\u8f6c\u9884\u6d4b\u7684\u6982\u7387\uff0c\u4f8b\u5982\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u6b63\u7c7b\u7684\u6982\u7387\u662f p\uff0c\u8bd5\u7740\u7528 1-p \u4ee3\u66ff\u5b83\u3002\u8fd9\u79cd AUC \u4e5f\u53ef\u80fd\u610f\u5473\u7740\u60a8\u7684\u9a8c\u8bc1\u6216\u6570\u636e\u5904\u7406\u5b58\u5728\u95ee\u9898\u3002 AUC = 0.5 \u610f\u5473\u7740\u4f60\u7684\u9884\u6d4b\u662f\u968f\u673a\u7684\u3002\u56e0\u6b64\uff0c\u5bf9\u4e8e\u4efb\u4f55\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u5982\u679c\u6211\u5c06\u6240\u6709\u76ee\u6807\u90fd\u9884\u6d4b\u4e3a 0.5\uff0c\u6211\u5c06\u5f97\u5230 0.5 \u7684 AUC\u3002 AUC \u503c\u4ecb\u4e8e 0 \u548c 0.5 \u4e4b\u95f4\uff0c\u610f\u5473\u7740\u4f60\u7684\u6a21\u578b\u6bd4\u968f\u673a\u6a21\u578b\u66f4\u5dee\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u662f\u56e0\u4e3a\u4f60\u98a0\u5012\u4e86\u7c7b\u522b\u3002 \u5982\u679c\u60a8\u5c1d\u8bd5\u53cd\u8f6c\u9884\u6d4b\uff0c\u60a8\u7684 AUC \u503c\u53ef\u80fd\u4f1a\u8d85\u8fc7 0.5\u3002\u63a5\u8fd1 1 \u7684 AUC \u503c\u88ab\u8ba4\u4e3a\u662f\u597d\u503c\u3002 \u4f46 AUC \u5bf9\u6211\u4eec\u7684\u6a21\u578b\u6709\u4ec0\u4e48\u5f71\u54cd\u5462\uff1f \u5047\u8bbe\u60a8\u5efa\u7acb\u4e86\u4e00\u4e2a\u4ece\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e2d\u68c0\u6d4b\u6c14\u80f8\u7684\u6a21\u578b\uff0c\u5176 AUC \u503c\u4e3a 0.85\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5982\u679c\u60a8\u4ece\u6570\u636e\u96c6\u4e2d\u968f\u673a\u9009\u62e9\u4e00\u5f20\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9633\u6027\u6837\u672c\uff09\u548c\u53e6\u4e00\u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9634\u6027\u6837\u672c\uff09\uff0c\u90a3\u4e48\u6c14\u80f8\u56fe\u50cf\u7684\u6392\u540d\u5c06\u9ad8\u4e8e\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u6982\u7387\u4e3a 0.85\u3002 \u8ba1\u7b97\u6982\u7387\u548c AUC \u540e\uff0c\u60a8\u9700\u8981\u5bf9\u6d4b\u8bd5\u96c6\u8fdb\u884c\u9884\u6d4b\u3002\u6839\u636e\u95ee\u9898\u548c\u4f7f\u7528\u60c5\u51b5\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u6982\u7387\u6216\u5b9e\u9645\u7c7b\u522b\u3002\u5982\u679c\u4f60\u60f3\u8981\u6982\u7387\uff0c\u8fd9\u5e76\u4e0d\u96be\u3002\u5982\u679c\u60a8\u60f3\u8981\u7c7b\u522b\uff0c\u5219\u9700\u8981\u9009\u62e9\u4e00\u4e2a\u9608\u503c\u3002\u5728\u4e8c\u5143\u5206\u7c7b\u7684\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u91c7\u7528\u7c7b\u4f3c\u4e0b\u9762\u7684\u65b9\u6cd5\u3002 \\[ Prediction = Probability >= Threshold \\] \u4e5f\u5c31\u662f\u8bf4\uff0c\u9884\u6d4b\u662f\u4e00\u4e2a\u53ea\u5305\u542b\u4e8c\u5143\u53d8\u91cf\u7684\u65b0\u5217\u8868\u3002\u5982\u679c\u6982\u7387\u5927\u4e8e\u6216\u7b49\u4e8e\u7ed9\u5b9a\u7684\u9608\u503c\uff0c\u5219\u9884\u6d4b\u4e2d\u7684\u4e00\u9879\u4e3a 1\uff0c\u5426\u5219\u4e3a 0\u3002 \u4f60\u731c\u600e\u4e48\u7740\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 ROC \u66f2\u7ebf\u6765\u9009\u62e9\u8fd9\u4e2a\u9608\u503c\uff01ROC \u66f2\u7ebf\u4f1a\u544a\u8bc9\u60a8\u9608\u503c\u5bf9\u5047\u9633\u6027\u7387\u548c\u771f\u9633\u6027\u7387\u7684\u5f71\u54cd\uff0c\u8fdb\u800c\u5f71\u54cd\u5047\u9633\u6027\u548c\u771f\u9633\u6027\u3002\u60a8\u5e94\u8be5\u9009\u62e9\u6700\u9002\u5408\u60a8\u7684\u95ee\u9898\u548c\u6570\u636e\u96c6\u7684\u9608\u503c\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u4e0d\u5e0c\u671b\u6709\u592a\u591a\u7684\u8bef\u62a5\uff0c\u90a3\u4e48\u9608\u503c\u5c31\u5e94\u8be5\u9ad8\u4e00\u4e9b\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e5f\u4f1a\u5e26\u6765\u66f4\u591a\u7684\u8bef\u62a5\u3002\u6ce8\u610f\u6743\u8861\u5229\u5f0a\uff0c\u9009\u62e9\u6700\u4f73\u9608\u503c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u9608\u503c\u5982\u4f55\u5f71\u54cd\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u503c\u3002 # \u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list = [] # \u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list = [] # \u771f\u5b9e\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 temp_tp = true_positive ( y_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 temp_fp = false_positive ( y_true , temp_pred ) # \u52a0\u5165\u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list . append ( temp_tp ) # \u52a0\u5165\u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list . append ( temp_fp ) \u5229\u7528\u8fd9\u4e00\u70b9\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8868\u683c\uff0c\u5982\u56fe 5 \u6240\u793a\u3002 \u56fe 5\uff1a\u4e0d\u540c\u9608\u503c\u7684 TP \u503c\u548c FP \u503c \u5982\u56fe 6 \u6240\u793a\uff0c\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0cROC \u66f2\u7ebf\u5de6\u4e0a\u89d2\u7684\u503c\u5e94\u8be5\u662f\u4e00\u4e2a\u76f8\u5f53\u4e0d\u9519\u7684\u9608\u503c\u3002 \u5bf9\u6bd4\u8868\u683c\u548c ROC \u66f2\u7ebf\uff0c\u6211\u4eec\u53ef\u4ee5\u53d1\u73b0\uff0c0.6 \u5de6\u53f3\u7684\u9608\u503c\u76f8\u5f53\u4e0d\u9519\uff0c\u65e2\u4e0d\u4f1a\u4e22\u5931\u5927\u91cf\u7684\u771f\u9633\u6027\u7ed3\u679c\uff0c\u4e5f\u4e0d\u4f1a\u51fa\u73b0\u5927\u91cf\u7684\u5047\u9633\u6027\u7ed3\u679c\u3002 \u56fe 6\uff1a\u4ece ROC \u66f2\u7ebf\u6700\u5de6\u4fa7\u7684\u9876\u70b9\u9009\u62e9\u6700\u4f73\u9608\u503c AUC \u662f\u4e1a\u5185\u5e7f\u6cdb\u5e94\u7528\u4e8e\u504f\u659c\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6307\u6807\uff0c\u4e5f\u662f\u6bcf\u4e2a\u4eba\u90fd\u5e94\u8be5\u4e86\u89e3\u7684\u6307\u6807\u3002\u4e00\u65e6\u7406\u89e3\u4e86 AUC \u80cc\u540e\u7684\u7406\u5ff5\uff08\u5982\u4e0a\u6587\u6240\u8ff0\uff09\uff0c\u4e5f\u5c31\u5f88\u5bb9\u6613\u5411\u4e1a\u754c\u53ef\u80fd\u4f1a\u8bc4\u4f30\u60a8\u7684\u6a21\u578b\u7684\u975e\u6280\u672f\u4eba\u5458\u89e3\u91ca\u5b83\u4e86\u3002 \u5b66\u4e60 AUC \u540e\uff0c\u4f60\u5e94\u8be5\u5b66\u4e60\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u6307\u6807\u662f\u5bf9\u6570\u635f\u5931\u3002\u5bf9\u4e8e\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5c06\u5bf9\u6570\u635f\u5931\u5b9a\u4e49\u4e3a\uff1a \\[ LogLoss = -1.0 \\times (target \\times log(prediction) + (1-target) \\times log(1-prediction)) \\] \u5176\u4e2d\uff0c\u76ee\u6807\u503c\u4e3a 0 \u6216 1\uff0c\u9884\u6d4b\u503c\u4e3a\u6837\u672c\u5c5e\u4e8e\u7c7b\u522b 1 \u7684\u6982\u7387\u3002 \u5bf9\u4e8e\u6570\u636e\u96c6\u4e2d\u7684\u591a\u4e2a\u6837\u672c\uff0c\u6240\u6709\u6837\u672c\u7684\u5bf9\u6570\u635f\u5931\u53ea\u662f\u6240\u6709\u5355\u4e2a\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u9700\u8981\u8bb0\u4f4f\u7684\u4e00\u70b9\u662f\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u4e0d\u6b63\u786e\u6216\u504f\u5dee\u8f83\u5927\u7684\u9884\u6d4b\u8fdb\u884c\u76f8\u5f53\u9ad8\u7684\u60e9\u7f5a\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u975e\u5e38\u786e\u5b9a\u548c\u975e\u5e38\u9519\u8bef\u7684\u9884\u6d4b\u8fdb\u884c\u60e9\u7f5a\u3002 import numpy as np def log_loss ( y_true , y_proba ): # \u6781\u5c0f\u503c\uff0c\u9632\u6b620\u505a\u5206\u6bcd epsilon = 1e-15 # \u5bf9\u6570\u635f\u5931\u5217\u8868 loss = [] # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_proba ): # \u9650\u5236yp\u8303\u56f4\uff0c\u6700\u5c0f\u4e3aepsilon\uff0c\u6700\u5927\u4e3a1-epsilon yp = np . clip ( yp , epsilon , 1 - epsilon ) # \u8ba1\u7b97\u5bf9\u6570\u635f\u5931 temp_loss = - 1.0 * ( yt * np . log ( yp ) + ( 1 - yt ) * np . log ( 1 - yp )) # \u52a0\u5165\u5bf9\u6570\u635f\u5931\u5217\u8868 loss . append ( temp_loss ) return np . mean ( loss ) \u8ba9\u6211\u4eec\u6d4b\u8bd5\u4e00\u4e0b\u51fd\u6570\u6267\u884c\u60c5\u51b5\uff1a In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_proba = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u6211\u4eec\u53ef\u4ee5\u5c06\u5176\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u5b9e\u73b0\u662f\u6b63\u786e\u7684\u3002 \u5bf9\u6570\u635f\u5931\u7684\u5b9e\u73b0\u5f88\u5bb9\u6613\u3002\u89e3\u91ca\u8d77\u6765\u4f3c\u4e4e\u6709\u70b9\u56f0\u96be\u3002\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5bf9\u6570\u635f\u5931\u7684\u60e9\u7f5a\u8981\u6bd4\u5176\u4ed6\u6307\u6807\u5927\u5f97\u591a\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u6709 51% \u7684\u628a\u63e1\u8ba4\u4e3a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\uff0c\u90a3\u4e48\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.51) + (1 - 1) \\times log(1 - 0.51))=0.67 \\] \u5982\u679c\u4f60\u5bf9\u5c5e\u4e8e 0 \u7c7b\u7684\u6837\u672c\u6709 49% \u7684\u628a\u63e1\uff0c\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.49) + (1 - 1) \\times log(1 - 0.49))=0.67 \\] \u56e0\u6b64\uff0c\u5373\u4f7f\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0.5 \u7684\u622a\u65ad\u503c\u5e76\u5f97\u5230\u5b8c\u7f8e\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u6211\u4eec\u4ecd\u7136\u4f1a\u6709\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002\u56e0\u6b64\uff0c\u5728\u5904\u7406\u5bf9\u6570\u635f\u5931\u65f6\uff0c\u4f60\u9700\u8981\u975e\u5e38\u5c0f\u5fc3\uff1b\u4efb\u4f55\u4e0d\u786e\u5b9a\u7684\u9884\u6d4b\u90fd\u4f1a\u4ea7\u751f\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002 \u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u5927\u591a\u6570\u6307\u6807\u90fd\u53ef\u4ee5\u8f6c\u6362\u6210\u591a\u7c7b\u7248\u672c\u3002\u8fd9\u4e2a\u60f3\u6cd5\u5f88\u7b80\u5355\u3002\u4ee5\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u4e3a\u4f8b\u3002\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u6bcf\u4e00\u7c7b\u7684\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u6709\u4e09\u79cd\u4e0d\u540c\u7684\u8ba1\u7b97\u65b9\u6cd5\uff0c\u6709\u65f6\u53ef\u80fd\u4f1a\u4ee4\u4eba\u56f0\u60d1\u3002\u5047\u8bbe\u6211\u4eec\u9996\u5148\u5bf9\u7cbe\u786e\u7387\u611f\u5174\u8da3\u3002\u6211\u4eec\u77e5\u9053\uff0c\u7cbe\u786e\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u3002 \u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Macro averaged precision\uff09\uff1a\u5206\u522b\u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u7cbe\u786e\u7387\u7136\u540e\u6c42\u5e73\u5747\u503c \u5fae\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Micro averaged precision\uff09\uff1a\u6309\u7c7b\u8ba1\u7b97\u771f\u9633\u6027\u548c\u5047\u9633\u6027\uff0c\u7136\u540e\u7528\u5176\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387\u3002\u7136\u540e\u4ee5\u6b64\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387 \u52a0\u6743\u7cbe\u786e\u7387 \uff08Weighted precision\uff09\uff1a\u4e0e\u5b8f\u89c2\u7cbe\u786e\u7387\u76f8\u540c\uff0c\u4f46\u8fd9\u91cc\u662f\u52a0\u6743\u5e73\u5747\u7cbe\u786e\u7387 \u53d6\u51b3\u4e8e\u6bcf\u4e2a\u7c7b\u522b\u4e2d\u7684\u9879\u76ee\u6570 \u8fd9\u770b\u4f3c\u590d\u6742\uff0c\u4f46\u5728 python \u5b9e\u73b0\u4e2d\u5f88\u5bb9\u6613\u7406\u89e3\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import numpy as np def macro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u5982\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u8ba1\u7b97\u7cbe\u786e\u5ea6 temp_precision = tp / ( tp + fp ) # \u5404\u7c7b\u7cbe\u786e\u7387\u76f8\u52a0 precision += temp_precision # \u8ba1\u7b97\u5e73\u5747\u503c precision /= num_classes return precision \u4f60\u4f1a\u53d1\u73b0\u8fd9\u5e76\u4e0d\u96be\u3002\u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5fae\u5e73\u5747\u7cbe\u786e\u7387\u5206\u6570\u3002 import numpy as np def micro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u6570 tp = 0 # \u521d\u59cb\u5316\u5047\u9633\u6027\u6837\u672c\u6570 fp = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 tp += true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 fp += false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8fd9\u4e5f\u4e0d\u96be\u3002\u90a3\u4ec0\u4e48\u96be\uff1f\u4ec0\u4e48\u90fd\u4e0d\u96be\u3002\u673a\u5668\u5b66\u4e60\u5f88\u7b80\u5355\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u52a0\u6743\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u3002 from collections import Counter import numpy as np def weighted_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 temp_precision = tp / ( tp + fp ) # \u6839\u636e\u8be5\u79cd\u7c7b\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_precision = class_counts [ class_ ] * temp_precision # \u52a0\u6743\u7cbe\u786e\u7387\u6c42\u548c precision += weighted_precision # \u8ba1\u7b97\u5e73\u5747\u7cbe\u786e\u7387 overall_precision = precision / len ( y_true ) return overall_precision \u5c06\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff0c\u4ee5\u4e86\u89e3\u5b9e\u73b0\u662f\u5426\u6b63\u786e\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: macro_precision ( y_true , y_pred ) Out [ X ]: 0.3611111111111111 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"macro\" ) Out [ X ]: 0.3611111111111111 In [ X ]: micro_precision ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"micro\" ) Out [ X ]: 0.4444444444444444 In [ X ]: weighted_precision ( y_true , y_pred ) Out [ X ]: 0.39814814814814814 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.39814814814814814 \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6b63\u786e\u5730\u5b9e\u73b0\u4e86\u4e00\u5207\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u91cc\u5c55\u793a\u7684\u5b9e\u73b0\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\uff0c\u4f46\u5374\u662f\u6700\u5bb9\u6613\u7406\u89e3\u7684\u3002 \u540c\u6837\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5b9e\u73b0 \u591a\u7c7b\u522b\u7684\u53ec\u56de\u7387\u6307\u6807 \u3002\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u3001\u5047\u9633\u6027\u548c\u5047\u9634\u6027\uff0c\u800c F1 \u5219\u53d6\u51b3\u4e8e\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u53ec\u56de\u7387\u7684\u5b9e\u73b0\u65b9\u6cd5\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\uff0c\u8fd9\u91cc\u5b9e\u73b0\u7684\u662f\u591a\u7c7b F1 \u7684\u4e00\u4e2a\u7248\u672c\uff0c\u5373\u52a0\u6743\u5e73\u5747\u503c\u3002 from collections import Counter import numpy as np def weighted_f1 ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316F1\u503c f1 = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( temp_true , temp_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( temp_true , temp_pred ) # \u82e5\u7cbe\u786e\u7387+\u53ec\u56de\u7387\u4e0d\u4e3a0\uff0c\u5219\u4f7f\u7528\u516c\u5f0f\u8ba1\u7b97F1\u503c if p + r != 0 : temp_f1 = 2 * p * r / ( p + r ) # \u5426\u5219\u76f4\u63a5\u4e3a0 else : temp_f1 = 0 # \u6839\u636e\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_f1 = class_counts [ class_ ] * temp_f1 # \u52a0\u6743F1\u503c\u76f8\u52a0 f1 += weighted_f1 # \u8ba1\u7b97\u52a0\u6743\u5e73\u5747F1\u503c overall_f1 = f1 / len ( y_true ) return overall_f1 \u8bf7\u6ce8\u610f\uff0c\u4e0a\u9762\u6709\u51e0\u884c\u4ee3\u7801\u662f\u65b0\u5199\u7684\u3002\u56e0\u6b64\uff0c\u4f60\u5e94\u8be5\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e9b\u4ee3\u7801\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: weighted_f1 ( y_true , y_pred ) Out [ X ]: 0.41269841269841273 In [ X ]: metrics . f1_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.41269841269841273 \u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u591a\u7c7b\u95ee\u9898\u5b9e\u73b0\u4e86\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002\u540c\u6837\uff0c\u60a8\u4e5f\u53ef\u4ee5\u5c06 AUC \u548c\u5bf9\u6570\u635f\u5931\u8f6c\u6362\u4e3a\u591a\u7c7b\u683c\u5f0f\u3002\u8fd9\u79cd\u8f6c\u6362\u683c\u5f0f\u88ab\u79f0\u4e3a one-vs-all \u3002\u8fd9\u91cc\u6211\u4e0d\u6253\u7b97\u5b9e\u73b0\u5b83\u4eec\uff0c\u56e0\u4e3a\u5b9e\u73b0\u65b9\u6cd5\u4e0e\u6211\u4eec\u5df2\u7ecf\u8ba8\u8bba\u8fc7\u7684\u5f88\u76f8\u4f3c\u3002 \u5728\u4e8c\u5143\u6216\u591a\u7c7b\u5206\u7c7b\u4e2d\uff0c\u770b\u4e00\u4e0b \u6df7\u6dc6\u77e9\u9635 \u4e5f\u5f88\u6d41\u884c\u3002\u4e0d\u8981\u56f0\u60d1\uff0c\u8fd9\u5f88\u7b80\u5355\u3002\u6df7\u6dc6\u77e9\u9635\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u5305\u542b TP\u3001FP\u3001TN \u548c FN \u7684\u8868\u683c\u3002\u4f7f\u7528\u6df7\u6dc6\u77e9\u9635\uff0c\u60a8\u53ef\u4ee5\u5feb\u901f\u67e5\u770b\u6709\u591a\u5c11\u6837\u672c\u88ab\u9519\u8bef\u5206\u7c7b\uff0c\u6709\u591a\u5c11\u6837\u672c\u88ab\u6b63\u786e\u5206\u7c7b\u3002\u4e5f\u8bb8\u6709\u4eba\u4f1a\u8bf4\uff0c\u6df7\u6dc6\u77e9\u9635\u5e94\u8be5\u5728\u672c\u7ae0\u5f88\u65e9\u5c31\u8bb2\u5230\uff0c\u4f46\u6211\u6ca1\u6709\u8fd9\u4e48\u505a\u3002\u5982\u679c\u4e86\u89e3\u4e86 TP\u3001FP\u3001TN\u3001FN\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c AUC\uff0c\u5c31\u5f88\u5bb9\u6613\u7406\u89e3\u548c\u89e3\u91ca\u6df7\u6dc6\u77e9\u9635\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u56fe 7 \u4e2d\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6df7\u6dc6\u77e9\u9635\u7531 TP\u3001FP\u3001FN \u548c TN \u7ec4\u6210\u3002\u6211\u4eec\u53ea\u9700\u8981\u8fd9\u4e9b\u503c\u6765\u8ba1\u7b97\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u3001F1 \u5206\u6570\u548c AUC\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u4e5f\u559c\u6b22\u628a FP \u79f0\u4e3a \u7b2c\u4e00\u7c7b\u9519\u8bef \uff0c\u628a FN \u79f0\u4e3a \u7b2c\u4e8c\u7c7b\u9519\u8bef \u3002 \u56fe 7\uff1a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6df7\u6dc6\u77e9\u9635 \u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u4e8c\u5143\u6df7\u6dc6\u77e9\u9635\u6269\u5c55\u4e3a\u591a\u7c7b\u6df7\u6dc6\u77e9\u9635\u3002\u5b83\u4f1a\u662f\u4ec0\u4e48\u6837\u5b50\u5462\uff1f\u5982\u679c\u6211\u4eec\u6709 N \u4e2a\u7c7b\u522b\uff0c\u5b83\u5c06\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a NxN \u7684\u77e9\u9635\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b\uff0c\u6211\u4eec\u90fd\u8981\u8ba1\u7b97\u76f8\u5173\u7c7b\u522b\u548c\u5176\u4ed6\u7c7b\u522b\u7684\u6837\u672c\u603b\u6570\u3002\u4e3e\u4e2a\u4f8b\u5b50\u53ef\u4ee5\u8ba9\u6211\u4eec\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4ee5\u4e0b\u771f\u5b9e\u6807\u7b7e\uff1a \\[ [0, 1, 2, 0, 1, 2, 0, 2, 2] \\] \u6211\u4eec\u7684\u9884\u6d4b\u6807\u7b7e\u662f\uff1a \\[ [0, 2, 1, 0, 2, 1, 0, 0, 2] \\] \u90a3\u4e48\uff0c\u6211\u4eec\u7684\u6df7\u6dc6\u77e9\u9635\u5c06\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u591a\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635 \u56fe 8 \u8bf4\u660e\u4e86\u4ec0\u4e48\uff1f \u8ba9\u6211\u4eec\u6765\u770b\u770b 0 \u7c7b\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e 0 \u7c7b\u3002\u7136\u800c\uff0c\u5728\u9884\u6d4b\u4e2d\uff0c\u6211\u4eec\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 0 \u7c7b\uff0c1 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\u3002\u7406\u60f3\u60c5\u51b5\u4e0b\uff0c\u5bf9\u4e8e\u771f\u5b9e\u6807\u7b7e\u4e2d\u7684\u7c7b\u522b 0\uff0c\u9884\u6d4b\u6807\u7b7e 1 \u548c 2 \u5e94\u8be5\u6ca1\u6709\u4efb\u4f55\u6837\u672c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u7c7b\u522b 2\u3002\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 4\uff0c\u800c\u5728\u9884\u6d4b\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 3\u3002 \u4e00\u4e2a\u5b8c\u7f8e\u7684\u6df7\u6dc6\u77e9\u9635\u53ea\u80fd\u4ece\u5de6\u5230\u53f3\u659c\u5411\u586b\u5145\u3002 \u6df7\u6dc6\u77e9\u9635 \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u8ba1\u7b97\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u4e0d\u540c\u6307\u6807\u3002Scikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u76f4\u63a5\u7684\u65b9\u6cd5\u6765\u751f\u6210\u6df7\u6dc6\u77e9\u9635\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5728\u56fe 8 \u4e2d\u663e\u793a\u7684\u6df7\u6dc6\u77e9\u9635\u662f scikit-learn \u6df7\u6dc6\u77e9\u9635\u7684\u8f6c\u7f6e\uff0c\u539f\u59cb\u7248\u672c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u7ed8\u5236\u3002 import matplotlib.pyplot as plt import seaborn as sns from sklearn import metrics # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] # \u9884\u6d4b\u6837\u672c\u6807\u7b7e y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] # \u8ba1\u7b97\u6df7\u6dc6\u77e9\u9635 cm = metrics . confusion_matrix ( y_true , y_pred ) # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 10 , 10 )) # \u521b\u5efa\u65b9\u683c cmap = sns . cubehelix_palette ( 50 , hue = 0.05 , rot = 0 , light = 0.9 , dark = 0 , as_cmap = True ) # \u89c4\u5b9a\u5b57\u4f53\u5927\u5c0f sns . set ( font_scale = 2.5 ) # \u7ed8\u5236\u70ed\u56fe sns . heatmap ( cm , annot = True , cmap = cmap , cbar = False ) # y\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . ylabel ( 'Actual Labels' , fontsize = 20 ) # x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . xlabel ( 'Predicted Labels' , fontsize = 20 ) \u56e0\u6b64\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u89e3\u51b3\u4e86\u4e8c\u5143\u5206\u7c7b\u548c\u591a\u7c7b\u5206\u7c7b\u7684\u5ea6\u91cf\u95ee\u9898\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u8ba8\u8bba\u53e6\u4e00\u79cd\u7c7b\u578b\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u5373\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u5728\u591a\u6807\u7b7e\u5206\u7c7b\u4e2d\uff0c\u6bcf\u4e2a\u6837\u672c\u90fd\u53ef\u80fd\u4e0e\u4e00\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u76f8\u5173\u8054\u3002\u8fd9\u7c7b\u95ee\u9898\u7684\u4e00\u4e2a\u7b80\u5355\u4f8b\u5b50\u5c31\u662f\u8981\u6c42\u4f60\u9884\u6d4b\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53\u3002 \u56fe 9 \u663e\u793a\u4e86\u4e00\u4e2a\u8457\u540d\u6570\u636e\u96c6\u7684\u56fe\u50cf\u793a\u4f8b\u3002\u8bf7\u6ce8\u610f\uff0c\u8be5\u6570\u636e\u96c6\u7684\u76ee\u6807\u6709\u6240\u4e0d\u540c\uff0c\u4f46\u6211\u4eec\u6682\u4e14\u4e0d\u53bb\u8ba8\u8bba\u5b83\u3002\u6211\u4eec\u5047\u8bbe\u5176\u76ee\u7684\u53ea\u662f\u9884\u6d4b\u56fe\u50cf\u4e2d\u662f\u5426\u5b58\u5728\u67d0\u4e2a\u7269\u4f53\u3002\u5728\u56fe 9 \u4e2d\uff0c\u6211\u4eec\u6709\u6905\u5b50\u3001\u82b1\u76c6\u3001\u7a97\u6237\uff0c\u4f46\u6ca1\u6709\u5176\u4ed6\u7269\u4f53\uff0c\u5982\u7535\u8111\u3001\u5e8a\u3001\u7535\u89c6\u7b49\u3002\u56e0\u6b64\uff0c\u4e00\u5e45\u56fe\u50cf\u53ef\u80fd\u6709\u591a\u4e2a\u76f8\u5173\u76ee\u6807\u3002\u8fd9\u7c7b\u95ee\u9898\u5c31\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u95ee\u9898\u3002 \u56fe 9\uff1a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53 \u8fd9\u7c7b\u5206\u7c7b\u95ee\u9898\u7684\u8861\u91cf\u6807\u51c6\u6709\u4e9b\u4e0d\u540c\u3002\u4e00\u4e9b\u5408\u9002\u7684 \u6700\u5e38\u89c1\u7684\u6307\u6807\u6709\uff1a k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u786e\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 \u8ba9\u6211\u4eec\u4ece k \u7cbe\u786e\u7387\u6216\u8005 P@k \u6211\u4eec\u4e0d\u80fd\u5c06\u8fd9\u4e00\u7cbe\u786e\u7387\u4e0e\u524d\u9762\u8ba8\u8bba\u7684\u7cbe\u786e\u7387\u6df7\u6dc6\u3002\u5982\u679c\u60a8\u6709\u4e00\u4e2a\u7ed9\u5b9a\u6837\u672c\u7684\u539f\u59cb\u7c7b\u522b\u5217\u8868\u548c\u540c\u4e00\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u7c7b\u522b\u5217\u8868\uff0c\u90a3\u4e48\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u5c31\u662f\u9884\u6d4b\u5217\u8868\u4e2d\u4ec5\u8003\u8651\u524d k \u4e2a\u9884\u6d4b\u7ed3\u679c\u7684\u547d\u4e2d\u6570\u9664\u4ee5 k\u3002 \u5982\u679c\u60a8\u5bf9\u6b64\u611f\u5230\u56f0\u60d1\uff0c\u4f7f\u7528 python \u4ee3\u7801\u540e\u5c31\u4f1a\u660e\u767d\u3002 def pk ( y_true , y_pred , k ): # \u5982\u679ck\u4e3a0 if k == 0 : # \u8fd4\u56de0 return 0 # \u53d6\u9884\u6d4b\u6807\u7b7e\u524dk\u4e2a y_pred = y_pred [: k ] # \u5c06\u9884\u6d4b\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 pred_set = set ( y_pred ) # \u5c06\u771f\u5b9e\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 true_set = set ( y_true ) # \u9884\u6d4b\u6807\u7b7e\u96c6\u5408\u4e0e\u771f\u5b9e\u6807\u7b7e\u96c6\u5408\u4ea4\u96c6 common_values = pred_set . intersection ( true_set ) # \u8ba1\u7b97\u7cbe\u786e\u7387 return len ( common_values ) / len ( y_pred [: k ]) \u6709\u4e86\u4ee3\u7801\uff0c\u4e00\u5207\u90fd\u53d8\u5f97\u66f4\u5bb9\u6613\u7406\u89e3\u4e86\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e86 k \u5e73\u5747\u7cbe\u786e\u7387\u6216 AP@k \u3002AP@k \u662f\u901a\u8fc7 P@k \u8ba1\u7b97\u5f97\u51fa\u7684\u3002\u4f8b\u5982\uff0c\u5982\u679c\u8981\u8ba1\u7b97 AP@3\uff0c\u6211\u4eec\u8981\u5148\u8ba1\u7b97 P@1\u3001P@2 \u548c P@3\uff0c\u7136\u540e\u5c06\u603b\u548c\u9664\u4ee5 3\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u7684\u5b9e\u73b0\u3002 def apk ( y_true , y_pred , k ): # \u521d\u59cb\u5316P@k\u5217\u8868 pk_values = [] # \u904d\u53861~k for i in range ( 1 , k + 1 ): # \u5c06P@k\u52a0\u5165\u5217\u8868 pk_values . append ( pk ( y_true , y_pred , i )) # \u82e5\u957f\u5ea6\u4e3a0 if len ( pk_values ) == 0 : # \u8fd4\u56de0 return 0 # \u5426\u5219\u8ba1\u7b97AP@K return sum ( pk_values ) / len ( pk_values ) \u8fd9\u4e24\u4e2a\u51fd\u6570\u53ef\u4ee5\u7528\u6765\u8ba1\u7b97\u4e24\u4e2a\u7ed9\u5b9a\u5217\u8868\u7684 k \u5e73\u5747\u7cbe\u786e\u7387 (AP@k)\uff1b\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u8ba1\u7b97\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: for i in range ( len ( y_true )): ... : for j in range ( 1 , 4 ): ... : print ( ... : f \"\"\" ...: y_true= { y_true [ i ] } , ...: y_pred= { y_pred [ i ] } , ...: AP@ { j } = { apk ( y_true [ i ], y_pred [ i ], k = j ) } ...: \"\"\" ... : ) ... : y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 1 = 0.0 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 2 = 0.25 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 3 = 0.38888888888888884 \u8bf7\u6ce8\u610f\uff0c\u6211\u7701\u7565\u4e86\u8f93\u51fa\u7ed3\u679c\u4e2d\u7684\u8bb8\u591a\u6570\u503c\uff0c\u4f46\u4f60\u4f1a\u660e\u767d\u5176\u4e2d\u7684\u610f\u601d\u3002\u8fd9\u5c31\u662f\u6211\u4eec\u5982\u4f55\u8ba1\u7b97 AP@k \u7684\u65b9\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u7684 AP@k\u3002\u5728\u673a\u5668\u5b66\u4e60\u4e2d\uff0c\u6211\u4eec\u5bf9\u6240\u6709\u6837\u672c\u90fd\u611f\u5174\u8da3\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u6709 \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387 k \u6216 MAP@k \u3002MAP@k \u53ea\u662f AP@k \u7684\u5e73\u5747\u503c\uff0c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b python \u4ee3\u7801\u8f7b\u677e\u8ba1\u7b97\u3002 def mapk ( y_true , y_pred , k ): # \u521d\u59cb\u5316AP@k\u5217\u8868 apk_values = [] # \u904d\u53860~\uff08\u771f\u5b9e\u6807\u7b7e\u6570-1\uff09 for i in range ( len ( y_true )): # \u5c06AP@K\u52a0\u5165\u5217\u8868 apk_values . append ( apk ( y_true [ i ], y_pred [ i ], k = k ) ) # \u8ba1\u7b97\u5e73\u5747AP@k return sum ( apk_values ) / len ( apk_values ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u9488\u5bf9\u76f8\u540c\u7684\u5217\u8868\u8ba1\u7b97 k=1\u30012\u30013 \u548c 4 \u65f6\u7684 MAP@k\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: mapk ( y_true , y_pred , k = 1 ) Out [ X ]: 0.3333333333333333 In [ X ]: mapk ( y_true , y_pred , k = 2 ) Out [ X ]: 0.375 In [ X ]: mapk ( y_true , y_pred , k = 3 ) Out [ X ]: 0.3611111111111111 In [ X ]: mapk ( y_true , y_pred , k = 4 ) Out [ X ]: 0.34722222222222215 P@k\u3001AP@k \u548c MAP@k \u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u5176\u4e2d 1 \u4e3a\u6700\u4f73\u3002 \u8bf7\u6ce8\u610f\uff0c\u6709\u65f6\u60a8\u53ef\u80fd\u4f1a\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230 P@k \u548c AP@k \u7684\u4e0d\u540c\u5b9e\u73b0\u65b9\u5f0f\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5176\u4e2d\u4e00\u79cd\u5b9e\u73b0\u65b9\u5f0f\u3002 import numpy as np def apk ( actual , predicted , k = 10 ): # \u82e5\u9884\u6d4b\u6807\u7b7e\u957f\u5ea6\u5927\u4e8ek if len ( predicted ) > k : # \u53d6\u524dk\u4e2a\u6807\u7b7e predicted = predicted [: k ] score = 0.0 num_hits = 0.0 for i , p in enumerate ( predicted ): if p in actual and p not in predicted [: i ]: num_hits += 1.0 score += num_hits / ( i + 1.0 ) if not actual : return 0.0 return score / min ( len ( actual ), k ) \u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u662f AP@k \u7684\u53e6\u4e00\u4e2a\u7248\u672c\uff0c\u5176\u4e2d\u987a\u5e8f\u5f88\u91cd\u8981\uff0c\u6211\u4eec\u8981\u6743\u8861\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u7684\u7ed3\u679c\u4e0e\u6211\u7684\u4ecb\u7ecd\u7565\u6709\u4e0d\u540c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6765\u770b\u770b \u591a\u6807\u7b7e\u5206\u7c7b\u7684\u5bf9\u6570\u635f\u5931 \u3002\u8fd9\u5f88\u5bb9\u6613\u3002\u60a8\u53ef\u4ee5\u5c06\u76ee\u6807\u8f6c\u6362\u4e3a\u4e8c\u5143\u5206\u7c7b\uff0c\u7136\u540e\u5bf9\u6bcf\u4e00\u5217\u4f7f\u7528\u5bf9\u6570\u635f\u5931\u3002\u6700\u540e\uff0c\u4f60\u53ef\u4ee5\u6c42\u51fa\u6bcf\u5217\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u8fd9\u4e5f\u88ab\u79f0\u4e3a\u5e73\u5747\u5217\u5bf9\u6570\u635f\u5931\u3002\u5f53\u7136\uff0c\u8fd8\u6709\u5176\u4ed6\u65b9\u6cd5\u53ef\u4ee5\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff0c\u4f60\u5e94\u8be5\u5728\u9047\u5230\u65f6\u52a0\u4ee5\u63a2\u7d22\u3002 \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u8bf4\u5df2\u7ecf\u638c\u63e1\u4e86\u6240\u6709\u4e8c\u5143\u5206\u7c7b\u3001\u591a\u7c7b\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u6307\u6807\uff0c\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u6307\u6807\u3002 \u56de\u5f52\u4e2d\u6700\u5e38\u89c1\u7684\u6307\u6807\u662f \u8bef\u5dee\uff08Error\uff09 \u3002\u8bef\u5dee\u5f88\u7b80\u5355\uff0c\u4e5f\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \\[ Error = True\\ Value - Predicted\\ Value \\] \u7edd\u5bf9\u8bef\u5dee\uff08Absolute error\uff09 \u53ea\u662f\u4e0a\u8ff0\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\u3002 \\[ Absolute\\ Error = Abs(True\\ Value - Predicted\\ Value) \\] \u63a5\u4e0b\u6765\u6211\u4eec\u8ba8\u8bba \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff08MAE\uff09 \u3002\u5b83\u53ea\u662f\u6240\u6709\u7edd\u5bf9\u8bef\u5dee\u7684\u5e73\u5747\u503c\u3002 import numpy as np def mean_absolute_error ( y_true , y_pred ): #\u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u7edd\u5bf9\u8bef\u5dee error += np . abs ( yt - yp ) # \u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee return error / len ( y_true ) \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5e73\u65b9\u8bef\u5dee\u548c \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u3002 \\[ Squared\\ Error = (True Value - Predicted\\ Value)^2 \\] \u5747\u65b9\u8bef\u5dee\uff08MSE\uff09\u7684\u8ba1\u7b97\u65b9\u5f0f\u5982\u4e0b def mean_squared_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u8bef\u5dee\u5e73\u65b9\u548c error += ( yt - yp ) ** 2 # \u8ba1\u7b97\u5747\u65b9\u8bef\u5dee return error / len ( y_true ) MSE \u548c RMSE\uff08\u5747\u65b9\u6839\u8bef\u5dee\uff09 \u662f\u8bc4\u4f30\u56de\u5f52\u6a21\u578b\u6700\u5e38\u7528\u7684\u6307\u6807\u3002 \\[ RMSE = SQRT(MSE) \\] \u540c\u4e00\u7c7b\u8bef\u5dee\u7684\u53e6\u4e00\u79cd\u7c7b\u578b\u662f \u5e73\u65b9\u5bf9\u6570\u8bef\u5dee \u3002\u6709\u4eba\u79f0\u5176\u4e3a SLE \uff0c\u5f53\u6211\u4eec\u53d6\u6240\u6709\u6837\u672c\u4e2d\u8fd9\u4e00\u8bef\u5dee\u7684\u5e73\u5747\u503c\u65f6\uff0c\u5b83\u88ab\u79f0\u4e3a MSLE\uff08\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee\uff09\uff0c\u5b9e\u73b0\u65b9\u6cd5\u5982\u4e0b\u3002 import numpy as np def mean_squared_log_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee error += ( np . log ( 1 + yt ) - np . log ( 1 + yp )) ** 2 # \u8ba1\u7b97\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee return error / len ( y_true ) \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \u53ea\u662f\u5176\u5e73\u65b9\u6839\u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a RMSLE \u3002 \u7136\u540e\u662f\u767e\u5206\u6bd4\u8bef\u5dee\uff1a \\[ Percentage\\ Error = (( True\\ Value \u2013 Predicted\\ Value ) / True\\ Value ) \\times 100 \\] \u540c\u6837\u53ef\u4ee5\u8f6c\u6362\u4e3a\u6240\u6709\u6837\u672c\u7684\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee\u3002 def mean_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u767e\u5206\u6bd4\u8bef\u5dee error += ( yt - yp ) / yt # \u8fd4\u56de\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u7edd\u5bf9\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\uff08\u4e5f\u662f\u66f4\u5e38\u89c1\u7684\u7248\u672c\uff09\u88ab\u79f0\u4e3a \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee\u6216 MAPE \u3002 import numpy as np def mean_abs_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee error += np . abs ( yt - yp ) / yt #\u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u56de\u5f52\u7684\u6700\u5927\u4f18\u70b9\u662f\uff0c\u53ea\u6709\u51e0\u4e2a\u6700\u5e38\u7528\u7684\u6307\u6807\uff0c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u6240\u6709\u56de\u5f52\u95ee\u9898\u3002\u4e0e\u5206\u7c7b\u6307\u6807\u76f8\u6bd4\uff0c\u56de\u5f52\u6307\u6807\u66f4\u5bb9\u6613\u7406\u89e3\u3002 \u8ba9\u6211\u4eec\u6765\u8c08\u8c08\u53e6\u4e00\u4e2a\u56de\u5f52\u6307\u6807 \\(R^2\\) \uff08R \u65b9\uff09\uff0c\u4e5f\u79f0\u4e3a \u5224\u5b9a\u7cfb\u6570 \u3002 \u7b80\u5355\u5730\u8bf4\uff0cR \u65b9\u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u3002R \u65b9\u63a5\u8fd1 1.0 \u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u76f8\u5f53\u597d\uff0c\u800c\u63a5\u8fd1 0 \u5219\u8868\u793a\u6a21\u578b\u4e0d\u662f\u90a3\u4e48\u597d\u3002\u5f53\u6a21\u578b\u53ea\u662f\u505a\u51fa\u8352\u8c2c\u7684\u9884\u6d4b\u65f6\uff0cR \u65b9\u4e5f\u53ef\u80fd\u662f\u8d1f\u503c\u3002 R \u65b9\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\u6240\u793a\uff0c\u4f46 Python \u7684\u5b9e\u73b0\u603b\u662f\u80fd\u8ba9\u4e00\u5207\u66f4\u52a0\u6e05\u6670\u3002 \\[ R^2 = \\frac{\\sum^{N}_{i=1}(y_{t_i}-y_{p_i})^2}{\\sum^{N}_{i=1}(y_{t_i} - y_{t_{mean}})} \\] import numpy as np def r2 ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u5747\u771f\u5b9e\u503c mean_true_value = np . mean ( y_true ) # \u521d\u59cb\u5316\u5e73\u65b9\u8bef\u5dee numerator = 0 denominator = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): numerator += ( yt - yp ) ** 2 denominator += ( yt - mean_true_value ) ** 2 ratio = numerator / denominator # \u8ba1\u7b97R\u65b9 return 1 \u2013 ratio \u8fd8\u6709\u66f4\u591a\u7684\u8bc4\u4ef7\u6307\u6807\uff0c\u8fd9\u4e2a\u6e05\u5355\u6c38\u8fdc\u4e5f\u5217\u4e0d\u5b8c\u3002\u6211\u53ef\u4ee5\u5199\u4e00\u672c\u4e66\uff0c\u53ea\u4ecb\u7ecd\u4e0d\u540c\u7684\u8bc4\u4ef7\u6307\u6807\u3002\u4e5f\u8bb8\u6211\u4f1a\u7684\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b\u8bc4\u4f30\u6307\u6807\u51e0\u4e4e\u53ef\u4ee5\u6ee1\u8db3\u4f60\u60f3\u5c1d\u8bd5\u89e3\u51b3\u7684\u6240\u6709\u95ee\u9898\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ee5\u6700\u76f4\u63a5\u7684\u65b9\u5f0f\u5b9e\u73b0\u4e86\u8fd9\u4e9b\u6307\u6807\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u4eec\u4e0d\u591f\u9ad8\u6548\u3002\u4f60\u53ef\u4ee5\u901a\u8fc7\u6b63\u786e\u4f7f\u7528 numpy \u4ee5\u975e\u5e38\u9ad8\u6548\u7684\u65b9\u5f0f\u5b9e\u73b0\u5176\u4e2d\u5927\u90e8\u5206\u6307\u6807\u3002\u4f8b\u5982\uff0c\u770b\u770b\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\u7684\u5b9e\u73b0\uff0c\u4e0d\u9700\u8981\u4efb\u4f55\u5faa\u73af\u3002 import numpy as np def mae_np ( y_true , y_pred ): return np . mean ( np . abs ( y_true - y_pred )) \u6211\u672c\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5b9e\u73b0\u6240\u6709\u6307\u6807\uff0c\u4f46\u4e3a\u4e86\u5b66\u4e60\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b\u5e95\u5c42\u5b9e\u73b0\u3002\u4e00\u65e6\u4f60\u5b66\u4f1a\u4e86\u7eaf python \u7684\u5e95\u5c42\u5b9e\u73b0\uff0c\u5e76\u4e14\u4e0d\u4f7f\u7528\u5927\u91cf numpy\uff0c\u4f60\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u5176\u8f6c\u6362\u4e3a numpy\uff0c\u5e76\u4f7f\u5176\u53d8\u5f97\u66f4\u5feb\u3002 \u7136\u540e\u662f\u4e00\u4e9b\u9ad8\u7ea7\u5ea6\u91cf\u3002 \u5176\u4e2d\u4e00\u4e2a\u5e94\u7528\u76f8\u5f53\u5e7f\u6cdb\u7684\u6307\u6807\u662f \u4e8c\u6b21\u52a0\u6743\u5361\u5e15 \uff0c\u4e5f\u79f0\u4e3a QWK \u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a\u79d1\u6069\u5361\u5e15\u3002 QWK \u8861\u91cf\u4e24\u4e2a \"\u8bc4\u5206 \"\u4e4b\u95f4\u7684 \"\u4e00\u81f4\u6027\"\u3002\u8bc4\u5206\u53ef\u4ee5\u662f 0 \u5230 N \u4e4b\u95f4\u7684\u4efb\u4f55\u5b9e\u6570\uff0c\u9884\u6d4b\u4e5f\u5728\u540c\u4e00\u8303\u56f4\u5185\u3002\u4e00\u81f4\u6027\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u8fd9\u4e9b\u8bc4\u7ea7\u4e4b\u95f4\u7684\u63a5\u8fd1\u7a0b\u5ea6\u3002\u56e0\u6b64\uff0c\u5b83\u9002\u7528\u4e8e\u6709 N \u4e2a\u4e0d\u540c\u7c7b\u522b\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u4e00\u81f4\u5ea6\u9ad8\uff0c\u5206\u6570\u5c31\u66f4\u63a5\u8fd1 1.0\u3002Cohen's kappa \u5728 scikit-learn \u4e2d\u6709\u5f88\u597d\u7684\u5b9e\u73b0\uff0c\u5173\u4e8e\u8be5\u6307\u6807\u7684\u8be6\u7ec6\u8ba8\u8bba\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 ] In [ X ]: y_pred = [ 2 , 1 , 3 , 1 , 2 , 3 , 3 , 1 , 2 ] In [ X ]: metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) Out [ X ]: 0.33333333333333337 In [ X ]: metrics . accuracy_score ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 \u60a8\u53ef\u4ee5\u770b\u5230\uff0c\u5c3d\u7ba1\u51c6\u786e\u5ea6\u5f88\u9ad8\uff0c\u4f46 QWK \u5374\u5f88\u4f4e\u3002QWK \u5927\u4e8e 0.85 \u5373\u4e3a\u975e\u5e38\u597d\uff01 \u4e00\u4e2a\u91cd\u8981\u7684\u6307\u6807\u662f \u9a6c\u4fee\u76f8\u5173\u7cfb\u6570\uff08MCC\uff09 \u30021 \u4ee3\u8868\u5b8c\u7f8e\u9884\u6d4b\uff0c-1 \u4ee3\u8868\u4e0d\u5b8c\u7f8e\u9884\u6d4b\uff0c0 \u4ee3\u8868\u968f\u673a\u9884\u6d4b\u3002MCC \u7684\u8ba1\u7b97\u516c\u5f0f\u975e\u5e38\u7b80\u5355\u3002 \\[ MCC = \\frac{TP \\times TN - FP \\times FN}{\\sqrt{(TP + FP) \\times (FN + TN) \\times (FP + TN) \\times (TP + FN)}} \\] \u6211\u4eec\u770b\u5230\uff0cMCC \u8003\u8651\u4e86 TP\u3001FP\u3001TN \u548c FN\uff0c\u56e0\u6b64\u53ef\u7528\u4e8e\u5904\u7406\u7c7b\u504f\u659c\u7684\u95ee\u9898\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u7684\u65b9\u6cd5\u5728 python \u4e2d\u5feb\u901f\u5b9e\u73b0\u5b83\u3002 def mcc ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) numerator = ( tp * tn ) - ( fp * fn ) denominator = ( ( tp + fp ) * ( fn + tn ) * ( fp + tn ) * ( tp + fn ) ) denominator = denominator ** 0.5 return numerator / denominator \u8fd9\u4e9b\u6307\u6807\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002 \u9700\u8981\u6ce8\u610f\u7684\u4e00\u70b9\u662f\uff0c\u5728\u8bc4\u4f30\u975e\u76d1\u7763\u65b9\u6cd5\uff08\u4f8b\u5982\u67d0\u79cd\u805a\u7c7b\uff09\u65f6\uff0c\u6700\u597d\u521b\u5efa\u6216\u624b\u52a8\u6807\u8bb0\u6d4b\u8bd5\u96c6\uff0c\u5e76\u5c06\u5176\u4e0e\u5efa\u6a21\u90e8\u5206\u7684\u6240\u6709\u5185\u5bb9\u5206\u5f00\u3002\u5b8c\u6210\u805a\u7c7b\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u4efb\u4f55\u4e00\u79cd\u76d1\u7763\u5b66\u4e60\u6307\u6807\u6765\u8bc4\u4f30\u6d4b\u8bd5\u96c6\u7684\u6027\u80fd\u4e86\u3002 \u4e00\u65e6\u6211\u4eec\u4e86\u89e3\u4e86\u7279\u5b9a\u95ee\u9898\u5e94\u8be5\u4f7f\u7528\u4ec0\u4e48\u6307\u6807\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f00\u59cb\u66f4\u6df1\u5165\u5730\u7814\u7a76\u6211\u4eec\u7684\u6a21\u578b\uff0c\u4ee5\u6c42\u6539\u8fdb\u3002","title":"\u8bc4\u4f30\u6307\u6807"},{"location":"%E8%B6%85%E5%8F%82%E6%95%B0%E4%BC%98%E5%8C%96/","text":"\u8d85\u53c2\u6570\u4f18\u5316 \u6709\u4e86\u4f18\u79c0\u7684\u6a21\u578b\uff0c\u5c31\u6709\u4e86\u4f18\u5316\u8d85\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u5f97\u5206\u6a21\u578b\u7684\u96be\u9898\u3002\u90a3\u4e48\uff0c\u4ec0\u4e48\u662f\u8d85\u53c2\u6570\u4f18\u5316\u5462\uff1f\u5047\u8bbe\u60a8\u7684\u673a\u5668\u5b66\u4e60\u9879\u76ee\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u6d41\u7a0b\u3002\u6709\u4e00\u4e2a\u6570\u636e\u96c6\uff0c\u4f60\u76f4\u63a5\u5e94\u7528\u4e00\u4e2a\u6a21\u578b\uff0c\u7136\u540e\u5f97\u5230\u7ed3\u679c\u3002\u6a21\u578b\u5728\u8fd9\u91cc\u7684\u53c2\u6570\u88ab\u79f0\u4e3a\u8d85\u53c2\u6570\uff0c\u5373\u63a7\u5236\u6a21\u578b\u8bad\u7ec3/\u62df\u5408\u8fc7\u7a0b\u7684\u53c2\u6570\u3002\u5982\u679c\u6211\u4eec\u7528 SGD \u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u6a21\u578b\u7684\u53c2\u6570\u662f\u659c\u7387\u548c\u504f\u5dee\uff0c\u8d85\u53c2\u6570\u662f\u5b66\u4e60\u7387\u3002\u4f60\u4f1a\u53d1\u73b0\u6211\u5728\u672c\u7ae0\u548c\u672c\u4e66\u4e2d\u4ea4\u66ff\u4f7f\u7528\u8fd9\u4e9b\u672f\u8bed\u3002\u5047\u8bbe\u6a21\u578b\u4e2d\u6709\u4e09\u4e2a\u53c2\u6570 a\u3001b\u3001c\uff0c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u90fd\u53ef\u4ee5\u662f 1 \u5230 10 \u4e4b\u95f4\u7684\u6574\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u7684 \"\u6b63\u786e \"\u7ec4\u5408\u5c06\u4e3a\u60a8\u63d0\u4f9b\u6700\u4f73\u7ed3\u679c\u3002\u56e0\u6b64\uff0c\u8fd9\u5c31\u6709\u70b9\u50cf\u4e00\u4e2a\u88c5\u6709\u4e09\u62e8\u5bc6\u7801\u9501\u7684\u624b\u63d0\u7bb1\u3002\u4e0d\u8fc7\uff0c\u4e09\u62e8\u5bc6\u7801\u9501\u53ea\u6709\u4e00\u4e2a\u6b63\u786e\u7b54\u6848\u3002\u800c\u6a21\u578b\u6709\u5f88\u591a\u6b63\u786e\u7b54\u6848\u3002\u90a3\u4e48\uff0c\u5982\u4f55\u627e\u5230\u6700\u4f73\u53c2\u6570\u5462\uff1f\u4e00\u79cd\u65b9\u6cd5\u662f\u5bf9\u6240\u6709\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\uff0c\u770b\u54ea\u79cd\u7ec4\u5408\u80fd\u63d0\u9ad8\u6307\u6807\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u521d\u59cb\u5316\u6700\u4f73\u51c6\u786e\u5ea6 best_accuracy = 0 # \u521d\u59cb\u5316\u6700\u4f73\u53c2\u6570\u7684\u5b57\u5178 best_parameters = { \"a\" : 0 , \"b\" : 0 , \"c\" : 0 } # \u5faa\u73af\u904d\u5386 a \u7684\u53d6\u503c\u8303\u56f4 1~10 for a in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 b \u7684\u53d6\u503c\u8303\u56f4 1~10 for b in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 c \u7684\u53d6\u503c\u8303\u56f4 1~10 for c in range ( 1 , 11 ): # \u521b\u5efa\u6a21\u578b\uff0c\u4f7f\u7528 a\u3001b\u3001c \u53c2\u6570 model = MODEL ( a , b , c ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u6a21\u578b model . fit ( training_data ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u9a8c\u8bc1\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( validation_data ) # \u8ba1\u7b97\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 accuracy = metrics . accuracy_score ( targets , preds ) # \u5982\u679c\u5f53\u524d\u51c6\u786e\u5ea6\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u5219\u66f4\u65b0\u6700\u4f73\u51c6\u786e\u5ea6\u548c\u6700\u4f73\u53c2\u6570 if accuracy > best_accuracy : best_accuracy = accuracy best_parameters [ \"a\" ] = a best_parameters [ \"b\" ] = b best_parameters [ \"c\" ] = c \u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4ece 1 \u5230 10 \u5bf9\u6240\u6709\u53c2\u6570\u8fdb\u884c\u4e86\u62df\u5408\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u603b\u5171\u8981\u5bf9\u6a21\u578b\u8fdb\u884c 1000 \u6b21\uff0810 x 10 x 10\uff09\u62df\u5408\u3002\u8fd9\u53ef\u80fd\u4f1a\u5f88\u6602\u8d35\uff0c\u56e0\u4e3a\u6a21\u578b\u7684\u8bad\u7ec3\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u8fc7\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\u5e94\u8be5\u6ca1\u95ee\u9898\uff0c\u4f46\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\uff0c\u5e76\u4e0d\u662f\u53ea\u6709\u4e09\u4e2a\u53c2\u6570\uff0c\u6bcf\u4e2a\u53c2\u6570\u4e5f\u4e0d\u662f\u53ea\u6709\u5341\u4e2a\u503c\u3002 \u5927\u591a\u6570\u6a21\u578b\u53c2\u6570\u90fd\u662f\u5b9e\u6570\uff0c\u4e0d\u540c\u53c2\u6570\u7684\u7ec4\u5408\u53ef\u4ee5\u662f\u65e0\u9650\u7684\u3002 \u8ba9\u6211\u4eec\u770b\u770b scikit-learn \u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002 RandomForestClassifier ( n_estimators = 100 , criterion = 'gini' , max_depth = None , min_samples_split = 2 , min_samples_leaf = 1 , min_weight_fraction_leaf = 0.0 , max_features = 'auto' , max_leaf_nodes = None , min_impurity_decrease = 0.0 , min_impurity_split = None , bootstrap = True , oob_score = False , n_jobs = None , random_state = None , verbose = 0 , warm_start = False , class_weight = None , ccp_alpha = 0.0 , max_samples = None , ) \u6709 19 \u4e2a\u53c2\u6570\uff0c\u800c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u7684\u6240\u6709\u7ec4\u5408\uff0c\u4ee5\u53ca\u5b83\u4eec\u53ef\u4ee5\u627f\u62c5\u7684\u6240\u6709\u503c\uff0c\u90fd\u5c06\u662f\u65e0\u7a77\u65e0\u5c3d\u7684\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u6ca1\u6709\u8db3\u591f\u7684\u8d44\u6e90\u548c\u65f6\u95f4\u6765\u505a\u8fd9\u4ef6\u4e8b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6307\u5b9a\u4e86\u4e00\u4e2a\u53c2\u6570\u7f51\u683c\u3002\u5728\u8fd9\u4e2a\u7f51\u683c\u4e0a\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408\u7684\u641c\u7d22\u79f0\u4e3a\u7f51\u683c\u641c\u7d22\u3002\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0cn_estimators \u53ef\u4ee5\u662f 100\u3001200\u3001250\u3001300\u3001400\u3001500\uff1bmax_depth \u53ef\u4ee5\u662f 1\u30012\u30015\u30017\u300111\u300115\uff1bcriterion \u53ef\u4ee5\u662f gini \u6216 entropy\u3002\u8fd9\u4e9b\u53c2\u6570\u770b\u8d77\u6765\u5e76\u4e0d\u591a\uff0c\u4f46\u5982\u679c\u6570\u636e\u96c6\u8fc7\u5927\uff0c\u8ba1\u7b97\u8d77\u6765\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u53ef\u4ee5\u50cf\u4e4b\u524d\u4e00\u6837\u521b\u5efa\u4e09\u4e2a for \u5faa\u73af\uff0c\u5e76\u5728\u9a8c\u8bc1\u96c6\u4e0a\u8ba1\u7b97\u5f97\u5206\uff0c\u8fd9\u6837\u5c31\u80fd\u5b9e\u73b0\u7f51\u683c\u641c\u7d22\u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u8981\u8fdb\u884c k \u6298\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5219\u9700\u8981\u66f4\u591a\u7684\u5faa\u73af\uff0c\u8fd9\u610f\u5473\u7740\u9700\u8981\u66f4\u591a\u7684\u65f6\u95f4\u6765\u627e\u5230\u5b8c\u7f8e\u7684\u53c2\u6570\u3002\u56e0\u6b64\uff0c\u7f51\u683c\u641c\u7d22\u5e76\u4e0d\u6d41\u884c\u3002\u8ba9\u6211\u4eec\u4ee5\u6839\u636e \u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4 \u6570\u636e\u96c6\u4e3a\u4f8b\uff0c\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 \u56fe 1\uff1a\u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4\u6570\u636e\u96c6\u5c55\u793a \u8bad\u7ec3\u96c6\u4e2d\u53ea\u6709 2000 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u53ef\u4ee5\u8f7b\u677e\u5730\u4f7f\u7528\u5206\u5c42 kfold \u548c\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5177\u6709\u4e0a\u8ff0\u53c2\u6570\u8303\u56f4\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\uff0c\u5e76\u5728\u4e0b\u9762\u7684\u793a\u4f8b\u4e2d\u4e86\u89e3\u5982\u4f55\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u3002 # rf_grid_search.py import numpy as np import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u5220\u9664 price_range \u5217 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u53d6\u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\uff0c\u4f7f\u7528\u6240\u6709\u53ef\u7528\u7684 CPU \u6838\u5fc3\u8fdb\u884c\u8bad\u7ec3 classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid = { \"n_estimators\" : [ 100 , 200 , 250 , 300 , 400 , 500 ], \"max_depth\" : [ 1 , 2 , 5 , 7 , 11 , 15 ], \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22 model = model_selection . GridSearchCV ( estimator = classifier , param_grid = param_grid , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( f \"Best score: { model . best_score_ } \" ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u8fd9\u91cc\u6253\u5370\u4e86\u5f88\u591a\u5185\u5bb9\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u6700\u540e\u51e0\u884c\u3002 [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.895 , total = 1.0 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.890 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.910 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.880 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.870 , total = 1.1 s [ Parallel ( n_jobs = 1 )]: Done 360 out of 360 | elapsed : 3.7 min finished Best score : 0.889 Best parameters set : criterion : 'entropy' max_depth : 15 n_estimators : 500 \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c5\u6298\u4ea4\u53c9\u68c0\u9a8c\u6700\u4f73\u5f97\u5206\u662f 0.889\uff0c\u6211\u4eec\u7684\u7f51\u683c\u641c\u7d22\u5f97\u5230\u4e86\u6700\u4f73\u53c2\u6570\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u7684\u4e0b\u4e00\u4e2a\u6700\u4f73\u65b9\u6cd5\u662f \u968f\u673a\u641c\u7d22 \u3002\u5728\u968f\u673a\u641c\u7d22\u4e2d\uff0c\u6211\u4eec\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u53c2\u6570\u7ec4\u5408\uff0c\u7136\u540e\u8ba1\u7b97\u4ea4\u53c9\u9a8c\u8bc1\u5f97\u5206\u3002\u8fd9\u91cc\u6d88\u8017\u7684\u65f6\u95f4\u6bd4\u7f51\u683c\u641c\u7d22\u5c11\uff0c\u56e0\u4e3a\u6211\u4eec\u4e0d\u5bf9\u6240\u6709\u4e0d\u540c\u7684\u53c2\u6570\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\u3002\u6211\u4eec\u9009\u62e9\u8981\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u5c11\u6b21\u8bc4\u4f30\uff0c\u8fd9\u5c31\u51b3\u5b9a\u4e86\u641c\u7d22\u6240\u9700\u7684\u65f6\u95f4\u3002\u4ee3\u7801\u4e0e\u4e0a\u9762\u7684\u5dee\u522b\u4e0d\u5927\u3002\u9664 GridSearchCV \u5916\uff0c\u6211\u4eec\u4f7f\u7528 RandomizedSearchCV\u3002 if __name__ == \"__main__\" : classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u66f4\u6539\u641c\u7d22\u7a7a\u95f4 param_grid = { \"n_estimators\" : np . arange ( 100 , 1500 , 100 ), \"max_depth\" : np . arange ( 1 , 31 ), \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u968f\u673a\u53c2\u6570\u641c\u7d22 model = model_selection . RandomizedSearchCV ( estimator = classifier , param_distributions = param_grid , n_iter = 20 , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) print ( f \"Best score: { model . best_score_ } \" ) print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u6211\u4eec\u66f4\u6539\u4e86\u968f\u673a\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c\uff0c\u7ed3\u679c\u4f3c\u4e4e\u6709\u4e86\u4e9b\u8bb8\u6539\u8fdb\u3002 Best score : 0.8905 Best parameters set : criterion : entropy max_depth : 25 n_estimators : 300 \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8f83\u5c11\uff0c\u968f\u673a\u641c\u7d22\u6bd4\u7f51\u683c\u641c\u7d22\u66f4\u5feb\u3002\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\uff0c\u4f60\u53ef\u4ee5\u4e3a\u5404\u79cd\u6a21\u578b\u627e\u5230\u6700\u4f18\u53c2\u6570\uff0c\u53ea\u8981\u5b83\u4eec\u6709\u62df\u5408\u548c\u9884\u6d4b\u529f\u80fd\uff0c\u8fd9\u4e5f\u662f scikit-learn \u7684\u6807\u51c6\u3002\u6709\u65f6\uff0c\u4f60\u53ef\u80fd\u60f3\u4f7f\u7528\u7ba1\u9053\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u7531\u4e24\u5217\u6587\u672c\u7ec4\u6210\uff0c\u4f60\u9700\u8981\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u7c7b\u522b\u3002\u8ba9\u6211\u4eec\u5047\u8bbe\u4f60\u9009\u62e9\u7684\u7ba1\u9053\u662f\u9996\u5148\u4ee5\u534a\u76d1\u7763\u7684\u65b9\u5f0f\u5e94\u7528 tf-idf\uff0c\u7136\u540e\u4f7f\u7528 SVD \u548c SVM \u5206\u7c7b\u5668\u3002\u73b0\u5728\u7684\u95ee\u9898\u662f\uff0c\u6211\u4eec\u5fc5\u987b\u9009\u62e9 SVD \u7684\u6210\u5206\uff0c\u8fd8\u9700\u8981\u8c03\u6574 SVM \u7684\u53c2\u6570\u3002\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import numpy as np import pandas as pd from sklearn import metrics from sklearn import model_selection from sklearn import pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # \u8ba1\u7b97\u52a0\u6743\u4e8c\u6b21 Kappa \u5206\u6570 def quadratic_weighted_kappa ( y_true , y_pred ): return metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) if __name__ == '__main__' : # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( '../input/train.csv' ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u63d0\u53d6 id \u5217\u7684\u503c\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u7c7b\u578b\uff0c\u5b58\u50a8\u5728\u53d8\u91cf idx \u4e2d idx = test . id . values . astype ( int ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 train = train . drop ( 'id' , axis = 1 ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 test = test . drop ( 'id' , axis = 1 ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf 'relevance' \uff0c\u5b58\u50a8\u5728\u53d8\u91cf y \u4e2d y = train . relevance . values # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 traindata \u4e2d traindata = list ( train . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 testdata \u4e2d testdata = list ( test . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u521b\u5efa\u4e00\u4e2a TfidfVectorizer \u5bf9\u8c61 tfv\uff0c\u7528\u4e8e\u5c06\u6587\u672c\u6570\u636e\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv = TfidfVectorizer ( min_df = 3 , max_features = None , strip_accents = 'unicode' , analyzer = 'word' , token_pattern = r '\\w{1,}' , ngram_range = ( 1 , 3 ), use_idf = 1 , smooth_idf = 1 , sublinear_tf = 1 , stop_words = 'english' ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408 TfidfVectorizer\uff0c\u5c06\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv . fit ( traindata ) # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X X = tfv . transform ( traindata ) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X_test X_test = tfv . transform ( testdata ) # \u521b\u5efa TruncatedSVD \u5bf9\u8c61 svd\uff0c\u7528\u4e8e\u8fdb\u884c\u5947\u5f02\u503c\u5206\u89e3 svd = TruncatedSVD () # \u521b\u5efa StandardScaler \u5bf9\u8c61 scl\uff0c\u7528\u4e8e\u8fdb\u884c\u7279\u5f81\u7f29\u653e scl = StandardScaler () # \u521b\u5efa\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668\u5bf9\u8c61 svm_model svm_model = SVC () # \u521b\u5efa\u673a\u5668\u5b66\u4e60\u7ba1\u9053 clf\uff0c\u5305\u542b\u5947\u5f02\u503c\u5206\u89e3\u3001\u7279\u5f81\u7f29\u653e\u548c\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668 clf = pipeline . Pipeline ( [ ( 'svd' , svd ), ( 'scl' , scl ), ( 'svm' , svm_model ) ] ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid param_grid = { 'svd__n_components' : [ 200 , 300 ], 'svm__C' : [ 10 , 12 ] } # \u521b\u5efa\u81ea\u5b9a\u4e49\u7684\u8bc4\u5206\u51fd\u6570 kappa_scorer\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd kappa_scorer = metrics . make_scorer ( quadratic_weighted_kappa , greater_is_better = True ) # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model = model_selection . GridSearchCV ( estimator = clf , param_grid = param_grid , scoring = kappa_scorer , verbose = 10 , n_jobs =- 1 , refit = True , cv = 5 ) # \u4f7f\u7528 GridSearchCV \u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( \"Best score: %0.3f \" % model . best_score_ ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( \" \\t %s : %r \" % ( param_name , best_parameters [ param_name ])) # \u83b7\u53d6\u6700\u4f73\u6a21\u578b best_model = model . best_estimator_ best_model . fit ( X , y ) # \u4f7f\u7528\u6700\u4f73\u6a21\u578b\u8fdb\u884c\u9884\u6d4b preds = best_model . predict ( ... ) \u8fd9\u91cc\u663e\u793a\u7684\u7ba1\u9053\u5305\u62ec SVD\uff08\u5947\u5f02\u503c\u5206\u89e3\uff09\u3001\u6807\u51c6\u7f29\u653e\u548c SVM\uff08\u652f\u6301\u5411\u91cf\u673a\uff09\u6a21\u578b\u3002\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e\u6ca1\u6709\u8bad\u7ec3\u6570\u636e\uff0c\u60a8\u65e0\u6cd5\u6309\u539f\u6837\u8fd0\u884c\u4e0a\u8ff0\u4ee3\u7801\u3002\u5f53\u6211\u4eec\u8fdb\u5165\u9ad8\u7ea7\u8d85\u53c2\u6570\u4f18\u5316\u6280\u672f\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4e0d\u540c\u7c7b\u578b\u7684 \u6700\u5c0f\u5316\u7b97\u6cd5 \u6765\u7814\u7a76\u51fd\u6570\u7684\u6700\u5c0f\u5316\u3002\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u591a\u79cd\u6700\u5c0f\u5316\u51fd\u6570\u6765\u5b9e\u73b0\uff0c\u5982\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u3001\u5185\u5c14\u5fb7-\u6885\u5fb7\u4f18\u5316\u7b97\u6cd5\u3001\u4f7f\u7528\u8d1d\u53f6\u65af\u6280\u672f\u548c\u9ad8\u65af\u8fc7\u7a0b\u5bfb\u627e\u6700\u4f18\u53c2\u6570\u6216\u4f7f\u7528\u9057\u4f20\u7b97\u6cd5\u3002\u6211\u5c06\u5728 \"\u96c6\u5408\u4e0e\u5806\u53e0\uff08ensembling and stacking\uff09 \"\u4e00\u7ae0\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u548c Nelder-Mead \u7b97\u6cd5\u7684\u5e94\u7528\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u9ad8\u65af\u8fc7\u7a0b\u5982\u4f55\u7528\u4e8e\u8d85\u53c2\u6570\u4f18\u5316\u3002\u8fd9\u7c7b\u7b97\u6cd5\u9700\u8981\u4e00\u4e2a\u53ef\u4ee5\u4f18\u5316\u7684\u51fd\u6570\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u90fd\u662f\u6700\u5c0f\u5316\u8fd9\u4e2a\u51fd\u6570\uff0c\u5c31\u50cf\u6211\u4eec\u6700\u5c0f\u5316\u635f\u5931\u4e00\u6837\u3002 \u56e0\u6b64\uff0c\u6bd4\u65b9\u8bf4\uff0c\u4f60\u60f3\u627e\u5230\u6700\u4f73\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u663e\u7136\uff0c\u51c6\u786e\u5ea6\u8d8a\u9ad8\u8d8a\u597d\u3002\u73b0\u5728\uff0c\u6211\u4eec\u4e0d\u80fd\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\uff0c\u4f46\u6211\u4eec\u53ef\u4ee5\u5c06\u7cbe\u786e\u5ea6\u4e58\u4ee5-1\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u662f\u5728\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\u7684\u8d1f\u503c\uff0c\u4f46\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u6700\u5927\u5316\u7cbe\u786e\u5ea6\u3002 \u5728\u9ad8\u65af\u8fc7\u7a0b\u4e2d\u4f7f\u7528\u8d1d\u53f6\u65af\u4f18\u5316\uff0c\u53ef\u4ee5\u4f7f\u7528 scikit-optimize (skopt) \u5e93\u4e2d\u7684 gp_minimize \u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4f7f\u7528\u8be5\u51fd\u6570\u8c03\u6574\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u53c2\u6570\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from skopt import gp_minimize from skopt import space def optimize ( params , param_names , x , y ): # \u5c06\u53c2\u6570\u540d\u79f0\u548c\u5bf9\u5e94\u7684\u503c\u6253\u5305\u6210\u5b57\u5178 params = dict ( zip ( param_names , params )) # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u6a21\u578b\uff0c\u4f7f\u7528\u4f20\u5165\u7684\u53c2\u6570\u914d\u7f6e model = ensemble . RandomForestClassifier ( ** params ) # \u521b\u5efa StratifiedKFold \u4ea4\u53c9\u9a8c\u8bc1\u5bf9\u8c61\uff0c\u5c06\u6570\u636e\u5206\u4e3a 5 \u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6298\u53e0\u7684\u51c6\u786e\u5ea6\u7684\u5217\u8868 accuracies = [] # \u5faa\u73af\u904d\u5386\u6bcf\u4e2a\u6298\u53e0\u7684\u8bad\u7ec3\u548c\u6d4b\u8bd5\u6570\u636e for idx in kf . split ( X = x , y = y ): train_idx , test_idx = idx [ 0 ], idx [ 1 ] xtrain = x [ train_idx ] ytrain = y [ train_idx ] xtest = x [ test_idx ] ytest = y [ test_idx ] # \u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u62df\u5408\u6a21\u578b model . fit ( xtrain , ytrain ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( xtest ) # \u8ba1\u7b97\u6298\u53e0\u7684\u51c6\u786e\u5ea6 fold_accuracy = metrics . accuracy_score ( ytest , preds ) accuracies . append ( fold_accuracy ) # \u8fd4\u56de\u5e73\u5747\u51c6\u786e\u5ea6\u7684\u8d1f\u6570\uff08\u56e0\u4e3a skopt \u4f7f\u7528\u8d1f\u6570\u6765\u6700\u5c0f\u5316\u76ee\u6807\u51fd\u6570\uff09 return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u53d6\u7279\u5f81\u77e9\u9635 X\uff08\u53bb\u6389\"price_range\"\u5217\uff09 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u5b9a\u4e49\u8d85\u53c2\u6570\u641c\u7d22\u7a7a\u95f4 param_space param_space = [ space . Integer ( 3 , 15 , name = \"max_depth\" ), space . Integer ( 100 , 1500 , name = \"n_estimators\" ), space . Categorical ([ \"gini\" , \"entropy\" ], name = \"criterion\" ), space . Real ( 0.01 , 1 , prior = \"uniform\" , name = \"max_features\" ) ] # \u5b9a\u4e49\u8d85\u53c2\u6570\u7684\u540d\u79f0\u5217\u8868 param_names param_names = [ \"max_depth\" , \"n_estimators\" , \"criterion\" , \"max_features\" ] # \u521b\u5efa\u51fd\u6570 optimization_function\uff0c\u7528\u4e8e\u4f20\u9012\u7ed9 gp_minimize optimization_function = partial ( optimize , param_names = param_names , x = X , y = y ) # \u4f7f\u7528 Bayesian Optimization\uff08\u57fa\u4e8e\u8d1d\u53f6\u65af\u4f18\u5316\uff09\u6765\u641c\u7d22\u6700\u4f73\u8d85\u53c2\u6570 result = gp_minimize ( optimization_function , dimensions = param_space , n_calls = 15 , n_random_starts = 10 , verbose = 10 ) # \u83b7\u53d6\u6700\u4f73\u8d85\u53c2\u6570\u7684\u5b57\u5178 best_params = dict ( zip ( param_names , result . x ) ) # \u6253\u5370\u51fa\u627e\u5230\u7684\u6700\u4f73\u8d85\u53c2\u6570 print ( best_params ) \u8fd9\u540c\u6837\u4f1a\u4ea7\u751f\u5927\u91cf\u8f93\u51fa\uff0c\u6700\u540e\u4e00\u90e8\u5206\u5982\u4e0b\u6240\u793a\u3002 Iteration No : 14 started . Searching for the next optimal point . Iteration No : 14 ended . Search finished for the next optimal point . Time taken : 4.7793 Function value obtained : - 0.9075 Current minimum : - 0.9075 Iteration No : 15 started . Searching for the next optimal point . Iteration No : 15 ended . Search finished for the next optimal point . Time taken : 49.4186 Function value obtained : - 0.9075 Current minimum : - 0.9075 { 'max_depth' : 12 , 'n_estimators' : 100 , 'criterion' : 'entropy' , 'max_features' : 1.0 } \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6210\u529f\u7a81\u7834\u4e86 0.90 \u7684\u51c6\u786e\u7387\u3002\u8fd9\u771f\u662f\u592a\u795e\u5947\u4e86\uff01 \u6211\u4eec\u8fd8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6bb5\u67e5\u770b\uff08\u7ed8\u5236\uff09\u6211\u4eec\u662f\u5982\u4f55\u5b9e\u73b0\u6536\u655b\u7684\u3002 from skopt.plots import plot_convergence plot_convergence ( result ) \u6536\u655b\u56fe\u5982\u56fe 2 \u6240\u793a\u3002 \u56fe 2\uff1a\u968f\u673a\u68ee\u6797\u53c2\u6570\u4f18\u5316\u7684\u6536\u655b\u56fe Scikit- optimize \u5c31\u662f\u8fd9\u6837\u4e00\u4e2a\u5e93\u3002 hyperopt \u4f7f\u7528\u6811\u72b6\u7ed3\u6784\u8d1d\u53f6\u65af\u4f30\u8ba1\u5668\uff08TPE\uff09\u6765\u627e\u5230\u6700\u4f18\u53c2\u6570\u3002\u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u7247\u6bb5\uff0c\u6211\u5728\u4f7f\u7528 hyperopt \u65f6\u5bf9\u4e4b\u524d\u7684\u4ee3\u7801\u505a\u4e86\u6700\u5c0f\u7684\u6539\u52a8\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from hyperopt import hp , fmin , tpe , Trials from hyperopt.pyll.base import scope def optimize ( params , x , y ): model = ensemble . RandomForestClassifier ( ** params ) kf = model_selection . StratifiedKFold ( n_splits = 5 ) ... return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/mobile_train.csv\" ) X = df . drop ( \"price_range\" , axis = 1 ) . values y = df . price_range . values # \u5b9a\u4e49\u641c\u7d22\u7a7a\u95f4\uff08\u6574\u578b\u3001\u6d6e\u70b9\u6570\u578b\u3001\u9009\u62e9\u578b\uff09 param_space = { \"max_depth\" : scope . int ( hp . quniform ( \"max_depth\" , 1 , 15 , 1 )), \"n_estimators\" : scope . int ( hp . quniform ( \"n_estimators\" , 100 , 1500 , 1 ) ), \"criterion\" : hp . choice ( \"criterion\" , [ \"gini\" , \"entropy\" ]), \"max_features\" : hp . uniform ( \"max_features\" , 0 , 1 ) } # \u5305\u88c5\u51fd\u6570 optimization_function = partial ( optimize , x = X , y = y ) # \u5f00\u59cb\u8bad\u7ec3 trials = Trials () # \u6700\u5c0f\u5316\u76ee\u6807\u503c hopt = fmin ( fn = optimization_function , space = param_space , algo = tpe . suggest , max_evals = 15 , trials = trials ) #\u6253\u5370\u6700\u4f73\u53c2\u6570 print ( hopt ) \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u4e0e\u4e4b\u524d\u7684\u4ee3\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u4f60\u5fc5\u987b\u4ee5\u4e0d\u540c\u7684\u683c\u5f0f\u5b9a\u4e49\u53c2\u6570\u7a7a\u95f4\uff0c\u8fd8\u9700\u8981\u6539\u53d8\u5b9e\u9645\u4f18\u5316\u90e8\u5206\uff0c\u7528 hyperopt \u4ee3\u66ff gp_minimize\u3002\u7ed3\u679c\u76f8\u5f53\u4e0d\u9519\uff01 \u276f python rf_hyperopt . py 100 %| \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 15 / 15 [ 04 : 38 < 00 : 00 , 18.57 s / trial , best loss : - 0.9095000000000001 ] { 'criterion' : 1 , 'max_depth' : 11.0 , 'max_features' : 0.821163568049807 , 'n_estimators' : 806.0 } \u6211\u4eec\u5f97\u5230\u4e86\u6bd4\u4ee5\u524d\u66f4\u597d\u7684\u51c6\u786e\u5ea6\u548c\u4e00\u7ec4\u53ef\u4ee5\u4f7f\u7528\u7684\u53c2\u6570\u3002\u8bf7\u6ce8\u610f\uff0c\u6700\u7ec8\u7ed3\u679c\u4e2d\u7684\u6807\u51c6\u662f 1\u3002\u8fd9\u610f\u5473\u7740\u9009\u62e9\u4e86 1\uff0c\u5373\u71b5\u3002 \u4e0a\u8ff0\u8c03\u6574\u8d85\u53c2\u6570\u7684\u65b9\u6cd5\u662f\u6700\u5e38\u89c1\u7684\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u6a21\u578b\uff1a\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001\u57fa\u4e8e\u6811\u7684\u65b9\u6cd5\u3001\u68af\u5ea6\u63d0\u5347\u6a21\u578b\uff08\u5982 xgboost\u3001lightgbm\uff09\uff0c\u751a\u81f3\u795e\u7ecf\u7f51\u7edc\uff01 \u867d\u7136\u8fd9\u4e9b\u65b9\u6cd5\u5df2\u7ecf\u5b58\u5728\uff0c\u4f46\u5b66\u4e60\u65f6\u5fc5\u987b\u4ece\u624b\u52a8\u8c03\u6574\u8d85\u53c2\u6570\u5f00\u59cb\uff0c\u5373\u624b\u5de5\u8c03\u6574\u3002\u624b\u52a8\u8c03\u6574\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5b66\u4e60\u57fa\u7840\u77e5\u8bc6\uff0c\u4f8b\u5982\uff0c\u5728\u68af\u5ea6\u63d0\u5347\u4e2d\uff0c\u5f53\u4f60\u589e\u52a0\u6df1\u5ea6\u65f6\uff0c\u4f60\u5e94\u8be5\u964d\u4f4e\u5b66\u4e60\u7387\u3002\u5982\u679c\u4f7f\u7528\u81ea\u52a8\u5de5\u5177\uff0c\u5c31\u65e0\u6cd5\u5b66\u4e60\u5230\u8fd9\u4e00\u70b9\u3002\u8bf7\u53c2\u8003\u4e0b\u8868\uff0c\u4e86\u89e3\u5e94\u5982\u4f55\u8c03\u6574\u3002RS* \u8868\u793a\u968f\u673a\u641c\u7d22\u5e94\u8be5\u66f4\u597d\u3002 \u4e00\u65e6\u4f60\u80fd\u66f4\u597d\u5730\u624b\u52a8\u8c03\u6574\u53c2\u6570\uff0c\u4f60\u751a\u81f3\u53ef\u80fd\u4e0d\u9700\u8981\u4efb\u4f55\u81ea\u52a8\u8d85\u53c2\u6570\u8c03\u6574\u3002\u521b\u5efa\u5927\u578b\u6a21\u578b\u6216\u5f15\u5165\u5927\u91cf\u7279\u5f81\u65f6\uff0c\u4e5f\u5bb9\u6613\u9020\u6210\u8bad\u7ec3\u6570\u636e\u7684\u8fc7\u5ea6\u62df\u5408\u3002\u4e3a\u907f\u514d\u8fc7\u5ea6\u62df\u5408\uff0c\u9700\u8981\u5728\u8bad\u7ec3\u6570\u636e\u7279\u5f81\u4e2d\u5f15\u5165\u566a\u58f0\u6216\u5bf9\u4ee3\u4ef7\u51fd\u6570\u8fdb\u884c\u60e9\u7f5a\u3002\u8fd9\u79cd\u60e9\u7f5a\u79f0\u4e3a \u6b63\u5219\u5316 \uff0c\u6709\u52a9\u4e8e\u6cdb\u5316\u6a21\u578b\u3002\u5728\u7ebf\u6027\u6a21\u578b\u4e2d\uff0c\u6700\u5e38\u89c1\u7684\u6b63\u5219\u5316\u7c7b\u578b\u662f L1 \u548c L2\u3002L1 \u4e5f\u79f0\u4e3a Lasso \u56de\u5f52\uff0cL2 \u79f0\u4e3a Ridge \u56de\u5f52\u3002\u8bf4\u5230\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528dropout\u3001\u6dfb\u52a0\u589e\u5f3a\u3001\u566a\u58f0\u7b49\u65b9\u6cd5\u5bf9\u6a21\u578b\u8fdb\u884c\u6b63\u5219\u5316\u3002\u5229\u7528\u8d85\u53c2\u6570\u4f18\u5316\uff0c\u8fd8\u53ef\u4ee5\u627e\u5230\u6b63\u786e\u7684\u60e9\u7f5a\u65b9\u6cd5\u3002 Model Optimize Range of values Linear Regression - fit_intercept - normalize - True/False - True/False Ridge - alpha - fit_intercept - normalize - 0.01, 0.1, 1.0, 10, 100 - True/False - True/False k-neighbors - n_neighbors - p - 2, 4, 8, 16, ... - 2, 3, ... SVM - C - gamma - class_weight - 0.001, 0.01, ...,10, 100, 1000 - 'auto', RS* - 'balanced', None Logistic Regression - Penalyt - C - L1 or L2 - 0.001, 0.01, ..., 10, ..., 100 Lasso - Alpha - Normalize - 0.1, 1.0, 10 - True/False Random Forest - n_estimators - max_depth - min_samples_split - min_samples_leaf - max features - 120, 300, 500, 800, 1200 - 5, 8, 15, 25, 30, None - 1, 2, 5, 10, 15, 100 - log2, sqrt, None XGBoost - eta - gamma - max_depth - min_child_weight - subsample - colsample_bytree - lambda - alpha - 0.01, 0.015, 0.025, 0.05, 0.1 - 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0 - 3, 5, 7, 9, 12, 15, 17, 25 - 1, 3, 5, 7 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.01, 0.1, 1.0, RS - 0, 0.1, 0.5, 1.0, RS","title":"\u8d85\u53c2\u6570\u4f18\u5316"},{"location":"%E8%B6%85%E5%8F%82%E6%95%B0%E4%BC%98%E5%8C%96/#_1","text":"\u6709\u4e86\u4f18\u79c0\u7684\u6a21\u578b\uff0c\u5c31\u6709\u4e86\u4f18\u5316\u8d85\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u5f97\u5206\u6a21\u578b\u7684\u96be\u9898\u3002\u90a3\u4e48\uff0c\u4ec0\u4e48\u662f\u8d85\u53c2\u6570\u4f18\u5316\u5462\uff1f\u5047\u8bbe\u60a8\u7684\u673a\u5668\u5b66\u4e60\u9879\u76ee\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u6d41\u7a0b\u3002\u6709\u4e00\u4e2a\u6570\u636e\u96c6\uff0c\u4f60\u76f4\u63a5\u5e94\u7528\u4e00\u4e2a\u6a21\u578b\uff0c\u7136\u540e\u5f97\u5230\u7ed3\u679c\u3002\u6a21\u578b\u5728\u8fd9\u91cc\u7684\u53c2\u6570\u88ab\u79f0\u4e3a\u8d85\u53c2\u6570\uff0c\u5373\u63a7\u5236\u6a21\u578b\u8bad\u7ec3/\u62df\u5408\u8fc7\u7a0b\u7684\u53c2\u6570\u3002\u5982\u679c\u6211\u4eec\u7528 SGD \u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u6a21\u578b\u7684\u53c2\u6570\u662f\u659c\u7387\u548c\u504f\u5dee\uff0c\u8d85\u53c2\u6570\u662f\u5b66\u4e60\u7387\u3002\u4f60\u4f1a\u53d1\u73b0\u6211\u5728\u672c\u7ae0\u548c\u672c\u4e66\u4e2d\u4ea4\u66ff\u4f7f\u7528\u8fd9\u4e9b\u672f\u8bed\u3002\u5047\u8bbe\u6a21\u578b\u4e2d\u6709\u4e09\u4e2a\u53c2\u6570 a\u3001b\u3001c\uff0c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u90fd\u53ef\u4ee5\u662f 1 \u5230 10 \u4e4b\u95f4\u7684\u6574\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u7684 \"\u6b63\u786e \"\u7ec4\u5408\u5c06\u4e3a\u60a8\u63d0\u4f9b\u6700\u4f73\u7ed3\u679c\u3002\u56e0\u6b64\uff0c\u8fd9\u5c31\u6709\u70b9\u50cf\u4e00\u4e2a\u88c5\u6709\u4e09\u62e8\u5bc6\u7801\u9501\u7684\u624b\u63d0\u7bb1\u3002\u4e0d\u8fc7\uff0c\u4e09\u62e8\u5bc6\u7801\u9501\u53ea\u6709\u4e00\u4e2a\u6b63\u786e\u7b54\u6848\u3002\u800c\u6a21\u578b\u6709\u5f88\u591a\u6b63\u786e\u7b54\u6848\u3002\u90a3\u4e48\uff0c\u5982\u4f55\u627e\u5230\u6700\u4f73\u53c2\u6570\u5462\uff1f\u4e00\u79cd\u65b9\u6cd5\u662f\u5bf9\u6240\u6709\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\uff0c\u770b\u54ea\u79cd\u7ec4\u5408\u80fd\u63d0\u9ad8\u6307\u6807\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u521d\u59cb\u5316\u6700\u4f73\u51c6\u786e\u5ea6 best_accuracy = 0 # \u521d\u59cb\u5316\u6700\u4f73\u53c2\u6570\u7684\u5b57\u5178 best_parameters = { \"a\" : 0 , \"b\" : 0 , \"c\" : 0 } # \u5faa\u73af\u904d\u5386 a \u7684\u53d6\u503c\u8303\u56f4 1~10 for a in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 b \u7684\u53d6\u503c\u8303\u56f4 1~10 for b in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 c \u7684\u53d6\u503c\u8303\u56f4 1~10 for c in range ( 1 , 11 ): # \u521b\u5efa\u6a21\u578b\uff0c\u4f7f\u7528 a\u3001b\u3001c \u53c2\u6570 model = MODEL ( a , b , c ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u6a21\u578b model . fit ( training_data ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u9a8c\u8bc1\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( validation_data ) # \u8ba1\u7b97\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 accuracy = metrics . accuracy_score ( targets , preds ) # \u5982\u679c\u5f53\u524d\u51c6\u786e\u5ea6\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u5219\u66f4\u65b0\u6700\u4f73\u51c6\u786e\u5ea6\u548c\u6700\u4f73\u53c2\u6570 if accuracy > best_accuracy : best_accuracy = accuracy best_parameters [ \"a\" ] = a best_parameters [ \"b\" ] = b best_parameters [ \"c\" ] = c \u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4ece 1 \u5230 10 \u5bf9\u6240\u6709\u53c2\u6570\u8fdb\u884c\u4e86\u62df\u5408\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u603b\u5171\u8981\u5bf9\u6a21\u578b\u8fdb\u884c 1000 \u6b21\uff0810 x 10 x 10\uff09\u62df\u5408\u3002\u8fd9\u53ef\u80fd\u4f1a\u5f88\u6602\u8d35\uff0c\u56e0\u4e3a\u6a21\u578b\u7684\u8bad\u7ec3\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u8fc7\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\u5e94\u8be5\u6ca1\u95ee\u9898\uff0c\u4f46\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\uff0c\u5e76\u4e0d\u662f\u53ea\u6709\u4e09\u4e2a\u53c2\u6570\uff0c\u6bcf\u4e2a\u53c2\u6570\u4e5f\u4e0d\u662f\u53ea\u6709\u5341\u4e2a\u503c\u3002 \u5927\u591a\u6570\u6a21\u578b\u53c2\u6570\u90fd\u662f\u5b9e\u6570\uff0c\u4e0d\u540c\u53c2\u6570\u7684\u7ec4\u5408\u53ef\u4ee5\u662f\u65e0\u9650\u7684\u3002 \u8ba9\u6211\u4eec\u770b\u770b scikit-learn \u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002 RandomForestClassifier ( n_estimators = 100 , criterion = 'gini' , max_depth = None , min_samples_split = 2 , min_samples_leaf = 1 , min_weight_fraction_leaf = 0.0 , max_features = 'auto' , max_leaf_nodes = None , min_impurity_decrease = 0.0 , min_impurity_split = None , bootstrap = True , oob_score = False , n_jobs = None , random_state = None , verbose = 0 , warm_start = False , class_weight = None , ccp_alpha = 0.0 , max_samples = None , ) \u6709 19 \u4e2a\u53c2\u6570\uff0c\u800c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u7684\u6240\u6709\u7ec4\u5408\uff0c\u4ee5\u53ca\u5b83\u4eec\u53ef\u4ee5\u627f\u62c5\u7684\u6240\u6709\u503c\uff0c\u90fd\u5c06\u662f\u65e0\u7a77\u65e0\u5c3d\u7684\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u6ca1\u6709\u8db3\u591f\u7684\u8d44\u6e90\u548c\u65f6\u95f4\u6765\u505a\u8fd9\u4ef6\u4e8b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6307\u5b9a\u4e86\u4e00\u4e2a\u53c2\u6570\u7f51\u683c\u3002\u5728\u8fd9\u4e2a\u7f51\u683c\u4e0a\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408\u7684\u641c\u7d22\u79f0\u4e3a\u7f51\u683c\u641c\u7d22\u3002\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0cn_estimators \u53ef\u4ee5\u662f 100\u3001200\u3001250\u3001300\u3001400\u3001500\uff1bmax_depth \u53ef\u4ee5\u662f 1\u30012\u30015\u30017\u300111\u300115\uff1bcriterion \u53ef\u4ee5\u662f gini \u6216 entropy\u3002\u8fd9\u4e9b\u53c2\u6570\u770b\u8d77\u6765\u5e76\u4e0d\u591a\uff0c\u4f46\u5982\u679c\u6570\u636e\u96c6\u8fc7\u5927\uff0c\u8ba1\u7b97\u8d77\u6765\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u53ef\u4ee5\u50cf\u4e4b\u524d\u4e00\u6837\u521b\u5efa\u4e09\u4e2a for \u5faa\u73af\uff0c\u5e76\u5728\u9a8c\u8bc1\u96c6\u4e0a\u8ba1\u7b97\u5f97\u5206\uff0c\u8fd9\u6837\u5c31\u80fd\u5b9e\u73b0\u7f51\u683c\u641c\u7d22\u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u8981\u8fdb\u884c k \u6298\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5219\u9700\u8981\u66f4\u591a\u7684\u5faa\u73af\uff0c\u8fd9\u610f\u5473\u7740\u9700\u8981\u66f4\u591a\u7684\u65f6\u95f4\u6765\u627e\u5230\u5b8c\u7f8e\u7684\u53c2\u6570\u3002\u56e0\u6b64\uff0c\u7f51\u683c\u641c\u7d22\u5e76\u4e0d\u6d41\u884c\u3002\u8ba9\u6211\u4eec\u4ee5\u6839\u636e \u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4 \u6570\u636e\u96c6\u4e3a\u4f8b\uff0c\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 \u56fe 1\uff1a\u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4\u6570\u636e\u96c6\u5c55\u793a \u8bad\u7ec3\u96c6\u4e2d\u53ea\u6709 2000 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u53ef\u4ee5\u8f7b\u677e\u5730\u4f7f\u7528\u5206\u5c42 kfold \u548c\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5177\u6709\u4e0a\u8ff0\u53c2\u6570\u8303\u56f4\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\uff0c\u5e76\u5728\u4e0b\u9762\u7684\u793a\u4f8b\u4e2d\u4e86\u89e3\u5982\u4f55\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u3002 # rf_grid_search.py import numpy as np import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u5220\u9664 price_range \u5217 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u53d6\u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\uff0c\u4f7f\u7528\u6240\u6709\u53ef\u7528\u7684 CPU \u6838\u5fc3\u8fdb\u884c\u8bad\u7ec3 classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid = { \"n_estimators\" : [ 100 , 200 , 250 , 300 , 400 , 500 ], \"max_depth\" : [ 1 , 2 , 5 , 7 , 11 , 15 ], \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22 model = model_selection . GridSearchCV ( estimator = classifier , param_grid = param_grid , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( f \"Best score: { model . best_score_ } \" ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u8fd9\u91cc\u6253\u5370\u4e86\u5f88\u591a\u5185\u5bb9\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u6700\u540e\u51e0\u884c\u3002 [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.895 , total = 1.0 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.890 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.910 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.880 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.870 , total = 1.1 s [ Parallel ( n_jobs = 1 )]: Done 360 out of 360 | elapsed : 3.7 min finished Best score : 0.889 Best parameters set : criterion : 'entropy' max_depth : 15 n_estimators : 500 \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c5\u6298\u4ea4\u53c9\u68c0\u9a8c\u6700\u4f73\u5f97\u5206\u662f 0.889\uff0c\u6211\u4eec\u7684\u7f51\u683c\u641c\u7d22\u5f97\u5230\u4e86\u6700\u4f73\u53c2\u6570\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u7684\u4e0b\u4e00\u4e2a\u6700\u4f73\u65b9\u6cd5\u662f \u968f\u673a\u641c\u7d22 \u3002\u5728\u968f\u673a\u641c\u7d22\u4e2d\uff0c\u6211\u4eec\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u53c2\u6570\u7ec4\u5408\uff0c\u7136\u540e\u8ba1\u7b97\u4ea4\u53c9\u9a8c\u8bc1\u5f97\u5206\u3002\u8fd9\u91cc\u6d88\u8017\u7684\u65f6\u95f4\u6bd4\u7f51\u683c\u641c\u7d22\u5c11\uff0c\u56e0\u4e3a\u6211\u4eec\u4e0d\u5bf9\u6240\u6709\u4e0d\u540c\u7684\u53c2\u6570\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\u3002\u6211\u4eec\u9009\u62e9\u8981\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u5c11\u6b21\u8bc4\u4f30\uff0c\u8fd9\u5c31\u51b3\u5b9a\u4e86\u641c\u7d22\u6240\u9700\u7684\u65f6\u95f4\u3002\u4ee3\u7801\u4e0e\u4e0a\u9762\u7684\u5dee\u522b\u4e0d\u5927\u3002\u9664 GridSearchCV \u5916\uff0c\u6211\u4eec\u4f7f\u7528 RandomizedSearchCV\u3002 if __name__ == \"__main__\" : classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u66f4\u6539\u641c\u7d22\u7a7a\u95f4 param_grid = { \"n_estimators\" : np . arange ( 100 , 1500 , 100 ), \"max_depth\" : np . arange ( 1 , 31 ), \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u968f\u673a\u53c2\u6570\u641c\u7d22 model = model_selection . RandomizedSearchCV ( estimator = classifier , param_distributions = param_grid , n_iter = 20 , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) print ( f \"Best score: { model . best_score_ } \" ) print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u6211\u4eec\u66f4\u6539\u4e86\u968f\u673a\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c\uff0c\u7ed3\u679c\u4f3c\u4e4e\u6709\u4e86\u4e9b\u8bb8\u6539\u8fdb\u3002 Best score : 0.8905 Best parameters set : criterion : entropy max_depth : 25 n_estimators : 300 \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8f83\u5c11\uff0c\u968f\u673a\u641c\u7d22\u6bd4\u7f51\u683c\u641c\u7d22\u66f4\u5feb\u3002\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\uff0c\u4f60\u53ef\u4ee5\u4e3a\u5404\u79cd\u6a21\u578b\u627e\u5230\u6700\u4f18\u53c2\u6570\uff0c\u53ea\u8981\u5b83\u4eec\u6709\u62df\u5408\u548c\u9884\u6d4b\u529f\u80fd\uff0c\u8fd9\u4e5f\u662f scikit-learn \u7684\u6807\u51c6\u3002\u6709\u65f6\uff0c\u4f60\u53ef\u80fd\u60f3\u4f7f\u7528\u7ba1\u9053\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u7531\u4e24\u5217\u6587\u672c\u7ec4\u6210\uff0c\u4f60\u9700\u8981\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u7c7b\u522b\u3002\u8ba9\u6211\u4eec\u5047\u8bbe\u4f60\u9009\u62e9\u7684\u7ba1\u9053\u662f\u9996\u5148\u4ee5\u534a\u76d1\u7763\u7684\u65b9\u5f0f\u5e94\u7528 tf-idf\uff0c\u7136\u540e\u4f7f\u7528 SVD \u548c SVM \u5206\u7c7b\u5668\u3002\u73b0\u5728\u7684\u95ee\u9898\u662f\uff0c\u6211\u4eec\u5fc5\u987b\u9009\u62e9 SVD \u7684\u6210\u5206\uff0c\u8fd8\u9700\u8981\u8c03\u6574 SVM \u7684\u53c2\u6570\u3002\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import numpy as np import pandas as pd from sklearn import metrics from sklearn import model_selection from sklearn import pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # \u8ba1\u7b97\u52a0\u6743\u4e8c\u6b21 Kappa \u5206\u6570 def quadratic_weighted_kappa ( y_true , y_pred ): return metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) if __name__ == '__main__' : # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( '../input/train.csv' ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u63d0\u53d6 id \u5217\u7684\u503c\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u7c7b\u578b\uff0c\u5b58\u50a8\u5728\u53d8\u91cf idx \u4e2d idx = test . id . values . astype ( int ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 train = train . drop ( 'id' , axis = 1 ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 test = test . drop ( 'id' , axis = 1 ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf 'relevance' \uff0c\u5b58\u50a8\u5728\u53d8\u91cf y \u4e2d y = train . relevance . values # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 traindata \u4e2d traindata = list ( train . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 testdata \u4e2d testdata = list ( test . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u521b\u5efa\u4e00\u4e2a TfidfVectorizer \u5bf9\u8c61 tfv\uff0c\u7528\u4e8e\u5c06\u6587\u672c\u6570\u636e\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv = TfidfVectorizer ( min_df = 3 , max_features = None , strip_accents = 'unicode' , analyzer = 'word' , token_pattern = r '\\w{1,}' , ngram_range = ( 1 , 3 ), use_idf = 1 , smooth_idf = 1 , sublinear_tf = 1 , stop_words = 'english' ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408 TfidfVectorizer\uff0c\u5c06\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv . fit ( traindata ) # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X X = tfv . transform ( traindata ) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X_test X_test = tfv . transform ( testdata ) # \u521b\u5efa TruncatedSVD \u5bf9\u8c61 svd\uff0c\u7528\u4e8e\u8fdb\u884c\u5947\u5f02\u503c\u5206\u89e3 svd = TruncatedSVD () # \u521b\u5efa StandardScaler \u5bf9\u8c61 scl\uff0c\u7528\u4e8e\u8fdb\u884c\u7279\u5f81\u7f29\u653e scl = StandardScaler () # \u521b\u5efa\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668\u5bf9\u8c61 svm_model svm_model = SVC () # \u521b\u5efa\u673a\u5668\u5b66\u4e60\u7ba1\u9053 clf\uff0c\u5305\u542b\u5947\u5f02\u503c\u5206\u89e3\u3001\u7279\u5f81\u7f29\u653e\u548c\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668 clf = pipeline . Pipeline ( [ ( 'svd' , svd ), ( 'scl' , scl ), ( 'svm' , svm_model ) ] ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid param_grid = { 'svd__n_components' : [ 200 , 300 ], 'svm__C' : [ 10 , 12 ] } # \u521b\u5efa\u81ea\u5b9a\u4e49\u7684\u8bc4\u5206\u51fd\u6570 kappa_scorer\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd kappa_scorer = metrics . make_scorer ( quadratic_weighted_kappa , greater_is_better = True ) # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model = model_selection . GridSearchCV ( estimator = clf , param_grid = param_grid , scoring = kappa_scorer , verbose = 10 , n_jobs =- 1 , refit = True , cv = 5 ) # \u4f7f\u7528 GridSearchCV \u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( \"Best score: %0.3f \" % model . best_score_ ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( \" \\t %s : %r \" % ( param_name , best_parameters [ param_name ])) # \u83b7\u53d6\u6700\u4f73\u6a21\u578b best_model = model . best_estimator_ best_model . fit ( X , y ) # \u4f7f\u7528\u6700\u4f73\u6a21\u578b\u8fdb\u884c\u9884\u6d4b preds = best_model . predict ( ... ) \u8fd9\u91cc\u663e\u793a\u7684\u7ba1\u9053\u5305\u62ec SVD\uff08\u5947\u5f02\u503c\u5206\u89e3\uff09\u3001\u6807\u51c6\u7f29\u653e\u548c SVM\uff08\u652f\u6301\u5411\u91cf\u673a\uff09\u6a21\u578b\u3002\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e\u6ca1\u6709\u8bad\u7ec3\u6570\u636e\uff0c\u60a8\u65e0\u6cd5\u6309\u539f\u6837\u8fd0\u884c\u4e0a\u8ff0\u4ee3\u7801\u3002\u5f53\u6211\u4eec\u8fdb\u5165\u9ad8\u7ea7\u8d85\u53c2\u6570\u4f18\u5316\u6280\u672f\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4e0d\u540c\u7c7b\u578b\u7684 \u6700\u5c0f\u5316\u7b97\u6cd5 \u6765\u7814\u7a76\u51fd\u6570\u7684\u6700\u5c0f\u5316\u3002\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u591a\u79cd\u6700\u5c0f\u5316\u51fd\u6570\u6765\u5b9e\u73b0\uff0c\u5982\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u3001\u5185\u5c14\u5fb7-\u6885\u5fb7\u4f18\u5316\u7b97\u6cd5\u3001\u4f7f\u7528\u8d1d\u53f6\u65af\u6280\u672f\u548c\u9ad8\u65af\u8fc7\u7a0b\u5bfb\u627e\u6700\u4f18\u53c2\u6570\u6216\u4f7f\u7528\u9057\u4f20\u7b97\u6cd5\u3002\u6211\u5c06\u5728 \"\u96c6\u5408\u4e0e\u5806\u53e0\uff08ensembling and stacking\uff09 \"\u4e00\u7ae0\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u548c Nelder-Mead \u7b97\u6cd5\u7684\u5e94\u7528\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u9ad8\u65af\u8fc7\u7a0b\u5982\u4f55\u7528\u4e8e\u8d85\u53c2\u6570\u4f18\u5316\u3002\u8fd9\u7c7b\u7b97\u6cd5\u9700\u8981\u4e00\u4e2a\u53ef\u4ee5\u4f18\u5316\u7684\u51fd\u6570\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u90fd\u662f\u6700\u5c0f\u5316\u8fd9\u4e2a\u51fd\u6570\uff0c\u5c31\u50cf\u6211\u4eec\u6700\u5c0f\u5316\u635f\u5931\u4e00\u6837\u3002 \u56e0\u6b64\uff0c\u6bd4\u65b9\u8bf4\uff0c\u4f60\u60f3\u627e\u5230\u6700\u4f73\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u663e\u7136\uff0c\u51c6\u786e\u5ea6\u8d8a\u9ad8\u8d8a\u597d\u3002\u73b0\u5728\uff0c\u6211\u4eec\u4e0d\u80fd\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\uff0c\u4f46\u6211\u4eec\u53ef\u4ee5\u5c06\u7cbe\u786e\u5ea6\u4e58\u4ee5-1\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u662f\u5728\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\u7684\u8d1f\u503c\uff0c\u4f46\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u6700\u5927\u5316\u7cbe\u786e\u5ea6\u3002 \u5728\u9ad8\u65af\u8fc7\u7a0b\u4e2d\u4f7f\u7528\u8d1d\u53f6\u65af\u4f18\u5316\uff0c\u53ef\u4ee5\u4f7f\u7528 scikit-optimize (skopt) \u5e93\u4e2d\u7684 gp_minimize \u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4f7f\u7528\u8be5\u51fd\u6570\u8c03\u6574\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u53c2\u6570\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from skopt import gp_minimize from skopt import space def optimize ( params , param_names , x , y ): # \u5c06\u53c2\u6570\u540d\u79f0\u548c\u5bf9\u5e94\u7684\u503c\u6253\u5305\u6210\u5b57\u5178 params = dict ( zip ( param_names , params )) # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u6a21\u578b\uff0c\u4f7f\u7528\u4f20\u5165\u7684\u53c2\u6570\u914d\u7f6e model = ensemble . RandomForestClassifier ( ** params ) # \u521b\u5efa StratifiedKFold \u4ea4\u53c9\u9a8c\u8bc1\u5bf9\u8c61\uff0c\u5c06\u6570\u636e\u5206\u4e3a 5 \u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6298\u53e0\u7684\u51c6\u786e\u5ea6\u7684\u5217\u8868 accuracies = [] # \u5faa\u73af\u904d\u5386\u6bcf\u4e2a\u6298\u53e0\u7684\u8bad\u7ec3\u548c\u6d4b\u8bd5\u6570\u636e for idx in kf . split ( X = x , y = y ): train_idx , test_idx = idx [ 0 ], idx [ 1 ] xtrain = x [ train_idx ] ytrain = y [ train_idx ] xtest = x [ test_idx ] ytest = y [ test_idx ] # \u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u62df\u5408\u6a21\u578b model . fit ( xtrain , ytrain ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( xtest ) # \u8ba1\u7b97\u6298\u53e0\u7684\u51c6\u786e\u5ea6 fold_accuracy = metrics . accuracy_score ( ytest , preds ) accuracies . append ( fold_accuracy ) # \u8fd4\u56de\u5e73\u5747\u51c6\u786e\u5ea6\u7684\u8d1f\u6570\uff08\u56e0\u4e3a skopt \u4f7f\u7528\u8d1f\u6570\u6765\u6700\u5c0f\u5316\u76ee\u6807\u51fd\u6570\uff09 return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u53d6\u7279\u5f81\u77e9\u9635 X\uff08\u53bb\u6389\"price_range\"\u5217\uff09 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u5b9a\u4e49\u8d85\u53c2\u6570\u641c\u7d22\u7a7a\u95f4 param_space param_space = [ space . Integer ( 3 , 15 , name = \"max_depth\" ), space . Integer ( 100 , 1500 , name = \"n_estimators\" ), space . Categorical ([ \"gini\" , \"entropy\" ], name = \"criterion\" ), space . Real ( 0.01 , 1 , prior = \"uniform\" , name = \"max_features\" ) ] # \u5b9a\u4e49\u8d85\u53c2\u6570\u7684\u540d\u79f0\u5217\u8868 param_names param_names = [ \"max_depth\" , \"n_estimators\" , \"criterion\" , \"max_features\" ] # \u521b\u5efa\u51fd\u6570 optimization_function\uff0c\u7528\u4e8e\u4f20\u9012\u7ed9 gp_minimize optimization_function = partial ( optimize , param_names = param_names , x = X , y = y ) # \u4f7f\u7528 Bayesian Optimization\uff08\u57fa\u4e8e\u8d1d\u53f6\u65af\u4f18\u5316\uff09\u6765\u641c\u7d22\u6700\u4f73\u8d85\u53c2\u6570 result = gp_minimize ( optimization_function , dimensions = param_space , n_calls = 15 , n_random_starts = 10 , verbose = 10 ) # \u83b7\u53d6\u6700\u4f73\u8d85\u53c2\u6570\u7684\u5b57\u5178 best_params = dict ( zip ( param_names , result . x ) ) # \u6253\u5370\u51fa\u627e\u5230\u7684\u6700\u4f73\u8d85\u53c2\u6570 print ( best_params ) \u8fd9\u540c\u6837\u4f1a\u4ea7\u751f\u5927\u91cf\u8f93\u51fa\uff0c\u6700\u540e\u4e00\u90e8\u5206\u5982\u4e0b\u6240\u793a\u3002 Iteration No : 14 started . Searching for the next optimal point . Iteration No : 14 ended . Search finished for the next optimal point . Time taken : 4.7793 Function value obtained : - 0.9075 Current minimum : - 0.9075 Iteration No : 15 started . Searching for the next optimal point . Iteration No : 15 ended . Search finished for the next optimal point . Time taken : 49.4186 Function value obtained : - 0.9075 Current minimum : - 0.9075 { 'max_depth' : 12 , 'n_estimators' : 100 , 'criterion' : 'entropy' , 'max_features' : 1.0 } \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6210\u529f\u7a81\u7834\u4e86 0.90 \u7684\u51c6\u786e\u7387\u3002\u8fd9\u771f\u662f\u592a\u795e\u5947\u4e86\uff01 \u6211\u4eec\u8fd8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6bb5\u67e5\u770b\uff08\u7ed8\u5236\uff09\u6211\u4eec\u662f\u5982\u4f55\u5b9e\u73b0\u6536\u655b\u7684\u3002 from skopt.plots import plot_convergence plot_convergence ( result ) \u6536\u655b\u56fe\u5982\u56fe 2 \u6240\u793a\u3002 \u56fe 2\uff1a\u968f\u673a\u68ee\u6797\u53c2\u6570\u4f18\u5316\u7684\u6536\u655b\u56fe Scikit- optimize \u5c31\u662f\u8fd9\u6837\u4e00\u4e2a\u5e93\u3002 hyperopt \u4f7f\u7528\u6811\u72b6\u7ed3\u6784\u8d1d\u53f6\u65af\u4f30\u8ba1\u5668\uff08TPE\uff09\u6765\u627e\u5230\u6700\u4f18\u53c2\u6570\u3002\u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u7247\u6bb5\uff0c\u6211\u5728\u4f7f\u7528 hyperopt \u65f6\u5bf9\u4e4b\u524d\u7684\u4ee3\u7801\u505a\u4e86\u6700\u5c0f\u7684\u6539\u52a8\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from hyperopt import hp , fmin , tpe , Trials from hyperopt.pyll.base import scope def optimize ( params , x , y ): model = ensemble . RandomForestClassifier ( ** params ) kf = model_selection . StratifiedKFold ( n_splits = 5 ) ... return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/mobile_train.csv\" ) X = df . drop ( \"price_range\" , axis = 1 ) . values y = df . price_range . values # \u5b9a\u4e49\u641c\u7d22\u7a7a\u95f4\uff08\u6574\u578b\u3001\u6d6e\u70b9\u6570\u578b\u3001\u9009\u62e9\u578b\uff09 param_space = { \"max_depth\" : scope . int ( hp . quniform ( \"max_depth\" , 1 , 15 , 1 )), \"n_estimators\" : scope . int ( hp . quniform ( \"n_estimators\" , 100 , 1500 , 1 ) ), \"criterion\" : hp . choice ( \"criterion\" , [ \"gini\" , \"entropy\" ]), \"max_features\" : hp . uniform ( \"max_features\" , 0 , 1 ) } # \u5305\u88c5\u51fd\u6570 optimization_function = partial ( optimize , x = X , y = y ) # \u5f00\u59cb\u8bad\u7ec3 trials = Trials () # \u6700\u5c0f\u5316\u76ee\u6807\u503c hopt = fmin ( fn = optimization_function , space = param_space , algo = tpe . suggest , max_evals = 15 , trials = trials ) #\u6253\u5370\u6700\u4f73\u53c2\u6570 print ( hopt ) \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u4e0e\u4e4b\u524d\u7684\u4ee3\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u4f60\u5fc5\u987b\u4ee5\u4e0d\u540c\u7684\u683c\u5f0f\u5b9a\u4e49\u53c2\u6570\u7a7a\u95f4\uff0c\u8fd8\u9700\u8981\u6539\u53d8\u5b9e\u9645\u4f18\u5316\u90e8\u5206\uff0c\u7528 hyperopt \u4ee3\u66ff gp_minimize\u3002\u7ed3\u679c\u76f8\u5f53\u4e0d\u9519\uff01 \u276f python rf_hyperopt . py 100 %| \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 15 / 15 [ 04 : 38 < 00 : 00 , 18.57 s / trial , best loss : - 0.9095000000000001 ] { 'criterion' : 1 , 'max_depth' : 11.0 , 'max_features' : 0.821163568049807 , 'n_estimators' : 806.0 } \u6211\u4eec\u5f97\u5230\u4e86\u6bd4\u4ee5\u524d\u66f4\u597d\u7684\u51c6\u786e\u5ea6\u548c\u4e00\u7ec4\u53ef\u4ee5\u4f7f\u7528\u7684\u53c2\u6570\u3002\u8bf7\u6ce8\u610f\uff0c\u6700\u7ec8\u7ed3\u679c\u4e2d\u7684\u6807\u51c6\u662f 1\u3002\u8fd9\u610f\u5473\u7740\u9009\u62e9\u4e86 1\uff0c\u5373\u71b5\u3002 \u4e0a\u8ff0\u8c03\u6574\u8d85\u53c2\u6570\u7684\u65b9\u6cd5\u662f\u6700\u5e38\u89c1\u7684\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u6a21\u578b\uff1a\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001\u57fa\u4e8e\u6811\u7684\u65b9\u6cd5\u3001\u68af\u5ea6\u63d0\u5347\u6a21\u578b\uff08\u5982 xgboost\u3001lightgbm\uff09\uff0c\u751a\u81f3\u795e\u7ecf\u7f51\u7edc\uff01 \u867d\u7136\u8fd9\u4e9b\u65b9\u6cd5\u5df2\u7ecf\u5b58\u5728\uff0c\u4f46\u5b66\u4e60\u65f6\u5fc5\u987b\u4ece\u624b\u52a8\u8c03\u6574\u8d85\u53c2\u6570\u5f00\u59cb\uff0c\u5373\u624b\u5de5\u8c03\u6574\u3002\u624b\u52a8\u8c03\u6574\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5b66\u4e60\u57fa\u7840\u77e5\u8bc6\uff0c\u4f8b\u5982\uff0c\u5728\u68af\u5ea6\u63d0\u5347\u4e2d\uff0c\u5f53\u4f60\u589e\u52a0\u6df1\u5ea6\u65f6\uff0c\u4f60\u5e94\u8be5\u964d\u4f4e\u5b66\u4e60\u7387\u3002\u5982\u679c\u4f7f\u7528\u81ea\u52a8\u5de5\u5177\uff0c\u5c31\u65e0\u6cd5\u5b66\u4e60\u5230\u8fd9\u4e00\u70b9\u3002\u8bf7\u53c2\u8003\u4e0b\u8868\uff0c\u4e86\u89e3\u5e94\u5982\u4f55\u8c03\u6574\u3002RS* \u8868\u793a\u968f\u673a\u641c\u7d22\u5e94\u8be5\u66f4\u597d\u3002 \u4e00\u65e6\u4f60\u80fd\u66f4\u597d\u5730\u624b\u52a8\u8c03\u6574\u53c2\u6570\uff0c\u4f60\u751a\u81f3\u53ef\u80fd\u4e0d\u9700\u8981\u4efb\u4f55\u81ea\u52a8\u8d85\u53c2\u6570\u8c03\u6574\u3002\u521b\u5efa\u5927\u578b\u6a21\u578b\u6216\u5f15\u5165\u5927\u91cf\u7279\u5f81\u65f6\uff0c\u4e5f\u5bb9\u6613\u9020\u6210\u8bad\u7ec3\u6570\u636e\u7684\u8fc7\u5ea6\u62df\u5408\u3002\u4e3a\u907f\u514d\u8fc7\u5ea6\u62df\u5408\uff0c\u9700\u8981\u5728\u8bad\u7ec3\u6570\u636e\u7279\u5f81\u4e2d\u5f15\u5165\u566a\u58f0\u6216\u5bf9\u4ee3\u4ef7\u51fd\u6570\u8fdb\u884c\u60e9\u7f5a\u3002\u8fd9\u79cd\u60e9\u7f5a\u79f0\u4e3a \u6b63\u5219\u5316 \uff0c\u6709\u52a9\u4e8e\u6cdb\u5316\u6a21\u578b\u3002\u5728\u7ebf\u6027\u6a21\u578b\u4e2d\uff0c\u6700\u5e38\u89c1\u7684\u6b63\u5219\u5316\u7c7b\u578b\u662f L1 \u548c L2\u3002L1 \u4e5f\u79f0\u4e3a Lasso \u56de\u5f52\uff0cL2 \u79f0\u4e3a Ridge \u56de\u5f52\u3002\u8bf4\u5230\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528dropout\u3001\u6dfb\u52a0\u589e\u5f3a\u3001\u566a\u58f0\u7b49\u65b9\u6cd5\u5bf9\u6a21\u578b\u8fdb\u884c\u6b63\u5219\u5316\u3002\u5229\u7528\u8d85\u53c2\u6570\u4f18\u5316\uff0c\u8fd8\u53ef\u4ee5\u627e\u5230\u6b63\u786e\u7684\u60e9\u7f5a\u65b9\u6cd5\u3002 Model Optimize Range of values Linear Regression - fit_intercept - normalize - True/False - True/False Ridge - alpha - fit_intercept - normalize - 0.01, 0.1, 1.0, 10, 100 - True/False - True/False k-neighbors - n_neighbors - p - 2, 4, 8, 16, ... - 2, 3, ... SVM - C - gamma - class_weight - 0.001, 0.01, ...,10, 100, 1000 - 'auto', RS* - 'balanced', None Logistic Regression - Penalyt - C - L1 or L2 - 0.001, 0.01, ..., 10, ..., 100 Lasso - Alpha - Normalize - 0.1, 1.0, 10 - True/False Random Forest - n_estimators - max_depth - min_samples_split - min_samples_leaf - max features - 120, 300, 500, 800, 1200 - 5, 8, 15, 25, 30, None - 1, 2, 5, 10, 15, 100 - log2, sqrt, None XGBoost - eta - gamma - max_depth - min_child_weight - subsample - colsample_bytree - lambda - alpha - 0.01, 0.015, 0.025, 0.05, 0.1 - 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0 - 3, 5, 7, 9, 12, 15, 17, 25 - 1, 3, 5, 7 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.01, 0.1, 1.0, RS - 0, 0.1, 0.5, 1.0, RS","title":"\u8d85\u53c2\u6570\u4f18\u5316"}]} \ No newline at end of file +{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"AAAMLP-CN \u65b0\u7279\u6027 2023.10.25 \ud83d\ude0e\u5b8c\u6210\u5168\u90e8\u7ffb\u8bd1 \ud83d\udcdd\u8ba1\u5212\u5bf9kaggle\u6e38\u4e50\u56ed\u7cfb\u5217\u4f18\u79c0\u89e3\u51b3\u65b9\u6848\u4ee3\u7801\u8fdb\u884c\u89e3\u6790 2023.09.07 \u26a1\u4fee\u6b63\u90e8\u5206\u5df2\u77e5\u6587\u5b57\u9519\u8bef\u548c\u4ee3\u7801\u9519\u8bef \ud83e\udd17\u6dfb\u52a0 \u5728\u7ebf\u6587\u6863 \u7ffb\u8bd1\u8fdb\u7a0b 2023.09.12 \u6dfb\u52a0\u7ae0\u8282\uff1a \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u3001 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5 \u7b80\u4ecb Abhishek Thakur\uff0c\u5f88\u591a kaggler \u5bf9\u4ed6\u90fd\u975e\u5e38\u719f\u6089\uff0c2017 \u5e74\uff0c\u4ed6\u5728 Linkedin \u53d1\u8868\u4e86\u4e00\u7bc7\u540d\u4e3a Approaching (Almost) Any Machine Learning Problem \u7684\u6587\u7ae0\uff0c\u4ecb\u7ecd\u4ed6\u5efa\u7acb\u7684\u4e00\u4e2a\u81ea\u52a8\u7684\u673a\u5668\u5b66\u4e60\u6846\u67b6\uff0c\u51e0\u4e4e\u53ef\u4ee5\u89e3\u51b3\u4efb\u4f55\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u8fd9\u7bc7\u6587\u7ae0\u66fe\u706b\u904d Kaggle\u3002 Abhishek \u5728 Kaggle \u4e0a\u7684\u6210\u5c31\uff1a Competitions Grandmaster\uff0817 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 3\uff09 Kernels Expert \uff08Kagglers \u6392\u540d\u524d 1\uff05\uff09 Discussion Grandmaster\uff0865 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 2\uff09 \u76ee\u524d\uff0cAbhishek \u5728\u632a\u5a01 boost \u516c\u53f8\u62c5\u4efb\u9996\u5e2d\u6570\u636e\u79d1\u5b66\u5bb6\u7684\u804c\u4f4d\uff0c\u8fd9\u662f\u4e00\u5bb6\u4e13\u95e8\u4ece\u4e8b\u4f1a\u8bdd\u4eba\u5de5\u667a\u80fd\u7684\u8f6f\u4ef6\u516c\u53f8\u3002 \u672c\u6587\u5bf9 Approaching (Almost) Any Machine Learning Problem \u8fdb\u884c\u4e86 \u4e2d\u6587\u7ffb\u8bd1 \uff0c\u7531\u4e8e\u672c\u4eba\u6c34\u5e73\u6709\u9650\uff0c\u4e14\u672a\u4f7f\u7528\u673a\u5668\u7ffb\u8bd1\uff0c\u53ef\u80fd\u6709\u90e8\u5206\u8a00\u8bed\u4e0d\u901a\u987a\u6216\u672c\u571f\u5316\u7a0b\u5ea6\u4e0d\u8db3\uff0c\u4e5f\u8bf7\u5927\u5bb6\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u591a\u63d0\u4f9b\u5b9d\u8d35\u610f\u89c1\u3002\u53e6\u9644\u4e0a\u4e66\u7c4d\u539f \u9879\u76ee\u5730\u5740 \uff0c \u8f6c\u8f7d\u8bf7\u4e00\u5b9a\u6807\u660e\u51fa\u5904\uff01 \u672c\u9879\u76ee \u652f\u6301\u5728\u7ebf\u9605\u8bfb \uff0c\u65b9\u4fbf\u60a8\u968f\u65f6\u968f\u5730\u8fdb\u884c\u67e5\u9605\u3002 \u56e0\u4e3a\u6709\u51e0\u7ae0\u5185\u5bb9\u592a\u8fc7\u57fa\u7840\uff0c\u6240\u4ee5\u672a\u8fdb\u884c\u7ffb\u8bd1\uff0c\u8be6\u7ec6\u60c5\u51b5\u8bf7\u53c2\u7167\u4e66\u7c4d\u76ee\u5f55\uff1a \u51c6\u5907\u73af\u5883\uff08\u5df2\u7ffb\u8bd1\uff09 \u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u4ea4\u53c9\u68c0\u9a8c\uff08\u5df2\u7ffb\u8bd1\uff09 \u8bc4\u4f30\u6307\u6807\uff08\u5df2\u7ffb\u8bd1\uff09 - \u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u5904\u7406\u5206\u7c7b\u53d8\u91cf\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u5de5\u7a0b\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u9009\u62e9\uff08\u5df2\u7ffb\u8bd1\uff09 \u8d85\u53c2\u6570\u4f18\u5316\uff08\u5df2\u7ffb\u8bd1\uff09 \u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6211\u5c06\u4f1a\u628a\u5b8c\u6574\u7684\u7ffb\u8bd1\u7248 Markdown \u6587\u4ef6\u4e0a\u4f20\u5230 GitHub\uff0c\u4ee5\u4f9b\u5927\u5bb6\u514d\u8d39\u4e0b\u8f7d\u548c\u9605\u8bfb\u3002\u4e3a\u4e86\u6700\u4f73\u7684\u9605\u8bfb\u4f53\u9a8c\uff0c\u63a8\u8350\u4f7f\u7528 PDF \u683c\u5f0f\u6216\u662f\u5728\u7ebf\u9605\u8bfb\u8fdb\u884c\u67e5\u770b \u82e5\u60a8\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u53d1\u73b0\u4efb\u4f55\u9519\u8bef\u6216\u4e0d\u51c6\u786e\u4e4b\u5904\uff0c\u975e\u5e38\u6b22\u8fce\u901a\u8fc7\u63d0\u4ea4 Issue \u6216 Pull Request \u6765\u534f\u52a9\u6211\u8fdb\u884c\u4fee\u6b63\u3002 \u968f\u7740\u65f6\u95f4\u63a8\u79fb\uff0c\u6211\u53ef\u80fd\u4f1a \u7ee7\u7eed\u7ffb\u8bd1\u5c1a\u672a\u5b8c\u6210\u7684\u7ae0\u8282 \u3002\u5982\u679c\u60a8\u89c9\u5f97\u8fd9\u4e2a\u9879\u76ee\u5bf9\u60a8\u6709\u5e2e\u52a9\uff0c\u8bf7\u4e0d\u541d\u7ed9\u4e88 Star \u6216\u8005\u8fdb\u884c\u5173\u6ce8\u3002","title":"\u524d\u8a00"},{"location":"#aaamlp-cn","text":"","title":"AAAMLP-CN"},{"location":"#_1","text":"","title":"\u65b0\u7279\u6027"},{"location":"#20231025","text":"\ud83d\ude0e\u5b8c\u6210\u5168\u90e8\u7ffb\u8bd1 \ud83d\udcdd\u8ba1\u5212\u5bf9kaggle\u6e38\u4e50\u56ed\u7cfb\u5217\u4f18\u79c0\u89e3\u51b3\u65b9\u6848\u4ee3\u7801\u8fdb\u884c\u89e3\u6790","title":"2023.10.25"},{"location":"#20230907","text":"\u26a1\u4fee\u6b63\u90e8\u5206\u5df2\u77e5\u6587\u5b57\u9519\u8bef\u548c\u4ee3\u7801\u9519\u8bef \ud83e\udd17\u6dfb\u52a0 \u5728\u7ebf\u6587\u6863","title":"2023.09.07"},{"location":"#_2","text":"2023.09.12 \u6dfb\u52a0\u7ae0\u8282\uff1a \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u3001 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5","title":"\u7ffb\u8bd1\u8fdb\u7a0b"},{"location":"#_3","text":"Abhishek Thakur\uff0c\u5f88\u591a kaggler \u5bf9\u4ed6\u90fd\u975e\u5e38\u719f\u6089\uff0c2017 \u5e74\uff0c\u4ed6\u5728 Linkedin \u53d1\u8868\u4e86\u4e00\u7bc7\u540d\u4e3a Approaching (Almost) Any Machine Learning Problem \u7684\u6587\u7ae0\uff0c\u4ecb\u7ecd\u4ed6\u5efa\u7acb\u7684\u4e00\u4e2a\u81ea\u52a8\u7684\u673a\u5668\u5b66\u4e60\u6846\u67b6\uff0c\u51e0\u4e4e\u53ef\u4ee5\u89e3\u51b3\u4efb\u4f55\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u8fd9\u7bc7\u6587\u7ae0\u66fe\u706b\u904d Kaggle\u3002 Abhishek \u5728 Kaggle \u4e0a\u7684\u6210\u5c31\uff1a Competitions Grandmaster\uff0817 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 3\uff09 Kernels Expert \uff08Kagglers \u6392\u540d\u524d 1\uff05\uff09 Discussion Grandmaster\uff0865 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 2\uff09 \u76ee\u524d\uff0cAbhishek \u5728\u632a\u5a01 boost \u516c\u53f8\u62c5\u4efb\u9996\u5e2d\u6570\u636e\u79d1\u5b66\u5bb6\u7684\u804c\u4f4d\uff0c\u8fd9\u662f\u4e00\u5bb6\u4e13\u95e8\u4ece\u4e8b\u4f1a\u8bdd\u4eba\u5de5\u667a\u80fd\u7684\u8f6f\u4ef6\u516c\u53f8\u3002 \u672c\u6587\u5bf9 Approaching (Almost) Any Machine Learning Problem \u8fdb\u884c\u4e86 \u4e2d\u6587\u7ffb\u8bd1 \uff0c\u7531\u4e8e\u672c\u4eba\u6c34\u5e73\u6709\u9650\uff0c\u4e14\u672a\u4f7f\u7528\u673a\u5668\u7ffb\u8bd1\uff0c\u53ef\u80fd\u6709\u90e8\u5206\u8a00\u8bed\u4e0d\u901a\u987a\u6216\u672c\u571f\u5316\u7a0b\u5ea6\u4e0d\u8db3\uff0c\u4e5f\u8bf7\u5927\u5bb6\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u591a\u63d0\u4f9b\u5b9d\u8d35\u610f\u89c1\u3002\u53e6\u9644\u4e0a\u4e66\u7c4d\u539f \u9879\u76ee\u5730\u5740 \uff0c \u8f6c\u8f7d\u8bf7\u4e00\u5b9a\u6807\u660e\u51fa\u5904\uff01 \u672c\u9879\u76ee \u652f\u6301\u5728\u7ebf\u9605\u8bfb \uff0c\u65b9\u4fbf\u60a8\u968f\u65f6\u968f\u5730\u8fdb\u884c\u67e5\u9605\u3002 \u56e0\u4e3a\u6709\u51e0\u7ae0\u5185\u5bb9\u592a\u8fc7\u57fa\u7840\uff0c\u6240\u4ee5\u672a\u8fdb\u884c\u7ffb\u8bd1\uff0c\u8be6\u7ec6\u60c5\u51b5\u8bf7\u53c2\u7167\u4e66\u7c4d\u76ee\u5f55\uff1a \u51c6\u5907\u73af\u5883\uff08\u5df2\u7ffb\u8bd1\uff09 \u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u4ea4\u53c9\u68c0\u9a8c\uff08\u5df2\u7ffb\u8bd1\uff09 \u8bc4\u4f30\u6307\u6807\uff08\u5df2\u7ffb\u8bd1\uff09 - \u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u5904\u7406\u5206\u7c7b\u53d8\u91cf\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u5de5\u7a0b\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u9009\u62e9\uff08\u5df2\u7ffb\u8bd1\uff09 \u8d85\u53c2\u6570\u4f18\u5316\uff08\u5df2\u7ffb\u8bd1\uff09 \u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6211\u5c06\u4f1a\u628a\u5b8c\u6574\u7684\u7ffb\u8bd1\u7248 Markdown \u6587\u4ef6\u4e0a\u4f20\u5230 GitHub\uff0c\u4ee5\u4f9b\u5927\u5bb6\u514d\u8d39\u4e0b\u8f7d\u548c\u9605\u8bfb\u3002\u4e3a\u4e86\u6700\u4f73\u7684\u9605\u8bfb\u4f53\u9a8c\uff0c\u63a8\u8350\u4f7f\u7528 PDF \u683c\u5f0f\u6216\u662f\u5728\u7ebf\u9605\u8bfb\u8fdb\u884c\u67e5\u770b \u82e5\u60a8\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u53d1\u73b0\u4efb\u4f55\u9519\u8bef\u6216\u4e0d\u51c6\u786e\u4e4b\u5904\uff0c\u975e\u5e38\u6b22\u8fce\u901a\u8fc7\u63d0\u4ea4 Issue \u6216 Pull Request \u6765\u534f\u52a9\u6211\u8fdb\u884c\u4fee\u6b63\u3002 \u968f\u7740\u65f6\u95f4\u63a8\u79fb\uff0c\u6211\u53ef\u80fd\u4f1a \u7ee7\u7eed\u7ffb\u8bd1\u5c1a\u672a\u5b8c\u6210\u7684\u7ae0\u8282 \u3002\u5982\u679c\u60a8\u89c9\u5f97\u8fd9\u4e2a\u9879\u76ee\u5bf9\u60a8\u6709\u5e2e\u52a9\uff0c\u8bf7\u4e0d\u541d\u7ed9\u4e88 Star \u6216\u8005\u8fdb\u884c\u5173\u6ce8\u3002","title":"\u7b80\u4ecb"},{"location":"%E4%BA%A4%E5%8F%89%E6%A3%80%E9%AA%8C/","text":"\u4ea4\u53c9\u68c0\u9a8c \u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u6ca1\u6709\u5efa\u7acb\u4efb\u4f55\u6a21\u578b\u3002\u539f\u56e0\u5f88\u7b80\u5355\uff0c\u5728\u521b\u5efa\u4efb\u4f55\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4ee5\u53ca\u5982\u4f55\u6839\u636e\u6570\u636e\u96c6\u9009\u62e9\u6700\u4f73\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\u3002 \u90a3\u4e48\uff0c\u4ec0\u4e48\u662f \u4ea4\u53c9\u68c0\u9a8c \uff0c\u6211\u4eec\u4e3a\u4ec0\u4e48\u8981\u5173\u6ce8\u5b83\uff1f \u5173\u4e8e\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u53ef\u4ee5\u627e\u5230\u591a\u79cd\u5b9a\u4e49\u3002\u6211\u7684\u5b9a\u4e49\u53ea\u6709\u4e00\u53e5\u8bdd\uff1a\u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fc7\u7a0b\u4e2d\u7684\u4e00\u4e2a\u6b65\u9aa4\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u786e\u4fdd\u6a21\u578b\u51c6\u786e\u62df\u5408\u6570\u636e\uff0c\u540c\u65f6\u786e\u4fdd\u6211\u4eec\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u4f46\u8fd9\u53c8\u5f15\u51fa\u4e86\u53e6\u4e00\u4e2a\u8bcd\uff1a \u8fc7\u62df\u5408 \u3002 \u8981\u89e3\u91ca\u8fc7\u62df\u5408\uff0c\u6211\u8ba4\u4e3a\u6700\u597d\u5148\u770b\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6709\u4e00\u4e2a\u76f8\u5f53\u6709\u540d\u7684\u7ea2\u9152\u8d28\u91cf\u6570\u636e\u96c6\uff08 red wine quality dataset \uff09\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 11 \u4e2a\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u8fd9\u4e9b\u7279\u5f81\u51b3\u5b9a\u4e86\u7ea2\u9152\u7684\u8d28\u91cf\u3002 \u8fd9\u4e9b\u5c5e\u6027\u5305\u62ec\uff1a \u56fa\u5b9a\u9178\u5ea6\uff08fixed acidity\uff09 \u6325\u53d1\u6027\u9178\u5ea6\uff08volatile acidity\uff09 \u67e0\u6aac\u9178\uff08citric acid\uff09 \u6b8b\u7559\u7cd6\uff08residual sugar\uff09 \u6c2f\u5316\u7269\uff08chlorides\uff09 \u6e38\u79bb\u4e8c\u6c27\u5316\u786b\uff08free sulfur dioxide\uff09 \u4e8c\u6c27\u5316\u786b\u603b\u91cf\uff08total sulfur dioxide\uff09 \u5bc6\u5ea6\uff08density\uff09 PH \u503c\uff08pH\uff09 \u786b\u9178\u76d0\uff08sulphates\uff09 \u9152\u7cbe\uff08alcohol\uff09 \u6839\u636e\u8fd9\u4e9b\u4e0d\u540c\u7279\u5f81\uff0c\u6211\u4eec\u9700\u8981\u9884\u6d4b\u7ea2\u8461\u8404\u9152\u7684\u8d28\u91cf\uff0c\u8d28\u91cf\u503c\u4ecb\u4e8e 0 \u5230 10 \u4e4b\u95f4\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u600e\u6837\u7684\u3002 import pandas as pd df = pd . read_csv ( \"winequality-red.csv\" ) \u56fe 1:\u7ea2\u8461\u8404\u9152\u8d28\u91cf\u6570\u636e\u96c6\u7b80\u5355\u5c55\u793a \u6211\u4eec\u53ef\u4ee5\u5c06\u8fd9\u4e2a\u95ee\u9898\u89c6\u4e3a\u5206\u7c7b\u95ee\u9898\uff0c\u4e5f\u53ef\u4ee5\u89c6\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u9009\u62e9\u5206\u7c7b\u3002\u7136\u800c\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u503c\u5305\u542b 6 \u79cd\u8d28\u91cf\u503c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u6240\u6709\u8d28\u91cf\u503c\u6620\u5c04\u5230 0 \u5230 5 \u4e4b\u95f4\u3002 # \u4e00\u4e2a\u6620\u5c04\u5b57\u5178\uff0c\u7528\u4e8e\u5c06\u8d28\u91cf\u503c\u4ece 0 \u5230 5 \u8fdb\u884c\u6620\u5c04 quality_mapping = { 3 : 0 , 4 : 1 , 5 : 2 , 6 : 3 , 7 : 4 , 8 : 5 } # \u4f60\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 map \u51fd\u6570\u4ee5\u53ca\u4efb\u4f55\u5b57\u5178\uff0c # \u6765\u8f6c\u6362\u7ed9\u5b9a\u5217\u4e2d\u7684\u503c\u4e3a\u5b57\u5178\u4e2d\u7684\u503c df . loc [:, \"quality\" ] = df . quality . map ( quality_mapping ) \u5f53\u6211\u4eec\u770b\u5927\u8fd9\u4e9b\u6570\u636e\u5e76\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u6211\u4eec\u8111\u6d77\u4e2d\u4f1a\u6d6e\u73b0\u51fa\u5f88\u591a\u53ef\u4ee5\u5e94\u7528\u7684\u7b97\u6cd5\uff0c\u4e5f\u8bb8\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u4ece\u4e00\u5f00\u59cb\u5c31\u6df1\u5165\u7814\u7a76\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u5c31\u6709\u70b9\u7275\u5f3a\u4e86\u3002\u6240\u4ee5\uff0c\u8ba9\u6211\u4eec\u4ece\u7b80\u5355\u7684\u3001\u6211\u4eec\u4e5f\u80fd\u53ef\u89c6\u5316\u7684\u4e1c\u897f\u5f00\u59cb\uff1a\u51b3\u7b56\u6811\u3002 \u5728\u5f00\u59cb\u4e86\u89e3\u4ec0\u4e48\u662f\u8fc7\u62df\u5408\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 1599 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u4fdd\u7559 1000 \u4e2a\u6837\u672c\u7528\u4e8e\u8bad\u7ec3\uff0c599 \u4e2a\u6837\u672c\u4f5c\u4e3a\u4e00\u4e2a\u5355\u72ec\u7684\u96c6\u5408\u3002 \u4ee5\u4e0b\u4ee3\u7801\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u5212\u5206\uff1a # \u4f7f\u7528 frac=1 \u7684 sample \u65b9\u6cd5\u6765\u6253\u4e71 dataframe # \u7531\u4e8e\u6253\u4e71\u540e\u7d22\u5f15\u4f1a\u6539\u53d8\uff0c\u6240\u4ee5\u6211\u4eec\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u9009\u53d6\u524d 1000 \u884c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e df_train = df . head ( 1000 ) # \u9009\u53d6\u6700\u540e\u7684 599 \u884c\u4f5c\u4e3a\u6d4b\u8bd5/\u9a8c\u8bc1\u6570\u636e df_test = df . tail ( 599 ) \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u5728\u8bad\u7ec3\u96c6\u4e0a\u4f7f\u7528 scikit-learn \u8bad\u7ec3\u4e00\u4e2a\u51b3\u7b56\u6811\u6a21\u578b\u3002 # \u4ece scikit-learn \u5bfc\u5165\u9700\u8981\u7684\u6a21\u5757 from sklearn import tree from sklearn import metrics # \u521d\u59cb\u5316\u4e00\u4e2a\u51b3\u7b56\u6811\u5206\u7c7b\u5668\uff0c\u8bbe\u7f6e\u6700\u5927\u6df1\u5ea6\u4e3a 3 clf = tree . DecisionTreeClassifier ( max_depth = 3 ) # \u9009\u62e9\u4f60\u60f3\u8981\u8bad\u7ec3\u6a21\u578b\u7684\u5217 # \u8fd9\u4e9b\u5217\u4f5c\u4e3a\u6a21\u578b\u7684\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u4f7f\u7528\u4e4b\u524d\u6620\u5c04\u7684\u8d28\u91cf\u4ee5\u53ca\u63d0\u4f9b\u7684\u7279\u5f81\u6765\u8bad\u7ec3\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5c06\u51b3\u7b56\u6811\u5206\u7c7b\u5668\u7684\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u8bbe\u4e3a 3\u3002\u8be5\u6a21\u578b\u7684\u6240\u6709\u5176\u4ed6\u53c2\u6570\u5747\u4fdd\u6301\u9ed8\u8ba4\u503c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u6d4b\u8bd5\u8be5\u6a21\u578b\u7684\u51c6\u786e\u6027\uff1a # \u5728\u8bad\u7ec3\u96c6\u4e0a\u751f\u6210\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) # \u5728\u6d4b\u8bd5\u96c6\u4e0a\u751f\u6210\u9884\u6d4b test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) # \u8ba1\u7b97\u6d4b\u8bd5\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) \u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u5206\u522b\u4e3a 58.9%\u548c 54.25%\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u589e\u52a0\u5230 7\uff0c\u5e76\u91cd\u590d\u4e0a\u8ff0\u8fc7\u7a0b\u3002\u8fd9\u6837\uff0c\u8bad\u7ec3\u51c6\u786e\u7387\u4e3a 76.6%\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4e3a 57.3%\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4f7f\u7528\u51c6\u786e\u7387\uff0c\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u662f\u6700\u76f4\u63a5\u7684\u6307\u6807\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u95ee\u9898\u6765\u8bf4\uff0c\u5b83\u53ef\u80fd\u4e0d\u662f\u6700\u597d\u7684\u6307\u6807\u3002\u6211\u4eec\u53ef\u4ee5\u6839\u636e\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u4e0d\u540c\u503c\u6765\u8ba1\u7b97\u8fd9\u4e9b\u51c6\u786e\u7387\uff0c\u5e76\u7ed8\u5236\u66f2\u7ebf\u56fe\u3002 # \u6ce8\u610f\uff1a\u8fd9\u6bb5\u4ee3\u7801\u5728 Jupyter \u7b14\u8bb0\u672c\u4e2d\u7f16\u5199 # \u5bfc\u5165 scikit-learn \u7684 tree \u548c metrics from sklearn import tree from sklearn import metrics # \u5bfc\u5165 matplotlib \u548c seaborn # \u7528\u4e8e\u7ed8\u56fe import matplotlib import matplotlib.pyplot as plt import seaborn as sns # \u8bbe\u7f6e\u5168\u5c40\u6807\u7b7e\u6587\u672c\u7684\u5927\u5c0f matplotlib . rc ( 'xtick' , labelsize = 20 ) matplotlib . rc ( 'ytick' , labelsize = 20 ) # \u786e\u4fdd\u56fe\u8868\u76f4\u63a5\u5728\u7b14\u8bb0\u672c\u5185\u663e\u793a % matplotlib inline # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6\u7684\u5217\u8868 # \u6211\u4eec\u4ece 50% \u7684\u51c6\u786e\u5ea6\u5f00\u59cb train_accuracies = [ 0.5 ] test_accuracies = [ 0.5 ] # \u904d\u5386\u51e0\u4e2a\u4e0d\u540c\u7684\u6811\u6df1\u5ea6\u503c for depth in range ( 1 , 25 ): # \u521d\u59cb\u5316\u6a21\u578b clf = tree . DecisionTreeClassifier ( max_depth = depth ) # \u9009\u62e9\u7528\u4e8e\u8bad\u7ec3\u7684\u5217/\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u5728\u7ed9\u5b9a\u7279\u5f81\u4e0a\u62df\u5408\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) # \u521b\u5efa\u8bad\u7ec3\u548c\u6d4b\u8bd5\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) # \u6dfb\u52a0\u51c6\u786e\u5ea6\u5230\u5217\u8868 train_accuracies . append ( train_accuracy ) test_accuracies . append ( test_accuracy ) # \u4f7f\u7528 matplotlib \u548c seaborn \u521b\u5efa\u4e24\u4e2a\u56fe plt . figure ( figsize = ( 10 , 5 )) sns . set_style ( \"whitegrid\" ) plt . plot ( train_accuracies , label = \"train accuracy\" ) plt . plot ( test_accuracies , label = \"test accuracy\" ) plt . legend ( loc = \"upper left\" , prop = { 'size' : 15 }) plt . xticks ( range ( 0 , 26 , 5 )) plt . xlabel ( \"max_depth\" , size = 20 ) plt . ylabel ( \"accuracy\" , size = 20 ) plt . show () \u8fd9\u5c06\u751f\u6210\u5982\u56fe 2 \u6240\u793a\u7684\u66f2\u7ebf\u56fe\u3002 \u56fe 2\uff1a\u4e0d\u540c max_depth \u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u7387\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5f53\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u503c\u4e3a 14 \u65f6\uff0c\u6d4b\u8bd5\u6570\u636e\u7684\u5f97\u5206\u6700\u9ad8\u3002\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u589e\u52a0\u8fd9\u4e2a\u53c2\u6570\u7684\u503c\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4f1a\u4fdd\u6301\u4e0d\u53d8\u6216\u53d8\u5dee\uff0c\u4f46\u8bad\u7ec3\u51c6\u786e\u7387\u4f1a\u4e0d\u65ad\u63d0\u9ad8\u3002\u8fd9\u8bf4\u660e\uff0c\u968f\u7740\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u589e\u52a0\uff0c\u51b3\u7b56\u6811\u6a21\u578b\u5bf9\u8bad\u7ec3\u6570\u636e\u7684\u5b66\u4e60\u6548\u679c\u8d8a\u6765\u8d8a\u597d\uff0c\u4f46\u6d4b\u8bd5\u6570\u636e\u7684\u6027\u80fd\u5374\u4e1d\u6beb\u6ca1\u6709\u63d0\u9ad8\u3002 \u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8fc7\u62df\u5408 \u3002 \u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u5b8c\u5168\u62df\u5408\uff0c\u800c\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5374\u8868\u73b0\u4e0d\u4f73\u3002\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5f88\u597d\u5730\u5b66\u4e60\u8bad\u7ec3\u6570\u636e\uff0c\u4f46\u65e0\u6cd5\u6cdb\u5316\u5230\u672a\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u3002\u5728\u4e0a\u9762\u7684\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u975e\u5e38\u9ad8\u7684\u6a21\u578b\uff0c\u5b83\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u4f1a\u6709\u51fa\u8272\u7684\u7ed3\u679c\uff0c\u4f46\u8fd9\u79cd\u6a21\u578b\u5e76\u4e0d\u5b9e\u7528\uff0c\u56e0\u4e3a\u5b83\u5728\u771f\u5b9e\u4e16\u754c\u7684\u6837\u672c\u6216\u5b9e\u65f6\u6570\u636e\u4e0a\u4e0d\u4f1a\u63d0\u4f9b\u7c7b\u4f3c\u7684\u7ed3\u679c\u3002 \u6709\u4eba\u53ef\u80fd\u4f1a\u8bf4\uff0c\u8fd9\u79cd\u65b9\u6cd5\u5e76\u6ca1\u6709\u8fc7\u62df\u5408\uff0c\u56e0\u4e3a\u6d4b\u8bd5\u96c6\u7684\u51c6\u786e\u7387\u57fa\u672c\u4fdd\u6301\u4e0d\u53d8\u3002\u8fc7\u62df\u5408\u7684\u53e6\u4e00\u4e2a\u5b9a\u4e49\u662f\uff0c\u5f53\u6211\u4eec\u4e0d\u65ad\u63d0\u9ad8\u8bad\u7ec3\u635f\u5931\u65f6\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u5728\u589e\u52a0\u3002\u8fd9\u79cd\u60c5\u51b5\u5728\u795e\u7ecf\u7f51\u7edc\u4e2d\u975e\u5e38\u5e38\u89c1\u3002 \u6bcf\u5f53\u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u90fd\u5fc5\u987b\u5728\u8bad\u7ec3\u671f\u95f4\u76d1\u63a7\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u3002\u5982\u679c\u6211\u4eec\u6709\u4e00\u4e2a\u975e\u5e38\u5927\u7684\u7f51\u7edc\u6765\u5904\u7406\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u6570\u636e\u96c6\uff08\u5373\u6837\u672c\u6570\u975e\u5e38\u5c11\uff09\uff0c\u6211\u4eec\u5c31\u4f1a\u89c2\u5bdf\u5230\uff0c\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u8bad\u7ec3\uff0c\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u90fd\u4f1a\u51cf\u5c11\u3002\u4f46\u662f\uff0c\u5728\u67d0\u4e2a\u65f6\u523b\uff0c\u6d4b\u8bd5\u635f\u5931\u4f1a\u8fbe\u5230\u6700\u5c0f\u503c\uff0c\u4e4b\u540e\uff0c\u5373\u4f7f\u8bad\u7ec3\u635f\u5931\u8fdb\u4e00\u6b65\u51cf\u5c11\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u4f1a\u5f00\u59cb\u589e\u52a0\u3002\u6211\u4eec\u5fc5\u987b\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u6700\u5c0f\u503c\u65f6\u505c\u6b62\u8bad\u7ec3\u3002 \u8fd9\u662f\u5bf9\u8fc7\u62df\u5408\u6700\u5e38\u89c1\u7684\u89e3\u91ca \u3002 \u5965\u5361\u59c6\u5243\u5200\u7528\u7b80\u5355\u7684\u8bdd\u8bf4\uff0c\u5c31\u662f\u4e0d\u8981\u8bd5\u56fe\u628a\u53ef\u4ee5\u7528\u7b80\u5355\u5f97\u591a\u7684\u65b9\u6cd5\u89e3\u51b3\u7684\u4e8b\u60c5\u590d\u6742\u5316\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6848\u5c31\u662f\u6700\u5177\u901a\u7528\u6027\u7684\u89e3\u51b3\u65b9\u6848\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u53ea\u8981\u4f60\u7684\u6a21\u578b\u4e0d\u7b26\u5408\u5965\u5361\u59c6\u5243\u5200\u539f\u5219\uff0c\u5c31\u5f88\u53ef\u80fd\u662f\u8fc7\u62df\u5408\u3002 \u56fe 3\uff1a\u8fc7\u62df\u5408\u7684\u6700\u4e00\u822c\u5b9a\u4e49 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u56de\u5230\u4ea4\u53c9\u68c0\u9a8c\u3002 \u5728\u89e3\u91ca\u8fc7\u62df\u5408\u65f6\uff0c\u6211\u51b3\u5b9a\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u6211\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u53e6\u4e00\u90e8\u5206\u4e0a\u68c0\u67e5\u5176\u6027\u80fd\u3002\u8fd9\u4e5f\u662f\u4ea4\u53c9\u68c0\u9a8c\u7684\u4e00\u79cd\uff0c\u901a\u5e38\u88ab\u79f0\u4e3a \"\u6682\u7559\u96c6\"\uff08 hold-out set \uff09\u3002\u5f53\u6211\u4eec\u62e5\u6709\u5927\u91cf\u6570\u636e\uff0c\u800c\u6a21\u578b\u63a8\u7406\u662f\u4e00\u4e2a\u8017\u65f6\u7684\u8fc7\u7a0b\u65f6\uff0c\u6211\u4eec\u5c31\u4f1a\u4f7f\u7528\u8fd9\u79cd\uff08\u4ea4\u53c9\uff09\u9a8c\u8bc1\u3002 \u4ea4\u53c9\u68c0\u9a8c\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\uff0c\u5b83\u662f\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u6b65\u9aa4\u3002 \u9009\u62e9\u6b63\u786e\u7684\u4ea4\u53c9\u68c0\u9a8c \u53d6\u51b3\u4e8e\u6240\u5904\u7406\u7684\u6570\u636e\u96c6\uff0c\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u9002\u7528\u7684\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u53ef\u80fd\u4e0d\u9002\u7528\u4e8e\u5176\u4ed6\u6570\u636e\u96c6\u3002\u4e0d\u8fc7\uff0c\u6709\u51e0\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u6700\u4e3a\u6d41\u884c\u548c\u5e7f\u6cdb\u4f7f\u7528\u3002 \u5176\u4e2d\u5305\u62ec\uff1a k \u6298\u4ea4\u53c9\u68c0\u9a8c \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c \u5206\u7ec4 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u4ea4\u53c9\u68c0\u9a8c\u662f\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u5c42\u51e0\u4e2a\u90e8\u5206\uff0c\u6211\u4eec\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u5176\u4f59\u90e8\u5206\u4e0a\u8fdb\u884c\u6d4b\u8bd5\u3002\u8bf7\u770b\u56fe 4\u3002 \u56fe 4\uff1a\u5c06\u6570\u636e\u96c6\u62c6\u5206\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u56fe 4 \u548c\u56fe 5 \u8bf4\u660e\uff0c\u5f53\u4f60\u5f97\u5230\u4e00\u4e2a\u6570\u636e\u96c6\u6765\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u65f6\uff0c\u4f60\u4f1a\u628a\u5b83\u4eec\u5206\u6210 \u4e24\u4e2a\u4e0d\u540c\u7684\u96c6\uff1a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u3002\u5f88\u591a\u4eba\u8fd8\u4f1a\u5c06\u5176\u5206\u6210\u7b2c\u4e09\u7ec4\uff0c\u79f0\u4e4b\u4e3a\u6d4b\u8bd5\u96c6\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u5c06\u53ea\u4f7f\u7528\u4e24\u4e2a\u96c6\u3002\u5982\u4f60\u6240\u89c1\uff0c\u6211\u4eec\u5c06\u6837\u672c\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\u8fdb\u884c\u4e86\u5212\u5206\u3002\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a k \u4e2a\u4e92\u4e0d\u5173\u8054\u7684\u4e0d\u540c\u96c6\u5408\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002 \u56fe 5\uff1aK \u6298\u4ea4\u53c9\u68c0\u9a8c \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 KFold \u5c06\u4efb\u4f55\u6570\u636e\u5206\u5272\u6210 k \u4e2a\u76f8\u7b49\u7684\u90e8\u5206\u3002\u6bcf\u4e2a\u6837\u672c\u5206\u914d\u4e00\u4e2a\u4ece 0 \u5230 k-1 \u7684\u503c\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u5b58\u50a8\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a kfold \u7684\u65b0\u5217\uff0c\u5e76\u7528 -1 \u586b\u5145 df [ \"kfold\" ] = - 1 # \u63a5\u4e0b\u6765\u7684\u6b65\u9aa4\u662f\u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4ece model_selection \u6a21\u5757\u521d\u59cb\u5316 kfold \u7c7b kf = model_selection . KFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684 kfold \u5217\uff08enumerate\u7684\u4f5c\u7528\u662f\u8fd4\u56de\u4e00\u4e2a\u8fed\u4ee3\u5668\uff09 for fold , ( trn_ , val_ ) in enumerate ( kf . split ( X = df )): df . loc [ val_ , 'kfold' ] = fold # \u4fdd\u5b58\u5e26\u6709 kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u636e\u96c6\u90fd\u53ef\u4ee5\u4f7f\u7528\u6b64\u6d41\u7a0b\u3002\u4f8b\u5982\uff0c\u5f53\u6570\u636e\u56fe\u50cf\u65f6\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u56fe\u50cf ID\u3001\u56fe\u50cf\u4f4d\u7f6e\u548c\u56fe\u50cf\u6807\u7b7e\u7684 CSV\uff0c\u7136\u540e\u4f7f\u7528\u4e0a\u8ff0\u6d41\u7a0b\u3002 \u53e6\u4e00\u79cd\u91cd\u8981\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u662f \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u6570\u636e\u96c6\uff0c\u5176\u4e2d\u6b63\u6837\u672c\u5360 90%\uff0c\u8d1f\u6837\u672c\u53ea\u5360 10%\uff0c\u90a3\u4e48\u4f60\u5c31\u4e0d\u5e94\u8be5\u4f7f\u7528\u968f\u673a k \u6298\u4ea4\u53c9\u3002\u5bf9\u8fd9\u6837\u7684\u6570\u636e\u96c6\u4f7f\u7528\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u6298\u53e0\u6837\u672c\u5168\u90e8\u4e3a\u8d1f\u6837\u672c\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u66f4\u503e\u5411\u4e8e\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u4ee5\u4fdd\u6301\u6bcf\u4e2a\u6298\u4e2d\u6807\u7b7e\u7684\u6bd4\u4f8b\u4e0d\u53d8\u3002\u56e0\u6b64\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u90fd\u4f1a\u6709\u76f8\u540c\u7684 90% \u6b63\u6837\u672c\u548c 10% \u8d1f\u6837\u672c\u3002\u56e0\u6b64\uff0c\u65e0\u8bba\u60a8\u9009\u62e9\u4ec0\u4e48\u6307\u6807\u8fdb\u884c\u8bc4\u4f30\uff0c\u90fd\u4f1a\u5728\u6240\u6709\u6298\u53e0\u4e2d\u5f97\u5230\u76f8\u4f3c\u7684\u7ed3\u679c\u3002 \u4fee\u6539\u521b\u5efa k \u6298\u4ea4\u53c9\u68c0\u9a8c\u7684\u4ee3\u7801\u4ee5\u521b\u5efa\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5f88\u5bb9\u6613\u3002\u6211\u4eec\u53ea\u9700\u5c06 model_selection.KFold \u66f4\u6539\u4e3a model_selection.StratifiedKFold \uff0c\u5e76\u5728 kf.split(...) \u51fd\u6570\u4e2d\u6307\u5b9a\u8981\u5206\u5c42\u7684\u76ee\u6807\u5217\u3002\u6211\u4eec\u5047\u8bbe CSV \u6570\u636e\u96c6\u6709\u4e00\u5217\u540d\u4e3a \"target\" \uff0c\u5e76\u4e14\u662f\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u4fdd\u5b58\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6dfb\u52a0\u4e00\u4e2a\u65b0\u5217 kfold\uff0c\u5e76\u7528 -1 \u521d\u59cb\u5316 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u53d8\u91cf y = df . target . values # \u521d\u59cb\u5316 StratifiedKFold \u7c7b\uff0c\u8bbe\u7f6e\u6298\u6570\uff08folds\uff09\u4e3a 5 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u4f7f\u7528 StratifiedKFold \u5bf9\u8c61\u7684 split \u65b9\u6cd5\u6765\u83b7\u53d6\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7d22\u5f15 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u5305\u542b kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u5bf9\u4e8e\u8461\u8404\u9152\u6570\u636e\u96c6\uff0c\u6211\u4eec\u6765\u770b\u770b\u6807\u7b7e\u7684\u5206\u5e03\u60c5\u51b5\u3002 b = sns . countplot ( x = 'quality' , data = df ) b . set_xlabel ( \"quality\" , fontsize = 20 ) b . set_ylabel ( \"count\" , fontsize = 20 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u7ee7\u7eed\u4e0a\u9762\u7684\u4ee3\u7801\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u8f6c\u6362\u4e86\u76ee\u6807\u503c\u3002\u4ece\u56fe 6 \u4e2d\u6211\u4eec\u53ef\u4ee5\u770b\u51fa\uff0c\u8d28\u91cf\u504f\u5dee\u5f88\u5927\u3002\u6709\u4e9b\u7c7b\u522b\u6709\u5f88\u591a\u6837\u672c\uff0c\u6709\u4e9b\u5219\u6ca1\u6709\u90a3\u4e48\u591a\u3002\u5982\u679c\u6211\u4eec\u8fdb\u884c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u90a3\u4e48\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u76ee\u6807\u503c\u5206\u5e03\u90fd\u4e0d\u4f1a\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u56fe 6\uff1a\u8461\u8404\u9152\u6570\u636e\u96c6\u4e2d \"\u8d28\u91cf\" \u5206\u5e03\u60c5\u51b5 \u89c4\u5219\u5f88\u7b80\u5355\uff0c\u5982\u679c\u662f\u6807\u51c6\u5206\u7c7b\u95ee\u9898\uff0c\u5c31\u76f2\u76ee\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u4f46\u5982\u679c\u6570\u636e\u91cf\u5f88\u5927\uff0c\u8be5\u600e\u4e48\u529e\u5462\uff1f\u5047\u8bbe\u6211\u4eec\u6709 100 \u4e07\u4e2a\u6837\u672c\u30025 \u500d\u4ea4\u53c9\u68c0\u9a8c\u610f\u5473\u7740\u5728 800k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u5728 200k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u9a8c\u8bc1\u3002\u6839\u636e\u6211\u4eec\u9009\u62e9\u7684\u7b97\u6cd5\uff0c\u5bf9\u4e8e\u8fd9\u6837\u89c4\u6a21\u7684\u6570\u636e\u96c6\u6765\u8bf4\uff0c\u8bad\u7ec3\u751a\u81f3\u9a8c\u8bc1\u90fd\u53ef\u80fd\u975e\u5e38\u6602\u8d35\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u3002 \u521b\u5efa\u4fdd\u6301\u7ed3\u679c\u7684\u8fc7\u7a0b\u4e0e\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u76f8\u540c\u3002\u5bf9\u4e8e\u62e5\u6709 100 \u4e07\u4e2a\u6837\u672c\u7684\u6570\u636e\u96c6\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa 10 \u4e2a\u6298\u53e0\u800c\u4e0d\u662f 5 \u4e2a\uff0c\u5e76\u4fdd\u7559\u5176\u4e2d\u4e00\u4e2a\u6298\u53e0\u4f5c\u4e3a\u4fdd\u7559\u6837\u672c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u6211\u4eec\u5c06\u6709 10 \u4e07\u4e2a\u6837\u672c\u88ab\u4fdd\u7559\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u59cb\u7ec8\u5728\u8fd9\u4e2a\u6837\u672c\u96c6\u4e0a\u8ba1\u7b97\u635f\u5931\u3001\u51c6\u786e\u7387\u548c\u5176\u4ed6\u6307\u6807\uff0c\u5e76\u5728 90 \u4e07\u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002 \u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u6682\u7559\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u975e\u5e38\u5e38\u7528\u3002\u5047\u8bbe\u6211\u4eec\u8981\u89e3\u51b3\u7684\u95ee\u9898\u662f\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97 2020 \u5e74\u7684\u9500\u552e\u989d\uff0c\u800c\u6211\u4eec\u5f97\u5230\u7684\u662f 2015-2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f60\u53ef\u4ee5\u9009\u62e9 2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u4f5c\u4e3a\u4fdd\u7559\u6570\u636e\uff0c\u7136\u540e\u5728 2015 \u5e74\u81f3 2018 \u5e74\u7684\u6240\u6709\u6570\u636e\u4e0a\u8bad\u7ec3\u4f60\u7684\u6a21\u578b\u3002 \u56fe 7\uff1a\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u793a\u4f8b \u5728\u56fe 7 \u6240\u793a\u7684\u793a\u4f8b\u4e2d\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u4ece\u65f6\u95f4\u6b65\u9aa4 31 \u5230 40 \u7684\u9500\u552e\u989d\u3002\u6211\u4eec\u53ef\u4ee5\u4fdd\u7559 21 \u81f3 30 \u6b65\u7684\u6570\u636e\uff0c\u7136\u540e\u4ece 0 \u6b65\u5230 20 \u6b65\u8bad\u7ec3\u6a21\u578b\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5728\u9884\u6d4b 31 \u6b65\u81f3 40 \u6b65\u65f6\uff0c\u5e94\u5c06 21 \u6b65\u81f3 30 \u6b65\u7684\u6570\u636e\u7eb3\u5165\u6a21\u578b\uff0c\u5426\u5219\uff0c\u6a21\u578b\u7684\u6027\u80fd\u5c06\u5927\u6253\u6298\u6263\u3002 \u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u5fc5\u987b\u5904\u7406\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u800c\u521b\u5efa\u5927\u578b\u9a8c\u8bc1\u96c6\u610f\u5473\u7740\u6a21\u578b\u5b66\u4e60\u4f1a\u4e22\u5931\u5927\u91cf\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9\u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c\uff0c\u76f8\u5f53\u4e8e\u7279\u6b8a\u7684 k \u5219\u4ea4\u53c9\u68c0\u9a8c\u5176\u4e2d k=N \uff0cN \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8fd9\u610f\u5473\u7740\u5728\u6240\u6709\u7684\u8bad\u7ec3\u6298\u53e0\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u9664 1 \u4e4b\u5916\u7684\u6240\u6709\u6570\u636e\u6837\u672c\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u7684\u6298\u53e0\u6570\u4e0e\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u76f8\u540c\u3002 \u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u6a21\u578b\u7684\u901f\u5ea6\u4e0d\u591f\u5feb\uff0c\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\uff0c\u4f46\u7531\u4e8e\u8fd9\u79cd\u4ea4\u53c9\u68c0\u9a8c\u53ea\u9002\u7528\u4e8e\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u56e0\u6b64\u5e76\u4e0d\u91cd\u8981\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u95ee\u9898\u4e86\u3002\u56de\u5f52\u95ee\u9898\u7684\u597d\u5904\u5728\u4e8e\uff0c\u9664\u4e86\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e4b\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u5728\u56de\u5f52\u95ee\u9898\u4e0a\u4f7f\u7528\u4e0a\u8ff0\u6240\u6709\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6211\u4eec\u4e0d\u80fd\u76f4\u63a5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u6709\u4e00\u4e9b\u65b9\u6cd5\u53ef\u4ee5\u7a0d\u7a0d\u6539\u53d8\u95ee\u9898\uff0c\u4ece\u800c\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u9002\u7528\u4e8e\u4efb\u4f55\u56de\u5f52\u95ee\u9898\u3002\u4f46\u662f\uff0c\u5982\u679c\u53d1\u73b0\u76ee\u6807\u5206\u5e03\u4e0d\u4e00\u81f4\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u8981\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u5fc5\u987b\u5148\u5c06\u76ee\u6807\u5212\u5206\u4e3a\u82e5\u5e72\u4e2a\u5206\u5c42\uff0c\u7136\u540e\u518d\u4ee5\u5904\u7406\u5206\u7c7b\u95ee\u9898\u7684\u76f8\u540c\u65b9\u5f0f\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u9009\u62e9\u5408\u9002\u7684\u5206\u5c42\u6570\u6709\u51e0\u79cd\u9009\u62e9\u3002\u5982\u679c\u6837\u672c\u91cf\u5f88\u5927\uff08> 10k\uff0c> 100k\uff09\uff0c\u90a3\u4e48\u5c31\u4e0d\u9700\u8981\u8003\u8651\u5206\u5c42\u7684\u6570\u91cf\u3002\u53ea\u9700\u5c06\u6570\u636e\u5206\u4e3a 10 \u6216 20 \u5c42\u5373\u53ef\u3002\u5982\u679c\u6837\u672c\u6570\u4e0d\u591a\uff0c\u5219\u53ef\u4ee5\u4f7f\u7528 Sturge's Rule \u8fd9\u6837\u7684\u7b80\u5355\u89c4\u5219\u6765\u8ba1\u7b97\u9002\u5f53\u7684\u5206\u5c42\u6570\u3002 Sturge's Rule\uff1a \\[ Number of Bins = 1 + log_2(N) \\] \u5176\u4e2d \\(N\\) \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8be5\u51fd\u6570\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u5229\u7528\u65af\u7279\u683c\u6cd5\u5219\u7ed8\u5236\u6837\u672c\u4e0e\u7bb1\u6570\u5bf9\u6bd4\u56fe \u8ba9\u6211\u4eec\u5236\u4f5c\u4e00\u4e2a\u56de\u5f52\u6570\u636e\u96c6\u6837\u672c\uff0c\u5e76\u5c1d\u8bd5\u5e94\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 # stratified-kfold for regression # \u4e3a\u56de\u5f52\u95ee\u9898\u8fdb\u884c\u5206\u5c42K-\u6298\u4ea4\u53c9\u9a8c\u8bc1 # \u5bfc\u5165\u9700\u8981\u7684\u5e93 import numpy as np import pandas as pd from sklearn import datasets from sklearn import model_selection # \u521b\u5efa\u5206\u6298\uff08folds\uff09\u7684\u51fd\u6570 def create_folds ( data ): # \u521b\u5efa\u4e00\u4e2a\u65b0\u5217\u53eb\u505akfold\uff0c\u5e76\u7528-1\u6765\u586b\u5145 data [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4f7f\u7528Sturge\u89c4\u5219\u8ba1\u7b97bin\u7684\u6570\u91cf num_bins = int ( np . floor ( 1 + np . log2 ( len ( data )))) # \u4f7f\u7528pandas\u7684cut\u51fd\u6570\u8fdb\u884c\u76ee\u6807\u53d8\u91cf\uff08target\uff09\u7684\u5206\u7bb1 data . loc [:, \"bins\" ] = pd . cut ( data [ \"target\" ], bins = num_bins , labels = False ) # \u521d\u59cb\u5316StratifiedKFold\u7c7b kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684kfold\u5217 # \u6ce8\u610f\uff1a\u6211\u4eec\u4f7f\u7528\u7684\u662fbins\u800c\u4e0d\u662f\u5b9e\u9645\u7684\u76ee\u6807\u53d8\u91cf\uff08target\uff09\uff01 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = data , y = data . bins . values )): data . loc [ v_ , 'kfold' ] = f # \u5220\u9664bins\u5217 data = data . drop ( \"bins\" , axis = 1 ) # \u8fd4\u56de\u5305\u542bfolds\u7684\u6570\u636e return data # \u4e3b\u7a0b\u5e8f\u5f00\u59cb if __name__ == \"__main__\" : # \u521b\u5efa\u4e00\u4e2a\u5e26\u670915000\u4e2a\u6837\u672c\u3001100\u4e2a\u7279\u5f81\u548c1\u4e2a\u76ee\u6807\u53d8\u91cf\u7684\u6837\u672c\u6570\u636e\u96c6 X , y = datasets . make_regression ( n_samples = 15000 , n_features = 100 , n_targets = 1 ) # \u4f7f\u7528numpy\u6570\u7ec4\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u6846 df = pd . DataFrame ( X , columns = [ f \"f_ { i } \" for i in range ( X . shape [ 1 ])] ) df . loc [:, \"target\" ] = y # \u521b\u5efafolds df = create_folds ( df ) \u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u7b2c\u4e00\u6b65\uff0c\u4e5f\u662f\u6700\u57fa\u672c\u7684\u4e00\u6b65\u3002\u5982\u679c\u8981\u505a\u7279\u5f81\u5de5\u7a0b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u8981\u5efa\u7acb\u6a21\u578b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u65b9\u6848\uff0c\u5176\u4e2d\u9a8c\u8bc1\u6570\u636e\u80fd\u591f\u4ee3\u8868\u8bad\u7ec3\u6570\u636e\u548c\u771f\u5b9e\u4e16\u754c\u7684\u6570\u636e\uff0c\u90a3\u4e48\u4f60\u5c31\u80fd\u5efa\u7acb\u4e00\u4e2a\u5177\u6709\u9ad8\u5ea6\u901a\u7528\u6027\u7684\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u3002 \u672c\u7ae0\u4ecb\u7ecd\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u6570\u636e\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u6839\u636e\u4f60\u7684\u95ee\u9898\u548c\u6570\u636e\u91c7\u7528\u65b0\u7684\u4ea4\u53c9\u68c0\u9a8c\u5f62\u5f0f\u3002 \u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u95ee\u9898\uff0c\u5e0c\u671b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u4ece\u60a3\u8005\u7684\u76ae\u80a4\u56fe\u50cf\u4e2d\u68c0\u6d4b\u51fa\u76ae\u80a4\u764c\u3002\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u5668\uff0c\u8be5\u5206\u7c7b\u5668\u63a5\u6536\u8f93\u5165\u56fe\u50cf\u5e76\u9884\u6d4b\u5176\u826f\u6027\u6216\u6076\u6027\u7684\u6982\u7387\u3002 \u5728\u8fd9\u7c7b\u6570\u636e\u96c6\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u53ef\u80fd\u6709\u540c\u4e00\u60a3\u8005\u7684\u591a\u5f20\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u8981\u5728\u8fd9\u91cc\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u7cfb\u7edf\uff0c\u5fc5\u987b\u6709\u5206\u5c42\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u4e5f\u5fc5\u987b\u786e\u4fdd\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u60a3\u8005\u4e0d\u4f1a\u51fa\u73b0\u5728\u9a8c\u8bc1\u6570\u636e\u4e2d\u3002\u5e78\u8fd0\u7684\u662f\uff0cscikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u79f0\u4e3a GroupKFold \u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u3002 \u5728\u8fd9\u91cc\uff0c\u60a3\u8005\u53ef\u4ee5\u88ab\u89c6\u4e3a\u7ec4\u3002 \u4f46\u9057\u61be\u7684\u662f\uff0cscikit-learn \u65e0\u6cd5\u5c06 GroupKFold \u4e0e StratifiedKFold \u7ed3\u5408\u8d77\u6765\u3002\u6240\u4ee5\u4f60\u9700\u8981\u81ea\u5df1\u52a8\u624b\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u8bfb\u8005\u7684\u7ec3\u4e60\u3002","title":"\u4ea4\u53c9\u68c0\u9a8c"},{"location":"%E4%BA%A4%E5%8F%89%E6%A3%80%E9%AA%8C/#_1","text":"\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u6ca1\u6709\u5efa\u7acb\u4efb\u4f55\u6a21\u578b\u3002\u539f\u56e0\u5f88\u7b80\u5355\uff0c\u5728\u521b\u5efa\u4efb\u4f55\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4ee5\u53ca\u5982\u4f55\u6839\u636e\u6570\u636e\u96c6\u9009\u62e9\u6700\u4f73\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\u3002 \u90a3\u4e48\uff0c\u4ec0\u4e48\u662f \u4ea4\u53c9\u68c0\u9a8c \uff0c\u6211\u4eec\u4e3a\u4ec0\u4e48\u8981\u5173\u6ce8\u5b83\uff1f \u5173\u4e8e\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u53ef\u4ee5\u627e\u5230\u591a\u79cd\u5b9a\u4e49\u3002\u6211\u7684\u5b9a\u4e49\u53ea\u6709\u4e00\u53e5\u8bdd\uff1a\u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fc7\u7a0b\u4e2d\u7684\u4e00\u4e2a\u6b65\u9aa4\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u786e\u4fdd\u6a21\u578b\u51c6\u786e\u62df\u5408\u6570\u636e\uff0c\u540c\u65f6\u786e\u4fdd\u6211\u4eec\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u4f46\u8fd9\u53c8\u5f15\u51fa\u4e86\u53e6\u4e00\u4e2a\u8bcd\uff1a \u8fc7\u62df\u5408 \u3002 \u8981\u89e3\u91ca\u8fc7\u62df\u5408\uff0c\u6211\u8ba4\u4e3a\u6700\u597d\u5148\u770b\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6709\u4e00\u4e2a\u76f8\u5f53\u6709\u540d\u7684\u7ea2\u9152\u8d28\u91cf\u6570\u636e\u96c6\uff08 red wine quality dataset \uff09\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 11 \u4e2a\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u8fd9\u4e9b\u7279\u5f81\u51b3\u5b9a\u4e86\u7ea2\u9152\u7684\u8d28\u91cf\u3002 \u8fd9\u4e9b\u5c5e\u6027\u5305\u62ec\uff1a \u56fa\u5b9a\u9178\u5ea6\uff08fixed acidity\uff09 \u6325\u53d1\u6027\u9178\u5ea6\uff08volatile acidity\uff09 \u67e0\u6aac\u9178\uff08citric acid\uff09 \u6b8b\u7559\u7cd6\uff08residual sugar\uff09 \u6c2f\u5316\u7269\uff08chlorides\uff09 \u6e38\u79bb\u4e8c\u6c27\u5316\u786b\uff08free sulfur dioxide\uff09 \u4e8c\u6c27\u5316\u786b\u603b\u91cf\uff08total sulfur dioxide\uff09 \u5bc6\u5ea6\uff08density\uff09 PH \u503c\uff08pH\uff09 \u786b\u9178\u76d0\uff08sulphates\uff09 \u9152\u7cbe\uff08alcohol\uff09 \u6839\u636e\u8fd9\u4e9b\u4e0d\u540c\u7279\u5f81\uff0c\u6211\u4eec\u9700\u8981\u9884\u6d4b\u7ea2\u8461\u8404\u9152\u7684\u8d28\u91cf\uff0c\u8d28\u91cf\u503c\u4ecb\u4e8e 0 \u5230 10 \u4e4b\u95f4\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u600e\u6837\u7684\u3002 import pandas as pd df = pd . read_csv ( \"winequality-red.csv\" ) \u56fe 1:\u7ea2\u8461\u8404\u9152\u8d28\u91cf\u6570\u636e\u96c6\u7b80\u5355\u5c55\u793a \u6211\u4eec\u53ef\u4ee5\u5c06\u8fd9\u4e2a\u95ee\u9898\u89c6\u4e3a\u5206\u7c7b\u95ee\u9898\uff0c\u4e5f\u53ef\u4ee5\u89c6\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u9009\u62e9\u5206\u7c7b\u3002\u7136\u800c\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u503c\u5305\u542b 6 \u79cd\u8d28\u91cf\u503c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u6240\u6709\u8d28\u91cf\u503c\u6620\u5c04\u5230 0 \u5230 5 \u4e4b\u95f4\u3002 # \u4e00\u4e2a\u6620\u5c04\u5b57\u5178\uff0c\u7528\u4e8e\u5c06\u8d28\u91cf\u503c\u4ece 0 \u5230 5 \u8fdb\u884c\u6620\u5c04 quality_mapping = { 3 : 0 , 4 : 1 , 5 : 2 , 6 : 3 , 7 : 4 , 8 : 5 } # \u4f60\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 map \u51fd\u6570\u4ee5\u53ca\u4efb\u4f55\u5b57\u5178\uff0c # \u6765\u8f6c\u6362\u7ed9\u5b9a\u5217\u4e2d\u7684\u503c\u4e3a\u5b57\u5178\u4e2d\u7684\u503c df . loc [:, \"quality\" ] = df . quality . map ( quality_mapping ) \u5f53\u6211\u4eec\u770b\u5927\u8fd9\u4e9b\u6570\u636e\u5e76\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u6211\u4eec\u8111\u6d77\u4e2d\u4f1a\u6d6e\u73b0\u51fa\u5f88\u591a\u53ef\u4ee5\u5e94\u7528\u7684\u7b97\u6cd5\uff0c\u4e5f\u8bb8\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u4ece\u4e00\u5f00\u59cb\u5c31\u6df1\u5165\u7814\u7a76\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u5c31\u6709\u70b9\u7275\u5f3a\u4e86\u3002\u6240\u4ee5\uff0c\u8ba9\u6211\u4eec\u4ece\u7b80\u5355\u7684\u3001\u6211\u4eec\u4e5f\u80fd\u53ef\u89c6\u5316\u7684\u4e1c\u897f\u5f00\u59cb\uff1a\u51b3\u7b56\u6811\u3002 \u5728\u5f00\u59cb\u4e86\u89e3\u4ec0\u4e48\u662f\u8fc7\u62df\u5408\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 1599 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u4fdd\u7559 1000 \u4e2a\u6837\u672c\u7528\u4e8e\u8bad\u7ec3\uff0c599 \u4e2a\u6837\u672c\u4f5c\u4e3a\u4e00\u4e2a\u5355\u72ec\u7684\u96c6\u5408\u3002 \u4ee5\u4e0b\u4ee3\u7801\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u5212\u5206\uff1a # \u4f7f\u7528 frac=1 \u7684 sample \u65b9\u6cd5\u6765\u6253\u4e71 dataframe # \u7531\u4e8e\u6253\u4e71\u540e\u7d22\u5f15\u4f1a\u6539\u53d8\uff0c\u6240\u4ee5\u6211\u4eec\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u9009\u53d6\u524d 1000 \u884c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e df_train = df . head ( 1000 ) # \u9009\u53d6\u6700\u540e\u7684 599 \u884c\u4f5c\u4e3a\u6d4b\u8bd5/\u9a8c\u8bc1\u6570\u636e df_test = df . tail ( 599 ) \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u5728\u8bad\u7ec3\u96c6\u4e0a\u4f7f\u7528 scikit-learn \u8bad\u7ec3\u4e00\u4e2a\u51b3\u7b56\u6811\u6a21\u578b\u3002 # \u4ece scikit-learn \u5bfc\u5165\u9700\u8981\u7684\u6a21\u5757 from sklearn import tree from sklearn import metrics # \u521d\u59cb\u5316\u4e00\u4e2a\u51b3\u7b56\u6811\u5206\u7c7b\u5668\uff0c\u8bbe\u7f6e\u6700\u5927\u6df1\u5ea6\u4e3a 3 clf = tree . DecisionTreeClassifier ( max_depth = 3 ) # \u9009\u62e9\u4f60\u60f3\u8981\u8bad\u7ec3\u6a21\u578b\u7684\u5217 # \u8fd9\u4e9b\u5217\u4f5c\u4e3a\u6a21\u578b\u7684\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u4f7f\u7528\u4e4b\u524d\u6620\u5c04\u7684\u8d28\u91cf\u4ee5\u53ca\u63d0\u4f9b\u7684\u7279\u5f81\u6765\u8bad\u7ec3\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5c06\u51b3\u7b56\u6811\u5206\u7c7b\u5668\u7684\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u8bbe\u4e3a 3\u3002\u8be5\u6a21\u578b\u7684\u6240\u6709\u5176\u4ed6\u53c2\u6570\u5747\u4fdd\u6301\u9ed8\u8ba4\u503c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u6d4b\u8bd5\u8be5\u6a21\u578b\u7684\u51c6\u786e\u6027\uff1a # \u5728\u8bad\u7ec3\u96c6\u4e0a\u751f\u6210\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) # \u5728\u6d4b\u8bd5\u96c6\u4e0a\u751f\u6210\u9884\u6d4b test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) # \u8ba1\u7b97\u6d4b\u8bd5\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) \u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u5206\u522b\u4e3a 58.9%\u548c 54.25%\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u589e\u52a0\u5230 7\uff0c\u5e76\u91cd\u590d\u4e0a\u8ff0\u8fc7\u7a0b\u3002\u8fd9\u6837\uff0c\u8bad\u7ec3\u51c6\u786e\u7387\u4e3a 76.6%\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4e3a 57.3%\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4f7f\u7528\u51c6\u786e\u7387\uff0c\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u662f\u6700\u76f4\u63a5\u7684\u6307\u6807\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u95ee\u9898\u6765\u8bf4\uff0c\u5b83\u53ef\u80fd\u4e0d\u662f\u6700\u597d\u7684\u6307\u6807\u3002\u6211\u4eec\u53ef\u4ee5\u6839\u636e\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u4e0d\u540c\u503c\u6765\u8ba1\u7b97\u8fd9\u4e9b\u51c6\u786e\u7387\uff0c\u5e76\u7ed8\u5236\u66f2\u7ebf\u56fe\u3002 # \u6ce8\u610f\uff1a\u8fd9\u6bb5\u4ee3\u7801\u5728 Jupyter \u7b14\u8bb0\u672c\u4e2d\u7f16\u5199 # \u5bfc\u5165 scikit-learn \u7684 tree \u548c metrics from sklearn import tree from sklearn import metrics # \u5bfc\u5165 matplotlib \u548c seaborn # \u7528\u4e8e\u7ed8\u56fe import matplotlib import matplotlib.pyplot as plt import seaborn as sns # \u8bbe\u7f6e\u5168\u5c40\u6807\u7b7e\u6587\u672c\u7684\u5927\u5c0f matplotlib . rc ( 'xtick' , labelsize = 20 ) matplotlib . rc ( 'ytick' , labelsize = 20 ) # \u786e\u4fdd\u56fe\u8868\u76f4\u63a5\u5728\u7b14\u8bb0\u672c\u5185\u663e\u793a % matplotlib inline # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6\u7684\u5217\u8868 # \u6211\u4eec\u4ece 50% \u7684\u51c6\u786e\u5ea6\u5f00\u59cb train_accuracies = [ 0.5 ] test_accuracies = [ 0.5 ] # \u904d\u5386\u51e0\u4e2a\u4e0d\u540c\u7684\u6811\u6df1\u5ea6\u503c for depth in range ( 1 , 25 ): # \u521d\u59cb\u5316\u6a21\u578b clf = tree . DecisionTreeClassifier ( max_depth = depth ) # \u9009\u62e9\u7528\u4e8e\u8bad\u7ec3\u7684\u5217/\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u5728\u7ed9\u5b9a\u7279\u5f81\u4e0a\u62df\u5408\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) # \u521b\u5efa\u8bad\u7ec3\u548c\u6d4b\u8bd5\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) # \u6dfb\u52a0\u51c6\u786e\u5ea6\u5230\u5217\u8868 train_accuracies . append ( train_accuracy ) test_accuracies . append ( test_accuracy ) # \u4f7f\u7528 matplotlib \u548c seaborn \u521b\u5efa\u4e24\u4e2a\u56fe plt . figure ( figsize = ( 10 , 5 )) sns . set_style ( \"whitegrid\" ) plt . plot ( train_accuracies , label = \"train accuracy\" ) plt . plot ( test_accuracies , label = \"test accuracy\" ) plt . legend ( loc = \"upper left\" , prop = { 'size' : 15 }) plt . xticks ( range ( 0 , 26 , 5 )) plt . xlabel ( \"max_depth\" , size = 20 ) plt . ylabel ( \"accuracy\" , size = 20 ) plt . show () \u8fd9\u5c06\u751f\u6210\u5982\u56fe 2 \u6240\u793a\u7684\u66f2\u7ebf\u56fe\u3002 \u56fe 2\uff1a\u4e0d\u540c max_depth \u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u7387\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5f53\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u503c\u4e3a 14 \u65f6\uff0c\u6d4b\u8bd5\u6570\u636e\u7684\u5f97\u5206\u6700\u9ad8\u3002\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u589e\u52a0\u8fd9\u4e2a\u53c2\u6570\u7684\u503c\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4f1a\u4fdd\u6301\u4e0d\u53d8\u6216\u53d8\u5dee\uff0c\u4f46\u8bad\u7ec3\u51c6\u786e\u7387\u4f1a\u4e0d\u65ad\u63d0\u9ad8\u3002\u8fd9\u8bf4\u660e\uff0c\u968f\u7740\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u589e\u52a0\uff0c\u51b3\u7b56\u6811\u6a21\u578b\u5bf9\u8bad\u7ec3\u6570\u636e\u7684\u5b66\u4e60\u6548\u679c\u8d8a\u6765\u8d8a\u597d\uff0c\u4f46\u6d4b\u8bd5\u6570\u636e\u7684\u6027\u80fd\u5374\u4e1d\u6beb\u6ca1\u6709\u63d0\u9ad8\u3002 \u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8fc7\u62df\u5408 \u3002 \u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u5b8c\u5168\u62df\u5408\uff0c\u800c\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5374\u8868\u73b0\u4e0d\u4f73\u3002\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5f88\u597d\u5730\u5b66\u4e60\u8bad\u7ec3\u6570\u636e\uff0c\u4f46\u65e0\u6cd5\u6cdb\u5316\u5230\u672a\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u3002\u5728\u4e0a\u9762\u7684\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u975e\u5e38\u9ad8\u7684\u6a21\u578b\uff0c\u5b83\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u4f1a\u6709\u51fa\u8272\u7684\u7ed3\u679c\uff0c\u4f46\u8fd9\u79cd\u6a21\u578b\u5e76\u4e0d\u5b9e\u7528\uff0c\u56e0\u4e3a\u5b83\u5728\u771f\u5b9e\u4e16\u754c\u7684\u6837\u672c\u6216\u5b9e\u65f6\u6570\u636e\u4e0a\u4e0d\u4f1a\u63d0\u4f9b\u7c7b\u4f3c\u7684\u7ed3\u679c\u3002 \u6709\u4eba\u53ef\u80fd\u4f1a\u8bf4\uff0c\u8fd9\u79cd\u65b9\u6cd5\u5e76\u6ca1\u6709\u8fc7\u62df\u5408\uff0c\u56e0\u4e3a\u6d4b\u8bd5\u96c6\u7684\u51c6\u786e\u7387\u57fa\u672c\u4fdd\u6301\u4e0d\u53d8\u3002\u8fc7\u62df\u5408\u7684\u53e6\u4e00\u4e2a\u5b9a\u4e49\u662f\uff0c\u5f53\u6211\u4eec\u4e0d\u65ad\u63d0\u9ad8\u8bad\u7ec3\u635f\u5931\u65f6\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u5728\u589e\u52a0\u3002\u8fd9\u79cd\u60c5\u51b5\u5728\u795e\u7ecf\u7f51\u7edc\u4e2d\u975e\u5e38\u5e38\u89c1\u3002 \u6bcf\u5f53\u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u90fd\u5fc5\u987b\u5728\u8bad\u7ec3\u671f\u95f4\u76d1\u63a7\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u3002\u5982\u679c\u6211\u4eec\u6709\u4e00\u4e2a\u975e\u5e38\u5927\u7684\u7f51\u7edc\u6765\u5904\u7406\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u6570\u636e\u96c6\uff08\u5373\u6837\u672c\u6570\u975e\u5e38\u5c11\uff09\uff0c\u6211\u4eec\u5c31\u4f1a\u89c2\u5bdf\u5230\uff0c\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u8bad\u7ec3\uff0c\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u90fd\u4f1a\u51cf\u5c11\u3002\u4f46\u662f\uff0c\u5728\u67d0\u4e2a\u65f6\u523b\uff0c\u6d4b\u8bd5\u635f\u5931\u4f1a\u8fbe\u5230\u6700\u5c0f\u503c\uff0c\u4e4b\u540e\uff0c\u5373\u4f7f\u8bad\u7ec3\u635f\u5931\u8fdb\u4e00\u6b65\u51cf\u5c11\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u4f1a\u5f00\u59cb\u589e\u52a0\u3002\u6211\u4eec\u5fc5\u987b\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u6700\u5c0f\u503c\u65f6\u505c\u6b62\u8bad\u7ec3\u3002 \u8fd9\u662f\u5bf9\u8fc7\u62df\u5408\u6700\u5e38\u89c1\u7684\u89e3\u91ca \u3002 \u5965\u5361\u59c6\u5243\u5200\u7528\u7b80\u5355\u7684\u8bdd\u8bf4\uff0c\u5c31\u662f\u4e0d\u8981\u8bd5\u56fe\u628a\u53ef\u4ee5\u7528\u7b80\u5355\u5f97\u591a\u7684\u65b9\u6cd5\u89e3\u51b3\u7684\u4e8b\u60c5\u590d\u6742\u5316\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6848\u5c31\u662f\u6700\u5177\u901a\u7528\u6027\u7684\u89e3\u51b3\u65b9\u6848\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u53ea\u8981\u4f60\u7684\u6a21\u578b\u4e0d\u7b26\u5408\u5965\u5361\u59c6\u5243\u5200\u539f\u5219\uff0c\u5c31\u5f88\u53ef\u80fd\u662f\u8fc7\u62df\u5408\u3002 \u56fe 3\uff1a\u8fc7\u62df\u5408\u7684\u6700\u4e00\u822c\u5b9a\u4e49 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u56de\u5230\u4ea4\u53c9\u68c0\u9a8c\u3002 \u5728\u89e3\u91ca\u8fc7\u62df\u5408\u65f6\uff0c\u6211\u51b3\u5b9a\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u6211\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u53e6\u4e00\u90e8\u5206\u4e0a\u68c0\u67e5\u5176\u6027\u80fd\u3002\u8fd9\u4e5f\u662f\u4ea4\u53c9\u68c0\u9a8c\u7684\u4e00\u79cd\uff0c\u901a\u5e38\u88ab\u79f0\u4e3a \"\u6682\u7559\u96c6\"\uff08 hold-out set \uff09\u3002\u5f53\u6211\u4eec\u62e5\u6709\u5927\u91cf\u6570\u636e\uff0c\u800c\u6a21\u578b\u63a8\u7406\u662f\u4e00\u4e2a\u8017\u65f6\u7684\u8fc7\u7a0b\u65f6\uff0c\u6211\u4eec\u5c31\u4f1a\u4f7f\u7528\u8fd9\u79cd\uff08\u4ea4\u53c9\uff09\u9a8c\u8bc1\u3002 \u4ea4\u53c9\u68c0\u9a8c\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\uff0c\u5b83\u662f\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u6b65\u9aa4\u3002 \u9009\u62e9\u6b63\u786e\u7684\u4ea4\u53c9\u68c0\u9a8c \u53d6\u51b3\u4e8e\u6240\u5904\u7406\u7684\u6570\u636e\u96c6\uff0c\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u9002\u7528\u7684\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u53ef\u80fd\u4e0d\u9002\u7528\u4e8e\u5176\u4ed6\u6570\u636e\u96c6\u3002\u4e0d\u8fc7\uff0c\u6709\u51e0\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u6700\u4e3a\u6d41\u884c\u548c\u5e7f\u6cdb\u4f7f\u7528\u3002 \u5176\u4e2d\u5305\u62ec\uff1a k \u6298\u4ea4\u53c9\u68c0\u9a8c \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c \u5206\u7ec4 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u4ea4\u53c9\u68c0\u9a8c\u662f\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u5c42\u51e0\u4e2a\u90e8\u5206\uff0c\u6211\u4eec\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u5176\u4f59\u90e8\u5206\u4e0a\u8fdb\u884c\u6d4b\u8bd5\u3002\u8bf7\u770b\u56fe 4\u3002 \u56fe 4\uff1a\u5c06\u6570\u636e\u96c6\u62c6\u5206\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u56fe 4 \u548c\u56fe 5 \u8bf4\u660e\uff0c\u5f53\u4f60\u5f97\u5230\u4e00\u4e2a\u6570\u636e\u96c6\u6765\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u65f6\uff0c\u4f60\u4f1a\u628a\u5b83\u4eec\u5206\u6210 \u4e24\u4e2a\u4e0d\u540c\u7684\u96c6\uff1a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u3002\u5f88\u591a\u4eba\u8fd8\u4f1a\u5c06\u5176\u5206\u6210\u7b2c\u4e09\u7ec4\uff0c\u79f0\u4e4b\u4e3a\u6d4b\u8bd5\u96c6\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u5c06\u53ea\u4f7f\u7528\u4e24\u4e2a\u96c6\u3002\u5982\u4f60\u6240\u89c1\uff0c\u6211\u4eec\u5c06\u6837\u672c\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\u8fdb\u884c\u4e86\u5212\u5206\u3002\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a k \u4e2a\u4e92\u4e0d\u5173\u8054\u7684\u4e0d\u540c\u96c6\u5408\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002 \u56fe 5\uff1aK \u6298\u4ea4\u53c9\u68c0\u9a8c \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 KFold \u5c06\u4efb\u4f55\u6570\u636e\u5206\u5272\u6210 k \u4e2a\u76f8\u7b49\u7684\u90e8\u5206\u3002\u6bcf\u4e2a\u6837\u672c\u5206\u914d\u4e00\u4e2a\u4ece 0 \u5230 k-1 \u7684\u503c\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u5b58\u50a8\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a kfold \u7684\u65b0\u5217\uff0c\u5e76\u7528 -1 \u586b\u5145 df [ \"kfold\" ] = - 1 # \u63a5\u4e0b\u6765\u7684\u6b65\u9aa4\u662f\u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4ece model_selection \u6a21\u5757\u521d\u59cb\u5316 kfold \u7c7b kf = model_selection . KFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684 kfold \u5217\uff08enumerate\u7684\u4f5c\u7528\u662f\u8fd4\u56de\u4e00\u4e2a\u8fed\u4ee3\u5668\uff09 for fold , ( trn_ , val_ ) in enumerate ( kf . split ( X = df )): df . loc [ val_ , 'kfold' ] = fold # \u4fdd\u5b58\u5e26\u6709 kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u636e\u96c6\u90fd\u53ef\u4ee5\u4f7f\u7528\u6b64\u6d41\u7a0b\u3002\u4f8b\u5982\uff0c\u5f53\u6570\u636e\u56fe\u50cf\u65f6\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u56fe\u50cf ID\u3001\u56fe\u50cf\u4f4d\u7f6e\u548c\u56fe\u50cf\u6807\u7b7e\u7684 CSV\uff0c\u7136\u540e\u4f7f\u7528\u4e0a\u8ff0\u6d41\u7a0b\u3002 \u53e6\u4e00\u79cd\u91cd\u8981\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u662f \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u6570\u636e\u96c6\uff0c\u5176\u4e2d\u6b63\u6837\u672c\u5360 90%\uff0c\u8d1f\u6837\u672c\u53ea\u5360 10%\uff0c\u90a3\u4e48\u4f60\u5c31\u4e0d\u5e94\u8be5\u4f7f\u7528\u968f\u673a k \u6298\u4ea4\u53c9\u3002\u5bf9\u8fd9\u6837\u7684\u6570\u636e\u96c6\u4f7f\u7528\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u6298\u53e0\u6837\u672c\u5168\u90e8\u4e3a\u8d1f\u6837\u672c\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u66f4\u503e\u5411\u4e8e\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u4ee5\u4fdd\u6301\u6bcf\u4e2a\u6298\u4e2d\u6807\u7b7e\u7684\u6bd4\u4f8b\u4e0d\u53d8\u3002\u56e0\u6b64\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u90fd\u4f1a\u6709\u76f8\u540c\u7684 90% \u6b63\u6837\u672c\u548c 10% \u8d1f\u6837\u672c\u3002\u56e0\u6b64\uff0c\u65e0\u8bba\u60a8\u9009\u62e9\u4ec0\u4e48\u6307\u6807\u8fdb\u884c\u8bc4\u4f30\uff0c\u90fd\u4f1a\u5728\u6240\u6709\u6298\u53e0\u4e2d\u5f97\u5230\u76f8\u4f3c\u7684\u7ed3\u679c\u3002 \u4fee\u6539\u521b\u5efa k \u6298\u4ea4\u53c9\u68c0\u9a8c\u7684\u4ee3\u7801\u4ee5\u521b\u5efa\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5f88\u5bb9\u6613\u3002\u6211\u4eec\u53ea\u9700\u5c06 model_selection.KFold \u66f4\u6539\u4e3a model_selection.StratifiedKFold \uff0c\u5e76\u5728 kf.split(...) \u51fd\u6570\u4e2d\u6307\u5b9a\u8981\u5206\u5c42\u7684\u76ee\u6807\u5217\u3002\u6211\u4eec\u5047\u8bbe CSV \u6570\u636e\u96c6\u6709\u4e00\u5217\u540d\u4e3a \"target\" \uff0c\u5e76\u4e14\u662f\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u4fdd\u5b58\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6dfb\u52a0\u4e00\u4e2a\u65b0\u5217 kfold\uff0c\u5e76\u7528 -1 \u521d\u59cb\u5316 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u53d8\u91cf y = df . target . values # \u521d\u59cb\u5316 StratifiedKFold \u7c7b\uff0c\u8bbe\u7f6e\u6298\u6570\uff08folds\uff09\u4e3a 5 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u4f7f\u7528 StratifiedKFold \u5bf9\u8c61\u7684 split \u65b9\u6cd5\u6765\u83b7\u53d6\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7d22\u5f15 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u5305\u542b kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u5bf9\u4e8e\u8461\u8404\u9152\u6570\u636e\u96c6\uff0c\u6211\u4eec\u6765\u770b\u770b\u6807\u7b7e\u7684\u5206\u5e03\u60c5\u51b5\u3002 b = sns . countplot ( x = 'quality' , data = df ) b . set_xlabel ( \"quality\" , fontsize = 20 ) b . set_ylabel ( \"count\" , fontsize = 20 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u7ee7\u7eed\u4e0a\u9762\u7684\u4ee3\u7801\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u8f6c\u6362\u4e86\u76ee\u6807\u503c\u3002\u4ece\u56fe 6 \u4e2d\u6211\u4eec\u53ef\u4ee5\u770b\u51fa\uff0c\u8d28\u91cf\u504f\u5dee\u5f88\u5927\u3002\u6709\u4e9b\u7c7b\u522b\u6709\u5f88\u591a\u6837\u672c\uff0c\u6709\u4e9b\u5219\u6ca1\u6709\u90a3\u4e48\u591a\u3002\u5982\u679c\u6211\u4eec\u8fdb\u884c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u90a3\u4e48\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u76ee\u6807\u503c\u5206\u5e03\u90fd\u4e0d\u4f1a\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u56fe 6\uff1a\u8461\u8404\u9152\u6570\u636e\u96c6\u4e2d \"\u8d28\u91cf\" \u5206\u5e03\u60c5\u51b5 \u89c4\u5219\u5f88\u7b80\u5355\uff0c\u5982\u679c\u662f\u6807\u51c6\u5206\u7c7b\u95ee\u9898\uff0c\u5c31\u76f2\u76ee\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u4f46\u5982\u679c\u6570\u636e\u91cf\u5f88\u5927\uff0c\u8be5\u600e\u4e48\u529e\u5462\uff1f\u5047\u8bbe\u6211\u4eec\u6709 100 \u4e07\u4e2a\u6837\u672c\u30025 \u500d\u4ea4\u53c9\u68c0\u9a8c\u610f\u5473\u7740\u5728 800k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u5728 200k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u9a8c\u8bc1\u3002\u6839\u636e\u6211\u4eec\u9009\u62e9\u7684\u7b97\u6cd5\uff0c\u5bf9\u4e8e\u8fd9\u6837\u89c4\u6a21\u7684\u6570\u636e\u96c6\u6765\u8bf4\uff0c\u8bad\u7ec3\u751a\u81f3\u9a8c\u8bc1\u90fd\u53ef\u80fd\u975e\u5e38\u6602\u8d35\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u3002 \u521b\u5efa\u4fdd\u6301\u7ed3\u679c\u7684\u8fc7\u7a0b\u4e0e\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u76f8\u540c\u3002\u5bf9\u4e8e\u62e5\u6709 100 \u4e07\u4e2a\u6837\u672c\u7684\u6570\u636e\u96c6\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa 10 \u4e2a\u6298\u53e0\u800c\u4e0d\u662f 5 \u4e2a\uff0c\u5e76\u4fdd\u7559\u5176\u4e2d\u4e00\u4e2a\u6298\u53e0\u4f5c\u4e3a\u4fdd\u7559\u6837\u672c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u6211\u4eec\u5c06\u6709 10 \u4e07\u4e2a\u6837\u672c\u88ab\u4fdd\u7559\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u59cb\u7ec8\u5728\u8fd9\u4e2a\u6837\u672c\u96c6\u4e0a\u8ba1\u7b97\u635f\u5931\u3001\u51c6\u786e\u7387\u548c\u5176\u4ed6\u6307\u6807\uff0c\u5e76\u5728 90 \u4e07\u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002 \u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u6682\u7559\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u975e\u5e38\u5e38\u7528\u3002\u5047\u8bbe\u6211\u4eec\u8981\u89e3\u51b3\u7684\u95ee\u9898\u662f\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97 2020 \u5e74\u7684\u9500\u552e\u989d\uff0c\u800c\u6211\u4eec\u5f97\u5230\u7684\u662f 2015-2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f60\u53ef\u4ee5\u9009\u62e9 2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u4f5c\u4e3a\u4fdd\u7559\u6570\u636e\uff0c\u7136\u540e\u5728 2015 \u5e74\u81f3 2018 \u5e74\u7684\u6240\u6709\u6570\u636e\u4e0a\u8bad\u7ec3\u4f60\u7684\u6a21\u578b\u3002 \u56fe 7\uff1a\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u793a\u4f8b \u5728\u56fe 7 \u6240\u793a\u7684\u793a\u4f8b\u4e2d\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u4ece\u65f6\u95f4\u6b65\u9aa4 31 \u5230 40 \u7684\u9500\u552e\u989d\u3002\u6211\u4eec\u53ef\u4ee5\u4fdd\u7559 21 \u81f3 30 \u6b65\u7684\u6570\u636e\uff0c\u7136\u540e\u4ece 0 \u6b65\u5230 20 \u6b65\u8bad\u7ec3\u6a21\u578b\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5728\u9884\u6d4b 31 \u6b65\u81f3 40 \u6b65\u65f6\uff0c\u5e94\u5c06 21 \u6b65\u81f3 30 \u6b65\u7684\u6570\u636e\u7eb3\u5165\u6a21\u578b\uff0c\u5426\u5219\uff0c\u6a21\u578b\u7684\u6027\u80fd\u5c06\u5927\u6253\u6298\u6263\u3002 \u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u5fc5\u987b\u5904\u7406\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u800c\u521b\u5efa\u5927\u578b\u9a8c\u8bc1\u96c6\u610f\u5473\u7740\u6a21\u578b\u5b66\u4e60\u4f1a\u4e22\u5931\u5927\u91cf\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9\u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c\uff0c\u76f8\u5f53\u4e8e\u7279\u6b8a\u7684 k \u5219\u4ea4\u53c9\u68c0\u9a8c\u5176\u4e2d k=N \uff0cN \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8fd9\u610f\u5473\u7740\u5728\u6240\u6709\u7684\u8bad\u7ec3\u6298\u53e0\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u9664 1 \u4e4b\u5916\u7684\u6240\u6709\u6570\u636e\u6837\u672c\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u7684\u6298\u53e0\u6570\u4e0e\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u76f8\u540c\u3002 \u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u6a21\u578b\u7684\u901f\u5ea6\u4e0d\u591f\u5feb\uff0c\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\uff0c\u4f46\u7531\u4e8e\u8fd9\u79cd\u4ea4\u53c9\u68c0\u9a8c\u53ea\u9002\u7528\u4e8e\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u56e0\u6b64\u5e76\u4e0d\u91cd\u8981\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u95ee\u9898\u4e86\u3002\u56de\u5f52\u95ee\u9898\u7684\u597d\u5904\u5728\u4e8e\uff0c\u9664\u4e86\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e4b\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u5728\u56de\u5f52\u95ee\u9898\u4e0a\u4f7f\u7528\u4e0a\u8ff0\u6240\u6709\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6211\u4eec\u4e0d\u80fd\u76f4\u63a5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u6709\u4e00\u4e9b\u65b9\u6cd5\u53ef\u4ee5\u7a0d\u7a0d\u6539\u53d8\u95ee\u9898\uff0c\u4ece\u800c\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u9002\u7528\u4e8e\u4efb\u4f55\u56de\u5f52\u95ee\u9898\u3002\u4f46\u662f\uff0c\u5982\u679c\u53d1\u73b0\u76ee\u6807\u5206\u5e03\u4e0d\u4e00\u81f4\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u8981\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u5fc5\u987b\u5148\u5c06\u76ee\u6807\u5212\u5206\u4e3a\u82e5\u5e72\u4e2a\u5206\u5c42\uff0c\u7136\u540e\u518d\u4ee5\u5904\u7406\u5206\u7c7b\u95ee\u9898\u7684\u76f8\u540c\u65b9\u5f0f\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u9009\u62e9\u5408\u9002\u7684\u5206\u5c42\u6570\u6709\u51e0\u79cd\u9009\u62e9\u3002\u5982\u679c\u6837\u672c\u91cf\u5f88\u5927\uff08> 10k\uff0c> 100k\uff09\uff0c\u90a3\u4e48\u5c31\u4e0d\u9700\u8981\u8003\u8651\u5206\u5c42\u7684\u6570\u91cf\u3002\u53ea\u9700\u5c06\u6570\u636e\u5206\u4e3a 10 \u6216 20 \u5c42\u5373\u53ef\u3002\u5982\u679c\u6837\u672c\u6570\u4e0d\u591a\uff0c\u5219\u53ef\u4ee5\u4f7f\u7528 Sturge's Rule \u8fd9\u6837\u7684\u7b80\u5355\u89c4\u5219\u6765\u8ba1\u7b97\u9002\u5f53\u7684\u5206\u5c42\u6570\u3002 Sturge's Rule\uff1a \\[ Number of Bins = 1 + log_2(N) \\] \u5176\u4e2d \\(N\\) \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8be5\u51fd\u6570\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u5229\u7528\u65af\u7279\u683c\u6cd5\u5219\u7ed8\u5236\u6837\u672c\u4e0e\u7bb1\u6570\u5bf9\u6bd4\u56fe \u8ba9\u6211\u4eec\u5236\u4f5c\u4e00\u4e2a\u56de\u5f52\u6570\u636e\u96c6\u6837\u672c\uff0c\u5e76\u5c1d\u8bd5\u5e94\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 # stratified-kfold for regression # \u4e3a\u56de\u5f52\u95ee\u9898\u8fdb\u884c\u5206\u5c42K-\u6298\u4ea4\u53c9\u9a8c\u8bc1 # \u5bfc\u5165\u9700\u8981\u7684\u5e93 import numpy as np import pandas as pd from sklearn import datasets from sklearn import model_selection # \u521b\u5efa\u5206\u6298\uff08folds\uff09\u7684\u51fd\u6570 def create_folds ( data ): # \u521b\u5efa\u4e00\u4e2a\u65b0\u5217\u53eb\u505akfold\uff0c\u5e76\u7528-1\u6765\u586b\u5145 data [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4f7f\u7528Sturge\u89c4\u5219\u8ba1\u7b97bin\u7684\u6570\u91cf num_bins = int ( np . floor ( 1 + np . log2 ( len ( data )))) # \u4f7f\u7528pandas\u7684cut\u51fd\u6570\u8fdb\u884c\u76ee\u6807\u53d8\u91cf\uff08target\uff09\u7684\u5206\u7bb1 data . loc [:, \"bins\" ] = pd . cut ( data [ \"target\" ], bins = num_bins , labels = False ) # \u521d\u59cb\u5316StratifiedKFold\u7c7b kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684kfold\u5217 # \u6ce8\u610f\uff1a\u6211\u4eec\u4f7f\u7528\u7684\u662fbins\u800c\u4e0d\u662f\u5b9e\u9645\u7684\u76ee\u6807\u53d8\u91cf\uff08target\uff09\uff01 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = data , y = data . bins . values )): data . loc [ v_ , 'kfold' ] = f # \u5220\u9664bins\u5217 data = data . drop ( \"bins\" , axis = 1 ) # \u8fd4\u56de\u5305\u542bfolds\u7684\u6570\u636e return data # \u4e3b\u7a0b\u5e8f\u5f00\u59cb if __name__ == \"__main__\" : # \u521b\u5efa\u4e00\u4e2a\u5e26\u670915000\u4e2a\u6837\u672c\u3001100\u4e2a\u7279\u5f81\u548c1\u4e2a\u76ee\u6807\u53d8\u91cf\u7684\u6837\u672c\u6570\u636e\u96c6 X , y = datasets . make_regression ( n_samples = 15000 , n_features = 100 , n_targets = 1 ) # \u4f7f\u7528numpy\u6570\u7ec4\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u6846 df = pd . DataFrame ( X , columns = [ f \"f_ { i } \" for i in range ( X . shape [ 1 ])] ) df . loc [:, \"target\" ] = y # \u521b\u5efafolds df = create_folds ( df ) \u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u7b2c\u4e00\u6b65\uff0c\u4e5f\u662f\u6700\u57fa\u672c\u7684\u4e00\u6b65\u3002\u5982\u679c\u8981\u505a\u7279\u5f81\u5de5\u7a0b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u8981\u5efa\u7acb\u6a21\u578b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u65b9\u6848\uff0c\u5176\u4e2d\u9a8c\u8bc1\u6570\u636e\u80fd\u591f\u4ee3\u8868\u8bad\u7ec3\u6570\u636e\u548c\u771f\u5b9e\u4e16\u754c\u7684\u6570\u636e\uff0c\u90a3\u4e48\u4f60\u5c31\u80fd\u5efa\u7acb\u4e00\u4e2a\u5177\u6709\u9ad8\u5ea6\u901a\u7528\u6027\u7684\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u3002 \u672c\u7ae0\u4ecb\u7ecd\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u6570\u636e\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u6839\u636e\u4f60\u7684\u95ee\u9898\u548c\u6570\u636e\u91c7\u7528\u65b0\u7684\u4ea4\u53c9\u68c0\u9a8c\u5f62\u5f0f\u3002 \u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u95ee\u9898\uff0c\u5e0c\u671b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u4ece\u60a3\u8005\u7684\u76ae\u80a4\u56fe\u50cf\u4e2d\u68c0\u6d4b\u51fa\u76ae\u80a4\u764c\u3002\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u5668\uff0c\u8be5\u5206\u7c7b\u5668\u63a5\u6536\u8f93\u5165\u56fe\u50cf\u5e76\u9884\u6d4b\u5176\u826f\u6027\u6216\u6076\u6027\u7684\u6982\u7387\u3002 \u5728\u8fd9\u7c7b\u6570\u636e\u96c6\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u53ef\u80fd\u6709\u540c\u4e00\u60a3\u8005\u7684\u591a\u5f20\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u8981\u5728\u8fd9\u91cc\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u7cfb\u7edf\uff0c\u5fc5\u987b\u6709\u5206\u5c42\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u4e5f\u5fc5\u987b\u786e\u4fdd\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u60a3\u8005\u4e0d\u4f1a\u51fa\u73b0\u5728\u9a8c\u8bc1\u6570\u636e\u4e2d\u3002\u5e78\u8fd0\u7684\u662f\uff0cscikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u79f0\u4e3a GroupKFold \u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u3002 \u5728\u8fd9\u91cc\uff0c\u60a3\u8005\u53ef\u4ee5\u88ab\u89c6\u4e3a\u7ec4\u3002 \u4f46\u9057\u61be\u7684\u662f\uff0cscikit-learn \u65e0\u6cd5\u5c06 GroupKFold \u4e0e StratifiedKFold \u7ed3\u5408\u8d77\u6765\u3002\u6240\u4ee5\u4f60\u9700\u8981\u81ea\u5df1\u52a8\u624b\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u8bfb\u8005\u7684\u7ec3\u4e60\u3002","title":"\u4ea4\u53c9\u68c0\u9a8c"},{"location":"%E5%87%86%E5%A4%87%E7%8E%AF%E5%A2%83/","text":"\u51c6\u5907\u73af\u5883 \u5728\u6211\u4eec\u5f00\u59cb\u7f16\u7a0b\u4e4b\u524d\uff0c\u5728\u4f60\u7684\u673a\u5668\u4e0a\u8bbe\u7f6e\u597d\u4e00\u5207\u662f\u975e\u5e38\u91cd\u8981\u7684\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Ubuntu 18.04 \u548c Python 3.7.6\u3002\u5982\u679c\u4f60\u662f Windows \u7528\u6237\uff0c\u53ef\u4ee5\u901a\u8fc7\u591a\u79cd\u65b9\u5f0f\u5b89\u88c5 Ubuntu\u3002\u4f8b\u5982\uff0c\u5728\u865a\u62df\u673a\u4e0a\u5b89\u88c5\u7531Oracle\u516c\u53f8\u63d0\u4f9b\u7684\u514d\u8d39\u8f6f\u4ef6 Virtual Box\u3002\u4e0eWindows\u4e00\u8d77\u4f5c\u4e3a\u53cc\u542f\u52a8\u7cfb\u7edf\u3002\u6211\u66f4\u559c\u6b22\u53cc\u542f\u52a8\uff0c\u56e0\u4e3a\u5b83\u662f\u539f\u751f\u7684\u3002\u5982\u679c\u4f60\u4e0d\u662fUbuntu\u7528\u6237\uff0c\u5728\u4f7f\u7528\u672c\u4e66\u4e2d\u7684\u67d0\u4e9bbash\u811a\u672c\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u95ee\u9898\u3002\u4e3a\u4e86\u907f\u514d\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u53ef\u4ee5\u5728\u865a\u62df\u673a\u4e2d\u5b89\u88c5Ubuntu\uff0c\u6216\u8005\u5728Windows\u4e0a\u5b89\u88c5Linux shell\u3002 \u7528 Anaconda \u5728\u4efb\u4f55\u673a\u5668\u4e0a\u5b89\u88c5 Python \u90fd\u5f88\u7b80\u5355\u3002\u6211\u7279\u522b\u559c\u6b22 Miniconda \uff0c\u5b83\u662f conda \u7684\u6700\u5c0f\u5b89\u88c5\u7a0b\u5e8f\u3002\u5b83\u9002\u7528\u4e8e Linux\u3001OSX \u548c Windows\u3002\u7531\u4e8e Python 2 \u652f\u6301\u5df2\u4e8e 2019 \u5e74\u5e95\u7ed3\u675f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Python 3 \u53d1\u884c\u7248\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0cminiconda \u5e76\u4e0d\u50cf\u666e\u901a Anaconda \u9644\u5e26\u6240\u6709\u8f6f\u4ef6\u5305\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u968f\u65f6\u5b89\u88c5\u65b0\u8f6f\u4ef6\u5305\u3002\u5b89\u88c5 miniconda \u975e\u5e38\u7b80\u5355\u3002 \u9996\u5148\u8981\u505a\u7684\u662f\u5c06 Miniconda3 \u4e0b\u8f7d\u5230\u7cfb\u7edf\u4e2d\u3002 cd ~/Downloads wget https://repo.anaconda.com/miniconda/... \u5176\u4e2d wget \u547d\u4ee4\u540e\u7684 URL \u662f miniconda3 \u7f51\u9875\u7684 URL\u3002\u5bf9\u4e8e 64 \u4f4d Linux \u7cfb\u7edf\uff0c\u7f16\u5199\u672c\u4e66\u65f6\u7684 URL \u662f https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \u4e0b\u8f7d miniconda3 \u540e\uff0c\u53ef\u4ee5\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\uff1a sh Miniconda3-latest-Linux-x86_64.sh \u63a5\u4e0b\u6765\uff0c\u8bf7\u9605\u8bfb\u5e76\u6309\u7167\u5c4f\u5e55\u4e0a\u7684\u8bf4\u660e\u64cd\u4f5c\u3002\u5982\u679c\u5b89\u88c5\u6b63\u786e\uff0c\u4f60\u5e94\u8be5\u53ef\u4ee5\u901a\u8fc7\u5728\u7ec8\u7aef\u8f93\u5165 conda init \u6765\u542f\u52a8 conda \u73af\u5883\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u5728\u672c\u4e66\u4e2d\u4e00\u76f4\u4f7f\u7528\u7684 conda \u73af\u5883\u3002\u8981\u521b\u5efa conda \u73af\u5883\uff0c\u53ef\u4ee5\u8f93\u5165\uff1a conda create -n environment_name python = 3 .7.6 \u6b64\u547d\u4ee4\u5c06\u521b\u5efa\u540d\u4e3a environment_name \u7684 conda \u73af\u5883\uff0c\u53ef\u4ee5\u4f7f\u7528\uff1a conda activate environment_name \u73b0\u5728\u6211\u4eec\u7684\u73af\u5883\u5df2\u7ecf\u642d\u5efa\u5b8c\u6bd5\u3002\u662f\u65f6\u5019\u5b89\u88c5\u4e00\u4e9b\u6211\u4eec\u4f1a\u7528\u5230\u7684\u8f6f\u4ef6\u5305\u4e86\u3002\u5728 conda \u73af\u5883\u4e2d\uff0c\u5b89\u88c5\u8f6f\u4ef6\u5305\u6709\u4e24\u79cd\u4e0d\u540c\u7684\u65b9\u5f0f\u3002 \u4f60\u53ef\u4ee5\u4ece conda \u4ed3\u5e93\u6216 PyPi \u5b98\u65b9\u4ed3\u5e93\u5b89\u88c5\u8f6f\u4ef6\u5305\u3002 conda/pip install package_name \u6ce8\u610f\uff1a\u67d0\u4e9b\u8f6f\u4ef6\u5305\u53ef\u80fd\u65e0\u6cd5\u5728 conda \u8f6f\u4ef6\u4ed3\u5e93\u4e2d\u627e\u5230\u3002\u56e0\u6b64\uff0c\u5728\u672c\u4e66\u4e2d\uff0c\u4f7f\u7528 pip \u5b89\u88c5\u662f\u6700\u53ef\u53d6\u7684\u65b9\u6cd5\u3002\u6211\u5df2\u7ecf\u521b\u5efa\u4e86\u4e00\u4e2a\u7f16\u5199\u672c\u4e66\u65f6\u4f7f\u7528\u7684\u8f6f\u4ef6\u5305\u5217\u8868\uff0c\u4fdd\u5b58\u5728 environment.yml \u4e2d\u3002 \u4f60\u53ef\u4ee5\u5728\u6211\u7684 GitHub \u4ed3\u5e93\u4e2d\u7684\u989d\u5916\u8d44\u6599\u4e2d\u627e\u5230\u5b83\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u521b\u5efa\u73af\u5883\uff1a conda env create -f environment.yml \u8be5\u547d\u4ee4\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a ml \u7684\u73af\u5883\u3002\u8981\u6fc0\u6d3b\u8be5\u73af\u5883\u5e76\u5f00\u59cb\u4f7f\u7528\uff0c\u5e94\u8fd0\u884c\uff1a conda activate ml \u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u51c6\u5907\u5c31\u7eea\uff0c\u53ef\u4ee5\u8fdb\u884c\u4e00\u4e9b\u5e94\u7528\u673a\u5668\u5b66\u4e60\u7684\u5de5\u4f5c\u4e86\uff01\u5728\u4f7f\u7528\u672c\u4e66\u8fdb\u884c\u7f16\u7801\u65f6\uff0c\u8bf7\u59cb\u7ec8\u8bb0\u4f4f\u8981\u5728 \"ml \"\u73af\u5883\u4e0b\u8fdb\u884c\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5f00\u59cb\u5b66\u4e60\u771f\u6b63\u7684\u7b2c\u4e00\u7ae0\u3002","title":"\u51c6\u5907\u73af\u5883"},{"location":"%E5%87%86%E5%A4%87%E7%8E%AF%E5%A2%83/#_1","text":"\u5728\u6211\u4eec\u5f00\u59cb\u7f16\u7a0b\u4e4b\u524d\uff0c\u5728\u4f60\u7684\u673a\u5668\u4e0a\u8bbe\u7f6e\u597d\u4e00\u5207\u662f\u975e\u5e38\u91cd\u8981\u7684\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Ubuntu 18.04 \u548c Python 3.7.6\u3002\u5982\u679c\u4f60\u662f Windows \u7528\u6237\uff0c\u53ef\u4ee5\u901a\u8fc7\u591a\u79cd\u65b9\u5f0f\u5b89\u88c5 Ubuntu\u3002\u4f8b\u5982\uff0c\u5728\u865a\u62df\u673a\u4e0a\u5b89\u88c5\u7531Oracle\u516c\u53f8\u63d0\u4f9b\u7684\u514d\u8d39\u8f6f\u4ef6 Virtual Box\u3002\u4e0eWindows\u4e00\u8d77\u4f5c\u4e3a\u53cc\u542f\u52a8\u7cfb\u7edf\u3002\u6211\u66f4\u559c\u6b22\u53cc\u542f\u52a8\uff0c\u56e0\u4e3a\u5b83\u662f\u539f\u751f\u7684\u3002\u5982\u679c\u4f60\u4e0d\u662fUbuntu\u7528\u6237\uff0c\u5728\u4f7f\u7528\u672c\u4e66\u4e2d\u7684\u67d0\u4e9bbash\u811a\u672c\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u95ee\u9898\u3002\u4e3a\u4e86\u907f\u514d\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u53ef\u4ee5\u5728\u865a\u62df\u673a\u4e2d\u5b89\u88c5Ubuntu\uff0c\u6216\u8005\u5728Windows\u4e0a\u5b89\u88c5Linux shell\u3002 \u7528 Anaconda \u5728\u4efb\u4f55\u673a\u5668\u4e0a\u5b89\u88c5 Python \u90fd\u5f88\u7b80\u5355\u3002\u6211\u7279\u522b\u559c\u6b22 Miniconda \uff0c\u5b83\u662f conda \u7684\u6700\u5c0f\u5b89\u88c5\u7a0b\u5e8f\u3002\u5b83\u9002\u7528\u4e8e Linux\u3001OSX \u548c Windows\u3002\u7531\u4e8e Python 2 \u652f\u6301\u5df2\u4e8e 2019 \u5e74\u5e95\u7ed3\u675f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Python 3 \u53d1\u884c\u7248\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0cminiconda \u5e76\u4e0d\u50cf\u666e\u901a Anaconda \u9644\u5e26\u6240\u6709\u8f6f\u4ef6\u5305\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u968f\u65f6\u5b89\u88c5\u65b0\u8f6f\u4ef6\u5305\u3002\u5b89\u88c5 miniconda \u975e\u5e38\u7b80\u5355\u3002 \u9996\u5148\u8981\u505a\u7684\u662f\u5c06 Miniconda3 \u4e0b\u8f7d\u5230\u7cfb\u7edf\u4e2d\u3002 cd ~/Downloads wget https://repo.anaconda.com/miniconda/... \u5176\u4e2d wget \u547d\u4ee4\u540e\u7684 URL \u662f miniconda3 \u7f51\u9875\u7684 URL\u3002\u5bf9\u4e8e 64 \u4f4d Linux \u7cfb\u7edf\uff0c\u7f16\u5199\u672c\u4e66\u65f6\u7684 URL \u662f https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \u4e0b\u8f7d miniconda3 \u540e\uff0c\u53ef\u4ee5\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\uff1a sh Miniconda3-latest-Linux-x86_64.sh \u63a5\u4e0b\u6765\uff0c\u8bf7\u9605\u8bfb\u5e76\u6309\u7167\u5c4f\u5e55\u4e0a\u7684\u8bf4\u660e\u64cd\u4f5c\u3002\u5982\u679c\u5b89\u88c5\u6b63\u786e\uff0c\u4f60\u5e94\u8be5\u53ef\u4ee5\u901a\u8fc7\u5728\u7ec8\u7aef\u8f93\u5165 conda init \u6765\u542f\u52a8 conda \u73af\u5883\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u5728\u672c\u4e66\u4e2d\u4e00\u76f4\u4f7f\u7528\u7684 conda \u73af\u5883\u3002\u8981\u521b\u5efa conda \u73af\u5883\uff0c\u53ef\u4ee5\u8f93\u5165\uff1a conda create -n environment_name python = 3 .7.6 \u6b64\u547d\u4ee4\u5c06\u521b\u5efa\u540d\u4e3a environment_name \u7684 conda \u73af\u5883\uff0c\u53ef\u4ee5\u4f7f\u7528\uff1a conda activate environment_name \u73b0\u5728\u6211\u4eec\u7684\u73af\u5883\u5df2\u7ecf\u642d\u5efa\u5b8c\u6bd5\u3002\u662f\u65f6\u5019\u5b89\u88c5\u4e00\u4e9b\u6211\u4eec\u4f1a\u7528\u5230\u7684\u8f6f\u4ef6\u5305\u4e86\u3002\u5728 conda \u73af\u5883\u4e2d\uff0c\u5b89\u88c5\u8f6f\u4ef6\u5305\u6709\u4e24\u79cd\u4e0d\u540c\u7684\u65b9\u5f0f\u3002 \u4f60\u53ef\u4ee5\u4ece conda \u4ed3\u5e93\u6216 PyPi \u5b98\u65b9\u4ed3\u5e93\u5b89\u88c5\u8f6f\u4ef6\u5305\u3002 conda/pip install package_name \u6ce8\u610f\uff1a\u67d0\u4e9b\u8f6f\u4ef6\u5305\u53ef\u80fd\u65e0\u6cd5\u5728 conda \u8f6f\u4ef6\u4ed3\u5e93\u4e2d\u627e\u5230\u3002\u56e0\u6b64\uff0c\u5728\u672c\u4e66\u4e2d\uff0c\u4f7f\u7528 pip \u5b89\u88c5\u662f\u6700\u53ef\u53d6\u7684\u65b9\u6cd5\u3002\u6211\u5df2\u7ecf\u521b\u5efa\u4e86\u4e00\u4e2a\u7f16\u5199\u672c\u4e66\u65f6\u4f7f\u7528\u7684\u8f6f\u4ef6\u5305\u5217\u8868\uff0c\u4fdd\u5b58\u5728 environment.yml \u4e2d\u3002 \u4f60\u53ef\u4ee5\u5728\u6211\u7684 GitHub \u4ed3\u5e93\u4e2d\u7684\u989d\u5916\u8d44\u6599\u4e2d\u627e\u5230\u5b83\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u521b\u5efa\u73af\u5883\uff1a conda env create -f environment.yml \u8be5\u547d\u4ee4\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a ml \u7684\u73af\u5883\u3002\u8981\u6fc0\u6d3b\u8be5\u73af\u5883\u5e76\u5f00\u59cb\u4f7f\u7528\uff0c\u5e94\u8fd0\u884c\uff1a conda activate ml \u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u51c6\u5907\u5c31\u7eea\uff0c\u53ef\u4ee5\u8fdb\u884c\u4e00\u4e9b\u5e94\u7528\u673a\u5668\u5b66\u4e60\u7684\u5de5\u4f5c\u4e86\uff01\u5728\u4f7f\u7528\u672c\u4e66\u8fdb\u884c\u7f16\u7801\u65f6\uff0c\u8bf7\u59cb\u7ec8\u8bb0\u4f4f\u8981\u5728 \"ml \"\u73af\u5883\u4e0b\u8fdb\u884c\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5f00\u59cb\u5b66\u4e60\u771f\u6b63\u7684\u7b2c\u4e00\u7ae0\u3002","title":"\u51c6\u5907\u73af\u5883"},{"location":"%E5%8F%AF%E9%87%8D%E5%A4%8D%E4%BB%A3%E7%A0%81%E5%92%8C%E6%A8%A1%E5%9E%8B%E6%96%B9%E6%B3%95/","text":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u5230\u4e86\u53ef\u4ee5\u5c06\u6a21\u578b/\u8bad\u7ec3\u4ee3\u7801\u5206\u53d1\u7ed9\u4ed6\u4eba\u4f7f\u7528\u7684\u9636\u6bb5\u3002\u60a8\u53ef\u4ee5\u7528\u8f6f\u76d8\u5206\u53d1\u6216\u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\uff0c\u4f46\u8fd9\u5e76\u4e0d\u7406\u60f3\u3002\u662f\u8fd9\u6837\u5417\uff1f\u4e5f\u8bb8\u5f88\u591a\u5e74\u524d\uff0c\u8fd9\u662f\u7406\u60f3\u7684\u505a\u6cd5\uff0c\u4f46\u73b0\u5728\u4e0d\u662f\u4e86\u3002 \u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\u548c\u534f\u4f5c\u7684\u9996\u9009\u65b9\u5f0f\u662f\u4f7f\u7528\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u3002Git \u662f\u6700\u6d41\u884c\u7684\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u4e4b\u4e00\u3002\u90a3\u4e48\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u5b66\u4f1a\u4e86 Git\uff0c\u5e76\u6b63\u786e\u5730\u683c\u5f0f\u5316\u4e86\u4ee3\u7801\uff0c\u7f16\u5199\u4e86\u9002\u5f53\u7684\u6587\u6863\uff0c\u8fd8\u5f00\u6e90\u4e86\u4f60\u7684\u9879\u76ee\u3002\u8fd9\u5c31\u591f\u4e86\u5417\uff1f\u4e0d\uff0c\u8fd8\u4e0d\u591f\u3002\u56e0\u4e3a\u4f60\u5728\u81ea\u5df1\u7684\u7535\u8111\u4e0a\u5199\u7684\u4ee3\u7801\uff0c\u5728\u522b\u4eba\u7684\u7535\u8111\u4e0a\u53ef\u80fd\u4f1a\u56e0\u4e3a\u5404\u79cd\u539f\u56e0\u800c\u65e0\u6cd5\u8fd0\u884c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u5728\u53d1\u5e03\u4ee3\u7801\u65f6\u80fd\u590d\u5236\u81ea\u5df1\u7684\u7535\u8111\uff0c\u800c\u5176\u4ed6\u4eba\u5728\u5b89\u88c5\u60a8\u7684\u8f6f\u4ef6\u6216\u8fd0\u884c\u60a8\u7684\u4ee3\u7801\u65f6\u4e5f\u80fd\u590d\u5236\u60a8\u7684\u7535\u8111\uff0c\u90a3\u5c31\u518d\u597d\u4e0d\u8fc7\u4e86\u3002\u4e3a\u6b64\uff0c\u5982\u4eca\u6700\u6d41\u884c\u7684\u65b9\u6cd5\u662f\u4f7f\u7528 Docker \u5bb9\u5668\uff08Docker Containers\uff09\u3002\u8981\u4f7f\u7528 Docker \u5bb9\u5668\uff0c\u4f60\u9700\u8981\u5b89\u88c5 Docker\u3002 \u8ba9\u6211\u4eec\u7528\u4e0b\u9762\u7684\u547d\u4ee4\u6765\u5b89\u88c5 Docker\u3002 sudo apt install docker.io sudo systemctl start docker sudo systemctl enable docker sudo groupadd docker sudo usermod -aG docker $USER \u8fd9\u4e9b\u547d\u4ee4\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e0a\u8fd0\u884c\u3002Docker \u6700\u68d2\u7684\u5730\u65b9\u5728\u4e8e\u5b83\u53ef\u4ee5\u5b89\u88c5\u5728\u4efb\u4f55\u673a\u5668\u4e0a\uff1a Linux\u3001Windows\u3001OSX\u3002\u56e0\u6b64\uff0c\u5982\u679c\u4f60\u4e00\u76f4\u5728 Docker \u5bb9\u5668\u4e2d\u5de5\u4f5c\uff0c\u54ea\u53f0\u673a\u5668\u90fd\u6ca1\u5173\u7cfb\uff01 Docker \u5bb9\u5668\u53ef\u4ee5\u88ab\u89c6\u4e3a\u5c0f\u578b\u865a\u62df\u673a\u3002\u4f60\u53ef\u4ee5\u4e3a\u4f60\u7684\u4ee3\u7801\u521b\u5efa\u4e00\u4e2a\u5bb9\u5668\uff0c\u7136\u540e\u6bcf\u4e2a\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u548c\u8bbf\u95ee\u5b83\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u521b\u5efa\u53ef\u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u5bb9\u5668\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e00\u7ae0\u4e2d\u8bad\u7ec3\u7684 BERT \u6a21\u578b\uff0c\u5e76\u5c1d\u8bd5\u5c06\u8bad\u7ec3\u4ee3\u7801\u5bb9\u5668\u5316\u3002 \u9996\u5148\uff0c\u4f60\u9700\u8981\u4e00\u4e2a\u5305\u542b python \u9879\u76ee\u9700\u6c42\u7684\u6587\u4ef6\u3002\u9700\u6c42\u5305\u542b\u5728\u540d\u4e3a requirements.txt \u7684\u6587\u4ef6\u4e2d\u3002\u6587\u4ef6\u540d\u662f thestandard\u3002\u8be5\u6587\u4ef6\u5305\u542b\u9879\u76ee\u4e2d\u4f7f\u7528\u7684\u6240\u6709 python \u5e93\u3002\u4e5f\u5c31\u662f\u53ef\u4ee5\u901a\u8fc7 PyPI (pip) \u4e0b\u8f7d\u7684 python \u5e93\u3002\u7528\u4e8e \u8bad\u7ec3 BERT \u6a21\u578b\u4ee5\u68c0\u6d4b\u6b63/\u8d1f\u60c5\u611f\uff0c\u6211\u4eec\u4f7f\u7528\u4e86 torch\u3001transformers\u3001tqdm\u3001scikit-learn\u3001pandas \u548c numpy\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u4eec\u5199\u5165 requirements.txt \u4e2d\u3002\u4f60\u53ef\u4ee5\u53ea\u5199\u540d\u79f0\uff0c\u4e5f\u53ef\u4ee5\u5305\u62ec\u7248\u672c\u3002\u5305\u542b\u7248\u672c\u603b\u662f\u6700\u597d\u7684\uff0c\u8fd9\u4e5f\u662f\u4f60\u5e94\u8be5\u505a\u7684\u3002\u5305\u542b\u7248\u672c\u540e\uff0c\u53ef\u4ee5\u786e\u4fdd\u5176\u4ed6\u4eba\u4f7f\u7528\u7684\u7248\u672c\u4e0e\u4f60\u7684\u7248\u672c\u76f8\u540c\uff0c\u800c\u4e0d\u662f\u6700\u65b0\u7248\u672c\uff0c\u56e0\u4e3a\u6700\u65b0\u7248\u672c\u53ef\u80fd\u4f1a\u66f4\u6539\u67d0\u4e9b\u5185\u5bb9\uff0c\u5982\u679c\u662f\u8fd9\u6837\u7684\u8bdd\uff0c\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u5c31\u4e0d\u4f1a\u4e0e\u4f60\u7684\u76f8\u540c\u4e86\u3002 \u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u663e\u793a\u4e86 requirements.txt\u3002 # requirements.txt pandas == 1.0.4 scikit - learn == 0.22.1 torch == 1.5.0 transformers == 2.11.0 \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a Dockerfile \u7684 Docker \u6587\u4ef6\u3002\u6ca1\u6709\u6269\u5c55\u540d\u3002Dockerfile \u6709\u51e0\u4e2a\u5143\u7d20\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 # Dockerfile # First of all, we include where we are getting the image # from. Image can be thought of as an operating system. # You can do \"FROM ubuntu:18.04\" # this will start from a clean ubuntu 18.04 image. # All images are downloaded from dockerhub # Here are we grabbing image from nvidia's repo # they created a docker image using ubuntu 18.04 # and installed cuda 10.1 and cudnn7 in it. Thus, we don't have to # install it. Makes our life easy. FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 # this is the same apt-get command that you are used to # except the fact that, we have -y argument. Its because # when we build this container, we cannot press Y when asked for RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* # We add a new user called \"abhishek\" # this can be anything. Anything you want it # to be. Usually, we don't use our own name, # you can use \"user\" or \"ubuntu\" RUN useradd -m abhishek # make our user own its own home directory RUN chown -R abhishek:abhishek /home/abhishek/ # copy all files from this direrctory to a # directory called app inside the home of abhishek # and abhishek owns it. COPY --chown = abhishek *.* /home/abhishek/app/ # change to user abhishek USER abhishek RUN mkdir /home/abhishek/data/ # Now we install all the requirements # after moving to the app directory # PLEASE NOTE that ubuntu 18.04 image # has python 3.6.9 and not python 3.7.6 # you can also install conda python here and use that # however, to simplify it, I will be using python 3.6.9 # inside the docker container!!!! RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt # install mkl. its needed for transformers RUN pip3 install mkl # when we log into the docker container, # we will go inside this directory automatically WORKDIR /home/abhishek/app \u521b\u5efa\u597d Docker \u6587\u4ef6\u540e\uff0c\u6211\u4eec\u5c31\u9700\u8981\u6784\u5efa\u5b83\u3002\u6784\u5efa Docker \u5bb9\u5668\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u547d\u4ee4\u3002 docker build -f Dockerfile -t bert:train . \u8be5\u547d\u4ee4\u6839\u636e\u63d0\u4f9b\u7684 Dockerfile \u6784\u5efa\u4e00\u4e2a\u5bb9\u5668\u3002Docker \u5bb9\u5668\u7684\u540d\u79f0\u662f bert:train\u3002\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a \u276f docker build -f Dockerfile -t bert:train . Sending build context to Docker daemon 19.97kB Step 1/7 : FROM nvidia/cuda:10.1-cudnn7-ubuntu18.04 ---> 3b55548ae91f Step 2/7 : RUN apt-get update && apt-get install -y git curl ca- certificates python3 python3-pip sudo && rm -rf /var/lib/apt/lists/* . . . . Removing intermediate container 8f6975dd08ba ---> d1802ac9f1b4 Step 7/7 : WORKDIR /home/abhishek/app ---> Running in 257ff09502ed Removing intermediate container 257ff09502ed ---> e5f6eb4cddd7 Successfully built e5f6eb4cddd7 Successfully tagged bert:train \u8bf7\u6ce8\u610f\uff0c\u6211\u5220\u9664\u4e86\u8f93\u51fa\u4e2d\u7684\u8bb8\u591a\u884c\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u767b\u5f55\u5bb9\u5668\u3002 docker run -ti bert:train /bin/bash \u4f60\u9700\u8981\u8bb0\u4f4f\uff0c\u4e00\u65e6\u9000\u51fa shell\uff0c\u4f60\u5728 shell \u4e2d\u6240\u505a\u7684\u4e00\u5207\u90fd\u5c06\u4e22\u5931\u3002\u4f60\u8fd8\u53ef\u4ee5\u5728 Docker \u5bb9\u5668\u4e2d\u4f7f\u7528\u3002 docker run -ti bert:train python3 train.py \u8f93\u51fa\u60c5\u51b5\uff1a Traceback (most recent call last): File \"train.py\", line 2, in import config File \"/home/abhishek/app/config.py\", line 28, in do_lower_case=True File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 393, in from_pretrained return cls._from_pretrained(*inputs, **kwargs) File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 496, in _from_pretrained list(cls.vocab_files_names.values()), OSError: Model name '../input/bert_base_uncased/' was not found in tokenizers model name list (bert-base-uncased, bert-large-uncased, bert- base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base- multilingual-cased, bert-base-chinese, bert-base-german-cased, bert- large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased- whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert- base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base- finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). We assumed '../input/bert_base_uncased/' was a path, a model identifier, or url to a directory containing vocabulary files named ['vocab.txt'] but couldn't find such vocabulary files at this path or url. \u54ce\u5440\uff0c\u51fa\u9519\u4e86\uff01 \u6211\u4e3a\u4ec0\u4e48\u8981\u628a\u9519\u8bef\u5370\u5728\u4e66\u4e0a\u5462\uff1f \u56e0\u4e3a\u7406\u89e3\u8fd9\u4e2a\u9519\u8bef\u975e\u5e38\u91cd\u8981\u3002\u8fd9\u4e2a\u9519\u8bef\u8bf4\u660e\u4ee3\u7801\u65e0\u6cd5\u627e\u5230\u76ee\u5f55\".../input/bert_base_cased\"\u3002\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u5728\u6ca1\u6709 Docker \u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u76ee\u5f55\u548c\u6240\u6709\u6587\u4ef6\u90fd\u5b58\u5728\u3002\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u662f\u56e0\u4e3a Docker \u5c31\u50cf\u4e00\u4e2a\u865a\u62df\u673a\uff01\u5b83\u6709\u81ea\u5df1\u7684\u6587\u4ef6\u7cfb\u7edf\uff0c\u672c\u5730\u673a\u5668\u4e0a\u7684\u6587\u4ef6\u4e0d\u4f1a\u5171\u4eab\u7ed9 Docker \u5bb9\u5668\u3002\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u672c\u5730\u673a\u5668\u4e0a\u7684\u8def\u5f84\u5e76\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\uff0c\u4f60\u9700\u8981\u5728\u8fd0\u884c Docker \u65f6\u5c06\u5176\u6302\u8f7d\u5230 Docker \u5bb9\u5668\u4e0a\u3002\u5f53\u6211\u4eec\u67e5\u770b\u8fd9\u4e2a\u6587\u4ef6\u5939\u7684\u8def\u5f84\u65f6\uff0c\u6211\u4eec\u77e5\u9053\u5b83\u4f4d\u4e8e\u540d\u4e3a input \u7684\u6587\u4ef6\u5939\u7684\u4e0a\u4e00\u7ea7\u3002\u8ba9\u6211\u4eec\u7a0d\u5fae\u4fee\u6539\u4e00\u4e0b config.py \u6587\u4ef6\uff01 # config.py import os import transformers # fetch home directory # in our docker container, it is # /home/abhishek HOME_DIR = os . path . expanduser ( \"~\" ) # this is the maximum number of tokens in the sentence MAX_LEN = 512 # batch sizes is low because model is huge! TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 # let's train for a maximum of 10 epochs EPOCHS = 10 # define path to BERT model files # Now we assume that all the data is stored inside # /home/abhishek/data BERT_PATH = os . path . join ( HOME_DIR , \"data\" , \"bert_base_uncased\" ) # this is where you want to save the model MODEL_PATH = os . path . join ( HOME_DIR , \"data\" , \"model.bin\" ) # training file TRAINING_FILE = os . path . join ( HOME_DIR , \"data\" , \"imdb.csv\" ) TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u73b0\u5728\uff0c\u4ee3\u7801\u5047\u5b9a\u6240\u6709\u5185\u5bb9\u90fd\u5728\u4e3b\u76ee\u5f55\u4e0b\u540d\u4e3a data \u7684\u6587\u4ef6\u5939\u4e2d\u3002 \u8bf7\u6ce8\u610f\uff0c\u5982\u679c Python \u811a\u672c\u6709\u4efb\u4f55\u6539\u52a8\uff0c\u90fd\u610f\u5473\u7740\u9700\u8981\u91cd\u5efa Docker \u5bb9\u5668\uff01\u56e0\u6b64\uff0c\u6211\u4eec\u91cd\u5efa\u5bb9\u5668\uff0c\u7136\u540e\u91cd\u65b0\u8fd0\u884c Docker \u547d\u4ee4\uff0c\u4f46\u8fd9\u6b21\u8981\u6709\u6240\u6539\u53d8\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u6ca1\u6709\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u8fd9\u4e5f\u662f\u884c\u4e0d\u901a\u7684\u3002\u522b\u62c5\u5fc3\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a Docker \u5bb9\u5668\u3002\u4f60\u53ea\u9700\u8981\u505a\u4e00\u6b21\u3002\u8981\u5b89\u88c5\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e2d\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\u3002 distribution = $( . /etc/os-release ; echo $ID$VERSION_ID ) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/ $distribution /nvidia- docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u6784\u5efa\u6211\u4eec\u7684\u5bb9\u5668\uff0c\u5e76\u5f00\u59cb\u8bad\u7ec3\u8fc7\u7a0b\uff1a docker run --gpus 1 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:train python3 train.py \u5176\u4e2d\uff0c-gpus 1 \u8868\u793a\u6211\u4eec\u5728 docker \u5bb9\u5668\u4e2d\u4f7f\u7528 1 \u4e2a GPU\uff0c-v \u8868\u793a\u6302\u8f7d\u5377\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u8981\u5c06\u672c\u5730\u76ee\u5f55 /home/abhishek/workspace/approaching_almost/input/ \u6302\u8f7d\u5230 docker \u5bb9\u5668\u4e2d\u7684 /home/abhishek/data/\u3002\u8fd9\u4e00\u6b65\u8981\u82b1\u70b9\u65f6\u95f4\uff0c\u4f46\u5b8c\u6210\u540e\uff0c\u672c\u5730\u6587\u4ef6\u5939\u4e2d\u5c31\u4f1a\u6709 model.bin\u3002 \u8fd9\u6837\uff0c\u53ea\u9700\u505a\u4e00\u4e9b\u7b80\u5355\u7684\u6539\u52a8\uff0c\u4f60\u7684\u8bad\u7ec3\u4ee3\u7801\u5c31\u5df2\u7ecf \"dockerized \"\u4e86\u3002\u73b0\u5728\uff0c\u4f60\u53ef\u4ee5\u5728\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u7cfb\u7edf\u4e0a\u4f7f\u7528\u8fd9\u4e9b\u4ee3\u7801\u8fdb\u884c\u8bad\u7ec3\u3002 \u4e0b\u4e00\u90e8\u5206\u662f\u5c06\u6211\u4eec\u8bad\u7ec3\u597d\u7684\u6a21\u578b \"\u63d0\u4f9b \"\u7ed9\u6700\u7ec8\u7528\u6237\u3002\u5047\u8bbe\u60a8\u60f3\u4ece\u63a5\u6536\u5230\u7684\u63a8\u6587\u6d41\u4e2d\u63d0\u53d6\u60c5\u611f\u4fe1\u606f\u3002\u8981\u5b8c\u6210\u8fd9\u9879\u4efb\u52a1\uff0c\u60a8\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a API\uff0c\u7528\u4e8e\u8f93\u5165\u53e5\u5b50\uff0c\u7136\u540e\u8fd4\u56de\u5e26\u6709\u60c5\u611f\u6982\u7387\u7684\u8f93\u51fa\u3002\u4f7f\u7528 Python \u6784\u5efa API \u7684\u6700\u5e38\u89c1\u65b9\u6cd5\u662f\u4f7f\u7528 Flask \uff0c\u5b83\u662f\u4e00\u4e2a\u5fae\u578b\u7f51\u7edc\u670d\u52a1\u6846\u67b6\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) MODEL = None DEVICE = \"cuda\" def sentence_prediction ( sentence ): tokenizer = config . TOKENIZER max_len = config . MAX_LEN review = str ( sentence ) review = \" \" . join ( review . split ()) inputs = tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = max_len ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] padding_length = max_len - len ( ids ) ids = ids + ([ 0 ] * padding_length ) mask = mask + ([ 0 ] * padding_length ) token_type_ids = token_type_ids + ([ 0 ] * padding_length ) ids = torch . tensor ( ids , dtype = torch . long ) . unsqueeze ( 0 ) mask = torch . tensor ( mask , dtype = torch . long ) . unsqueeze ( 0 ) token_type_ids = torch . tensor ( token_type_ids , dtype = torch . long ) . unsqueeze ( 0 ) ids = ids . to ( DEVICE , dtype = torch . long ) token_type_ids = token_type_ids . to ( DEVICE , dtype = torch . long ) mask = mask . to ( DEVICE , dtype = torch . long ) outputs = MODEL ( ids = ids , mask = mask , token_type_ids = token_type_ids ) outputs = torch . sigmoid ( outputs ) . cpu () . detach () . numpy () return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): sentence = request . args . get ( \"sentence\" ) start_time = time . time () positive_prediction = sentence_prediction ( sentence ) negative_prediction = 1 - positive_prediction response = {} response [ \"response\" ] = { \"positive\" : str ( positive_prediction ), \"negative\" : str ( negative_prediction ), \"sentence\" : str ( sentence ), \"time_taken\" : str ( time . time () - start_time ), } return flask . jsonify ( response ) if __name__ == \"__main__\" : MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ) )) MODEL . to ( DEVICE ) MODEL . eval () app . run ( host = \"0.0.0.0\" ) \u7136\u540e\u8fd0\u884c \"python api.py \"\u547d\u4ee4\u542f\u52a8 API\u3002API \u5c06\u5728\u7aef\u53e3 5000 \u7684 localhost \u4e0a\u542f\u52a8\u3002cURL \u8bf7\u6c42\u53ca\u5176\u54cd\u5e94\u793a\u4f8b\u5982\u4e0b\u3002 \u276f curl $'http://192.168.86.48:5000/predict?sentence=this%20is%20the%20best%20boo k%20ever' {\"response\":{\"negative\":\"0.0032927393913269043\",\"positive\":\"0.99670726\",\" sentence\":\"this is the best book ever\",\"time_taken\":\"0.029126882553100586\"}} \u53ef\u4ee5\u770b\u5230\uff0c\u6211\u4eec\u5f97\u5230\u7684\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002 \u60a8\u8fd8\u53ef\u4ee5\u8bbf\u95ee http://127.0.0.1:5000/predict?sentence=this%20book%20is%20too%20complicated%20for%20me\u3002\u8fd9\u5c06\u518d\u6b21\u8fd4\u56de\u4e00\u4e2a JSON \u6587\u4ef6\u3002 { response : { negative : \"0.8646619468927383\" , positive : \"0.13533805\" , sentence : \"this book is too complicated for me\" , time_taken : \"0.03852701187133789\" } } \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\uff0c\u53ef\u4ee5\u7528\u6765\u4e3a\u5c11\u91cf\u7528\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u4e3a\u4ec0\u4e48\u662f\u5c11\u91cf\uff1f\u56e0\u4e3a\u8fd9\u4e2a API \u4e00\u6b21\u53ea\u670d\u52a1\u4e00\u4e2a\u8bf7\u6c42\u3002gunicorn \u662f UNIX \u4e0a\u7684 Python WSGI HTTP \u670d\u52a1\u5668\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u5b83\u7684 CPU \u6765\u5904\u7406\u591a\u4e2a\u5e76\u884c\u8bf7\u6c42\u3002Gunicorn \u53ef\u4ee5\u4e3a API \u521b\u5efa\u591a\u4e2a\u8fdb\u7a0b\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u540c\u65f6\u4e3a\u591a\u4e2a\u5ba2\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 \"pip install gunicorn \"\u5b89\u88c5 gunicorn\u3002 \u4e3a\u4e86\u5c06\u4ee3\u7801\u8f6c\u6362\u4e3a\u4e0e gunicorn \u517c\u5bb9\uff0c\u6211\u4eec\u9700\u8981\u79fb\u9664 init main\uff0c\u5e76\u5c06\u5176\u4e2d\u7684\u6240\u6709\u5185\u5bb9\u79fb\u81f3\u5168\u5c40\u8303\u56f4\u3002\u6b64\u5916\uff0c\u6211\u4eec\u73b0\u5728\u4f7f\u7528\u7684\u662f CPU \u800c\u4e0d\u662f GPU\u3002\u4fee\u6539\u540e\u7684\u4ee3\u7801\u5982\u4e0b\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) DEVICE = \"cpu\" MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ))) MODEL . to ( DEVICE ) MODEL . eval () def sentence_prediction ( sentence ): return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): return flask . jsonify ( response ) \u6211\u4eec\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002 gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u5728\u63d0\u4f9b\u7684 IP \u5730\u5740\u548c\u7aef\u53e3\u4e0a\u4f7f\u7528 4 \u4e2a Worker \u8fd0\u884c\u6211\u4eec\u7684 flask \u5e94\u7528\u7a0b\u5e8f\u3002\u7531\u4e8e\u6709 4 \u4e2a Worker\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u540c\u65f6\u5904\u7406 4 \u4e2a\u8bf7\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u73b0\u5728\u6211\u4eec\u7684\u7ec8\u7aef\u4f7f\u7528\u7684\u662f CPU\uff0c\u56e0\u6b64\u4e0d\u9700\u8981 GPU \u673a\u5668\uff0c\u53ef\u4ee5\u5728\u4efb\u4f55\u6807\u51c6\u670d\u52a1\u5668/\u865a\u62df\u673a\u4e0a\u8fd0\u884c\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u8fd8\u6709\u4e00\u4e2a\u95ee\u9898\uff1a\u6211\u4eec\u5df2\u7ecf\u5728\u672c\u5730\u673a\u5668\u4e0a\u5b8c\u6210\u4e86\u6240\u6709\u5de5\u4f5c\uff0c\u56e0\u6b64\u5fc5\u987b\u5c06\u5176\u575e\u5316\u3002\u770b\u770b\u4e0b\u9762\u8fd9\u4e2a\u672a\u6ce8\u91ca\u7684 Dockerfile\uff0c\u5b83\u53ef\u4ee5\u7528\u6765\u90e8\u7f72\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002\u8bf7\u6ce8\u610f\u7528\u4e8e\u57f9\u8bad\u7684\u65e7 Dockerfile \u548c\u8fd9\u4e2a Dockerfile \u4e4b\u95f4\u7684\u533a\u522b\u3002\u533a\u522b\u4e0d\u5927\u3002 # CPU Dockerfile FROM ubuntu:18.04 RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* RUN useradd -m abhishek RUN chown -R abhishek:abhishek /home/abhishek/ COPY --chown = abhishek *.* /home/abhishek/app/ USER abhishek RUN mkdir /home/abhishek/data/ RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt RUN pip3 install mkl WORKDIR /home/abhishek/app \u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u65b0\u7684 Docker \u5bb9\u5668\u3002 docker build -f Dockerfile -t bert:api \u5f53 Docker \u5bb9\u5668\u6784\u5efa\u5b8c\u6210\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u76f4\u63a5\u8fd0\u884c API \u4e86\u3002 docker run -p 5000 :5000 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:api /home/abhishek/.local/bin/gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5c06\u5bb9\u5668\u5185\u7684 5000 \u7aef\u53e3\u66b4\u9732\u7ed9\u5bb9\u5668\u5916\u7684 5000 \u7aef\u53e3\u3002\u5982\u679c\u4f7f\u7528 docker-compose\uff0c\u4e5f\u53ef\u4ee5\u5f88\u597d\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002Dockercompose \u662f\u4e00\u4e2a\u53ef\u4ee5\u8ba9\u4f60\u540c\u65f6\u5728\u4e0d\u540c\u6216\u76f8\u540c\u5bb9\u5668\u4e2d\u8fd0\u884c\u4e0d\u540c\u670d\u52a1\u7684\u5de5\u5177\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 \"pip install docker-compose \"\u5b89\u88c5 docker-compose\uff0c\u7136\u540e\u5728\u6784\u5efa\u5bb9\u5668\u540e\u8fd0\u884c \"docker-compose up\"\u3002\u8981\u4f7f\u7528 docker-compose\uff0c\u4f60\u9700\u8981\u4e00\u4e2a docker-compose.yml \u6587\u4ef6\u3002 # docker-compose.yml # specify a version of the compose version : '3.7' # you can add multiple services services : # specify service name. we call our service: api api : # specify image name image : bert:api # the command that you would like to run inside the container command : /home/abhishek/.local/bin/gunicorn api:app --bind 0.0.0.0:5000 --workers 4 # mount the volume volumes : - /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ # this ensures that our ports from container will be # exposed as it is network_mode : host \u73b0\u5728\uff0c\u60a8\u53ea\u9700\u4f7f\u7528\u4e0a\u8ff0\u547d\u4ee4\u5373\u53ef\u91cd\u65b0\u8fd0\u884c API\uff0c\u5176\u8fd0\u884c\u65b9\u5f0f\u4e0e\u4e4b\u524d\u76f8\u540c\u3002\u606d\u559c\u4f60\uff0c\u73b0\u5728\uff0c\u4f60\u4e5f\u5df2\u7ecf\u6210\u529f\u5730\u5c06\u9884\u6d4b API \u8fdb\u884c\u4e86 Docker \u5316\uff0c\u53ef\u4ee5\u968f\u65f6\u968f\u5730\u90e8\u7f72\u4e86\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u5b66\u4e60\u4e86 Docker\u3001\u4f7f\u7528 flask \u6784\u5efa API\u3001\u4f7f\u7528 gunicorn \u548c Docker \u670d\u52a1 API \u4ee5\u53ca docker-compose\u3002\u5173\u4e8e docker \u7684\u77e5\u8bc6\u8fdc\u4e0d\u6b62\u8fd9\u4e9b\uff0c\u4f46\u8fd9\u5e94\u8be5\u662f\u4e00\u4e2a\u5f00\u59cb\u3002\u5176\u4ed6\u5185\u5bb9\u53ef\u4ee5\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9010\u6e10\u638c\u63e1\u3002 \u6211\u4eec\u8fd8\u8df3\u8fc7\u4e86\u8bb8\u591a\u5de5\u5177\uff0c\u5982 kubernetes\u3001bean-stalk\u3001sagemaker\u3001heroku \u548c\u8bb8\u591a\u5176\u4ed6\u5de5\u5177\uff0c\u8fd9\u4e9b\u5de5\u5177\u5982\u4eca\u88ab\u4eba\u4eec\u7528\u6765\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u6a21\u578b\u3002\"\u6211\u8981\u5199\u4ec0\u4e48\uff1f\u70b9\u51fb\u4fee\u6539\u56fe X \u4e2d\u7684 docker \u5bb9\u5668\"\uff1f\u5728\u4e66\u4e2d\u63cf\u8ff0\u8fd9\u4e9b\u662f\u4e0d\u53ef\u884c\u7684\uff0c\u4e5f\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u6240\u4ee5\u6211\u5c06\u4f7f\u7528\u4e0d\u540c\u7684\u5a92\u4ecb\u6765\u8d5e\u7f8e\u672c\u4e66\u7684\u8fd9\u4e00\u90e8\u5206\u3002\u8bf7\u8bb0\u4f4f\uff0c\u4e00\u65e6\u4f60\u5bf9\u5e94\u7528\u7a0b\u5e8f\u8fdb\u884c\u4e86 Docker \u5316\uff0c\u4f7f\u7528\u8fd9\u4e9b\u6280\u672f/\u5e73\u53f0\u8fdb\u884c\u90e8\u7f72\u5c31\u53d8\u5f97\u6613\u5982\u53cd\u638c\u4e86\u3002\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u8981\u8ba9\u4f60\u7684\u4ee3\u7801\u548c\u6a21\u578b\u5bf9\u4ed6\u4eba\u53ef\u7528\uff0c\u5e76\u505a\u597d\u6587\u6863\u8bb0\u5f55\uff0c\u8fd9\u6837\u4efb\u4f55\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u4f60\u5f00\u53d1\u7684\u4e1c\u897f\uff0c\u800c\u65e0\u9700\u591a\u6b21\u8be2\u95ee\u4f60\u3002\u8fd9\u4e0d\u4ec5\u80fd\u8282\u7701\u60a8\u7684\u65f6\u95f4\uff0c\u8fd8\u80fd\u8282\u7701\u4ed6\u4eba\u7684\u65f6\u95f4\u3002\u597d\u7684\u3001\u5f00\u6e90\u7684\u3001\u53ef\u91cd\u590d\u4f7f\u7528\u7684\u4ee3\u7801\u5728\u60a8\u7684\u4f5c\u54c1\u96c6\u4e2d\u4e5f\u975e\u5e38\u91cd\u8981\u3002","title":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5"},{"location":"%E5%8F%AF%E9%87%8D%E5%A4%8D%E4%BB%A3%E7%A0%81%E5%92%8C%E6%A8%A1%E5%9E%8B%E6%96%B9%E6%B3%95/#_1","text":"\u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u5230\u4e86\u53ef\u4ee5\u5c06\u6a21\u578b/\u8bad\u7ec3\u4ee3\u7801\u5206\u53d1\u7ed9\u4ed6\u4eba\u4f7f\u7528\u7684\u9636\u6bb5\u3002\u60a8\u53ef\u4ee5\u7528\u8f6f\u76d8\u5206\u53d1\u6216\u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\uff0c\u4f46\u8fd9\u5e76\u4e0d\u7406\u60f3\u3002\u662f\u8fd9\u6837\u5417\uff1f\u4e5f\u8bb8\u5f88\u591a\u5e74\u524d\uff0c\u8fd9\u662f\u7406\u60f3\u7684\u505a\u6cd5\uff0c\u4f46\u73b0\u5728\u4e0d\u662f\u4e86\u3002 \u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\u548c\u534f\u4f5c\u7684\u9996\u9009\u65b9\u5f0f\u662f\u4f7f\u7528\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u3002Git \u662f\u6700\u6d41\u884c\u7684\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u4e4b\u4e00\u3002\u90a3\u4e48\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u5b66\u4f1a\u4e86 Git\uff0c\u5e76\u6b63\u786e\u5730\u683c\u5f0f\u5316\u4e86\u4ee3\u7801\uff0c\u7f16\u5199\u4e86\u9002\u5f53\u7684\u6587\u6863\uff0c\u8fd8\u5f00\u6e90\u4e86\u4f60\u7684\u9879\u76ee\u3002\u8fd9\u5c31\u591f\u4e86\u5417\uff1f\u4e0d\uff0c\u8fd8\u4e0d\u591f\u3002\u56e0\u4e3a\u4f60\u5728\u81ea\u5df1\u7684\u7535\u8111\u4e0a\u5199\u7684\u4ee3\u7801\uff0c\u5728\u522b\u4eba\u7684\u7535\u8111\u4e0a\u53ef\u80fd\u4f1a\u56e0\u4e3a\u5404\u79cd\u539f\u56e0\u800c\u65e0\u6cd5\u8fd0\u884c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u5728\u53d1\u5e03\u4ee3\u7801\u65f6\u80fd\u590d\u5236\u81ea\u5df1\u7684\u7535\u8111\uff0c\u800c\u5176\u4ed6\u4eba\u5728\u5b89\u88c5\u60a8\u7684\u8f6f\u4ef6\u6216\u8fd0\u884c\u60a8\u7684\u4ee3\u7801\u65f6\u4e5f\u80fd\u590d\u5236\u60a8\u7684\u7535\u8111\uff0c\u90a3\u5c31\u518d\u597d\u4e0d\u8fc7\u4e86\u3002\u4e3a\u6b64\uff0c\u5982\u4eca\u6700\u6d41\u884c\u7684\u65b9\u6cd5\u662f\u4f7f\u7528 Docker \u5bb9\u5668\uff08Docker Containers\uff09\u3002\u8981\u4f7f\u7528 Docker \u5bb9\u5668\uff0c\u4f60\u9700\u8981\u5b89\u88c5 Docker\u3002 \u8ba9\u6211\u4eec\u7528\u4e0b\u9762\u7684\u547d\u4ee4\u6765\u5b89\u88c5 Docker\u3002 sudo apt install docker.io sudo systemctl start docker sudo systemctl enable docker sudo groupadd docker sudo usermod -aG docker $USER \u8fd9\u4e9b\u547d\u4ee4\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e0a\u8fd0\u884c\u3002Docker \u6700\u68d2\u7684\u5730\u65b9\u5728\u4e8e\u5b83\u53ef\u4ee5\u5b89\u88c5\u5728\u4efb\u4f55\u673a\u5668\u4e0a\uff1a Linux\u3001Windows\u3001OSX\u3002\u56e0\u6b64\uff0c\u5982\u679c\u4f60\u4e00\u76f4\u5728 Docker \u5bb9\u5668\u4e2d\u5de5\u4f5c\uff0c\u54ea\u53f0\u673a\u5668\u90fd\u6ca1\u5173\u7cfb\uff01 Docker \u5bb9\u5668\u53ef\u4ee5\u88ab\u89c6\u4e3a\u5c0f\u578b\u865a\u62df\u673a\u3002\u4f60\u53ef\u4ee5\u4e3a\u4f60\u7684\u4ee3\u7801\u521b\u5efa\u4e00\u4e2a\u5bb9\u5668\uff0c\u7136\u540e\u6bcf\u4e2a\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u548c\u8bbf\u95ee\u5b83\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u521b\u5efa\u53ef\u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u5bb9\u5668\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e00\u7ae0\u4e2d\u8bad\u7ec3\u7684 BERT \u6a21\u578b\uff0c\u5e76\u5c1d\u8bd5\u5c06\u8bad\u7ec3\u4ee3\u7801\u5bb9\u5668\u5316\u3002 \u9996\u5148\uff0c\u4f60\u9700\u8981\u4e00\u4e2a\u5305\u542b python \u9879\u76ee\u9700\u6c42\u7684\u6587\u4ef6\u3002\u9700\u6c42\u5305\u542b\u5728\u540d\u4e3a requirements.txt \u7684\u6587\u4ef6\u4e2d\u3002\u6587\u4ef6\u540d\u662f thestandard\u3002\u8be5\u6587\u4ef6\u5305\u542b\u9879\u76ee\u4e2d\u4f7f\u7528\u7684\u6240\u6709 python \u5e93\u3002\u4e5f\u5c31\u662f\u53ef\u4ee5\u901a\u8fc7 PyPI (pip) \u4e0b\u8f7d\u7684 python \u5e93\u3002\u7528\u4e8e \u8bad\u7ec3 BERT \u6a21\u578b\u4ee5\u68c0\u6d4b\u6b63/\u8d1f\u60c5\u611f\uff0c\u6211\u4eec\u4f7f\u7528\u4e86 torch\u3001transformers\u3001tqdm\u3001scikit-learn\u3001pandas \u548c numpy\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u4eec\u5199\u5165 requirements.txt \u4e2d\u3002\u4f60\u53ef\u4ee5\u53ea\u5199\u540d\u79f0\uff0c\u4e5f\u53ef\u4ee5\u5305\u62ec\u7248\u672c\u3002\u5305\u542b\u7248\u672c\u603b\u662f\u6700\u597d\u7684\uff0c\u8fd9\u4e5f\u662f\u4f60\u5e94\u8be5\u505a\u7684\u3002\u5305\u542b\u7248\u672c\u540e\uff0c\u53ef\u4ee5\u786e\u4fdd\u5176\u4ed6\u4eba\u4f7f\u7528\u7684\u7248\u672c\u4e0e\u4f60\u7684\u7248\u672c\u76f8\u540c\uff0c\u800c\u4e0d\u662f\u6700\u65b0\u7248\u672c\uff0c\u56e0\u4e3a\u6700\u65b0\u7248\u672c\u53ef\u80fd\u4f1a\u66f4\u6539\u67d0\u4e9b\u5185\u5bb9\uff0c\u5982\u679c\u662f\u8fd9\u6837\u7684\u8bdd\uff0c\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u5c31\u4e0d\u4f1a\u4e0e\u4f60\u7684\u76f8\u540c\u4e86\u3002 \u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u663e\u793a\u4e86 requirements.txt\u3002 # requirements.txt pandas == 1.0.4 scikit - learn == 0.22.1 torch == 1.5.0 transformers == 2.11.0 \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a Dockerfile \u7684 Docker \u6587\u4ef6\u3002\u6ca1\u6709\u6269\u5c55\u540d\u3002Dockerfile \u6709\u51e0\u4e2a\u5143\u7d20\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 # Dockerfile # First of all, we include where we are getting the image # from. Image can be thought of as an operating system. # You can do \"FROM ubuntu:18.04\" # this will start from a clean ubuntu 18.04 image. # All images are downloaded from dockerhub # Here are we grabbing image from nvidia's repo # they created a docker image using ubuntu 18.04 # and installed cuda 10.1 and cudnn7 in it. Thus, we don't have to # install it. Makes our life easy. FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 # this is the same apt-get command that you are used to # except the fact that, we have -y argument. Its because # when we build this container, we cannot press Y when asked for RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* # We add a new user called \"abhishek\" # this can be anything. Anything you want it # to be. Usually, we don't use our own name, # you can use \"user\" or \"ubuntu\" RUN useradd -m abhishek # make our user own its own home directory RUN chown -R abhishek:abhishek /home/abhishek/ # copy all files from this direrctory to a # directory called app inside the home of abhishek # and abhishek owns it. COPY --chown = abhishek *.* /home/abhishek/app/ # change to user abhishek USER abhishek RUN mkdir /home/abhishek/data/ # Now we install all the requirements # after moving to the app directory # PLEASE NOTE that ubuntu 18.04 image # has python 3.6.9 and not python 3.7.6 # you can also install conda python here and use that # however, to simplify it, I will be using python 3.6.9 # inside the docker container!!!! RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt # install mkl. its needed for transformers RUN pip3 install mkl # when we log into the docker container, # we will go inside this directory automatically WORKDIR /home/abhishek/app \u521b\u5efa\u597d Docker \u6587\u4ef6\u540e\uff0c\u6211\u4eec\u5c31\u9700\u8981\u6784\u5efa\u5b83\u3002\u6784\u5efa Docker \u5bb9\u5668\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u547d\u4ee4\u3002 docker build -f Dockerfile -t bert:train . \u8be5\u547d\u4ee4\u6839\u636e\u63d0\u4f9b\u7684 Dockerfile \u6784\u5efa\u4e00\u4e2a\u5bb9\u5668\u3002Docker \u5bb9\u5668\u7684\u540d\u79f0\u662f bert:train\u3002\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a \u276f docker build -f Dockerfile -t bert:train . Sending build context to Docker daemon 19.97kB Step 1/7 : FROM nvidia/cuda:10.1-cudnn7-ubuntu18.04 ---> 3b55548ae91f Step 2/7 : RUN apt-get update && apt-get install -y git curl ca- certificates python3 python3-pip sudo && rm -rf /var/lib/apt/lists/* . . . . Removing intermediate container 8f6975dd08ba ---> d1802ac9f1b4 Step 7/7 : WORKDIR /home/abhishek/app ---> Running in 257ff09502ed Removing intermediate container 257ff09502ed ---> e5f6eb4cddd7 Successfully built e5f6eb4cddd7 Successfully tagged bert:train \u8bf7\u6ce8\u610f\uff0c\u6211\u5220\u9664\u4e86\u8f93\u51fa\u4e2d\u7684\u8bb8\u591a\u884c\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u767b\u5f55\u5bb9\u5668\u3002 docker run -ti bert:train /bin/bash \u4f60\u9700\u8981\u8bb0\u4f4f\uff0c\u4e00\u65e6\u9000\u51fa shell\uff0c\u4f60\u5728 shell \u4e2d\u6240\u505a\u7684\u4e00\u5207\u90fd\u5c06\u4e22\u5931\u3002\u4f60\u8fd8\u53ef\u4ee5\u5728 Docker \u5bb9\u5668\u4e2d\u4f7f\u7528\u3002 docker run -ti bert:train python3 train.py \u8f93\u51fa\u60c5\u51b5\uff1a Traceback (most recent call last): File \"train.py\", line 2, in import config File \"/home/abhishek/app/config.py\", line 28, in do_lower_case=True File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 393, in from_pretrained return cls._from_pretrained(*inputs, **kwargs) File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 496, in _from_pretrained list(cls.vocab_files_names.values()), OSError: Model name '../input/bert_base_uncased/' was not found in tokenizers model name list (bert-base-uncased, bert-large-uncased, bert- base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base- multilingual-cased, bert-base-chinese, bert-base-german-cased, bert- large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased- whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert- base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base- finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). We assumed '../input/bert_base_uncased/' was a path, a model identifier, or url to a directory containing vocabulary files named ['vocab.txt'] but couldn't find such vocabulary files at this path or url. \u54ce\u5440\uff0c\u51fa\u9519\u4e86\uff01 \u6211\u4e3a\u4ec0\u4e48\u8981\u628a\u9519\u8bef\u5370\u5728\u4e66\u4e0a\u5462\uff1f \u56e0\u4e3a\u7406\u89e3\u8fd9\u4e2a\u9519\u8bef\u975e\u5e38\u91cd\u8981\u3002\u8fd9\u4e2a\u9519\u8bef\u8bf4\u660e\u4ee3\u7801\u65e0\u6cd5\u627e\u5230\u76ee\u5f55\".../input/bert_base_cased\"\u3002\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u5728\u6ca1\u6709 Docker \u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u76ee\u5f55\u548c\u6240\u6709\u6587\u4ef6\u90fd\u5b58\u5728\u3002\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u662f\u56e0\u4e3a Docker \u5c31\u50cf\u4e00\u4e2a\u865a\u62df\u673a\uff01\u5b83\u6709\u81ea\u5df1\u7684\u6587\u4ef6\u7cfb\u7edf\uff0c\u672c\u5730\u673a\u5668\u4e0a\u7684\u6587\u4ef6\u4e0d\u4f1a\u5171\u4eab\u7ed9 Docker \u5bb9\u5668\u3002\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u672c\u5730\u673a\u5668\u4e0a\u7684\u8def\u5f84\u5e76\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\uff0c\u4f60\u9700\u8981\u5728\u8fd0\u884c Docker \u65f6\u5c06\u5176\u6302\u8f7d\u5230 Docker \u5bb9\u5668\u4e0a\u3002\u5f53\u6211\u4eec\u67e5\u770b\u8fd9\u4e2a\u6587\u4ef6\u5939\u7684\u8def\u5f84\u65f6\uff0c\u6211\u4eec\u77e5\u9053\u5b83\u4f4d\u4e8e\u540d\u4e3a input \u7684\u6587\u4ef6\u5939\u7684\u4e0a\u4e00\u7ea7\u3002\u8ba9\u6211\u4eec\u7a0d\u5fae\u4fee\u6539\u4e00\u4e0b config.py \u6587\u4ef6\uff01 # config.py import os import transformers # fetch home directory # in our docker container, it is # /home/abhishek HOME_DIR = os . path . expanduser ( \"~\" ) # this is the maximum number of tokens in the sentence MAX_LEN = 512 # batch sizes is low because model is huge! TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 # let's train for a maximum of 10 epochs EPOCHS = 10 # define path to BERT model files # Now we assume that all the data is stored inside # /home/abhishek/data BERT_PATH = os . path . join ( HOME_DIR , \"data\" , \"bert_base_uncased\" ) # this is where you want to save the model MODEL_PATH = os . path . join ( HOME_DIR , \"data\" , \"model.bin\" ) # training file TRAINING_FILE = os . path . join ( HOME_DIR , \"data\" , \"imdb.csv\" ) TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u73b0\u5728\uff0c\u4ee3\u7801\u5047\u5b9a\u6240\u6709\u5185\u5bb9\u90fd\u5728\u4e3b\u76ee\u5f55\u4e0b\u540d\u4e3a data \u7684\u6587\u4ef6\u5939\u4e2d\u3002 \u8bf7\u6ce8\u610f\uff0c\u5982\u679c Python \u811a\u672c\u6709\u4efb\u4f55\u6539\u52a8\uff0c\u90fd\u610f\u5473\u7740\u9700\u8981\u91cd\u5efa Docker \u5bb9\u5668\uff01\u56e0\u6b64\uff0c\u6211\u4eec\u91cd\u5efa\u5bb9\u5668\uff0c\u7136\u540e\u91cd\u65b0\u8fd0\u884c Docker \u547d\u4ee4\uff0c\u4f46\u8fd9\u6b21\u8981\u6709\u6240\u6539\u53d8\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u6ca1\u6709\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u8fd9\u4e5f\u662f\u884c\u4e0d\u901a\u7684\u3002\u522b\u62c5\u5fc3\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a Docker \u5bb9\u5668\u3002\u4f60\u53ea\u9700\u8981\u505a\u4e00\u6b21\u3002\u8981\u5b89\u88c5\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e2d\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\u3002 distribution = $( . /etc/os-release ; echo $ID$VERSION_ID ) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/ $distribution /nvidia- docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u6784\u5efa\u6211\u4eec\u7684\u5bb9\u5668\uff0c\u5e76\u5f00\u59cb\u8bad\u7ec3\u8fc7\u7a0b\uff1a docker run --gpus 1 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:train python3 train.py \u5176\u4e2d\uff0c-gpus 1 \u8868\u793a\u6211\u4eec\u5728 docker \u5bb9\u5668\u4e2d\u4f7f\u7528 1 \u4e2a GPU\uff0c-v \u8868\u793a\u6302\u8f7d\u5377\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u8981\u5c06\u672c\u5730\u76ee\u5f55 /home/abhishek/workspace/approaching_almost/input/ \u6302\u8f7d\u5230 docker \u5bb9\u5668\u4e2d\u7684 /home/abhishek/data/\u3002\u8fd9\u4e00\u6b65\u8981\u82b1\u70b9\u65f6\u95f4\uff0c\u4f46\u5b8c\u6210\u540e\uff0c\u672c\u5730\u6587\u4ef6\u5939\u4e2d\u5c31\u4f1a\u6709 model.bin\u3002 \u8fd9\u6837\uff0c\u53ea\u9700\u505a\u4e00\u4e9b\u7b80\u5355\u7684\u6539\u52a8\uff0c\u4f60\u7684\u8bad\u7ec3\u4ee3\u7801\u5c31\u5df2\u7ecf \"dockerized \"\u4e86\u3002\u73b0\u5728\uff0c\u4f60\u53ef\u4ee5\u5728\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u7cfb\u7edf\u4e0a\u4f7f\u7528\u8fd9\u4e9b\u4ee3\u7801\u8fdb\u884c\u8bad\u7ec3\u3002 \u4e0b\u4e00\u90e8\u5206\u662f\u5c06\u6211\u4eec\u8bad\u7ec3\u597d\u7684\u6a21\u578b \"\u63d0\u4f9b \"\u7ed9\u6700\u7ec8\u7528\u6237\u3002\u5047\u8bbe\u60a8\u60f3\u4ece\u63a5\u6536\u5230\u7684\u63a8\u6587\u6d41\u4e2d\u63d0\u53d6\u60c5\u611f\u4fe1\u606f\u3002\u8981\u5b8c\u6210\u8fd9\u9879\u4efb\u52a1\uff0c\u60a8\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a API\uff0c\u7528\u4e8e\u8f93\u5165\u53e5\u5b50\uff0c\u7136\u540e\u8fd4\u56de\u5e26\u6709\u60c5\u611f\u6982\u7387\u7684\u8f93\u51fa\u3002\u4f7f\u7528 Python \u6784\u5efa API \u7684\u6700\u5e38\u89c1\u65b9\u6cd5\u662f\u4f7f\u7528 Flask \uff0c\u5b83\u662f\u4e00\u4e2a\u5fae\u578b\u7f51\u7edc\u670d\u52a1\u6846\u67b6\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) MODEL = None DEVICE = \"cuda\" def sentence_prediction ( sentence ): tokenizer = config . TOKENIZER max_len = config . MAX_LEN review = str ( sentence ) review = \" \" . join ( review . split ()) inputs = tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = max_len ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] padding_length = max_len - len ( ids ) ids = ids + ([ 0 ] * padding_length ) mask = mask + ([ 0 ] * padding_length ) token_type_ids = token_type_ids + ([ 0 ] * padding_length ) ids = torch . tensor ( ids , dtype = torch . long ) . unsqueeze ( 0 ) mask = torch . tensor ( mask , dtype = torch . long ) . unsqueeze ( 0 ) token_type_ids = torch . tensor ( token_type_ids , dtype = torch . long ) . unsqueeze ( 0 ) ids = ids . to ( DEVICE , dtype = torch . long ) token_type_ids = token_type_ids . to ( DEVICE , dtype = torch . long ) mask = mask . to ( DEVICE , dtype = torch . long ) outputs = MODEL ( ids = ids , mask = mask , token_type_ids = token_type_ids ) outputs = torch . sigmoid ( outputs ) . cpu () . detach () . numpy () return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): sentence = request . args . get ( \"sentence\" ) start_time = time . time () positive_prediction = sentence_prediction ( sentence ) negative_prediction = 1 - positive_prediction response = {} response [ \"response\" ] = { \"positive\" : str ( positive_prediction ), \"negative\" : str ( negative_prediction ), \"sentence\" : str ( sentence ), \"time_taken\" : str ( time . time () - start_time ), } return flask . jsonify ( response ) if __name__ == \"__main__\" : MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ) )) MODEL . to ( DEVICE ) MODEL . eval () app . run ( host = \"0.0.0.0\" ) \u7136\u540e\u8fd0\u884c \"python api.py \"\u547d\u4ee4\u542f\u52a8 API\u3002API \u5c06\u5728\u7aef\u53e3 5000 \u7684 localhost \u4e0a\u542f\u52a8\u3002cURL \u8bf7\u6c42\u53ca\u5176\u54cd\u5e94\u793a\u4f8b\u5982\u4e0b\u3002 \u276f curl $'http://192.168.86.48:5000/predict?sentence=this%20is%20the%20best%20boo k%20ever' {\"response\":{\"negative\":\"0.0032927393913269043\",\"positive\":\"0.99670726\",\" sentence\":\"this is the best book ever\",\"time_taken\":\"0.029126882553100586\"}} \u53ef\u4ee5\u770b\u5230\uff0c\u6211\u4eec\u5f97\u5230\u7684\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002 \u60a8\u8fd8\u53ef\u4ee5\u8bbf\u95ee http://127.0.0.1:5000/predict?sentence=this%20book%20is%20too%20complicated%20for%20me\u3002\u8fd9\u5c06\u518d\u6b21\u8fd4\u56de\u4e00\u4e2a JSON \u6587\u4ef6\u3002 { response : { negative : \"0.8646619468927383\" , positive : \"0.13533805\" , sentence : \"this book is too complicated for me\" , time_taken : \"0.03852701187133789\" } } \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\uff0c\u53ef\u4ee5\u7528\u6765\u4e3a\u5c11\u91cf\u7528\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u4e3a\u4ec0\u4e48\u662f\u5c11\u91cf\uff1f\u56e0\u4e3a\u8fd9\u4e2a API \u4e00\u6b21\u53ea\u670d\u52a1\u4e00\u4e2a\u8bf7\u6c42\u3002gunicorn \u662f UNIX \u4e0a\u7684 Python WSGI HTTP \u670d\u52a1\u5668\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u5b83\u7684 CPU \u6765\u5904\u7406\u591a\u4e2a\u5e76\u884c\u8bf7\u6c42\u3002Gunicorn \u53ef\u4ee5\u4e3a API \u521b\u5efa\u591a\u4e2a\u8fdb\u7a0b\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u540c\u65f6\u4e3a\u591a\u4e2a\u5ba2\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 \"pip install gunicorn \"\u5b89\u88c5 gunicorn\u3002 \u4e3a\u4e86\u5c06\u4ee3\u7801\u8f6c\u6362\u4e3a\u4e0e gunicorn \u517c\u5bb9\uff0c\u6211\u4eec\u9700\u8981\u79fb\u9664 init main\uff0c\u5e76\u5c06\u5176\u4e2d\u7684\u6240\u6709\u5185\u5bb9\u79fb\u81f3\u5168\u5c40\u8303\u56f4\u3002\u6b64\u5916\uff0c\u6211\u4eec\u73b0\u5728\u4f7f\u7528\u7684\u662f CPU \u800c\u4e0d\u662f GPU\u3002\u4fee\u6539\u540e\u7684\u4ee3\u7801\u5982\u4e0b\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) DEVICE = \"cpu\" MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ))) MODEL . to ( DEVICE ) MODEL . eval () def sentence_prediction ( sentence ): return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): return flask . jsonify ( response ) \u6211\u4eec\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002 gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u5728\u63d0\u4f9b\u7684 IP \u5730\u5740\u548c\u7aef\u53e3\u4e0a\u4f7f\u7528 4 \u4e2a Worker \u8fd0\u884c\u6211\u4eec\u7684 flask \u5e94\u7528\u7a0b\u5e8f\u3002\u7531\u4e8e\u6709 4 \u4e2a Worker\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u540c\u65f6\u5904\u7406 4 \u4e2a\u8bf7\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u73b0\u5728\u6211\u4eec\u7684\u7ec8\u7aef\u4f7f\u7528\u7684\u662f CPU\uff0c\u56e0\u6b64\u4e0d\u9700\u8981 GPU \u673a\u5668\uff0c\u53ef\u4ee5\u5728\u4efb\u4f55\u6807\u51c6\u670d\u52a1\u5668/\u865a\u62df\u673a\u4e0a\u8fd0\u884c\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u8fd8\u6709\u4e00\u4e2a\u95ee\u9898\uff1a\u6211\u4eec\u5df2\u7ecf\u5728\u672c\u5730\u673a\u5668\u4e0a\u5b8c\u6210\u4e86\u6240\u6709\u5de5\u4f5c\uff0c\u56e0\u6b64\u5fc5\u987b\u5c06\u5176\u575e\u5316\u3002\u770b\u770b\u4e0b\u9762\u8fd9\u4e2a\u672a\u6ce8\u91ca\u7684 Dockerfile\uff0c\u5b83\u53ef\u4ee5\u7528\u6765\u90e8\u7f72\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002\u8bf7\u6ce8\u610f\u7528\u4e8e\u57f9\u8bad\u7684\u65e7 Dockerfile \u548c\u8fd9\u4e2a Dockerfile \u4e4b\u95f4\u7684\u533a\u522b\u3002\u533a\u522b\u4e0d\u5927\u3002 # CPU Dockerfile FROM ubuntu:18.04 RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* RUN useradd -m abhishek RUN chown -R abhishek:abhishek /home/abhishek/ COPY --chown = abhishek *.* /home/abhishek/app/ USER abhishek RUN mkdir /home/abhishek/data/ RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt RUN pip3 install mkl WORKDIR /home/abhishek/app \u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u65b0\u7684 Docker \u5bb9\u5668\u3002 docker build -f Dockerfile -t bert:api \u5f53 Docker \u5bb9\u5668\u6784\u5efa\u5b8c\u6210\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u76f4\u63a5\u8fd0\u884c API \u4e86\u3002 docker run -p 5000 :5000 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:api /home/abhishek/.local/bin/gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5c06\u5bb9\u5668\u5185\u7684 5000 \u7aef\u53e3\u66b4\u9732\u7ed9\u5bb9\u5668\u5916\u7684 5000 \u7aef\u53e3\u3002\u5982\u679c\u4f7f\u7528 docker-compose\uff0c\u4e5f\u53ef\u4ee5\u5f88\u597d\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002Dockercompose \u662f\u4e00\u4e2a\u53ef\u4ee5\u8ba9\u4f60\u540c\u65f6\u5728\u4e0d\u540c\u6216\u76f8\u540c\u5bb9\u5668\u4e2d\u8fd0\u884c\u4e0d\u540c\u670d\u52a1\u7684\u5de5\u5177\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 \"pip install docker-compose \"\u5b89\u88c5 docker-compose\uff0c\u7136\u540e\u5728\u6784\u5efa\u5bb9\u5668\u540e\u8fd0\u884c \"docker-compose up\"\u3002\u8981\u4f7f\u7528 docker-compose\uff0c\u4f60\u9700\u8981\u4e00\u4e2a docker-compose.yml \u6587\u4ef6\u3002 # docker-compose.yml # specify a version of the compose version : '3.7' # you can add multiple services services : # specify service name. we call our service: api api : # specify image name image : bert:api # the command that you would like to run inside the container command : /home/abhishek/.local/bin/gunicorn api:app --bind 0.0.0.0:5000 --workers 4 # mount the volume volumes : - /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ # this ensures that our ports from container will be # exposed as it is network_mode : host \u73b0\u5728\uff0c\u60a8\u53ea\u9700\u4f7f\u7528\u4e0a\u8ff0\u547d\u4ee4\u5373\u53ef\u91cd\u65b0\u8fd0\u884c API\uff0c\u5176\u8fd0\u884c\u65b9\u5f0f\u4e0e\u4e4b\u524d\u76f8\u540c\u3002\u606d\u559c\u4f60\uff0c\u73b0\u5728\uff0c\u4f60\u4e5f\u5df2\u7ecf\u6210\u529f\u5730\u5c06\u9884\u6d4b API \u8fdb\u884c\u4e86 Docker \u5316\uff0c\u53ef\u4ee5\u968f\u65f6\u968f\u5730\u90e8\u7f72\u4e86\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u5b66\u4e60\u4e86 Docker\u3001\u4f7f\u7528 flask \u6784\u5efa API\u3001\u4f7f\u7528 gunicorn \u548c Docker \u670d\u52a1 API \u4ee5\u53ca docker-compose\u3002\u5173\u4e8e docker \u7684\u77e5\u8bc6\u8fdc\u4e0d\u6b62\u8fd9\u4e9b\uff0c\u4f46\u8fd9\u5e94\u8be5\u662f\u4e00\u4e2a\u5f00\u59cb\u3002\u5176\u4ed6\u5185\u5bb9\u53ef\u4ee5\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9010\u6e10\u638c\u63e1\u3002 \u6211\u4eec\u8fd8\u8df3\u8fc7\u4e86\u8bb8\u591a\u5de5\u5177\uff0c\u5982 kubernetes\u3001bean-stalk\u3001sagemaker\u3001heroku \u548c\u8bb8\u591a\u5176\u4ed6\u5de5\u5177\uff0c\u8fd9\u4e9b\u5de5\u5177\u5982\u4eca\u88ab\u4eba\u4eec\u7528\u6765\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u6a21\u578b\u3002\"\u6211\u8981\u5199\u4ec0\u4e48\uff1f\u70b9\u51fb\u4fee\u6539\u56fe X \u4e2d\u7684 docker \u5bb9\u5668\"\uff1f\u5728\u4e66\u4e2d\u63cf\u8ff0\u8fd9\u4e9b\u662f\u4e0d\u53ef\u884c\u7684\uff0c\u4e5f\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u6240\u4ee5\u6211\u5c06\u4f7f\u7528\u4e0d\u540c\u7684\u5a92\u4ecb\u6765\u8d5e\u7f8e\u672c\u4e66\u7684\u8fd9\u4e00\u90e8\u5206\u3002\u8bf7\u8bb0\u4f4f\uff0c\u4e00\u65e6\u4f60\u5bf9\u5e94\u7528\u7a0b\u5e8f\u8fdb\u884c\u4e86 Docker \u5316\uff0c\u4f7f\u7528\u8fd9\u4e9b\u6280\u672f/\u5e73\u53f0\u8fdb\u884c\u90e8\u7f72\u5c31\u53d8\u5f97\u6613\u5982\u53cd\u638c\u4e86\u3002\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u8981\u8ba9\u4f60\u7684\u4ee3\u7801\u548c\u6a21\u578b\u5bf9\u4ed6\u4eba\u53ef\u7528\uff0c\u5e76\u505a\u597d\u6587\u6863\u8bb0\u5f55\uff0c\u8fd9\u6837\u4efb\u4f55\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u4f60\u5f00\u53d1\u7684\u4e1c\u897f\uff0c\u800c\u65e0\u9700\u591a\u6b21\u8be2\u95ee\u4f60\u3002\u8fd9\u4e0d\u4ec5\u80fd\u8282\u7701\u60a8\u7684\u65f6\u95f4\uff0c\u8fd8\u80fd\u8282\u7701\u4ed6\u4eba\u7684\u65f6\u95f4\u3002\u597d\u7684\u3001\u5f00\u6e90\u7684\u3001\u53ef\u91cd\u590d\u4f7f\u7528\u7684\u4ee3\u7801\u5728\u60a8\u7684\u4f5c\u54c1\u96c6\u4e2d\u4e5f\u975e\u5e38\u91cd\u8981\u3002","title":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5"},{"location":"%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB%E5%92%8C%E5%88%86%E5%89%B2%E6%96%B9%E6%B3%95/","text":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5 \u8bf4\u5230\u56fe\u50cf\uff0c\u8fc7\u53bb\u51e0\u5e74\u53d6\u5f97\u4e86\u5f88\u591a\u6210\u5c31\u3002\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8fdb\u6b65\u76f8\u5f53\u5feb\uff0c\u611f\u89c9\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8bb8\u591a\u95ee\u9898\u73b0\u5728\u90fd\u66f4\u5bb9\u6613\u89e3\u51b3\u4e86\u3002\u968f\u7740\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u51fa\u73b0\u548c\u8ba1\u7b97\u6210\u672c\u7684\u964d\u4f4e\uff0c\u73b0\u5728\u5728\u5bb6\u91cc\u5c31\u80fd\u8f7b\u677e\u8bad\u7ec3\u51fa\u63a5\u8fd1\u6700\u5148\u8fdb\u6c34\u5e73\u7684\u6a21\u578b\uff0c\u89e3\u51b3\u5927\u591a\u6570\u4e0e\u56fe\u50cf\u76f8\u5173\u7684\u95ee\u9898\u3002\u4f46\u662f\uff0c\u56fe\u50cf\u95ee\u9898\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u578b\u3002\u4ece\u4e24\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u7684\u6807\u51c6\u56fe\u50cf\u5206\u7c7b\uff0c\u5230\u50cf\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\u8fd9\u6837\u5177\u6709\u6311\u6218\u6027\u7684\u95ee\u9898\u3002\u6211\u4eec\u4e0d\u4f1a\u5728\u672c\u4e66\u4e2d\u8ba8\u8bba\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\uff0c\u4f46\u6211\u4eec\u663e\u7136\u4f1a\u5904\u7406\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u56fe\u50cf\u95ee\u9898\u3002 \u6211\u4eec\u53ef\u4ee5\u5bf9\u56fe\u50cf\u91c7\u7528\u54ea\u4e9b\u4e0d\u540c\u7684\u65b9\u6cd5\uff1f\u56fe\u50cf\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u6570\u5b57\u77e9\u9635\u3002\u8ba1\u7b97\u673a\u65e0\u6cd5\u50cf\u4eba\u7c7b\u4e00\u6837\u770b\u5230\u56fe\u50cf\u3002\u5b83\u53ea\u80fd\u770b\u5230\u6570\u5b57\uff0c\u8fd9\u5c31\u662f\u56fe\u50cf\u3002\u7070\u5ea6\u56fe\u50cf\u662f\u4e00\u4e2a\u4e8c\u7ef4\u77e9\u9635\uff0c\u6570\u503c\u8303\u56f4\u4ece 0 \u5230 255\u30020 \u4ee3\u8868\u9ed1\u8272\uff0c255 \u4ee3\u8868\u767d\u8272\uff0c\u4ecb\u4e8e\u4e24\u8005\u4e4b\u95f4\u7684\u662f\u5404\u79cd\u7070\u8272\u3002\u4ee5\u524d\uff0c\u5728\u6ca1\u6709\u6df1\u5ea6\u5b66\u4e60\u7684\u65f6\u5019\uff08\u6216\u8005\u8bf4\u6df1\u5ea6\u5b66\u4e60\u8fd8\u4e0d\u6d41\u884c\u7684\u65f6\u5019\uff09\uff0c\u4eba\u4eec\u4e60\u60ef\u4e8e\u67e5\u770b\u50cf\u7d20\u3002\u6bcf\u4e2a\u50cf\u7d20\u90fd\u662f\u4e00\u4e2a\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u5728 Python \u4e2d\u8f7b\u677e\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u53ea\u9700\u4f7f\u7528 OpenCV \u6216 Python-PIL \u8bfb\u53d6\u7070\u5ea6\u56fe\u50cf\uff0c\u8f6c\u6362\u4e3a numpy \u6570\u7ec4\uff0c\u7136\u540e\u5c06\u77e9\u9635\u5e73\u94fa\uff08\u6241\u5e73\u5316\uff09\u5373\u53ef\u3002\u5982\u679c\u5904\u7406\u7684\u662f RGB \u56fe\u50cf\uff0c\u5219\u9700\u8981\u4e09\u4e2a\u77e9\u9635\uff0c\u800c\u4e0d\u662f\u4e00\u4e2a\u3002\u4f46\u601d\u8def\u662f\u4e00\u6837\u7684\u3002 import numpy as np import matplotlib.pyplot as plt # \u751f\u6210\u4e00\u4e2a 256x256 \u7684\u968f\u673a\u7070\u5ea6\u56fe\u50cf\uff0c\u50cf\u7d20\u503c\u57280\u5230255\u4e4b\u95f4\u968f\u673a\u5206\u5e03 random_image = np . random . randint ( 0 , 256 , ( 256 , 256 )) # \u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u56fe\u50cf\u7a97\u53e3\uff0c\u8bbe\u7f6e\u7a97\u53e3\u5927\u5c0f\u4e3a7x7\u82f1\u5bf8 plt . figure ( figsize = ( 7 , 7 )) # \u663e\u793a\u751f\u6210\u7684\u968f\u673a\u56fe\u50cf # \u4f7f\u7528\u7070\u5ea6\u989c\u8272\u6620\u5c04 (colormap)\uff0c\u8303\u56f4\u4ece0\u5230255 plt . imshow ( random_image , cmap = 'gray' , vmin = 0 , vmax = 255 ) # \u663e\u793a\u56fe\u50cf\u7a97\u53e3 plt . show () \u4e0a\u9762\u7684\u4ee3\u7801\u4f7f\u7528 numpy \u751f\u6210\u4e00\u4e2a\u968f\u673a\u77e9\u9635\u3002\u8be5\u77e9\u9635\u7531 0 \u5230 255\uff08\u5305\u542b\uff09\u7684\u503c\u7ec4\u6210\uff0c\u5927\u5c0f\u4e3a 256x256\uff08\u4e5f\u79f0\u4e3a\u50cf\u7d20\uff09\u3002 \u56fe 1\uff1a\u4e8c\u7ef4\u56fe\u50cf\u9635\u5217\uff08\u5355\u901a\u9053\uff09\u53ca\u5176\u5c55\u5e73\u7248\u672c \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u62fc\u5199\u540e\u7684\u7248\u672c\u53ea\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a M \u7684\u5411\u91cf\uff0c\u5176\u4e2d M = N * N\uff0c\u5728\u672c\u4f8b\u4e2d\uff0c\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 256 * 256 = 65536\u3002 \u73b0\u5728\uff0c\u5982\u679c\u6211\u4eec\u7ee7\u7eed\u5bf9\u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u8fdb\u884c\u5904\u7406\uff0c\u6bcf\u4e2a\u6837\u672c\u5c31\u4f1a\u6709 65536 \u4e2a\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5feb\u901f\u5efa\u7acb \u51b3\u7b56\u6811\u6a21\u578b\u3001\u968f\u673a\u68ee\u6797\u6a21\u578b\u6216\u57fa\u4e8e SVM \u7684\u6a21\u578b \u3002\u8fd9\u4e9b\u6a21\u578b\u5c06\u57fa\u4e8e\u50cf\u7d20\u503c\uff0c\u5c1d\u8bd5\u5c06\u6b63\u6837\u672c\u4e0e\u8d1f\u6837\u672c\u533a\u5206\u5f00\u6765\uff08\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff09\u3002 \u4f60\u4eec\u4e00\u5b9a\u90fd\u542c\u8bf4\u8fc7\u732b\u4e0e\u72d7\u7684\u95ee\u9898\uff0c\u8fd9\u662f\u4e00\u4e2a\u7ecf\u5178\u7684\u95ee\u9898\u3002\u5982\u679c\u4f60\u4eec\u8fd8\u8bb0\u5f97\uff0c\u5728\u8bc4\u4f30\u6307\u6807\u4e00\u7ae0\u7684\u5f00\u5934\uff0c\u6211\u5411\u4f60\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e2a\u6c14\u80f8\u56fe\u50cf\u6570\u636e\u96c6\u3002\u90a3\u4e48\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u68c0\u6d4b\u80ba\u90e8\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\uff08\u5e76\u4e0d\uff09\u7b80\u5355\u7684\u4e8c\u5143\u5206\u7c7b\u3002 \u56fe 2\uff1a\u975e\u6c14\u80f8\u4e0e\u6c14\u80f8 X \u5149\u56fe\u50cf\u5bf9\u6bd4 \u5728\u56fe 2 \u4e2d\uff0c\u60a8\u53ef\u4ee5\u770b\u5230\u975e\u6c14\u80f8\u548c\u6c14\u80f8\u56fe\u50cf\u7684\u5bf9\u6bd4\u3002\u60a8\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\u4e86\uff0c\u5bf9\u4e8e\u4e00\u4e2a\u975e\u4e13\u4e1a\u4eba\u58eb\uff08\u6bd4\u5982\u6211\uff09\u6765\u8bf4\uff0c\u8981\u5728\u8fd9\u4e9b\u56fe\u50cf\u4e2d\u8fa8\u522b\u51fa\u54ea\u4e2a\u662f\u6c14\u80f8\u662f\u76f8\u5f53\u56f0\u96be\u7684\u3002 \u6700\u521d\u7684\u6570\u636e\u96c6\u662f\u5173\u4e8e\u68c0\u6d4b\u6c14\u80f8\u7684\u5177\u4f53\u4f4d\u7f6e\uff0c\u4f46\u6211\u4eec\u5c06\u95ee\u9898\u4fee\u6539\u4e3a\u67e5\u627e\u7ed9\u5b9a\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u522b\u62c5\u5fc3\uff0c\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u4ecb\u7ecd\u8fd9\u4e2a\u90e8\u5206\u3002\u6570\u636e\u96c6\u7531 10675 \u5f20\u72ec\u7279\u7684\u56fe\u50cf\u7ec4\u6210\uff0c\u5176\u4e2d 2379 \u5f20\u6709\u6c14\u80f8\uff08\u6ce8\u610f\uff0c\u8fd9\u4e9b\u6570\u5b57\u662f\u7ecf\u8fc7\u6570\u636e\u6e05\u7406\u540e\u5f97\u51fa\u7684\uff0c\u56e0\u6b64\u4e0e\u539f\u59cb\u6570\u636e\u96c6\u4e0d\u7b26\uff09\u3002\u6b63\u5982\u6570\u636e\u79d1\u5b66\u5bb6\u6240\u8bf4\uff1a\u8fd9\u662f\u4e00\u4e2a\u5178\u578b\u7684 \u504f\u659c\u4e8c\u5143\u5206\u7c7b\u6848\u4f8b \u3002\u56e0\u6b64\uff0c\u6211\u4eec\u9009\u62e9 AUC \u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\uff0c\u5e76\u91c7\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u9a8c\u8bc1\u65b9\u6848\u3002 \u60a8\u53ef\u4ee5\u5c06\u7279\u5f81\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c1d\u8bd5\u4e00\u4e9b\u7ecf\u5178\u65b9\u6cd5\uff08\u5982 SVM\u3001RF\uff09\u6765\u8fdb\u884c\u5206\u7c7b\uff0c\u8fd9\u5b8c\u5168\u6ca1\u95ee\u9898\uff0c\u4f46\u5374\u65e0\u6cd5\u8ba9\u60a8\u8fbe\u5230\u6700\u5148\u8fdb\u7684\u6c34\u5e73\u3002\u6b64\u5916\uff0c\u56fe\u50cf\u5927\u5c0f\u4e3a 1024x1024\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u6a21\u578b\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u7ba1\u600e\u6837\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u7531\u4e8e\u56fe\u50cf\u662f\u7070\u5ea6\u7684\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u8fdb\u884c\u4efb\u4f55\u8f6c\u6362\u3002\u6211\u4eec\u5c06\u628a\u56fe\u50cf\u5927\u5c0f\u8c03\u6574\u4e3a 256x256\uff0c\u4f7f\u5176\u66f4\u5c0f\uff0c\u5e76\u4f7f\u7528\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684 AUC \u4f5c\u4e3a\u8861\u91cf\u6307\u6807\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5b83\u7684\u8868\u73b0\u5982\u4f55\u3002 import os import numpy as np import pandas as pd from PIL import Image from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from tqdm import tqdm # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u6765\u521b\u5efa\u6570\u636e\u96c6 def create_dataset ( training_df , image_dir ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\u6765\u5b58\u50a8\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c images = [] targets = [] # \u8fed\u4ee3\u5904\u7406\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u7684\u6bcf\u4e00\u884c for index , row in tqdm ( training_df . iterrows (), total = len ( training_df ), desc = \"processing images\" ): # \u83b7\u53d6\u56fe\u50cf\u6587\u4ef6\u540d image_id = row [ \"ImageId\" ] # \u6784\u5efa\u5b8c\u6574\u7684\u56fe\u50cf\u6587\u4ef6\u8def\u5f84 image_path = os . path . join ( image_dir , image_id ) # \u6253\u5f00\u56fe\u50cf\u6587\u4ef6\u5e76\u8fdb\u884c\u5927\u5c0f\u8c03\u6574\uff08resize\uff09\u4e3a 256x256 \u50cf\u7d20\uff0c\u4f7f\u7528\u53cc\u7ebf\u6027\u63d2\u503c\uff08BILINEAR\uff09 image = Image . open ( image_path + \".png\" ) image = image . resize (( 256 , 256 ), resample = Image . BILINEAR ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 image = np . array ( image ) # \u5c06\u56fe\u50cf\u6241\u5e73\u5316\u4e3a\u4e00\u7ef4\u6570\u7ec4\uff0c\u5e76\u5c06\u5176\u6dfb\u52a0\u5230\u56fe\u50cf\u5217\u8868 image = image . ravel () images . append ( image ) # \u5c06\u76ee\u6807\u503c\uff08target\uff09\u6dfb\u52a0\u5230\u76ee\u6807\u5217\u8868 targets . append ( int ( row [ \"target\" ])) # \u5c06\u56fe\u50cf\u5217\u8868\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 images = np . array ( images ) # \u6253\u5370\u56fe\u50cf\u6570\u7ec4\u7684\u5f62\u72b6 print ( images . shape ) # \u8fd4\u56de\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c return images , targets if __name__ == \"__main__\" : # \u5b9a\u4e49CSV\u6587\u4ef6\u8def\u5f84\u548c\u56fe\u50cf\u6587\u4ef6\u76ee\u5f55\u8def\u5f84 csv_path = \"/home/abhishek/workspace/siim_png/train.csv\" image_path = \"/home/abhishek/workspace/siim_png/train_png/\" # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( csv_path ) # \u6dfb\u52a0\u4e00\u4e2a\u540d\u4e3a'kfold'\u7684\u5217\uff0c\u5e76\u521d\u59cb\u5316\u4e3a-1 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u503c\uff08target\uff09 y = df . target . values # \u4f7f\u7528\u5206\u5c42KFold\u4ea4\u53c9\u9a8c\u8bc1\u5c06\u6570\u636e\u96c6\u5206\u62105\u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u904d\u5386\u6bcf\u4e2a\u6298\uff08fold\uff09 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u904d\u5386\u6bcf\u4e2a\u6298 for fold_ in range ( 5 ): # \u83b7\u53d6\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtrain , ytrain = create_dataset ( train_df , image_path ) # \u521b\u5efa\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtest , ytest = create_dataset ( test_df , image_path ) # \u521d\u59cb\u5316\u4e00\u4e2a\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668 clf = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u5206\u7c7b\u5668 clf . fit ( xtrain , ytrain ) # \u4f7f\u7528\u5206\u7c7b\u5668\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b\uff0c\u5e76\u83b7\u53d6\u6982\u7387\u503c preds = clf . predict_proba ( xtest )[:, 1 ] # \u6253\u5370\u6298\u6570\uff08fold\uff09\u548cAUC\uff08ROC\u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff09 print ( f \"FOLD: { fold_ } \" ) print ( f \"AUC = { metrics . roc_auc_score ( ytest , preds ) } \" ) print ( \"\" ) \u5e73\u5747 AUC \u503c\u7ea6\u4e3a 0.72\u3002\u8fd9\u8fd8\u4e0d\u9519\uff0c\u4f46\u6211\u4eec\u5e0c\u671b\u80fd\u505a\u5f97\u66f4\u597d\u3002\u4f60\u53ef\u4ee5\u5c06\u8fd9\u79cd\u65b9\u6cd5\u7528\u4e8e\u56fe\u50cf\uff0c\u8fd9\u4e5f\u662f\u5b83\u5728\u4ee5\u524d\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002SVM \u5728\u56fe\u50cf\u6570\u636e\u96c6\u65b9\u9762\u76f8\u5f53\u6709\u540d\u3002\u6df1\u5ea6\u5b66\u4e60\u5df2\u88ab\u8bc1\u660e\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u6700\u5148\u8fdb\u65b9\u6cd5\uff0c\u56e0\u6b64\u6211\u4eec\u4e0b\u4e00\u6b65\u53ef\u4ee5\u8bd5\u8bd5\u5b83\u3002 \u5173\u4e8e\u6df1\u5ea6\u5b66\u4e60\u7684\u5386\u53f2\u4ee5\u53ca\u8c01\u53d1\u660e\u4e86\u4ec0\u4e48\uff0c\u6211\u5c31\u4e0d\u591a\u8bf4\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u6700\u8457\u540d\u7684\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e4b\u4e00 AlexNet\u3002 \u56fe 3\uff1aAlexNet \u67b6\u67849 \u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u4e2d\u7684\u8f93\u5165\u5927\u5c0f\u4e0d\u662f 224x224 \u800c\u662f 227x227 \u5982\u4eca\uff0c\u4f60\u53ef\u80fd\u4f1a\u8bf4\u8fd9\u53ea\u662f\u4e00\u4e2a\u57fa\u672c\u7684 \u6df1\u5ea6\u5377\u79ef\u795e\u7ecf\u7f51\u7edc \uff0c\u4f46\u5b83\u5374\u662f\u8bb8\u591a\u65b0\u578b\u6df1\u5ea6\u7f51\u7edc\uff08\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff09\u7684\u57fa\u7840\u3002\u6211\u4eec\u770b\u5230\uff0c\u56fe 3 \u4e2d\u7684\u7f51\u7edc\u662f\u4e00\u4e2a\u5177\u6709\u4e94\u4e2a\u5377\u79ef\u5c42\u3001\u4e24\u4e2a\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u3002\u6211\u4eec\u770b\u5230\u8fd8\u6709\u6700\u5927\u6c60\u5316\u3002\u8fd9\u662f\u4ec0\u4e48\u610f\u601d\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5728\u8fdb\u884c\u6df1\u5ea6\u5b66\u4e60\u65f6\u4f1a\u9047\u5230\u7684\u4e00\u4e9b\u672f\u8bed\u3002 \u56fe 4\uff1a\u56fe\u50cf\u5927\u5c0f\u4e3a 8x8\uff0c\u6ee4\u6ce2\u5668\u5927\u5c0f\u4e3a 3x3\uff0c\u6b65\u957f\u4e3a 2\u3002 \u56fe 4 \u5f15\u5165\u4e86\u4e24\u4e2a\u65b0\u672f\u8bed\uff1a\u6ee4\u6ce2\u5668\u548c\u6b65\u957f\u3002 \u6ee4\u6ce2\u5668 \u662f\u7531\u7ed9\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u7684\u4e8c\u7ef4\u77e9\u9635\uff0c\u7531\u6307\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u3002 Kaiming\u6b63\u6001\u521d\u59cb\u5316 \uff0c\u662f\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u7684\u6700\u4f73\u9009\u62e9\u3002\u8fd9\u662f\u56e0\u4e3a\u5927\u591a\u6570\u73b0\u4ee3\u7f51\u7edc\u90fd\u4f7f\u7528 ReLU \uff08\u6574\u6d41\u7ebf\u6027\u5355\u5143\uff09\u6fc0\u6d3b\u51fd\u6570\uff0c\u9700\u8981\u9002\u5f53\u7684\u521d\u59cb\u5316\u6765\u907f\u514d\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff08\u68af\u5ea6\u8d8b\u8fd1\u4e8e\u96f6\uff0c\u7f51\u7edc\u6743\u91cd\u4e0d\u53d8\uff09\u3002\u8be5\u6ee4\u6ce2\u5668\u4e0e\u56fe\u50cf\u8fdb\u884c\u5377\u79ef\u3002\u5377\u79ef\u4e0d\u8fc7\u662f\u6ee4\u6ce2\u5668\u4e0e\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u5f53\u524d\u91cd\u53e0\u50cf\u7d20\u4e4b\u95f4\u7684\u5143\u7d20\u76f8\u4e58\u7684\u603b\u548c\u3002\u60a8\u53ef\u4ee5\u5728\u4efb\u4f55\u9ad8\u4e2d\u6570\u5b66\u6559\u79d1\u4e66\u4e2d\u9605\u8bfb\u66f4\u591a\u5173\u4e8e\u5377\u79ef\u7684\u5185\u5bb9\u3002\u6211\u4eec\u4ece\u56fe\u50cf\u7684\u5de6\u4e0a\u89d2\u5f00\u59cb\u5bf9\u6ee4\u955c\u8fdb\u884c\u5377\u79ef\uff0c\u7136\u540e\u6c34\u5e73\u79fb\u52a8\u6ee4\u955c\u3002\u5982\u679c\u79fb\u52a8 1 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 1\uff1b\u5982\u679c\u79fb\u52a8 2 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 2\u3002 \u5373\u4f7f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u4f8b\u5982\u5728\u95ee\u9898\u548c\u56de\u7b54\u7cfb\u7edf\u4e2d\u9700\u8981\u4ece\u5927\u91cf\u6587\u672c\u8bed\u6599\u4e2d\u7b5b\u9009\u7b54\u6848\u65f6\uff0c\u6b65\u957f\u4e5f\u662f\u4e00\u4e2a\u975e\u5e38\u6709\u7528\u7684\u6982\u5ff5\u3002\u5f53\u6211\u4eec\u5728\u6c34\u5e73\u65b9\u5411\u4e0a\u8d70\u5230\u5c3d\u5934\u65f6\uff0c\u5c31\u4f1a\u4ee5\u540c\u6837\u7684\u6b65\u957f\u5782\u76f4\u5411\u4e0b\u79fb\u52a8\u8fc7\u6ee4\u5668\uff0c\u4ece\u5de6\u4fa7\u5f00\u59cb\u3002\u56fe 4 \u8fd8\u663e\u793a\u4e86\u8fc7\u6ee4\u5668\u79fb\u51fa\u56fe\u50cf\u7684\u60c5\u51b5\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u65e0\u6cd5\u8ba1\u7b97\u5377\u79ef\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u8df3\u8fc7\u5b83\u3002\u5982\u679c\u4e0d\u60f3\u8df3\u8fc7\uff0c\u5219\u9700\u8981\u5bf9\u56fe\u50cf\u8fdb\u884c \u586b\u5145\uff08pad\uff09 \u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5377\u79ef\u4f1a\u51cf\u5c0f\u56fe\u50cf\u7684\u5927\u5c0f\u3002\u586b\u5145\u4e5f\u662f\u4fdd\u6301\u56fe\u50cf\u5927\u5c0f\u4e0d\u53d8\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u5728\u56fe 4 \u4e2d\uff0c\u4e00\u4e2a 3x3 \u6ee4\u6ce2\u5668\u6b63\u5728\u6c34\u5e73\u548c\u5782\u76f4\u79fb\u52a8\uff0c\u6bcf\u6b21\u79fb\u52a8\u90fd\u4f1a\u5206\u522b\u8df3\u8fc7\u4e24\u5217\u548c\u4e24\u884c\uff08\u5373\u50cf\u7d20\uff09\u3002\u7531\u4e8e\u5b83\u8df3\u8fc7\u4e86\u4e24\u4e2a\u50cf\u7d20\uff0c\u6240\u4ee5\u6b65\u957f = 2\u3002\u56e0\u6b64\u56fe\u50cf\u5927\u5c0f\u4e3a [(8-3) / 2] + 1 = 3.5\u3002\u6211\u4eec\u53d6 3.5 \u7684\u4e0b\u9650\uff0c\u6240\u4ee5\u662f 3x3\u3002\u60a8\u53ef\u4ee5\u5728\u8349\u7a3f\u7eb8\u4e0a\u8fdb\u884c\u5c1d\u8bd5\u3002 \u56fe 5\uff1a\u901a\u8fc7\u586b\u5145\uff0c\u6211\u4eec\u53ef\u4ee5\u63d0\u4f9b\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\u7684\u56fe\u50cf \u6211\u4eec\u53ef\u4ee5\u4ece\u56fe 5 \u4e2d\u770b\u5230\u586b\u5145\u7684\u6548\u679c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e00\u4e2a 3x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5b83\u4ee5 1 \u7684\u6b65\u957f\u79fb\u52a8\u3002\u539f\u59cb\u56fe\u50cf\u7684\u5927\u5c0f\u4e3a 6x6\uff0c\u6211\u4eec\u6dfb\u52a0\u4e86 1 \u7684 \u586b\u5145 \u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u751f\u6210\u7684\u56fe\u50cf\u5c06\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\uff0c\u5373 6x6\u3002\u5728\u5904\u7406\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u7684\u53e6\u4e00\u4e2a\u76f8\u5173\u672f\u8bed\u662f \u81a8\u80c0\uff08dilation\uff09 \uff0c\u5982\u56fe 6 \u6240\u793a\u3002 \u56fe 6\uff1a\u81a8\u80c0\uff08dilation\uff09\u7684\u4f8b\u5b50 \u5728\u81a8\u80c0\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u5c06\u6ee4\u6ce2\u5668\u6269\u5927 N-1\uff0c\u5176\u4e2d N \u662f\u81a8\u80c0\u7387\u7684\u503c\uff0c\u6216\u7b80\u79f0\u4e3a\u81a8\u80c0\u3002\u5728\u8fd9\u79cd\u5e26\u81a8\u80c0\u7684\u5185\u6838\u4e2d\uff0c\u6bcf\u6b21\u5377\u79ef\u90fd\u4f1a\u8df3\u8fc7\u4e00\u4e9b\u50cf\u7d20\u3002\u8fd9\u5728\u5206\u5272\u4efb\u52a1\u4e2d\u5c24\u4e3a\u6709\u6548\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u4e8c\u7ef4\u5377\u79ef\u3002 \u8fd8\u6709\u4e00\u7ef4\u5377\u79ef\u548c\u66f4\u9ad8\u7ef4\u5ea6\u7684\u5377\u79ef\u3002\u5b83\u4eec\u90fd\u57fa\u4e8e\u76f8\u540c\u7684\u57fa\u672c\u6982\u5ff5\u3002 \u63a5\u4e0b\u6765\u662f \u6700\u5927\u6c60\u5316\uff08Max pooling\uff09 \u3002\u6700\u5927\u503c\u6c60\u53ea\u662f\u4e00\u4e2a\u8fd4\u56de\u6700\u5927\u503c\u7684\u6ee4\u6ce2\u5668\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u63d0\u53d6\u7684\u4e0d\u662f\u5377\u79ef\uff0c\u800c\u662f\u50cf\u7d20\u7684\u6700\u5927\u503c\u3002\u540c\u6837\uff0c \u5e73\u5747\u6c60\u5316\uff08average pooling\uff09 \u6216 \u5747\u503c\u6c60\u5316\uff08mean pooling\uff09 \u4f1a\u8fd4\u56de\u50cf\u7d20\u7684\u5e73\u5747\u503c\u3002\u5b83\u4eec\u7684\u4f7f\u7528\u65b9\u6cd5\u4e0e\u5377\u79ef\u6838\u76f8\u540c\u3002\u6c60\u5316\u6bd4\u5377\u79ef\u66f4\u5feb\uff0c\u662f\u4e00\u79cd\u5bf9\u56fe\u50cf\u8fdb\u884c\u7f29\u51cf\u91c7\u6837\u7684\u65b9\u6cd5\u3002\u6700\u5927\u6c60\u5316\u53ef\u68c0\u6d4b\u8fb9\u7f18\uff0c\u5e73\u5747\u6c60\u5316\u53ef\u5e73\u6ed1\u56fe\u50cf\u3002 \u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u548c\u6df1\u5ea6\u5b66\u4e60\u7684\u6982\u5ff5\u592a\u591a\u4e86\u3002\u6211\u6240\u8ba8\u8bba\u7684\u662f\u4e00\u4e9b\u57fa\u7840\u77e5\u8bc6\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5728 PyTorch \u4e2d\u6784\u5efa\u7b2c\u4e00\u4e2a\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u505a\u597d\u4e86\u5145\u5206\u51c6\u5907\u3002PyTorch \u63d0\u4f9b\u4e86\u4e00\u79cd\u76f4\u89c2\u800c\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u5b9e\u73b0\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e14\u4f60\u4e0d\u9700\u8981\u5173\u5fc3\u53cd\u5411\u4f20\u64ad\u3002\u6211\u4eec\u7528\u4e00\u4e2a python \u7c7b\u548c\u4e00\u4e2a\u524d\u9988\u51fd\u6570\u6765\u5b9a\u4e49\u7f51\u7edc\uff0c\u544a\u8bc9 PyTorch \u5404\u5c42\u4e4b\u95f4\u5982\u4f55\u8fde\u63a5\u3002\u5728 PyTorch \u4e2d\uff0c\u56fe\u50cf\u7b26\u53f7\u662f BS\u3001C\u3001H\u3001W\uff0c\u5176\u4e2d\uff0cBS \u662f\u6279\u5927\u5c0f\uff0cC \u662f\u901a\u9053\uff0cH \u662f\u9ad8\u5ea6\uff0cW \u662f\u5bbd\u5ea6\u3002\u8ba9\u6211\u4eec\u770b\u770b PyTorch \u662f\u5982\u4f55\u5b9e\u73b0 AlexNet \u7684\u3002 import torch import torch.nn as nn import torch.nn.functional as F class AlexNet ( nn . Module ): def __init__ ( self ): super ( AlexNet , self ) . __init__ () self . conv1 = nn . Conv2d ( in_channels = 3 , out_channels = 96 , kernel_size = 11 , stride = 4 , padding = 0 ) self . pool1 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv2 = nn . Conv2d ( in_channels = 96 , out_channels = 256 , kernel_size = 5 , stride = 1 , padding = 2 ) self . pool2 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv3 = nn . Conv2d ( in_channels = 256 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv4 = nn . Conv2d ( in_channels = 384 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv5 = nn . Conv2d ( in_channels = 384 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ) self . pool3 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . fc1 = nn . Linear ( in_features = 9216 , out_features = 4096 ) self . dropout1 = nn . Dropout ( 0.5 ) self . fc2 = nn . Linear ( in_features = 4096 , out_features = 4096 ) self . dropout2 = nn . Dropout ( 0.5 ) self . fc3 = nn . Linear ( in_features = 4096 , out_features = 1000 ) def forward ( self , image ): bs , c , h , w = image . size () x = F . relu ( self . conv1 ( image )) # size: (bs, 96, 55, 55) x = self . pool1 ( x ) # size: (bs, 96, 27, 27) x = F . relu ( self . conv2 ( x )) # size: (bs, 256, 27, 27) x = self . pool2 ( x ) # size: (bs, 256, 13, 13) x = F . relu ( self . conv3 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv4 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv5 ( x )) # size: (bs, 256, 13, 13) x = self . pool3 ( x ) # size: (bs, 256, 6, 6) x = x . view ( bs , - 1 ) # size: (bs, 9216) x = F . relu ( self . fc1 ( x )) # size: (bs, 4096) x = self . dropout1 ( x ) # size: (bs, 4096) # dropout does not change size # dropout is used for regularization # 0.3 dropout means that only 70% of the nodes # of the current layer are used for the next layer x = F . relu ( self . fc2 ( x )) # size: (bs, 4096) x = self . dropout2 ( x ) # size: (bs, 4096) x = F . relu ( self . fc3 ( x )) # size: (bs, 1000) # 1000 is number of classes in ImageNet Dataset # softmax is an activation function that converts # linear output to probabilities that add up to 1 # for each sample in the batch x = torch . softmax ( x , axis = 1 ) # size: (bs, 1000) return x \u5982\u679c\u60a8\u6709\u4e00\u5e45 3x227x227 \u7684\u56fe\u50cf\uff0c\u5e76\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11 \u7684\u5377\u79ef\u6ee4\u6ce2\u5668\uff0c\u8fd9\u610f\u5473\u7740\u60a8\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5e76\u4e0e\u4e00\u4e2a\u5927\u5c0f\u4e3a 227x227x3 \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u5377\u79ef\u3002\u8f93\u51fa\u901a\u9053\u7684\u6570\u91cf\u5c31\u662f\u5206\u522b\u5e94\u7528\u4e8e\u56fe\u50cf\u7684\u76f8\u540c\u5927\u5c0f\u7684\u4e0d\u540c\u5377\u79ef\u6ee4\u6ce2\u5668\u7684\u6570\u91cf\u3002 \u56e0\u6b64\uff0c\u5728\u7b2c\u4e00\u4e2a\u5377\u79ef\u5c42\u4e2d\uff0c\u8f93\u5165\u901a\u9053\u662f 3\uff0c\u4e5f\u5c31\u662f\u539f\u59cb\u8f93\u5165\uff0c\u5373 R\u3001G\u3001B \u4e09\u901a\u9053\u3002PyTorch \u7684 torchvision \u63d0\u4f9b\u4e86\u8bb8\u591a\u4e0e AlexNet \u7c7b\u4f3c\u7684\u4e0d\u540c\u6a21\u578b\uff0c\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0cAlexNet \u7684\u5b9e\u73b0\u4e0e torchvision \u7684\u5b9e\u73b0\u5e76\u4e0d\u76f8\u540c\u3002Torchvision \u7684 AlexNet \u5b9e\u73b0\u662f\u4ece\u53e6\u4e00\u7bc7\u8bba\u6587\u4e2d\u4fee\u6539\u800c\u6765\u7684 AlexNet\uff1a Krizhevsky, A. One weird trick for parallelizing convolutional neural networks. CoRR, abs/1404.5997, 2014. \u4f60\u53ef\u4ee5\u4e3a\u81ea\u5df1\u7684\u4efb\u52a1\u8bbe\u8ba1\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u5f88\u591a\u65f6\u5019\uff0c\u4ece\u96f6\u505a\u8d77\u662f\u4e2a\u4e0d\u9519\u7684\u4e3b\u610f\u3002\u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u7f51\u7edc\uff0c\u7528\u4e8e\u533a\u5206\u56fe\u50cf\u6709\u65e0\u6c14\u80f8\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u51c6\u5907\u4e00\u4e9b\u6587\u4ef6\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u4e00\u4e2a\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\uff0c\u5373 train.csv\uff0c\u4f46\u589e\u52a0\u4e00\u5217 kfold\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e94\u4e2a\u6587\u4ef6\u5939\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u5df2\u7ecf\u6f14\u793a\u4e86\u5982\u4f55\u9488\u5bf9\u4e0d\u540c\u7684\u6570\u636e\u96c6\u521b\u5efa\u6298\u53e0\uff0c\u56e0\u6b64\u6211\u5c06\u8df3\u8fc7\u8fd9\u4e00\u90e8\u5206\uff0c\u7559\u4f5c\u7ec3\u4e60\u3002\u5bf9\u4e8e\u57fa\u4e8e PyTorch \u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u9700\u8981\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u7684\u76ee\u7684\u662f\u8fd4\u56de\u4e00\u4e2a\u6570\u636e\u9879\u6216\u6570\u636e\u6837\u672c\u3002\u8fd9\u4e2a\u6570\u636e\u6837\u672c\u5e94\u8be5\u5305\u542b\u8bad\u7ec3\u6216\u8bc4\u4f30\u6a21\u578b\u6240\u9700\u7684\u6240\u6709\u5185\u5bb9\u3002 import torch import numpy as np from PIL import Image from PIL import ImageFile ImageFile . LOAD_TRUNCATED_IMAGES = True # \u5b9a\u4e49\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u5904\u7406\u56fe\u50cf\u5206\u7c7b\u4efb\u52a1 class ClassificationDataset : def __init__ ( self , image_paths , targets , resize = None , augmentations = None ): # \u56fe\u50cf\u6587\u4ef6\u8def\u5f84\u5217\u8868 self . image_paths = image_paths # \u76ee\u6807\u6807\u7b7e\u5217\u8868 self . targets = targets # \u56fe\u50cf\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . resize = resize # \u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . augmentations = augmentations def __len__ ( self ): # \u8fd4\u56de\u6570\u636e\u96c6\u7684\u5927\u5c0f\uff0c\u5373\u56fe\u50cf\u6570\u91cf return len ( self . image_paths ) def __getitem__ ( self , item ): # \u83b7\u53d6\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e2a\u6837\u672c image = Image . open ( self . image_paths [ item ]) image = image . convert ( \"RGB\" ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aRGB\u683c\u5f0f # \u83b7\u53d6\u8be5\u6837\u672c\u7684\u76ee\u6807\u6807\u7b7e targets = self . targets [ item ] if self . resize is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u5c06\u56fe\u50cf\u8fdb\u884c\u5c3a\u5bf8\u8c03\u6574 image = image . resize (( self . resize [ 1 ], self . resize [ 0 ]), resample = Image . BILINEAR ) image = np . array ( image ) if self . augmentations is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u5e94\u7528\u6570\u636e\u589e\u5f3a augmented = self . augmentations ( image = image ) image = augmented [ \"image\" ] # \u5c06\u56fe\u50cf\u901a\u9053\u987a\u5e8f\u8c03\u6574\u4e3a(C, H, W)\u7684\u5f62\u5f0f\uff0c\u5e76\u8f6c\u6362\u4e3afloat32\u7c7b\u578b image = np . transpose ( image , ( 2 , 0 , 1 )) . astype ( np . float32 ) # \u8fd4\u56de\u6837\u672c\uff0c\u5305\u62ec\u56fe\u50cf\u548c\u5bf9\u5e94\u7684\u76ee\u6807\u6807\u7b7e return { \"image\" : torch . tensor ( image , dtype = torch . float ), \"targets\" : torch . tensor ( targets , dtype = torch . long ), } \u73b0\u5728\u6211\u4eec\u9700\u8981 engine.py\u3002engine.py \u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u770b\u770b engine.py \u662f\u5982\u4f55\u7f16\u5199\u7684\u3002 import torch import torch.nn as nn from tqdm import tqdm # \u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u51fd\u6570 def train ( data_loader , model , optimizer , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bad\u7ec3\u6a21\u5f0f model . train () for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u548c\u76ee\u6807\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) targets = targets . to ( device , dtype = torch . float ) # \u5c06\u4f18\u5316\u5668\u4e2d\u7684\u68af\u5ea6\u5f52\u96f6 optimizer . zero_grad () # \u524d\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u6a21\u578b\u9884\u6d4b outputs = model ( inputs ) # \u4f7f\u7528\u5e26\u903b\u8f91\u65af\u8482\u51fd\u6570\u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\u8ba1\u7b97\u635f\u5931 loss = nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) # \u53cd\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u68af\u5ea6\u5e76\u66f4\u65b0\u6a21\u578b\u6743\u91cd loss . backward () optimizer . step () # \u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u7684\u51fd\u6570 def evaluate ( data_loader , model , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bc4\u4f30\u6a21\u5f0f\uff08\u4e0d\u8fdb\u884c\u68af\u5ea6\u8ba1\u7b97\uff09 model . eval () # \u521d\u59cb\u5316\u5217\u8868\u4ee5\u5b58\u50a8\u771f\u5b9e\u76ee\u6807\u548c\u6a21\u578b\u9884\u6d4b final_targets = [] final_outputs = [] with torch . no_grad (): for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) # \u83b7\u53d6\u6a21\u578b\u9884\u6d4b output = model ( inputs ) # \u5c06\u76ee\u6807\u548c\u8f93\u51fa\u8f6c\u6362\u4e3aCPU\u548cPython\u5217\u8868 targets = targets . detach () . cpu () . numpy () . tolist () output = output . detach () . cpu () . numpy () . tolist () # \u5c06\u5217\u8868\u6269\u5c55\u4ee5\u5305\u542b\u6279\u6b21\u6570\u636e final_targets . extend ( targets ) final_outputs . extend ( output ) # \u8fd4\u56de\u6700\u7ec8\u7684\u6a21\u578b\u9884\u6d4b\u548c\u771f\u5b9e\u76ee\u6807 return final_outputs , final_targets \u6709\u4e86 engine.py\uff0c\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\uff1amodel.py\u3002model.py \u5c06\u5305\u542b\u6211\u4eec\u7684\u6a21\u578b\u3002\u628a\u6a21\u578b\u4e0e\u8bad\u7ec3\u5206\u5f00\u662f\u4e2a\u597d\u4e3b\u610f\uff0c\u56e0\u4e3a\u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u8bd5\u9a8c\u4e0d\u540c\u7684\u6a21\u578b\u548c\u4e0d\u540c\u7684\u67b6\u6784\u3002\u540d\u4e3a pretrainedmodels \u7684 PyTorch \u5e93\u4e2d\u6709\u5f88\u591a\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\uff0c\u5982 AlexNet\u3001ResNet\u3001DenseNet \u7b49\u3002\u8fd9\u4e9b\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\u662f\u5728\u540d\u4e3a ImageNet \u7684\u5927\u578b\u56fe\u50cf\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u51fa\u6765\u7684\u3002\u5728 ImageNet \u4e0a\u8bad\u7ec3\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5b83\u4eec\u7684\u6743\u91cd\uff0c\u4e5f\u53ef\u4ee5\u4e0d\u4f7f\u7528\u8fd9\u4e9b\u6743\u91cd\u3002\u5982\u679c\u6211\u4eec\u4e0d\u4f7f\u7528 ImageNet \u6743\u91cd\u8fdb\u884c\u8bad\u7ec3\uff0c\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u7f51\u7edc\u5c06\u4ece\u5934\u5f00\u59cb\u5b66\u4e60\u4e00\u5207\u3002\u8fd9\u5c31\u662f model.py \u7684\u6837\u5b50\u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 4096 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 4096 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5982\u679c\u4f60\u6253\u5370\u4e86\u7f51\u7edc\uff0c\u4f1a\u5f97\u5230\u5982\u4e0b\u8f93\u51fa\uff1a AlexNet ( ( avgpool ): AdaptiveAvgPool2d ( output_size = ( 6 , 6 )) ( _features ): Sequential ( ( 0 ): Conv2d ( 3 , 64 , kernel_size = ( 11 , 11 ), stride = ( 4 , 4 ), padding = ( 2 , 2 )) ( 1 ): ReLU ( inplace = True ) ( 2 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 3 ): Conv2d ( 64 , 192 , kernel_size = ( 5 , 5 ), stride = ( 1 , 1 ), padding = ( 2 , 2 )) ( 4 ): ReLU ( inplace = True ) ( 5 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 6 ): Conv2d ( 192 , 384 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 7 ): ReLU ( inplace = True ) ( 8 ): Conv2d ( 384 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 9 ): ReLU ( inplace = True ) ( 10 ): Conv2d ( 256 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 11 ): ReLU ( inplace = True ) ( 12 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , eil_mode = False )) ( dropout0 ): Dropout ( p = 0.5 , inplace = False ) ( linear0 ): Linear ( in_features = 9216 , out_features = 4096 , bias = True ) ( relu0 ): ReLU ( inplace = True ) ( dropout1 ): Dropout ( p = 0.5 , inplace = False ) ( linear1 ): Linear ( in_features = 4096 , out_features = 4096 , bias = True ) ( relu1 ): ReLU ( inplace = True ) ( last_linear ): Sequential ( ( 0 ): BatchNorm1d ( 4096 , eps = 1e-05 , momentum = 0.1 , affine = True , rack_running_stats = True ) ( 1 ): Dropout ( p = 0.25 , inplace = False ) ( 2 ): Linear ( in_features = 4096 , out_features = 2048 , bias = True ) ( 3 ): ReLU () ( 4 ): BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 , affine = True , track_running_stats = True ) ( 5 ): Dropout ( p = 0.5 , inplace = False ) ( 6 ): Linear ( in_features = 2048 , out_features = 1 , bias = True ) ) ) \u73b0\u5728\uff0c\u4e07\u4e8b\u4ff1\u5907\uff0c\u53ef\u4ee5\u5f00\u59cb\u8bad\u7ec3\u4e86\u3002\u6211\u4eec\u5c06\u4f7f\u7528 train.py \u8bad\u7ec3\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import torch from sklearn import metrics from sklearn.model_selection import train_test_split import dataset import engine from model import get_model if __name__ == \"__main__\" : # \u5b9a\u4e49\u6570\u636e\u8def\u5f84\u3001\u8bbe\u5907\u3001\u8fed\u4ee3\u6b21\u6570 data_path = \"/home/abhishek/workspace/siim_png/\" device = \"cuda\" # \u4f7f\u7528GPU\u52a0\u901f epochs = 10 # \u4eceCSV\u6587\u4ef6\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( os . path . join ( data_path , \"train.csv\" )) images = df . ImageId . values . tolist () images = [ os . path . join ( data_path , \"train_png\" , i + \".png\" ) for i in images ] targets = df . target . values # \u83b7\u53d6\u9884\u8bad\u7ec3\u7684\u6a21\u578b model = get_model ( pretrained = True ) model . to ( device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\uff0c\u7528\u4e8e\u6570\u636e\u6807\u51c6\u5316 mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) # \u6570\u636e\u589e\u5f3a\uff0c\u5c06\u56fe\u50cf\u6807\u51c6\u5316 aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5212\u5206\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 train_images , valid_images , train_targets , valid_targets = train_test_split ( images , targets , stratify = targets , random_state = 42 ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u548c\u9a8c\u8bc1\u6570\u636e\u96c6 train_dataset = dataset . ClassificationDataset ( image_paths = train_images , targets = train_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = 16 , shuffle = True , num_workers = 4 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = dataset . ClassificationDataset ( image_paths = valid_images , targets = valid_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = 16 , shuffle = False , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u8bad\u7ec3\u5faa\u73af for epoch in range ( epochs ): # \u8bad\u7ec3\u6a21\u578b engine . train ( train_loader , model , optimizer , device = device ) # \u8bc4\u4f30\u6a21\u578b\u6027\u80fd predictions , valid_targets = engine . evaluate ( valid_loader , model , device = device ) # \u8ba1\u7b97ROC AUC\u5206\u6570\u5e76\u6253\u5370 roc_auc = metrics . roc_auc_score ( valid_targets , predictions ) print ( f \"Epoch= { epoch } , Valid ROC AUC= { roc_auc } \" ) \u8ba9\u6211\u4eec\u5728\u6ca1\u6709\u9884\u8bad\u7ec3\u6743\u91cd\u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff1a Epoch = 0 , Valid ROC AUC = 0.5737161981475328 Epoch = 1 , Valid ROC AUC = 0.5362868001588292 Epoch = 2 , Valid ROC AUC = 0.6163448214387008 Epoch = 3 , Valid ROC AUC = 0.6119219143780944 Epoch = 4 , Valid ROC AUC = 0.6229718888519726 Epoch = 5 , Valid ROC AUC = 0.5983014999635341 Epoch = 6 , Valid ROC AUC = 0.5523236874306134 Epoch = 7 , Valid ROC AUC = 0.4717721611306046 Epoch = 8 , Valid ROC AUC = 0.6473408263980617 Epoch = 9 , Valid ROC AUC = 0.6639862888260415 AUC \u7ea6\u4e3a 0.66\uff0c\u751a\u81f3\u4f4e\u4e8e\u6211\u4eec\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u4f1a\u53d1\u751f\u4ec0\u4e48\u60c5\u51b5\uff1f Epoch = 0 , Valid ROC AUC = 0.5730387429803165 Epoch = 1 , Valid ROC AUC = 0.5319813942934937 Epoch = 2 , Valid ROC AUC = 0.627111577514323 Epoch = 3 , Valid ROC AUC = 0.6819736959393209 Epoch = 4 , Valid ROC AUC = 0.5747117168950512 Epoch = 5 , Valid ROC AUC = 0.5994619255609669 Epoch = 6 , Valid ROC AUC = 0.5080889443530546 Epoch = 7 , Valid ROC AUC = 0.6323792776512727 Epoch = 8 , Valid ROC AUC = 0.6685753182661686 Epoch = 9 , Valid ROC AUC = 0.6861802387300147 \u73b0\u5728\u7684 AUC \u597d\u4e86\u5f88\u591a\u3002\u4e0d\u8fc7\uff0c\u5b83\u4ecd\u7136\u8f83\u4f4e\u3002\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u597d\u5904\u662f\u53ef\u4ee5\u8f7b\u677e\u5c1d\u8bd5\u591a\u79cd\u4e0d\u540c\u7684\u6a21\u578b\u3002\u8ba9\u6211\u4eec\u8bd5\u8bd5\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u7684 resnet18 \u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 512 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 512 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5728\u5c1d\u8bd5\u8be5\u6a21\u578b\u65f6\uff0c\u6211\u8fd8\u5c06\u56fe\u50cf\u5927\u5c0f\u6539\u4e3a 512x512\uff0c\u5e76\u6dfb\u52a0\u4e86\u4e00\u4e2a\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\uff0c\u6bcf 3 \u4e2aepochs\u540e\u5c06\u5b66\u4e60\u7387\u4e58\u4ee5 0.5\u3002 Epoch = 0 , Valid ROC AUC = 0.5988225569880796 Epoch = 1 , Valid ROC AUC = 0.730349343208836 Epoch = 2 , Valid ROC AUC = 0.5870943169939142 Epoch = 3 , Valid ROC AUC = 0.5775864444138311 Epoch = 4 , Valid ROC AUC = 0.7330502499939224 Epoch = 5 , Valid ROC AUC = 0.7500336296524395 Epoch = 6 , Valid ROC AUC = 0.7563722113724951 Epoch = 7 , Valid ROC AUC = 0.7987463837994215 Epoch = 8 , Valid ROC AUC = 0.798505708937384 Epoch = 9 , Valid ROC AUC = 0.8025477500546988 \u8fd9\u4e2a\u6a21\u578b\u4f3c\u4e4e\u8868\u73b0\u6700\u597d\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u8c03\u6574 AlexNet \u4e2d\u7684\u4e0d\u540c\u53c2\u6570\u548c\u56fe\u50cf\u5927\u5c0f\uff0c\u4ee5\u83b7\u5f97\u66f4\u597d\u7684\u5206\u6570\u3002 \u4f7f\u7528\u589e\u5f3a\u6280\u672f\u5c06\u8fdb\u4e00\u6b65\u63d0\u9ad8\u5f97\u5206\u3002\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u5f88\u96be\uff0c\u4f46\u5e76\u975e\u4e0d\u53ef\u80fd\u3002\u9009\u62e9 Adam \u4f18\u5316\u5668\u3001\u4f7f\u7528\u4f4e\u5b66\u4e60\u7387\u3001\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u9ad8\u70b9\u65f6\u964d\u4f4e\u5b66\u4e60\u7387\u3001\u5c1d\u8bd5\u4e00\u4e9b\u589e\u5f3a\u6280\u672f\u3001\u5c1d\u8bd5\u5bf9\u56fe\u50cf\u8fdb\u884c\u9884\u5904\u7406\uff08\u5982\u5728\u9700\u8981\u65f6\u8fdb\u884c\u88c1\u526a\uff0c\u8fd9\u4e5f\u53ef\u89c6\u4e3a\u9884\u5904\u7406\uff09\u3001\u6539\u53d8\u6279\u6b21\u5927\u5c0f\u7b49\u3002\u4f60\u53ef\u4ee5\u505a\u5f88\u591a\u4e8b\u60c5\u6765\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u3002 \u4e0e AlexNet \u76f8\u6bd4\uff0c ResNet \u7684\u7ed3\u6784\u8981\u590d\u6742\u5f97\u591a\u3002ResNet \u662f\u6b8b\u5dee\u795e\u7ecf\u7f51\u7edc\uff08Residual Neural Network\uff09\u7684\u7f29\u5199\uff0c\u7531 K. He\u3001X. Zhang\u3001S. Ren \u548c J. Sun \u5728 2015 \u5e74\u53d1\u8868\u7684\u8bba\u6587\u4e2d\u63d0\u51fa\u3002ResNet \u7531 \u6b8b\u5dee\u5757 \uff08residual blocks\uff09\u7ec4\u6210\uff0c\u901a\u8fc7\u8df3\u8fc7\u67d0\u4e9b\u5c42\uff0c\u4f7f\u77e5\u8bc6\u80fd\u591f\u4e0d\u65ad\u5728\u5404\u5c42\u4e2d\u8fdb\u884c\u4f20\u9012\u3002\u8fd9\u4e9b\u5c42\u4e4b\u95f4\u7684 \u8fde\u63a5\u88ab\u79f0\u4e3a \u8df3\u8dc3\u8fde\u63a5 \uff08skip-connections\uff09\uff0c\u56e0\u4e3a\u6211\u4eec\u8df3\u8fc7\u4e86\u4e00\u5c42\u6216\u591a\u5c42\u3002\u8df3\u8dc3\u8fde\u63a5\u901a\u8fc7\u5c06\u68af\u5ea6\u4f20\u64ad\u5230\u66f4\u591a\u5c42\u6765\u5e2e\u52a9\u89e3\u51b3\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u8bad\u7ec3\u975e\u5e38\u5927\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e0d\u4f1a\u635f\u5931\u6027\u80fd\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5927\u578b\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u4e48\u5f53\u8bad\u7ec3\u5230\u67d0\u4e00\u8282\u70b9\u4e0a\u65f6\u8bad\u7ec3\u635f\u5931\u53cd\u800c\u4f1a\u589e\u52a0\uff0c\u4f46\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u8df3\u8dc3\u8fde\u63a5\u6765\u907f\u514d\u3002\u901a\u8fc7\u56fe 7 \u53ef\u4ee5\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u56fe 7\uff1a\u7b80\u5355\u8fde\u63a5\u4e0e\u6b8b\u5dee\u8fde\u63a5\u7684\u6bd4\u8f83\u3002\u53c2\u89c1\u8df3\u8dc3\u8fde\u63a5\u3002\u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u7701\u7565\u4e86\u6700\u540e\u4e00\u5c42\u3002 \u6b8b\u5dee\u5757\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4f60\u4ece\u67d0\u4e00\u5c42\u83b7\u53d6\u8f93\u51fa\uff0c\u8df3\u8fc7\u4e00\u4e9b\u5c42\uff0c\u7136\u540e\u5c06\u8f93\u51fa\u6dfb\u52a0\u5230\u7f51\u7edc\u4e2d\u66f4\u8fdc\u7684\u4e00\u5c42\u3002\u865a\u7ebf\u8868\u793a\u8f93\u5165\u5f62\u72b6\u9700\u8981\u8c03\u6574\uff0c\u56e0\u4e3a\u4f7f\u7528\u4e86\u6700\u5927\u6c60\u5316\uff0c\u800c\u6700\u5927\u6c60\u5316\u7684\u4f7f\u7528\u4f1a\u6539\u53d8\u8f93\u51fa\u7684\u5927\u5c0f\u3002 ResNet \u6709\u591a\u79cd\u4e0d\u540c\u7684\u7248\u672c\uff1a \u6709 18 \u5c42\u300134 \u5c42\u300150 \u5c42\u3001101 \u5c42\u548c 152 \u5c42\uff0c\u6240\u6709\u8fd9\u4e9b\u5c42\u90fd\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8fdb\u884c\u4e86\u6743\u91cd\u9884\u8bad\u7ec3\u3002\u5982\u4eca\uff0c\u9884\u8bad\u7ec3\u6a21\u578b\uff08\u51e0\u4e4e\uff09\u9002\u7528\u4e8e\u6240\u6709\u60c5\u51b5\uff0c\u4f46\u8bf7\u786e\u4fdd\u60a8\u4ece\u8f83\u5c0f\u7684\u6a21\u578b\u5f00\u59cb\uff0c\u4f8b\u5982\uff0c\u4ece resnet-18 \u5f00\u59cb\uff0c\u800c\u4e0d\u662f resnet-50\u3002\u5176\u4ed6\u4e00\u4e9b ImageNet \u9884\u8bad\u7ec3\u6a21\u578b\u5305\u62ec\uff1a Inception DenseNet(different variations) NASNet PNASNet VGG Xception ResNeXt EfficientNet, etc. \u5927\u90e8\u5206\u9884\u8bad\u7ec3\u7684\u6700\u5148\u8fdb\u6a21\u578b\u53ef\u4ee5\u5728 GitHub \u4e0a\u7684 pytorch- pretrainedmodels \u8d44\u6e90\u5e93\u4e2d\u627e\u5230\uff1ahttps://github.com/Cadene/pretrained-models.pytorch\u3002\u8be6\u7ec6\u8ba8\u8bba\u8fd9\u4e9b\u6a21\u578b\u4e0d\u5728\u672c\u7ae0\uff08\u548c\u672c\u4e66\uff09\u8303\u56f4\u4e4b\u5185\u3002\u65e2\u7136\u6211\u4eec\u53ea\u5173\u6ce8\u5e94\u7528\uff0c\u90a3\u5c31\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6837\u7684\u9884\u8bad\u7ec3\u6a21\u578b\u5982\u4f55\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u3002 \u56fe 8\uff1aU-Net\u67b6\u6784 \u5206\u5272\uff08Segmentation\uff09\u662f\u8ba1\u7b97\u673a\u89c6\u89c9\u4e2d\u76f8\u5f53\u6d41\u884c\u7684\u4e00\u9879\u4efb\u52a1\u3002\u5728\u5206\u5272\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u8bd5\u56fe\u4ece\u80cc\u666f\u4e2d\u79fb\u9664/\u63d0\u53d6\u524d\u666f\u3002 \u524d\u666f\u548c\u80cc\u666f\u53ef\u4ee5\u6709\u4e0d\u540c\u7684\u5b9a\u4e49\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u8bf4\uff0c\u8fd9\u662f\u4e00\u9879\u50cf\u7d20\u5206\u7c7b\u4efb\u52a1\uff0c\u4f60\u7684\u5de5\u4f5c\u662f\u7ed9\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u6bcf\u4e2a\u50cf\u7d20\u5206\u914d\u4e00\u4e2a\u7c7b\u522b\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u6b63\u5728\u5904\u7406\u7684\u6c14\u80f8\u6570\u636e\u96c6\u5c31\u662f\u4e00\u9879\u5206\u5272\u4efb\u52a1\u3002\u5728\u8fd9\u9879\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u7ed9\u5b9a\u7684\u80f8\u90e8\u653e\u5c04\u56fe\u50cf\u8fdb\u884c\u6c14\u80f8\u5206\u5272\u3002\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u7684\u6700\u5e38\u7528\u6a21\u578b\u662f U-Net\u3002\u5176\u7ed3\u6784\u5982\u56fe 8 \u6240\u793a\u3002 U-Net \u5305\u62ec\u4e24\u4e2a\u90e8\u5206\uff1a\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u3002\u7f16\u7801\u5668\u4e0e\u60a8\u76ee\u524d\u6240\u89c1\u8fc7\u7684\u4efb\u4f55 U-Net \u90fd\u662f\u4e00\u6837\u7684\u3002\u89e3\u7801\u5668\u5219\u6709\u4e9b\u4e0d\u540c\u3002\u89e3\u7801\u5668\u7531\u4e0a\u5377\u79ef\u5c42\u7ec4\u6210\u3002\u5728\u4e0a\u5377\u79ef\uff08up-convolutions\uff09\uff08 \u8f6c\u7f6e\u5377\u79ef transposed convolutions\uff09\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u6ee4\u6ce2\u5668\uff0c\u5f53\u5e94\u7528\u5230\u4e00\u4e2a\u5c0f\u56fe\u50cf\u65f6\uff0c\u4f1a\u4ea7\u751f\u4e00\u4e2a\u5927\u56fe\u50cf\u3002\u5728 PyTorch \u4e2d\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528 ConvTranspose2d \u6765\u5b8c\u6210\u8fd9\u4e00\u64cd\u4f5c\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u4e0a\u5377\u79ef\u4e0e\u4e0a\u91c7\u6837\u5e76\u4e0d\u76f8\u540c\u3002\u4e0a\u91c7\u6837\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u8fc7\u7a0b\uff0c\u6211\u4eec\u5728\u56fe\u50cf\u4e0a\u5e94\u7528\u4e00\u4e2a\u51fd\u6570\u6765\u8c03\u6574\u5b83\u7684\u5927\u5c0f\u3002\u5728\u4e0a\u5377\u79ef\u4e2d\uff0c\u6211\u4eec\u8981\u5b66\u4e60\u6ee4\u6ce2\u5668\u3002\u6211\u4eec\u5c06\u7f16\u7801\u5668\u7684\u67d0\u4e9b\u90e8\u5206\u4f5c\u4e3a\u67d0\u4e9b\u89e3\u7801\u5668\u7684\u8f93\u5165\u3002\u8fd9\u5bf9 \u4e0a\u5377\u79ef\u5c42\u975e\u5e38\u91cd\u8981\u3002 \u8ba9\u6211\u4eec\u770b\u770b U-Net \u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import torch import torch.nn as nn from torch.nn import functional as F # \u5b9a\u4e49\u4e00\u4e2a\u53cc\u5377\u79ef\u5c42 def double_conv ( in_channels , out_channels ): conv = nn . Sequential ( nn . Conv2d ( in_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ), nn . Conv2d ( out_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ) ) return conv # \u5b9a\u4e49\u51fd\u6570\u7528\u4e8e\u88c1\u526a\u8f93\u5165\u5f20\u91cf def crop_tensor ( tensor , target_tensor ): target_size = target_tensor . size ()[ 2 ] tensor_size = tensor . size ()[ 2 ] delta = tensor_size - target_size delta = delta // 2 return tensor [:, :, delta : tensor_size - delta , delta : tensor_size - delta ] # \u5b9a\u4e49 U-Net \u6a21\u578b class UNet ( nn . Module ): def __init__ ( self ): super ( UNet , self ) . __init () # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . max_pool_2x2 = nn . MaxPool2d ( kernel_size = 2 , stride = 2 ) self . down_conv_1 = double_conv ( 1 , 64 ) self . down_conv_2 = double_conv ( 64 , 128 ) self . down_conv_3 = double_conv ( 128 , 256 ) self . down_conv_4 = double_conv ( 256 , 512 ) self . down_conv_5 = double_conv ( 512 , 1024 ) # \u5b9a\u4e49\u4e0a\u91c7\u6837\u5c42\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . up_trans_1 = nn . ConvTranspose2d ( in_channels = 1024 , out_channels = 512 , kernel_size = 2 , stride = 2 ) self . up_conv_1 = double_conv ( 1024 , 512 ) self . up_trans_2 = nn . ConvTranspose2d ( in_channels = 512 , out_channels = 256 , kernel_size = 2 , stride = 2 ) self . up_conv_2 = double_conv ( 512 , 256 ) self . up_trans_3 = nn . ConvTranspose2d ( in_channels = 256 , out_channels = 128 , kernel_size = 2 , stride = 2 ) self . up_conv_3 = double_conv ( 256 , 128 ) self . up_trans_4 = nn . ConvTranspose2d ( in_channels = 128 , out_channels = 64 , kernel_size = 2 , stride = 2 ) self . up_conv_4 = double_conv ( 128 , 64 ) # \u5b9a\u4e49\u8f93\u51fa\u5c42 self . out = nn . Conv2d ( in_channels = 64 , out_channels = 2 , kernel_size = 1 ) def forward ( self , image ): # \u7f16\u7801\u5668\u90e8\u5206 x1 = self . down_conv_1 ( image ) x2 = self . max_pool_2x2 ( x1 ) x3 = self . down_conv_2 ( x2 ) x4 = self . max_pool_2x2 ( x3 ) x5 = self . down_conv_3 ( x4 ) x6 = self . max_pool_2x2 ( x5 ) x7 = self . down_conv_4 ( x6 ) x8 = self . max_pool_2x2 ( x7 ) x9 = self . down_conv_5 ( x8 ) # \u89e3\u7801\u5668\u90e8\u5206 x = self . up_trans_1 ( x9 ) y = crop_tensor ( x7 , x ) x = self . up_conv_1 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_2 ( x ) y = crop_tensor ( x5 , x ) x = self . up_conv_2 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_3 ( x ) y = crop_tensor ( x3 , x ) x = self . up_conv_3 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_4 ( x ) y = crop_tensor ( x1 , x ) x = self . up_conv_4 ( torch . cat ([ x , y ], axis = 1 )) # \u8f93\u51fa\u5c42 out = self . out ( x ) return out if __name__ == \"__main__\" : image = torch . rand (( 1 , 1 , 572 , 572 )) model = UNet () print ( model ( image )) \u8bf7\u6ce8\u610f\uff0c\u6211\u4e0a\u9762\u5c55\u793a\u7684 U-Net \u5b9e\u73b0\u662f U-Net \u8bba\u6587\u7684\u539f\u59cb\u5b9e\u73b0\u3002\u4e92\u8054\u7f51\u4e0a\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9e\u73b0\u65b9\u6cd5\u3002 \u6709\u4e9b\u4eba\u559c\u6b22\u4f7f\u7528\u53cc\u7ebf\u6027\u91c7\u6837\u4ee3\u66ff\u8f6c\u7f6e\u5377\u79ef\u8fdb\u884c\u4e0a\u91c7\u6837\uff0c\u4f46\u8fd9\u5e76\u4e0d\u662f\u8bba\u6587\u7684\u771f\u6b63\u5b9e\u73b0\u3002\u4e0d\u8fc7\uff0c\u5b83\u7684\u6027\u80fd\u53ef\u80fd\u4f1a\u66f4\u597d\u3002\u5728\u4e0a\u56fe\u6240\u793a\u7684\u539f\u59cb\u5b9e\u73b0\u4e2d\uff0c\u6709\u4e00\u4e2a\u5355\u901a\u9053\u56fe\u50cf\uff0c\u8f93\u51fa\u4e2d\u6709\u4e24\u4e2a\u901a\u9053\uff1a\u4e00\u4e2a\u662f\u524d\u666f\uff0c\u4e00\u4e2a\u662f\u80cc\u666f\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u4e3a\u4efb\u610f\u6570\u91cf\u7684\u7c7b\u548c\u4efb\u610f\u6570\u91cf\u7684\u8f93\u5165\u901a\u9053\u8fdb\u884c\u5b9a\u5236\u3002\u5728\u6b64\u5b9e\u73b0\u4e2d\uff0c\u8f93\u5165\u56fe\u50cf\u7684\u5927\u5c0f\u4e0e\u8f93\u51fa\u56fe\u50cf\u7684\u5927\u5c0f\u4e0d\u540c\uff0c\u56e0\u4e3a\u6211\u4eec\u4f7f\u7528\u7684\u662f\u65e0\u586b\u5145\u5377\u79ef\uff08convolutions without padding\uff09\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0cU-Net \u7684\u7f16\u7801\u5668\u90e8\u5206\u53ea\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u5377\u79ef\u7f51\u7edc\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u4efb\u4f55\u7f51\u7edc\uff08\u5982 ResNet\uff09\u6765\u66ff\u6362\u5b83\u3002 \u8fd9\u79cd\u66ff\u6362\u4e5f\u53ef\u4ee5\u901a\u8fc7\u9884\u8bad\u7ec3\u6743\u91cd\u6765\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u57fa\u4e8e ResNet \u7684\u7f16\u7801\u5668\uff0c\u8be5\u7f16\u7801\u5668\u5df2\u5728 ImageNet \u548c\u901a\u7528\u89e3\u7801\u5668\u4e0a\u8fdb\u884c\u4e86\u9884\u8bad\u7ec3\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u591a\u79cd\u4e0d\u540c\u7684\u7f51\u7edc\u67b6\u6784\u6765\u4ee3\u66ff ResNet\u3002Pavel Yakubovskiy \u6240\u8457\u7684\u300aSegmentation Models Pytorch\u300b\u5c31\u662f\u8bb8\u591a\u6b64\u7c7b\u53d8\u4f53\u7684\u5b9e\u73b0\uff0c\u5176\u4e2d\u7f16\u7801\u5668\u53ef\u4ee5\u88ab\u9884\u8bad\u7ec3\u6a21\u578b\u6240\u53d6\u4ee3\u3002\u8ba9\u6211\u4eec\u5e94\u7528\u57fa\u4e8e ResNet \u7684 U-Net \u6765\u89e3\u51b3\u6c14\u80f8\u68c0\u6d4b\u95ee\u9898\u3002 \u5927\u591a\u6570\u7c7b\u4f3c\u7684\u95ee\u9898\u90fd\u6709\u4e24\u4e2a\u8f93\u5165\uff1a\u539f\u59cb\u56fe\u50cf\u548c\u63a9\u7801\uff08mask\uff09\u3002 \u5982\u679c\u6709\u591a\u4e2a\u5bf9\u8c61\uff0c\u5c31\u4f1a\u6709\u591a\u4e2a\u63a9\u7801\u3002 \u5728\u6211\u4eec\u7684\u6c14\u80f8\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f RLE\u3002RLE \u4ee3\u8868\u8fd0\u884c\u957f\u5ea6\u7f16\u7801\uff0c\u662f\u4e00\u79cd\u8868\u793a\u4e8c\u8fdb\u5236\u63a9\u7801\u4ee5\u8282\u7701\u7a7a\u95f4\u7684\u65b9\u6cd5\u3002\u6df1\u5165\u7814\u7a76 RLE \u8d85\u51fa\u4e86\u672c\u7ae0\u7684\u8303\u56f4\u3002\u56e0\u6b64\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u5f20\u8f93\u5165\u56fe\u50cf\u548c\u76f8\u5e94\u7684\u63a9\u7801\u3002\u8ba9\u6211\u4eec\u5148\u8bbe\u8ba1\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u8f93\u51fa\u56fe\u50cf\u548c\u63a9\u7801\u56fe\u50cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u521b\u5efa\u7684\u811a\u672c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u5206\u5272\u95ee\u9898\u3002\u8bad\u7ec3\u6570\u636e\u96c6\u662f\u4e00\u4e2a CSV \u6587\u4ef6\uff0c\u53ea\u5305\u542b\u56fe\u50cf ID\uff08\u4e5f\u662f\u6587\u4ef6\u540d\uff09\u3002 import os import glob import torch import numpy as np import pandas as pd from PIL import Image , ImageFile from tqdm import tqdm from collections import defaultdict from torchvision import transforms from albumentations import ( Compose , OneOf , RandomBrightnessContrast , RandomGamma , ShiftScaleRotate , ) # \u8bbe\u7f6ePIL\u56fe\u50cf\u52a0\u8f7d\u622a\u65ad\u7684\u5904\u7406 ImageFile . LOAD_TRUNCATED_IMAGES = True # \u521b\u5efaSIIM\u6570\u636e\u96c6\u7c7b class SIIMDataset ( torch . utils . data . Dataset ): def __init__ ( self , image_ids , transform = True , preprocessing_fn = None ): self . data = defaultdict ( dict ) self . transform = transform self . preprocessing_fn = preprocessing_fn # \u5b9a\u4e49\u6570\u636e\u589e\u5f3a self . aug = Compose ( [ ShiftScaleRotate ( shift_limit = 0.0625 , scale_limit = 0.1 , rotate_limit = 10 , p = 0.8 ), OneOf ( [ RandomGamma ( gamma_limit = ( 90 , 110 ) ), RandomBrightnessContrast ( brightness_limit = 0.1 , contrast_limit = 0.1 ), ], p = 0.5 , ), ] ) # \u6784\u5efa\u6570\u636e\u5b57\u5178\uff0c\u5176\u4e2d\u5305\u542b\u56fe\u50cf\u548c\u63a9\u7801\u7684\u8def\u5f84\u4fe1\u606f for imgid in image_ids : files = glob . glob ( os . path . join ( TRAIN_PATH , imgid , \"*.png\" )) self . data [ counter ] = { \"img_path\" : os . path . join ( TRAIN_PATH , imgid + \".png\" ), \"mask_path\" : os . path . join ( TRAIN_PATH , imgid + \"_mask.png\" ), } def __len__ ( self ): return len ( self . data ) def __getitem__ ( self , item ): img_path = self . data [ item ][ \"img_path\" ] mask_path = self . data [ item ][ \"mask_path\" ] # \u6253\u5f00\u56fe\u50cf\u5e76\u5c06\u5176\u8f6c\u6362\u4e3aRGB\u6a21\u5f0f img = Image . open ( img_path ) img = img . convert ( \"RGB\" ) img = np . array ( img ) # \u6253\u5f00\u63a9\u7801\u56fe\u50cf\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6570 mask = Image . open ( mask_path ) mask = ( mask >= 1 ) . astype ( \"float32\" ) # \u5982\u679c\u9700\u8981\u8fdb\u884c\u6570\u636e\u589e\u5f3a if self . transform is True : augmented = self . aug ( image = img , mask = mask ) img = augmented [ \"image\" ] mask = augmented [ \"mask\" ] # \u5e94\u7528\u9884\u5904\u7406\u51fd\u6570\uff08\u5982\u679c\u6709\uff09 img = self . preprocessing_fn ( img ) # \u8fd4\u56de\u56fe\u50cf\u548c\u63a9\u7801 return { \"image\" : transforms . ToTensor ()( img ), \"mask\" : transforms . ToTensor ()( mask ) . float (), } \u6709\u4e86\u6570\u636e\u96c6\u7c7b\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8bad\u7ec3\u51fd\u6570\u3002 import os import sys import torch import numpy as np import pandas as pd import segmentation_models_pytorch as smp import torch.nn as nn import torch.optim as optim from apex import amp from collections import OrderedDict from sklearn import model_selection from tqdm import tqdm from torch.optim import lr_scheduler from dataset import SIIMDataset # \u5b9a\u4e49\u8bad\u7ec3\u6570\u636e\u96c6CSV\u6587\u4ef6\u8def\u5f84 TRAINING_CSV = \"../input/train_pneumothorax.csv\" # \u5b9a\u4e49\u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u6279\u91cf\u5927\u5c0f TRAINING_BATCH_SIZE = 16 TEST_BATCH_SIZE = 4 # \u5b9a\u4e49\u8bad\u7ec3\u7684\u65f6\u671f\u6570 EPOCHS = 10 # \u6307\u5b9a\u4f7f\u7528\u7684\u7f16\u7801\u5668\u548c\u6743\u91cd ENCODER = \"resnet18\" ENCODER_WEIGHTS = \"imagenet\" # \u6307\u5b9a\u8bbe\u5907\uff08GPU\uff09 DEVICE = \"cuda\" # \u5b9a\u4e49\u8bad\u7ec3\u51fd\u6570 def train ( dataset , data_loader , model , criterion , optimizer ): model . train () num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs . to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) optimizer . zero_grad () outputs = model ( inputs ) loss = criterion ( outputs , targets ) with amp . scale_loss ( loss , optimizer ) as scaled_loss : scaled_loss . backward () optimizer . step () tk0 . close () # \u5b9a\u4e49\u8bc4\u4f30\u51fd\u6570 def evaluate ( dataset , data_loader , model ): model . eval () final_loss = 0 num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) with torch . no_grad (): for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) output = model ( inputs ) loss = criterion ( output , targets ) final_loss += loss tk0 . close () return final_loss / num_batches if __name__ == \"__main__\" : df = pd . read_csv ( TRAINING_CSV ) df_train , df_valid = model_selection . train_test_split ( df , random_state = 42 , test_size = 0.1 ) training_images = df_train . image_id . values validation_images = df_valid . image_id . values # \u521b\u5efa U-Net \u6a21\u578b model = smp . Unet ( encoder_name = ENCODER , encoder_weights = ENCODER_WEIGHTS , classes = 1 , activation = None , ) # \u83b7\u53d6\u6570\u636e\u9884\u5904\u7406\u51fd\u6570 prep_fn = smp . encoders . get_preprocessing_fn ( ENCODER , ENCODER_WEIGHTS ) # \u5c06\u6a21\u578b\u653e\u5728\u8bbe\u5907\u4e0a model . to ( DEVICE ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6 train_dataset = SIIMDataset ( training_images , transform = True , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = TRAINING_BATCH_SIZE , shuffle = True , num_workers = 12 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = SIIMDataset ( validation_images , transform = False , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = TEST_BATCH_SIZE , shuffle = True , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) # \u5b9a\u4e49\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = lr_scheduler . ReduceLROnPlateau ( optimizer , mode = \"min\" , patience = 3 , verbose = True ) # \u521d\u59cb\u5316 Apex \u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3 model , optimizer = amp . initialize ( model , optimizer , opt_level = \"O1\" , verbosity = 0 ) # \u5982\u679c\u6709\u591a\u4e2aGPU\uff0c\u5219\u4f7f\u7528 DataParallel \u8fdb\u884c\u5e76\u884c\u8bad\u7ec3 if torch . cuda . device_count () > 1 : print ( f \"Let's use { torch . cuda . device_count () } GPUs!\" ) model = nn . DataParallel ( model ) # \u8f93\u51fa\u8bad\u7ec3\u76f8\u5173\u7684\u4fe1\u606f print ( f \"Training batch size: { TRAINING_BATCH_SIZE } \" ) print ( f \"Test batch size: { TEST_BATCH_SIZE } \" ) print ( f \"Epochs: { EPOCHS } \" ) print ( f \"Image size: { IMAGE_SIZE } \" ) print ( f \"Number of training images: { len ( train_dataset ) } \" ) print ( f \"Number of validation images: { len ( valid_dataset ) } \" ) print ( f \"Encoder: { ENCODER } \" ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( EPOCHS ): print ( f \"Training Epoch: { epoch } \" ) train ( train_dataset , train_loader , model , criterion , optimizer ) print ( f \"Validation Epoch: { epoch } \" ) val_log = evaluate ( valid_dataset , valid_loader , model ) scheduler . step ( val_log [ \"loss\" ]) print ( \" \\n \" ) \u5728\u5206\u5272\u95ee\u9898\u4e2d\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u5404\u79cd\u635f\u5931\u51fd\u6570\uff0c\u4f8b\u5982\u4e8c\u5143\u4ea4\u53c9\u71b5\u3001focal\u635f\u5931\u3001dice\u635f\u5931\u7b49\u3002\u6211\u628a\u8fd9\u4e2a\u95ee\u9898\u7559\u7ed9 \u8bfb\u8005\u6839\u636e\u8bc4\u4f30\u6307\u6807\u6765\u51b3\u5b9a\u5408\u9002\u7684\u635f\u5931\u3002\u5f53\u8bad\u7ec3\u8fd9\u6837\u4e00\u4e2a\u6a21\u578b\u65f6\uff0c\u60a8\u5c06\u5efa\u7acb\u9884\u6d4b\u6c14\u80f8\u4f4d\u7f6e\u7684\u6a21\u578b\uff0c\u5982\u56fe 9 \u6240\u793a\u3002\u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u82f1\u4f1f\u8fbe apex \u8fdb\u884c\u4e86\u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4ece PyTorch 1.6.0+ \u7248\u672c\u5f00\u59cb\uff0cPyTorch \u672c\u8eab\u5c31\u63d0\u4f9b\u4e86\u8fd9\u4e00\u529f\u80fd\u3002 \u56fe 9\uff1a\u4ece\u8bad\u7ec3\u6709\u7d20\u7684\u6a21\u578b\u4e2d\u68c0\u6d4b\u5230\u6c14\u80f8\u7684\u793a\u4f8b\uff08\u53ef\u80fd\u4e0d\u662f\u6b63\u786e\u9884\u6d4b\uff09\u3002 \u6211\u5728\u4e00\u4e2a\u540d\u4e3a \"Well That's Fantastic Machine Learning (WTFML) \"\u7684 python \u8f6f\u4ef6\u5305\u4e2d\u6536\u5f55\u4e86\u4e00\u4e9b\u5e38\u7528\u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u5982\u4f55\u5e2e\u52a9\u6211\u4eec\u4e3a FGVC 202013 \u690d\u7269\u75c5\u7406\u5b66\u6311\u6218\u8d5b\u4e2d\u7684\u690d\u7269\u56fe\u50cf\u5efa\u7acb\u591a\u7c7b\u5206\u7c7b\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import argparse import torch import torchvision import torch.nn as nn import torch.nn.functional as F from sklearn import metrics from sklearn.model_selection import train_test_split from wtfml.engine import Engine from wtfml.data_loaders.image import ClassificationDataLoader # \u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\uff0c\u5b9e\u73b0\u5bc6\u96c6\u4ea4\u53c9\u71b5 class DenseCrossEntropy ( nn . Module ): def __init__ ( self ): super ( DenseCrossEntropy , self ) . __init__ () def forward ( self , logits , labels ): logits = logits . float () labels = labels . float () logprobs = F . log_softmax ( logits , dim =- 1 ) loss = - labels * logprobs loss = loss . sum ( - 1 ) return loss . mean () # \u81ea\u5b9a\u4e49\u795e\u7ecf\u7f51\u7edc\u6a21\u578b class Model ( nn . Module ): def __init__ ( self ): super () . __init () self . base_model = torchvision . models . resnet18 ( pretrained = True ) in_features = self . base_model . fc . in_features self . out = nn . Linear ( in_features , 4 ) def forward ( self , image , targets = None ): batch_size , C , H , W = image . shape x = self . base_model . conv1 ( image ) x = self . base_model . bn1 ( x ) x = self . base_model . relu ( x ) x = self . base_model . maxpool ( x ) x = self . base_model . layer1 ( x ) x = self . base_model . layer2 ( x ) x = self . base_model . layer3 ( x ) x = self . base_model . layer4 ( x ) x = F . adaptive_avg_pool2d ( x , 1 ) . reshape ( batch_size , - 1 ) x = self . out ( x ) loss = None if targets is not None : loss = DenseCrossEntropy ()( x , targets . type_as ( x )) return x , loss if __name__ == \"__main__\" : # \u547d\u4ee4\u884c\u53c2\u6570\u89e3\u6790\u5668 parser = argparse . ArgumentParser () parser . add_argument ( \"--data_path\" , type = str , ) parser . add_argument ( \"--device\" , type = str ,) parser . add_argument ( \"--epochs\" , type = int ,) args = parser . parse_args () # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( os . path . join ( args . data_path , \"train.csv\" )) images = df . image_id . values . tolist () images = [ os . path . join ( args . data_path , \"images\" , i + \".jpg\" ) for i in images ] targets = df [[ \"healthy\" , \"multiple_diseases\" , \"rust\" , \"scab\" ]] . values # \u521b\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u578b model = Model () model . to ( args . device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\u4ee5\u53ca\u6570\u636e\u589e\u5f3a mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5206\u5272\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 ( train_images , valid_images , train_targets , valid_targets ) = train_test_split ( images , targets ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = ClassificationDataLoader ( image_paths = train_images , targets = train_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = True , tpu = False ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = ClassificationDataLoader ( image_paths = valid_images , targets = valid_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = False , tpu = False ) # \u521b\u5efa\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u521b\u5efa\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = torch . optim . lr_scheduler . StepLR ( optimizer , step_size = 15 , gamma = 0.6 ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( args . epochs ): # \u8bad\u7ec3\u6a21\u578b train_loss = Engine . train ( train_loader , model , optimizer , device = args . device ) # \u8bc4\u4f30\u6a21\u578b valid_loss = Engine . evaluate ( valid_loader , model , device = args . device ) # \u6253\u5370\u635f\u5931\u4fe1\u606f print ( f \" { epoch } , Train Loss= { train_loss } Valid Loss= { valid_loss } \" ) \u6709\u4e86\u6570\u636e\u540e\uff0c\u5c31\u53ef\u4ee5\u8fd0\u884c\u811a\u672c\u4e86\uff1a python plant.py --data_path ../../plant_pathology --device cuda -- epochs 2 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .73it/s, loss = 0 .723 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .62it/s, loss = 0 .433 ] 0 , Train Loss = 0 .7228777609592261 Valid Loss = 0 .4327834551704341 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .74it/s, loss = 0 .271 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .63it/s, loss = 0 .568 ] 1 , Train Loss = 0 .2708700496790021 Valid Loss = 0 .56841839541649 \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u8ba9\u6211\u4eec\u6784\u5efa\u6a21\u578b\u53d8\u5f97\u7b80\u5355\uff0c\u4ee3\u7801\u4e5f\u6613\u4e8e\u9605\u8bfb\u548c\u7406\u89e3\u3002\u6ca1\u6709\u4efb\u4f55\u5c01\u88c5\u7684 PyTorch \u6548\u679c\u6700\u597d\u3002\u56fe\u50cf\u4e2d\u4e0d\u4ec5\u4ec5\u6709\u5206\u7c7b\uff0c\u8fd8\u6709\u5f88\u591a\u5176\u4ed6\u7684\u5185\u5bb9\uff0c\u5982\u679c\u6211\u5f00\u59cb\u5199\u6240\u6709\u7684\u5185\u5bb9\uff0c\u5c31\u5f97\u518d\u5199\u4e00\u672c\u4e66\u4e86\uff0c \u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u56fe\u50cf\u95ee\u9898\uff08\u4f5c\u8005\u5728\u5f00\u73a9\u7b11\uff09\u3002","title":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5"},{"location":"%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB%E5%92%8C%E5%88%86%E5%89%B2%E6%96%B9%E6%B3%95/#_1","text":"\u8bf4\u5230\u56fe\u50cf\uff0c\u8fc7\u53bb\u51e0\u5e74\u53d6\u5f97\u4e86\u5f88\u591a\u6210\u5c31\u3002\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8fdb\u6b65\u76f8\u5f53\u5feb\uff0c\u611f\u89c9\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8bb8\u591a\u95ee\u9898\u73b0\u5728\u90fd\u66f4\u5bb9\u6613\u89e3\u51b3\u4e86\u3002\u968f\u7740\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u51fa\u73b0\u548c\u8ba1\u7b97\u6210\u672c\u7684\u964d\u4f4e\uff0c\u73b0\u5728\u5728\u5bb6\u91cc\u5c31\u80fd\u8f7b\u677e\u8bad\u7ec3\u51fa\u63a5\u8fd1\u6700\u5148\u8fdb\u6c34\u5e73\u7684\u6a21\u578b\uff0c\u89e3\u51b3\u5927\u591a\u6570\u4e0e\u56fe\u50cf\u76f8\u5173\u7684\u95ee\u9898\u3002\u4f46\u662f\uff0c\u56fe\u50cf\u95ee\u9898\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u578b\u3002\u4ece\u4e24\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u7684\u6807\u51c6\u56fe\u50cf\u5206\u7c7b\uff0c\u5230\u50cf\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\u8fd9\u6837\u5177\u6709\u6311\u6218\u6027\u7684\u95ee\u9898\u3002\u6211\u4eec\u4e0d\u4f1a\u5728\u672c\u4e66\u4e2d\u8ba8\u8bba\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\uff0c\u4f46\u6211\u4eec\u663e\u7136\u4f1a\u5904\u7406\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u56fe\u50cf\u95ee\u9898\u3002 \u6211\u4eec\u53ef\u4ee5\u5bf9\u56fe\u50cf\u91c7\u7528\u54ea\u4e9b\u4e0d\u540c\u7684\u65b9\u6cd5\uff1f\u56fe\u50cf\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u6570\u5b57\u77e9\u9635\u3002\u8ba1\u7b97\u673a\u65e0\u6cd5\u50cf\u4eba\u7c7b\u4e00\u6837\u770b\u5230\u56fe\u50cf\u3002\u5b83\u53ea\u80fd\u770b\u5230\u6570\u5b57\uff0c\u8fd9\u5c31\u662f\u56fe\u50cf\u3002\u7070\u5ea6\u56fe\u50cf\u662f\u4e00\u4e2a\u4e8c\u7ef4\u77e9\u9635\uff0c\u6570\u503c\u8303\u56f4\u4ece 0 \u5230 255\u30020 \u4ee3\u8868\u9ed1\u8272\uff0c255 \u4ee3\u8868\u767d\u8272\uff0c\u4ecb\u4e8e\u4e24\u8005\u4e4b\u95f4\u7684\u662f\u5404\u79cd\u7070\u8272\u3002\u4ee5\u524d\uff0c\u5728\u6ca1\u6709\u6df1\u5ea6\u5b66\u4e60\u7684\u65f6\u5019\uff08\u6216\u8005\u8bf4\u6df1\u5ea6\u5b66\u4e60\u8fd8\u4e0d\u6d41\u884c\u7684\u65f6\u5019\uff09\uff0c\u4eba\u4eec\u4e60\u60ef\u4e8e\u67e5\u770b\u50cf\u7d20\u3002\u6bcf\u4e2a\u50cf\u7d20\u90fd\u662f\u4e00\u4e2a\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u5728 Python \u4e2d\u8f7b\u677e\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u53ea\u9700\u4f7f\u7528 OpenCV \u6216 Python-PIL \u8bfb\u53d6\u7070\u5ea6\u56fe\u50cf\uff0c\u8f6c\u6362\u4e3a numpy \u6570\u7ec4\uff0c\u7136\u540e\u5c06\u77e9\u9635\u5e73\u94fa\uff08\u6241\u5e73\u5316\uff09\u5373\u53ef\u3002\u5982\u679c\u5904\u7406\u7684\u662f RGB \u56fe\u50cf\uff0c\u5219\u9700\u8981\u4e09\u4e2a\u77e9\u9635\uff0c\u800c\u4e0d\u662f\u4e00\u4e2a\u3002\u4f46\u601d\u8def\u662f\u4e00\u6837\u7684\u3002 import numpy as np import matplotlib.pyplot as plt # \u751f\u6210\u4e00\u4e2a 256x256 \u7684\u968f\u673a\u7070\u5ea6\u56fe\u50cf\uff0c\u50cf\u7d20\u503c\u57280\u5230255\u4e4b\u95f4\u968f\u673a\u5206\u5e03 random_image = np . random . randint ( 0 , 256 , ( 256 , 256 )) # \u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u56fe\u50cf\u7a97\u53e3\uff0c\u8bbe\u7f6e\u7a97\u53e3\u5927\u5c0f\u4e3a7x7\u82f1\u5bf8 plt . figure ( figsize = ( 7 , 7 )) # \u663e\u793a\u751f\u6210\u7684\u968f\u673a\u56fe\u50cf # \u4f7f\u7528\u7070\u5ea6\u989c\u8272\u6620\u5c04 (colormap)\uff0c\u8303\u56f4\u4ece0\u5230255 plt . imshow ( random_image , cmap = 'gray' , vmin = 0 , vmax = 255 ) # \u663e\u793a\u56fe\u50cf\u7a97\u53e3 plt . show () \u4e0a\u9762\u7684\u4ee3\u7801\u4f7f\u7528 numpy \u751f\u6210\u4e00\u4e2a\u968f\u673a\u77e9\u9635\u3002\u8be5\u77e9\u9635\u7531 0 \u5230 255\uff08\u5305\u542b\uff09\u7684\u503c\u7ec4\u6210\uff0c\u5927\u5c0f\u4e3a 256x256\uff08\u4e5f\u79f0\u4e3a\u50cf\u7d20\uff09\u3002 \u56fe 1\uff1a\u4e8c\u7ef4\u56fe\u50cf\u9635\u5217\uff08\u5355\u901a\u9053\uff09\u53ca\u5176\u5c55\u5e73\u7248\u672c \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u62fc\u5199\u540e\u7684\u7248\u672c\u53ea\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a M \u7684\u5411\u91cf\uff0c\u5176\u4e2d M = N * N\uff0c\u5728\u672c\u4f8b\u4e2d\uff0c\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 256 * 256 = 65536\u3002 \u73b0\u5728\uff0c\u5982\u679c\u6211\u4eec\u7ee7\u7eed\u5bf9\u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u8fdb\u884c\u5904\u7406\uff0c\u6bcf\u4e2a\u6837\u672c\u5c31\u4f1a\u6709 65536 \u4e2a\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5feb\u901f\u5efa\u7acb \u51b3\u7b56\u6811\u6a21\u578b\u3001\u968f\u673a\u68ee\u6797\u6a21\u578b\u6216\u57fa\u4e8e SVM \u7684\u6a21\u578b \u3002\u8fd9\u4e9b\u6a21\u578b\u5c06\u57fa\u4e8e\u50cf\u7d20\u503c\uff0c\u5c1d\u8bd5\u5c06\u6b63\u6837\u672c\u4e0e\u8d1f\u6837\u672c\u533a\u5206\u5f00\u6765\uff08\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff09\u3002 \u4f60\u4eec\u4e00\u5b9a\u90fd\u542c\u8bf4\u8fc7\u732b\u4e0e\u72d7\u7684\u95ee\u9898\uff0c\u8fd9\u662f\u4e00\u4e2a\u7ecf\u5178\u7684\u95ee\u9898\u3002\u5982\u679c\u4f60\u4eec\u8fd8\u8bb0\u5f97\uff0c\u5728\u8bc4\u4f30\u6307\u6807\u4e00\u7ae0\u7684\u5f00\u5934\uff0c\u6211\u5411\u4f60\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e2a\u6c14\u80f8\u56fe\u50cf\u6570\u636e\u96c6\u3002\u90a3\u4e48\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u68c0\u6d4b\u80ba\u90e8\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\uff08\u5e76\u4e0d\uff09\u7b80\u5355\u7684\u4e8c\u5143\u5206\u7c7b\u3002 \u56fe 2\uff1a\u975e\u6c14\u80f8\u4e0e\u6c14\u80f8 X \u5149\u56fe\u50cf\u5bf9\u6bd4 \u5728\u56fe 2 \u4e2d\uff0c\u60a8\u53ef\u4ee5\u770b\u5230\u975e\u6c14\u80f8\u548c\u6c14\u80f8\u56fe\u50cf\u7684\u5bf9\u6bd4\u3002\u60a8\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\u4e86\uff0c\u5bf9\u4e8e\u4e00\u4e2a\u975e\u4e13\u4e1a\u4eba\u58eb\uff08\u6bd4\u5982\u6211\uff09\u6765\u8bf4\uff0c\u8981\u5728\u8fd9\u4e9b\u56fe\u50cf\u4e2d\u8fa8\u522b\u51fa\u54ea\u4e2a\u662f\u6c14\u80f8\u662f\u76f8\u5f53\u56f0\u96be\u7684\u3002 \u6700\u521d\u7684\u6570\u636e\u96c6\u662f\u5173\u4e8e\u68c0\u6d4b\u6c14\u80f8\u7684\u5177\u4f53\u4f4d\u7f6e\uff0c\u4f46\u6211\u4eec\u5c06\u95ee\u9898\u4fee\u6539\u4e3a\u67e5\u627e\u7ed9\u5b9a\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u522b\u62c5\u5fc3\uff0c\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u4ecb\u7ecd\u8fd9\u4e2a\u90e8\u5206\u3002\u6570\u636e\u96c6\u7531 10675 \u5f20\u72ec\u7279\u7684\u56fe\u50cf\u7ec4\u6210\uff0c\u5176\u4e2d 2379 \u5f20\u6709\u6c14\u80f8\uff08\u6ce8\u610f\uff0c\u8fd9\u4e9b\u6570\u5b57\u662f\u7ecf\u8fc7\u6570\u636e\u6e05\u7406\u540e\u5f97\u51fa\u7684\uff0c\u56e0\u6b64\u4e0e\u539f\u59cb\u6570\u636e\u96c6\u4e0d\u7b26\uff09\u3002\u6b63\u5982\u6570\u636e\u79d1\u5b66\u5bb6\u6240\u8bf4\uff1a\u8fd9\u662f\u4e00\u4e2a\u5178\u578b\u7684 \u504f\u659c\u4e8c\u5143\u5206\u7c7b\u6848\u4f8b \u3002\u56e0\u6b64\uff0c\u6211\u4eec\u9009\u62e9 AUC \u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\uff0c\u5e76\u91c7\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u9a8c\u8bc1\u65b9\u6848\u3002 \u60a8\u53ef\u4ee5\u5c06\u7279\u5f81\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c1d\u8bd5\u4e00\u4e9b\u7ecf\u5178\u65b9\u6cd5\uff08\u5982 SVM\u3001RF\uff09\u6765\u8fdb\u884c\u5206\u7c7b\uff0c\u8fd9\u5b8c\u5168\u6ca1\u95ee\u9898\uff0c\u4f46\u5374\u65e0\u6cd5\u8ba9\u60a8\u8fbe\u5230\u6700\u5148\u8fdb\u7684\u6c34\u5e73\u3002\u6b64\u5916\uff0c\u56fe\u50cf\u5927\u5c0f\u4e3a 1024x1024\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u6a21\u578b\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u7ba1\u600e\u6837\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u7531\u4e8e\u56fe\u50cf\u662f\u7070\u5ea6\u7684\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u8fdb\u884c\u4efb\u4f55\u8f6c\u6362\u3002\u6211\u4eec\u5c06\u628a\u56fe\u50cf\u5927\u5c0f\u8c03\u6574\u4e3a 256x256\uff0c\u4f7f\u5176\u66f4\u5c0f\uff0c\u5e76\u4f7f\u7528\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684 AUC \u4f5c\u4e3a\u8861\u91cf\u6307\u6807\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5b83\u7684\u8868\u73b0\u5982\u4f55\u3002 import os import numpy as np import pandas as pd from PIL import Image from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from tqdm import tqdm # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u6765\u521b\u5efa\u6570\u636e\u96c6 def create_dataset ( training_df , image_dir ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\u6765\u5b58\u50a8\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c images = [] targets = [] # \u8fed\u4ee3\u5904\u7406\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u7684\u6bcf\u4e00\u884c for index , row in tqdm ( training_df . iterrows (), total = len ( training_df ), desc = \"processing images\" ): # \u83b7\u53d6\u56fe\u50cf\u6587\u4ef6\u540d image_id = row [ \"ImageId\" ] # \u6784\u5efa\u5b8c\u6574\u7684\u56fe\u50cf\u6587\u4ef6\u8def\u5f84 image_path = os . path . join ( image_dir , image_id ) # \u6253\u5f00\u56fe\u50cf\u6587\u4ef6\u5e76\u8fdb\u884c\u5927\u5c0f\u8c03\u6574\uff08resize\uff09\u4e3a 256x256 \u50cf\u7d20\uff0c\u4f7f\u7528\u53cc\u7ebf\u6027\u63d2\u503c\uff08BILINEAR\uff09 image = Image . open ( image_path + \".png\" ) image = image . resize (( 256 , 256 ), resample = Image . BILINEAR ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 image = np . array ( image ) # \u5c06\u56fe\u50cf\u6241\u5e73\u5316\u4e3a\u4e00\u7ef4\u6570\u7ec4\uff0c\u5e76\u5c06\u5176\u6dfb\u52a0\u5230\u56fe\u50cf\u5217\u8868 image = image . ravel () images . append ( image ) # \u5c06\u76ee\u6807\u503c\uff08target\uff09\u6dfb\u52a0\u5230\u76ee\u6807\u5217\u8868 targets . append ( int ( row [ \"target\" ])) # \u5c06\u56fe\u50cf\u5217\u8868\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 images = np . array ( images ) # \u6253\u5370\u56fe\u50cf\u6570\u7ec4\u7684\u5f62\u72b6 print ( images . shape ) # \u8fd4\u56de\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c return images , targets if __name__ == \"__main__\" : # \u5b9a\u4e49CSV\u6587\u4ef6\u8def\u5f84\u548c\u56fe\u50cf\u6587\u4ef6\u76ee\u5f55\u8def\u5f84 csv_path = \"/home/abhishek/workspace/siim_png/train.csv\" image_path = \"/home/abhishek/workspace/siim_png/train_png/\" # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( csv_path ) # \u6dfb\u52a0\u4e00\u4e2a\u540d\u4e3a'kfold'\u7684\u5217\uff0c\u5e76\u521d\u59cb\u5316\u4e3a-1 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u503c\uff08target\uff09 y = df . target . values # \u4f7f\u7528\u5206\u5c42KFold\u4ea4\u53c9\u9a8c\u8bc1\u5c06\u6570\u636e\u96c6\u5206\u62105\u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u904d\u5386\u6bcf\u4e2a\u6298\uff08fold\uff09 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u904d\u5386\u6bcf\u4e2a\u6298 for fold_ in range ( 5 ): # \u83b7\u53d6\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtrain , ytrain = create_dataset ( train_df , image_path ) # \u521b\u5efa\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtest , ytest = create_dataset ( test_df , image_path ) # \u521d\u59cb\u5316\u4e00\u4e2a\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668 clf = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u5206\u7c7b\u5668 clf . fit ( xtrain , ytrain ) # \u4f7f\u7528\u5206\u7c7b\u5668\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b\uff0c\u5e76\u83b7\u53d6\u6982\u7387\u503c preds = clf . predict_proba ( xtest )[:, 1 ] # \u6253\u5370\u6298\u6570\uff08fold\uff09\u548cAUC\uff08ROC\u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff09 print ( f \"FOLD: { fold_ } \" ) print ( f \"AUC = { metrics . roc_auc_score ( ytest , preds ) } \" ) print ( \"\" ) \u5e73\u5747 AUC \u503c\u7ea6\u4e3a 0.72\u3002\u8fd9\u8fd8\u4e0d\u9519\uff0c\u4f46\u6211\u4eec\u5e0c\u671b\u80fd\u505a\u5f97\u66f4\u597d\u3002\u4f60\u53ef\u4ee5\u5c06\u8fd9\u79cd\u65b9\u6cd5\u7528\u4e8e\u56fe\u50cf\uff0c\u8fd9\u4e5f\u662f\u5b83\u5728\u4ee5\u524d\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002SVM \u5728\u56fe\u50cf\u6570\u636e\u96c6\u65b9\u9762\u76f8\u5f53\u6709\u540d\u3002\u6df1\u5ea6\u5b66\u4e60\u5df2\u88ab\u8bc1\u660e\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u6700\u5148\u8fdb\u65b9\u6cd5\uff0c\u56e0\u6b64\u6211\u4eec\u4e0b\u4e00\u6b65\u53ef\u4ee5\u8bd5\u8bd5\u5b83\u3002 \u5173\u4e8e\u6df1\u5ea6\u5b66\u4e60\u7684\u5386\u53f2\u4ee5\u53ca\u8c01\u53d1\u660e\u4e86\u4ec0\u4e48\uff0c\u6211\u5c31\u4e0d\u591a\u8bf4\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u6700\u8457\u540d\u7684\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e4b\u4e00 AlexNet\u3002 \u56fe 3\uff1aAlexNet \u67b6\u67849 \u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u4e2d\u7684\u8f93\u5165\u5927\u5c0f\u4e0d\u662f 224x224 \u800c\u662f 227x227 \u5982\u4eca\uff0c\u4f60\u53ef\u80fd\u4f1a\u8bf4\u8fd9\u53ea\u662f\u4e00\u4e2a\u57fa\u672c\u7684 \u6df1\u5ea6\u5377\u79ef\u795e\u7ecf\u7f51\u7edc \uff0c\u4f46\u5b83\u5374\u662f\u8bb8\u591a\u65b0\u578b\u6df1\u5ea6\u7f51\u7edc\uff08\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff09\u7684\u57fa\u7840\u3002\u6211\u4eec\u770b\u5230\uff0c\u56fe 3 \u4e2d\u7684\u7f51\u7edc\u662f\u4e00\u4e2a\u5177\u6709\u4e94\u4e2a\u5377\u79ef\u5c42\u3001\u4e24\u4e2a\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u3002\u6211\u4eec\u770b\u5230\u8fd8\u6709\u6700\u5927\u6c60\u5316\u3002\u8fd9\u662f\u4ec0\u4e48\u610f\u601d\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5728\u8fdb\u884c\u6df1\u5ea6\u5b66\u4e60\u65f6\u4f1a\u9047\u5230\u7684\u4e00\u4e9b\u672f\u8bed\u3002 \u56fe 4\uff1a\u56fe\u50cf\u5927\u5c0f\u4e3a 8x8\uff0c\u6ee4\u6ce2\u5668\u5927\u5c0f\u4e3a 3x3\uff0c\u6b65\u957f\u4e3a 2\u3002 \u56fe 4 \u5f15\u5165\u4e86\u4e24\u4e2a\u65b0\u672f\u8bed\uff1a\u6ee4\u6ce2\u5668\u548c\u6b65\u957f\u3002 \u6ee4\u6ce2\u5668 \u662f\u7531\u7ed9\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u7684\u4e8c\u7ef4\u77e9\u9635\uff0c\u7531\u6307\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u3002 Kaiming\u6b63\u6001\u521d\u59cb\u5316 \uff0c\u662f\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u7684\u6700\u4f73\u9009\u62e9\u3002\u8fd9\u662f\u56e0\u4e3a\u5927\u591a\u6570\u73b0\u4ee3\u7f51\u7edc\u90fd\u4f7f\u7528 ReLU \uff08\u6574\u6d41\u7ebf\u6027\u5355\u5143\uff09\u6fc0\u6d3b\u51fd\u6570\uff0c\u9700\u8981\u9002\u5f53\u7684\u521d\u59cb\u5316\u6765\u907f\u514d\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff08\u68af\u5ea6\u8d8b\u8fd1\u4e8e\u96f6\uff0c\u7f51\u7edc\u6743\u91cd\u4e0d\u53d8\uff09\u3002\u8be5\u6ee4\u6ce2\u5668\u4e0e\u56fe\u50cf\u8fdb\u884c\u5377\u79ef\u3002\u5377\u79ef\u4e0d\u8fc7\u662f\u6ee4\u6ce2\u5668\u4e0e\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u5f53\u524d\u91cd\u53e0\u50cf\u7d20\u4e4b\u95f4\u7684\u5143\u7d20\u76f8\u4e58\u7684\u603b\u548c\u3002\u60a8\u53ef\u4ee5\u5728\u4efb\u4f55\u9ad8\u4e2d\u6570\u5b66\u6559\u79d1\u4e66\u4e2d\u9605\u8bfb\u66f4\u591a\u5173\u4e8e\u5377\u79ef\u7684\u5185\u5bb9\u3002\u6211\u4eec\u4ece\u56fe\u50cf\u7684\u5de6\u4e0a\u89d2\u5f00\u59cb\u5bf9\u6ee4\u955c\u8fdb\u884c\u5377\u79ef\uff0c\u7136\u540e\u6c34\u5e73\u79fb\u52a8\u6ee4\u955c\u3002\u5982\u679c\u79fb\u52a8 1 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 1\uff1b\u5982\u679c\u79fb\u52a8 2 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 2\u3002 \u5373\u4f7f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u4f8b\u5982\u5728\u95ee\u9898\u548c\u56de\u7b54\u7cfb\u7edf\u4e2d\u9700\u8981\u4ece\u5927\u91cf\u6587\u672c\u8bed\u6599\u4e2d\u7b5b\u9009\u7b54\u6848\u65f6\uff0c\u6b65\u957f\u4e5f\u662f\u4e00\u4e2a\u975e\u5e38\u6709\u7528\u7684\u6982\u5ff5\u3002\u5f53\u6211\u4eec\u5728\u6c34\u5e73\u65b9\u5411\u4e0a\u8d70\u5230\u5c3d\u5934\u65f6\uff0c\u5c31\u4f1a\u4ee5\u540c\u6837\u7684\u6b65\u957f\u5782\u76f4\u5411\u4e0b\u79fb\u52a8\u8fc7\u6ee4\u5668\uff0c\u4ece\u5de6\u4fa7\u5f00\u59cb\u3002\u56fe 4 \u8fd8\u663e\u793a\u4e86\u8fc7\u6ee4\u5668\u79fb\u51fa\u56fe\u50cf\u7684\u60c5\u51b5\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u65e0\u6cd5\u8ba1\u7b97\u5377\u79ef\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u8df3\u8fc7\u5b83\u3002\u5982\u679c\u4e0d\u60f3\u8df3\u8fc7\uff0c\u5219\u9700\u8981\u5bf9\u56fe\u50cf\u8fdb\u884c \u586b\u5145\uff08pad\uff09 \u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5377\u79ef\u4f1a\u51cf\u5c0f\u56fe\u50cf\u7684\u5927\u5c0f\u3002\u586b\u5145\u4e5f\u662f\u4fdd\u6301\u56fe\u50cf\u5927\u5c0f\u4e0d\u53d8\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u5728\u56fe 4 \u4e2d\uff0c\u4e00\u4e2a 3x3 \u6ee4\u6ce2\u5668\u6b63\u5728\u6c34\u5e73\u548c\u5782\u76f4\u79fb\u52a8\uff0c\u6bcf\u6b21\u79fb\u52a8\u90fd\u4f1a\u5206\u522b\u8df3\u8fc7\u4e24\u5217\u548c\u4e24\u884c\uff08\u5373\u50cf\u7d20\uff09\u3002\u7531\u4e8e\u5b83\u8df3\u8fc7\u4e86\u4e24\u4e2a\u50cf\u7d20\uff0c\u6240\u4ee5\u6b65\u957f = 2\u3002\u56e0\u6b64\u56fe\u50cf\u5927\u5c0f\u4e3a [(8-3) / 2] + 1 = 3.5\u3002\u6211\u4eec\u53d6 3.5 \u7684\u4e0b\u9650\uff0c\u6240\u4ee5\u662f 3x3\u3002\u60a8\u53ef\u4ee5\u5728\u8349\u7a3f\u7eb8\u4e0a\u8fdb\u884c\u5c1d\u8bd5\u3002 \u56fe 5\uff1a\u901a\u8fc7\u586b\u5145\uff0c\u6211\u4eec\u53ef\u4ee5\u63d0\u4f9b\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\u7684\u56fe\u50cf \u6211\u4eec\u53ef\u4ee5\u4ece\u56fe 5 \u4e2d\u770b\u5230\u586b\u5145\u7684\u6548\u679c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e00\u4e2a 3x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5b83\u4ee5 1 \u7684\u6b65\u957f\u79fb\u52a8\u3002\u539f\u59cb\u56fe\u50cf\u7684\u5927\u5c0f\u4e3a 6x6\uff0c\u6211\u4eec\u6dfb\u52a0\u4e86 1 \u7684 \u586b\u5145 \u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u751f\u6210\u7684\u56fe\u50cf\u5c06\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\uff0c\u5373 6x6\u3002\u5728\u5904\u7406\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u7684\u53e6\u4e00\u4e2a\u76f8\u5173\u672f\u8bed\u662f \u81a8\u80c0\uff08dilation\uff09 \uff0c\u5982\u56fe 6 \u6240\u793a\u3002 \u56fe 6\uff1a\u81a8\u80c0\uff08dilation\uff09\u7684\u4f8b\u5b50 \u5728\u81a8\u80c0\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u5c06\u6ee4\u6ce2\u5668\u6269\u5927 N-1\uff0c\u5176\u4e2d N \u662f\u81a8\u80c0\u7387\u7684\u503c\uff0c\u6216\u7b80\u79f0\u4e3a\u81a8\u80c0\u3002\u5728\u8fd9\u79cd\u5e26\u81a8\u80c0\u7684\u5185\u6838\u4e2d\uff0c\u6bcf\u6b21\u5377\u79ef\u90fd\u4f1a\u8df3\u8fc7\u4e00\u4e9b\u50cf\u7d20\u3002\u8fd9\u5728\u5206\u5272\u4efb\u52a1\u4e2d\u5c24\u4e3a\u6709\u6548\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u4e8c\u7ef4\u5377\u79ef\u3002 \u8fd8\u6709\u4e00\u7ef4\u5377\u79ef\u548c\u66f4\u9ad8\u7ef4\u5ea6\u7684\u5377\u79ef\u3002\u5b83\u4eec\u90fd\u57fa\u4e8e\u76f8\u540c\u7684\u57fa\u672c\u6982\u5ff5\u3002 \u63a5\u4e0b\u6765\u662f \u6700\u5927\u6c60\u5316\uff08Max pooling\uff09 \u3002\u6700\u5927\u503c\u6c60\u53ea\u662f\u4e00\u4e2a\u8fd4\u56de\u6700\u5927\u503c\u7684\u6ee4\u6ce2\u5668\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u63d0\u53d6\u7684\u4e0d\u662f\u5377\u79ef\uff0c\u800c\u662f\u50cf\u7d20\u7684\u6700\u5927\u503c\u3002\u540c\u6837\uff0c \u5e73\u5747\u6c60\u5316\uff08average pooling\uff09 \u6216 \u5747\u503c\u6c60\u5316\uff08mean pooling\uff09 \u4f1a\u8fd4\u56de\u50cf\u7d20\u7684\u5e73\u5747\u503c\u3002\u5b83\u4eec\u7684\u4f7f\u7528\u65b9\u6cd5\u4e0e\u5377\u79ef\u6838\u76f8\u540c\u3002\u6c60\u5316\u6bd4\u5377\u79ef\u66f4\u5feb\uff0c\u662f\u4e00\u79cd\u5bf9\u56fe\u50cf\u8fdb\u884c\u7f29\u51cf\u91c7\u6837\u7684\u65b9\u6cd5\u3002\u6700\u5927\u6c60\u5316\u53ef\u68c0\u6d4b\u8fb9\u7f18\uff0c\u5e73\u5747\u6c60\u5316\u53ef\u5e73\u6ed1\u56fe\u50cf\u3002 \u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u548c\u6df1\u5ea6\u5b66\u4e60\u7684\u6982\u5ff5\u592a\u591a\u4e86\u3002\u6211\u6240\u8ba8\u8bba\u7684\u662f\u4e00\u4e9b\u57fa\u7840\u77e5\u8bc6\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5728 PyTorch \u4e2d\u6784\u5efa\u7b2c\u4e00\u4e2a\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u505a\u597d\u4e86\u5145\u5206\u51c6\u5907\u3002PyTorch \u63d0\u4f9b\u4e86\u4e00\u79cd\u76f4\u89c2\u800c\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u5b9e\u73b0\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e14\u4f60\u4e0d\u9700\u8981\u5173\u5fc3\u53cd\u5411\u4f20\u64ad\u3002\u6211\u4eec\u7528\u4e00\u4e2a python \u7c7b\u548c\u4e00\u4e2a\u524d\u9988\u51fd\u6570\u6765\u5b9a\u4e49\u7f51\u7edc\uff0c\u544a\u8bc9 PyTorch \u5404\u5c42\u4e4b\u95f4\u5982\u4f55\u8fde\u63a5\u3002\u5728 PyTorch \u4e2d\uff0c\u56fe\u50cf\u7b26\u53f7\u662f BS\u3001C\u3001H\u3001W\uff0c\u5176\u4e2d\uff0cBS \u662f\u6279\u5927\u5c0f\uff0cC \u662f\u901a\u9053\uff0cH \u662f\u9ad8\u5ea6\uff0cW \u662f\u5bbd\u5ea6\u3002\u8ba9\u6211\u4eec\u770b\u770b PyTorch \u662f\u5982\u4f55\u5b9e\u73b0 AlexNet \u7684\u3002 import torch import torch.nn as nn import torch.nn.functional as F class AlexNet ( nn . Module ): def __init__ ( self ): super ( AlexNet , self ) . __init__ () self . conv1 = nn . Conv2d ( in_channels = 3 , out_channels = 96 , kernel_size = 11 , stride = 4 , padding = 0 ) self . pool1 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv2 = nn . Conv2d ( in_channels = 96 , out_channels = 256 , kernel_size = 5 , stride = 1 , padding = 2 ) self . pool2 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv3 = nn . Conv2d ( in_channels = 256 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv4 = nn . Conv2d ( in_channels = 384 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv5 = nn . Conv2d ( in_channels = 384 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ) self . pool3 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . fc1 = nn . Linear ( in_features = 9216 , out_features = 4096 ) self . dropout1 = nn . Dropout ( 0.5 ) self . fc2 = nn . Linear ( in_features = 4096 , out_features = 4096 ) self . dropout2 = nn . Dropout ( 0.5 ) self . fc3 = nn . Linear ( in_features = 4096 , out_features = 1000 ) def forward ( self , image ): bs , c , h , w = image . size () x = F . relu ( self . conv1 ( image )) # size: (bs, 96, 55, 55) x = self . pool1 ( x ) # size: (bs, 96, 27, 27) x = F . relu ( self . conv2 ( x )) # size: (bs, 256, 27, 27) x = self . pool2 ( x ) # size: (bs, 256, 13, 13) x = F . relu ( self . conv3 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv4 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv5 ( x )) # size: (bs, 256, 13, 13) x = self . pool3 ( x ) # size: (bs, 256, 6, 6) x = x . view ( bs , - 1 ) # size: (bs, 9216) x = F . relu ( self . fc1 ( x )) # size: (bs, 4096) x = self . dropout1 ( x ) # size: (bs, 4096) # dropout does not change size # dropout is used for regularization # 0.3 dropout means that only 70% of the nodes # of the current layer are used for the next layer x = F . relu ( self . fc2 ( x )) # size: (bs, 4096) x = self . dropout2 ( x ) # size: (bs, 4096) x = F . relu ( self . fc3 ( x )) # size: (bs, 1000) # 1000 is number of classes in ImageNet Dataset # softmax is an activation function that converts # linear output to probabilities that add up to 1 # for each sample in the batch x = torch . softmax ( x , axis = 1 ) # size: (bs, 1000) return x \u5982\u679c\u60a8\u6709\u4e00\u5e45 3x227x227 \u7684\u56fe\u50cf\uff0c\u5e76\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11 \u7684\u5377\u79ef\u6ee4\u6ce2\u5668\uff0c\u8fd9\u610f\u5473\u7740\u60a8\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5e76\u4e0e\u4e00\u4e2a\u5927\u5c0f\u4e3a 227x227x3 \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u5377\u79ef\u3002\u8f93\u51fa\u901a\u9053\u7684\u6570\u91cf\u5c31\u662f\u5206\u522b\u5e94\u7528\u4e8e\u56fe\u50cf\u7684\u76f8\u540c\u5927\u5c0f\u7684\u4e0d\u540c\u5377\u79ef\u6ee4\u6ce2\u5668\u7684\u6570\u91cf\u3002 \u56e0\u6b64\uff0c\u5728\u7b2c\u4e00\u4e2a\u5377\u79ef\u5c42\u4e2d\uff0c\u8f93\u5165\u901a\u9053\u662f 3\uff0c\u4e5f\u5c31\u662f\u539f\u59cb\u8f93\u5165\uff0c\u5373 R\u3001G\u3001B \u4e09\u901a\u9053\u3002PyTorch \u7684 torchvision \u63d0\u4f9b\u4e86\u8bb8\u591a\u4e0e AlexNet \u7c7b\u4f3c\u7684\u4e0d\u540c\u6a21\u578b\uff0c\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0cAlexNet \u7684\u5b9e\u73b0\u4e0e torchvision \u7684\u5b9e\u73b0\u5e76\u4e0d\u76f8\u540c\u3002Torchvision \u7684 AlexNet \u5b9e\u73b0\u662f\u4ece\u53e6\u4e00\u7bc7\u8bba\u6587\u4e2d\u4fee\u6539\u800c\u6765\u7684 AlexNet\uff1a Krizhevsky, A. One weird trick for parallelizing convolutional neural networks. CoRR, abs/1404.5997, 2014. \u4f60\u53ef\u4ee5\u4e3a\u81ea\u5df1\u7684\u4efb\u52a1\u8bbe\u8ba1\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u5f88\u591a\u65f6\u5019\uff0c\u4ece\u96f6\u505a\u8d77\u662f\u4e2a\u4e0d\u9519\u7684\u4e3b\u610f\u3002\u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u7f51\u7edc\uff0c\u7528\u4e8e\u533a\u5206\u56fe\u50cf\u6709\u65e0\u6c14\u80f8\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u51c6\u5907\u4e00\u4e9b\u6587\u4ef6\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u4e00\u4e2a\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\uff0c\u5373 train.csv\uff0c\u4f46\u589e\u52a0\u4e00\u5217 kfold\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e94\u4e2a\u6587\u4ef6\u5939\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u5df2\u7ecf\u6f14\u793a\u4e86\u5982\u4f55\u9488\u5bf9\u4e0d\u540c\u7684\u6570\u636e\u96c6\u521b\u5efa\u6298\u53e0\uff0c\u56e0\u6b64\u6211\u5c06\u8df3\u8fc7\u8fd9\u4e00\u90e8\u5206\uff0c\u7559\u4f5c\u7ec3\u4e60\u3002\u5bf9\u4e8e\u57fa\u4e8e PyTorch \u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u9700\u8981\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u7684\u76ee\u7684\u662f\u8fd4\u56de\u4e00\u4e2a\u6570\u636e\u9879\u6216\u6570\u636e\u6837\u672c\u3002\u8fd9\u4e2a\u6570\u636e\u6837\u672c\u5e94\u8be5\u5305\u542b\u8bad\u7ec3\u6216\u8bc4\u4f30\u6a21\u578b\u6240\u9700\u7684\u6240\u6709\u5185\u5bb9\u3002 import torch import numpy as np from PIL import Image from PIL import ImageFile ImageFile . LOAD_TRUNCATED_IMAGES = True # \u5b9a\u4e49\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u5904\u7406\u56fe\u50cf\u5206\u7c7b\u4efb\u52a1 class ClassificationDataset : def __init__ ( self , image_paths , targets , resize = None , augmentations = None ): # \u56fe\u50cf\u6587\u4ef6\u8def\u5f84\u5217\u8868 self . image_paths = image_paths # \u76ee\u6807\u6807\u7b7e\u5217\u8868 self . targets = targets # \u56fe\u50cf\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . resize = resize # \u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . augmentations = augmentations def __len__ ( self ): # \u8fd4\u56de\u6570\u636e\u96c6\u7684\u5927\u5c0f\uff0c\u5373\u56fe\u50cf\u6570\u91cf return len ( self . image_paths ) def __getitem__ ( self , item ): # \u83b7\u53d6\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e2a\u6837\u672c image = Image . open ( self . image_paths [ item ]) image = image . convert ( \"RGB\" ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aRGB\u683c\u5f0f # \u83b7\u53d6\u8be5\u6837\u672c\u7684\u76ee\u6807\u6807\u7b7e targets = self . targets [ item ] if self . resize is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u5c06\u56fe\u50cf\u8fdb\u884c\u5c3a\u5bf8\u8c03\u6574 image = image . resize (( self . resize [ 1 ], self . resize [ 0 ]), resample = Image . BILINEAR ) image = np . array ( image ) if self . augmentations is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u5e94\u7528\u6570\u636e\u589e\u5f3a augmented = self . augmentations ( image = image ) image = augmented [ \"image\" ] # \u5c06\u56fe\u50cf\u901a\u9053\u987a\u5e8f\u8c03\u6574\u4e3a(C, H, W)\u7684\u5f62\u5f0f\uff0c\u5e76\u8f6c\u6362\u4e3afloat32\u7c7b\u578b image = np . transpose ( image , ( 2 , 0 , 1 )) . astype ( np . float32 ) # \u8fd4\u56de\u6837\u672c\uff0c\u5305\u62ec\u56fe\u50cf\u548c\u5bf9\u5e94\u7684\u76ee\u6807\u6807\u7b7e return { \"image\" : torch . tensor ( image , dtype = torch . float ), \"targets\" : torch . tensor ( targets , dtype = torch . long ), } \u73b0\u5728\u6211\u4eec\u9700\u8981 engine.py\u3002engine.py \u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u770b\u770b engine.py \u662f\u5982\u4f55\u7f16\u5199\u7684\u3002 import torch import torch.nn as nn from tqdm import tqdm # \u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u51fd\u6570 def train ( data_loader , model , optimizer , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bad\u7ec3\u6a21\u5f0f model . train () for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u548c\u76ee\u6807\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) targets = targets . to ( device , dtype = torch . float ) # \u5c06\u4f18\u5316\u5668\u4e2d\u7684\u68af\u5ea6\u5f52\u96f6 optimizer . zero_grad () # \u524d\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u6a21\u578b\u9884\u6d4b outputs = model ( inputs ) # \u4f7f\u7528\u5e26\u903b\u8f91\u65af\u8482\u51fd\u6570\u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\u8ba1\u7b97\u635f\u5931 loss = nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) # \u53cd\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u68af\u5ea6\u5e76\u66f4\u65b0\u6a21\u578b\u6743\u91cd loss . backward () optimizer . step () # \u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u7684\u51fd\u6570 def evaluate ( data_loader , model , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bc4\u4f30\u6a21\u5f0f\uff08\u4e0d\u8fdb\u884c\u68af\u5ea6\u8ba1\u7b97\uff09 model . eval () # \u521d\u59cb\u5316\u5217\u8868\u4ee5\u5b58\u50a8\u771f\u5b9e\u76ee\u6807\u548c\u6a21\u578b\u9884\u6d4b final_targets = [] final_outputs = [] with torch . no_grad (): for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) # \u83b7\u53d6\u6a21\u578b\u9884\u6d4b output = model ( inputs ) # \u5c06\u76ee\u6807\u548c\u8f93\u51fa\u8f6c\u6362\u4e3aCPU\u548cPython\u5217\u8868 targets = targets . detach () . cpu () . numpy () . tolist () output = output . detach () . cpu () . numpy () . tolist () # \u5c06\u5217\u8868\u6269\u5c55\u4ee5\u5305\u542b\u6279\u6b21\u6570\u636e final_targets . extend ( targets ) final_outputs . extend ( output ) # \u8fd4\u56de\u6700\u7ec8\u7684\u6a21\u578b\u9884\u6d4b\u548c\u771f\u5b9e\u76ee\u6807 return final_outputs , final_targets \u6709\u4e86 engine.py\uff0c\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\uff1amodel.py\u3002model.py \u5c06\u5305\u542b\u6211\u4eec\u7684\u6a21\u578b\u3002\u628a\u6a21\u578b\u4e0e\u8bad\u7ec3\u5206\u5f00\u662f\u4e2a\u597d\u4e3b\u610f\uff0c\u56e0\u4e3a\u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u8bd5\u9a8c\u4e0d\u540c\u7684\u6a21\u578b\u548c\u4e0d\u540c\u7684\u67b6\u6784\u3002\u540d\u4e3a pretrainedmodels \u7684 PyTorch \u5e93\u4e2d\u6709\u5f88\u591a\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\uff0c\u5982 AlexNet\u3001ResNet\u3001DenseNet \u7b49\u3002\u8fd9\u4e9b\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\u662f\u5728\u540d\u4e3a ImageNet \u7684\u5927\u578b\u56fe\u50cf\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u51fa\u6765\u7684\u3002\u5728 ImageNet \u4e0a\u8bad\u7ec3\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5b83\u4eec\u7684\u6743\u91cd\uff0c\u4e5f\u53ef\u4ee5\u4e0d\u4f7f\u7528\u8fd9\u4e9b\u6743\u91cd\u3002\u5982\u679c\u6211\u4eec\u4e0d\u4f7f\u7528 ImageNet \u6743\u91cd\u8fdb\u884c\u8bad\u7ec3\uff0c\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u7f51\u7edc\u5c06\u4ece\u5934\u5f00\u59cb\u5b66\u4e60\u4e00\u5207\u3002\u8fd9\u5c31\u662f model.py \u7684\u6837\u5b50\u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 4096 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 4096 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5982\u679c\u4f60\u6253\u5370\u4e86\u7f51\u7edc\uff0c\u4f1a\u5f97\u5230\u5982\u4e0b\u8f93\u51fa\uff1a AlexNet ( ( avgpool ): AdaptiveAvgPool2d ( output_size = ( 6 , 6 )) ( _features ): Sequential ( ( 0 ): Conv2d ( 3 , 64 , kernel_size = ( 11 , 11 ), stride = ( 4 , 4 ), padding = ( 2 , 2 )) ( 1 ): ReLU ( inplace = True ) ( 2 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 3 ): Conv2d ( 64 , 192 , kernel_size = ( 5 , 5 ), stride = ( 1 , 1 ), padding = ( 2 , 2 )) ( 4 ): ReLU ( inplace = True ) ( 5 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 6 ): Conv2d ( 192 , 384 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 7 ): ReLU ( inplace = True ) ( 8 ): Conv2d ( 384 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 9 ): ReLU ( inplace = True ) ( 10 ): Conv2d ( 256 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 11 ): ReLU ( inplace = True ) ( 12 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , eil_mode = False )) ( dropout0 ): Dropout ( p = 0.5 , inplace = False ) ( linear0 ): Linear ( in_features = 9216 , out_features = 4096 , bias = True ) ( relu0 ): ReLU ( inplace = True ) ( dropout1 ): Dropout ( p = 0.5 , inplace = False ) ( linear1 ): Linear ( in_features = 4096 , out_features = 4096 , bias = True ) ( relu1 ): ReLU ( inplace = True ) ( last_linear ): Sequential ( ( 0 ): BatchNorm1d ( 4096 , eps = 1e-05 , momentum = 0.1 , affine = True , rack_running_stats = True ) ( 1 ): Dropout ( p = 0.25 , inplace = False ) ( 2 ): Linear ( in_features = 4096 , out_features = 2048 , bias = True ) ( 3 ): ReLU () ( 4 ): BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 , affine = True , track_running_stats = True ) ( 5 ): Dropout ( p = 0.5 , inplace = False ) ( 6 ): Linear ( in_features = 2048 , out_features = 1 , bias = True ) ) ) \u73b0\u5728\uff0c\u4e07\u4e8b\u4ff1\u5907\uff0c\u53ef\u4ee5\u5f00\u59cb\u8bad\u7ec3\u4e86\u3002\u6211\u4eec\u5c06\u4f7f\u7528 train.py \u8bad\u7ec3\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import torch from sklearn import metrics from sklearn.model_selection import train_test_split import dataset import engine from model import get_model if __name__ == \"__main__\" : # \u5b9a\u4e49\u6570\u636e\u8def\u5f84\u3001\u8bbe\u5907\u3001\u8fed\u4ee3\u6b21\u6570 data_path = \"/home/abhishek/workspace/siim_png/\" device = \"cuda\" # \u4f7f\u7528GPU\u52a0\u901f epochs = 10 # \u4eceCSV\u6587\u4ef6\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( os . path . join ( data_path , \"train.csv\" )) images = df . ImageId . values . tolist () images = [ os . path . join ( data_path , \"train_png\" , i + \".png\" ) for i in images ] targets = df . target . values # \u83b7\u53d6\u9884\u8bad\u7ec3\u7684\u6a21\u578b model = get_model ( pretrained = True ) model . to ( device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\uff0c\u7528\u4e8e\u6570\u636e\u6807\u51c6\u5316 mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) # \u6570\u636e\u589e\u5f3a\uff0c\u5c06\u56fe\u50cf\u6807\u51c6\u5316 aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5212\u5206\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 train_images , valid_images , train_targets , valid_targets = train_test_split ( images , targets , stratify = targets , random_state = 42 ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u548c\u9a8c\u8bc1\u6570\u636e\u96c6 train_dataset = dataset . ClassificationDataset ( image_paths = train_images , targets = train_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = 16 , shuffle = True , num_workers = 4 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = dataset . ClassificationDataset ( image_paths = valid_images , targets = valid_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = 16 , shuffle = False , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u8bad\u7ec3\u5faa\u73af for epoch in range ( epochs ): # \u8bad\u7ec3\u6a21\u578b engine . train ( train_loader , model , optimizer , device = device ) # \u8bc4\u4f30\u6a21\u578b\u6027\u80fd predictions , valid_targets = engine . evaluate ( valid_loader , model , device = device ) # \u8ba1\u7b97ROC AUC\u5206\u6570\u5e76\u6253\u5370 roc_auc = metrics . roc_auc_score ( valid_targets , predictions ) print ( f \"Epoch= { epoch } , Valid ROC AUC= { roc_auc } \" ) \u8ba9\u6211\u4eec\u5728\u6ca1\u6709\u9884\u8bad\u7ec3\u6743\u91cd\u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff1a Epoch = 0 , Valid ROC AUC = 0.5737161981475328 Epoch = 1 , Valid ROC AUC = 0.5362868001588292 Epoch = 2 , Valid ROC AUC = 0.6163448214387008 Epoch = 3 , Valid ROC AUC = 0.6119219143780944 Epoch = 4 , Valid ROC AUC = 0.6229718888519726 Epoch = 5 , Valid ROC AUC = 0.5983014999635341 Epoch = 6 , Valid ROC AUC = 0.5523236874306134 Epoch = 7 , Valid ROC AUC = 0.4717721611306046 Epoch = 8 , Valid ROC AUC = 0.6473408263980617 Epoch = 9 , Valid ROC AUC = 0.6639862888260415 AUC \u7ea6\u4e3a 0.66\uff0c\u751a\u81f3\u4f4e\u4e8e\u6211\u4eec\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u4f1a\u53d1\u751f\u4ec0\u4e48\u60c5\u51b5\uff1f Epoch = 0 , Valid ROC AUC = 0.5730387429803165 Epoch = 1 , Valid ROC AUC = 0.5319813942934937 Epoch = 2 , Valid ROC AUC = 0.627111577514323 Epoch = 3 , Valid ROC AUC = 0.6819736959393209 Epoch = 4 , Valid ROC AUC = 0.5747117168950512 Epoch = 5 , Valid ROC AUC = 0.5994619255609669 Epoch = 6 , Valid ROC AUC = 0.5080889443530546 Epoch = 7 , Valid ROC AUC = 0.6323792776512727 Epoch = 8 , Valid ROC AUC = 0.6685753182661686 Epoch = 9 , Valid ROC AUC = 0.6861802387300147 \u73b0\u5728\u7684 AUC \u597d\u4e86\u5f88\u591a\u3002\u4e0d\u8fc7\uff0c\u5b83\u4ecd\u7136\u8f83\u4f4e\u3002\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u597d\u5904\u662f\u53ef\u4ee5\u8f7b\u677e\u5c1d\u8bd5\u591a\u79cd\u4e0d\u540c\u7684\u6a21\u578b\u3002\u8ba9\u6211\u4eec\u8bd5\u8bd5\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u7684 resnet18 \u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 512 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 512 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5728\u5c1d\u8bd5\u8be5\u6a21\u578b\u65f6\uff0c\u6211\u8fd8\u5c06\u56fe\u50cf\u5927\u5c0f\u6539\u4e3a 512x512\uff0c\u5e76\u6dfb\u52a0\u4e86\u4e00\u4e2a\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\uff0c\u6bcf 3 \u4e2aepochs\u540e\u5c06\u5b66\u4e60\u7387\u4e58\u4ee5 0.5\u3002 Epoch = 0 , Valid ROC AUC = 0.5988225569880796 Epoch = 1 , Valid ROC AUC = 0.730349343208836 Epoch = 2 , Valid ROC AUC = 0.5870943169939142 Epoch = 3 , Valid ROC AUC = 0.5775864444138311 Epoch = 4 , Valid ROC AUC = 0.7330502499939224 Epoch = 5 , Valid ROC AUC = 0.7500336296524395 Epoch = 6 , Valid ROC AUC = 0.7563722113724951 Epoch = 7 , Valid ROC AUC = 0.7987463837994215 Epoch = 8 , Valid ROC AUC = 0.798505708937384 Epoch = 9 , Valid ROC AUC = 0.8025477500546988 \u8fd9\u4e2a\u6a21\u578b\u4f3c\u4e4e\u8868\u73b0\u6700\u597d\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u8c03\u6574 AlexNet \u4e2d\u7684\u4e0d\u540c\u53c2\u6570\u548c\u56fe\u50cf\u5927\u5c0f\uff0c\u4ee5\u83b7\u5f97\u66f4\u597d\u7684\u5206\u6570\u3002 \u4f7f\u7528\u589e\u5f3a\u6280\u672f\u5c06\u8fdb\u4e00\u6b65\u63d0\u9ad8\u5f97\u5206\u3002\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u5f88\u96be\uff0c\u4f46\u5e76\u975e\u4e0d\u53ef\u80fd\u3002\u9009\u62e9 Adam \u4f18\u5316\u5668\u3001\u4f7f\u7528\u4f4e\u5b66\u4e60\u7387\u3001\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u9ad8\u70b9\u65f6\u964d\u4f4e\u5b66\u4e60\u7387\u3001\u5c1d\u8bd5\u4e00\u4e9b\u589e\u5f3a\u6280\u672f\u3001\u5c1d\u8bd5\u5bf9\u56fe\u50cf\u8fdb\u884c\u9884\u5904\u7406\uff08\u5982\u5728\u9700\u8981\u65f6\u8fdb\u884c\u88c1\u526a\uff0c\u8fd9\u4e5f\u53ef\u89c6\u4e3a\u9884\u5904\u7406\uff09\u3001\u6539\u53d8\u6279\u6b21\u5927\u5c0f\u7b49\u3002\u4f60\u53ef\u4ee5\u505a\u5f88\u591a\u4e8b\u60c5\u6765\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u3002 \u4e0e AlexNet \u76f8\u6bd4\uff0c ResNet \u7684\u7ed3\u6784\u8981\u590d\u6742\u5f97\u591a\u3002ResNet \u662f\u6b8b\u5dee\u795e\u7ecf\u7f51\u7edc\uff08Residual Neural Network\uff09\u7684\u7f29\u5199\uff0c\u7531 K. He\u3001X. Zhang\u3001S. Ren \u548c J. Sun \u5728 2015 \u5e74\u53d1\u8868\u7684\u8bba\u6587\u4e2d\u63d0\u51fa\u3002ResNet \u7531 \u6b8b\u5dee\u5757 \uff08residual blocks\uff09\u7ec4\u6210\uff0c\u901a\u8fc7\u8df3\u8fc7\u67d0\u4e9b\u5c42\uff0c\u4f7f\u77e5\u8bc6\u80fd\u591f\u4e0d\u65ad\u5728\u5404\u5c42\u4e2d\u8fdb\u884c\u4f20\u9012\u3002\u8fd9\u4e9b\u5c42\u4e4b\u95f4\u7684 \u8fde\u63a5\u88ab\u79f0\u4e3a \u8df3\u8dc3\u8fde\u63a5 \uff08skip-connections\uff09\uff0c\u56e0\u4e3a\u6211\u4eec\u8df3\u8fc7\u4e86\u4e00\u5c42\u6216\u591a\u5c42\u3002\u8df3\u8dc3\u8fde\u63a5\u901a\u8fc7\u5c06\u68af\u5ea6\u4f20\u64ad\u5230\u66f4\u591a\u5c42\u6765\u5e2e\u52a9\u89e3\u51b3\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u8bad\u7ec3\u975e\u5e38\u5927\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e0d\u4f1a\u635f\u5931\u6027\u80fd\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5927\u578b\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u4e48\u5f53\u8bad\u7ec3\u5230\u67d0\u4e00\u8282\u70b9\u4e0a\u65f6\u8bad\u7ec3\u635f\u5931\u53cd\u800c\u4f1a\u589e\u52a0\uff0c\u4f46\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u8df3\u8dc3\u8fde\u63a5\u6765\u907f\u514d\u3002\u901a\u8fc7\u56fe 7 \u53ef\u4ee5\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u56fe 7\uff1a\u7b80\u5355\u8fde\u63a5\u4e0e\u6b8b\u5dee\u8fde\u63a5\u7684\u6bd4\u8f83\u3002\u53c2\u89c1\u8df3\u8dc3\u8fde\u63a5\u3002\u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u7701\u7565\u4e86\u6700\u540e\u4e00\u5c42\u3002 \u6b8b\u5dee\u5757\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4f60\u4ece\u67d0\u4e00\u5c42\u83b7\u53d6\u8f93\u51fa\uff0c\u8df3\u8fc7\u4e00\u4e9b\u5c42\uff0c\u7136\u540e\u5c06\u8f93\u51fa\u6dfb\u52a0\u5230\u7f51\u7edc\u4e2d\u66f4\u8fdc\u7684\u4e00\u5c42\u3002\u865a\u7ebf\u8868\u793a\u8f93\u5165\u5f62\u72b6\u9700\u8981\u8c03\u6574\uff0c\u56e0\u4e3a\u4f7f\u7528\u4e86\u6700\u5927\u6c60\u5316\uff0c\u800c\u6700\u5927\u6c60\u5316\u7684\u4f7f\u7528\u4f1a\u6539\u53d8\u8f93\u51fa\u7684\u5927\u5c0f\u3002 ResNet \u6709\u591a\u79cd\u4e0d\u540c\u7684\u7248\u672c\uff1a \u6709 18 \u5c42\u300134 \u5c42\u300150 \u5c42\u3001101 \u5c42\u548c 152 \u5c42\uff0c\u6240\u6709\u8fd9\u4e9b\u5c42\u90fd\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8fdb\u884c\u4e86\u6743\u91cd\u9884\u8bad\u7ec3\u3002\u5982\u4eca\uff0c\u9884\u8bad\u7ec3\u6a21\u578b\uff08\u51e0\u4e4e\uff09\u9002\u7528\u4e8e\u6240\u6709\u60c5\u51b5\uff0c\u4f46\u8bf7\u786e\u4fdd\u60a8\u4ece\u8f83\u5c0f\u7684\u6a21\u578b\u5f00\u59cb\uff0c\u4f8b\u5982\uff0c\u4ece resnet-18 \u5f00\u59cb\uff0c\u800c\u4e0d\u662f resnet-50\u3002\u5176\u4ed6\u4e00\u4e9b ImageNet \u9884\u8bad\u7ec3\u6a21\u578b\u5305\u62ec\uff1a Inception DenseNet(different variations) NASNet PNASNet VGG Xception ResNeXt EfficientNet, etc. \u5927\u90e8\u5206\u9884\u8bad\u7ec3\u7684\u6700\u5148\u8fdb\u6a21\u578b\u53ef\u4ee5\u5728 GitHub \u4e0a\u7684 pytorch- pretrainedmodels \u8d44\u6e90\u5e93\u4e2d\u627e\u5230\uff1ahttps://github.com/Cadene/pretrained-models.pytorch\u3002\u8be6\u7ec6\u8ba8\u8bba\u8fd9\u4e9b\u6a21\u578b\u4e0d\u5728\u672c\u7ae0\uff08\u548c\u672c\u4e66\uff09\u8303\u56f4\u4e4b\u5185\u3002\u65e2\u7136\u6211\u4eec\u53ea\u5173\u6ce8\u5e94\u7528\uff0c\u90a3\u5c31\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6837\u7684\u9884\u8bad\u7ec3\u6a21\u578b\u5982\u4f55\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u3002 \u56fe 8\uff1aU-Net\u67b6\u6784 \u5206\u5272\uff08Segmentation\uff09\u662f\u8ba1\u7b97\u673a\u89c6\u89c9\u4e2d\u76f8\u5f53\u6d41\u884c\u7684\u4e00\u9879\u4efb\u52a1\u3002\u5728\u5206\u5272\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u8bd5\u56fe\u4ece\u80cc\u666f\u4e2d\u79fb\u9664/\u63d0\u53d6\u524d\u666f\u3002 \u524d\u666f\u548c\u80cc\u666f\u53ef\u4ee5\u6709\u4e0d\u540c\u7684\u5b9a\u4e49\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u8bf4\uff0c\u8fd9\u662f\u4e00\u9879\u50cf\u7d20\u5206\u7c7b\u4efb\u52a1\uff0c\u4f60\u7684\u5de5\u4f5c\u662f\u7ed9\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u6bcf\u4e2a\u50cf\u7d20\u5206\u914d\u4e00\u4e2a\u7c7b\u522b\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u6b63\u5728\u5904\u7406\u7684\u6c14\u80f8\u6570\u636e\u96c6\u5c31\u662f\u4e00\u9879\u5206\u5272\u4efb\u52a1\u3002\u5728\u8fd9\u9879\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u7ed9\u5b9a\u7684\u80f8\u90e8\u653e\u5c04\u56fe\u50cf\u8fdb\u884c\u6c14\u80f8\u5206\u5272\u3002\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u7684\u6700\u5e38\u7528\u6a21\u578b\u662f U-Net\u3002\u5176\u7ed3\u6784\u5982\u56fe 8 \u6240\u793a\u3002 U-Net \u5305\u62ec\u4e24\u4e2a\u90e8\u5206\uff1a\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u3002\u7f16\u7801\u5668\u4e0e\u60a8\u76ee\u524d\u6240\u89c1\u8fc7\u7684\u4efb\u4f55 U-Net \u90fd\u662f\u4e00\u6837\u7684\u3002\u89e3\u7801\u5668\u5219\u6709\u4e9b\u4e0d\u540c\u3002\u89e3\u7801\u5668\u7531\u4e0a\u5377\u79ef\u5c42\u7ec4\u6210\u3002\u5728\u4e0a\u5377\u79ef\uff08up-convolutions\uff09\uff08 \u8f6c\u7f6e\u5377\u79ef transposed convolutions\uff09\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u6ee4\u6ce2\u5668\uff0c\u5f53\u5e94\u7528\u5230\u4e00\u4e2a\u5c0f\u56fe\u50cf\u65f6\uff0c\u4f1a\u4ea7\u751f\u4e00\u4e2a\u5927\u56fe\u50cf\u3002\u5728 PyTorch \u4e2d\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528 ConvTranspose2d \u6765\u5b8c\u6210\u8fd9\u4e00\u64cd\u4f5c\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u4e0a\u5377\u79ef\u4e0e\u4e0a\u91c7\u6837\u5e76\u4e0d\u76f8\u540c\u3002\u4e0a\u91c7\u6837\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u8fc7\u7a0b\uff0c\u6211\u4eec\u5728\u56fe\u50cf\u4e0a\u5e94\u7528\u4e00\u4e2a\u51fd\u6570\u6765\u8c03\u6574\u5b83\u7684\u5927\u5c0f\u3002\u5728\u4e0a\u5377\u79ef\u4e2d\uff0c\u6211\u4eec\u8981\u5b66\u4e60\u6ee4\u6ce2\u5668\u3002\u6211\u4eec\u5c06\u7f16\u7801\u5668\u7684\u67d0\u4e9b\u90e8\u5206\u4f5c\u4e3a\u67d0\u4e9b\u89e3\u7801\u5668\u7684\u8f93\u5165\u3002\u8fd9\u5bf9 \u4e0a\u5377\u79ef\u5c42\u975e\u5e38\u91cd\u8981\u3002 \u8ba9\u6211\u4eec\u770b\u770b U-Net \u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import torch import torch.nn as nn from torch.nn import functional as F # \u5b9a\u4e49\u4e00\u4e2a\u53cc\u5377\u79ef\u5c42 def double_conv ( in_channels , out_channels ): conv = nn . Sequential ( nn . Conv2d ( in_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ), nn . Conv2d ( out_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ) ) return conv # \u5b9a\u4e49\u51fd\u6570\u7528\u4e8e\u88c1\u526a\u8f93\u5165\u5f20\u91cf def crop_tensor ( tensor , target_tensor ): target_size = target_tensor . size ()[ 2 ] tensor_size = tensor . size ()[ 2 ] delta = tensor_size - target_size delta = delta // 2 return tensor [:, :, delta : tensor_size - delta , delta : tensor_size - delta ] # \u5b9a\u4e49 U-Net \u6a21\u578b class UNet ( nn . Module ): def __init__ ( self ): super ( UNet , self ) . __init () # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . max_pool_2x2 = nn . MaxPool2d ( kernel_size = 2 , stride = 2 ) self . down_conv_1 = double_conv ( 1 , 64 ) self . down_conv_2 = double_conv ( 64 , 128 ) self . down_conv_3 = double_conv ( 128 , 256 ) self . down_conv_4 = double_conv ( 256 , 512 ) self . down_conv_5 = double_conv ( 512 , 1024 ) # \u5b9a\u4e49\u4e0a\u91c7\u6837\u5c42\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . up_trans_1 = nn . ConvTranspose2d ( in_channels = 1024 , out_channels = 512 , kernel_size = 2 , stride = 2 ) self . up_conv_1 = double_conv ( 1024 , 512 ) self . up_trans_2 = nn . ConvTranspose2d ( in_channels = 512 , out_channels = 256 , kernel_size = 2 , stride = 2 ) self . up_conv_2 = double_conv ( 512 , 256 ) self . up_trans_3 = nn . ConvTranspose2d ( in_channels = 256 , out_channels = 128 , kernel_size = 2 , stride = 2 ) self . up_conv_3 = double_conv ( 256 , 128 ) self . up_trans_4 = nn . ConvTranspose2d ( in_channels = 128 , out_channels = 64 , kernel_size = 2 , stride = 2 ) self . up_conv_4 = double_conv ( 128 , 64 ) # \u5b9a\u4e49\u8f93\u51fa\u5c42 self . out = nn . Conv2d ( in_channels = 64 , out_channels = 2 , kernel_size = 1 ) def forward ( self , image ): # \u7f16\u7801\u5668\u90e8\u5206 x1 = self . down_conv_1 ( image ) x2 = self . max_pool_2x2 ( x1 ) x3 = self . down_conv_2 ( x2 ) x4 = self . max_pool_2x2 ( x3 ) x5 = self . down_conv_3 ( x4 ) x6 = self . max_pool_2x2 ( x5 ) x7 = self . down_conv_4 ( x6 ) x8 = self . max_pool_2x2 ( x7 ) x9 = self . down_conv_5 ( x8 ) # \u89e3\u7801\u5668\u90e8\u5206 x = self . up_trans_1 ( x9 ) y = crop_tensor ( x7 , x ) x = self . up_conv_1 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_2 ( x ) y = crop_tensor ( x5 , x ) x = self . up_conv_2 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_3 ( x ) y = crop_tensor ( x3 , x ) x = self . up_conv_3 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_4 ( x ) y = crop_tensor ( x1 , x ) x = self . up_conv_4 ( torch . cat ([ x , y ], axis = 1 )) # \u8f93\u51fa\u5c42 out = self . out ( x ) return out if __name__ == \"__main__\" : image = torch . rand (( 1 , 1 , 572 , 572 )) model = UNet () print ( model ( image )) \u8bf7\u6ce8\u610f\uff0c\u6211\u4e0a\u9762\u5c55\u793a\u7684 U-Net \u5b9e\u73b0\u662f U-Net \u8bba\u6587\u7684\u539f\u59cb\u5b9e\u73b0\u3002\u4e92\u8054\u7f51\u4e0a\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9e\u73b0\u65b9\u6cd5\u3002 \u6709\u4e9b\u4eba\u559c\u6b22\u4f7f\u7528\u53cc\u7ebf\u6027\u91c7\u6837\u4ee3\u66ff\u8f6c\u7f6e\u5377\u79ef\u8fdb\u884c\u4e0a\u91c7\u6837\uff0c\u4f46\u8fd9\u5e76\u4e0d\u662f\u8bba\u6587\u7684\u771f\u6b63\u5b9e\u73b0\u3002\u4e0d\u8fc7\uff0c\u5b83\u7684\u6027\u80fd\u53ef\u80fd\u4f1a\u66f4\u597d\u3002\u5728\u4e0a\u56fe\u6240\u793a\u7684\u539f\u59cb\u5b9e\u73b0\u4e2d\uff0c\u6709\u4e00\u4e2a\u5355\u901a\u9053\u56fe\u50cf\uff0c\u8f93\u51fa\u4e2d\u6709\u4e24\u4e2a\u901a\u9053\uff1a\u4e00\u4e2a\u662f\u524d\u666f\uff0c\u4e00\u4e2a\u662f\u80cc\u666f\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u4e3a\u4efb\u610f\u6570\u91cf\u7684\u7c7b\u548c\u4efb\u610f\u6570\u91cf\u7684\u8f93\u5165\u901a\u9053\u8fdb\u884c\u5b9a\u5236\u3002\u5728\u6b64\u5b9e\u73b0\u4e2d\uff0c\u8f93\u5165\u56fe\u50cf\u7684\u5927\u5c0f\u4e0e\u8f93\u51fa\u56fe\u50cf\u7684\u5927\u5c0f\u4e0d\u540c\uff0c\u56e0\u4e3a\u6211\u4eec\u4f7f\u7528\u7684\u662f\u65e0\u586b\u5145\u5377\u79ef\uff08convolutions without padding\uff09\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0cU-Net \u7684\u7f16\u7801\u5668\u90e8\u5206\u53ea\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u5377\u79ef\u7f51\u7edc\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u4efb\u4f55\u7f51\u7edc\uff08\u5982 ResNet\uff09\u6765\u66ff\u6362\u5b83\u3002 \u8fd9\u79cd\u66ff\u6362\u4e5f\u53ef\u4ee5\u901a\u8fc7\u9884\u8bad\u7ec3\u6743\u91cd\u6765\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u57fa\u4e8e ResNet \u7684\u7f16\u7801\u5668\uff0c\u8be5\u7f16\u7801\u5668\u5df2\u5728 ImageNet \u548c\u901a\u7528\u89e3\u7801\u5668\u4e0a\u8fdb\u884c\u4e86\u9884\u8bad\u7ec3\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u591a\u79cd\u4e0d\u540c\u7684\u7f51\u7edc\u67b6\u6784\u6765\u4ee3\u66ff ResNet\u3002Pavel Yakubovskiy \u6240\u8457\u7684\u300aSegmentation Models Pytorch\u300b\u5c31\u662f\u8bb8\u591a\u6b64\u7c7b\u53d8\u4f53\u7684\u5b9e\u73b0\uff0c\u5176\u4e2d\u7f16\u7801\u5668\u53ef\u4ee5\u88ab\u9884\u8bad\u7ec3\u6a21\u578b\u6240\u53d6\u4ee3\u3002\u8ba9\u6211\u4eec\u5e94\u7528\u57fa\u4e8e ResNet \u7684 U-Net \u6765\u89e3\u51b3\u6c14\u80f8\u68c0\u6d4b\u95ee\u9898\u3002 \u5927\u591a\u6570\u7c7b\u4f3c\u7684\u95ee\u9898\u90fd\u6709\u4e24\u4e2a\u8f93\u5165\uff1a\u539f\u59cb\u56fe\u50cf\u548c\u63a9\u7801\uff08mask\uff09\u3002 \u5982\u679c\u6709\u591a\u4e2a\u5bf9\u8c61\uff0c\u5c31\u4f1a\u6709\u591a\u4e2a\u63a9\u7801\u3002 \u5728\u6211\u4eec\u7684\u6c14\u80f8\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f RLE\u3002RLE \u4ee3\u8868\u8fd0\u884c\u957f\u5ea6\u7f16\u7801\uff0c\u662f\u4e00\u79cd\u8868\u793a\u4e8c\u8fdb\u5236\u63a9\u7801\u4ee5\u8282\u7701\u7a7a\u95f4\u7684\u65b9\u6cd5\u3002\u6df1\u5165\u7814\u7a76 RLE \u8d85\u51fa\u4e86\u672c\u7ae0\u7684\u8303\u56f4\u3002\u56e0\u6b64\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u5f20\u8f93\u5165\u56fe\u50cf\u548c\u76f8\u5e94\u7684\u63a9\u7801\u3002\u8ba9\u6211\u4eec\u5148\u8bbe\u8ba1\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u8f93\u51fa\u56fe\u50cf\u548c\u63a9\u7801\u56fe\u50cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u521b\u5efa\u7684\u811a\u672c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u5206\u5272\u95ee\u9898\u3002\u8bad\u7ec3\u6570\u636e\u96c6\u662f\u4e00\u4e2a CSV \u6587\u4ef6\uff0c\u53ea\u5305\u542b\u56fe\u50cf ID\uff08\u4e5f\u662f\u6587\u4ef6\u540d\uff09\u3002 import os import glob import torch import numpy as np import pandas as pd from PIL import Image , ImageFile from tqdm import tqdm from collections import defaultdict from torchvision import transforms from albumentations import ( Compose , OneOf , RandomBrightnessContrast , RandomGamma , ShiftScaleRotate , ) # \u8bbe\u7f6ePIL\u56fe\u50cf\u52a0\u8f7d\u622a\u65ad\u7684\u5904\u7406 ImageFile . LOAD_TRUNCATED_IMAGES = True # \u521b\u5efaSIIM\u6570\u636e\u96c6\u7c7b class SIIMDataset ( torch . utils . data . Dataset ): def __init__ ( self , image_ids , transform = True , preprocessing_fn = None ): self . data = defaultdict ( dict ) self . transform = transform self . preprocessing_fn = preprocessing_fn # \u5b9a\u4e49\u6570\u636e\u589e\u5f3a self . aug = Compose ( [ ShiftScaleRotate ( shift_limit = 0.0625 , scale_limit = 0.1 , rotate_limit = 10 , p = 0.8 ), OneOf ( [ RandomGamma ( gamma_limit = ( 90 , 110 ) ), RandomBrightnessContrast ( brightness_limit = 0.1 , contrast_limit = 0.1 ), ], p = 0.5 , ), ] ) # \u6784\u5efa\u6570\u636e\u5b57\u5178\uff0c\u5176\u4e2d\u5305\u542b\u56fe\u50cf\u548c\u63a9\u7801\u7684\u8def\u5f84\u4fe1\u606f for imgid in image_ids : files = glob . glob ( os . path . join ( TRAIN_PATH , imgid , \"*.png\" )) self . data [ counter ] = { \"img_path\" : os . path . join ( TRAIN_PATH , imgid + \".png\" ), \"mask_path\" : os . path . join ( TRAIN_PATH , imgid + \"_mask.png\" ), } def __len__ ( self ): return len ( self . data ) def __getitem__ ( self , item ): img_path = self . data [ item ][ \"img_path\" ] mask_path = self . data [ item ][ \"mask_path\" ] # \u6253\u5f00\u56fe\u50cf\u5e76\u5c06\u5176\u8f6c\u6362\u4e3aRGB\u6a21\u5f0f img = Image . open ( img_path ) img = img . convert ( \"RGB\" ) img = np . array ( img ) # \u6253\u5f00\u63a9\u7801\u56fe\u50cf\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6570 mask = Image . open ( mask_path ) mask = ( mask >= 1 ) . astype ( \"float32\" ) # \u5982\u679c\u9700\u8981\u8fdb\u884c\u6570\u636e\u589e\u5f3a if self . transform is True : augmented = self . aug ( image = img , mask = mask ) img = augmented [ \"image\" ] mask = augmented [ \"mask\" ] # \u5e94\u7528\u9884\u5904\u7406\u51fd\u6570\uff08\u5982\u679c\u6709\uff09 img = self . preprocessing_fn ( img ) # \u8fd4\u56de\u56fe\u50cf\u548c\u63a9\u7801 return { \"image\" : transforms . ToTensor ()( img ), \"mask\" : transforms . ToTensor ()( mask ) . float (), } \u6709\u4e86\u6570\u636e\u96c6\u7c7b\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8bad\u7ec3\u51fd\u6570\u3002 import os import sys import torch import numpy as np import pandas as pd import segmentation_models_pytorch as smp import torch.nn as nn import torch.optim as optim from apex import amp from collections import OrderedDict from sklearn import model_selection from tqdm import tqdm from torch.optim import lr_scheduler from dataset import SIIMDataset # \u5b9a\u4e49\u8bad\u7ec3\u6570\u636e\u96c6CSV\u6587\u4ef6\u8def\u5f84 TRAINING_CSV = \"../input/train_pneumothorax.csv\" # \u5b9a\u4e49\u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u6279\u91cf\u5927\u5c0f TRAINING_BATCH_SIZE = 16 TEST_BATCH_SIZE = 4 # \u5b9a\u4e49\u8bad\u7ec3\u7684\u65f6\u671f\u6570 EPOCHS = 10 # \u6307\u5b9a\u4f7f\u7528\u7684\u7f16\u7801\u5668\u548c\u6743\u91cd ENCODER = \"resnet18\" ENCODER_WEIGHTS = \"imagenet\" # \u6307\u5b9a\u8bbe\u5907\uff08GPU\uff09 DEVICE = \"cuda\" # \u5b9a\u4e49\u8bad\u7ec3\u51fd\u6570 def train ( dataset , data_loader , model , criterion , optimizer ): model . train () num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs . to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) optimizer . zero_grad () outputs = model ( inputs ) loss = criterion ( outputs , targets ) with amp . scale_loss ( loss , optimizer ) as scaled_loss : scaled_loss . backward () optimizer . step () tk0 . close () # \u5b9a\u4e49\u8bc4\u4f30\u51fd\u6570 def evaluate ( dataset , data_loader , model ): model . eval () final_loss = 0 num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) with torch . no_grad (): for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) output = model ( inputs ) loss = criterion ( output , targets ) final_loss += loss tk0 . close () return final_loss / num_batches if __name__ == \"__main__\" : df = pd . read_csv ( TRAINING_CSV ) df_train , df_valid = model_selection . train_test_split ( df , random_state = 42 , test_size = 0.1 ) training_images = df_train . image_id . values validation_images = df_valid . image_id . values # \u521b\u5efa U-Net \u6a21\u578b model = smp . Unet ( encoder_name = ENCODER , encoder_weights = ENCODER_WEIGHTS , classes = 1 , activation = None , ) # \u83b7\u53d6\u6570\u636e\u9884\u5904\u7406\u51fd\u6570 prep_fn = smp . encoders . get_preprocessing_fn ( ENCODER , ENCODER_WEIGHTS ) # \u5c06\u6a21\u578b\u653e\u5728\u8bbe\u5907\u4e0a model . to ( DEVICE ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6 train_dataset = SIIMDataset ( training_images , transform = True , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = TRAINING_BATCH_SIZE , shuffle = True , num_workers = 12 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = SIIMDataset ( validation_images , transform = False , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = TEST_BATCH_SIZE , shuffle = True , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) # \u5b9a\u4e49\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = lr_scheduler . ReduceLROnPlateau ( optimizer , mode = \"min\" , patience = 3 , verbose = True ) # \u521d\u59cb\u5316 Apex \u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3 model , optimizer = amp . initialize ( model , optimizer , opt_level = \"O1\" , verbosity = 0 ) # \u5982\u679c\u6709\u591a\u4e2aGPU\uff0c\u5219\u4f7f\u7528 DataParallel \u8fdb\u884c\u5e76\u884c\u8bad\u7ec3 if torch . cuda . device_count () > 1 : print ( f \"Let's use { torch . cuda . device_count () } GPUs!\" ) model = nn . DataParallel ( model ) # \u8f93\u51fa\u8bad\u7ec3\u76f8\u5173\u7684\u4fe1\u606f print ( f \"Training batch size: { TRAINING_BATCH_SIZE } \" ) print ( f \"Test batch size: { TEST_BATCH_SIZE } \" ) print ( f \"Epochs: { EPOCHS } \" ) print ( f \"Image size: { IMAGE_SIZE } \" ) print ( f \"Number of training images: { len ( train_dataset ) } \" ) print ( f \"Number of validation images: { len ( valid_dataset ) } \" ) print ( f \"Encoder: { ENCODER } \" ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( EPOCHS ): print ( f \"Training Epoch: { epoch } \" ) train ( train_dataset , train_loader , model , criterion , optimizer ) print ( f \"Validation Epoch: { epoch } \" ) val_log = evaluate ( valid_dataset , valid_loader , model ) scheduler . step ( val_log [ \"loss\" ]) print ( \" \\n \" ) \u5728\u5206\u5272\u95ee\u9898\u4e2d\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u5404\u79cd\u635f\u5931\u51fd\u6570\uff0c\u4f8b\u5982\u4e8c\u5143\u4ea4\u53c9\u71b5\u3001focal\u635f\u5931\u3001dice\u635f\u5931\u7b49\u3002\u6211\u628a\u8fd9\u4e2a\u95ee\u9898\u7559\u7ed9 \u8bfb\u8005\u6839\u636e\u8bc4\u4f30\u6307\u6807\u6765\u51b3\u5b9a\u5408\u9002\u7684\u635f\u5931\u3002\u5f53\u8bad\u7ec3\u8fd9\u6837\u4e00\u4e2a\u6a21\u578b\u65f6\uff0c\u60a8\u5c06\u5efa\u7acb\u9884\u6d4b\u6c14\u80f8\u4f4d\u7f6e\u7684\u6a21\u578b\uff0c\u5982\u56fe 9 \u6240\u793a\u3002\u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u82f1\u4f1f\u8fbe apex \u8fdb\u884c\u4e86\u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4ece PyTorch 1.6.0+ \u7248\u672c\u5f00\u59cb\uff0cPyTorch \u672c\u8eab\u5c31\u63d0\u4f9b\u4e86\u8fd9\u4e00\u529f\u80fd\u3002 \u56fe 9\uff1a\u4ece\u8bad\u7ec3\u6709\u7d20\u7684\u6a21\u578b\u4e2d\u68c0\u6d4b\u5230\u6c14\u80f8\u7684\u793a\u4f8b\uff08\u53ef\u80fd\u4e0d\u662f\u6b63\u786e\u9884\u6d4b\uff09\u3002 \u6211\u5728\u4e00\u4e2a\u540d\u4e3a \"Well That's Fantastic Machine Learning (WTFML) \"\u7684 python \u8f6f\u4ef6\u5305\u4e2d\u6536\u5f55\u4e86\u4e00\u4e9b\u5e38\u7528\u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u5982\u4f55\u5e2e\u52a9\u6211\u4eec\u4e3a FGVC 202013 \u690d\u7269\u75c5\u7406\u5b66\u6311\u6218\u8d5b\u4e2d\u7684\u690d\u7269\u56fe\u50cf\u5efa\u7acb\u591a\u7c7b\u5206\u7c7b\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import argparse import torch import torchvision import torch.nn as nn import torch.nn.functional as F from sklearn import metrics from sklearn.model_selection import train_test_split from wtfml.engine import Engine from wtfml.data_loaders.image import ClassificationDataLoader # \u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\uff0c\u5b9e\u73b0\u5bc6\u96c6\u4ea4\u53c9\u71b5 class DenseCrossEntropy ( nn . Module ): def __init__ ( self ): super ( DenseCrossEntropy , self ) . __init__ () def forward ( self , logits , labels ): logits = logits . float () labels = labels . float () logprobs = F . log_softmax ( logits , dim =- 1 ) loss = - labels * logprobs loss = loss . sum ( - 1 ) return loss . mean () # \u81ea\u5b9a\u4e49\u795e\u7ecf\u7f51\u7edc\u6a21\u578b class Model ( nn . Module ): def __init__ ( self ): super () . __init () self . base_model = torchvision . models . resnet18 ( pretrained = True ) in_features = self . base_model . fc . in_features self . out = nn . Linear ( in_features , 4 ) def forward ( self , image , targets = None ): batch_size , C , H , W = image . shape x = self . base_model . conv1 ( image ) x = self . base_model . bn1 ( x ) x = self . base_model . relu ( x ) x = self . base_model . maxpool ( x ) x = self . base_model . layer1 ( x ) x = self . base_model . layer2 ( x ) x = self . base_model . layer3 ( x ) x = self . base_model . layer4 ( x ) x = F . adaptive_avg_pool2d ( x , 1 ) . reshape ( batch_size , - 1 ) x = self . out ( x ) loss = None if targets is not None : loss = DenseCrossEntropy ()( x , targets . type_as ( x )) return x , loss if __name__ == \"__main__\" : # \u547d\u4ee4\u884c\u53c2\u6570\u89e3\u6790\u5668 parser = argparse . ArgumentParser () parser . add_argument ( \"--data_path\" , type = str , ) parser . add_argument ( \"--device\" , type = str ,) parser . add_argument ( \"--epochs\" , type = int ,) args = parser . parse_args () # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( os . path . join ( args . data_path , \"train.csv\" )) images = df . image_id . values . tolist () images = [ os . path . join ( args . data_path , \"images\" , i + \".jpg\" ) for i in images ] targets = df [[ \"healthy\" , \"multiple_diseases\" , \"rust\" , \"scab\" ]] . values # \u521b\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u578b model = Model () model . to ( args . device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\u4ee5\u53ca\u6570\u636e\u589e\u5f3a mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5206\u5272\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 ( train_images , valid_images , train_targets , valid_targets ) = train_test_split ( images , targets ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = ClassificationDataLoader ( image_paths = train_images , targets = train_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = True , tpu = False ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = ClassificationDataLoader ( image_paths = valid_images , targets = valid_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = False , tpu = False ) # \u521b\u5efa\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u521b\u5efa\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = torch . optim . lr_scheduler . StepLR ( optimizer , step_size = 15 , gamma = 0.6 ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( args . epochs ): # \u8bad\u7ec3\u6a21\u578b train_loss = Engine . train ( train_loader , model , optimizer , device = args . device ) # \u8bc4\u4f30\u6a21\u578b valid_loss = Engine . evaluate ( valid_loader , model , device = args . device ) # \u6253\u5370\u635f\u5931\u4fe1\u606f print ( f \" { epoch } , Train Loss= { train_loss } Valid Loss= { valid_loss } \" ) \u6709\u4e86\u6570\u636e\u540e\uff0c\u5c31\u53ef\u4ee5\u8fd0\u884c\u811a\u672c\u4e86\uff1a python plant.py --data_path ../../plant_pathology --device cuda -- epochs 2 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .73it/s, loss = 0 .723 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .62it/s, loss = 0 .433 ] 0 , Train Loss = 0 .7228777609592261 Valid Loss = 0 .4327834551704341 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .74it/s, loss = 0 .271 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .63it/s, loss = 0 .568 ] 1 , Train Loss = 0 .2708700496790021 Valid Loss = 0 .56841839541649 \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u8ba9\u6211\u4eec\u6784\u5efa\u6a21\u578b\u53d8\u5f97\u7b80\u5355\uff0c\u4ee3\u7801\u4e5f\u6613\u4e8e\u9605\u8bfb\u548c\u7406\u89e3\u3002\u6ca1\u6709\u4efb\u4f55\u5c01\u88c5\u7684 PyTorch \u6548\u679c\u6700\u597d\u3002\u56fe\u50cf\u4e2d\u4e0d\u4ec5\u4ec5\u6709\u5206\u7c7b\uff0c\u8fd8\u6709\u5f88\u591a\u5176\u4ed6\u7684\u5185\u5bb9\uff0c\u5982\u679c\u6211\u5f00\u59cb\u5199\u6240\u6709\u7684\u5185\u5bb9\uff0c\u5c31\u5f97\u518d\u5199\u4e00\u672c\u4e66\u4e86\uff0c \u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u56fe\u50cf\u95ee\u9898\uff08\u4f5c\u8005\u5728\u5f00\u73a9\u7b11\uff09\u3002","title":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5"},{"location":"%E5%A4%84%E7%90%86%E5%88%86%E7%B1%BB%E5%8F%98%E9%87%8F/","text":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf \u5f88\u591a\u4eba\u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u65f6\u90fd\u4f1a\u9047\u5230\u5f88\u591a\u56f0\u96be\uff0c\u56e0\u6b64\u8fd9\u503c\u5f97\u7528\u6574\u6574\u4e00\u7ae0\u7684\u7bc7\u5e45\u6765\u8ba8\u8bba\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u5c06\u8bb2\u8ff0\u4e0d\u540c\u7c7b\u578b\u7684\u5206\u7c7b\u6570\u636e\uff0c\u4ee5\u53ca\u5982\u4f55\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u3002 \u4ec0\u4e48\u662f\u5206\u7c7b\u53d8\u91cf\uff1f \u5206\u7c7b\u53d8\u91cf/\u7279\u5f81\u662f\u6307\u4efb\u4f55\u7279\u5f81\u7c7b\u578b\uff0c\u53ef\u5206\u4e3a\u4e24\u5927\u7c7b\uff1a - \u65e0\u5e8f - \u6709\u5e8f \u65e0\u5e8f\u53d8\u91cf \u662f\u6307\u6709\u4e24\u4e2a\u6216\u4e24\u4e2a\u4ee5\u4e0a\u7c7b\u522b\u7684\u53d8\u91cf\uff0c\u8fd9\u4e9b\u7c7b\u522b\u6ca1\u6709\u4efb\u4f55\u76f8\u5173\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u5982\u679c\u5c06\u6027\u522b\u5206\u4e3a\u4e24\u7ec4\uff0c\u5373\u7537\u6027\u548c\u5973\u6027\uff0c\u5219\u53ef\u5c06\u5176\u89c6\u4e3a\u540d\u4e49\u53d8\u91cf\u3002 \u6709\u5e8f\u53d8\u91cf \u5219\u6709 \"\u7b49\u7ea7 \"\u6216\u7c7b\u522b\uff0c\u5e76\u6709\u7279\u5b9a\u7684\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u4e00\u4e2a\u987a\u5e8f\u5206\u7c7b\u53d8\u91cf\u53ef\u4ee5\u662f\u4e00\u4e2a\u5177\u6709\u4f4e\u3001\u4e2d\u3001\u9ad8\u4e09\u4e2a\u4e0d\u540c\u7b49\u7ea7\u7684\u7279\u5f81\u3002\u987a\u5e8f\u5f88\u91cd\u8981\u3002 \u5c31\u5b9a\u4e49\u800c\u8a00\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u53d8\u91cf\u5206\u4e3a \u4e8c\u5143\u53d8\u91cf \uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u7c7b\u522b\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u6709\u4e9b\u4eba\u751a\u81f3\u628a\u5206\u7c7b\u53d8\u91cf\u79f0\u4e3a \" \u5faa\u73af \"\u53d8\u91cf\u3002\u5468\u671f\u53d8\u91cf\u4ee5 \"\u5468\u671f \"\u7684\u5f62\u5f0f\u5b58\u5728\uff0c\u4f8b\u5982\u4e00\u5468\u4e2d\u7684\u5929\u6570\uff1a \u5468\u65e5\u3001\u5468\u4e00\u3001\u5468\u4e8c\u3001\u5468\u4e09\u3001\u5468\u56db\u3001\u5468\u4e94\u548c\u5468\u516d\u3002\u5468\u516d\u8fc7\u540e\uff0c\u53c8\u662f\u5468\u65e5\u3002\u8fd9\u5c31\u662f\u4e00\u4e2a\u5faa\u73af\u3002\u53e6\u4e00\u4e2a\u4f8b\u5b50\u662f\u4e00\u5929\u4e2d\u7684\u5c0f\u65f6\u6570\uff0c\u5982\u679c\u6211\u4eec\u5c06\u5b83\u4eec\u89c6\u4e3a\u7c7b\u522b\u7684\u8bdd\u3002 \u5206\u7c7b\u53d8\u91cf\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9a\u4e49\uff0c\u5f88\u591a\u4eba\u4e5f\u8c08\u5230\u8981\u6839\u636e\u5206\u7c7b\u53d8\u91cf\u7684\u7c7b\u578b\u6765\u5904\u7406\u4e0d\u540c\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u4e0d\u8fc7\uff0c\u6211\u8ba4\u4e3a\u6ca1\u6709\u5fc5\u8981\u8fd9\u6837\u505a\u3002\u6240\u6709\u6d89\u53ca\u5206\u7c7b\u53d8\u91cf\u7684\u95ee\u9898\u90fd\u53ef\u4ee5\u7528\u540c\u6837\u7684\u65b9\u6cd5\u5904\u7406\u3002 \u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e00\u4e2a\u6570\u636e\u96c6\uff08\u4e00\u5982\u65e2\u5f80\uff09\u3002\u8981\u4e86\u89e3\u5206\u7c7b\u53d8\u91cf\uff0c\u6700\u597d\u7684\u514d\u8d39\u6570\u636e\u96c6\u4e4b\u4e00\u662f Kaggle \u5206\u7c7b\u7279\u5f81\u7f16\u7801\u6311\u6218\u8d5b\u4e2d\u7684 cat-in-the-dat \u3002\u5171\u6709\u4e24\u4e2a\u6311\u6218\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6311\u6218\u7684\u6570\u636e\uff0c\u56e0\u4e3a\u5b83\u6bd4\u524d\u4e00\u4e2a\u7248\u672c\u6709\u66f4\u591a\u53d8\u91cf\uff0c\u96be\u5ea6\u4e5f\u66f4\u5927\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u3002 \u56fe 1\uff1aCat-in-the-dat-ii challenge\u90e8\u5206\u6570\u636e\u5c55\u793a \u6570\u636e\u96c6\u7531\u5404\u79cd\u5206\u7c7b\u53d8\u91cf\u7ec4\u6210\uff1a \u65e0\u5e8f \u6709\u5e8f \u5faa\u73af \u4e8c\u5143 \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6240\u6709\u5b58\u5728\u7684\u53d8\u91cf\u548c\u76ee\u6807\u53d8\u91cf\u7684\u5b50\u96c6\u3002 \u8fd9\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002 \u76ee\u6807\u53d8\u91cf\u5bf9\u4e8e\u6211\u4eec\u5b66\u4e60\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\u5e76\u4e0d\u5341\u5206\u91cd\u8981\uff0c\u4f46\u6700\u7ec8\u6211\u4eec\u5c06\u5efa\u7acb\u4e00\u4e2a\u7aef\u5230\u7aef\u6a21\u578b\uff0c\u56e0\u6b64\u8ba9\u6211\u4eec\u770b\u770b\u56fe 2 \u4e2d\u7684\u76ee\u6807\u53d8\u91cf\u5206\u5e03\u3002\u6211\u4eec\u770b\u5230\u76ee\u6807\u662f \u504f\u659c \u7684\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8fd9\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u6765\u8bf4\uff0c\u6700\u597d\u7684\u6307\u6807\u662f ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528\u7cbe\u786e\u5ea6\u548c\u53ec\u56de\u7387\uff0c\u4f46 AUC \u7ed3\u5408\u4e86\u8fd9\u4e24\u4e2a\u6307\u6807\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 AUC \u6765\u8bc4\u4f30\u6211\u4eec\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u7684\u6a21\u578b\u3002 \u56fe 2\uff1a\u6807\u7b7e\u8ba1\u6570\u3002X \u8f74\u8868\u793a\u6807\u7b7e\uff0cY \u8f74\u8868\u793a\u6807\u7b7e\u8ba1\u6570 \u603b\u4f53\u800c\u8a00\uff0c\u6709\uff1a 5\u4e2a\u4e8c\u5143\u53d8\u91cf 10\u4e2a\u65e0\u5e8f\u53d8\u91cf 6\u4e2a\u6709\u5e8f\u53d8\u91cf 2\u4e2a\u5faa\u73af\u53d8\u91cf 1\u4e2a\u76ee\u6807\u53d8\u91cf \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u96c6\u4e2d\u7684 ord_2 \u7279\u5f81\u3002\u5b83\u5305\u62ec6\u4e2a\u4e0d\u540c\u7684\u7c7b\u522b\uff1a - \u51b0\u51bb - \u6e29\u6696 - \u5bd2\u51b7 - \u8f83\u70ed - \u70ed - \u975e\u5e38\u70ed \u6211\u4eec\u5fc5\u987b\u77e5\u9053\uff0c\u8ba1\u7b97\u673a\u65e0\u6cd5\u7406\u89e3\u6587\u672c\u6570\u636e\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u662f\u521b\u5efa\u4e00\u4e2a\u5b57\u5178\uff0c\u5c06\u8fd9\u4e9b\u503c\u6620\u5c04\u4e3a\u4ece 0 \u5230 N-1 \u7684\u6570\u5b57\uff0c\u5176\u4e2d N \u662f\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7c7b\u522b\u7684\u603b\u6570\u3002 # \u6620\u5c04\u5b57\u5178 mapping = { \"Freezing\" : 0 , \"Warm\" : 1 , \"Cold\" : 2 , \"Boiling Hot\" : 3 , \"Hot\" : 4 , \"Lava Hot\" : 5 } \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8bfb\u53d6\u6570\u636e\u96c6\uff0c\u5e76\u8f7b\u677e\u5730\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002 import pandas as pd # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u53d6*ord_2*\u5217\uff0c\u5e76\u4f7f\u7528\u6620\u5c04\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57 df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. map ( mapping ) \u6620\u5c04\u524d\u7684\u6570\u503c\u8ba1\u6570\uff1a df .* ord_2 *. value_counts () Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : * ord_2 * , dtype : int64 \u6620\u5c04\u540e\u7684\u6570\u503c\u8ba1\u6570\uff1a 0.0 142726 1.0 124239 2.0 97822 3.0 84790 4.0 67508 5.0 64840 Name : * ord_2 * , dtype : int64 \u8fd9\u79cd\u5206\u7c7b\u53d8\u91cf\u7684\u7f16\u7801\u65b9\u5f0f\u88ab\u79f0\u4e3a\u6807\u7b7e\u7f16\u7801\uff08Label Encoding\uff09\u6211\u4eec\u5c06\u6bcf\u4e2a\u7c7b\u522b\u7f16\u7801\u4e3a\u4e00\u4e2a\u6570\u5b57\u6807\u7b7e\u3002 \u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 LabelEncoder \u8fdb\u884c\u7f16\u7801\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u5c06\u7f3a\u5931\u503c\u586b\u5145\u4e3a\"NONE\" df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. fillna ( \"NONE\" ) # LabelEncoder\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u8f6c\u6362\u6570\u636e df . loc [:, \"*ord_2*\" ] = lbl_enc . fit_transform ( df .* ord_2 *. values ) \u4f60\u4f1a\u770b\u5230\u6211\u4f7f\u7528\u4e86 pandas \u7684 fillna\u3002\u539f\u56e0\u662f scikit-learn \u7684 LabelEncoder \u65e0\u6cd5\u5904\u7406 NaN \u503c\uff0c\u800c ord_2 \u5217\u4e2d\u6709 NaN \u503c\u3002 \u6211\u4eec\u53ef\u4ee5\u5728\u8bb8\u591a\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u4e2d\u76f4\u63a5\u4f7f\u7528\u5b83\uff1a - \u51b3\u7b56\u6811 - \u968f\u673a\u68ee\u6797 - \u63d0\u5347\u6811 - \u6216\u4efb\u4f55\u4e00\u79cd\u63d0\u5347\u6811\u6a21\u578b - XGBoost - GBM - LightGBM \u8fd9\u79cd\u7f16\u7801\u65b9\u5f0f\u4e0d\u80fd\u7528\u4e8e\u7ebf\u6027\u6a21\u578b\u3001\u652f\u6301\u5411\u91cf\u673a\u6216\u795e\u7ecf\u7f51\u7edc\uff0c\u56e0\u4e3a\u5b83\u4eec\u5e0c\u671b\u6570\u636e\u662f\u6807\u51c6\u5316\u7684\u3002 \u5bf9\u4e8e\u8fd9\u4e9b\u7c7b\u578b\u7684\u6a21\u578b\uff0c\u6211\u4eec\u53ef\u4ee5\u5bf9\u6570\u636e\u8fdb\u884c\u4e8c\u503c\u5316\uff08binarize\uff09\u5904\u7406\u3002 Freezing --> 0 --> 0 0 0 Warm --> 1 --> 0 0 1 Cold --> 2 --> 0 1 0 Boiling Hot --> 3 --> 0 1 1 Hot --> 4 --> 1 0 0 Lava Hot --> 5 --> 1 0 1 \u8fd9\u53ea\u662f\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\uff0c\u7136\u540e\u518d\u8f6c\u6362\u4e3a\u4e8c\u503c\u5316\u8868\u793a\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u628a\u4e00\u4e2a\u7279\u5f81\u5206\u6210\u4e86\u4e09\u4e2a\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u7279\u5f81\uff08\u6216\u5217\uff09\u3002\u5982\u679c\u6211\u4eec\u6709\u66f4\u591a\u7684\u7c7b\u522b\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u5206\u6210\u66f4\u591a\u7684\u5217\u3002 \u5982\u679c\u6211\u4eec\u7528\u7a00\u758f\u683c\u5f0f\u5b58\u50a8\u5927\u91cf\u4e8c\u503c\u5316\u53d8\u91cf\uff0c\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u5b58\u50a8\u8fd9\u4e9b\u53d8\u91cf\u3002\u7a00\u758f\u683c\u5f0f\u4e0d\u8fc7\u662f\u4e00\u79cd\u5728\u5185\u5b58\u4e2d\u5b58\u50a8\u6570\u636e\u7684\u8868\u793a\u6216\u65b9\u5f0f\uff0c\u5728\u8fd9\u79cd\u683c\u5f0f\u4e2d\uff0c\u4f60\u5e76\u4e0d\u5b58\u50a8\u6240\u6709\u7684\u503c\uff0c\u800c\u53ea\u5b58\u50a8\u91cd\u8981\u7684\u503c\u3002\u5728\u4e0a\u8ff0\u4e8c\u8fdb\u5236\u53d8\u91cf\u7684\u60c5\u51b5\u4e2d\uff0c\u6700\u91cd\u8981\u7684\u5c31\u662f\u6709 1 \u7684\u5730\u65b9\u3002 \u5f88\u96be\u60f3\u8c61\u8fd9\u6837\u7684\u683c\u5f0f\uff0c\u4f46\u4e3e\u4e2a\u4f8b\u5b50\u5c31\u4f1a\u660e\u767d\u3002 \u5047\u8bbe\u4e0a\u9762\u7684\u6570\u636e\u5e27\u4e2d\u53ea\u6709\u4e00\u4e2a\u7279\u5f81\uff1a ord_2 \u3002 Index Feature 0 Warm 1 Hot 2 Lava hot \u76ee\u524d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6570\u636e\u96c6\u4e2d\u7684\u4e09\u4e2a\u6837\u672c\u3002\u8ba9\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u4e8c\u503c\u8868\u793a\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u6709\u4e09\u4e2a\u9879\u76ee\u3002 \u8fd9\u4e09\u4e2a\u9879\u76ee\u5c31\u662f\u4e09\u4e2a\u7279\u5f81\u3002 Index Feature_0 Feature_1 Feature_2 0 0 0 1 1 1 0 0 2 1 0 1 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u7279\u5f81\u5b58\u50a8\u5728\u4e00\u4e2a\u6709 3 \u884c 3 \u5217\uff083x3\uff09\u7684\u77e9\u9635\u4e2d\u3002\u77e9\u9635\u7684\u6bcf\u4e2a\u5143\u7d20\u5360\u7528 8 \u4e2a\u5b57\u8282\u3002\u56e0\u6b64\uff0c\u8fd9\u4e2a\u6570\u7ec4\u7684\u603b\u5185\u5b58\u9700\u6c42\u4e3a 8x3x3 = 72 \u5b57\u8282\u3002 \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e2a\u7b80\u5355\u7684 python \u4ee3\u7801\u6bb5\u6765\u68c0\u67e5\u8fd9\u4e00\u70b9\u3002 import numpy as np example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) print ( example . nbytes ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u6253\u5370\u51fa 72\uff0c\u5c31\u50cf\u6211\u4eec\u4e4b\u524d\u8ba1\u7b97\u7684\u90a3\u6837\u3002\u4f46\u6211\u4eec\u9700\u8981\u5b58\u50a8\u8fd9\u4e2a\u77e9\u9635\u7684\u6240\u6709\u5143\u7d20\u5417\uff1f\u5982\u524d\u6240\u8ff0\uff0c\u6211\u4eec\u53ea\u5bf9 1 \u611f\u5174\u8da3\u30020 \u5e76\u4e0d\u91cd\u8981\uff0c\u56e0\u4e3a\u4efb\u4f55\u4e0e 0 \u76f8\u4e58\u7684\u5143\u7d20\u90fd\u662f 0\uff0c\u800c 0 \u4e0e\u4efb\u4f55\u5143\u7d20\u76f8\u52a0\u6216\u76f8\u51cf\u4e5f\u6ca1\u6709\u4efb\u4f55\u533a\u522b\u3002\u53ea\u7528 1 \u8868\u793a\u77e9\u9635\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u67d0\u79cd\u5b57\u5178\u65b9\u6cd5\uff0c\u5176\u4e2d\u952e\u662f\u884c\u548c\u5217\u7684\u7d22\u5f15\uff0c\u503c\u662f 1\uff1a ( 0 , 2 ) 1 ( 1 , 0 ) 1 ( 2 , 0 ) 1 ( 2 , 2 ) 1 \u8fd9\u6837\u7684\u7b26\u53f7\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u5b58\u50a8\u56db\u4e2a\u503c\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u3002\u4f7f\u7528\u7684\u603b\u5185\u5b58\u4e3a 8x4 = 32 \u5b57\u8282\u3002\u4efb\u4f55 numpy \u6570\u7ec4\u90fd\u53ef\u4ee5\u901a\u8fc7\u7b80\u5355\u7684 python \u4ee3\u7801\u8f6c\u6362\u4e3a\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) sparse_example = sparse . csr_matrix ( example ) print ( sparse_example . data . nbytes ) \u8fd9\u5c06\u6253\u5370 32\uff0c\u6bd4\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u5c11\u4e86\u8fd9\u4e48\u591a\uff01\u7a00\u758f csr \u77e9\u9635\u7684\u603b\u5927\u5c0f\u662f\u4e09\u4e2a\u503c\u7684\u603b\u548c\u3002 print ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) \u8fd9\u5c06\u6253\u5370\u51fa 64 \u4e2a\u5143\u7d20\uff0c\u4ecd\u7136\u5c11\u4e8e\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u3002\u9057\u61be\u7684\u662f\uff0c\u6211\u4e0d\u4f1a\u8be6\u7ec6\u4ecb\u7ecd\u8fd9\u4e9b\u5143\u7d20\u3002\u4f60\u53ef\u4ee5\u5728 scipy \u6587\u6863\u4e2d\u4e86\u89e3\u66f4\u591a\u3002\u5f53\u6211\u4eec\u62e5\u6709\u66f4\u5927\u7684\u6570\u7ec4\u65f6\uff0c\u6bd4\u5982\u8bf4\u62e5\u6709\u6570\u5343\u4e2a\u6837\u672c\u548c\u6570\u4e07\u4e2a\u7279\u5f81\u7684\u6570\u7ec4\uff0c\u5927\u5c0f\u5dee\u5f02\u5c31\u4f1a\u53d8\u5f97\u975e\u5e38\u5927\u3002\u4f8b\u5982\uff0c\u6211\u4eec\u4f7f\u7528\u57fa\u4e8e\u8ba1\u6570\u7279\u5f81\u7684\u6587\u672c\u6570\u636e\u96c6\u3002 import numpy as np from scipy import sparse n_rows = 10000 n_cols = 100000 # \u751f\u6210\u7b26\u5408\u4f2f\u52aa\u5229\u5206\u5e03\u7684\u968f\u673a\u6570\u7ec4\uff0c\u7ef4\u5ea6\u4e3a[10000, 100000] example = np . random . binomial ( 1 , p = 0.05 , size = ( n_rows , n_cols )) print ( f \"Size of dense array: { example . nbytes } \" ) # \u5c06\u968f\u673a\u77e9\u9635\u8f6c\u6362\u4e3a\u6d17\u6f31\u77e9\u9635 sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u8fd9\u5c06\u6253\u5370\uff1a Size of dense array : 8000000000 Size of sparse array : 399932496 Full size of sparse array : 599938748 \u56e0\u6b64\uff0c\u5bc6\u96c6\u9635\u5217\u9700\u8981 ~8000MB \u6216\u5927\u7ea6 8GB \u5185\u5b58\u3002\u800c\u7a00\u758f\u9635\u5217\u53ea\u5360\u7528 399MB \u5185\u5b58\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5f53\u6211\u4eec\u7684\u7279\u5f81\u4e2d\u6709\u5927\u91cf\u96f6\u65f6\uff0c\u6211\u4eec\u66f4\u559c\u6b22\u7a00\u758f\u9635\u5217\u800c\u4e0d\u662f\u5bc6\u96c6\u9635\u5217\u7684\u539f\u56e0\u3002 \u8bf7\u6ce8\u610f\uff0c\u7a00\u758f\u77e9\u9635\u6709\u591a\u79cd\u4e0d\u540c\u7684\u8868\u793a\u65b9\u6cd5\u3002\u8fd9\u91cc\u6211\u53ea\u5c55\u793a\u4e86\u5176\u4e2d\u4e00\u79cd\uff08\u53ef\u80fd\u4e5f\u662f\u6700\u5e38\u7528\u7684\uff09\u65b9\u6cd5\u3002\u6df1\u5165\u63a2\u8ba8\u8fd9\u4e9b\u65b9\u6cd5\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\uff0c\u56e0\u6b64\u7559\u7ed9\u8bfb\u8005\u4e00\u4e2a\u7ec3\u4e60\u3002 \u5c3d\u7ba1\u4e8c\u503c\u5316\u7279\u5f81\u7684\u7a00\u758f\u8868\u793a\u6bd4\u5176\u5bc6\u96c6\u8868\u793a\u6240\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u4f46\u5bf9\u4e8e\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\uff0c\u8fd8\u6709\u4e00\u79cd\u8f6c\u6362\u6240\u5360\u7528\u7684\u5185\u5b58\u66f4\u5c11\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \" \u72ec\u70ed\u7f16\u7801 \"\u3002 \u72ec\u70ed\u7f16\u7801\u4e5f\u662f\u4e00\u79cd\u4e8c\u503c\u7f16\u7801\uff0c\u56e0\u4e3a\u53ea\u6709 0 \u548c 1 \u4e24\u4e2a\u503c\u3002\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5b83\u5e76\u4e0d\u662f\u4e8c\u503c\u8868\u793a\u6cd5\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0b\u9762\u7684\u4f8b\u5b50\u6765\u7406\u89e3\u5b83\u7684\u8868\u793a\u6cd5\u3002 \u5047\u8bbe\u6211\u4eec\u7528\u4e00\u4e2a\u5411\u91cf\u6765\u8868\u793a ord_2 \u53d8\u91cf\u7684\u6bcf\u4e2a\u7c7b\u522b\u3002\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e0e ord_2 \u53d8\u91cf\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u6bcf\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u90fd\u662f 6\uff0c\u5e76\u4e14\u9664\u4e86\u4e00\u4e2a\u4f4d\u7f6e\u5916\uff0c\u5176\u4ed6\u4f4d\u7f6e\u90fd\u662f 0\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u7279\u6b8a\u7684\u5411\u91cf\u8868\u3002 Freezing 0 0 0 0 0 1 Warm 0 0 0 0 1 0 Cold 0 0 0 1 0 0 Boiling Hot 0 0 1 0 0 0 Hot 0 1 0 0 0 0 Lava Hot 1 0 0 0 0 0 \u6211\u4eec\u770b\u5230\u5411\u91cf\u7684\u5927\u5c0f\u662f 1x6\uff0c\u5373\u5411\u91cf\u4e2d\u67096\u4e2a\u5143\u7d20\u3002\u8fd9\u4e2a\u6570\u5b57\u662f\u600e\u4e48\u6765\u7684\u5462\uff1f\u5982\u679c\u4f60\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u5982\u524d\u6240\u8ff0\uff0c\u67096\u4e2a\u7c7b\u522b\u3002\u5728\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u65f6\uff0c\u5411\u91cf\u7684\u5927\u5c0f\u5fc5\u987b\u4e0e\u6211\u4eec\u8981\u67e5\u770b\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u6bcf\u4e2a\u5411\u91cf\u90fd\u6709\u4e00\u4e2a 1\uff0c\u5176\u4f59\u6240\u6709\u503c\u90fd\u662f 0\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u7528\u8fd9\u4e9b\u7279\u5f81\u6765\u4ee3\u66ff\u4e4b\u524d\u7684\u4e8c\u503c\u5316\u7279\u5f81\uff0c\u770b\u770b\u80fd\u8282\u7701\u591a\u5c11\u5185\u5b58\u3002 \u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\u4ee5\u524d\u7684\u6570\u636e\uff0c\u5b83\u770b\u8d77\u6765\u5982\u4e0b\uff1a Index Feature 0 Warm 1 Hot 2 Lava hot \u6bcf\u4e2a\u6837\u672c\u67093\u4e2a\u7279\u5f81\u3002\u4f46\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u72ec\u70ed\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 6\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u67096\u4e2a\u7279\u5f81\uff0c\u800c\u4e0d\u662f3\u4e2a\u3002 Index F_0 F_1 F_2 F_3 F_4 F_5 0 0 0 0 0 1 0 1 0 1 0 0 0 0 2 1 0 1 0 0 0 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 6 \u4e2a\u7279\u5f81\uff0c\u800c\u5728\u8fd9\u4e2a 3x6 \u6570\u7ec4\u4e2d\uff0c\u53ea\u6709 3 \u4e2a1\u3002\u4f7f\u7528 numpy \u8ba1\u7b97\u5927\u5c0f\u4e0e\u4e8c\u503c\u5316\u5927\u5c0f\u8ba1\u7b97\u811a\u672c\u975e\u5e38\u76f8\u4f3c\u3002\u4f60\u9700\u8981\u6539\u53d8\u7684\u53ea\u662f\u6570\u7ec4\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6bb5\u4ee3\u7801\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 0 , 0 , 1 , 0 ], [ 0 , 1 , 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 , 0 , 0 ] ] ) print ( f \"Size of dense array: { example . nbytes } \" ) sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u6253\u5370\u5185\u5b58\u5927\u5c0f\u4e3a\uff1a Size of dense array : 144 Size of sparse array : 24 Full size of sparse array : 52 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5bc6\u96c6\u77e9\u9635\u7684\u5927\u5c0f\u8fdc\u8fdc\u5927\u4e8e\u4e8c\u503c\u5316\u77e9\u9635\u7684\u5927\u5c0f\u3002\u4e0d\u8fc7\uff0c\u7a00\u758f\u6570\u7ec4\u7684\u5927\u5c0f\u8981\u66f4\u5c0f\u3002\u8ba9\u6211\u4eec\u7528\u66f4\u5927\u7684\u6570\u7ec4\u6765\u8bd5\u8bd5\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 OneHotEncoder \u5c06\u5305\u542b 1001 \u4e2a\u7c7b\u522b\u7684\u7279\u5f81\u6570\u7ec4\u8f6c\u6362\u4e3a\u5bc6\u96c6\u77e9\u9635\u548c\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from sklearn import preprocessing # \u751f\u6210\u7b26\u5408\u5747\u5300\u5206\u5e03\u7684\u968f\u673a\u6574\u6570\uff0c\u7ef4\u5ea6\u4e3a[1000000, 10000000] example = np . random . randint ( 1000 , size = 1000000 ) # \u72ec\u70ed\u7f16\u7801\uff0c\u975e\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = False ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of dense array: { ohe_example . nbytes } \" ) # \u72ec\u70ed\u7f16\u7801\uff0c\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = True ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of sparse array: { ohe_example . data . nbytes } \" ) full_size = ( ohe_example . data . nbytes + ohe_example . indptr . nbytes + ohe_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u4e0a\u9762\u4ee3\u7801\u6253\u5370\u7684\u8f93\u51fa\uff1a Size of dense array : 8000000000 Size of sparse array : 8000000 Full size of sparse array : 16000004 \u8fd9\u91cc\u7684\u5bc6\u96c6\u9635\u5217\u5927\u5c0f\u7ea6\u4e3a 8GB\uff0c\u7a00\u758f\u9635\u5217\u4e3a 8MB\u3002\u5982\u679c\u53ef\u4ee5\u9009\u62e9\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff1f\u5728\u6211\u770b\u6765\uff0c\u9009\u62e9\u5f88\u7b80\u5355\uff0c\u4e0d\u662f\u5417\uff1f \u8fd9\u4e09\u79cd\u65b9\u6cd5\uff08\u6807\u7b7e\u7f16\u7801\u3001\u7a00\u758f\u77e9\u9635\u3001\u72ec\u70ed\u7f16\u7801\uff09\u662f\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u7684\u6700\u91cd\u8981\u65b9\u6cd5\u3002\u4e0d\u8fc7\uff0c\u4f60\u8fd8\u53ef\u4ee5\u7528\u5f88\u591a\u5176\u4ed6\u4e0d\u540c\u7684\u65b9\u6cd5\u6765\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u3002\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u6570\u503c\u53d8\u91cf\u5c31\u662f\u5176\u4e2d\u7684\u4e00\u4e2a\u4f8b\u5b50\u3002 \u5047\u8bbe\u6211\u4eec\u56de\u5230\u4e4b\u524d\u7684\u5206\u7c7b\u7279\u5f81\u6570\u636e\uff08\u539f\u59cb\u6570\u636e\u4e2d\u7684 cat-in-the-dat-ii\uff09\u3002\u5728\u6570\u636e\u4e2d\uff0c ord_2 \u7684\u503c\u4e3a\u201c\u70ed\u201c\u7684 id \u6709\u591a\u5c11\uff1f \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8ba1\u7b97\u6570\u636e\u7684\u5f62\u72b6\uff08shape\uff09\u8f7b\u677e\u8ba1\u7b97\u51fa\u8fd9\u4e2a\u503c\uff0c\u5176\u4e2d ord_2 \u5217\u7684\u503c\u4e3a Boiling Hot \u3002 In [ X ]: df [ df . ord_2 == \"Boiling Hot\" ] . shape Out [ X ]: ( 84790 , 25 ) \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 84790 \u6761\u8bb0\u5f55\u5177\u6709\u6b64\u503c\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 pandas \u4e2d\u7684 groupby \u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u8be5\u503c\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . count () Out [ X ]: ord_2 Boiling Hot 84790 Cold 97822 Freezing 142726 Hot 67508 Lava Hot 64840 Warm 124239 Name : id , dtype : int64 \u5982\u679c\u6211\u4eec\u53ea\u662f\u5c06 ord_2 \u5217\u66ff\u6362\u4e3a\u5176\u8ba1\u6570\u503c\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u5c06\u5176\u8f6c\u6362\u4e3a\u4e00\u79cd\u6570\u503c\u7279\u5f81\u4e86\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 transform \u51fd\u6570\u548c groupby \u6765\u521b\u5efa\u65b0\u5217\u6216\u66ff\u6362\u8fd9\u4e00\u5217\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . transform ( \"count\" ) Out [ X ]: 0 67508.0 1 124239.0 2 142726.0 3 64840.0 4 97822.0 ... 599995 142726.0 599996 84790.0 599997 142726.0 599998 124239.0 599999 84790.0 Name : id , Length : 600000 , dtype : float64 \u4f60\u53ef\u4ee5\u6dfb\u52a0\u6240\u6709\u7279\u5f81\u7684\u8ba1\u6570\uff0c\u4e5f\u53ef\u4ee5\u66ff\u6362\u5b83\u4eec\uff0c\u6216\u8005\u6839\u636e\u591a\u4e2a\u5217\u53ca\u5176\u8ba1\u6570\u8fdb\u884c\u5206\u7ec4\u3002\u4f8b\u5982\uff0c\u4ee5\u4e0b\u4ee3\u7801\u901a\u8fc7\u5bf9 ord_1 \u548c ord_2 \u5217\u5206\u7ec4\u8fdb\u884c\u8ba1\u6570\u3002 In [ X ]: df . groupby ( ... : [ ... : \"ord_1\" , ... : \"ord_2\" ... : ] ... : )[ \"id\" ] . count () . reset_index ( name = \"count\" ) Out [ X ]: ord_1 ord_2 count 0 Contributor Boiling Hot 15634 1 Contributor Cold 17734 2 Contributor Freezing 26082 3 Contributor Hot 12428 4 Contributor Lava Hot 11919 5 Contributor Warm 22774 6 Expert Boiling Hot 19477 7 Expert Cold 22956 8 Expert Freezing 33249 9 Expert Hot 15792 10 Expert Lava Hot 15078 11 Expert Warm 28900 12 Grandmaster Boiling Hot 13623 13 Grandmaster Cold 15464 14 Grandmaster Freezing 22818 15 Grandmaster Hot 10805 16 Grandmaster Lava Hot 10363 17 Grandmaster Warm 19899 18 Master Boiling Hot 10800 ... \u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u884c\uff0c\u4ee5\u4fbf\u5728\u4e00\u9875\u4e2d\u5bb9\u7eb3\u8fd9\u4e9b\u884c\u3002\u8fd9\u662f\u53e6\u4e00\u79cd\u53ef\u4ee5\u4f5c\u4e3a\u529f\u80fd\u6dfb\u52a0\u7684\u8ba1\u6570\u3002\u60a8\u73b0\u5728\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4f7f\u7528 id \u5217\u8fdb\u884c\u8ba1\u6570\u3002\u4e0d\u8fc7\uff0c\u4f60\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5bf9\u5217\u7684\u7ec4\u5408\u8fdb\u884c\u5206\u7ec4\uff0c\u5bf9\u5176\u4ed6\u5217\u8fdb\u884c\u8ba1\u6570\u3002 \u8fd8\u6709\u4e00\u4e2a\u5c0f\u7a8d\u95e8\uff0c\u5c31\u662f\u4ece\u8fd9\u4e9b\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u4ece\u73b0\u6709\u7684\u7279\u5f81\u4e2d\u521b\u5efa\u65b0\u7684\u5206\u7c7b\u7279\u5f81\uff0c\u800c\u4e14\u53ef\u4ee5\u6beb\u4e0d\u8d39\u529b\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot 1 Grandmaster_Warm 2 nan_Freezing 3 Novice_Lava Hot 4 Grandmaster_Cold ... 599999 Contributor_Boiling Hot Name : new_feature , Length : 600000 , dtype : object \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u7528\u4e0b\u5212\u7ebf\u5c06 ord_1 \u548c ord_2 \u5408\u5e76\uff0c\u7136\u540e\u5c06\u8fd9\u4e9b\u5217\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u7c7b\u578b\u3002\u8bf7\u6ce8\u610f\uff0cNaN \u4e5f\u4f1a\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u3002\u4e0d\u8fc7\u6ca1\u5173\u7cfb\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06 NaN \u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u6709\u4e86\u4e00\u4e2a\u7531\u8fd9\u4e24\u4e2a\u7279\u5f81\u7ec4\u5408\u800c\u6210\u7684\u65b0\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u5c06\u4e09\u5217\u4ee5\u4e0a\u6216\u56db\u5217\u751a\u81f3\u66f4\u591a\u5217\u7ec4\u5408\u5728\u4e00\u8d77\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : + \"_\" ... : + df . ord_3 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot_c 1 Grandmaster_Warm_e 2 nan_Freezing_n 3 Novice_Lava Hot_a 4 Grandmaster_Cold_h ... 599999 Contributor_Boiling Hot_b Name : new_feature , Length : 600000 , dtype : object \u90a3\u4e48\uff0c\u6211\u4eec\u5e94\u8be5\u628a\u54ea\u4e9b\u7c7b\u522b\u7ed3\u5408\u8d77\u6765\u5462\uff1f\u8fd9\u5e76\u6ca1\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u7b54\u6848\u3002\u8fd9\u53d6\u51b3\u4e8e\u60a8\u7684\u6570\u636e\u548c\u7279\u5f81\u7c7b\u578b\u3002\u4e00\u4e9b\u9886\u57df\u77e5\u8bc6\u5bf9\u4e8e\u521b\u5efa\u8fd9\u6837\u7684\u7279\u5f81\u53ef\u80fd\u5f88\u6709\u7528\u3002\u4f46\u662f\uff0c\u5982\u679c\u4f60\u4e0d\u62c5\u5fc3\u5185\u5b58\u548c CPU \u7684\u4f7f\u7528\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\uff0c\u5373\u521b\u5efa\u8bb8\u591a\u8fd9\u6837\u7684\u7ec4\u5408\uff0c\u7136\u540e\u4f7f\u7528\u4e00\u4e2a\u6a21\u578b\u6765\u51b3\u5b9a\u54ea\u4e9b\u7279\u5f81\u662f\u6709\u7528\u7684\uff0c\u5e76\u4fdd\u7559\u5b83\u4eec\u3002\u6211\u4eec\u5c06\u5728\u672c\u4e66\u7a0d\u540e\u90e8\u5206\u4ecb\u7ecd\u8fd9\u79cd\u65b9\u6cd5\u3002 \u65e0\u8bba\u4f55\u65f6\u83b7\u5f97\u5206\u7c7b\u53d8\u91cf\uff0c\u90fd\u8981\u9075\u5faa\u4ee5\u4e0b\u7b80\u5355\u6b65\u9aa4\uff1a - \u586b\u5145 NaN \u503c\uff08\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff01\uff09\u3002 - \u4f7f\u7528 scikit-learn \u7684 LabelEncoder \u6216\u6620\u5c04\u5b57\u5178\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\uff0c\u5c06\u5b83\u4eec\u8f6c\u6362\u4e3a\u6574\u6570\u3002\u5982\u679c\u6ca1\u6709\u586b\u5145 NaN \u503c\uff0c\u53ef\u80fd\u9700\u8981\u5728\u8fd9\u4e00\u6b65\u4e2d\u8fdb\u884c\u5904\u7406 - \u521b\u5efa\u72ec\u70ed\u7f16\u7801\u3002\u662f\u7684\uff0c\u4f60\u53ef\u4ee5\u8df3\u8fc7\u4e8c\u503c\u5316\uff01 - \u5efa\u6a21\uff01\u6211\u6307\u7684\u662f\u673a\u5668\u5b66\u4e60\u3002 \u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u5904\u7406 NaN \u6570\u636e\u975e\u5e38\u91cd\u8981\uff0c\u5426\u5219\u60a8\u53ef\u80fd\u4f1a\u4ece scikit-learn \u7684 LabelEncoder \u4e2d\u5f97\u5230\u81ed\u540d\u662d\u8457\u7684\u9519\u8bef\u4fe1\u606f\uff1a ValueError: y \u5305\u542b\u4ee5\u524d\u672a\u89c1\u8fc7\u7684\u6807\u7b7e\uff1a [Nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan) \u8fd9\u4ec5\u4ec5\u610f\u5473\u7740\uff0c\u5728\u8f6c\u6362\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6570\u636e\u4e2d\u51fa\u73b0\u4e86 NaN \u503c\u3002\u8fd9\u662f\u56e0\u4e3a\u4f60\u5728\u8bad\u7ec3\u65f6\u5fd8\u8bb0\u4e86\u5904\u7406\u5b83\u4eec\u3002 \u5904\u7406 NaN \u503c \u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u4e22\u5f03\u5b83\u4eec\u3002\u867d\u7136\u7b80\u5355\uff0c\u4f46\u5e76\u4e0d\u7406\u60f3\u3002NaN \u503c\u4e2d\u53ef\u80fd\u5305\u542b\u5f88\u591a\u4fe1\u606f\uff0c\u5982\u679c\u53ea\u662f\u4e22\u5f03\u8fd9\u4e9b\u503c\uff0c\u5c31\u4f1a\u4e22\u5931\u8fd9\u4e9b\u4fe1\u606f\u3002\u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u5927\u90e8\u5206\u6570\u636e\u90fd\u662f NaN \u503c\uff0c\u56e0\u6b64\u4e0d\u80fd\u4e22\u5f03 NaN \u503c\u7684\u884c/\u6837\u672c\u3002\u5904\u7406 NaN \u503c\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u5c06\u5176\u4f5c\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u662f\u5904\u7406 NaN \u503c\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002\u5982\u679c\u4f7f\u7528 pandas\uff0c\u8fd8\u53ef\u4ee5\u901a\u8fc7\u975e\u5e38\u7b80\u5355\u7684\u65b9\u5f0f\u5b9e\u73b0\u3002 \u8bf7\u770b\u6211\u4eec\u4e4b\u524d\u67e5\u770b\u8fc7\u7684\u6570\u636e\u7684 ord_2 \u5217\u3002 In [ X ]: df . ord_2 . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : ord_2 , dtype : int64 \u586b\u5165 NaN \u503c\u540e\uff0c\u5c31\u53d8\u6210\u4e86 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u54c7\uff01\u8fd9\u4e00\u5217\u4e2d\u6709 18075 \u4e2a NaN \u503c\uff0c\u800c\u6211\u4eec\u4e4b\u524d\u751a\u81f3\u90fd\u6ca1\u6709\u8003\u8651\u4f7f\u7528\u5b83\u4eec\u3002\u589e\u52a0\u4e86\u8fd9\u4e2a\u65b0\u7c7b\u522b\u540e\uff0c\u7c7b\u522b\u603b\u6570\u4ece 6 \u4e2a\u589e\u52a0\u5230\u4e86 7 \u4e2a\u3002\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u73b0\u5728\u6211\u4eec\u5728\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u4e5f\u4f1a\u8003\u8651 NaN\u3002\u76f8\u5173\u4fe1\u606f\u8d8a\u591a\uff0c\u6a21\u578b\u5c31\u8d8a\u597d\u3002 \u5047\u8bbe ord_2 \u6ca1\u6709\u4efb\u4f55 NaN \u503c\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e00\u5217\u4e2d\u7684\u6240\u6709\u7c7b\u522b\u90fd\u6709\u663e\u8457\u7684\u8ba1\u6570\u3002\u5176\u4e2d\u6ca1\u6709 \"\u7f55\u89c1 \"\u7c7b\u522b\uff0c\u5373\u53ea\u5728\u6837\u672c\u603b\u6570\u4e2d\u5360\u5f88\u5c0f\u6bd4\u4f8b\u7684\u7c7b\u522b\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5047\u8bbe\u60a8\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u4e86\u4f7f\u7528\u8fd9\u4e00\u5217\u7684\u6a21\u578b\uff0c\u5f53\u6a21\u578b\u6216\u9879\u76ee\u4e0a\u7ebf\u65f6\uff0c\u60a8\u5728 ord_2 \u5217\u4e2d\u5f97\u5230\u4e86\u4e00\u4e2a\u5728\u8bad\u7ec3\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u7ba1\u9053\u4f1a\u629b\u51fa\u4e00\u4e2a\u9519\u8bef\uff0c\u60a8\u5bf9\u6b64\u65e0\u80fd\u4e3a\u529b\u3002\u5982\u679c\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\uff0c\u90a3\u4e48\u53ef\u80fd\u662f\u751f\u4ea7\u4e2d\u7684\u7ba1\u9053\u51fa\u4e86\u95ee\u9898\u3002\u5982\u679c\u8fd9\u662f\u9884\u6599\u4e4b\u4e2d\u7684\uff0c\u90a3\u4e48\u60a8\u5c31\u5fc5\u987b\u4fee\u6539\u60a8\u7684\u6a21\u578b\u7ba1\u9053\uff0c\u5e76\u5728\u8fd9\u516d\u4e2a\u7c7b\u522b\u4e2d\u52a0\u5165\u4e00\u4e2a\u65b0\u7c7b\u522b\u3002 \u8fd9\u4e2a\u65b0\u7c7b\u522b\u88ab\u79f0\u4e3a \"\u7f55\u89c1 \"\u7c7b\u522b\u3002\u7f55\u89c1\u7c7b\u522b\u662f\u4e00\u79cd\u4e0d\u5e38\u89c1\u7684\u7c7b\u522b\uff0c\u53ef\u4ee5\u5305\u62ec\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8fd1\u90bb\u6a21\u578b\u6765 \"\u9884\u6d4b \"\u672a\u77e5\u7c7b\u522b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u4e86\u8fd9\u4e2a\u7c7b\u522b\uff0c\u5b83\u5c31\u4f1a\u6210\u4e3a\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u4e00\u4e2a\u7c7b\u522b\u3002 \u56fe 3\uff1a\u5177\u6709\u4e0d\u540c\u7279\u5f81\u4e14\u65e0\u6807\u7b7e\u7684\u6570\u636e\u96c6\u793a\u610f\u56fe\uff0c\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u53ef\u80fd\u4f1a\u5728\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u51fa\u73b0\u65b0\u503c \u5f53\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u56fe 3 \u6240\u793a\u7684\u6570\u636e\u96c6\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\uff0c\u5bf9\u9664 \"f3 \"\u4e4b\u5916\u7684\u6240\u6709\u7279\u5f81\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u6837\uff0c\u4f60\u5c06\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u4e0d\u77e5\u9053\u6216\u8bad\u7ec3\u4e2d\u6ca1\u6709 \"f3 \"\u65f6\u9884\u6d4b\u5b83\u3002\u6211\u4e0d\u6562\u8bf4\u8fd9\u6837\u7684\u6a21\u578b\u662f\u5426\u80fd\u5e26\u6765\u51fa\u8272\u7684\u6027\u80fd\uff0c\u4f46\u4e5f\u8bb8\u80fd\u5904\u7406\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u7684\u7f3a\u5931\u503c\uff0c\u5c31\u50cf\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5176\u4ed6\u4e8b\u60c5\u4e00\u6837\uff0c\u4e0d\u5c1d\u8bd5\u4e00\u4e0b\u662f\u8bf4\u4e0d\u51c6\u7684\u3002 \u5982\u679c\u4f60\u6709\u4e00\u4e2a\u56fa\u5b9a\u7684\u6d4b\u8bd5\u96c6\uff0c\u4f60\u53ef\u4ee5\u5c06\u6d4b\u8bd5\u6570\u636e\u6dfb\u52a0\u5230\u8bad\u7ec3\u4e2d\uff0c\u4ee5\u4e86\u89e3\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u7c7b\u522b\u3002\u8fd9\u4e0e\u534a\u76d1\u7763\u5b66\u4e60\u975e\u5e38\u76f8\u4f3c\uff0c\u5373\u4f7f\u7528\u65e0\u6cd5\u7528\u4e8e\u8bad\u7ec3\u7684\u6570\u636e\u6765\u6539\u8fdb\u6a21\u578b\u3002\u8fd9\u4e5f\u4f1a\u7167\u987e\u5230\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u51fa\u73b0\u6b21\u6570\u6781\u5c11\u4f46\u5728\u6d4b\u8bd5\u6570\u636e\u4e2d\u5927\u91cf\u5b58\u5728\u7684\u7a00\u6709\u503c\u3002\u4f60\u7684\u6a21\u578b\u5c06\u66f4\u52a0\u7a33\u5065\u3002 \u5f88\u591a\u4eba\u8ba4\u4e3a\u8fd9\u79cd\u60f3\u6cd5\u4f1a\u8fc7\u5ea6\u62df\u5408\u3002\u53ef\u80fd\u8fc7\u62df\u5408\uff0c\u4e5f\u53ef\u80fd\u4e0d\u8fc7\u62df\u5408\u3002\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u3002\u5982\u679c\u4f60\u5728\u8bbe\u8ba1\u4ea4\u53c9\u9a8c\u8bc1\u65f6\uff0c\u80fd\u591f\u5728\u6d4b\u8bd5\u6570\u636e\u4e0a\u8fd0\u884c\u6a21\u578b\u65f6\u590d\u5236\u9884\u6d4b\u8fc7\u7a0b\uff0c\u90a3\u4e48\u5b83\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u8fd9\u610f\u5473\u7740\u7b2c\u4e00\u6b65\u5e94\u8be5\u662f\u5206\u79bb\u6298\u53e0\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u4f60\u5e94\u8be5\u5e94\u7528\u4e0e\u6d4b\u8bd5\u6570\u636e\u76f8\u540c\u7684\u9884\u5904\u7406\u3002\u5047\u8bbe\u60a8\u60f3\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e\uff0c\u90a3\u4e48\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u60a8\u5fc5\u987b\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\uff0c\u5e76\u786e\u4fdd\u9a8c\u8bc1\u6570\u636e\u96c6\u590d\u5236\u4e86\u6d4b\u8bd5\u96c6\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u60a8\u5fc5\u987b\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u8bbe\u8ba1\u9a8c\u8bc1\u96c6\uff0c\u4f7f\u5176\u5305\u542b\u8bad\u7ec3\u96c6\u4e2d \"\u672a\u89c1 \"\u7684\u7c7b\u522b\u3002 \u56fe 4\uff1a\u5bf9\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u8fdb\u884c\u7b80\u5355\u5408\u5e76\uff0c\u4ee5\u4e86\u89e3\u6d4b\u8bd5\u96c6\u4e2d\u5b58\u5728\u4f46\u8bad\u7ec3\u96c6\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u6216\u8bad\u7ec3\u96c6\u4e2d\u7f55\u89c1\u7684\u7c7b\u522b \u53ea\u8981\u770b\u4e00\u4e0b\u56fe 4 \u548c\u4e0b\u9762\u7684\u4ee3\u7801\uff0c\u5c31\u80fd\u5f88\u5bb9\u6613\u7406\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( \"../input/cat_train.csv\" ) # \u8bfb\u53d6\u6d4b\u8bd5\u96c6 test = pd . read_csv ( \"../input/cat_test.csv\" ) # \u5c06\u6d4b\u8bd5\u96c6\"target\"\u5217\u5168\u90e8\u7f6e\u4e3a-1 test . loc [:, \"target\" ] = - 1 # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u62fc\u63a5 data = pd . concat ([ train , test ]) . reset_index ( drop = True ) # \u5c06\u9664\"id\"\u548c\"target\"\u5217\u7684\u5176\u4ed6\u7279\u5f81\u5217\u540d\u53d6\u51fa features = [ x for x in train . columns if x not in [ \"id\" , \"target\" ]] # \u904d\u5386\u7279\u5f81 for feat in features : # \u6807\u7b7e\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\",\u5e76\u5c06\u8be5\u5217\u683c\u5f0f\u53d8\u4e3astr temp_col = data [ feat ] . fillna ( \"NONE\" ) . astype ( str ) . values # \u8f6c\u6362\u6570\u503c data . loc [:, feat ] = lbl_enc . fit_transform ( temp_col ) # \u6839\u636e\"target\"\u5217\u5c06\u8bad\u7ec3\u96c6\u4e0e\u6d4b\u8bd5\u96c6\u5206\u5f00 train = data [ data . target != - 1 ] . reset_index ( drop = True ) test = data [ data . target == - 1 ] . reset_index ( drop = True ) \u5f53\u60a8\u9047\u5230\u5df2\u7ecf\u6709\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u95ee\u9898\u65f6\uff0c\u8fd9\u4e2a\u6280\u5de7\u5c31\u4f1a\u8d77\u4f5c\u7528\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u4e00\u62db\u5728\u5b9e\u65f6\u73af\u5883\u4e2d\u4e0d\u8d77\u4f5c\u7528\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u60a8\u6240\u5728\u7684\u516c\u53f8\u63d0\u4f9b\u5b9e\u65f6\u7ade\u4ef7\u89e3\u51b3\u65b9\u6848\uff08RTB\uff09\u3002RTB \u7cfb\u7edf\u4f1a\u5bf9\u5728\u7ebf\u770b\u5230\u7684\u6bcf\u4e2a\u7528\u6237\u8fdb\u884c\u7ade\u4ef7\uff0c\u4ee5\u8d2d\u4e70\u5e7f\u544a\u7a7a\u95f4\u3002\u8fd9\u79cd\u6a21\u5f0f\u53ef\u4f7f\u7528\u7684\u529f\u80fd\u53ef\u80fd\u5305\u62ec\u7f51\u7ad9\u4e2d\u6d4f\u89c8\u7684\u9875\u9762\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u7279\u5f81\u662f\u7528\u6237\u8bbf\u95ee\u7684\u6700\u540e\u4e94\u4e2a\u7c7b\u522b/\u9875\u9762\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u7f51\u7ad9\u5f15\u5165\u4e86\u65b0\u7684\u7c7b\u522b\uff0c\u6211\u4eec\u5c06\u65e0\u6cd5\u518d\u51c6\u786e\u9884\u6d4b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u5931\u6548\u3002\u8fd9\u79cd\u60c5\u51b5\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528 \"\u672a\u77e5 \"\u7c7b\u522b\u6765\u907f\u514d \u3002 \u5728\u6211\u4eec\u7684 cat-in-the-dat \u6570\u636e\u96c6\u4e2d\uff0c ord_2 \u5217\u4e2d\u5df2\u7ecf\u6709\u4e86\u672a\u77e5\u7c7b\u522b\u3002 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u6211\u4eec\u53ef\u4ee5\u5c06 \"NONE \"\u89c6\u4e3a\u672a\u77e5\u3002\u56e0\u6b64\uff0c\u5982\u679c\u5728\u5b9e\u65f6\u6d4b\u8bd5\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4ee5\u524d\u4ece\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\uff0c\u6211\u4eec\u5c31\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a \"NONE\"\u3002 \u8fd9\u4e0e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u95ee\u9898\u975e\u5e38\u76f8\u4f3c\u3002\u6211\u4eec\u603b\u662f\u57fa\u4e8e\u56fa\u5b9a\u7684\u8bcd\u6c47\u5efa\u7acb\u6a21\u578b\u3002\u589e\u52a0\u8bcd\u6c47\u91cf\u5c31\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u3002\u50cf BERT \u8fd9\u6837\u7684\u8f6c\u6362\u5668\u6a21\u578b\u662f\u5728 ~30000 \u4e2a\u5355\u8bcd\uff08\u82f1\u8bed\uff09\u7684\u57fa\u7840\u4e0a\u8bad\u7ec3\u7684\u3002\u56e0\u6b64\uff0c\u5f53\u6709\u65b0\u8bcd\u8f93\u5165\u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a UNK\uff08\u672a\u77e5\uff09\u3002 \u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u5047\u8bbe\u6d4b\u8bd5\u6570\u636e\u4e0e\u8bad\u7ec3\u6570\u636e\u5177\u6709\u76f8\u540c\u7684\u7c7b\u522b\uff0c\u4e5f\u53ef\u4ee5\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u5f15\u5165\u7f55\u89c1\u6216\u672a\u77e5\u7c7b\u522b\uff0c\u4ee5\u5904\u7406\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u65b0\u7c7b\u522b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u586b\u5165 NaN \u503c\u540e ord_4 \u5217\u7684\u503c\u8ba1\u6570\uff1a In [ X ]: df . ord_4 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 . . . K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 G 3404 V 3107 J 1950 L 1657 Name : ord_4 , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u6709\u4e9b\u6570\u503c\u53ea\u51fa\u73b0\u4e86\u51e0\u5343\u6b21\uff0c\u6709\u4e9b\u5219\u51fa\u73b0\u4e86\u8fd1 40000 \u6b21\u3002NaN \u4e5f\u7ecf\u5e38\u51fa\u73b0\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u503c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u5c06\u4e00\u4e2a\u503c\u79f0\u4e3a \" \u7f55\u89c1\uff08rare\uff09 \"\u7684\u6807\u51c6\u4e86\u3002\u6bd4\u65b9\u8bf4\uff0c\u5728\u8fd9\u4e00\u5217\u4e2d\uff0c\u7a00\u6709\u503c\u7684\u8981\u6c42\u662f\u8ba1\u6570\u5c0f\u4e8e 2000\u3002\u8fd9\u6837\u770b\u6765\uff0cJ \u548c L \u5c31\u53ef\u4ee5\u88ab\u6807\u8bb0\u4e3a\u7a00\u6709\u503c\u4e86\u3002\u4f7f\u7528 pandas\uff0c\u6839\u636e\u8ba1\u6570\u9608\u503c\u66ff\u6362\u7c7b\u522b\u975e\u5e38\u7b80\u5355\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 In [ X ]: df . ord_4 = df . ord_4 . fillna ( \"NONE\" ) In [ X ]: df . loc [ ... : df [ \"ord_4\" ] . value_counts ()[ df [ \"ord_4\" ]] . values < 2000 , ... : \"ord_4\" ... : ] = \"RARE\" In [ X ]: df . ord_4 . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 M 32504 . . . B 25212 E 21871 K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 RARE 3607 G 3404 V 3107 Name : ord_4 , dtype : int64 \u6211\u4eec\u8ba4\u4e3a\uff0c\u53ea\u8981\u67d0\u4e2a\u7c7b\u522b\u7684\u503c\u5c0f\u4e8e 2000\uff0c\u5c31\u5c06\u5176\u66ff\u6362\u4e3a\u7f55\u89c1\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u5728\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6240\u6709\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"RARE\"\uff0c\u800c\u6240\u6709\u7f3a\u5931\u503c\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"NONE\"\u3002 \u8fd9\u79cd\u65b9\u6cd5\u8fd8\u80fd\u786e\u4fdd\u5373\u4f7f\u6709\u65b0\u7684\u7c7b\u522b\uff0c\u6a21\u578b\u4e5f\u80fd\u5728\u5b9e\u9645\u73af\u5883\u4e2d\u6b63\u5e38\u5de5\u4f5c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u5177\u5907\u4e86\u5904\u7406\u4efb\u4f55\u5e26\u6709\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u6240\u9700\u7684\u4e00\u5207\u6761\u4ef6\u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u7b2c\u4e00\u4e2a\u6a21\u578b\uff0c\u5e76\u9010\u6b65\u63d0\u9ad8\u5176\u6027\u80fd\u3002 \u5728\u6784\u5efa\u4efb\u4f55\u7c7b\u578b\u7684\u6a21\u578b\u4e4b\u524d\uff0c\u4ea4\u53c9\u68c0\u9a8c\u81f3\u5173\u91cd\u8981\u3002\u6211\u4eec\u5df2\u7ecf\u770b\u5230\u4e86\u6807\u7b7e/\u76ee\u6807\u5206\u5e03\uff0c\u77e5\u9053\u8fd9\u662f\u4e00\u4e2a\u76ee\u6807\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u6765\u5206\u5272\u6570\u636e\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u6dfb\u52a0\"kfold\"\u5217\uff0c\u5e76\u7f6e\u4e3a-1 df [ \"kfold\" ] = - 1 # \u6253\u4e71\u6570\u636e\u987a\u5e8f\uff0c\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u5c06\u76ee\u6807\u5217\u53d6\u51fa y = df . target . values # \u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): # \u533a\u5206\u6298\u53e0 df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u6587\u4ef6 df . to_csv ( \"../input/cat_train_folds.csv\" , index = False ) \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u68c0\u67e5\u65b0\u7684\u6298\u53e0 csv\uff0c\u67e5\u770b\u6bcf\u4e2a\u6298\u53e0\u7684\u6837\u672c\u6570\uff1a In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) In [ X ]: df . kfold . value_counts () Out [ X ]: 4 120000 3 120000 2 120000 1 120000 0 120000 Name : kfold , dtype : int64 \u6240\u6709\u6298\u53e0\u90fd\u6709 120000 \u4e2a\u6837\u672c\u3002\u8fd9\u662f\u610f\u6599\u4e4b\u4e2d\u7684\uff0c\u56e0\u4e3a\u8bad\u7ec3\u6570\u636e\u6709 600000 \u4e2a\u6837\u672c\uff0c\u800c\u6211\u4eec\u505a\u4e865\u6b21\u6298\u53e0\u3002\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4e00\u5207\u987a\u5229\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u6bcf\u4e2a\u6298\u53e0\u7684\u76ee\u6807\u5206\u5e03\u3002 In [ X ]: df [ df . kfold == 0 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 1 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 2 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 3 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 4 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u76ee\u6807\u7684\u5206\u5e03\u90fd\u662f\u4e00\u6837\u7684\u3002\u8fd9\u6b63\u662f\u6211\u4eec\u6240\u9700\u8981\u7684\u3002\u5b83\u4e5f\u53ef\u4ee5\u662f\u76f8\u4f3c\u7684\uff0c\u5e76\u4e0d\u4e00\u5b9a\u8981\u4e00\u76f4\u76f8\u540c\u3002\u73b0\u5728\uff0c\u5f53\u6211\u4eec\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u6807\u7b7e\u5206\u5e03\u90fd\u5c06\u76f8\u540c\u3002 \u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u7684\u6700\u7b80\u5355\u7684\u6a21\u578b\u4e4b\u4e00\u662f\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u5e76\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): # \u8bfb\u53d6\u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) # \u53d6\u9664\"id\", \"target\", \"kfold\"\u5916\u7684\u5176\u4ed6\u7279\u5f81\u5217 features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] # \u904d\u5386\u7279\u5f81\u5217\u8868 for col in features : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u6d4b\u8bd5\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u903b\u8f91\u56de\u5f52 model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . target . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( auc ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6298\u53e00 run ( 0 ) \u90a3\u4e48\uff0c\u53d1\u751f\u4e86\u4ec0\u4e48\u5462\uff1f \u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u5c06\u6570\u636e\u5206\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u4e24\u90e8\u5206\uff0c\u7ed9\u5b9a\u6298\u53e0\u6570\uff0c\u5904\u7406 NaN \u503c\uff0c\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u5355\u6b21\u7f16\u7801\uff0c\u5e76\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\u3002 \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u90e8\u5206\u4ee3\u7801\u65f6\uff0c\u4f1a\u4ea7\u751f\u5982\u4e0b\u8f93\u51fa\uff1a \u276f python ohe_logres . py / home / abhishek / miniconda3 / envs / ml / lib / python3 .7 / site - packages / sklearn / linear_model / _logistic . py : 939 : ConvergenceWarning : lbfgs failed to converge ( status = 1 ): STOP : TOTAL NO . of ITERATIONS REACHED LIMIT . Increase the number of iterations ( max_iter ) or scale the data as shown in : https : // scikit - learn . org / stable / modules / preprocessing . html . Please also refer to the documentation for alternative solver options : https : // scikit - learn . org / stable / modules / linear_model . html #logistic- regression extra_warning_msg = _LOGISTIC_SOLVER_CONVERGENCE_MSG ) 0.7847865042255127 \u6709\u4e00\u4e9b\u8b66\u544a\u3002\u903b\u8f91\u56de\u5f52\u4f3c\u4e4e\u6ca1\u6709\u6536\u655b\u5230\u6700\u5927\u8fed\u4ee3\u6b21\u6570\u3002\u6211\u4eec\u6ca1\u6709\u8c03\u6574\u53c2\u6570\uff0c\u6240\u4ee5\u6ca1\u6709\u95ee\u9898\u3002\u6211\u4eec\u770b\u5230 AUC \u4e3a 0.785\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u5bf9\u4ee3\u7801\u8fdb\u884c\u7b80\u5355\u4fee\u6539\uff0c\u8fd0\u884c\u6240\u6709\u6298\u53e0\u3002 .... model = linear_model . LogisticRegression () model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u5faa\u73af\u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5e76\u6ca1\u6709\u505a\u5f88\u5927\u7684\u6539\u52a8\uff0c\u6240\u4ee5\u6211\u53ea\u663e\u793a\u4e86\u90e8\u5206\u4ee3\u7801\u884c\uff0c\u5176\u4e2d\u4e00\u4e9b\u4ee3\u7801\u884c\u6709\u6539\u52a8\u3002 \u8fd9\u5c31\u6253\u5370\u51fa\u4e86\uff1a python - W ignore ohe_logres . py Fold = 0 , AUC = 0.7847865042255127 Fold = 1 , AUC = 0.7853553605899214 Fold = 2 , AUC = 0.7879321942914885 Fold = 3 , AUC = 0.7870315929550808 Fold = 4 , AUC = 0.7864668243125608 \u8bf7\u6ce8\u610f\uff0c\u6211\u4f7f\u7528\"-W ignore \"\u5ffd\u7565\u4e86\u6240\u6709\u8b66\u544a\u3002 \u6211\u4eec\u770b\u5230\uff0cAUC \u5206\u6570\u5728\u6240\u6709\u8936\u76b1\u4e2d\u90fd\u76f8\u5f53\u7a33\u5b9a\u3002\u5e73\u5747 AUC \u4e3a 0.78631449527\u3002\u5bf9\u4e8e\u6211\u4eec\u7684\u7b2c\u4e00\u4e2a\u6a21\u578b\u6765\u8bf4\u76f8\u5f53\u4e0d\u9519\uff01 \u5f88\u591a\u4eba\u5728\u9047\u5230\u8fd9\u79cd\u95ee\u9898\u65f6\u4f1a\u9996\u5148\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6bd4\u5982\u968f\u673a\u68ee\u6797\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u5e94\u7528\u968f\u673a\u68ee\u6797\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\uff08label encoding\uff09\uff0c\u5c06\u6bcf\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81\u90fd\u8f6c\u6362\u4e3a\u6574\u6570\uff0c\u800c\u4e0d\u662f\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u72ec\u70ed\u7f16\u7801\u3002 \u8fd9\u79cd\u7f16\u7801\u4e0e\u72ec\u70ed\u7f16\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # \u968f\u673a\u68ee\u6797\u6a21\u578b model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u4f7f\u7528 scikit-learn \u4e2d\u7684\u968f\u673a\u68ee\u6797\uff0c\u5e76\u53d6\u6d88\u4e86\u72ec\u70ed\u7f16\u7801\u3002\u6211\u4eec\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u4ee3\u66ff\u72ec\u70ed\u7f16\u7801\u3002\u5f97\u5206\u5982\u4e0b \u276f python lbl_rf . py Fold = 0 , AUC = 0.7167390828113697 Fold = 1 , AUC = 0.7165459672958506 Fold = 2 , AUC = 0.7159709909587376 Fold = 3 , AUC = 0.7161589664189556 Fold = 4 , AUC = 0.7156020216155978 \u54c7 \u5de8\u5927\u7684\u5dee\u5f02\uff01 \u968f\u673a\u68ee\u6797\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8d85\u53c2\u6570\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u8868\u73b0\u8981\u6bd4\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u5dee\u5f88\u591a\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u603b\u662f\u5e94\u8be5\u5148\u4ece\u7b80\u5355\u6a21\u578b\u5f00\u59cb\u7684\u539f\u56e0\u3002\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u7c89\u4e1d\u4f1a\u4ece\u8fd9\u91cc\u5f00\u59cb\uff0c\u800c\u5ffd\u7565\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u8ba4\u4e3a\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u6a21\u578b\uff0c\u4e0d\u80fd\u5e26\u6765\u6bd4\u968f\u673a\u68ee\u6797\u66f4\u597d\u7684\u4ef7\u503c\u3002\u8fd9\u79cd\u4eba\u5c06\u4f1a\u72af\u4e0b\u5927\u9519\u3002\u5728\u6211\u4eec\u5b9e\u73b0\u968f\u673a\u68ee\u6797\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4e0e\u903b\u8f91\u56de\u5f52\u76f8\u6bd4\uff0c\u6298\u53e0\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u624d\u80fd\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e0d\u4ec5\u635f\u5931\u4e86 AUC\uff0c\u8fd8\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u6765\u5b8c\u6210\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u8fdb\u884c\u63a8\u7406\u4e5f\u5f88\u8017\u65f6\uff0c\u800c\u4e14\u5360\u7528\u7684\u7a7a\u95f4\u4e5f\u66f4\u5927\u3002 \u5982\u679c\u6211\u4eec\u613f\u610f\uff0c\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u5728\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u6570\u636e\u4e0a\u8fd0\u884c\u968f\u673a\u68ee\u6797\uff0c\u4f46\u8fd9\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u5947\u5f02\u503c\u5206\u89e3\u6765\u51cf\u5c11\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u77e9\u9635\u3002\u8fd9\u662f\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u63d0\u53d6\u4e3b\u9898\u7684\u5e38\u7528\u65b9\u6cd5\u3002 import pandas as pd from scipy import sparse from sklearn import decomposition from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" )] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) x_train = ohe . transform ( df_train [ features ]) x_valid = ohe . transform ( df_valid [ features ]) # \u5947\u5f02\u503c\u5206\u89e3 svd = decomposition . TruncatedSVD ( n_components = 120 ) full_sparse = sparse . vstack (( x_train , x_valid )) svd . fit ( full_sparse ) x_train = svd . transform ( x_train ) x_valid = svd . transform ( x_valid ) model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u5bf9\u5168\u90e8\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\uff0c\u7136\u540e\u7528\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u5728\u7a00\u758f\u77e9\u9635\u4e0a\u62df\u5408 scikit-learn \u7684 TruncatedSVD\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c06\u9ad8\u7ef4\u7a00\u758f\u77e9\u9635\u51cf\u5c11\u5230 120 \u4e2a\u7279\u5f81\uff0c\u7136\u540e\u62df\u5408\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u3002 \u4ee5\u4e0b\u662f\u8be5\u6a21\u578b\u7684\u8f93\u51fa\u7ed3\u679c\uff1a \u276f python ohe_svd_rf . py Fold = 0 , AUC = 0.7064863038754249 Fold = 1 , AUC = 0.706050102937374 Fold = 2 , AUC = 0.7086069243167242 Fold = 3 , AUC = 0.7066819080085971 Fold = 4 , AUC = 0.7058154015055585 \u6211\u4eec\u53d1\u73b0\u60c5\u51b5\u66f4\u7cdf\u3002\u770b\u6765\uff0c\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u7684\u6700\u4f73\u65b9\u6cd5\u662f\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u3002\u968f\u673a\u68ee\u6797\u4f3c\u4e4e\u8017\u65f6\u592a\u591a\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u8bd5\u8bd5 XGBoost\u3002\u5982\u679c\u4f60\u4e0d\u77e5\u9053 XGBoost\uff0c\u5b83\u662f\u6700\u6d41\u884c\u7684\u68af\u5ea6\u63d0\u5347\u7b97\u6cd5\u4e4b\u4e00\u3002\u7531\u4e8e\u5b83\u662f\u4e00\u79cd\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u6570\u636e\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 , n_estimators = 200 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6211\u5bf9 xgboost \u53c2\u6570\u505a\u4e86\u4e00\u4e9b\u4fee\u6539\u3002xgboost \u7684\u9ed8\u8ba4\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u662f 3\uff0c\u6211\u628a\u5b83\u6539\u6210\u4e86 7\uff0c\u8fd8\u628a\u4f30\u8ba1\u5668\u6570\u91cf\uff08n_estimators\uff09\u4ece 100 \u6539\u6210\u4e86 200\u3002 \u8be5\u6a21\u578b\u7684 5 \u6298\u4ea4\u53c9\u68c0\u9a8c\u5f97\u5206\u5982\u4e0b\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.7656768851999011 Fold = 1 , AUC = 0.7633006564148015 Fold = 2 , AUC = 0.7654277821434345 Fold = 3 , AUC = 0.7663609758878182 Fold = 4 , AUC = 0.764914671468069 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u4e0d\u505a\u4efb\u4f55\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u5f97\u5206\u6bd4\u666e\u901a\u968f\u673a\u68ee\u6797\u8981\u9ad8\u5f97\u591a\u3002 \u60a8\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4e00\u4e9b\u7279\u5f81\u5de5\u7a0b\uff0c\u653e\u5f03\u67d0\u4e9b\u5bf9\u6a21\u578b\u6ca1\u6709\u4efb\u4f55\u4ef7\u503c\u7684\u5217\u7b49\u3002\u4f46\u4f3c\u4e4e\u6211\u4eec\u80fd\u505a\u7684\u4e0d\u591a\uff0c\u65e0\u6cd5\u8bc1\u660e\u6a21\u578b\u7684\u6539\u8fdb\u3002\u8ba9\u6211\u4eec\u628a\u6570\u636e\u96c6\u6362\u6210\u53e6\u4e00\u4e2a\u6709\u5927\u91cf\u5206\u7c7b\u53d8\u91cf\u7684\u6570\u636e\u96c6\u3002\u53e6\u4e00\u4e2a\u6709\u540d\u7684\u6570\u636e\u96c6\u662f \u7f8e\u56fd\u6210\u4eba\u4eba\u53e3\u666e\u67e5\u6570\u636e\uff08US adult census data\uff09 \u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u5305\u542b\u4e00\u4e9b\u7279\u5f81\uff0c\u800c\u4f60\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5de5\u8d44\u7b49\u7ea7\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u3002\u56fe 5 \u663e\u793a\u4e86\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e9b\u5217\u3002 \u56fe 5\uff1a\u90e8\u5206\u6570\u636e\u96c6\u5c55\u793a \u8be5\u6570\u636e\u96c6\u6709\u4ee5\u4e0b\u51e0\u5217\uff1a - \u5e74\u9f84\uff08age\uff09 \u5de5\u4f5c\u7c7b\u522b\uff08workclass\uff09 \u5b66\u5386\uff08fnlwgt\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education.num\uff09 \u5a5a\u59fb\u72b6\u51b5\uff08marital.status\uff09 \u804c\u4e1a\uff08occupation\uff09 \u5173\u7cfb\uff08relationship\uff09 \u79cd\u65cf\uff08race\uff09 \u6027\u522b\uff08sex\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u539f\u7c4d\u56fd\uff08native.country\uff09 \u6536\u5165\uff08income\uff09 \u8fd9\u4e9b\u7279\u5f81\u5927\u591a\u4e0d\u8a00\u81ea\u660e\u3002\u90a3\u4e9b\u4e0d\u660e\u767d\u7684\uff0c\u6211\u4eec\u53ef\u4ee5\u4e0d\u8003\u8651\u3002\u8ba9\u6211\u4eec\u5148\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002 \u6211\u4eec\u770b\u5230\u6536\u5165\u5217\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u3002\u8ba9\u6211\u4eec\u5bf9\u8fd9\u4e00\u5217\u8fdb\u884c\u6570\u503c\u7edf\u8ba1\u3002 In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/adult.csv\" ) In [ X ]: df . income . value_counts () Out [ X ]: <= 50 K 24720 > 50 K 7841 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 7841 \u4e2a\u5b9e\u4f8b\u7684\u6536\u5165\u8d85\u8fc7 5 \u4e07\u7f8e\u5143\u3002\u8fd9\u5360\u6837\u672c\u603b\u6570\u7684 24%\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u4e0e\u732b\u6570\u636e\u96c6\u76f8\u540c\u7684\u8bc4\u4f30\u65b9\u6cd5\uff0c\u5373 AUC\u3002 \u5728\u5f00\u59cb\u5efa\u6a21\u4e4b\u524d\uff0c\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5c06\u53bb\u6389\u51e0\u5217\u7279\u5f81\uff0c\u5373 \u5b66\u5386\uff08fnlwgt\uff09 \u5e74\u9f84\uff08age\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u8ba9\u6211\u4eec\u8bd5\u7740\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u5668\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u7b2c\u4e00\u6b65\u603b\u662f\u8981\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002\u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u90e8\u5206\u4ee3\u7801\u3002\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u9700\u8981\u5220\u9664\u7684\u5217 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) # \u6620\u5c04 target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } # \u4f7f\u7528\u6620\u5c04\u66ff\u6362 df . loc [:, \"income\" ] = df . income . map ( target_mapping ) # \u53d6\u9664\"kfold\", \"income\"\u5217\u7684\u5176\u4ed6\u5217\u540d features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u9a8c\u8bc1\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u6784\u5efa\u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . income . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u6211\u4eec\u4f1a\u5f97\u5230 \u276f python - W ignore ohe_logres . py Fold = 0 , AUC = 0.8794809708119079 Fold = 1 , AUC = 0.8875785068274882 Fold = 2 , AUC = 0.8852609687685753 Fold = 3 , AUC = 0.8681236223251438 Fold = 4 , AUC = 0.8728581541840037 \u5bf9\u4e8e\u4e00\u4e2a\u5982\u6b64\u7b80\u5355\u7684\u6a21\u578b\u6765\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u4e0d\u9519\u7684 AUC\uff01 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5728\u4e0d\u8c03\u6574\u4efb\u4f55\u8d85\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c1d\u8bd5\u4e00\u4e0b\u6807\u7b7e\u7f16\u7801\u7684xgboost\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8ba9\u6211\u4eec\u8fd0\u884c\u4e0a\u9762\u4ee3\u7801\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8800810634234078 Fold = 1 , AUC = 0.886811884948154 Fold = 2 , AUC = 0.8854421433318472 Fold = 3 , AUC = 0.8676319549361007 Fold = 4 , AUC = 0.8714450054900602 \u8fd9\u770b\u8d77\u6765\u5df2\u7ecf\u76f8\u5f53\u4e0d\u9519\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b max_depth \u589e\u52a0\u5230 7 \u548c n_estimators \u589e\u52a0\u5230 200 \u65f6\u7684\u5f97\u5206\u3002 \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8764108944332032 Fold = 1 , AUC = 0.8840708537662638 Fold = 2 , AUC = 0.8816601162613102 Fold = 3 , AUC = 0.8662335762581732 Fold = 4 , AUC = 0.8698983461709926 \u770b\u8d77\u6765\u5e76\u6ca1\u6709\u6539\u5584\u3002 \u8fd9\u8868\u660e\uff0c\u4e00\u4e2a\u6570\u636e\u96c6\u7684\u53c2\u6570\u4e0d\u80fd\u79fb\u690d\u5230\u53e6\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6211\u4eec\u5fc5\u987b\u518d\u6b21\u5c1d\u8bd5\u8c03\u6574\u53c2\u6570\uff0c\u4f46\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u8bf4\u660e\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u4e0d\u8c03\u6574\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c06\u6570\u503c\u7279\u5f81\u7eb3\u5165 xgboost \u6a21\u578b\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u52a0\u5165\u6570\u503c\u7279\u5f81 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : if col not in num_cols : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u56e0\u6b64\uff0c\u6211\u4eec\u4fdd\u7559\u6570\u5b57\u5217\uff0c\u53ea\u662f\u4e0d\u5bf9\u5176\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6700\u7ec8\u7279\u5f81\u77e9\u9635\u5c31\u7531\u6570\u5b57\u5217\uff08\u539f\u6837\uff09\u548c\u7f16\u7801\u5206\u7c7b\u5217\u7ec4\u6210\u4e86\u3002\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\u90fd\u80fd\u8f7b\u677e\u5904\u7406\u8fd9\u79cd\u6df7\u5408\u3002 \u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff0c\u5728\u4f7f\u7528\u7ebf\u6027\u6a21\u578b\uff08\u5982\u903b\u8f91\u56de\u5f52\uff09\u65f6\u4e0d\u5bb9\u5ffd\u89c6\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u4e2a\u811a\u672c\uff01 \u276f python lbl_xgb_num . py Fold = 0 , AUC = 0.9209790185449889 Fold = 1 , AUC = 0.9247157449144706 Fold = 2 , AUC = 0.9269329887598243 Fold = 3 , AUC = 0.9119349082169275 Fold = 4 , AUC = 0.9166408030141667 \u54c7\u54e6 \u8fd9\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u5206\u6570\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c1d\u8bd5\u6dfb\u52a0\u4e00\u4e9b\u529f\u80fd\u3002\u6211\u4eec\u5c06\u63d0\u53d6\u6240\u6709\u5206\u7c7b\u5217\uff0c\u5e76\u521b\u5efa\u6240\u6709\u4e8c\u5ea6\u7ec4\u5408\u3002\u8bf7\u770b\u4e0b\u9762\u4ee3\u7801\u6bb5\u4e2d\u7684 feature_engineering \u51fd\u6570\uff0c\u4e86\u89e3\u5982\u4f55\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002 import itertools import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def feature_engineering ( df , cat_cols ): # \u751f\u6210\u4e24\u4e2a\u7279\u5f81\u7684\u7ec4\u5408 combi = list ( itertools . combinations ( cat_cols , 2 )) for c1 , c2 in combi : df . loc [:, c1 + \"_\" + c2 ] = df [ c1 ] . astype ( str ) + \"_\" + df [ c2 ] . astype ( str ) return df def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) cat_cols = [ c for c in df . columns if c not in num_cols and c not in ( \"kfold\" , \"income\" )] # \u7279\u5f81\u5de5\u7a0b df = feature_engineering ( df , cat_cols ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" )] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u8fd9\u662f\u4ece\u5206\u7c7b\u5217\u4e2d\u521b\u5efa\u7279\u5f81\u7684\u4e00\u79cd\u975e\u5e38\u5e7c\u7a1a\u7684\u65b9\u6cd5\u3002\u6211\u4eec\u5e94\u8be5\u4ed4\u7ec6\u7814\u7a76\u6570\u636e\uff0c\u770b\u770b\u54ea\u4e9b\u7ec4\u5408\u6700\u5408\u7406\u3002\u5982\u679c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5c31\u9700\u8981\u4f7f\u7528\u67d0\u79cd\u7279\u5f81\u9009\u62e9\u6765\u9009\u51fa\u6700\u4f73\u7279\u5f81\u3002\u7a0d\u540e\u6211\u4eec\u5c06\u8be6\u7ec6\u4ecb\u7ecd\u7279\u5f81\u9009\u62e9\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5206\u6570\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9211483465031423 Fold = 1 , AUC = 0.9251499446866125 Fold = 2 , AUC = 0.9262344766486692 Fold = 3 , AUC = 0.9114264068794995 Fold = 4 , AUC = 0.9177914453099201 \u770b\u6765\uff0c\u5373\u4f7f\u4e0d\u6539\u53d8\u4efb\u4f55\u8d85\u53c2\u6570\uff0c\u53ea\u589e\u52a0\u4e00\u4e9b\u7279\u5f81\uff0c\u6211\u4eec\u4e5f\u80fd\u63d0\u9ad8\u4e00\u4e9b\u6298\u53e0\u5f97\u5206\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5c06 max_depth \u589e\u52a0\u5230 7 \u662f\u5426\u6709\u5e2e\u52a9\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9286668430204137 Fold = 1 , AUC = 0.9329340656165378 Fold = 2 , AUC = 0.9319817543218744 Fold = 3 , AUC = 0.919046187194538 Fold = 4 , AUC = 0.9245692057162671 \u6211\u4eec\u518d\u6b21\u6539\u8fdb\u4e86\u6211\u4eec\u7684\u6a21\u578b\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u8fd8\u6ca1\u6709\u4f7f\u7528\u7a00\u6709\u503c\u3001\u4e8c\u503c\u5316\u3001\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u7279\u5f81\u7684\u7ec4\u5408\u4ee5\u53ca\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u3002 \u4ece\u5206\u7c7b\u7279\u5f81\u4e2d\u8fdb\u884c\u7279\u5f81\u5de5\u7a0b\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u4f7f\u7528 \u76ee\u6807\u7f16\u7801 \u3002\u4f46\u662f\uff0c\u60a8\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u8fd9\u53ef\u80fd\u4f1a\u4f7f\u60a8\u7684\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u76ee\u6807\u7f16\u7801\u662f\u4e00\u79cd\u5c06\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u6620\u5c04\u5230\u5176\u5e73\u5747\u76ee\u6807\u503c\u7684\u6280\u672f\uff0c\u4f46\u5fc5\u987b\u59cb\u7ec8\u4ee5\u4ea4\u53c9\u9a8c\u8bc1\u7684\u65b9\u5f0f\u8fdb\u884c\u3002\u8fd9\u610f\u5473\u7740\u9996\u5148\u8981\u521b\u5efa\u6298\u53e0\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u6298\u53e0\u4e3a\u6570\u636e\u7684\u4e0d\u540c\u5217\u521b\u5efa\u76ee\u6807\u7f16\u7801\u7279\u5f81\uff0c\u65b9\u6cd5\u4e0e\u5728\u6298\u53e0\u4e0a\u62df\u5408\u548c\u9884\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u521b\u5efa\u4e86 5 \u4e2a\u6298\u53e0\uff0c\u60a8\u5c31\u5fc5\u987b\u521b\u5efa 5 \u6b21\u76ee\u6807\u7f16\u7801\uff0c\u8fd9\u6837\u6700\u7ec8\uff0c\u60a8\u5c31\u53ef\u4ee5\u4e3a\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u53d8\u91cf\u521b\u5efa\u7f16\u7801\uff0c\u800c\u8fd9\u4e9b\u53d8\u91cf\u5e76\u975e\u6765\u81ea\u540c\u4e00\u4e2a\u6298\u53e0\u3002\u7136\u540e\u5728\u62df\u5408\u6a21\u578b\u65f6\uff0c\u5fc5\u987b\u518d\u6b21\u4f7f\u7528\u76f8\u540c\u7684\u6298\u53e0\u3002\u672a\u89c1\u6d4b\u8bd5\u6570\u636e\u7684\u76ee\u6807\u7f16\u7801\u53ef\u4ee5\u6765\u81ea\u5168\u90e8\u8bad\u7ec3\u6570\u636e\uff0c\u4e5f\u53ef\u4ee5\u662f\u6240\u6709 5 \u4e2a\u6298\u53e0\u7684\u5e73\u5747\u503c\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u540c\u4e00\u4e2a\u6210\u4eba\u6570\u636e\u96c6\u4e0a\u4f7f\u7528\u76ee\u6807\u7f16\u7801\uff0c\u4ee5\u4fbf\u8fdb\u884c\u6bd4\u8f83\u3002 import copy import pandas as pd from sklearn import metrics from sklearn import preprocessing import xgboost as xgb def mean_target_encoding ( data ): df = copy . deepcopy ( data ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) and f not in num_cols ] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) encoded_dfs = [] for fold in range ( 5 ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) for column in features : # \u76ee\u6807\u7f16\u7801 mapping_dict = dict ( df_train . groupby ( column )[ \"income\" ] . mean () ) df_valid . loc [:, column + \"_enc\" ] = df_valid [ column ] . map ( mapping_dict ) encoded_dfs . append ( df_valid ) encoded_df = pd . concat ( encoded_dfs , axis = 0 ) return encoded_df def run ( df , fold ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/adult_folds.csv\" ) df = mean_target_encoding ( df ) for fold_ in range ( 5 ): run ( df , fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u4e0a\u8ff0\u7247\u6bb5\u4e2d\uff0c\u6211\u5728\u8fdb\u884c\u76ee\u6807\u7f16\u7801\u65f6\u5e76\u6ca1\u6709\u5220\u9664\u5206\u7c7b\u5217\u3002\u6211\u4fdd\u7559\u4e86\u6240\u6709\u7279\u5f81\uff0c\u5e76\u5728\u6b64\u57fa\u7840\u4e0a\u6dfb\u52a0\u4e86\u76ee\u6807\u7f16\u7801\u7279\u5f81\u3002\u6b64\u5916\uff0c\u6211\u8fd8\u4f7f\u7528\u4e86\u5e73\u5747\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5e73\u5747\u503c\u3001\u4e2d\u4f4d\u6570\u3001\u6807\u51c6\u504f\u5dee\u6216\u76ee\u6807\u7684\u4efb\u4f55\u5176\u4ed6\u51fd\u6570\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u7ed3\u679c\u3002 Fold = 0 , AUC = 0.9332240662017529 Fold = 1 , AUC = 0.9363551625140347 Fold = 2 , AUC = 0.9375013544556173 Fold = 3 , AUC = 0.92237621307625 Fold = 4 , AUC = 0.9292131180445478 \u4e0d\u9519\uff01\u770b\u6765\u6211\u4eec\u53c8\u6709\u8fdb\u6b65\u4e86\u3002\u4e0d\u8fc7\uff0c\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u5b83\u592a\u5bb9\u6613\u51fa\u73b0\u8fc7\u5ea6\u62df\u5408\u3002\u5f53\u6211\u4eec\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\uff0c\u6700\u597d\u4f7f\u7528\u67d0\u79cd\u5e73\u6ed1\u65b9\u6cd5\u6216\u5728\u7f16\u7801\u503c\u4e2d\u6dfb\u52a0\u566a\u58f0\u3002 Scikit-learn \u7684\u8d21\u732e\u5e93\u4e2d\u6709\u5e26\u5e73\u6ed1\u7684\u76ee\u6807\u7f16\u7801\uff0c\u4f60\u4e5f\u53ef\u4ee5\u521b\u5efa\u81ea\u5df1\u7684\u5e73\u6ed1\u3002\u5e73\u6ed1\u4f1a\u5f15\u5165\u67d0\u79cd\u6b63\u5219\u5316\uff0c\u6709\u52a9\u4e8e\u907f\u514d\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u8fd9\u5e76\u4e0d\u96be\u3002 \u5904\u7406\u5206\u7c7b\u7279\u5f81\u662f\u4e00\u9879\u590d\u6742\u7684\u4efb\u52a1\u3002\u8bb8\u591a\u8d44\u6e90\u4e2d\u90fd\u6709\u5927\u91cf\u4fe1\u606f\u3002\u672c\u7ae0\u5e94\u8be5\u80fd\u5e2e\u52a9\u4f60\u5f00\u59cb\u89e3\u51b3\u5206\u7c7b\u53d8\u91cf\u7684\u4efb\u4f55\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u5bf9\u4e8e\u5927\u591a\u6570\u95ee\u9898\u6765\u8bf4\uff0c\u9664\u4e86\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u4e4b\u5916\uff0c\u4f60\u4e0d\u9700\u8981\u66f4\u591a\u7684\u4e1c\u897f\u3002 \u8981\u8fdb\u4e00\u6b65\u6539\u8fdb\u6a21\u578b\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u66f4\u591a\uff01 \u5728\u672c\u7ae0\u7684\u6700\u540e\uff0c\u6211\u4eec\u4e0d\u80fd\u4e0d\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e00\u79cd\u79f0\u4e3a \u5b9e\u4f53\u5d4c\u5165 \u7684\u6280\u672f\u3002\u5728\u5b9e\u4f53\u5d4c\u5165\u4e2d\uff0c\u7c7b\u522b\u7528\u5411\u91cf\u8868\u793a\u3002\u5728\u4e8c\u503c\u5316\u548c\u72ec\u70ed\u7f16\u7801\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u90fd\u662f\u7528\u5411\u91cf\u6765\u8868\u793a\u7c7b\u522b\u7684\u3002 \u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u6709\u6570\u4ee5\u4e07\u8ba1\u7684\u7c7b\u522b\u600e\u4e48\u529e\uff1f\u8fd9\u5c06\u4f1a\u4ea7\u751f\u5de8\u5927\u7684\u77e9\u9635\uff0c\u6211\u4eec\u5c06\u9700\u8981\u5f88\u957f\u65f6\u95f4\u6765\u8bad\u7ec3\u590d\u6742\u7684\u6a21\u578b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5e26\u6709\u6d6e\u70b9\u503c\u7684\u5411\u91cf\u6765\u8868\u793a\u5b83\u4eec\u3002 \u8fd9\u4e2a\u60f3\u6cd5\u975e\u5e38\u7b80\u5355\u3002\u6bcf\u4e2a\u5206\u7c7b\u7279\u5f81\u90fd\u6709\u4e00\u4e2a\u5d4c\u5165\u5c42\u3002\u56e0\u6b64\uff0c\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u73b0\u5728\u90fd\u53ef\u4ee5\u6620\u5c04\u5230\u4e00\u4e2a\u5d4c\u5165\u5c42\uff08\u5c31\u50cf\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u5c06\u5355\u8bcd\u6620\u5c04\u5230\u5d4c\u5165\u5c42\u4e00\u6837\uff09\u3002\u7136\u540e\uff0c\u6839\u636e\u5176\u7ef4\u5ea6\u91cd\u5851\u8fd9\u4e9b\u5d4c\u5165\u5c42\uff0c\u4f7f\u5176\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c06\u6240\u6709\u6241\u5e73\u5316\u7684\u8f93\u5165\u5d4c\u5165\u5c42\u8fde\u63a5\u8d77\u6765\u3002\u7136\u540e\u6dfb\u52a0\u4e00\u5806\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\uff0c\u5c31\u5927\u529f\u544a\u6210\u4e86\u3002 \u56fe 6\uff1a\u7c7b\u522b\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6216\u5d4c\u5165\u5411\u91cf \u51fa\u4e8e\u67d0\u79cd\u539f\u56e0\uff0c\u6211\u53d1\u73b0\u4f7f\u7528 TF/Keras \u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528 TF/Keras \u5b9e\u73b0\u5b83\u3002\u6b64\u5916\uff0c\u8fd9\u662f\u672c\u4e66\u4e2d\u552f\u4e00\u4e00\u4e2a\u4f7f\u7528 TF/Keras \u7684\u793a\u4f8b\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a PyTorch\uff08\u4f7f\u7528 cat-in-the-dat-ii \u6570\u636e\u96c6\uff09\u4e5f\u975e\u5e38\u5bb9\u6613 import os import gc import joblib import pandas as pd import numpy as np from sklearn import metrics , preprocessing from tensorflow.keras import layers from tensorflow.keras import optimizers from tensorflow.keras.models import Model , load_model from tensorflow.keras import callbacks from tensorflow.keras import backend as K from tensorflow.keras import utils def create_model ( data , catcols ): # \u521b\u5efa\u7a7a\u7684\u8f93\u5165\u5217\u8868\u548c\u8f93\u51fa\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6a21\u578b\u7684\u8f93\u5165\u548c\u8f93\u51fa inputs = [] outputs = [] # \u904d\u5386\u5206\u7c7b\u7279\u5f81\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81 for c in catcols : # \u8ba1\u7b97\u7279\u5f81\u4e2d\u552f\u4e00\u503c\u7684\u6570\u91cf num_unique_values = int ( data [ c ] . nunique ()) # \u8ba1\u7b97\u5d4c\u5165\u7ef4\u5ea6\uff0c\u6700\u5927\u4e0d\u8d85\u8fc750 embed_dim = int ( min ( np . ceil (( num_unique_values ) / 2 ), 50 )) # \u521b\u5efa\u6a21\u578b\u7684\u8f93\u5165\u5c42\uff0c\u6bcf\u4e2a\u7279\u5f81\u5bf9\u5e94\u4e00\u4e2a\u8f93\u5165 inp = layers . Input ( shape = ( 1 ,)) # \u521b\u5efa\u5d4c\u5165\u5c42\uff0c\u5c06\u5206\u7c7b\u7279\u5f81\u6620\u5c04\u5230\u4f4e\u7ef4\u5ea6\u7684\u8fde\u7eed\u5411\u91cf out = layers . Embedding ( num_unique_values + 1 , embed_dim , name = c )( inp ) # \u5bf9\u5d4c\u5165\u5c42\u8fdb\u884c\u7a7a\u95f4\u4e22\u5f03\uff08Dropout\uff09 out = layers . SpatialDropout1D ( 0.3 )( out ) # \u5c06\u5d4c\u5165\u5c42\u7684\u5f62\u72b6\u91cd\u65b0\u8c03\u6574\u4e3a\u4e00\u7ef4 out = layers . Reshape ( target_shape = ( embed_dim ,))( out ) # \u5c06\u8f93\u5165\u548c\u8f93\u51fa\u6dfb\u52a0\u5230\u5bf9\u5e94\u7684\u5217\u8868\u4e2d inputs . append ( inp ) outputs . append ( out ) # \u4f7f\u7528Concatenate\u5c42\u5c06\u6240\u6709\u7684\u5d4c\u5165\u5c42\u8f93\u51fa\u8fde\u63a5\u5728\u4e00\u8d77 x = layers . Concatenate ()( outputs ) # \u5bf9\u8fde\u63a5\u540e\u7684\u6570\u636e\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u53e6\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u8f93\u51fa\u5c42\uff0c\u5177\u67092\u4e2a\u795e\u7ecf\u5143\uff08\u7528\u4e8e\u4e8c\u8fdb\u5236\u5206\u7c7b\uff09\uff0c\u5e76\u4f7f\u7528softmax\u6fc0\u6d3b\u51fd\u6570 y = layers . Dense ( 2 , activation = \"softmax\" )( x ) # \u521b\u5efa\u6a21\u578b\uff0c\u5c06\u8f93\u5165\u548c\u8f93\u51fa\u4f20\u9012\u7ed9Model\u6784\u9020\u51fd\u6570 model = Model ( inputs = inputs , outputs = y ) # \u7f16\u8bd1\u6a21\u578b\uff0c\u6307\u5b9a\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668 model . compile ( loss = 'binary_crossentropy' , optimizer = 'adam' ) # \u8fd4\u56de\u521b\u5efa\u7684\u6a21\u578b return model def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for feat in features : lbl_enc = preprocessing . LabelEncoder () df . loc [:, feat ] = lbl_enc . fit_transform ( df [ feat ] . values ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) model = create_model ( df , features ) xtrain = [ df_train [ features ] . values [:, k ] for k in range ( len ( features ))] xvalid = [ df_valid [ features ] . values [:, k ] for k in range ( len ( features )) ] ytrain = df_train . target . values yvalid = df_valid . target . values ytrain_cat = utils . to_categorical ( ytrain ) yvalid_cat = utils . to_categorical ( yvalid ) model . fit ( xtrain , ytrain_cat , validation_data = ( xvalid , yvalid_cat ), verbose = 1 , batch_size = 1024 , epochs = 3 ) valid_preds = model . predict ( xvalid )[:, 1 ] print ( metrics . roc_auc_score ( yvalid , valid_preds )) K . clear_session () if __name__ == \"__main__\" : run ( 0 ) run ( 1 ) run ( 2 ) run ( 3 ) run ( 4 ) \u4f60\u4f1a\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u6548\u679c\u6700\u597d\uff0c\u800c\u4e14\u5982\u679c\u4f60\u6709 GPU\uff0c\u901f\u5ea6\u4e5f\u8d85\u5feb\uff01\u8fd9\u79cd\u65b9\u6cd5\u8fd8\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\uff0c\u800c\u4e14\u4f60\u65e0\u9700\u62c5\u5fc3\u7279\u5f81\u5de5\u7a0b\uff0c\u56e0\u4e3a\u795e\u7ecf\u7f51\u7edc\u4f1a\u81ea\u884c\u5904\u7406\u3002\u5728\u5904\u7406\u5927\u91cf\u5206\u7c7b\u7279\u5f81\u6570\u636e\u96c6\u65f6\uff0c\u8fd9\u7edd\u5bf9\u503c\u5f97\u4e00\u8bd5\u3002\u5f53\u5d4c\u5165\u5927\u5c0f\u4e0e\u552f\u4e00\u7c7b\u522b\u7684\u6570\u91cf\u76f8\u540c\u65f6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u7f16\u7801\uff08one-hot-encoding\uff09\u3002 \u672c\u7ae0\u57fa\u672c\u4e0a\u90fd\u662f\u5173\u4e8e\u7279\u5f81\u5de5\u7a0b\u7684\u3002\u8ba9\u6211\u4eec\u5728\u4e0b\u4e00\u7ae0\u4e2d\u770b\u770b\u5982\u4f55\u5728\u6570\u5b57\u7279\u5f81\u548c\u4e0d\u540c\u7c7b\u578b\u7279\u5f81\u7684\u7ec4\u5408\u65b9\u9762\u8fdb\u884c\u66f4\u591a\u7684\u7279\u5f81\u5de5\u7a0b\u3002","title":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf"},{"location":"%E5%A4%84%E7%90%86%E5%88%86%E7%B1%BB%E5%8F%98%E9%87%8F/#_1","text":"\u5f88\u591a\u4eba\u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u65f6\u90fd\u4f1a\u9047\u5230\u5f88\u591a\u56f0\u96be\uff0c\u56e0\u6b64\u8fd9\u503c\u5f97\u7528\u6574\u6574\u4e00\u7ae0\u7684\u7bc7\u5e45\u6765\u8ba8\u8bba\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u5c06\u8bb2\u8ff0\u4e0d\u540c\u7c7b\u578b\u7684\u5206\u7c7b\u6570\u636e\uff0c\u4ee5\u53ca\u5982\u4f55\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u3002 \u4ec0\u4e48\u662f\u5206\u7c7b\u53d8\u91cf\uff1f \u5206\u7c7b\u53d8\u91cf/\u7279\u5f81\u662f\u6307\u4efb\u4f55\u7279\u5f81\u7c7b\u578b\uff0c\u53ef\u5206\u4e3a\u4e24\u5927\u7c7b\uff1a - \u65e0\u5e8f - \u6709\u5e8f \u65e0\u5e8f\u53d8\u91cf \u662f\u6307\u6709\u4e24\u4e2a\u6216\u4e24\u4e2a\u4ee5\u4e0a\u7c7b\u522b\u7684\u53d8\u91cf\uff0c\u8fd9\u4e9b\u7c7b\u522b\u6ca1\u6709\u4efb\u4f55\u76f8\u5173\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u5982\u679c\u5c06\u6027\u522b\u5206\u4e3a\u4e24\u7ec4\uff0c\u5373\u7537\u6027\u548c\u5973\u6027\uff0c\u5219\u53ef\u5c06\u5176\u89c6\u4e3a\u540d\u4e49\u53d8\u91cf\u3002 \u6709\u5e8f\u53d8\u91cf \u5219\u6709 \"\u7b49\u7ea7 \"\u6216\u7c7b\u522b\uff0c\u5e76\u6709\u7279\u5b9a\u7684\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u4e00\u4e2a\u987a\u5e8f\u5206\u7c7b\u53d8\u91cf\u53ef\u4ee5\u662f\u4e00\u4e2a\u5177\u6709\u4f4e\u3001\u4e2d\u3001\u9ad8\u4e09\u4e2a\u4e0d\u540c\u7b49\u7ea7\u7684\u7279\u5f81\u3002\u987a\u5e8f\u5f88\u91cd\u8981\u3002 \u5c31\u5b9a\u4e49\u800c\u8a00\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u53d8\u91cf\u5206\u4e3a \u4e8c\u5143\u53d8\u91cf \uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u7c7b\u522b\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u6709\u4e9b\u4eba\u751a\u81f3\u628a\u5206\u7c7b\u53d8\u91cf\u79f0\u4e3a \" \u5faa\u73af \"\u53d8\u91cf\u3002\u5468\u671f\u53d8\u91cf\u4ee5 \"\u5468\u671f \"\u7684\u5f62\u5f0f\u5b58\u5728\uff0c\u4f8b\u5982\u4e00\u5468\u4e2d\u7684\u5929\u6570\uff1a \u5468\u65e5\u3001\u5468\u4e00\u3001\u5468\u4e8c\u3001\u5468\u4e09\u3001\u5468\u56db\u3001\u5468\u4e94\u548c\u5468\u516d\u3002\u5468\u516d\u8fc7\u540e\uff0c\u53c8\u662f\u5468\u65e5\u3002\u8fd9\u5c31\u662f\u4e00\u4e2a\u5faa\u73af\u3002\u53e6\u4e00\u4e2a\u4f8b\u5b50\u662f\u4e00\u5929\u4e2d\u7684\u5c0f\u65f6\u6570\uff0c\u5982\u679c\u6211\u4eec\u5c06\u5b83\u4eec\u89c6\u4e3a\u7c7b\u522b\u7684\u8bdd\u3002 \u5206\u7c7b\u53d8\u91cf\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9a\u4e49\uff0c\u5f88\u591a\u4eba\u4e5f\u8c08\u5230\u8981\u6839\u636e\u5206\u7c7b\u53d8\u91cf\u7684\u7c7b\u578b\u6765\u5904\u7406\u4e0d\u540c\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u4e0d\u8fc7\uff0c\u6211\u8ba4\u4e3a\u6ca1\u6709\u5fc5\u8981\u8fd9\u6837\u505a\u3002\u6240\u6709\u6d89\u53ca\u5206\u7c7b\u53d8\u91cf\u7684\u95ee\u9898\u90fd\u53ef\u4ee5\u7528\u540c\u6837\u7684\u65b9\u6cd5\u5904\u7406\u3002 \u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e00\u4e2a\u6570\u636e\u96c6\uff08\u4e00\u5982\u65e2\u5f80\uff09\u3002\u8981\u4e86\u89e3\u5206\u7c7b\u53d8\u91cf\uff0c\u6700\u597d\u7684\u514d\u8d39\u6570\u636e\u96c6\u4e4b\u4e00\u662f Kaggle \u5206\u7c7b\u7279\u5f81\u7f16\u7801\u6311\u6218\u8d5b\u4e2d\u7684 cat-in-the-dat \u3002\u5171\u6709\u4e24\u4e2a\u6311\u6218\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6311\u6218\u7684\u6570\u636e\uff0c\u56e0\u4e3a\u5b83\u6bd4\u524d\u4e00\u4e2a\u7248\u672c\u6709\u66f4\u591a\u53d8\u91cf\uff0c\u96be\u5ea6\u4e5f\u66f4\u5927\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u3002 \u56fe 1\uff1aCat-in-the-dat-ii challenge\u90e8\u5206\u6570\u636e\u5c55\u793a \u6570\u636e\u96c6\u7531\u5404\u79cd\u5206\u7c7b\u53d8\u91cf\u7ec4\u6210\uff1a \u65e0\u5e8f \u6709\u5e8f \u5faa\u73af \u4e8c\u5143 \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6240\u6709\u5b58\u5728\u7684\u53d8\u91cf\u548c\u76ee\u6807\u53d8\u91cf\u7684\u5b50\u96c6\u3002 \u8fd9\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002 \u76ee\u6807\u53d8\u91cf\u5bf9\u4e8e\u6211\u4eec\u5b66\u4e60\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\u5e76\u4e0d\u5341\u5206\u91cd\u8981\uff0c\u4f46\u6700\u7ec8\u6211\u4eec\u5c06\u5efa\u7acb\u4e00\u4e2a\u7aef\u5230\u7aef\u6a21\u578b\uff0c\u56e0\u6b64\u8ba9\u6211\u4eec\u770b\u770b\u56fe 2 \u4e2d\u7684\u76ee\u6807\u53d8\u91cf\u5206\u5e03\u3002\u6211\u4eec\u770b\u5230\u76ee\u6807\u662f \u504f\u659c \u7684\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8fd9\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u6765\u8bf4\uff0c\u6700\u597d\u7684\u6307\u6807\u662f ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528\u7cbe\u786e\u5ea6\u548c\u53ec\u56de\u7387\uff0c\u4f46 AUC \u7ed3\u5408\u4e86\u8fd9\u4e24\u4e2a\u6307\u6807\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 AUC \u6765\u8bc4\u4f30\u6211\u4eec\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u7684\u6a21\u578b\u3002 \u56fe 2\uff1a\u6807\u7b7e\u8ba1\u6570\u3002X \u8f74\u8868\u793a\u6807\u7b7e\uff0cY \u8f74\u8868\u793a\u6807\u7b7e\u8ba1\u6570 \u603b\u4f53\u800c\u8a00\uff0c\u6709\uff1a 5\u4e2a\u4e8c\u5143\u53d8\u91cf 10\u4e2a\u65e0\u5e8f\u53d8\u91cf 6\u4e2a\u6709\u5e8f\u53d8\u91cf 2\u4e2a\u5faa\u73af\u53d8\u91cf 1\u4e2a\u76ee\u6807\u53d8\u91cf \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u96c6\u4e2d\u7684 ord_2 \u7279\u5f81\u3002\u5b83\u5305\u62ec6\u4e2a\u4e0d\u540c\u7684\u7c7b\u522b\uff1a - \u51b0\u51bb - \u6e29\u6696 - \u5bd2\u51b7 - \u8f83\u70ed - \u70ed - \u975e\u5e38\u70ed \u6211\u4eec\u5fc5\u987b\u77e5\u9053\uff0c\u8ba1\u7b97\u673a\u65e0\u6cd5\u7406\u89e3\u6587\u672c\u6570\u636e\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u662f\u521b\u5efa\u4e00\u4e2a\u5b57\u5178\uff0c\u5c06\u8fd9\u4e9b\u503c\u6620\u5c04\u4e3a\u4ece 0 \u5230 N-1 \u7684\u6570\u5b57\uff0c\u5176\u4e2d N \u662f\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7c7b\u522b\u7684\u603b\u6570\u3002 # \u6620\u5c04\u5b57\u5178 mapping = { \"Freezing\" : 0 , \"Warm\" : 1 , \"Cold\" : 2 , \"Boiling Hot\" : 3 , \"Hot\" : 4 , \"Lava Hot\" : 5 } \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8bfb\u53d6\u6570\u636e\u96c6\uff0c\u5e76\u8f7b\u677e\u5730\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002 import pandas as pd # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u53d6*ord_2*\u5217\uff0c\u5e76\u4f7f\u7528\u6620\u5c04\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57 df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. map ( mapping ) \u6620\u5c04\u524d\u7684\u6570\u503c\u8ba1\u6570\uff1a df .* ord_2 *. value_counts () Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : * ord_2 * , dtype : int64 \u6620\u5c04\u540e\u7684\u6570\u503c\u8ba1\u6570\uff1a 0.0 142726 1.0 124239 2.0 97822 3.0 84790 4.0 67508 5.0 64840 Name : * ord_2 * , dtype : int64 \u8fd9\u79cd\u5206\u7c7b\u53d8\u91cf\u7684\u7f16\u7801\u65b9\u5f0f\u88ab\u79f0\u4e3a\u6807\u7b7e\u7f16\u7801\uff08Label Encoding\uff09\u6211\u4eec\u5c06\u6bcf\u4e2a\u7c7b\u522b\u7f16\u7801\u4e3a\u4e00\u4e2a\u6570\u5b57\u6807\u7b7e\u3002 \u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 LabelEncoder \u8fdb\u884c\u7f16\u7801\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u5c06\u7f3a\u5931\u503c\u586b\u5145\u4e3a\"NONE\" df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. fillna ( \"NONE\" ) # LabelEncoder\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u8f6c\u6362\u6570\u636e df . loc [:, \"*ord_2*\" ] = lbl_enc . fit_transform ( df .* ord_2 *. values ) \u4f60\u4f1a\u770b\u5230\u6211\u4f7f\u7528\u4e86 pandas \u7684 fillna\u3002\u539f\u56e0\u662f scikit-learn \u7684 LabelEncoder \u65e0\u6cd5\u5904\u7406 NaN \u503c\uff0c\u800c ord_2 \u5217\u4e2d\u6709 NaN \u503c\u3002 \u6211\u4eec\u53ef\u4ee5\u5728\u8bb8\u591a\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u4e2d\u76f4\u63a5\u4f7f\u7528\u5b83\uff1a - \u51b3\u7b56\u6811 - \u968f\u673a\u68ee\u6797 - \u63d0\u5347\u6811 - \u6216\u4efb\u4f55\u4e00\u79cd\u63d0\u5347\u6811\u6a21\u578b - XGBoost - GBM - LightGBM \u8fd9\u79cd\u7f16\u7801\u65b9\u5f0f\u4e0d\u80fd\u7528\u4e8e\u7ebf\u6027\u6a21\u578b\u3001\u652f\u6301\u5411\u91cf\u673a\u6216\u795e\u7ecf\u7f51\u7edc\uff0c\u56e0\u4e3a\u5b83\u4eec\u5e0c\u671b\u6570\u636e\u662f\u6807\u51c6\u5316\u7684\u3002 \u5bf9\u4e8e\u8fd9\u4e9b\u7c7b\u578b\u7684\u6a21\u578b\uff0c\u6211\u4eec\u53ef\u4ee5\u5bf9\u6570\u636e\u8fdb\u884c\u4e8c\u503c\u5316\uff08binarize\uff09\u5904\u7406\u3002 Freezing --> 0 --> 0 0 0 Warm --> 1 --> 0 0 1 Cold --> 2 --> 0 1 0 Boiling Hot --> 3 --> 0 1 1 Hot --> 4 --> 1 0 0 Lava Hot --> 5 --> 1 0 1 \u8fd9\u53ea\u662f\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\uff0c\u7136\u540e\u518d\u8f6c\u6362\u4e3a\u4e8c\u503c\u5316\u8868\u793a\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u628a\u4e00\u4e2a\u7279\u5f81\u5206\u6210\u4e86\u4e09\u4e2a\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u7279\u5f81\uff08\u6216\u5217\uff09\u3002\u5982\u679c\u6211\u4eec\u6709\u66f4\u591a\u7684\u7c7b\u522b\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u5206\u6210\u66f4\u591a\u7684\u5217\u3002 \u5982\u679c\u6211\u4eec\u7528\u7a00\u758f\u683c\u5f0f\u5b58\u50a8\u5927\u91cf\u4e8c\u503c\u5316\u53d8\u91cf\uff0c\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u5b58\u50a8\u8fd9\u4e9b\u53d8\u91cf\u3002\u7a00\u758f\u683c\u5f0f\u4e0d\u8fc7\u662f\u4e00\u79cd\u5728\u5185\u5b58\u4e2d\u5b58\u50a8\u6570\u636e\u7684\u8868\u793a\u6216\u65b9\u5f0f\uff0c\u5728\u8fd9\u79cd\u683c\u5f0f\u4e2d\uff0c\u4f60\u5e76\u4e0d\u5b58\u50a8\u6240\u6709\u7684\u503c\uff0c\u800c\u53ea\u5b58\u50a8\u91cd\u8981\u7684\u503c\u3002\u5728\u4e0a\u8ff0\u4e8c\u8fdb\u5236\u53d8\u91cf\u7684\u60c5\u51b5\u4e2d\uff0c\u6700\u91cd\u8981\u7684\u5c31\u662f\u6709 1 \u7684\u5730\u65b9\u3002 \u5f88\u96be\u60f3\u8c61\u8fd9\u6837\u7684\u683c\u5f0f\uff0c\u4f46\u4e3e\u4e2a\u4f8b\u5b50\u5c31\u4f1a\u660e\u767d\u3002 \u5047\u8bbe\u4e0a\u9762\u7684\u6570\u636e\u5e27\u4e2d\u53ea\u6709\u4e00\u4e2a\u7279\u5f81\uff1a ord_2 \u3002 Index Feature 0 Warm 1 Hot 2 Lava hot \u76ee\u524d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6570\u636e\u96c6\u4e2d\u7684\u4e09\u4e2a\u6837\u672c\u3002\u8ba9\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u4e8c\u503c\u8868\u793a\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u6709\u4e09\u4e2a\u9879\u76ee\u3002 \u8fd9\u4e09\u4e2a\u9879\u76ee\u5c31\u662f\u4e09\u4e2a\u7279\u5f81\u3002 Index Feature_0 Feature_1 Feature_2 0 0 0 1 1 1 0 0 2 1 0 1 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u7279\u5f81\u5b58\u50a8\u5728\u4e00\u4e2a\u6709 3 \u884c 3 \u5217\uff083x3\uff09\u7684\u77e9\u9635\u4e2d\u3002\u77e9\u9635\u7684\u6bcf\u4e2a\u5143\u7d20\u5360\u7528 8 \u4e2a\u5b57\u8282\u3002\u56e0\u6b64\uff0c\u8fd9\u4e2a\u6570\u7ec4\u7684\u603b\u5185\u5b58\u9700\u6c42\u4e3a 8x3x3 = 72 \u5b57\u8282\u3002 \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e2a\u7b80\u5355\u7684 python \u4ee3\u7801\u6bb5\u6765\u68c0\u67e5\u8fd9\u4e00\u70b9\u3002 import numpy as np example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) print ( example . nbytes ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u6253\u5370\u51fa 72\uff0c\u5c31\u50cf\u6211\u4eec\u4e4b\u524d\u8ba1\u7b97\u7684\u90a3\u6837\u3002\u4f46\u6211\u4eec\u9700\u8981\u5b58\u50a8\u8fd9\u4e2a\u77e9\u9635\u7684\u6240\u6709\u5143\u7d20\u5417\uff1f\u5982\u524d\u6240\u8ff0\uff0c\u6211\u4eec\u53ea\u5bf9 1 \u611f\u5174\u8da3\u30020 \u5e76\u4e0d\u91cd\u8981\uff0c\u56e0\u4e3a\u4efb\u4f55\u4e0e 0 \u76f8\u4e58\u7684\u5143\u7d20\u90fd\u662f 0\uff0c\u800c 0 \u4e0e\u4efb\u4f55\u5143\u7d20\u76f8\u52a0\u6216\u76f8\u51cf\u4e5f\u6ca1\u6709\u4efb\u4f55\u533a\u522b\u3002\u53ea\u7528 1 \u8868\u793a\u77e9\u9635\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u67d0\u79cd\u5b57\u5178\u65b9\u6cd5\uff0c\u5176\u4e2d\u952e\u662f\u884c\u548c\u5217\u7684\u7d22\u5f15\uff0c\u503c\u662f 1\uff1a ( 0 , 2 ) 1 ( 1 , 0 ) 1 ( 2 , 0 ) 1 ( 2 , 2 ) 1 \u8fd9\u6837\u7684\u7b26\u53f7\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u5b58\u50a8\u56db\u4e2a\u503c\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u3002\u4f7f\u7528\u7684\u603b\u5185\u5b58\u4e3a 8x4 = 32 \u5b57\u8282\u3002\u4efb\u4f55 numpy \u6570\u7ec4\u90fd\u53ef\u4ee5\u901a\u8fc7\u7b80\u5355\u7684 python \u4ee3\u7801\u8f6c\u6362\u4e3a\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) sparse_example = sparse . csr_matrix ( example ) print ( sparse_example . data . nbytes ) \u8fd9\u5c06\u6253\u5370 32\uff0c\u6bd4\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u5c11\u4e86\u8fd9\u4e48\u591a\uff01\u7a00\u758f csr \u77e9\u9635\u7684\u603b\u5927\u5c0f\u662f\u4e09\u4e2a\u503c\u7684\u603b\u548c\u3002 print ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) \u8fd9\u5c06\u6253\u5370\u51fa 64 \u4e2a\u5143\u7d20\uff0c\u4ecd\u7136\u5c11\u4e8e\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u3002\u9057\u61be\u7684\u662f\uff0c\u6211\u4e0d\u4f1a\u8be6\u7ec6\u4ecb\u7ecd\u8fd9\u4e9b\u5143\u7d20\u3002\u4f60\u53ef\u4ee5\u5728 scipy \u6587\u6863\u4e2d\u4e86\u89e3\u66f4\u591a\u3002\u5f53\u6211\u4eec\u62e5\u6709\u66f4\u5927\u7684\u6570\u7ec4\u65f6\uff0c\u6bd4\u5982\u8bf4\u62e5\u6709\u6570\u5343\u4e2a\u6837\u672c\u548c\u6570\u4e07\u4e2a\u7279\u5f81\u7684\u6570\u7ec4\uff0c\u5927\u5c0f\u5dee\u5f02\u5c31\u4f1a\u53d8\u5f97\u975e\u5e38\u5927\u3002\u4f8b\u5982\uff0c\u6211\u4eec\u4f7f\u7528\u57fa\u4e8e\u8ba1\u6570\u7279\u5f81\u7684\u6587\u672c\u6570\u636e\u96c6\u3002 import numpy as np from scipy import sparse n_rows = 10000 n_cols = 100000 # \u751f\u6210\u7b26\u5408\u4f2f\u52aa\u5229\u5206\u5e03\u7684\u968f\u673a\u6570\u7ec4\uff0c\u7ef4\u5ea6\u4e3a[10000, 100000] example = np . random . binomial ( 1 , p = 0.05 , size = ( n_rows , n_cols )) print ( f \"Size of dense array: { example . nbytes } \" ) # \u5c06\u968f\u673a\u77e9\u9635\u8f6c\u6362\u4e3a\u6d17\u6f31\u77e9\u9635 sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u8fd9\u5c06\u6253\u5370\uff1a Size of dense array : 8000000000 Size of sparse array : 399932496 Full size of sparse array : 599938748 \u56e0\u6b64\uff0c\u5bc6\u96c6\u9635\u5217\u9700\u8981 ~8000MB \u6216\u5927\u7ea6 8GB \u5185\u5b58\u3002\u800c\u7a00\u758f\u9635\u5217\u53ea\u5360\u7528 399MB \u5185\u5b58\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5f53\u6211\u4eec\u7684\u7279\u5f81\u4e2d\u6709\u5927\u91cf\u96f6\u65f6\uff0c\u6211\u4eec\u66f4\u559c\u6b22\u7a00\u758f\u9635\u5217\u800c\u4e0d\u662f\u5bc6\u96c6\u9635\u5217\u7684\u539f\u56e0\u3002 \u8bf7\u6ce8\u610f\uff0c\u7a00\u758f\u77e9\u9635\u6709\u591a\u79cd\u4e0d\u540c\u7684\u8868\u793a\u65b9\u6cd5\u3002\u8fd9\u91cc\u6211\u53ea\u5c55\u793a\u4e86\u5176\u4e2d\u4e00\u79cd\uff08\u53ef\u80fd\u4e5f\u662f\u6700\u5e38\u7528\u7684\uff09\u65b9\u6cd5\u3002\u6df1\u5165\u63a2\u8ba8\u8fd9\u4e9b\u65b9\u6cd5\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\uff0c\u56e0\u6b64\u7559\u7ed9\u8bfb\u8005\u4e00\u4e2a\u7ec3\u4e60\u3002 \u5c3d\u7ba1\u4e8c\u503c\u5316\u7279\u5f81\u7684\u7a00\u758f\u8868\u793a\u6bd4\u5176\u5bc6\u96c6\u8868\u793a\u6240\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u4f46\u5bf9\u4e8e\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\uff0c\u8fd8\u6709\u4e00\u79cd\u8f6c\u6362\u6240\u5360\u7528\u7684\u5185\u5b58\u66f4\u5c11\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \" \u72ec\u70ed\u7f16\u7801 \"\u3002 \u72ec\u70ed\u7f16\u7801\u4e5f\u662f\u4e00\u79cd\u4e8c\u503c\u7f16\u7801\uff0c\u56e0\u4e3a\u53ea\u6709 0 \u548c 1 \u4e24\u4e2a\u503c\u3002\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5b83\u5e76\u4e0d\u662f\u4e8c\u503c\u8868\u793a\u6cd5\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0b\u9762\u7684\u4f8b\u5b50\u6765\u7406\u89e3\u5b83\u7684\u8868\u793a\u6cd5\u3002 \u5047\u8bbe\u6211\u4eec\u7528\u4e00\u4e2a\u5411\u91cf\u6765\u8868\u793a ord_2 \u53d8\u91cf\u7684\u6bcf\u4e2a\u7c7b\u522b\u3002\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e0e ord_2 \u53d8\u91cf\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u6bcf\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u90fd\u662f 6\uff0c\u5e76\u4e14\u9664\u4e86\u4e00\u4e2a\u4f4d\u7f6e\u5916\uff0c\u5176\u4ed6\u4f4d\u7f6e\u90fd\u662f 0\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u7279\u6b8a\u7684\u5411\u91cf\u8868\u3002 Freezing 0 0 0 0 0 1 Warm 0 0 0 0 1 0 Cold 0 0 0 1 0 0 Boiling Hot 0 0 1 0 0 0 Hot 0 1 0 0 0 0 Lava Hot 1 0 0 0 0 0 \u6211\u4eec\u770b\u5230\u5411\u91cf\u7684\u5927\u5c0f\u662f 1x6\uff0c\u5373\u5411\u91cf\u4e2d\u67096\u4e2a\u5143\u7d20\u3002\u8fd9\u4e2a\u6570\u5b57\u662f\u600e\u4e48\u6765\u7684\u5462\uff1f\u5982\u679c\u4f60\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u5982\u524d\u6240\u8ff0\uff0c\u67096\u4e2a\u7c7b\u522b\u3002\u5728\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u65f6\uff0c\u5411\u91cf\u7684\u5927\u5c0f\u5fc5\u987b\u4e0e\u6211\u4eec\u8981\u67e5\u770b\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u6bcf\u4e2a\u5411\u91cf\u90fd\u6709\u4e00\u4e2a 1\uff0c\u5176\u4f59\u6240\u6709\u503c\u90fd\u662f 0\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u7528\u8fd9\u4e9b\u7279\u5f81\u6765\u4ee3\u66ff\u4e4b\u524d\u7684\u4e8c\u503c\u5316\u7279\u5f81\uff0c\u770b\u770b\u80fd\u8282\u7701\u591a\u5c11\u5185\u5b58\u3002 \u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\u4ee5\u524d\u7684\u6570\u636e\uff0c\u5b83\u770b\u8d77\u6765\u5982\u4e0b\uff1a Index Feature 0 Warm 1 Hot 2 Lava hot \u6bcf\u4e2a\u6837\u672c\u67093\u4e2a\u7279\u5f81\u3002\u4f46\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u72ec\u70ed\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 6\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u67096\u4e2a\u7279\u5f81\uff0c\u800c\u4e0d\u662f3\u4e2a\u3002 Index F_0 F_1 F_2 F_3 F_4 F_5 0 0 0 0 0 1 0 1 0 1 0 0 0 0 2 1 0 1 0 0 0 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 6 \u4e2a\u7279\u5f81\uff0c\u800c\u5728\u8fd9\u4e2a 3x6 \u6570\u7ec4\u4e2d\uff0c\u53ea\u6709 3 \u4e2a1\u3002\u4f7f\u7528 numpy \u8ba1\u7b97\u5927\u5c0f\u4e0e\u4e8c\u503c\u5316\u5927\u5c0f\u8ba1\u7b97\u811a\u672c\u975e\u5e38\u76f8\u4f3c\u3002\u4f60\u9700\u8981\u6539\u53d8\u7684\u53ea\u662f\u6570\u7ec4\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6bb5\u4ee3\u7801\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 0 , 0 , 1 , 0 ], [ 0 , 1 , 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 , 0 , 0 ] ] ) print ( f \"Size of dense array: { example . nbytes } \" ) sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u6253\u5370\u5185\u5b58\u5927\u5c0f\u4e3a\uff1a Size of dense array : 144 Size of sparse array : 24 Full size of sparse array : 52 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5bc6\u96c6\u77e9\u9635\u7684\u5927\u5c0f\u8fdc\u8fdc\u5927\u4e8e\u4e8c\u503c\u5316\u77e9\u9635\u7684\u5927\u5c0f\u3002\u4e0d\u8fc7\uff0c\u7a00\u758f\u6570\u7ec4\u7684\u5927\u5c0f\u8981\u66f4\u5c0f\u3002\u8ba9\u6211\u4eec\u7528\u66f4\u5927\u7684\u6570\u7ec4\u6765\u8bd5\u8bd5\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 OneHotEncoder \u5c06\u5305\u542b 1001 \u4e2a\u7c7b\u522b\u7684\u7279\u5f81\u6570\u7ec4\u8f6c\u6362\u4e3a\u5bc6\u96c6\u77e9\u9635\u548c\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from sklearn import preprocessing # \u751f\u6210\u7b26\u5408\u5747\u5300\u5206\u5e03\u7684\u968f\u673a\u6574\u6570\uff0c\u7ef4\u5ea6\u4e3a[1000000, 10000000] example = np . random . randint ( 1000 , size = 1000000 ) # \u72ec\u70ed\u7f16\u7801\uff0c\u975e\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = False ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of dense array: { ohe_example . nbytes } \" ) # \u72ec\u70ed\u7f16\u7801\uff0c\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = True ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of sparse array: { ohe_example . data . nbytes } \" ) full_size = ( ohe_example . data . nbytes + ohe_example . indptr . nbytes + ohe_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u4e0a\u9762\u4ee3\u7801\u6253\u5370\u7684\u8f93\u51fa\uff1a Size of dense array : 8000000000 Size of sparse array : 8000000 Full size of sparse array : 16000004 \u8fd9\u91cc\u7684\u5bc6\u96c6\u9635\u5217\u5927\u5c0f\u7ea6\u4e3a 8GB\uff0c\u7a00\u758f\u9635\u5217\u4e3a 8MB\u3002\u5982\u679c\u53ef\u4ee5\u9009\u62e9\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff1f\u5728\u6211\u770b\u6765\uff0c\u9009\u62e9\u5f88\u7b80\u5355\uff0c\u4e0d\u662f\u5417\uff1f \u8fd9\u4e09\u79cd\u65b9\u6cd5\uff08\u6807\u7b7e\u7f16\u7801\u3001\u7a00\u758f\u77e9\u9635\u3001\u72ec\u70ed\u7f16\u7801\uff09\u662f\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u7684\u6700\u91cd\u8981\u65b9\u6cd5\u3002\u4e0d\u8fc7\uff0c\u4f60\u8fd8\u53ef\u4ee5\u7528\u5f88\u591a\u5176\u4ed6\u4e0d\u540c\u7684\u65b9\u6cd5\u6765\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u3002\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u6570\u503c\u53d8\u91cf\u5c31\u662f\u5176\u4e2d\u7684\u4e00\u4e2a\u4f8b\u5b50\u3002 \u5047\u8bbe\u6211\u4eec\u56de\u5230\u4e4b\u524d\u7684\u5206\u7c7b\u7279\u5f81\u6570\u636e\uff08\u539f\u59cb\u6570\u636e\u4e2d\u7684 cat-in-the-dat-ii\uff09\u3002\u5728\u6570\u636e\u4e2d\uff0c ord_2 \u7684\u503c\u4e3a\u201c\u70ed\u201c\u7684 id \u6709\u591a\u5c11\uff1f \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8ba1\u7b97\u6570\u636e\u7684\u5f62\u72b6\uff08shape\uff09\u8f7b\u677e\u8ba1\u7b97\u51fa\u8fd9\u4e2a\u503c\uff0c\u5176\u4e2d ord_2 \u5217\u7684\u503c\u4e3a Boiling Hot \u3002 In [ X ]: df [ df . ord_2 == \"Boiling Hot\" ] . shape Out [ X ]: ( 84790 , 25 ) \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 84790 \u6761\u8bb0\u5f55\u5177\u6709\u6b64\u503c\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 pandas \u4e2d\u7684 groupby \u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u8be5\u503c\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . count () Out [ X ]: ord_2 Boiling Hot 84790 Cold 97822 Freezing 142726 Hot 67508 Lava Hot 64840 Warm 124239 Name : id , dtype : int64 \u5982\u679c\u6211\u4eec\u53ea\u662f\u5c06 ord_2 \u5217\u66ff\u6362\u4e3a\u5176\u8ba1\u6570\u503c\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u5c06\u5176\u8f6c\u6362\u4e3a\u4e00\u79cd\u6570\u503c\u7279\u5f81\u4e86\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 transform \u51fd\u6570\u548c groupby \u6765\u521b\u5efa\u65b0\u5217\u6216\u66ff\u6362\u8fd9\u4e00\u5217\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . transform ( \"count\" ) Out [ X ]: 0 67508.0 1 124239.0 2 142726.0 3 64840.0 4 97822.0 ... 599995 142726.0 599996 84790.0 599997 142726.0 599998 124239.0 599999 84790.0 Name : id , Length : 600000 , dtype : float64 \u4f60\u53ef\u4ee5\u6dfb\u52a0\u6240\u6709\u7279\u5f81\u7684\u8ba1\u6570\uff0c\u4e5f\u53ef\u4ee5\u66ff\u6362\u5b83\u4eec\uff0c\u6216\u8005\u6839\u636e\u591a\u4e2a\u5217\u53ca\u5176\u8ba1\u6570\u8fdb\u884c\u5206\u7ec4\u3002\u4f8b\u5982\uff0c\u4ee5\u4e0b\u4ee3\u7801\u901a\u8fc7\u5bf9 ord_1 \u548c ord_2 \u5217\u5206\u7ec4\u8fdb\u884c\u8ba1\u6570\u3002 In [ X ]: df . groupby ( ... : [ ... : \"ord_1\" , ... : \"ord_2\" ... : ] ... : )[ \"id\" ] . count () . reset_index ( name = \"count\" ) Out [ X ]: ord_1 ord_2 count 0 Contributor Boiling Hot 15634 1 Contributor Cold 17734 2 Contributor Freezing 26082 3 Contributor Hot 12428 4 Contributor Lava Hot 11919 5 Contributor Warm 22774 6 Expert Boiling Hot 19477 7 Expert Cold 22956 8 Expert Freezing 33249 9 Expert Hot 15792 10 Expert Lava Hot 15078 11 Expert Warm 28900 12 Grandmaster Boiling Hot 13623 13 Grandmaster Cold 15464 14 Grandmaster Freezing 22818 15 Grandmaster Hot 10805 16 Grandmaster Lava Hot 10363 17 Grandmaster Warm 19899 18 Master Boiling Hot 10800 ... \u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u884c\uff0c\u4ee5\u4fbf\u5728\u4e00\u9875\u4e2d\u5bb9\u7eb3\u8fd9\u4e9b\u884c\u3002\u8fd9\u662f\u53e6\u4e00\u79cd\u53ef\u4ee5\u4f5c\u4e3a\u529f\u80fd\u6dfb\u52a0\u7684\u8ba1\u6570\u3002\u60a8\u73b0\u5728\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4f7f\u7528 id \u5217\u8fdb\u884c\u8ba1\u6570\u3002\u4e0d\u8fc7\uff0c\u4f60\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5bf9\u5217\u7684\u7ec4\u5408\u8fdb\u884c\u5206\u7ec4\uff0c\u5bf9\u5176\u4ed6\u5217\u8fdb\u884c\u8ba1\u6570\u3002 \u8fd8\u6709\u4e00\u4e2a\u5c0f\u7a8d\u95e8\uff0c\u5c31\u662f\u4ece\u8fd9\u4e9b\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u4ece\u73b0\u6709\u7684\u7279\u5f81\u4e2d\u521b\u5efa\u65b0\u7684\u5206\u7c7b\u7279\u5f81\uff0c\u800c\u4e14\u53ef\u4ee5\u6beb\u4e0d\u8d39\u529b\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot 1 Grandmaster_Warm 2 nan_Freezing 3 Novice_Lava Hot 4 Grandmaster_Cold ... 599999 Contributor_Boiling Hot Name : new_feature , Length : 600000 , dtype : object \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u7528\u4e0b\u5212\u7ebf\u5c06 ord_1 \u548c ord_2 \u5408\u5e76\uff0c\u7136\u540e\u5c06\u8fd9\u4e9b\u5217\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u7c7b\u578b\u3002\u8bf7\u6ce8\u610f\uff0cNaN \u4e5f\u4f1a\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u3002\u4e0d\u8fc7\u6ca1\u5173\u7cfb\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06 NaN \u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u6709\u4e86\u4e00\u4e2a\u7531\u8fd9\u4e24\u4e2a\u7279\u5f81\u7ec4\u5408\u800c\u6210\u7684\u65b0\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u5c06\u4e09\u5217\u4ee5\u4e0a\u6216\u56db\u5217\u751a\u81f3\u66f4\u591a\u5217\u7ec4\u5408\u5728\u4e00\u8d77\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : + \"_\" ... : + df . ord_3 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot_c 1 Grandmaster_Warm_e 2 nan_Freezing_n 3 Novice_Lava Hot_a 4 Grandmaster_Cold_h ... 599999 Contributor_Boiling Hot_b Name : new_feature , Length : 600000 , dtype : object \u90a3\u4e48\uff0c\u6211\u4eec\u5e94\u8be5\u628a\u54ea\u4e9b\u7c7b\u522b\u7ed3\u5408\u8d77\u6765\u5462\uff1f\u8fd9\u5e76\u6ca1\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u7b54\u6848\u3002\u8fd9\u53d6\u51b3\u4e8e\u60a8\u7684\u6570\u636e\u548c\u7279\u5f81\u7c7b\u578b\u3002\u4e00\u4e9b\u9886\u57df\u77e5\u8bc6\u5bf9\u4e8e\u521b\u5efa\u8fd9\u6837\u7684\u7279\u5f81\u53ef\u80fd\u5f88\u6709\u7528\u3002\u4f46\u662f\uff0c\u5982\u679c\u4f60\u4e0d\u62c5\u5fc3\u5185\u5b58\u548c CPU \u7684\u4f7f\u7528\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\uff0c\u5373\u521b\u5efa\u8bb8\u591a\u8fd9\u6837\u7684\u7ec4\u5408\uff0c\u7136\u540e\u4f7f\u7528\u4e00\u4e2a\u6a21\u578b\u6765\u51b3\u5b9a\u54ea\u4e9b\u7279\u5f81\u662f\u6709\u7528\u7684\uff0c\u5e76\u4fdd\u7559\u5b83\u4eec\u3002\u6211\u4eec\u5c06\u5728\u672c\u4e66\u7a0d\u540e\u90e8\u5206\u4ecb\u7ecd\u8fd9\u79cd\u65b9\u6cd5\u3002 \u65e0\u8bba\u4f55\u65f6\u83b7\u5f97\u5206\u7c7b\u53d8\u91cf\uff0c\u90fd\u8981\u9075\u5faa\u4ee5\u4e0b\u7b80\u5355\u6b65\u9aa4\uff1a - \u586b\u5145 NaN \u503c\uff08\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff01\uff09\u3002 - \u4f7f\u7528 scikit-learn \u7684 LabelEncoder \u6216\u6620\u5c04\u5b57\u5178\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\uff0c\u5c06\u5b83\u4eec\u8f6c\u6362\u4e3a\u6574\u6570\u3002\u5982\u679c\u6ca1\u6709\u586b\u5145 NaN \u503c\uff0c\u53ef\u80fd\u9700\u8981\u5728\u8fd9\u4e00\u6b65\u4e2d\u8fdb\u884c\u5904\u7406 - \u521b\u5efa\u72ec\u70ed\u7f16\u7801\u3002\u662f\u7684\uff0c\u4f60\u53ef\u4ee5\u8df3\u8fc7\u4e8c\u503c\u5316\uff01 - \u5efa\u6a21\uff01\u6211\u6307\u7684\u662f\u673a\u5668\u5b66\u4e60\u3002 \u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u5904\u7406 NaN \u6570\u636e\u975e\u5e38\u91cd\u8981\uff0c\u5426\u5219\u60a8\u53ef\u80fd\u4f1a\u4ece scikit-learn \u7684 LabelEncoder \u4e2d\u5f97\u5230\u81ed\u540d\u662d\u8457\u7684\u9519\u8bef\u4fe1\u606f\uff1a ValueError: y \u5305\u542b\u4ee5\u524d\u672a\u89c1\u8fc7\u7684\u6807\u7b7e\uff1a [Nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan) \u8fd9\u4ec5\u4ec5\u610f\u5473\u7740\uff0c\u5728\u8f6c\u6362\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6570\u636e\u4e2d\u51fa\u73b0\u4e86 NaN \u503c\u3002\u8fd9\u662f\u56e0\u4e3a\u4f60\u5728\u8bad\u7ec3\u65f6\u5fd8\u8bb0\u4e86\u5904\u7406\u5b83\u4eec\u3002 \u5904\u7406 NaN \u503c \u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u4e22\u5f03\u5b83\u4eec\u3002\u867d\u7136\u7b80\u5355\uff0c\u4f46\u5e76\u4e0d\u7406\u60f3\u3002NaN \u503c\u4e2d\u53ef\u80fd\u5305\u542b\u5f88\u591a\u4fe1\u606f\uff0c\u5982\u679c\u53ea\u662f\u4e22\u5f03\u8fd9\u4e9b\u503c\uff0c\u5c31\u4f1a\u4e22\u5931\u8fd9\u4e9b\u4fe1\u606f\u3002\u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u5927\u90e8\u5206\u6570\u636e\u90fd\u662f NaN \u503c\uff0c\u56e0\u6b64\u4e0d\u80fd\u4e22\u5f03 NaN \u503c\u7684\u884c/\u6837\u672c\u3002\u5904\u7406 NaN \u503c\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u5c06\u5176\u4f5c\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u662f\u5904\u7406 NaN \u503c\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002\u5982\u679c\u4f7f\u7528 pandas\uff0c\u8fd8\u53ef\u4ee5\u901a\u8fc7\u975e\u5e38\u7b80\u5355\u7684\u65b9\u5f0f\u5b9e\u73b0\u3002 \u8bf7\u770b\u6211\u4eec\u4e4b\u524d\u67e5\u770b\u8fc7\u7684\u6570\u636e\u7684 ord_2 \u5217\u3002 In [ X ]: df . ord_2 . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : ord_2 , dtype : int64 \u586b\u5165 NaN \u503c\u540e\uff0c\u5c31\u53d8\u6210\u4e86 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u54c7\uff01\u8fd9\u4e00\u5217\u4e2d\u6709 18075 \u4e2a NaN \u503c\uff0c\u800c\u6211\u4eec\u4e4b\u524d\u751a\u81f3\u90fd\u6ca1\u6709\u8003\u8651\u4f7f\u7528\u5b83\u4eec\u3002\u589e\u52a0\u4e86\u8fd9\u4e2a\u65b0\u7c7b\u522b\u540e\uff0c\u7c7b\u522b\u603b\u6570\u4ece 6 \u4e2a\u589e\u52a0\u5230\u4e86 7 \u4e2a\u3002\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u73b0\u5728\u6211\u4eec\u5728\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u4e5f\u4f1a\u8003\u8651 NaN\u3002\u76f8\u5173\u4fe1\u606f\u8d8a\u591a\uff0c\u6a21\u578b\u5c31\u8d8a\u597d\u3002 \u5047\u8bbe ord_2 \u6ca1\u6709\u4efb\u4f55 NaN \u503c\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e00\u5217\u4e2d\u7684\u6240\u6709\u7c7b\u522b\u90fd\u6709\u663e\u8457\u7684\u8ba1\u6570\u3002\u5176\u4e2d\u6ca1\u6709 \"\u7f55\u89c1 \"\u7c7b\u522b\uff0c\u5373\u53ea\u5728\u6837\u672c\u603b\u6570\u4e2d\u5360\u5f88\u5c0f\u6bd4\u4f8b\u7684\u7c7b\u522b\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5047\u8bbe\u60a8\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u4e86\u4f7f\u7528\u8fd9\u4e00\u5217\u7684\u6a21\u578b\uff0c\u5f53\u6a21\u578b\u6216\u9879\u76ee\u4e0a\u7ebf\u65f6\uff0c\u60a8\u5728 ord_2 \u5217\u4e2d\u5f97\u5230\u4e86\u4e00\u4e2a\u5728\u8bad\u7ec3\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u7ba1\u9053\u4f1a\u629b\u51fa\u4e00\u4e2a\u9519\u8bef\uff0c\u60a8\u5bf9\u6b64\u65e0\u80fd\u4e3a\u529b\u3002\u5982\u679c\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\uff0c\u90a3\u4e48\u53ef\u80fd\u662f\u751f\u4ea7\u4e2d\u7684\u7ba1\u9053\u51fa\u4e86\u95ee\u9898\u3002\u5982\u679c\u8fd9\u662f\u9884\u6599\u4e4b\u4e2d\u7684\uff0c\u90a3\u4e48\u60a8\u5c31\u5fc5\u987b\u4fee\u6539\u60a8\u7684\u6a21\u578b\u7ba1\u9053\uff0c\u5e76\u5728\u8fd9\u516d\u4e2a\u7c7b\u522b\u4e2d\u52a0\u5165\u4e00\u4e2a\u65b0\u7c7b\u522b\u3002 \u8fd9\u4e2a\u65b0\u7c7b\u522b\u88ab\u79f0\u4e3a \"\u7f55\u89c1 \"\u7c7b\u522b\u3002\u7f55\u89c1\u7c7b\u522b\u662f\u4e00\u79cd\u4e0d\u5e38\u89c1\u7684\u7c7b\u522b\uff0c\u53ef\u4ee5\u5305\u62ec\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8fd1\u90bb\u6a21\u578b\u6765 \"\u9884\u6d4b \"\u672a\u77e5\u7c7b\u522b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u4e86\u8fd9\u4e2a\u7c7b\u522b\uff0c\u5b83\u5c31\u4f1a\u6210\u4e3a\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u4e00\u4e2a\u7c7b\u522b\u3002 \u56fe 3\uff1a\u5177\u6709\u4e0d\u540c\u7279\u5f81\u4e14\u65e0\u6807\u7b7e\u7684\u6570\u636e\u96c6\u793a\u610f\u56fe\uff0c\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u53ef\u80fd\u4f1a\u5728\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u51fa\u73b0\u65b0\u503c \u5f53\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u56fe 3 \u6240\u793a\u7684\u6570\u636e\u96c6\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\uff0c\u5bf9\u9664 \"f3 \"\u4e4b\u5916\u7684\u6240\u6709\u7279\u5f81\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u6837\uff0c\u4f60\u5c06\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u4e0d\u77e5\u9053\u6216\u8bad\u7ec3\u4e2d\u6ca1\u6709 \"f3 \"\u65f6\u9884\u6d4b\u5b83\u3002\u6211\u4e0d\u6562\u8bf4\u8fd9\u6837\u7684\u6a21\u578b\u662f\u5426\u80fd\u5e26\u6765\u51fa\u8272\u7684\u6027\u80fd\uff0c\u4f46\u4e5f\u8bb8\u80fd\u5904\u7406\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u7684\u7f3a\u5931\u503c\uff0c\u5c31\u50cf\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5176\u4ed6\u4e8b\u60c5\u4e00\u6837\uff0c\u4e0d\u5c1d\u8bd5\u4e00\u4e0b\u662f\u8bf4\u4e0d\u51c6\u7684\u3002 \u5982\u679c\u4f60\u6709\u4e00\u4e2a\u56fa\u5b9a\u7684\u6d4b\u8bd5\u96c6\uff0c\u4f60\u53ef\u4ee5\u5c06\u6d4b\u8bd5\u6570\u636e\u6dfb\u52a0\u5230\u8bad\u7ec3\u4e2d\uff0c\u4ee5\u4e86\u89e3\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u7c7b\u522b\u3002\u8fd9\u4e0e\u534a\u76d1\u7763\u5b66\u4e60\u975e\u5e38\u76f8\u4f3c\uff0c\u5373\u4f7f\u7528\u65e0\u6cd5\u7528\u4e8e\u8bad\u7ec3\u7684\u6570\u636e\u6765\u6539\u8fdb\u6a21\u578b\u3002\u8fd9\u4e5f\u4f1a\u7167\u987e\u5230\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u51fa\u73b0\u6b21\u6570\u6781\u5c11\u4f46\u5728\u6d4b\u8bd5\u6570\u636e\u4e2d\u5927\u91cf\u5b58\u5728\u7684\u7a00\u6709\u503c\u3002\u4f60\u7684\u6a21\u578b\u5c06\u66f4\u52a0\u7a33\u5065\u3002 \u5f88\u591a\u4eba\u8ba4\u4e3a\u8fd9\u79cd\u60f3\u6cd5\u4f1a\u8fc7\u5ea6\u62df\u5408\u3002\u53ef\u80fd\u8fc7\u62df\u5408\uff0c\u4e5f\u53ef\u80fd\u4e0d\u8fc7\u62df\u5408\u3002\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u3002\u5982\u679c\u4f60\u5728\u8bbe\u8ba1\u4ea4\u53c9\u9a8c\u8bc1\u65f6\uff0c\u80fd\u591f\u5728\u6d4b\u8bd5\u6570\u636e\u4e0a\u8fd0\u884c\u6a21\u578b\u65f6\u590d\u5236\u9884\u6d4b\u8fc7\u7a0b\uff0c\u90a3\u4e48\u5b83\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u8fd9\u610f\u5473\u7740\u7b2c\u4e00\u6b65\u5e94\u8be5\u662f\u5206\u79bb\u6298\u53e0\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u4f60\u5e94\u8be5\u5e94\u7528\u4e0e\u6d4b\u8bd5\u6570\u636e\u76f8\u540c\u7684\u9884\u5904\u7406\u3002\u5047\u8bbe\u60a8\u60f3\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e\uff0c\u90a3\u4e48\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u60a8\u5fc5\u987b\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\uff0c\u5e76\u786e\u4fdd\u9a8c\u8bc1\u6570\u636e\u96c6\u590d\u5236\u4e86\u6d4b\u8bd5\u96c6\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u60a8\u5fc5\u987b\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u8bbe\u8ba1\u9a8c\u8bc1\u96c6\uff0c\u4f7f\u5176\u5305\u542b\u8bad\u7ec3\u96c6\u4e2d \"\u672a\u89c1 \"\u7684\u7c7b\u522b\u3002 \u56fe 4\uff1a\u5bf9\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u8fdb\u884c\u7b80\u5355\u5408\u5e76\uff0c\u4ee5\u4e86\u89e3\u6d4b\u8bd5\u96c6\u4e2d\u5b58\u5728\u4f46\u8bad\u7ec3\u96c6\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u6216\u8bad\u7ec3\u96c6\u4e2d\u7f55\u89c1\u7684\u7c7b\u522b \u53ea\u8981\u770b\u4e00\u4e0b\u56fe 4 \u548c\u4e0b\u9762\u7684\u4ee3\u7801\uff0c\u5c31\u80fd\u5f88\u5bb9\u6613\u7406\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( \"../input/cat_train.csv\" ) # \u8bfb\u53d6\u6d4b\u8bd5\u96c6 test = pd . read_csv ( \"../input/cat_test.csv\" ) # \u5c06\u6d4b\u8bd5\u96c6\"target\"\u5217\u5168\u90e8\u7f6e\u4e3a-1 test . loc [:, \"target\" ] = - 1 # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u62fc\u63a5 data = pd . concat ([ train , test ]) . reset_index ( drop = True ) # \u5c06\u9664\"id\"\u548c\"target\"\u5217\u7684\u5176\u4ed6\u7279\u5f81\u5217\u540d\u53d6\u51fa features = [ x for x in train . columns if x not in [ \"id\" , \"target\" ]] # \u904d\u5386\u7279\u5f81 for feat in features : # \u6807\u7b7e\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\",\u5e76\u5c06\u8be5\u5217\u683c\u5f0f\u53d8\u4e3astr temp_col = data [ feat ] . fillna ( \"NONE\" ) . astype ( str ) . values # \u8f6c\u6362\u6570\u503c data . loc [:, feat ] = lbl_enc . fit_transform ( temp_col ) # \u6839\u636e\"target\"\u5217\u5c06\u8bad\u7ec3\u96c6\u4e0e\u6d4b\u8bd5\u96c6\u5206\u5f00 train = data [ data . target != - 1 ] . reset_index ( drop = True ) test = data [ data . target == - 1 ] . reset_index ( drop = True ) \u5f53\u60a8\u9047\u5230\u5df2\u7ecf\u6709\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u95ee\u9898\u65f6\uff0c\u8fd9\u4e2a\u6280\u5de7\u5c31\u4f1a\u8d77\u4f5c\u7528\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u4e00\u62db\u5728\u5b9e\u65f6\u73af\u5883\u4e2d\u4e0d\u8d77\u4f5c\u7528\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u60a8\u6240\u5728\u7684\u516c\u53f8\u63d0\u4f9b\u5b9e\u65f6\u7ade\u4ef7\u89e3\u51b3\u65b9\u6848\uff08RTB\uff09\u3002RTB \u7cfb\u7edf\u4f1a\u5bf9\u5728\u7ebf\u770b\u5230\u7684\u6bcf\u4e2a\u7528\u6237\u8fdb\u884c\u7ade\u4ef7\uff0c\u4ee5\u8d2d\u4e70\u5e7f\u544a\u7a7a\u95f4\u3002\u8fd9\u79cd\u6a21\u5f0f\u53ef\u4f7f\u7528\u7684\u529f\u80fd\u53ef\u80fd\u5305\u62ec\u7f51\u7ad9\u4e2d\u6d4f\u89c8\u7684\u9875\u9762\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u7279\u5f81\u662f\u7528\u6237\u8bbf\u95ee\u7684\u6700\u540e\u4e94\u4e2a\u7c7b\u522b/\u9875\u9762\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u7f51\u7ad9\u5f15\u5165\u4e86\u65b0\u7684\u7c7b\u522b\uff0c\u6211\u4eec\u5c06\u65e0\u6cd5\u518d\u51c6\u786e\u9884\u6d4b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u5931\u6548\u3002\u8fd9\u79cd\u60c5\u51b5\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528 \"\u672a\u77e5 \"\u7c7b\u522b\u6765\u907f\u514d \u3002 \u5728\u6211\u4eec\u7684 cat-in-the-dat \u6570\u636e\u96c6\u4e2d\uff0c ord_2 \u5217\u4e2d\u5df2\u7ecf\u6709\u4e86\u672a\u77e5\u7c7b\u522b\u3002 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u6211\u4eec\u53ef\u4ee5\u5c06 \"NONE \"\u89c6\u4e3a\u672a\u77e5\u3002\u56e0\u6b64\uff0c\u5982\u679c\u5728\u5b9e\u65f6\u6d4b\u8bd5\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4ee5\u524d\u4ece\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\uff0c\u6211\u4eec\u5c31\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a \"NONE\"\u3002 \u8fd9\u4e0e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u95ee\u9898\u975e\u5e38\u76f8\u4f3c\u3002\u6211\u4eec\u603b\u662f\u57fa\u4e8e\u56fa\u5b9a\u7684\u8bcd\u6c47\u5efa\u7acb\u6a21\u578b\u3002\u589e\u52a0\u8bcd\u6c47\u91cf\u5c31\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u3002\u50cf BERT \u8fd9\u6837\u7684\u8f6c\u6362\u5668\u6a21\u578b\u662f\u5728 ~30000 \u4e2a\u5355\u8bcd\uff08\u82f1\u8bed\uff09\u7684\u57fa\u7840\u4e0a\u8bad\u7ec3\u7684\u3002\u56e0\u6b64\uff0c\u5f53\u6709\u65b0\u8bcd\u8f93\u5165\u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a UNK\uff08\u672a\u77e5\uff09\u3002 \u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u5047\u8bbe\u6d4b\u8bd5\u6570\u636e\u4e0e\u8bad\u7ec3\u6570\u636e\u5177\u6709\u76f8\u540c\u7684\u7c7b\u522b\uff0c\u4e5f\u53ef\u4ee5\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u5f15\u5165\u7f55\u89c1\u6216\u672a\u77e5\u7c7b\u522b\uff0c\u4ee5\u5904\u7406\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u65b0\u7c7b\u522b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u586b\u5165 NaN \u503c\u540e ord_4 \u5217\u7684\u503c\u8ba1\u6570\uff1a In [ X ]: df . ord_4 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 . . . K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 G 3404 V 3107 J 1950 L 1657 Name : ord_4 , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u6709\u4e9b\u6570\u503c\u53ea\u51fa\u73b0\u4e86\u51e0\u5343\u6b21\uff0c\u6709\u4e9b\u5219\u51fa\u73b0\u4e86\u8fd1 40000 \u6b21\u3002NaN \u4e5f\u7ecf\u5e38\u51fa\u73b0\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u503c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u5c06\u4e00\u4e2a\u503c\u79f0\u4e3a \" \u7f55\u89c1\uff08rare\uff09 \"\u7684\u6807\u51c6\u4e86\u3002\u6bd4\u65b9\u8bf4\uff0c\u5728\u8fd9\u4e00\u5217\u4e2d\uff0c\u7a00\u6709\u503c\u7684\u8981\u6c42\u662f\u8ba1\u6570\u5c0f\u4e8e 2000\u3002\u8fd9\u6837\u770b\u6765\uff0cJ \u548c L \u5c31\u53ef\u4ee5\u88ab\u6807\u8bb0\u4e3a\u7a00\u6709\u503c\u4e86\u3002\u4f7f\u7528 pandas\uff0c\u6839\u636e\u8ba1\u6570\u9608\u503c\u66ff\u6362\u7c7b\u522b\u975e\u5e38\u7b80\u5355\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 In [ X ]: df . ord_4 = df . ord_4 . fillna ( \"NONE\" ) In [ X ]: df . loc [ ... : df [ \"ord_4\" ] . value_counts ()[ df [ \"ord_4\" ]] . values < 2000 , ... : \"ord_4\" ... : ] = \"RARE\" In [ X ]: df . ord_4 . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 M 32504 . . . B 25212 E 21871 K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 RARE 3607 G 3404 V 3107 Name : ord_4 , dtype : int64 \u6211\u4eec\u8ba4\u4e3a\uff0c\u53ea\u8981\u67d0\u4e2a\u7c7b\u522b\u7684\u503c\u5c0f\u4e8e 2000\uff0c\u5c31\u5c06\u5176\u66ff\u6362\u4e3a\u7f55\u89c1\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u5728\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6240\u6709\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"RARE\"\uff0c\u800c\u6240\u6709\u7f3a\u5931\u503c\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"NONE\"\u3002 \u8fd9\u79cd\u65b9\u6cd5\u8fd8\u80fd\u786e\u4fdd\u5373\u4f7f\u6709\u65b0\u7684\u7c7b\u522b\uff0c\u6a21\u578b\u4e5f\u80fd\u5728\u5b9e\u9645\u73af\u5883\u4e2d\u6b63\u5e38\u5de5\u4f5c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u5177\u5907\u4e86\u5904\u7406\u4efb\u4f55\u5e26\u6709\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u6240\u9700\u7684\u4e00\u5207\u6761\u4ef6\u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u7b2c\u4e00\u4e2a\u6a21\u578b\uff0c\u5e76\u9010\u6b65\u63d0\u9ad8\u5176\u6027\u80fd\u3002 \u5728\u6784\u5efa\u4efb\u4f55\u7c7b\u578b\u7684\u6a21\u578b\u4e4b\u524d\uff0c\u4ea4\u53c9\u68c0\u9a8c\u81f3\u5173\u91cd\u8981\u3002\u6211\u4eec\u5df2\u7ecf\u770b\u5230\u4e86\u6807\u7b7e/\u76ee\u6807\u5206\u5e03\uff0c\u77e5\u9053\u8fd9\u662f\u4e00\u4e2a\u76ee\u6807\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u6765\u5206\u5272\u6570\u636e\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u6dfb\u52a0\"kfold\"\u5217\uff0c\u5e76\u7f6e\u4e3a-1 df [ \"kfold\" ] = - 1 # \u6253\u4e71\u6570\u636e\u987a\u5e8f\uff0c\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u5c06\u76ee\u6807\u5217\u53d6\u51fa y = df . target . values # \u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): # \u533a\u5206\u6298\u53e0 df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u6587\u4ef6 df . to_csv ( \"../input/cat_train_folds.csv\" , index = False ) \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u68c0\u67e5\u65b0\u7684\u6298\u53e0 csv\uff0c\u67e5\u770b\u6bcf\u4e2a\u6298\u53e0\u7684\u6837\u672c\u6570\uff1a In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) In [ X ]: df . kfold . value_counts () Out [ X ]: 4 120000 3 120000 2 120000 1 120000 0 120000 Name : kfold , dtype : int64 \u6240\u6709\u6298\u53e0\u90fd\u6709 120000 \u4e2a\u6837\u672c\u3002\u8fd9\u662f\u610f\u6599\u4e4b\u4e2d\u7684\uff0c\u56e0\u4e3a\u8bad\u7ec3\u6570\u636e\u6709 600000 \u4e2a\u6837\u672c\uff0c\u800c\u6211\u4eec\u505a\u4e865\u6b21\u6298\u53e0\u3002\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4e00\u5207\u987a\u5229\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u6bcf\u4e2a\u6298\u53e0\u7684\u76ee\u6807\u5206\u5e03\u3002 In [ X ]: df [ df . kfold == 0 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 1 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 2 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 3 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 4 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u76ee\u6807\u7684\u5206\u5e03\u90fd\u662f\u4e00\u6837\u7684\u3002\u8fd9\u6b63\u662f\u6211\u4eec\u6240\u9700\u8981\u7684\u3002\u5b83\u4e5f\u53ef\u4ee5\u662f\u76f8\u4f3c\u7684\uff0c\u5e76\u4e0d\u4e00\u5b9a\u8981\u4e00\u76f4\u76f8\u540c\u3002\u73b0\u5728\uff0c\u5f53\u6211\u4eec\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u6807\u7b7e\u5206\u5e03\u90fd\u5c06\u76f8\u540c\u3002 \u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u7684\u6700\u7b80\u5355\u7684\u6a21\u578b\u4e4b\u4e00\u662f\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u5e76\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): # \u8bfb\u53d6\u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) # \u53d6\u9664\"id\", \"target\", \"kfold\"\u5916\u7684\u5176\u4ed6\u7279\u5f81\u5217 features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] # \u904d\u5386\u7279\u5f81\u5217\u8868 for col in features : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u6d4b\u8bd5\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u903b\u8f91\u56de\u5f52 model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . target . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( auc ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6298\u53e00 run ( 0 ) \u90a3\u4e48\uff0c\u53d1\u751f\u4e86\u4ec0\u4e48\u5462\uff1f \u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u5c06\u6570\u636e\u5206\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u4e24\u90e8\u5206\uff0c\u7ed9\u5b9a\u6298\u53e0\u6570\uff0c\u5904\u7406 NaN \u503c\uff0c\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u5355\u6b21\u7f16\u7801\uff0c\u5e76\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\u3002 \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u90e8\u5206\u4ee3\u7801\u65f6\uff0c\u4f1a\u4ea7\u751f\u5982\u4e0b\u8f93\u51fa\uff1a \u276f python ohe_logres . py / home / abhishek / miniconda3 / envs / ml / lib / python3 .7 / site - packages / sklearn / linear_model / _logistic . py : 939 : ConvergenceWarning : lbfgs failed to converge ( status = 1 ): STOP : TOTAL NO . of ITERATIONS REACHED LIMIT . Increase the number of iterations ( max_iter ) or scale the data as shown in : https : // scikit - learn . org / stable / modules / preprocessing . html . Please also refer to the documentation for alternative solver options : https : // scikit - learn . org / stable / modules / linear_model . html #logistic- regression extra_warning_msg = _LOGISTIC_SOLVER_CONVERGENCE_MSG ) 0.7847865042255127 \u6709\u4e00\u4e9b\u8b66\u544a\u3002\u903b\u8f91\u56de\u5f52\u4f3c\u4e4e\u6ca1\u6709\u6536\u655b\u5230\u6700\u5927\u8fed\u4ee3\u6b21\u6570\u3002\u6211\u4eec\u6ca1\u6709\u8c03\u6574\u53c2\u6570\uff0c\u6240\u4ee5\u6ca1\u6709\u95ee\u9898\u3002\u6211\u4eec\u770b\u5230 AUC \u4e3a 0.785\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u5bf9\u4ee3\u7801\u8fdb\u884c\u7b80\u5355\u4fee\u6539\uff0c\u8fd0\u884c\u6240\u6709\u6298\u53e0\u3002 .... model = linear_model . LogisticRegression () model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u5faa\u73af\u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5e76\u6ca1\u6709\u505a\u5f88\u5927\u7684\u6539\u52a8\uff0c\u6240\u4ee5\u6211\u53ea\u663e\u793a\u4e86\u90e8\u5206\u4ee3\u7801\u884c\uff0c\u5176\u4e2d\u4e00\u4e9b\u4ee3\u7801\u884c\u6709\u6539\u52a8\u3002 \u8fd9\u5c31\u6253\u5370\u51fa\u4e86\uff1a python - W ignore ohe_logres . py Fold = 0 , AUC = 0.7847865042255127 Fold = 1 , AUC = 0.7853553605899214 Fold = 2 , AUC = 0.7879321942914885 Fold = 3 , AUC = 0.7870315929550808 Fold = 4 , AUC = 0.7864668243125608 \u8bf7\u6ce8\u610f\uff0c\u6211\u4f7f\u7528\"-W ignore \"\u5ffd\u7565\u4e86\u6240\u6709\u8b66\u544a\u3002 \u6211\u4eec\u770b\u5230\uff0cAUC \u5206\u6570\u5728\u6240\u6709\u8936\u76b1\u4e2d\u90fd\u76f8\u5f53\u7a33\u5b9a\u3002\u5e73\u5747 AUC \u4e3a 0.78631449527\u3002\u5bf9\u4e8e\u6211\u4eec\u7684\u7b2c\u4e00\u4e2a\u6a21\u578b\u6765\u8bf4\u76f8\u5f53\u4e0d\u9519\uff01 \u5f88\u591a\u4eba\u5728\u9047\u5230\u8fd9\u79cd\u95ee\u9898\u65f6\u4f1a\u9996\u5148\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6bd4\u5982\u968f\u673a\u68ee\u6797\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u5e94\u7528\u968f\u673a\u68ee\u6797\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\uff08label encoding\uff09\uff0c\u5c06\u6bcf\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81\u90fd\u8f6c\u6362\u4e3a\u6574\u6570\uff0c\u800c\u4e0d\u662f\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u72ec\u70ed\u7f16\u7801\u3002 \u8fd9\u79cd\u7f16\u7801\u4e0e\u72ec\u70ed\u7f16\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # \u968f\u673a\u68ee\u6797\u6a21\u578b model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u4f7f\u7528 scikit-learn \u4e2d\u7684\u968f\u673a\u68ee\u6797\uff0c\u5e76\u53d6\u6d88\u4e86\u72ec\u70ed\u7f16\u7801\u3002\u6211\u4eec\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u4ee3\u66ff\u72ec\u70ed\u7f16\u7801\u3002\u5f97\u5206\u5982\u4e0b \u276f python lbl_rf . py Fold = 0 , AUC = 0.7167390828113697 Fold = 1 , AUC = 0.7165459672958506 Fold = 2 , AUC = 0.7159709909587376 Fold = 3 , AUC = 0.7161589664189556 Fold = 4 , AUC = 0.7156020216155978 \u54c7 \u5de8\u5927\u7684\u5dee\u5f02\uff01 \u968f\u673a\u68ee\u6797\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8d85\u53c2\u6570\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u8868\u73b0\u8981\u6bd4\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u5dee\u5f88\u591a\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u603b\u662f\u5e94\u8be5\u5148\u4ece\u7b80\u5355\u6a21\u578b\u5f00\u59cb\u7684\u539f\u56e0\u3002\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u7c89\u4e1d\u4f1a\u4ece\u8fd9\u91cc\u5f00\u59cb\uff0c\u800c\u5ffd\u7565\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u8ba4\u4e3a\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u6a21\u578b\uff0c\u4e0d\u80fd\u5e26\u6765\u6bd4\u968f\u673a\u68ee\u6797\u66f4\u597d\u7684\u4ef7\u503c\u3002\u8fd9\u79cd\u4eba\u5c06\u4f1a\u72af\u4e0b\u5927\u9519\u3002\u5728\u6211\u4eec\u5b9e\u73b0\u968f\u673a\u68ee\u6797\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4e0e\u903b\u8f91\u56de\u5f52\u76f8\u6bd4\uff0c\u6298\u53e0\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u624d\u80fd\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e0d\u4ec5\u635f\u5931\u4e86 AUC\uff0c\u8fd8\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u6765\u5b8c\u6210\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u8fdb\u884c\u63a8\u7406\u4e5f\u5f88\u8017\u65f6\uff0c\u800c\u4e14\u5360\u7528\u7684\u7a7a\u95f4\u4e5f\u66f4\u5927\u3002 \u5982\u679c\u6211\u4eec\u613f\u610f\uff0c\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u5728\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u6570\u636e\u4e0a\u8fd0\u884c\u968f\u673a\u68ee\u6797\uff0c\u4f46\u8fd9\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u5947\u5f02\u503c\u5206\u89e3\u6765\u51cf\u5c11\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u77e9\u9635\u3002\u8fd9\u662f\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u63d0\u53d6\u4e3b\u9898\u7684\u5e38\u7528\u65b9\u6cd5\u3002 import pandas as pd from scipy import sparse from sklearn import decomposition from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" )] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) x_train = ohe . transform ( df_train [ features ]) x_valid = ohe . transform ( df_valid [ features ]) # \u5947\u5f02\u503c\u5206\u89e3 svd = decomposition . TruncatedSVD ( n_components = 120 ) full_sparse = sparse . vstack (( x_train , x_valid )) svd . fit ( full_sparse ) x_train = svd . transform ( x_train ) x_valid = svd . transform ( x_valid ) model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u5bf9\u5168\u90e8\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\uff0c\u7136\u540e\u7528\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u5728\u7a00\u758f\u77e9\u9635\u4e0a\u62df\u5408 scikit-learn \u7684 TruncatedSVD\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c06\u9ad8\u7ef4\u7a00\u758f\u77e9\u9635\u51cf\u5c11\u5230 120 \u4e2a\u7279\u5f81\uff0c\u7136\u540e\u62df\u5408\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u3002 \u4ee5\u4e0b\u662f\u8be5\u6a21\u578b\u7684\u8f93\u51fa\u7ed3\u679c\uff1a \u276f python ohe_svd_rf . py Fold = 0 , AUC = 0.7064863038754249 Fold = 1 , AUC = 0.706050102937374 Fold = 2 , AUC = 0.7086069243167242 Fold = 3 , AUC = 0.7066819080085971 Fold = 4 , AUC = 0.7058154015055585 \u6211\u4eec\u53d1\u73b0\u60c5\u51b5\u66f4\u7cdf\u3002\u770b\u6765\uff0c\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u7684\u6700\u4f73\u65b9\u6cd5\u662f\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u3002\u968f\u673a\u68ee\u6797\u4f3c\u4e4e\u8017\u65f6\u592a\u591a\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u8bd5\u8bd5 XGBoost\u3002\u5982\u679c\u4f60\u4e0d\u77e5\u9053 XGBoost\uff0c\u5b83\u662f\u6700\u6d41\u884c\u7684\u68af\u5ea6\u63d0\u5347\u7b97\u6cd5\u4e4b\u4e00\u3002\u7531\u4e8e\u5b83\u662f\u4e00\u79cd\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u6570\u636e\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 , n_estimators = 200 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6211\u5bf9 xgboost \u53c2\u6570\u505a\u4e86\u4e00\u4e9b\u4fee\u6539\u3002xgboost \u7684\u9ed8\u8ba4\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u662f 3\uff0c\u6211\u628a\u5b83\u6539\u6210\u4e86 7\uff0c\u8fd8\u628a\u4f30\u8ba1\u5668\u6570\u91cf\uff08n_estimators\uff09\u4ece 100 \u6539\u6210\u4e86 200\u3002 \u8be5\u6a21\u578b\u7684 5 \u6298\u4ea4\u53c9\u68c0\u9a8c\u5f97\u5206\u5982\u4e0b\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.7656768851999011 Fold = 1 , AUC = 0.7633006564148015 Fold = 2 , AUC = 0.7654277821434345 Fold = 3 , AUC = 0.7663609758878182 Fold = 4 , AUC = 0.764914671468069 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u4e0d\u505a\u4efb\u4f55\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u5f97\u5206\u6bd4\u666e\u901a\u968f\u673a\u68ee\u6797\u8981\u9ad8\u5f97\u591a\u3002 \u60a8\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4e00\u4e9b\u7279\u5f81\u5de5\u7a0b\uff0c\u653e\u5f03\u67d0\u4e9b\u5bf9\u6a21\u578b\u6ca1\u6709\u4efb\u4f55\u4ef7\u503c\u7684\u5217\u7b49\u3002\u4f46\u4f3c\u4e4e\u6211\u4eec\u80fd\u505a\u7684\u4e0d\u591a\uff0c\u65e0\u6cd5\u8bc1\u660e\u6a21\u578b\u7684\u6539\u8fdb\u3002\u8ba9\u6211\u4eec\u628a\u6570\u636e\u96c6\u6362\u6210\u53e6\u4e00\u4e2a\u6709\u5927\u91cf\u5206\u7c7b\u53d8\u91cf\u7684\u6570\u636e\u96c6\u3002\u53e6\u4e00\u4e2a\u6709\u540d\u7684\u6570\u636e\u96c6\u662f \u7f8e\u56fd\u6210\u4eba\u4eba\u53e3\u666e\u67e5\u6570\u636e\uff08US adult census data\uff09 \u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u5305\u542b\u4e00\u4e9b\u7279\u5f81\uff0c\u800c\u4f60\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5de5\u8d44\u7b49\u7ea7\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u3002\u56fe 5 \u663e\u793a\u4e86\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e9b\u5217\u3002 \u56fe 5\uff1a\u90e8\u5206\u6570\u636e\u96c6\u5c55\u793a \u8be5\u6570\u636e\u96c6\u6709\u4ee5\u4e0b\u51e0\u5217\uff1a - \u5e74\u9f84\uff08age\uff09 \u5de5\u4f5c\u7c7b\u522b\uff08workclass\uff09 \u5b66\u5386\uff08fnlwgt\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education.num\uff09 \u5a5a\u59fb\u72b6\u51b5\uff08marital.status\uff09 \u804c\u4e1a\uff08occupation\uff09 \u5173\u7cfb\uff08relationship\uff09 \u79cd\u65cf\uff08race\uff09 \u6027\u522b\uff08sex\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u539f\u7c4d\u56fd\uff08native.country\uff09 \u6536\u5165\uff08income\uff09 \u8fd9\u4e9b\u7279\u5f81\u5927\u591a\u4e0d\u8a00\u81ea\u660e\u3002\u90a3\u4e9b\u4e0d\u660e\u767d\u7684\uff0c\u6211\u4eec\u53ef\u4ee5\u4e0d\u8003\u8651\u3002\u8ba9\u6211\u4eec\u5148\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002 \u6211\u4eec\u770b\u5230\u6536\u5165\u5217\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u3002\u8ba9\u6211\u4eec\u5bf9\u8fd9\u4e00\u5217\u8fdb\u884c\u6570\u503c\u7edf\u8ba1\u3002 In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/adult.csv\" ) In [ X ]: df . income . value_counts () Out [ X ]: <= 50 K 24720 > 50 K 7841 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 7841 \u4e2a\u5b9e\u4f8b\u7684\u6536\u5165\u8d85\u8fc7 5 \u4e07\u7f8e\u5143\u3002\u8fd9\u5360\u6837\u672c\u603b\u6570\u7684 24%\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u4e0e\u732b\u6570\u636e\u96c6\u76f8\u540c\u7684\u8bc4\u4f30\u65b9\u6cd5\uff0c\u5373 AUC\u3002 \u5728\u5f00\u59cb\u5efa\u6a21\u4e4b\u524d\uff0c\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5c06\u53bb\u6389\u51e0\u5217\u7279\u5f81\uff0c\u5373 \u5b66\u5386\uff08fnlwgt\uff09 \u5e74\u9f84\uff08age\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u8ba9\u6211\u4eec\u8bd5\u7740\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u5668\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u7b2c\u4e00\u6b65\u603b\u662f\u8981\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002\u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u90e8\u5206\u4ee3\u7801\u3002\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u9700\u8981\u5220\u9664\u7684\u5217 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) # \u6620\u5c04 target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } # \u4f7f\u7528\u6620\u5c04\u66ff\u6362 df . loc [:, \"income\" ] = df . income . map ( target_mapping ) # \u53d6\u9664\"kfold\", \"income\"\u5217\u7684\u5176\u4ed6\u5217\u540d features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u9a8c\u8bc1\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u6784\u5efa\u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . income . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u6211\u4eec\u4f1a\u5f97\u5230 \u276f python - W ignore ohe_logres . py Fold = 0 , AUC = 0.8794809708119079 Fold = 1 , AUC = 0.8875785068274882 Fold = 2 , AUC = 0.8852609687685753 Fold = 3 , AUC = 0.8681236223251438 Fold = 4 , AUC = 0.8728581541840037 \u5bf9\u4e8e\u4e00\u4e2a\u5982\u6b64\u7b80\u5355\u7684\u6a21\u578b\u6765\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u4e0d\u9519\u7684 AUC\uff01 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5728\u4e0d\u8c03\u6574\u4efb\u4f55\u8d85\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c1d\u8bd5\u4e00\u4e0b\u6807\u7b7e\u7f16\u7801\u7684xgboost\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8ba9\u6211\u4eec\u8fd0\u884c\u4e0a\u9762\u4ee3\u7801\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8800810634234078 Fold = 1 , AUC = 0.886811884948154 Fold = 2 , AUC = 0.8854421433318472 Fold = 3 , AUC = 0.8676319549361007 Fold = 4 , AUC = 0.8714450054900602 \u8fd9\u770b\u8d77\u6765\u5df2\u7ecf\u76f8\u5f53\u4e0d\u9519\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b max_depth \u589e\u52a0\u5230 7 \u548c n_estimators \u589e\u52a0\u5230 200 \u65f6\u7684\u5f97\u5206\u3002 \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8764108944332032 Fold = 1 , AUC = 0.8840708537662638 Fold = 2 , AUC = 0.8816601162613102 Fold = 3 , AUC = 0.8662335762581732 Fold = 4 , AUC = 0.8698983461709926 \u770b\u8d77\u6765\u5e76\u6ca1\u6709\u6539\u5584\u3002 \u8fd9\u8868\u660e\uff0c\u4e00\u4e2a\u6570\u636e\u96c6\u7684\u53c2\u6570\u4e0d\u80fd\u79fb\u690d\u5230\u53e6\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6211\u4eec\u5fc5\u987b\u518d\u6b21\u5c1d\u8bd5\u8c03\u6574\u53c2\u6570\uff0c\u4f46\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u8bf4\u660e\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u4e0d\u8c03\u6574\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c06\u6570\u503c\u7279\u5f81\u7eb3\u5165 xgboost \u6a21\u578b\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u52a0\u5165\u6570\u503c\u7279\u5f81 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : if col not in num_cols : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u56e0\u6b64\uff0c\u6211\u4eec\u4fdd\u7559\u6570\u5b57\u5217\uff0c\u53ea\u662f\u4e0d\u5bf9\u5176\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6700\u7ec8\u7279\u5f81\u77e9\u9635\u5c31\u7531\u6570\u5b57\u5217\uff08\u539f\u6837\uff09\u548c\u7f16\u7801\u5206\u7c7b\u5217\u7ec4\u6210\u4e86\u3002\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\u90fd\u80fd\u8f7b\u677e\u5904\u7406\u8fd9\u79cd\u6df7\u5408\u3002 \u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff0c\u5728\u4f7f\u7528\u7ebf\u6027\u6a21\u578b\uff08\u5982\u903b\u8f91\u56de\u5f52\uff09\u65f6\u4e0d\u5bb9\u5ffd\u89c6\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u4e2a\u811a\u672c\uff01 \u276f python lbl_xgb_num . py Fold = 0 , AUC = 0.9209790185449889 Fold = 1 , AUC = 0.9247157449144706 Fold = 2 , AUC = 0.9269329887598243 Fold = 3 , AUC = 0.9119349082169275 Fold = 4 , AUC = 0.9166408030141667 \u54c7\u54e6 \u8fd9\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u5206\u6570\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c1d\u8bd5\u6dfb\u52a0\u4e00\u4e9b\u529f\u80fd\u3002\u6211\u4eec\u5c06\u63d0\u53d6\u6240\u6709\u5206\u7c7b\u5217\uff0c\u5e76\u521b\u5efa\u6240\u6709\u4e8c\u5ea6\u7ec4\u5408\u3002\u8bf7\u770b\u4e0b\u9762\u4ee3\u7801\u6bb5\u4e2d\u7684 feature_engineering \u51fd\u6570\uff0c\u4e86\u89e3\u5982\u4f55\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002 import itertools import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def feature_engineering ( df , cat_cols ): # \u751f\u6210\u4e24\u4e2a\u7279\u5f81\u7684\u7ec4\u5408 combi = list ( itertools . combinations ( cat_cols , 2 )) for c1 , c2 in combi : df . loc [:, c1 + \"_\" + c2 ] = df [ c1 ] . astype ( str ) + \"_\" + df [ c2 ] . astype ( str ) return df def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) cat_cols = [ c for c in df . columns if c not in num_cols and c not in ( \"kfold\" , \"income\" )] # \u7279\u5f81\u5de5\u7a0b df = feature_engineering ( df , cat_cols ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" )] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u8fd9\u662f\u4ece\u5206\u7c7b\u5217\u4e2d\u521b\u5efa\u7279\u5f81\u7684\u4e00\u79cd\u975e\u5e38\u5e7c\u7a1a\u7684\u65b9\u6cd5\u3002\u6211\u4eec\u5e94\u8be5\u4ed4\u7ec6\u7814\u7a76\u6570\u636e\uff0c\u770b\u770b\u54ea\u4e9b\u7ec4\u5408\u6700\u5408\u7406\u3002\u5982\u679c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5c31\u9700\u8981\u4f7f\u7528\u67d0\u79cd\u7279\u5f81\u9009\u62e9\u6765\u9009\u51fa\u6700\u4f73\u7279\u5f81\u3002\u7a0d\u540e\u6211\u4eec\u5c06\u8be6\u7ec6\u4ecb\u7ecd\u7279\u5f81\u9009\u62e9\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5206\u6570\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9211483465031423 Fold = 1 , AUC = 0.9251499446866125 Fold = 2 , AUC = 0.9262344766486692 Fold = 3 , AUC = 0.9114264068794995 Fold = 4 , AUC = 0.9177914453099201 \u770b\u6765\uff0c\u5373\u4f7f\u4e0d\u6539\u53d8\u4efb\u4f55\u8d85\u53c2\u6570\uff0c\u53ea\u589e\u52a0\u4e00\u4e9b\u7279\u5f81\uff0c\u6211\u4eec\u4e5f\u80fd\u63d0\u9ad8\u4e00\u4e9b\u6298\u53e0\u5f97\u5206\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5c06 max_depth \u589e\u52a0\u5230 7 \u662f\u5426\u6709\u5e2e\u52a9\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9286668430204137 Fold = 1 , AUC = 0.9329340656165378 Fold = 2 , AUC = 0.9319817543218744 Fold = 3 , AUC = 0.919046187194538 Fold = 4 , AUC = 0.9245692057162671 \u6211\u4eec\u518d\u6b21\u6539\u8fdb\u4e86\u6211\u4eec\u7684\u6a21\u578b\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u8fd8\u6ca1\u6709\u4f7f\u7528\u7a00\u6709\u503c\u3001\u4e8c\u503c\u5316\u3001\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u7279\u5f81\u7684\u7ec4\u5408\u4ee5\u53ca\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u3002 \u4ece\u5206\u7c7b\u7279\u5f81\u4e2d\u8fdb\u884c\u7279\u5f81\u5de5\u7a0b\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u4f7f\u7528 \u76ee\u6807\u7f16\u7801 \u3002\u4f46\u662f\uff0c\u60a8\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u8fd9\u53ef\u80fd\u4f1a\u4f7f\u60a8\u7684\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u76ee\u6807\u7f16\u7801\u662f\u4e00\u79cd\u5c06\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u6620\u5c04\u5230\u5176\u5e73\u5747\u76ee\u6807\u503c\u7684\u6280\u672f\uff0c\u4f46\u5fc5\u987b\u59cb\u7ec8\u4ee5\u4ea4\u53c9\u9a8c\u8bc1\u7684\u65b9\u5f0f\u8fdb\u884c\u3002\u8fd9\u610f\u5473\u7740\u9996\u5148\u8981\u521b\u5efa\u6298\u53e0\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u6298\u53e0\u4e3a\u6570\u636e\u7684\u4e0d\u540c\u5217\u521b\u5efa\u76ee\u6807\u7f16\u7801\u7279\u5f81\uff0c\u65b9\u6cd5\u4e0e\u5728\u6298\u53e0\u4e0a\u62df\u5408\u548c\u9884\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u521b\u5efa\u4e86 5 \u4e2a\u6298\u53e0\uff0c\u60a8\u5c31\u5fc5\u987b\u521b\u5efa 5 \u6b21\u76ee\u6807\u7f16\u7801\uff0c\u8fd9\u6837\u6700\u7ec8\uff0c\u60a8\u5c31\u53ef\u4ee5\u4e3a\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u53d8\u91cf\u521b\u5efa\u7f16\u7801\uff0c\u800c\u8fd9\u4e9b\u53d8\u91cf\u5e76\u975e\u6765\u81ea\u540c\u4e00\u4e2a\u6298\u53e0\u3002\u7136\u540e\u5728\u62df\u5408\u6a21\u578b\u65f6\uff0c\u5fc5\u987b\u518d\u6b21\u4f7f\u7528\u76f8\u540c\u7684\u6298\u53e0\u3002\u672a\u89c1\u6d4b\u8bd5\u6570\u636e\u7684\u76ee\u6807\u7f16\u7801\u53ef\u4ee5\u6765\u81ea\u5168\u90e8\u8bad\u7ec3\u6570\u636e\uff0c\u4e5f\u53ef\u4ee5\u662f\u6240\u6709 5 \u4e2a\u6298\u53e0\u7684\u5e73\u5747\u503c\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u540c\u4e00\u4e2a\u6210\u4eba\u6570\u636e\u96c6\u4e0a\u4f7f\u7528\u76ee\u6807\u7f16\u7801\uff0c\u4ee5\u4fbf\u8fdb\u884c\u6bd4\u8f83\u3002 import copy import pandas as pd from sklearn import metrics from sklearn import preprocessing import xgboost as xgb def mean_target_encoding ( data ): df = copy . deepcopy ( data ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) and f not in num_cols ] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) encoded_dfs = [] for fold in range ( 5 ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) for column in features : # \u76ee\u6807\u7f16\u7801 mapping_dict = dict ( df_train . groupby ( column )[ \"income\" ] . mean () ) df_valid . loc [:, column + \"_enc\" ] = df_valid [ column ] . map ( mapping_dict ) encoded_dfs . append ( df_valid ) encoded_df = pd . concat ( encoded_dfs , axis = 0 ) return encoded_df def run ( df , fold ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/adult_folds.csv\" ) df = mean_target_encoding ( df ) for fold_ in range ( 5 ): run ( df , fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u4e0a\u8ff0\u7247\u6bb5\u4e2d\uff0c\u6211\u5728\u8fdb\u884c\u76ee\u6807\u7f16\u7801\u65f6\u5e76\u6ca1\u6709\u5220\u9664\u5206\u7c7b\u5217\u3002\u6211\u4fdd\u7559\u4e86\u6240\u6709\u7279\u5f81\uff0c\u5e76\u5728\u6b64\u57fa\u7840\u4e0a\u6dfb\u52a0\u4e86\u76ee\u6807\u7f16\u7801\u7279\u5f81\u3002\u6b64\u5916\uff0c\u6211\u8fd8\u4f7f\u7528\u4e86\u5e73\u5747\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5e73\u5747\u503c\u3001\u4e2d\u4f4d\u6570\u3001\u6807\u51c6\u504f\u5dee\u6216\u76ee\u6807\u7684\u4efb\u4f55\u5176\u4ed6\u51fd\u6570\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u7ed3\u679c\u3002 Fold = 0 , AUC = 0.9332240662017529 Fold = 1 , AUC = 0.9363551625140347 Fold = 2 , AUC = 0.9375013544556173 Fold = 3 , AUC = 0.92237621307625 Fold = 4 , AUC = 0.9292131180445478 \u4e0d\u9519\uff01\u770b\u6765\u6211\u4eec\u53c8\u6709\u8fdb\u6b65\u4e86\u3002\u4e0d\u8fc7\uff0c\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u5b83\u592a\u5bb9\u6613\u51fa\u73b0\u8fc7\u5ea6\u62df\u5408\u3002\u5f53\u6211\u4eec\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\uff0c\u6700\u597d\u4f7f\u7528\u67d0\u79cd\u5e73\u6ed1\u65b9\u6cd5\u6216\u5728\u7f16\u7801\u503c\u4e2d\u6dfb\u52a0\u566a\u58f0\u3002 Scikit-learn \u7684\u8d21\u732e\u5e93\u4e2d\u6709\u5e26\u5e73\u6ed1\u7684\u76ee\u6807\u7f16\u7801\uff0c\u4f60\u4e5f\u53ef\u4ee5\u521b\u5efa\u81ea\u5df1\u7684\u5e73\u6ed1\u3002\u5e73\u6ed1\u4f1a\u5f15\u5165\u67d0\u79cd\u6b63\u5219\u5316\uff0c\u6709\u52a9\u4e8e\u907f\u514d\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u8fd9\u5e76\u4e0d\u96be\u3002 \u5904\u7406\u5206\u7c7b\u7279\u5f81\u662f\u4e00\u9879\u590d\u6742\u7684\u4efb\u52a1\u3002\u8bb8\u591a\u8d44\u6e90\u4e2d\u90fd\u6709\u5927\u91cf\u4fe1\u606f\u3002\u672c\u7ae0\u5e94\u8be5\u80fd\u5e2e\u52a9\u4f60\u5f00\u59cb\u89e3\u51b3\u5206\u7c7b\u53d8\u91cf\u7684\u4efb\u4f55\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u5bf9\u4e8e\u5927\u591a\u6570\u95ee\u9898\u6765\u8bf4\uff0c\u9664\u4e86\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u4e4b\u5916\uff0c\u4f60\u4e0d\u9700\u8981\u66f4\u591a\u7684\u4e1c\u897f\u3002 \u8981\u8fdb\u4e00\u6b65\u6539\u8fdb\u6a21\u578b\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u66f4\u591a\uff01 \u5728\u672c\u7ae0\u7684\u6700\u540e\uff0c\u6211\u4eec\u4e0d\u80fd\u4e0d\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e00\u79cd\u79f0\u4e3a \u5b9e\u4f53\u5d4c\u5165 \u7684\u6280\u672f\u3002\u5728\u5b9e\u4f53\u5d4c\u5165\u4e2d\uff0c\u7c7b\u522b\u7528\u5411\u91cf\u8868\u793a\u3002\u5728\u4e8c\u503c\u5316\u548c\u72ec\u70ed\u7f16\u7801\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u90fd\u662f\u7528\u5411\u91cf\u6765\u8868\u793a\u7c7b\u522b\u7684\u3002 \u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u6709\u6570\u4ee5\u4e07\u8ba1\u7684\u7c7b\u522b\u600e\u4e48\u529e\uff1f\u8fd9\u5c06\u4f1a\u4ea7\u751f\u5de8\u5927\u7684\u77e9\u9635\uff0c\u6211\u4eec\u5c06\u9700\u8981\u5f88\u957f\u65f6\u95f4\u6765\u8bad\u7ec3\u590d\u6742\u7684\u6a21\u578b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5e26\u6709\u6d6e\u70b9\u503c\u7684\u5411\u91cf\u6765\u8868\u793a\u5b83\u4eec\u3002 \u8fd9\u4e2a\u60f3\u6cd5\u975e\u5e38\u7b80\u5355\u3002\u6bcf\u4e2a\u5206\u7c7b\u7279\u5f81\u90fd\u6709\u4e00\u4e2a\u5d4c\u5165\u5c42\u3002\u56e0\u6b64\uff0c\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u73b0\u5728\u90fd\u53ef\u4ee5\u6620\u5c04\u5230\u4e00\u4e2a\u5d4c\u5165\u5c42\uff08\u5c31\u50cf\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u5c06\u5355\u8bcd\u6620\u5c04\u5230\u5d4c\u5165\u5c42\u4e00\u6837\uff09\u3002\u7136\u540e\uff0c\u6839\u636e\u5176\u7ef4\u5ea6\u91cd\u5851\u8fd9\u4e9b\u5d4c\u5165\u5c42\uff0c\u4f7f\u5176\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c06\u6240\u6709\u6241\u5e73\u5316\u7684\u8f93\u5165\u5d4c\u5165\u5c42\u8fde\u63a5\u8d77\u6765\u3002\u7136\u540e\u6dfb\u52a0\u4e00\u5806\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\uff0c\u5c31\u5927\u529f\u544a\u6210\u4e86\u3002 \u56fe 6\uff1a\u7c7b\u522b\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6216\u5d4c\u5165\u5411\u91cf \u51fa\u4e8e\u67d0\u79cd\u539f\u56e0\uff0c\u6211\u53d1\u73b0\u4f7f\u7528 TF/Keras \u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528 TF/Keras \u5b9e\u73b0\u5b83\u3002\u6b64\u5916\uff0c\u8fd9\u662f\u672c\u4e66\u4e2d\u552f\u4e00\u4e00\u4e2a\u4f7f\u7528 TF/Keras \u7684\u793a\u4f8b\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a PyTorch\uff08\u4f7f\u7528 cat-in-the-dat-ii \u6570\u636e\u96c6\uff09\u4e5f\u975e\u5e38\u5bb9\u6613 import os import gc import joblib import pandas as pd import numpy as np from sklearn import metrics , preprocessing from tensorflow.keras import layers from tensorflow.keras import optimizers from tensorflow.keras.models import Model , load_model from tensorflow.keras import callbacks from tensorflow.keras import backend as K from tensorflow.keras import utils def create_model ( data , catcols ): # \u521b\u5efa\u7a7a\u7684\u8f93\u5165\u5217\u8868\u548c\u8f93\u51fa\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6a21\u578b\u7684\u8f93\u5165\u548c\u8f93\u51fa inputs = [] outputs = [] # \u904d\u5386\u5206\u7c7b\u7279\u5f81\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81 for c in catcols : # \u8ba1\u7b97\u7279\u5f81\u4e2d\u552f\u4e00\u503c\u7684\u6570\u91cf num_unique_values = int ( data [ c ] . nunique ()) # \u8ba1\u7b97\u5d4c\u5165\u7ef4\u5ea6\uff0c\u6700\u5927\u4e0d\u8d85\u8fc750 embed_dim = int ( min ( np . ceil (( num_unique_values ) / 2 ), 50 )) # \u521b\u5efa\u6a21\u578b\u7684\u8f93\u5165\u5c42\uff0c\u6bcf\u4e2a\u7279\u5f81\u5bf9\u5e94\u4e00\u4e2a\u8f93\u5165 inp = layers . Input ( shape = ( 1 ,)) # \u521b\u5efa\u5d4c\u5165\u5c42\uff0c\u5c06\u5206\u7c7b\u7279\u5f81\u6620\u5c04\u5230\u4f4e\u7ef4\u5ea6\u7684\u8fde\u7eed\u5411\u91cf out = layers . Embedding ( num_unique_values + 1 , embed_dim , name = c )( inp ) # \u5bf9\u5d4c\u5165\u5c42\u8fdb\u884c\u7a7a\u95f4\u4e22\u5f03\uff08Dropout\uff09 out = layers . SpatialDropout1D ( 0.3 )( out ) # \u5c06\u5d4c\u5165\u5c42\u7684\u5f62\u72b6\u91cd\u65b0\u8c03\u6574\u4e3a\u4e00\u7ef4 out = layers . Reshape ( target_shape = ( embed_dim ,))( out ) # \u5c06\u8f93\u5165\u548c\u8f93\u51fa\u6dfb\u52a0\u5230\u5bf9\u5e94\u7684\u5217\u8868\u4e2d inputs . append ( inp ) outputs . append ( out ) # \u4f7f\u7528Concatenate\u5c42\u5c06\u6240\u6709\u7684\u5d4c\u5165\u5c42\u8f93\u51fa\u8fde\u63a5\u5728\u4e00\u8d77 x = layers . Concatenate ()( outputs ) # \u5bf9\u8fde\u63a5\u540e\u7684\u6570\u636e\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u53e6\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u8f93\u51fa\u5c42\uff0c\u5177\u67092\u4e2a\u795e\u7ecf\u5143\uff08\u7528\u4e8e\u4e8c\u8fdb\u5236\u5206\u7c7b\uff09\uff0c\u5e76\u4f7f\u7528softmax\u6fc0\u6d3b\u51fd\u6570 y = layers . Dense ( 2 , activation = \"softmax\" )( x ) # \u521b\u5efa\u6a21\u578b\uff0c\u5c06\u8f93\u5165\u548c\u8f93\u51fa\u4f20\u9012\u7ed9Model\u6784\u9020\u51fd\u6570 model = Model ( inputs = inputs , outputs = y ) # \u7f16\u8bd1\u6a21\u578b\uff0c\u6307\u5b9a\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668 model . compile ( loss = 'binary_crossentropy' , optimizer = 'adam' ) # \u8fd4\u56de\u521b\u5efa\u7684\u6a21\u578b return model def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for feat in features : lbl_enc = preprocessing . LabelEncoder () df . loc [:, feat ] = lbl_enc . fit_transform ( df [ feat ] . values ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) model = create_model ( df , features ) xtrain = [ df_train [ features ] . values [:, k ] for k in range ( len ( features ))] xvalid = [ df_valid [ features ] . values [:, k ] for k in range ( len ( features )) ] ytrain = df_train . target . values yvalid = df_valid . target . values ytrain_cat = utils . to_categorical ( ytrain ) yvalid_cat = utils . to_categorical ( yvalid ) model . fit ( xtrain , ytrain_cat , validation_data = ( xvalid , yvalid_cat ), verbose = 1 , batch_size = 1024 , epochs = 3 ) valid_preds = model . predict ( xvalid )[:, 1 ] print ( metrics . roc_auc_score ( yvalid , valid_preds )) K . clear_session () if __name__ == \"__main__\" : run ( 0 ) run ( 1 ) run ( 2 ) run ( 3 ) run ( 4 ) \u4f60\u4f1a\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u6548\u679c\u6700\u597d\uff0c\u800c\u4e14\u5982\u679c\u4f60\u6709 GPU\uff0c\u901f\u5ea6\u4e5f\u8d85\u5feb\uff01\u8fd9\u79cd\u65b9\u6cd5\u8fd8\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\uff0c\u800c\u4e14\u4f60\u65e0\u9700\u62c5\u5fc3\u7279\u5f81\u5de5\u7a0b\uff0c\u56e0\u4e3a\u795e\u7ecf\u7f51\u7edc\u4f1a\u81ea\u884c\u5904\u7406\u3002\u5728\u5904\u7406\u5927\u91cf\u5206\u7c7b\u7279\u5f81\u6570\u636e\u96c6\u65f6\uff0c\u8fd9\u7edd\u5bf9\u503c\u5f97\u4e00\u8bd5\u3002\u5f53\u5d4c\u5165\u5927\u5c0f\u4e0e\u552f\u4e00\u7c7b\u522b\u7684\u6570\u91cf\u76f8\u540c\u65f6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u7f16\u7801\uff08one-hot-encoding\uff09\u3002 \u672c\u7ae0\u57fa\u672c\u4e0a\u90fd\u662f\u5173\u4e8e\u7279\u5f81\u5de5\u7a0b\u7684\u3002\u8ba9\u6211\u4eec\u5728\u4e0b\u4e00\u7ae0\u4e2d\u770b\u770b\u5982\u4f55\u5728\u6570\u5b57\u7279\u5f81\u548c\u4e0d\u540c\u7c7b\u578b\u7279\u5f81\u7684\u7ec4\u5408\u65b9\u9762\u8fdb\u884c\u66f4\u591a\u7684\u7279\u5f81\u5de5\u7a0b\u3002","title":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf"},{"location":"%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%88%96%E5%9B%9E%E5%BD%92%E6%96%B9%E6%B3%95/","text":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5 \u6587\u672c\u95ee\u9898\u662f\u6211\u7684\u6700\u7231\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u8fd9\u4e9b\u95ee\u9898\u4e5f\u88ab\u79f0\u4e3a \u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u95ee\u9898 \u3002NLP \u95ee\u9898\u4e0e\u56fe\u50cf\u95ee\u9898\u4e5f\u6709\u5f88\u5927\u4e0d\u540c\u3002\u4f60\u9700\u8981\u521b\u5efa\u4ee5\u524d\u4ece\u672a\u4e3a\u8868\u683c\u95ee\u9898\u521b\u5efa\u8fc7\u7684\u6570\u636e\u7ba1\u9053\u3002\u4f60\u9700\u8981\u4e86\u89e3\u5546\u4e1a\u6848\u4f8b\uff0c\u624d\u80fd\u5efa\u7acb\u4e00\u4e2a\u597d\u7684\u6a21\u578b\u3002\u987a\u4fbf\u8bf4\u4e00\u53e5\uff0c\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u4efb\u4f55\u4e8b\u60c5\u90fd\u662f\u5982\u6b64\u3002\u5efa\u7acb\u6a21\u578b\u4f1a\u8ba9\u4f60\u8fbe\u5230\u4e00\u5b9a\u7684\u6c34\u5e73\uff0c\u4f46\u8981\u60f3\u6539\u5584\u548c\u4fc3\u8fdb\u4f60\u6240\u5efa\u7acb\u6a21\u578b\u7684\u4e1a\u52a1\uff0c\u4f60\u5fc5\u987b\u4e86\u89e3\u5b83\u5bf9\u4e1a\u52a1\u7684\u5f71\u54cd\u3002 NLP \u95ee\u9898\u6709\u5f88\u591a\u79cd\uff0c\u5176\u4e2d\u6700\u5e38\u89c1\u7684\u662f\u5b57\u7b26\u4e32\u5206\u7c7b\u3002\u5f88\u591a\u65f6\u5019\uff0c\u6211\u4eec\u4f1a\u770b\u5230\u4eba\u4eec\u5728\u5904\u7406\u8868\u683c\u6570\u636e\u6216\u56fe\u50cf\u65f6\u8868\u73b0\u51fa\u8272\uff0c\u4f46\u5728\u5904\u7406\u6587\u672c\u65f6\uff0c\u4ed6\u4eec\u751a\u81f3\u4e0d\u77e5\u9053\u4ece\u4f55\u5165\u624b\u3002\u6587\u672c\u6570\u636e\u4e0e\u5176\u4ed6\u7c7b\u578b\u7684\u6570\u636e\u96c6\u6ca1\u6709\u4ec0\u4e48\u4e0d\u540c\u3002\u5bf9\u4e8e\u8ba1\u7b97\u673a\u6765\u8bf4\uff0c\u4e00\u5207\u90fd\u662f\u6570\u5b57\u3002 \u5047\u8bbe\u6211\u4eec\u4ece\u60c5\u611f\u5206\u7c7b\u8fd9\u4e00\u57fa\u672c\u4efb\u52a1\u5f00\u59cb\u3002\u6211\u4eec\u5c06\u5c1d\u8bd5\u5bf9\u7535\u5f71\u8bc4\u8bba\u8fdb\u884c\u60c5\u611f\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u6709\u4e00\u4e2a\u6587\u672c\uff0c\u5e76\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u60c5\u611f\u3002\u4f60\u5c06\u5982\u4f55\u5904\u7406\u8fd9\u7c7b\u95ee\u9898\uff1f\u662f\u5e94\u7528\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff1f \u4e0d\uff0c\u7edd\u5bf9\u9519\u4e86\u3002\u4f60\u8981\u4ece\u6700\u57fa\u672c\u7684\u5f00\u59cb\u3002\u8ba9\u6211\u4eec\u5148\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u4ec0\u4e48\u6837\u5b50\u7684\u3002 \u6211\u4eec\u4ece IMDB \u7535\u5f71\u8bc4\u8bba\u6570\u636e\u96c6 \u5f00\u59cb\uff0c\u8be5\u6570\u636e\u96c6\u5305\u542b 25000 \u7bc7\u6b63\u9762\u60c5\u611f\u8bc4\u8bba\u548c 25000 \u7bc7\u8d1f\u9762\u60c5\u611f\u8bc4\u8bba\u3002 \u6211\u5c06\u5728\u6b64\u8ba8\u8bba\u7684\u6982\u5ff5\u51e0\u4e4e\u9002\u7528\u4e8e\u4efb\u4f55\u6587\u672c\u5206\u7c7b\u6570\u636e\u96c6\u3002 \u8fd9\u4e2a\u6570\u636e\u96c6\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4e00\u7bc7\u8bc4\u8bba\u5bf9\u5e94\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5199\u7684\u662f\u8bc4\u8bba\u800c\u4e0d\u662f\u53e5\u5b50\u3002\u8bc4\u8bba\u5c31\u662f\u4e00\u5806\u53e5\u5b50\u3002\u6240\u4ee5\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4f60\u4e00\u5b9a\u53ea\u770b\u5230\u4e86\u5bf9\u5355\u53e5\u7684\u5206\u7c7b\uff0c\u4f46\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u591a\u4e2a\u53e5\u5b50\u8fdb\u884c\u5206\u7c7b\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u8fd9\u610f\u5473\u7740\u4e0d\u4ec5\u4e00\u4e2a\u53e5\u5b50\u4f1a\u5bf9\u60c5\u611f\u4ea7\u751f\u5f71\u54cd\uff0c\u800c\u4e14\u60c5\u611f\u5f97\u5206\u662f\u591a\u4e2a\u53e5\u5b50\u5f97\u5206\u7684\u7ec4\u5408\u3002\u6570\u636e\u7b80\u4ecb\u5982\u56fe 1 \u6240\u793a\u3002 \u5982\u4f55\u7740\u624b\u89e3\u51b3\u8fd9\u6837\u7684\u95ee\u9898\uff1f\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u624b\u5de5\u5236\u4f5c\u4e24\u4efd\u5355\u8bcd\u8868\u3002\u4e00\u4e2a\u5217\u8868\u5305\u542b\u4f60\u80fd\u60f3\u8c61\u5230\u7684\u6240\u6709\u6b63\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u597d\u3001\u68d2\u3001\u597d\u7b49\uff1b\u53e6\u4e00\u4e2a\u5217\u8868\u5305\u542b\u6240\u6709\u8d1f\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u574f\u3001\u6076\u7b49\u3002\u6211\u4eec\u5148\u4e0d\u8981\u4e3e\u4f8b\u8bf4\u660e\u574f\u8bcd\uff0c\u5426\u5219\u8fd9\u672c\u4e66\u5c31\u53ea\u80fd\u4f9b 18 \u5c81\u4ee5\u4e0a\u7684\u4eba\u9605\u8bfb\u4e86\u3002\u4e00\u65e6\u4f60\u6709\u4e86\u8fd9\u4e9b\u5217\u8868\uff0c\u4f60\u751a\u81f3\u4e0d\u9700\u8981\u4e00\u4e2a\u6a21\u578b\u6765\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u4e9b\u5217\u8868\u4e5f\u88ab\u79f0\u4e3a\u60c5\u611f\u8bcd\u5178\u3002\u4f60\u53ef\u4ee5\u7528\u4e00\u4e2a\u7b80\u5355\u7684\u8ba1\u6570\u5668\u6765\u8ba1\u7b97\u53e5\u5b50\u4e2d\u6b63\u9762\u548c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u3002\u5982\u679c\u6b63\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u6b63\u9762\u60c5\u611f\uff1b\u5982\u679c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u8d1f\u9762\u60c5\u611f\u3002\u5982\u679c\u53e5\u5b50\u4e2d\u6ca1\u6709\u8fd9\u4e9b\u8bcd\uff0c\u5219\u53ef\u4ee5\u8bf4\u8be5\u53e5\u5b50\u5177\u6709\u4e2d\u6027\u60c5\u611f\u3002\u8fd9\u662f\u6700\u53e4\u8001\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u73b0\u5728\u4ecd\u6709\u4eba\u5728\u4f7f\u7528\u3002\u5b83\u4e5f\u4e0d\u9700\u8981\u592a\u591a\u4ee3\u7801\u3002 def find_sentiment ( sentence , pos , neg ): sentence = sentence . split () sentence = set ( sentence ) num_common_pos = len ( sentence . intersection ( pos )) num_common_neg = len ( sentence . intersection ( neg )) if num_common_pos > num_common_neg : return \"positive\" if num_common_pos < num_common_neg : return \"negative\" return \"neutral\" \u4e0d\u8fc7\uff0c\u8fd9\u79cd\u65b9\u6cd5\u8003\u8651\u7684\u56e0\u7d20\u5e76\u4e0d\u591a\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u7684 split() \u4e5f\u5e76\u4e0d\u5b8c\u7f8e\u3002\u5982\u679c\u4f7f\u7528 split()\uff0c\u5c31\u4f1a\u51fa\u73b0\u8fd9\u6837\u7684\u53e5\u5b50\uff1a \"hi, how are you?\" \u7ecf\u8fc7\u5206\u5272\u540e\u53d8\u4e3a\uff1a [\"hi,\", \"how\",\"are\",\"you?\"] \u8fd9\u79cd\u65b9\u6cd5\u5e76\u4e0d\u7406\u60f3\uff0c\u56e0\u4e3a\u5355\u8bcd\u4e2d\u5305\u542b\u4e86\u9017\u53f7\u548c\u95ee\u53f7\uff0c\u5b83\u4eec\u5e76\u6ca1\u6709\u88ab\u5206\u5272\u3002\u56e0\u6b64\uff0c\u5982\u679c\u6ca1\u6709\u5728\u5206\u5272\u524d\u5bf9\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u8fdb\u884c\u9884\u5904\u7406\uff0c\u4e0d\u5efa\u8bae\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\u3002\u5c06\u5b57\u7b26\u4e32\u62c6\u5206\u4e3a\u5355\u8bcd\u5217\u8868\u79f0\u4e3a\u6807\u8bb0\u5316\u3002\u6700\u6d41\u884c\u7684\u6807\u8bb0\u5316\u65b9\u6cd5\u4e4b\u4e00\u6765\u81ea NLTK\uff08\u81ea\u7136\u8bed\u8a00\u5de5\u5177\u5305\uff09 \u3002 In [ X ]: from nltk.tokenize import word_tokenize In [ X ]: sentence = \"hi, how are you?\" In [ X ]: sentence . split () Out [ X ]: [ 'hi,' , 'how' , 'are' , 'you?' ] In [ X ]: word_tokenize ( sentence ) Out [ X ]: [ 'hi' , ',' , 'how' , 'are' , 'you' , '?' ] \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u4f7f\u7528 NLTK \u7684\u5355\u8bcd\u6807\u8bb0\u5316\u529f\u80fd\uff0c\u540c\u4e00\u4e2a\u53e5\u5b50\u7684\u62c6\u5206\u6548\u679c\u8981\u597d\u5f97\u591a\u3002\u4f7f\u7528\u5355\u8bcd\u5217\u8868\u8fdb\u884c\u5bf9\u6bd4\u7684\u6548\u679c\u4e5f\u4f1a\u66f4\u597d\uff01\u8fd9\u5c31\u662f\u6211\u4eec\u5c06\u5e94\u7528\u4e8e\u7b2c\u4e00\u4e2a\u60c5\u611f\u68c0\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u3002 \u5728\u5904\u7406 NLP \u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u60a8\u5e94\u8be5\u7ecf\u5e38\u5c1d\u8bd5\u7684\u57fa\u672c\u6a21\u578b\u4e4b\u4e00\u662f \u8bcd\u888b\u6a21\u578b\uff08bag of words\uff09 \u3002\u5728\u8bcd\u888b\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u5de8\u5927\u7684\u7a00\u758f\u77e9\u9635\uff0c\u5b58\u50a8\u8bed\u6599\u5e93\uff08\u8bed\u6599\u5e93=\u6240\u6709\u6587\u6863=\u6240\u6709\u53e5\u5b50\uff09\u4e2d\u6240\u6709\u5355\u8bcd\u7684\u8ba1\u6570\u3002\u4e3a\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 CountVectorizer\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u662f\u5982\u4f55\u5de5\u4f5c\u7684\u3002 from sklearn.feature_extraction.text import CountVectorizer corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer () ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) \u5982\u679c\u6211\u4eec\u6253\u5370 corpus_transformed\uff0c\u5c31\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e0b\u9762\u7684\u7ed3\u679c\uff1a ( 0 , 2 ) 1 ( 0 , 9 ) 1 ( 0 , 11 ) 1 ( 0 , 22 ) 1 ( 1 , 1 ) 1 ( 1 , 3 ) 1 ( 1 , 4 ) 1 ( 1 , 7 ) 1 ( 1 , 8 ) 1 ( 1 , 10 ) 1 ( 1 , 13 ) 1 ( 1 , 17 ) 1 ( 1 , 19 ) 1 ( 1 , 22 ) 2 ( 2 , 0 ) 1 ( 2 , 5 ) 1 ( 2 , 6 ) 1 ( 2 , 14 ) 1 ( 2 , 22 ) 1 ( 3 , 12 ) 1 ( 3 , 15 ) 1 ( 3 , 16 ) 1 ( 3 , 18 ) 1 ( 3 , 20 ) 1 ( 4 , 21 ) 1 \u5728\u524d\u9762\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u89c1\u8bc6\u8fc7\u8fd9\u79cd\u8868\u793a\u6cd5\u3002\u5373\u7a00\u758f\u8868\u793a\u6cd5\u3002\u56e0\u6b64\uff0c\u8bed\u6599\u5e93\u73b0\u5728\u662f\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5176\u4e2d\u7b2c\u4e00\u4e2a\u6837\u672c\u6709 4 \u4e2a\u5143\u7d20\uff0c\u7b2c\u4e8c\u4e2a\u6837\u672c\u6709 10 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\uff0c\u7b2c\u4e09\u4e2a\u6837\u672c\u6709 5 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e9b\u5143\u7d20\u90fd\u6709\u76f8\u5173\u7684\u8ba1\u6570\u3002\u6709\u4e9b\u5143\u7d20\u4f1a\u51fa\u73b0\u4e24\u6b21\uff0c\u6709\u4e9b\u5219\u53ea\u6709\u4e00\u6b21\u3002\u4f8b\u5982\uff0c\u5728\u6837\u672c 2\uff08\u7b2c 1 \u884c\uff09\u4e2d\uff0c\u6211\u4eec\u770b\u5230\u7b2c 22 \u5217\u7684\u6570\u503c\u662f 2\u3002\u8fd9\u662f\u4e3a\u4ec0\u4e48\u5462\uff1f\u7b2c 22 \u5217\u662f\u4ec0\u4e48\uff1f CountVectorizer \u7684\u5de5\u4f5c\u65b9\u5f0f\u662f\u9996\u5148\u5bf9\u53e5\u5b50\u8fdb\u884c\u6807\u8bb0\u5316\u5904\u7406\uff0c\u7136\u540e\u4e3a\u6bcf\u4e2a\u6807\u8bb0\u8d4b\u503c\u3002\u56e0\u6b64\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u7531\u4e00\u4e2a\u552f\u4e00\u7d22\u5f15\u8868\u793a\u3002\u8fd9\u4e9b\u552f\u4e00\u7d22\u5f15\u5c31\u662f\u6211\u4eec\u770b\u5230\u7684\u5217\u3002CountVectorizer \u4f1a\u5b58\u50a8\u8fd9\u4e9b\u4fe1\u606f\u3002 print ( ctv . vocabulary_ ) { 'hello' : 9 , 'how' : 11 , 'are' : 2 , 'you' : 22 , 'im' : 13 , 'getting' : 8 , 'bored' : 4 , 'at' : 3 , 'home' : 10 , 'and' : 1 , 'what' : 19 , 'do' : 7 , 'think' : 17 , 'did' : 6 , 'know' : 14 , 'about' : 0 , 'counts' : 5 , 'let' : 15 , 'see' : 16 , 'if' : 12 , 'this' : 18 , 'works' : 20 , 'yes' : 21 } \u6211\u4eec\u770b\u5230\uff0c\u7d22\u5f15 22 \u5c5e\u4e8e \"you\"\uff0c\u800c\u5728\u7b2c\u4e8c\u53e5\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u4e24\u6b21 \"you\"\u3002\u6211\u5e0c\u671b\u5927\u5bb6\u73b0\u5728\u5df2\u7ecf\u6e05\u695a\u4ec0\u4e48\u662f\u8bcd\u888b\u4e86\u3002\u4f46\u662f\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4e9b\u7279\u6b8a\u5b57\u7b26\u3002\u6709\u65f6\uff0c\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u4e5f\u5f88\u6709\u7528\u3002\u4f8b\u5982\uff0c\"? \"\u5728\u5927\u591a\u6570\u53e5\u5b50\u4e2d\u8868\u793a\u7591\u95ee\u53e5\u3002\u8ba9\u6211\u4eec\u628a scikit-learn \u7684 word_tokenize \u6574\u5408\u5230 CountVectorizer \u4e2d\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002 from sklearn.feature_extraction.text import CountVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) print ( ctv . vocabulary_ ) \u8fd9\u6837\uff0c\u6211\u4eec\u7684\u8bcd\u888b\u5c31\u53d8\u6210\u4e86\uff1a { 'hello' : 14 , ',' : 2 , 'how' : 16 , 'are' : 7 , 'you' : 27 , '?' : 4 , 'im' : 18 , 'getting' : 13 , 'bored' : 9 , 'at' : 8 , 'home' : 15 , '.' : 3 , 'and' : 6 , 'what' : 24 , 'do' : 12 , 'think' : 22 , 'did' : 11 , 'know' : 19 , 'about' : 5 , 'counts' : 10 , 'let' : 20 , \"'s\" : 1 , 'see' : 21 , 'if' : 17 , 'this' : 23 , 'works' : 25 , '!' : 0 , 'yes' : 26 } \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u5229\u7528 IMDB \u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u53e5\u5b50\u521b\u5efa\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5e76\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002\u8be5\u6570\u636e\u96c6\u4e2d\u6b63\u8d1f\u6837\u672c\u7684\u6bd4\u4f8b\u4e3a 1:1\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u5e76\u521b\u5efa\u4e00\u4e2a\u811a\u672c\u6765\u8bad\u7ec35\u4e2a\u6298\u53e0\u3002\u4f60\u4f1a\u95ee\u4f7f\u7528\u54ea\u4e2a\u6a21\u578b\uff1f\u5bf9\u4e8e\u9ad8\u7ef4\u7a00\u758f\u6570\u636e\uff0c\u54ea\u4e2a\u6a21\u578b\u6700\u5feb\uff1f\u903b\u8f91\u56de\u5f52\u3002\u6211\u4eec\u5c06\u9996\u5148\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u6765\u5904\u7406\u8fd9\u4e2a\u6570\u636e\u96c6\uff0c\u5e76\u521b\u5efa\u7b2c\u4e00\u4e2a\u57fa\u51c6\u6a21\u578b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) count_vec = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) count_vec . fit ( train_df . review ) xtrain = count_vec . transform ( train_df . review ) xtest = count_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u6bb5\u4ee3\u7801\u7684\u8fd0\u884c\u9700\u8981\u4e00\u5b9a\u7684\u65f6\u95f4\uff0c\u4f46\u53ef\u4ee5\u5f97\u5230\u4ee5\u4e0b\u8f93\u51fa\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8903 Fold : 1 Accuracy = 0.897 Fold : 2 Accuracy = 0.891 Fold : 3 Accuracy = 0.8914 Fold : 4 Accuracy = 0.8931 \u54c7\uff0c\u51c6\u786e\u7387\u5df2\u7ecf\u8fbe\u5230 89%\uff0c\u800c\u6211\u4eec\u6240\u505a\u7684\u53ea\u662f\u4f7f\u7528\u8bcd\u888b\u548c\u903b\u8f91\u56de\u5f52\uff01\u8fd9\u771f\u662f\u592a\u68d2\u4e86\uff01\u4e0d\u8fc7\uff0c\u8fd9\u4e2a\u6a21\u578b\u7684\u8bad\u7ec3\u82b1\u8d39\u4e86\u5f88\u591a\u65f6\u95f4\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u80fd\u5426\u901a\u8fc7\u4f7f\u7528\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u6765\u7f29\u77ed\u8bad\u7ec3\u65f6\u95f4\u3002\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u5728 NLP \u4efb\u52a1\u4e2d\u76f8\u5f53\u6d41\u884c\uff0c\u56e0\u4e3a\u7a00\u758f\u77e9\u9635\u975e\u5e38\u5e9e\u5927\uff0c\u800c\u6734\u7d20\u8d1d\u53f6\u65af\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\u3002\u8981\u4f7f\u7528\u8fd9\u4e2a\u6a21\u578b\uff0c\u9700\u8981\u66f4\u6539\u4e00\u4e2a\u5bfc\u5165\u548c\u6a21\u578b\u7684\u884c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e2a\u6a21\u578b\u7684\u6027\u80fd\u5982\u4f55\u3002\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 MultinomialNB\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import naive_bayes from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer model = naive_bayes . MultinomialNB () model . fit ( xtrain , train_df . sentiment ) \u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8444 Fold : 1 Accuracy = 0.8499 Fold : 2 Accuracy = 0.8422 Fold : 3 Accuracy = 0.8443 Fold : 4 Accuracy = 0.8455 \u6211\u4eec\u770b\u5230\u8fd9\u4e2a\u5206\u6570\u5f88\u4f4e\u3002\u4f46\u6734\u7d20\u8d1d\u53f6\u65af\u6a21\u578b\u7684\u901f\u5ea6\u975e\u5e38\u5feb\u3002 NLP \u4e2d\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f TF-IDF\uff0c\u5982\u4eca\u5927\u591a\u6570\u4eba\u90fd\u503e\u5411\u4e8e\u5ffd\u7565\u6216\u4e0d\u5c51\u4e8e\u4e86\u89e3\u8fd9\u79cd\u65b9\u6cd5\u3002TF \u662f\u672f\u8bed\u9891\u7387\uff0cIDF \u662f\u53cd\u5411\u6587\u6863\u9891\u7387\u3002\u4ece\u8fd9\u4e9b\u672f\u8bed\u6765\u770b\uff0c\u8fd9\u4f3c\u4e4e\u6709\u4e9b\u56f0\u96be\uff0c\u4f46\u901a\u8fc7 TF \u548c IDF \u7684\u8ba1\u7b97\u516c\u5f0f\uff0c\u4e8b\u60c5\u5c31\u4f1a\u53d8\u5f97\u5f88\u660e\u663e\u3002 $$ TF(t) = \\frac{Number\\ of\\ times\\ a\\ term\\ t\\ appears\\ in\\ a\\ document}{Total\\ number\\ of\\ terms\\ in \\ the\\ document} $$ \\[ IDF(t) = LOG\\left(\\frac{Total\\ number\\ of\\ documents}{Number\\ of\\ documents with\\ term\\ t\\ in\\ it}\\right) \\] \u672f\u8bed t \u7684 TF-IDF \u5b9a\u4e49\u4e3a\uff1a $$ TF-IDF(t) = TF(t) \\times IDF(t) $$ \u4e0e scikit-learn \u4e2d\u7684 CountVectorizer \u7c7b\u4f3c\uff0c\u6211\u4eec\u4e5f\u6709 TfidfVectorizer\u3002\u8ba9\u6211\u4eec\u8bd5\u7740\u50cf\u4f7f\u7528 CountVectorizer \u4e00\u6837\u4f7f\u7528\u5b83\u3002 from sklearn.feature_extraction.text import TfidfVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) print ( corpus_transformed ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a ( 0 , 27 ) 0.2965698850220162 ( 0 , 16 ) 0.4428321995085722 ( 0 , 14 ) 0.4428321995085722 ( 0 , 7 ) 0.4428321995085722 ( 0 , 4 ) 0.35727423026525224 ( 0 , 2 ) 0.4428321995085722 ( 1 , 27 ) 0.35299699146792735 ( 1 , 24 ) 0.2635440111190765 ( 1 , 22 ) 0.2635440111190765 ( 1 , 18 ) 0.2635440111190765 ( 1 , 15 ) 0.2635440111190765 ( 1 , 13 ) 0.2635440111190765 ( 1 , 12 ) 0.2635440111190765 ( 1 , 9 ) 0.2635440111190765 ( 1 , 8 ) 0.2635440111190765 ( 1 , 6 ) 0.2635440111190765 ( 1 , 4 ) 0.42525129752567803 ( 1 , 3 ) 0.2635440111190765 ( 2 , 27 ) 0.31752680284846835 ( 2 , 19 ) 0.4741246485558491 ( 2 , 11 ) 0.4741246485558491 ( 2 , 10 ) 0.4741246485558491 ( 2 , 5 ) 0.4741246485558491 ( 3 , 25 ) 0.38775666010579296 ( 3 , 23 ) 0.38775666010579296 ( 3 , 21 ) 0.38775666010579296 ( 3 , 20 ) 0.38775666010579296 ( 3 , 17 ) 0.38775666010579296 ( 3 , 1 ) 0.38775666010579296 ( 3 , 0 ) 0.3128396318588854 ( 4 , 26 ) 0.2959842226518677 ( 4 , 0 ) 0.9551928286692534 \u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6b21\u6211\u4eec\u5f97\u5230\u7684\u4e0d\u662f\u6574\u6570\u503c\uff0c\u800c\u662f\u6d6e\u70b9\u6570\u3002 \u7528 TfidfVectorizer \u4ee3\u66ff CountVectorizer \u4e5f\u662f\u5c0f\u83dc\u4e00\u789f\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 TfidfTransformer\u3002\u5982\u679c\u4f60\u4f7f\u7528\u7684\u662f\u8ba1\u6570\u503c\uff0c\u53ef\u4ee5\u4f7f\u7528 TfidfTransformer \u5e76\u83b7\u5f97\u4e0e TfidfVectorizer \u76f8\u540c\u7684\u6548\u679c\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfidf_vec . fit ( train_df . review ) xtrain = tfidf_vec . transform ( train_df . review ) xtest = tfidf_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u6211\u4eec\u53ef\u4ee5\u770b\u770b TF-IDF \u5728\u903b\u8f91\u56de\u5f52\u6a21\u578b\u4e0a\u7684\u8868\u73b0\u5982\u4f55\u3002 Fold : 0 Accuracy = 0.8976 Fold : 1 Accuracy = 0.8998 Fold : 2 Accuracy = 0.8948 Fold : 3 Accuracy = 0.8912 Fold : 4 Accuracy = 0.8995 \u6211\u4eec\u770b\u5230\uff0c\u8fd9\u4e9b\u5206\u6570\u90fd\u6bd4 CountVectorizer \u9ad8\u4e00\u4e9b\uff0c\u56e0\u6b64\u5b83\u6210\u4e3a\u4e86\u6211\u4eec\u60f3\u8981\u51fb\u8d25\u7684\u65b0\u57fa\u51c6\u3002 NLP \u4e2d\u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u6982\u5ff5\u662f N-gram\u3002N-grams \u662f\u6309\u987a\u5e8f\u6392\u5217\u7684\u5355\u8bcd\u7ec4\u5408\u3002N-grams \u5f88\u5bb9\u6613\u521b\u5efa\u3002\u60a8\u53ea\u9700\u6ce8\u610f\u987a\u5e8f\u5373\u53ef\u3002\u4e3a\u4e86\u8ba9\u4e8b\u60c5\u53d8\u5f97\u66f4\u7b80\u5355\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 NLTK \u7684 N-gram \u5b9e\u73b0\u3002 from nltk import ngrams from nltk.tokenize import word_tokenize N = 3 sentence = \"hi, how are you?\" tokenized_sentence = word_tokenize ( sentence ) n_grams = list ( ngrams ( tokenized_sentence , N )) print ( n_grams ) \u7531\u6b64\u5f97\u5230\uff1a [( 'hi' , ',' , 'how' ), ( ',' , 'how' , 'are' ), ( 'how' , 'are' , 'you' ), ( 'are' , 'you' , '?' )] \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u521b\u5efa 2-gram \u6216 4-gram \u7b49\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b n-gram \u5c06\u6210\u4e3a\u6211\u4eec\u8bcd\u6c47\u8868\u7684\u4e00\u90e8\u5206\uff0c\u5f53\u6211\u4eec\u8ba1\u7b97\u8ba1\u6570\u6216 tf-idf \u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u4e00\u4e2a n-gram \u89c6\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u6807\u8bb0\u3002\u56e0\u6b64\uff0c\u5728\u67d0\u79cd\u7a0b\u5ea6\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u7ed3\u5408\u4e0a\u4e0b\u6587\u3002scikit-learn \u7684 CountVectorizer \u548c TfidfVectorizer \u5b9e\u73b0\u90fd\u901a\u8fc7 ngram_range \u53c2\u6570\u63d0\u4f9b n-gram\uff0c\u8be5\u53c2\u6570\u6709\u6700\u5c0f\u548c\u6700\u5927\u9650\u5236\u3002\u9ed8\u8ba4\u60c5\u51b5\u4e0b\uff0c\u8be5\u53c2\u6570\u4e3a\uff081, 1\uff09\u3002\u5f53\u6211\u4eec\u5c06\u5176\u6539\u4e3a (1, 3) \u65f6\uff0c\u6211\u4eec\u5c06\u770b\u5230\u5355\u5b57\u5143\u3001\u53cc\u5b57\u5143\u548c\u4e09\u5b57\u5143\u3002\u4ee3\u7801\u6539\u52a8\u5f88\u5c0f\u3002 \u7531\u4e8e\u5230\u76ee\u524d\u4e3a\u6b62\u6211\u4eec\u4f7f\u7528 tf-idf \u5f97\u5230\u4e86\u6700\u597d\u7684\u7ed3\u679c\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5305\u542b n-grams \u76f4\u81f3 trigrams \u662f\u5426\u80fd\u6539\u8fdb\u6a21\u578b\u3002\u552f\u4e00\u9700\u8981\u4fee\u6539\u7684\u662f TfidfVectorizer \u7684\u521d\u59cb\u5316\u3002 tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None , ngram_range = ( 1 , 3 ) ) \u8ba9\u6211\u4eec\u770b\u770b\u662f\u5426\u4f1a\u6709\u6539\u8fdb\u3002 Fold : 0 Accuracy = 0.8931 Fold : 1 Accuracy = 0.8941 Fold : 2 Accuracy = 0.897 Fold : 3 Accuracy = 0.8922 Fold : 4 Accuracy = 0.8847 \u770b\u8d77\u6765\u8fd8\u884c\uff0c\u4f46\u6211\u4eec\u770b\u4e0d\u5230\u4efb\u4f55\u6539\u8fdb\u3002 \u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u591a\u4f7f\u7528 bigrams \u6765\u83b7\u5f97\u6539\u8fdb\u3002 \u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u4e00\u90e8\u5206\u3002\u4e5f\u8bb8\u4f60\u53ef\u4ee5\u81ea\u5df1\u8bd5\u7740\u505a\u3002 NLP \u7684\u57fa\u7840\u77e5\u8bc6\u8fd8\u6709\u5f88\u591a\u3002\u4f60\u5fc5\u987b\u77e5\u9053\u7684\u4e00\u4e2a\u672f\u8bed\u662f\u8bcd\u5e72\u63d0\u53d6\uff08strmming\uff09\u3002\u53e6\u4e00\u4e2a\u662f\u8bcd\u5f62\u8fd8\u539f\uff08lemmatization\uff09\u3002 \u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f \u53ef\u4ee5\u5c06\u4e00\u4e2a\u8bcd\u51cf\u5c11\u5230\u6700\u5c0f\u5f62\u5f0f\u3002\u5728\u8bcd\u5e72\u63d0\u53d6\u7684\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5e72\u5355\u8bcd\uff0c\u800c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5f62\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u8bcd\u5f62\u8fd8\u539f\u6bd4\u8bcd\u5e72\u63d0\u53d6\u66f4\u6fc0\u8fdb\uff0c\u800c\u8bcd\u5e72\u63d0\u53d6\u66f4\u6d41\u884c\u548c\u5e7f\u6cdb\u3002\u8bcd\u5e72\u548c\u8bcd\u5f62\u90fd\u6765\u81ea\u8bed\u8a00\u5b66\u3002\u5982\u679c\u4f60\u6253\u7b97\u4e3a\u67d0\u79cd\u8bed\u8a00\u5236\u4f5c\u8bcd\u5e72\u6216\u8bcd\u578b\uff0c\u9700\u8981\u5bf9\u8be5\u8bed\u8a00\u6709\u6df1\u5165\u7684\u4e86\u89e3\u3002\u5982\u679c\u8981\u8fc7\u591a\u5730\u4ecb\u7ecd\u8fd9\u4e9b\u77e5\u8bc6\uff0c\u5c31\u610f\u5473\u7740\u8981\u5728\u672c\u4e66\u4e2d\u589e\u52a0\u4e00\u7ae0\u3002\u4f7f\u7528 NLTK \u8f6f\u4ef6\u5305\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e24\u79cd\u65b9\u6cd5\u7684\u4e00\u4e9b\u793a\u4f8b\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u5668\u3002\u6211\u5c06\u7528\u6700\u5e38\u89c1\u7684 Snowball Stemmer \u548c WordNet Lemmatizer \u6765\u4e3e\u4f8b\u8bf4\u660e\u3002 from nltk.stem import WordNetLemmatizer from nltk.stem.snowball import SnowballStemmer lemmatizer = WordNetLemmatizer () stemmer = SnowballStemmer ( \"english\" ) words = [ \"fishing\" , \"fishes\" , \"fished\" ] for word in words : print ( f \"word= { word } \" ) print ( f \"stemmed_word= { stemmer . stem ( word ) } \" ) print ( f \"lemma= { lemmatizer . lemmatize ( word ) } \" ) print ( \"\" ) \u8fd9\u5c06\u6253\u5370\uff1a word = fishing stemmed_word = fish lemma = fishing word = fishes stemmed_word = fish lemma = fish word = fished stemmed_word = fish lemma = fished \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u662f\u622a\u7136\u4e0d\u540c\u7684\u3002\u5f53\u6211\u4eec\u8fdb\u884c\u8bcd\u5e72\u63d0\u53d6\u65f6\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u4e00\u4e2a\u8bcd\u7684\u6700\u5c0f\u5f62\u5f0f\uff0c\u5b83\u53ef\u80fd\u662f\u4e5f\u53ef\u80fd\u4e0d\u662f\u8be5\u8bcd\u6240\u5c5e\u8bed\u8a00\u8bcd\u5178\u4e2d\u7684\u4e00\u4e2a\u8bcd\u3002\u4f46\u662f\uff0c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u8fd9\u5c06\u662f\u4e00\u4e2a\u8bcd\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5c1d\u8bd5\u6dfb\u52a0\u8bcd\u5e72\u548c\u8bcd\u7d20\u5316\uff0c\u770b\u770b\u662f\u5426\u80fd\u6539\u5584\u7ed3\u679c\u3002 \u60a8\u8fd8\u5e94\u8be5\u4e86\u89e3\u7684\u4e00\u4e2a\u4e3b\u9898\u662f\u4e3b\u9898\u63d0\u53d6\u3002 \u4e3b\u9898\u63d0\u53d6 \u53ef\u4ee5\u4f7f\u7528\u975e\u8d1f\u77e9\u9635\u56e0\u5f0f\u5206\u89e3\uff08NMF\uff09\u6216\u6f5c\u5728\u8bed\u4e49\u5206\u6790\uff08LSA\uff09\u6765\u5b8c\u6210\uff0c\u540e\u8005\u4e5f\u88ab\u79f0\u4e3a\u5947\u5f02\u503c\u5206\u89e3\u6216 SVD\u3002\u8fd9\u4e9b\u5206\u89e3\u6280\u672f\u53ef\u5c06\u6570\u636e\u7b80\u5316\u4e3a\u7ed9\u5b9a\u6570\u91cf\u7684\u6210\u5206\u3002 \u60a8\u53ef\u4ee5\u5728\u4ece CountVectorizer \u6216 TfidfVectorizer \u4e2d\u83b7\u5f97\u7684\u7a00\u758f\u77e9\u9635\u4e0a\u5e94\u7528\u5176\u4e2d\u4efb\u4f55\u4e00\u79cd\u6280\u672f\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u5e94\u7528\u5230\u4e4b\u524d\u4f7f\u7528\u8fc7\u7684 TfidfVetorizer \u4e0a\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import decomposition from sklearn.feature_extraction.text import TfidfVectorizer corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus = corpus . review . values tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) svd = decomposition . TruncatedSVD ( n_components = 10 ) corpus_svd = svd . fit ( corpus_transformed ) sample_index = 0 feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) N = 5 print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ]) \u60a8\u53ef\u4ee5\u4f7f\u7528\u5faa\u73af\u6765\u8fd0\u884c\u591a\u4e2a\u6837\u672c\u3002 N = 5 for sample_index in range ( 5 ): feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ] ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a [ 'the' , ',' , '.' , 'a' , 'and' ] [ 'br' , '<' , '>' , '/' , '-' ] [ 'i' , 'movie' , '!' , 'it' , 'was' ] [ ',' , '!' , \"''\" , '``' , 'you' ] [ '!' , 'the' , '...' , \"''\" , '``' ] \u4f60\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6839\u672c\u8bf4\u4e0d\u901a\u3002\u600e\u4e48\u529e\u5462\uff1f\u8ba9\u6211\u4eec\u8bd5\u7740\u6e05\u7406\u4e00\u4e0b\uff0c\u770b\u770b\u662f\u5426\u6709\u610f\u4e49\u3002\u8981\u6e05\u7406\u4efb\u4f55\u6587\u672c\u6570\u636e\uff0c\u5c24\u5176\u662f pandas \u6570\u636e\u5e27\u4e2d\u7684\u6587\u672c\u6570\u636e\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u51fd\u6570\u3002 import re import string def clean_text ( s ): s = s . split () s = \" \" . join ( s ) s = re . sub ( f '[ { re . escape ( string . punctuation ) } ]' , '' , s ) return s \u8be5\u51fd\u6570\u4f1a\u5c06 \"hi, how are you????\" \u8fd9\u6837\u7684\u5b57\u7b26\u4e32\u8f6c\u6362\u4e3a \"hi how are you\"\u3002\u8ba9\u6211\u4eec\u628a\u8fd9\u4e2a\u51fd\u6570\u5e94\u7528\u5230\u65e7\u7684 SVD \u4ee3\u7801\u4e2d\uff0c\u770b\u770b\u5b83\u662f\u5426\u80fd\u7ed9\u63d0\u53d6\u7684\u4e3b\u9898\u5e26\u6765\u63d0\u5347\u3002\u4f7f\u7528 pandas\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 apply \u51fd\u6570\u5c06\u6e05\u7406\u4ee3\u7801 \"\u5e94\u7528 \"\u5230\u4efb\u610f\u7ed9\u5b9a\u7684\u5217\u4e2d\u3002 import pandas as pd corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus . loc [:, \"review\" ] = corpus . review . apply ( clean_text ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u5728\u4e3b SVD \u811a\u672c\u4e2d\u6dfb\u52a0\u4e86\u4e00\u884c\u4ee3\u7801\uff0c\u8fd9\u5c31\u662f\u4f7f\u7528\u51fd\u6570\u548c pandas \u5e94\u7528\u7684\u597d\u5904\u3002\u8fd9\u6b21\u751f\u6210\u7684\u4e3b\u9898\u5982\u4e0b\u3002 [ 'the' , 'a' , 'and' , 'of' , 'to' ] [ 'i' , 'movie' , 'it' , 'was' , 'this' ] [ 'the' , 'was' , 'i' , 'were' , 'of' ] [ 'her' , 'was' , 'she' , 'i' , 'he' ] [ 'br' , 'to' , 'they' , 'he' , 'show' ] \u547c\uff01\u81f3\u5c11\u8fd9\u6bd4\u6211\u4eec\u4e4b\u524d\u597d\u591a\u4e86\u3002\u4f46\u4f60\u77e5\u9053\u5417\uff1f\u4f60\u53ef\u4ee5\u901a\u8fc7\u5728\u6e05\u7406\u529f\u80fd\u4e2d\u5220\u9664\u505c\u6b62\u8bcd\uff08stopwords\uff09\u6765\u4f7f\u5b83\u53d8\u5f97\u66f4\u597d\u3002\u4ec0\u4e48\u662fstopwords\uff1f\u5b83\u4eec\u662f\u5b58\u5728\u4e8e\u6bcf\u79cd\u8bed\u8a00\u4e2d\u7684\u9ad8\u9891\u8bcd\u3002\u4f8b\u5982\uff0c\u5728\u82f1\u8bed\u4e2d\uff0c\u8fd9\u4e9b\u8bcd\u5305\u62ec \"a\"\u3001\"an\"\u3001\"the\"\u3001\"for \"\u7b49\u3002\u5220\u9664\u505c\u6b62\u8bcd\u5e76\u975e\u603b\u662f\u660e\u667a\u7684\u9009\u62e9\uff0c\u8fd9\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u4e1a\u52a1\u95ee\u9898\u3002\u50cf \"I need a new dog\"\u8fd9\u6837\u7684\u53e5\u5b50\uff0c\u53bb\u6389\u505c\u6b62\u8bcd\u540e\u4f1a\u53d8\u6210 \"need new dog\"\uff0c\u6b64\u65f6\u6211\u4eec\u4e0d\u77e5\u9053\u8c01\u9700\u8981new dog\u3002 \u5982\u679c\u6211\u4eec\u603b\u662f\u5220\u9664\u505c\u6b62\u8bcd\uff0c\u5c31\u4f1a\u4e22\u5931\u5f88\u591a\u4e0a\u4e0b\u6587\u4fe1\u606f\u3002\u4f60\u53ef\u4ee5\u5728 NLTK \u4e2d\u627e\u5230\u8bb8\u591a\u8bed\u8a00\u7684\u505c\u6b62\u8bcd\uff0c\u5982\u679c\u6ca1\u6709\uff0c\u4f60\u4e5f\u53ef\u4ee5\u5728\u81ea\u5df1\u559c\u6b22\u7684\u641c\u7d22\u5f15\u64ce\u4e0a\u5feb\u901f\u641c\u7d22\u4e00\u4e0b\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u8f6c\u5230\u5927\u591a\u6570\u4eba\u90fd\u559c\u6b22\u4f7f\u7528\u7684\u65b9\u6cd5\uff1a\u6df1\u5ea6\u5b66\u4e60\u3002\u4f46\u9996\u5148\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u8bcd\u5d4c\u5165\uff08embedings for words\uff09\u3002\u4f60\u5df2\u7ecf\u770b\u5230\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u5c06\u6807\u8bb0\u8f6c\u6362\u6210\u4e86\u6570\u5b57\u3002\u56e0\u6b64\uff0c\u5982\u679c\u67d0\u4e2a\u8bed\u6599\u5e93\u4e2d\u6709 N \u4e2a\u552f\u4e00\u7684\u8bcd\u5757\uff0c\u5b83\u4eec\u53ef\u4ee5\u7528 0 \u5230 N-1 \u4e4b\u95f4\u7684\u6574\u6570\u6765\u8868\u793a\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u7528\u5411\u91cf\u6765\u8868\u793a\u8fd9\u4e9b\u6574\u6570\u8bcd\u5757\u3002\u8fd9\u79cd\u5c06\u5355\u8bcd\u8868\u793a\u6210\u5411\u91cf\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u5355\u8bcd\u5d4c\u5165\u6216\u5355\u8bcd\u5411\u91cf\u3002\u8c37\u6b4c\u7684 Word2Vec \u662f\u5c06\u5355\u8bcd\u8f6c\u6362\u4e3a\u5411\u91cf\u7684\u6700\u53e4\u8001\u65b9\u6cd5\u4e4b\u4e00\u3002\u6b64\u5916\uff0c\u8fd8\u6709 Facebook \u7684 FastText \u548c\u65af\u5766\u798f\u5927\u5b66\u7684 GloVe\uff08\u7528\u4e8e\u5355\u8bcd\u8868\u793a\u7684\u5168\u5c40\u5411\u91cf\uff09\u3002\u8fd9\u4e9b\u65b9\u6cd5\u5f7c\u6b64\u5927\u76f8\u5f84\u5ead\u3002 \u5176\u57fa\u672c\u601d\u60f3\u662f\u5efa\u7acb\u4e00\u4e2a\u6d45\u5c42\u7f51\u7edc\uff0c\u901a\u8fc7\u91cd\u6784\u8f93\u5165\u53e5\u5b50\u6765\u5b66\u4e60\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u5468\u56f4\u7684\u6240\u6709\u5355\u8bcd\u6765\u8bad\u7ec3\u7f51\u7edc\u9884\u6d4b\u4e00\u4e2a\u7f3a\u5931\u7684\u5355\u8bcd\uff0c\u5728\u6b64\u8fc7\u7a0b\u4e2d\uff0c\u7f51\u7edc\u5c06\u5b66\u4e60\u5e76\u66f4\u65b0\u6240\u6709\u76f8\u5173\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u8fd9\u79cd\u65b9\u6cd5\u4e5f\u88ab\u79f0\u4e3a\u8fde\u7eed\u8bcd\u888b\u6216 CBoW \u6a21\u578b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e2a\u5355\u8bcd\u6765\u9884\u6d4b\u4e0a\u4e0b\u6587\u4e2d\u7684\u5355\u8bcd\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8df3\u683c\u6a21\u578b\u3002Word2Vec \u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\u5b66\u4e60\u5d4c\u5165\u3002 FastText \u53ef\u4ee5\u5b66\u4e60\u5b57\u7b26 n-gram \u7684\u5d4c\u5165\u3002\u548c\u5355\u8bcd n-gram \u4e00\u6837\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5b57\u7b26\uff0c\u5219\u79f0\u4e3a\u5b57\u7b26 n-gram\uff0c\u6700\u540e\uff0cGloVe \u901a\u8fc7\u5171\u73b0\u77e9\u9635\u6765\u5b66\u4e60\u8fd9\u4e9b\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0c\u6240\u6709\u8fd9\u4e9b\u4e0d\u540c\u7c7b\u578b\u7684\u5d4c\u5165\u6700\u7ec8\u90fd\u4f1a\u8fd4\u56de\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u8bed\u6599\u5e93\uff08\u4f8b\u5982\u82f1\u8bed\u7ef4\u57fa\u767e\u79d1\uff09\u4e2d\u7684\u5355\u8bcd\uff0c\u503c\u662f\u5927\u5c0f\u4e3a N\uff08\u901a\u5e38\u4e3a 300\uff09\u7684\u5411\u91cf\u3002 \u56fe 1\uff1a\u53ef\u89c6\u5316\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u3002 \u56fe 1 \u663e\u793a\u4e86\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u7684\u53ef\u89c6\u5316\u6548\u679c\u3002\u5047\u8bbe\u6211\u4eec\u4ee5\u67d0\u79cd\u65b9\u5f0f\u5b8c\u6210\u4e86\u8bcd\u8bed\u7684\u4e8c\u7ef4\u8868\u793a\u3002\u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u4eceBerlin\uff08\u5fb7\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u4e2d\u51cf\u53bb\u5fb7\u56fd\uff08Germany\uff09\u7684\u5411\u91cf\uff0c\u518d\u52a0\u4e0a\u6cd5\u56fd\uff08france\uff09\u7684\u5411\u91cf\uff0c\u5c31\u4f1a\u5f97\u5230\u4e00\u4e2a\u63a5\u8fd1Paris\uff08\u6cd5\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u3002\u7531\u6b64\u53ef\u89c1\uff0c\u5d4c\u5165\u5f0f\u4e5f\u80fd\u8fdb\u884c\u7c7b\u6bd4\u3002 \u8fd9\u5e76\u4e0d\u603b\u662f\u6b63\u786e\u7684\uff0c\u4f46\u8fd9\u6837\u7684\u4f8b\u5b50\u6709\u52a9\u4e8e\u7406\u89e3\u5355\u8bcd\u5d4c\u5165\u7684\u4f5c\u7528\u3002\u50cf \"\u55e8\uff0c\u4f60\u597d\u5417 \"\u8fd9\u6837\u7684\u53e5\u5b50\u53ef\u4ee5\u7528\u4e0b\u9762\u7684\u4e00\u5806\u5411\u91cf\u6765\u8868\u793a\u3002 hi \u2500> [vector (v1) of size 300] , \u2500> [vector (v2) of size 300] how \u2500> [vector (v3) of size 300] are \u2500> [vector (v4) of size 300] you \u2500> [vector (v5) of size 300] ? \u2500> [vector (v6) of size 300] \u4f7f\u7528\u8fd9\u4e9b\u4fe1\u606f\u6709\u591a\u79cd\u65b9\u6cd5\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u4e4b\u4e00\u5c31\u662f\u4f7f\u7528\u5d4c\u5165\u5411\u91cf\u3002\u5982\u4e0a\u4f8b\u6240\u793a\uff0c\u6bcf\u4e2a\u5355\u8bcd\u90fd\u6709\u4e00\u4e2a 1x300 \u7684\u5d4c\u5165\u5411\u91cf\u3002\u5229\u7528\u8fd9\u4e9b\u4fe1\u606f\uff0c\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u51fa\u6574\u4e2a\u53e5\u5b50\u7684\u5d4c\u5165\u3002\u8ba1\u7b97\u65b9\u6cd5\u6709\u591a\u79cd\u3002\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\u5982\u4e0b\u6240\u793a\u3002\u5728\u8fd9\u4e2a\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u5c06\u7ed9\u5b9a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u5411\u91cf\u63d0\u53d6\u51fa\u6765\uff0c\u7136\u540e\u4ece\u6240\u6709\u6807\u8bb0\u8bcd\u7684\u5355\u8bcd\u5411\u91cf\u4e2d\u521b\u5efa\u4e00\u4e2a\u5f52\u4e00\u5316\u7684\u5355\u8bcd\u5411\u91cf\u3002\u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u53e5\u5b50\u5411\u91cf\u3002 import numpy as np def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): words = str ( s ) . lower () words = tokenizer ( words ) words = [ w for w in words if not w in stop_words ] words = [ w for w in words if w . isalpha ()] M = [] for w in words : if w in embedding_dict : M . append ( embedding_dict [ w ]) if len ( M ) == 0 : return np . zeros ( 300 ) M = np . array ( M ) v = M . sum ( axis = 0 ) return v / np . sqrt (( v ** 2 ) . sum ()) \u6211\u4eec\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5c06\u6240\u6709\u793a\u4f8b\u8f6c\u6362\u6210\u4e00\u4e2a\u5411\u91cf\u3002\u6211\u4eec\u80fd\u5426\u4f7f\u7528 fastText \u5411\u91cf\u6765\u6539\u8fdb\u4e4b\u524d\u7684\u7ed3\u679c\uff1f\u6bcf\u7bc7\u8bc4\u8bba\u90fd\u6709 300 \u4e2a\u7279\u5f81\u3002 import io import numpy as np import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df = df . sample ( frac = 1 ) . reset_index ( drop = True ) print ( \"Loading embeddings\" ) embeddings = load_vectors ( \"../input/crawl-300d-2M.vec\" ) print ( \"Creating sentence vectors\" ) vectors = [] for review in df . review . values : vectors . append ( sentence_to_vec ( s = review , embedding_dict = embeddings , stop_words = [], tokenizer = word_tokenize ) ) vectors = np . array ( vectors ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for fold_ , ( t_ , v_ ) in enumerate ( kf . split ( X = vectors , y = y )): print ( f \"Training fold: { fold_ } \" ) xtrain = vectors [ t_ , :] ytrain = y [ t_ ] xtest = vectors [ v_ , :] ytest = y [ v_ ] model = linear_model . LogisticRegression () model . fit ( xtrain , ytrain ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( ytest , preds ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u5c06\u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Loading embeddings Creating sentence vectors Training fold : 0 Accuracy = 0.8619 Training fold : 1 Accuracy = 0.8661 Training fold : 2 Accuracy = 0.8544 Training fold : 3 Accuracy = 0.8624 Training fold : 4 Accuracy = 0.8595 Wow\uff01\u771f\u662f\u51fa\u4e4e\u610f\u6599\u3002\u6211\u4eec\u6240\u505a\u7684\u4e00\u5207\u90fd\u662f\u4e3a\u4e86\u4f7f\u7528 FastText \u5d4c\u5165\u3002\u8bd5\u7740\u628a\u5d4c\u5165\u5f0f\u6362\u6210 GloVe\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u3002 \u5f53\u6211\u4eec\u8c08\u8bba\u6587\u672c\u6570\u636e\u65f6\uff0c\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\u4e00\u4ef6\u4e8b\u3002\u6587\u672c\u6570\u636e\u4e0e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u975e\u5e38\u76f8\u4f3c\u3002\u5982\u56fe 2 \u6240\u793a\uff0c\u6211\u4eec\u8bc4\u8bba\u4e2d\u7684\u4efb\u4f55\u6837\u672c\u90fd\u662f\u5728\u4e0d\u540c\u65f6\u95f4\u6233\u4e0a\u6309\u9012\u589e\u987a\u5e8f\u6392\u5217\u7684\u6807\u8bb0\u5e8f\u5217\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u53ef\u4ee5\u8868\u793a\u4e3a\u4e00\u4e2a\u5411\u91cf/\u5d4c\u5165\u3002 \u56fe 2\uff1a\u5c06\u6807\u8bb0\u8868\u793a\u4e3a\u5d4c\u5165\uff0c\u5e76\u5c06\u5176\u89c6\u4e3a\u65f6\u95f4\u5e8f\u5217 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e7f\u6cdb\u7528\u4e8e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u7684\u6a21\u578b\uff0c\u4f8b\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\uff08LSTM\uff09\u6216\u95e8\u63a7\u9012\u5f52\u5355\u5143\uff08GRU\uff09\uff0c\u751a\u81f3\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff08CNN\uff09\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u53cc\u5411 LSTM \u6a21\u578b\u3002 \u9996\u5148\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u9879\u76ee\u3002\u4f60\u53ef\u4ee5\u968f\u610f\u7ed9\u5b83\u547d\u540d\u3002\u7136\u540e\uff0c\u6211\u4eec\u7684\u7b2c\u4e00\u6b65\u5c06\u662f\u5206\u5272\u6570\u636e\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f df . to_csv ( \"../input/imdb_folds.csv\" , index = False ) \u5c06\u6570\u636e\u96c6\u5212\u5206\u4e3a\u591a\u4e2a\u6298\u53e0\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5728 dataset.py \u4e2d\u521b\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u4f1a\u8fd4\u56de\u4e00\u4e2a\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\u6837\u672c\u3002 import torch class IMDBDataset : def __init__ ( self , reviews , targets ): self . reviews = reviews self . target = targets def __len__ ( self ): return len ( self . reviews ) def __getitem__ ( self , item ): review = self . reviews [ item , :] target = self . target [ item ] return { \"review\" : torch . tensor ( review , dtype = torch . long ), \"target\" : torch . tensor ( target , dtype = torch . float ) } \u5b8c\u6210\u6570\u636e\u96c6\u5206\u7c7b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa lstm.py\uff0c\u5176\u4e2d\u5305\u542b\u6211\u4eec\u7684 LSTM \u6a21\u578b import torch import torch.nn as nn class LSTM ( nn . Module ): def __init__ ( self , embedding_matrix ): super ( LSTM , self ) . __init__ () num_words = embedding_matrix . shape [ 0 ] embed_dim = embedding_matrix . shape [ 1 ] self . embedding = nn . Embedding ( num_embeddings = num_words , embedding_dim = embed_dim ) self . embedding . weight = nn . Parameter ( torch . tensor ( embedding_matrix , dtype = torch . float32 ) ) self . embedding . weight . requires_grad = False self . lstm = nn . LSTM ( embed_dim , 128 , bidirectional = True , batch_first = True , ) self . out = nn . Linear ( 512 , 1 ) def forward ( self , x ): x = self . embedding ( x ) x , _ = self . lstm ( x ) avg_pool = torch . mean ( x , 1 ) max_pool , _ = torch . max ( x , 1 ) out = torch . cat (( avg_pool , max_pool ), 1 ) out = self . out ( out ) return out \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa engine.py\uff0c\u5176\u4e2d\u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u51fd\u6570\u3002 import torch import torch.nn as nn def train ( data_loader , model , optimizer , device ): model . train () for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () predictions = model ( reviews ) loss = nn . BCEWithLogitsLoss ()( predictions , targets . view ( - 1 , 1 ) ) loss . backward () optimizer . step () def evaluate ( data_loader , model , device ): final_predictions = [] final_targets = [] model . eval () with torch . no_grad (): for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) predictions = model ( reviews ) predictions = predictions . cpu () . numpy () . tolist () targets = data [ \"target\" ] . cpu () . numpy () . tolist () final_predictions . extend ( predictions ) final_targets . extend ( targets ) return final_predictions , final_targets \u8fd9\u4e9b\u51fd\u6570\u5c06\u5728 train.py \u4e2d\u4e3a\u6211\u4eec\u63d0\u4f9b\u5e2e\u52a9\uff0c\u8be5\u51fd\u6570\u7528\u4e8e\u8bad\u7ec3\u591a\u4e2a\u6298\u53e0\u3002 import io import torch import numpy as np import pandas as pd import tensorflow as tf from sklearn import metrics import config import dataset import engine import lstm def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def create_embedding_matrix ( word_index , embedding_dict ): embedding_matrix = np . zeros (( len ( word_index ) + 1 , 300 )) for word , i in word_index . items (): if word in embedding_dict : embedding_matrix [ i ] = embedding_dict [ word ] return embedding_matrix def run ( df , fold ): train_df = df [ df . kfold != fold ] . reset_index ( drop = True ) valid_df = df [ df . kfold == fold ] . reset_index ( drop = True ) print ( \"Fitting tokenizer\" ) tokenizer = tf . keras . preprocessing . text . Tokenizer () tokenizer . fit_on_texts ( df . review . values . tolist ()) xtrain = tokenizer . texts_to_sequences ( train_df . review . values ) xtest = tokenizer . texts_to_sequences ( valid_df . review . values ) xtrain = tf . keras . preprocessing . sequence . pad_sequences ( xtrain , maxlen = config . MAX_LEN ) xtest = tf . keras . preprocessing . sequence . pad_sequences ( xtest , maxlen = config . MAX_LEN ) train_dataset = dataset . IMDBDataset ( reviews = xtrain , targets = train_df . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 2 ) valid_dataset = dataset . IMDBDataset ( reviews = xtest , targets = valid_df . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) print ( \"Loading embeddings\" ) embedding_dict = load_vectors ( \"../input/crawl-300d-2M.vec\" ) embedding_matrix = create_embedding_matrix ( tokenizer . word_index , embedding_dict ) device = torch . device ( \"cuda\" ) model = lstm . LSTM ( embedding_matrix ) model . to ( device ) optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) print ( \"Training Model\" ) best_accuracy = 0 early_stopping_counter = 0 for epoch in range ( config . EPOCHS ): engine . train ( train_data_loader , model , optimizer , device ) outputs , targets = engine . evaluate ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"FOLD: { fold } , Epoch: { epoch } , Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : best_accuracy = accuracy else : early_stopping_counter += 1 if early_stopping_counter > 2 : break if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb_folds.csv\" ) run ( df , fold = 0 ) run ( df , fold = 1 ) run ( df , fold = 2 ) run ( df , fold = 3 ) run ( df , fold = 4 ) \u6700\u540e\u662f config.py\u3002 MAX_LEN = 128 TRAIN_BATCH_SIZE = 16 VALID_BATCH_SIZE = 8 EPOCHS = 10 \u8ba9\u6211\u4eec\u770b\u770b\u8f93\u51fa\uff1a FOLD : 0 , Epoch : 3 , Accuracy Score = 0.9015 FOLD : 1 , Epoch : 4 , Accuracy Score = 0.9007 FOLD : 2 , Epoch : 3 , Accuracy Score = 0.8924 FOLD : 3 , Epoch : 2 , Accuracy Score = 0.9 FOLD : 4 , Epoch : 1 , Accuracy Score = 0.878 \u8fd9\u662f\u8fc4\u4eca\u4e3a\u6b62\u6211\u4eec\u83b7\u5f97\u7684\u6700\u597d\u6210\u7ee9\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u53ea\u663e\u793a\u4e86\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7cbe\u5ea6\u6700\u9ad8\u7684Epoch\u3002 \u4f60\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u9884\u5148\u8bad\u7ec3\u7684\u5d4c\u5165\u548c\u7b80\u5355\u7684\u53cc\u5411 LSTM\u3002 \u5982\u679c\u4f60\u60f3\u6539\u53d8\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u53ea\u6539\u53d8 lstm.py \u4e2d\u7684\u6a21\u578b\u5e76\u4fdd\u6301\u4e00\u5207\u4e0d\u53d8\u3002 \u8fd9\u79cd\u4ee3\u7801\u53ea\u9700\u8981\u5f88\u5c11\u7684\u5b9e\u9a8c\u6539\u52a8\uff0c\u5e76\u4e14\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \u4f8b\u5982\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5b66\u4e60\u5d4c\u5165\u800c\u4e0d\u662f\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6\u4e00\u4e9b\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u7ec4\u5408\u591a\u4e2a\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528GRU\uff0c\u60a8\u53ef\u4ee5\u5728\u5d4c\u5165\u540e\u4f7f\u7528\u7a7a\u95f4dropout\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0GRU LSTM \u5c42\u4e4b\u540e\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0\u4e24\u4e2a LSTM \u5c42\uff0c\u60a8\u53ef\u4ee5\u8fdb\u884c LSTM-GRU-LSTM \u914d\u7f6e\uff0c\u60a8\u53ef\u4ee5\u7528\u5377\u79ef\u5c42\u66ff\u6362 LSTM \u7b49\uff0c\u800c\u65e0\u9700\u5bf9\u4ee3\u7801\u8fdb\u884c\u592a\u591a\u66f4\u6539\u3002 \u6211\u63d0\u5230\u7684\u5927\u90e8\u5206\u5185\u5bb9\u53ea\u9700\u8981\u66f4\u6539\u6a21\u578b\u7c7b\u3002 \u5f53\u60a8\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\u65f6\uff0c\u5c1d\u8bd5\u67e5\u770b\u6709\u591a\u5c11\u5355\u8bcd\u65e0\u6cd5\u627e\u5230\u5d4c\u5165\u4ee5\u53ca\u539f\u56e0\u3002 \u9884\u8bad\u7ec3\u5d4c\u5165\u7684\u5355\u8bcd\u8d8a\u591a\uff0c\u7ed3\u679c\u5c31\u8d8a\u597d\u3002 \u6211\u5411\u60a8\u5c55\u793a\u4ee5\u4e0b\u672a\u6ce8\u91ca\u7684 (!) \u51fd\u6570\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5b83\u4e3a\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u8bad\u7ec3\u5d4c\u5165\u521b\u5efa\u5d4c\u5165\u77e9\u9635\uff0c\u5176\u683c\u5f0f\u4e0e glove \u6216 fastText \u76f8\u540c\uff08\u53ef\u80fd\u9700\u8981\u8fdb\u884c\u4e00\u4e9b\u66f4\u6539\uff09\u3002 def load_embeddings ( word_index , embedding_file , vector_length = 300 ): max_features = len ( word_index ) + 1 words_to_find = list ( word_index . keys ()) more_words_to_find = [] for wtf in words_to_find : more_words_to_find . append ( wtf ) more_words_to_find . append ( str ( wtf ) . capitalize ()) more_words_to_find = set ( more_words_to_find ) def get_coefs ( word , * arr ): return word , np . asarray ( arr , dtype = 'float32' ) embeddings_index = dict ( get_coefs ( * o . strip () . split ( \" \" )) for o in open ( embedding_file ) if o . split ( \" \" )[ 0 ] in more_words_to_find and len ( o ) > 100 ) embedding_matrix = np . zeros (( max_features , vector_length )) for word , i in word_index . items (): if i >= max_features : continue embedding_vector = embeddings_index . get ( word ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . capitalize () ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . upper () ) if ( embedding_vector is not None and len ( embedding_vector ) == vector_length ): embedding_matrix [ i ] = embedding_vector return embedding_matrix \u9605\u8bfb\u5e76\u8fd0\u884c\u4e0a\u9762\u7684\u51fd\u6570\uff0c\u770b\u770b\u53d1\u751f\u4e86\u4ec0\u4e48\u3002 \u8be5\u51fd\u6570\u8fd8\u53ef\u4ee5\u4fee\u6539\u4e3a\u4f7f\u7528\u8bcd\u5e72\u8bcd\u6216\u8bcd\u5f62\u8fd8\u539f\u8bcd\u3002 \u6700\u540e\uff0c\u60a8\u5e0c\u671b\u8bad\u7ec3\u8bed\u6599\u5e93\u4e2d\u7684\u672a\u77e5\u5355\u8bcd\u6570\u91cf\u6700\u5c11\u3002 \u53e6\u4e00\u4e2a\u6280\u5de7\u662f\u5b66\u4e60\u5d4c\u5165\u5c42\uff0c\u5373\u4f7f\u5176\u53ef\u8bad\u7ec3\uff0c\u7136\u540e\u8bad\u7ec3\u7f51\u7edc\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5206\u7c7b\u95ee\u9898\u6784\u5efa\u4e86\u5f88\u591a\u6a21\u578b\u3002 \u7136\u800c\uff0c\u73b0\u5728\u662f\u5e03\u5076\u65f6\u4ee3\uff0c\u8d8a\u6765\u8d8a\u591a\u7684\u4eba\u8f6c\u5411\u57fa\u4e8e\u53d8\u5f62\u91d1\u521a\u7684\u6a21\u578b\u3002 \u57fa\u4e8e Transformer \u7684\u7f51\u7edc\u80fd\u591f\u5904\u7406\u672c\u8d28\u4e0a\u957f\u671f\u7684\u4f9d\u8d56\u5173\u7cfb\u3002 LSTM \u4ec5\u5f53\u5b83\u770b\u5230\u524d\u4e00\u4e2a\u5355\u8bcd\u65f6\u624d\u67e5\u770b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u3002 \u53d8\u538b\u5668\u7684\u60c5\u51b5\u5e76\u975e\u5982\u6b64\u3002 \u5b83\u53ef\u4ee5\u540c\u65f6\u67e5\u770b\u6574\u4e2a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u3002 \u56e0\u6b64\uff0c\u53e6\u4e00\u4e2a\u4f18\u70b9\u662f\u5b83\u53ef\u4ee5\u8f7b\u677e\u5e76\u884c\u5316\u5e76\u66f4\u6709\u6548\u5730\u4f7f\u7528 GPU\u3002 Transformers \u662f\u4e00\u4e2a\u975e\u5e38\u5e7f\u6cdb\u7684\u8bdd\u9898\uff0c\u6709\u592a\u591a\u7684\u6a21\u578b\uff1a BERT\u3001RoBERTa\u3001XLNet\u3001XLM-RoBERTa\u3001T5 \u7b49\u3002\u6211\u5c06\u5411\u60a8\u5c55\u793a\u4e00\u79cd\u53ef\u7528\u4e8e\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\uff08T5 \u9664\u5916\uff09\u8fdb\u884c\u5206\u7c7b\u7684\u901a\u7528\u65b9\u6cd5 \u6211\u4eec\u4e00\u76f4\u5728\u8ba8\u8bba\u7684\u95ee\u9898\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u4e9b\u53d8\u538b\u5668\u9700\u8981\u8bad\u7ec3\u5b83\u4eec\u6240\u9700\u7684\u8ba1\u7b97\u80fd\u529b\u3002 \u56e0\u6b64\uff0c\u5982\u679c\u60a8\u6ca1\u6709\u9ad8\u7aef\u7cfb\u7edf\uff0c\u4e0e\u57fa\u4e8e LSTM \u6216 TF-IDF \u7684\u6a21\u578b\u76f8\u6bd4\uff0c\u8bad\u7ec3\u6a21\u578b\u53ef\u80fd\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u3002 \u6211\u4eec\u8981\u505a\u7684\u7b2c\u4e00\u4ef6\u4e8b\u662f\u521b\u5efa\u4e00\u4e2a\u914d\u7f6e\u6587\u4ef6\u3002 import transformers MAX_LEN = 512 TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 EPOCHS = 10 BERT_PATH = \"../input/bert_base_uncased/\" MODEL_PATH = \"model.bin\" TRAINING_FILE = \"../input/imdb.csv\" TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u8fd9\u91cc\u7684\u914d\u7f6e\u6587\u4ef6\u662f\u6211\u4eec\u5b9a\u4e49\u5206\u8bcd\u5668\u548c\u5176\u4ed6\u6211\u4eec\u60f3\u8981\u7ecf\u5e38\u66f4\u6539\u7684\u53c2\u6570\u7684\u552f\u4e00\u5730\u65b9 \u2014\u2014 \u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u505a\u5f88\u591a\u5b9e\u9a8c\u800c\u4e0d\u9700\u8981\u8fdb\u884c\u5927\u91cf\u66f4\u6539\u3002 \u4e0b\u4e00\u6b65\u662f\u6784\u5efa\u6570\u636e\u96c6\u7c7b\u3002 import config import torch class BERTDataset : def __init__ ( self , review , target ): self . review = review self . target = target self . tokenizer = config . TOKENIZER self . max_len = config . MAX_LEN def __len__ ( self ): return len ( self . review ) def __getitem__ ( self , item ): review = str ( self . review [ item ]) review = \" \" . join ( review . split ()) inputs = self . tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = self . max_len , pad_to_max_length = True , ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] return { \"ids\" : torch . tensor ( ids , dtype = torch . long ), \"mask\" : torch . tensor ( mask , dtype = torch . long ), \"token_type_ids\" : torch . tensor ( token_type_ids , dtype = torch . long ), \"targets\" : torch . tensor ( self . target [ item ], dtype = torch . float ) } \u73b0\u5728\u6211\u4eec\u6765\u5230\u4e86\u8be5\u9879\u76ee\u7684\u6838\u5fc3\uff0c\u5373\u6a21\u578b\u3002 import config import transformers import torch.nn as nn class BERTBaseUncased ( nn . Module ): def __init__ ( self ): super ( BERTBaseUncased , self ) . __init__ () self . bert = transformers . BertModel . from_pretrained ( config . BERT_PATH ) self . bert_drop = nn . Dropout ( 0.3 ) self . out = nn . Linear ( 768 , 1 ) def forward ( self , ids , mask , token_type_ids ): hidden state _ , o2 = self . bert ( ids , attention_mask = mask , token_type_ids = token_type_ids ) bo = self . bert_drop ( o2 ) output = self . out ( bo ) return output \u8be5\u6a21\u578b\u8fd4\u56de\u5355\u4e2a\u8f93\u51fa\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e26\u6709 logits \u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\uff0c\u5b83\u9996\u5148\u5e94\u7528 sigmoid\uff0c\u7136\u540e\u8ba1\u7b97\u635f\u5931\u3002 \u8fd9\u662f\u5728engine.py \u4e2d\u5b8c\u6210\u7684\u3002 import torch import torch.nn as nn def loss_fn ( outputs , targets ): return nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) def train_fn ( data_loader , model , optimizer , device , scheduler ): model . train () for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) loss = loss_fn ( outputs , targets ) loss . backward () optimizer . step () scheduler . step () def eval_fn ( data_loader , model , device ): model . eval () fin_targets = [] fin_outputs = [] with torch . no_grad (): for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) targets = targets . cpu () . detach () fin_targets . extend ( targets . numpy () . tolist ()) outputs = torch . sigmoid ( outputs ) . cpu () . detach () fin_outputs . extend ( outputs . numpy () . tolist ()) return fin_outputs , fin_targets \u6700\u540e\uff0c\u6211\u4eec\u51c6\u5907\u597d\u8bad\u7ec3\u4e86\u3002 \u6211\u4eec\u6765\u770b\u770b\u8bad\u7ec3\u811a\u672c\u5427\uff01 import config import dataset import engine import torch import pandas as pd import torch.nn as nn import numpy as np from model import BERTBaseUncased from sklearn import model_selection from sklearn import metrics from transformers import AdamW from transformers import get_linear_schedule_with_warmup def train (): dfx = pd . read_csv ( config . TRAINING_FILE ) . fillna ( \"none\" ) dfx . sentiment = dfx . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df_train , df_valid = model_selection . train_test_split ( dfx , test_size = 0.1 , random_state = 42 , stratify = dfx . sentiment . values ) df_train = df_train . reset_index ( drop = True ) df_valid = df_valid . reset_index ( drop = True ) train_dataset = dataset . BERTDataset ( review = df_train . review . values , target = df_train . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 4 ) valid_dataset = dataset . BERTDataset ( review = df_valid . review . values , target = df_valid . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) device = torch . device ( \"cuda\" ) model = BERTBaseUncased () model . to ( device ) param_optimizer = list ( model . named_parameters ()) no_decay = [ \"bias\" , \"LayerNorm.bias\" , \"LayerNorm.weight\" ] optimizer_parameters = [ { \"params\" : [ p for n , p in param_optimizer if not any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.001 , } \uff0c { \"params\" : [ p for n , p in param_optimizer if any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.0 , }] num_train_steps = int ( len ( df_train ) / config . TRAIN_BATCH_SIZE * config . EPOCHS ) optimizer = AdamW ( optimizer_parameters , lr = 3e-5 ) scheduler = get_linear_schedule_with_warmup ( optimizer , num_warmup_steps = 0 , num_training_steps = num_train_steps ) model = nn . DataParallel ( model ) best_accuracy = 0 for epoch in range ( config . EPOCHS ): engine . train_fn ( train_data_loader , model , optimizer , device , scheduler ) outputs , targets = engine . eval_fn ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : torch . save ( model . state_dict (), config . MODEL_PATH ) best_accuracy = accuracy if __name__ == \"__main__\" : train () \u4e4d\u4e00\u770b\u53ef\u80fd\u770b\u8d77\u6765\u5f88\u591a\uff0c\u4f46\u4e00\u65e6\u60a8\u4e86\u89e3\u4e86\u5404\u4e2a\u7ec4\u4ef6\uff0c\u5c31\u4e0d\u518d\u90a3\u4e48\u7b80\u5355\u4e86\u3002 \u60a8\u53ea\u9700\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u5373\u53ef\u8f7b\u677e\u5c06\u5176\u66f4\u6539\u4e3a\u60a8\u60f3\u8981\u4f7f\u7528\u7684\u4efb\u4f55\u5176\u4ed6\u53d8\u538b\u5668\u6a21\u578b\u3002 \u8be5\u6a21\u578b\u7684\u51c6\u786e\u7387\u4e3a 93%\uff01 \u54c7\uff01 \u8fd9\u6bd4\u4efb\u4f55\u5176\u4ed6\u6a21\u578b\u90fd\u8981\u597d\u5f97\u591a\u3002 \u4f46\u662f\u8fd9\u503c\u5f97\u5417\uff1f \u6211\u4eec\u4f7f\u7528 LSTM \u80fd\u591f\u5b9e\u73b0 90% \u7684\u76ee\u6807\uff0c\u800c\u4e14\u5b83\u4eec\u66f4\u7b80\u5355\u3001\u66f4\u5bb9\u6613\u8bad\u7ec3\u5e76\u4e14\u63a8\u7406\u901f\u5ea6\u66f4\u5feb\u3002 \u901a\u8fc7\u4f7f\u7528\u4e0d\u540c\u7684\u6570\u636e\u5904\u7406\u6216\u8c03\u6574\u5c42\u3001\u8282\u70b9\u3001dropout\u3001\u5b66\u4e60\u7387\u3001\u66f4\u6539\u4f18\u5316\u5668\u7b49\u53c2\u6570\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6a21\u578b\u6539\u8fdb\u4e00\u4e2a\u767e\u5206\u70b9\u3002\u7136\u540e\u6211\u4eec\u5c06\u4ece BERT \u4e2d\u83b7\u5f97\u7ea6 2% \u7684\u6536\u76ca\u3002 \u53e6\u4e00\u65b9\u9762\uff0cBERT \u7684\u8bad\u7ec3\u65f6\u95f4\u8981\u957f\u5f97\u591a\uff0c\u53c2\u6570\u5f88\u591a\uff0c\u800c\u4e14\u63a8\u7406\u901f\u5ea6\u4e5f\u5f88\u6162\u3002 \u6700\u540e\uff0c\u60a8\u5e94\u8be5\u5ba1\u89c6\u81ea\u5df1\u7684\u4e1a\u52a1\u5e76\u505a\u51fa\u660e\u667a\u7684\u9009\u62e9\u3002 \u4e0d\u8981\u4ec5\u4ec5\u56e0\u4e3a BERT\u201c\u9177\u201d\u800c\u9009\u62e9\u5b83\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6211\u4eec\u5728\u8fd9\u91cc\u8ba8\u8bba\u7684\u552f\u4e00\u4efb\u52a1\u662f\u5206\u7c7b\uff0c\u4f46\u5c06\u5176\u66f4\u6539\u4e3a\u56de\u5f52\u3001\u591a\u6807\u7b7e\u6216\u591a\u7c7b\u53ea\u9700\u8981\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u3002 \u4f8b\u5982\uff0c\u591a\u7c7b\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u7684\u540c\u4e00\u95ee\u9898\u5c06\u6709\u591a\u4e2a\u8f93\u51fa\u548c\u4ea4\u53c9\u71b5\u635f\u5931\u3002 \u5176\u4ed6\u4e00\u5207\u90fd\u5e94\u8be5\u4fdd\u6301\u4e0d\u53d8\u3002 \u81ea\u7136\u8bed\u8a00\u5904\u7406\u975e\u5e38\u5e9e\u5927\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u5176\u4e2d\u7684\u4e00\u5c0f\u90e8\u5206\u3002 \u663e\u7136\uff0c\u8fd9\u662f\u4e00\u4e2a\u5f88\u5927\u7684\u6bd4\u4f8b\uff0c\u56e0\u4e3a\u5927\u591a\u6570\u5de5\u4e1a\u6a21\u578b\u90fd\u662f\u5206\u7c7b\u6216\u56de\u5f52\u6a21\u578b\u3002 \u5982\u679c\u6211\u5f00\u59cb\u8be6\u7ec6\u5199\u6240\u6709\u5185\u5bb9\uff0c\u6211\u6700\u7ec8\u53ef\u80fd\u4f1a\u5199\u51e0\u767e\u9875\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u51b3\u5b9a\u5c06\u6240\u6709\u5185\u5bb9\u5305\u542b\u5728\u4e00\u672c\u5355\u72ec\u7684\u4e66\u4e2d\uff1a\u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55 NLP \u95ee\u9898\uff01","title":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5"},{"location":"%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%88%96%E5%9B%9E%E5%BD%92%E6%96%B9%E6%B3%95/#_1","text":"\u6587\u672c\u95ee\u9898\u662f\u6211\u7684\u6700\u7231\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u8fd9\u4e9b\u95ee\u9898\u4e5f\u88ab\u79f0\u4e3a \u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u95ee\u9898 \u3002NLP \u95ee\u9898\u4e0e\u56fe\u50cf\u95ee\u9898\u4e5f\u6709\u5f88\u5927\u4e0d\u540c\u3002\u4f60\u9700\u8981\u521b\u5efa\u4ee5\u524d\u4ece\u672a\u4e3a\u8868\u683c\u95ee\u9898\u521b\u5efa\u8fc7\u7684\u6570\u636e\u7ba1\u9053\u3002\u4f60\u9700\u8981\u4e86\u89e3\u5546\u4e1a\u6848\u4f8b\uff0c\u624d\u80fd\u5efa\u7acb\u4e00\u4e2a\u597d\u7684\u6a21\u578b\u3002\u987a\u4fbf\u8bf4\u4e00\u53e5\uff0c\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u4efb\u4f55\u4e8b\u60c5\u90fd\u662f\u5982\u6b64\u3002\u5efa\u7acb\u6a21\u578b\u4f1a\u8ba9\u4f60\u8fbe\u5230\u4e00\u5b9a\u7684\u6c34\u5e73\uff0c\u4f46\u8981\u60f3\u6539\u5584\u548c\u4fc3\u8fdb\u4f60\u6240\u5efa\u7acb\u6a21\u578b\u7684\u4e1a\u52a1\uff0c\u4f60\u5fc5\u987b\u4e86\u89e3\u5b83\u5bf9\u4e1a\u52a1\u7684\u5f71\u54cd\u3002 NLP \u95ee\u9898\u6709\u5f88\u591a\u79cd\uff0c\u5176\u4e2d\u6700\u5e38\u89c1\u7684\u662f\u5b57\u7b26\u4e32\u5206\u7c7b\u3002\u5f88\u591a\u65f6\u5019\uff0c\u6211\u4eec\u4f1a\u770b\u5230\u4eba\u4eec\u5728\u5904\u7406\u8868\u683c\u6570\u636e\u6216\u56fe\u50cf\u65f6\u8868\u73b0\u51fa\u8272\uff0c\u4f46\u5728\u5904\u7406\u6587\u672c\u65f6\uff0c\u4ed6\u4eec\u751a\u81f3\u4e0d\u77e5\u9053\u4ece\u4f55\u5165\u624b\u3002\u6587\u672c\u6570\u636e\u4e0e\u5176\u4ed6\u7c7b\u578b\u7684\u6570\u636e\u96c6\u6ca1\u6709\u4ec0\u4e48\u4e0d\u540c\u3002\u5bf9\u4e8e\u8ba1\u7b97\u673a\u6765\u8bf4\uff0c\u4e00\u5207\u90fd\u662f\u6570\u5b57\u3002 \u5047\u8bbe\u6211\u4eec\u4ece\u60c5\u611f\u5206\u7c7b\u8fd9\u4e00\u57fa\u672c\u4efb\u52a1\u5f00\u59cb\u3002\u6211\u4eec\u5c06\u5c1d\u8bd5\u5bf9\u7535\u5f71\u8bc4\u8bba\u8fdb\u884c\u60c5\u611f\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u6709\u4e00\u4e2a\u6587\u672c\uff0c\u5e76\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u60c5\u611f\u3002\u4f60\u5c06\u5982\u4f55\u5904\u7406\u8fd9\u7c7b\u95ee\u9898\uff1f\u662f\u5e94\u7528\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff1f \u4e0d\uff0c\u7edd\u5bf9\u9519\u4e86\u3002\u4f60\u8981\u4ece\u6700\u57fa\u672c\u7684\u5f00\u59cb\u3002\u8ba9\u6211\u4eec\u5148\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u4ec0\u4e48\u6837\u5b50\u7684\u3002 \u6211\u4eec\u4ece IMDB \u7535\u5f71\u8bc4\u8bba\u6570\u636e\u96c6 \u5f00\u59cb\uff0c\u8be5\u6570\u636e\u96c6\u5305\u542b 25000 \u7bc7\u6b63\u9762\u60c5\u611f\u8bc4\u8bba\u548c 25000 \u7bc7\u8d1f\u9762\u60c5\u611f\u8bc4\u8bba\u3002 \u6211\u5c06\u5728\u6b64\u8ba8\u8bba\u7684\u6982\u5ff5\u51e0\u4e4e\u9002\u7528\u4e8e\u4efb\u4f55\u6587\u672c\u5206\u7c7b\u6570\u636e\u96c6\u3002 \u8fd9\u4e2a\u6570\u636e\u96c6\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4e00\u7bc7\u8bc4\u8bba\u5bf9\u5e94\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5199\u7684\u662f\u8bc4\u8bba\u800c\u4e0d\u662f\u53e5\u5b50\u3002\u8bc4\u8bba\u5c31\u662f\u4e00\u5806\u53e5\u5b50\u3002\u6240\u4ee5\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4f60\u4e00\u5b9a\u53ea\u770b\u5230\u4e86\u5bf9\u5355\u53e5\u7684\u5206\u7c7b\uff0c\u4f46\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u591a\u4e2a\u53e5\u5b50\u8fdb\u884c\u5206\u7c7b\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u8fd9\u610f\u5473\u7740\u4e0d\u4ec5\u4e00\u4e2a\u53e5\u5b50\u4f1a\u5bf9\u60c5\u611f\u4ea7\u751f\u5f71\u54cd\uff0c\u800c\u4e14\u60c5\u611f\u5f97\u5206\u662f\u591a\u4e2a\u53e5\u5b50\u5f97\u5206\u7684\u7ec4\u5408\u3002\u6570\u636e\u7b80\u4ecb\u5982\u56fe 1 \u6240\u793a\u3002 \u5982\u4f55\u7740\u624b\u89e3\u51b3\u8fd9\u6837\u7684\u95ee\u9898\uff1f\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u624b\u5de5\u5236\u4f5c\u4e24\u4efd\u5355\u8bcd\u8868\u3002\u4e00\u4e2a\u5217\u8868\u5305\u542b\u4f60\u80fd\u60f3\u8c61\u5230\u7684\u6240\u6709\u6b63\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u597d\u3001\u68d2\u3001\u597d\u7b49\uff1b\u53e6\u4e00\u4e2a\u5217\u8868\u5305\u542b\u6240\u6709\u8d1f\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u574f\u3001\u6076\u7b49\u3002\u6211\u4eec\u5148\u4e0d\u8981\u4e3e\u4f8b\u8bf4\u660e\u574f\u8bcd\uff0c\u5426\u5219\u8fd9\u672c\u4e66\u5c31\u53ea\u80fd\u4f9b 18 \u5c81\u4ee5\u4e0a\u7684\u4eba\u9605\u8bfb\u4e86\u3002\u4e00\u65e6\u4f60\u6709\u4e86\u8fd9\u4e9b\u5217\u8868\uff0c\u4f60\u751a\u81f3\u4e0d\u9700\u8981\u4e00\u4e2a\u6a21\u578b\u6765\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u4e9b\u5217\u8868\u4e5f\u88ab\u79f0\u4e3a\u60c5\u611f\u8bcd\u5178\u3002\u4f60\u53ef\u4ee5\u7528\u4e00\u4e2a\u7b80\u5355\u7684\u8ba1\u6570\u5668\u6765\u8ba1\u7b97\u53e5\u5b50\u4e2d\u6b63\u9762\u548c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u3002\u5982\u679c\u6b63\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u6b63\u9762\u60c5\u611f\uff1b\u5982\u679c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u8d1f\u9762\u60c5\u611f\u3002\u5982\u679c\u53e5\u5b50\u4e2d\u6ca1\u6709\u8fd9\u4e9b\u8bcd\uff0c\u5219\u53ef\u4ee5\u8bf4\u8be5\u53e5\u5b50\u5177\u6709\u4e2d\u6027\u60c5\u611f\u3002\u8fd9\u662f\u6700\u53e4\u8001\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u73b0\u5728\u4ecd\u6709\u4eba\u5728\u4f7f\u7528\u3002\u5b83\u4e5f\u4e0d\u9700\u8981\u592a\u591a\u4ee3\u7801\u3002 def find_sentiment ( sentence , pos , neg ): sentence = sentence . split () sentence = set ( sentence ) num_common_pos = len ( sentence . intersection ( pos )) num_common_neg = len ( sentence . intersection ( neg )) if num_common_pos > num_common_neg : return \"positive\" if num_common_pos < num_common_neg : return \"negative\" return \"neutral\" \u4e0d\u8fc7\uff0c\u8fd9\u79cd\u65b9\u6cd5\u8003\u8651\u7684\u56e0\u7d20\u5e76\u4e0d\u591a\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u7684 split() \u4e5f\u5e76\u4e0d\u5b8c\u7f8e\u3002\u5982\u679c\u4f7f\u7528 split()\uff0c\u5c31\u4f1a\u51fa\u73b0\u8fd9\u6837\u7684\u53e5\u5b50\uff1a \"hi, how are you?\" \u7ecf\u8fc7\u5206\u5272\u540e\u53d8\u4e3a\uff1a [\"hi,\", \"how\",\"are\",\"you?\"] \u8fd9\u79cd\u65b9\u6cd5\u5e76\u4e0d\u7406\u60f3\uff0c\u56e0\u4e3a\u5355\u8bcd\u4e2d\u5305\u542b\u4e86\u9017\u53f7\u548c\u95ee\u53f7\uff0c\u5b83\u4eec\u5e76\u6ca1\u6709\u88ab\u5206\u5272\u3002\u56e0\u6b64\uff0c\u5982\u679c\u6ca1\u6709\u5728\u5206\u5272\u524d\u5bf9\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u8fdb\u884c\u9884\u5904\u7406\uff0c\u4e0d\u5efa\u8bae\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\u3002\u5c06\u5b57\u7b26\u4e32\u62c6\u5206\u4e3a\u5355\u8bcd\u5217\u8868\u79f0\u4e3a\u6807\u8bb0\u5316\u3002\u6700\u6d41\u884c\u7684\u6807\u8bb0\u5316\u65b9\u6cd5\u4e4b\u4e00\u6765\u81ea NLTK\uff08\u81ea\u7136\u8bed\u8a00\u5de5\u5177\u5305\uff09 \u3002 In [ X ]: from nltk.tokenize import word_tokenize In [ X ]: sentence = \"hi, how are you?\" In [ X ]: sentence . split () Out [ X ]: [ 'hi,' , 'how' , 'are' , 'you?' ] In [ X ]: word_tokenize ( sentence ) Out [ X ]: [ 'hi' , ',' , 'how' , 'are' , 'you' , '?' ] \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u4f7f\u7528 NLTK \u7684\u5355\u8bcd\u6807\u8bb0\u5316\u529f\u80fd\uff0c\u540c\u4e00\u4e2a\u53e5\u5b50\u7684\u62c6\u5206\u6548\u679c\u8981\u597d\u5f97\u591a\u3002\u4f7f\u7528\u5355\u8bcd\u5217\u8868\u8fdb\u884c\u5bf9\u6bd4\u7684\u6548\u679c\u4e5f\u4f1a\u66f4\u597d\uff01\u8fd9\u5c31\u662f\u6211\u4eec\u5c06\u5e94\u7528\u4e8e\u7b2c\u4e00\u4e2a\u60c5\u611f\u68c0\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u3002 \u5728\u5904\u7406 NLP \u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u60a8\u5e94\u8be5\u7ecf\u5e38\u5c1d\u8bd5\u7684\u57fa\u672c\u6a21\u578b\u4e4b\u4e00\u662f \u8bcd\u888b\u6a21\u578b\uff08bag of words\uff09 \u3002\u5728\u8bcd\u888b\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u5de8\u5927\u7684\u7a00\u758f\u77e9\u9635\uff0c\u5b58\u50a8\u8bed\u6599\u5e93\uff08\u8bed\u6599\u5e93=\u6240\u6709\u6587\u6863=\u6240\u6709\u53e5\u5b50\uff09\u4e2d\u6240\u6709\u5355\u8bcd\u7684\u8ba1\u6570\u3002\u4e3a\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 CountVectorizer\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u662f\u5982\u4f55\u5de5\u4f5c\u7684\u3002 from sklearn.feature_extraction.text import CountVectorizer corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer () ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) \u5982\u679c\u6211\u4eec\u6253\u5370 corpus_transformed\uff0c\u5c31\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e0b\u9762\u7684\u7ed3\u679c\uff1a ( 0 , 2 ) 1 ( 0 , 9 ) 1 ( 0 , 11 ) 1 ( 0 , 22 ) 1 ( 1 , 1 ) 1 ( 1 , 3 ) 1 ( 1 , 4 ) 1 ( 1 , 7 ) 1 ( 1 , 8 ) 1 ( 1 , 10 ) 1 ( 1 , 13 ) 1 ( 1 , 17 ) 1 ( 1 , 19 ) 1 ( 1 , 22 ) 2 ( 2 , 0 ) 1 ( 2 , 5 ) 1 ( 2 , 6 ) 1 ( 2 , 14 ) 1 ( 2 , 22 ) 1 ( 3 , 12 ) 1 ( 3 , 15 ) 1 ( 3 , 16 ) 1 ( 3 , 18 ) 1 ( 3 , 20 ) 1 ( 4 , 21 ) 1 \u5728\u524d\u9762\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u89c1\u8bc6\u8fc7\u8fd9\u79cd\u8868\u793a\u6cd5\u3002\u5373\u7a00\u758f\u8868\u793a\u6cd5\u3002\u56e0\u6b64\uff0c\u8bed\u6599\u5e93\u73b0\u5728\u662f\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5176\u4e2d\u7b2c\u4e00\u4e2a\u6837\u672c\u6709 4 \u4e2a\u5143\u7d20\uff0c\u7b2c\u4e8c\u4e2a\u6837\u672c\u6709 10 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\uff0c\u7b2c\u4e09\u4e2a\u6837\u672c\u6709 5 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e9b\u5143\u7d20\u90fd\u6709\u76f8\u5173\u7684\u8ba1\u6570\u3002\u6709\u4e9b\u5143\u7d20\u4f1a\u51fa\u73b0\u4e24\u6b21\uff0c\u6709\u4e9b\u5219\u53ea\u6709\u4e00\u6b21\u3002\u4f8b\u5982\uff0c\u5728\u6837\u672c 2\uff08\u7b2c 1 \u884c\uff09\u4e2d\uff0c\u6211\u4eec\u770b\u5230\u7b2c 22 \u5217\u7684\u6570\u503c\u662f 2\u3002\u8fd9\u662f\u4e3a\u4ec0\u4e48\u5462\uff1f\u7b2c 22 \u5217\u662f\u4ec0\u4e48\uff1f CountVectorizer \u7684\u5de5\u4f5c\u65b9\u5f0f\u662f\u9996\u5148\u5bf9\u53e5\u5b50\u8fdb\u884c\u6807\u8bb0\u5316\u5904\u7406\uff0c\u7136\u540e\u4e3a\u6bcf\u4e2a\u6807\u8bb0\u8d4b\u503c\u3002\u56e0\u6b64\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u7531\u4e00\u4e2a\u552f\u4e00\u7d22\u5f15\u8868\u793a\u3002\u8fd9\u4e9b\u552f\u4e00\u7d22\u5f15\u5c31\u662f\u6211\u4eec\u770b\u5230\u7684\u5217\u3002CountVectorizer \u4f1a\u5b58\u50a8\u8fd9\u4e9b\u4fe1\u606f\u3002 print ( ctv . vocabulary_ ) { 'hello' : 9 , 'how' : 11 , 'are' : 2 , 'you' : 22 , 'im' : 13 , 'getting' : 8 , 'bored' : 4 , 'at' : 3 , 'home' : 10 , 'and' : 1 , 'what' : 19 , 'do' : 7 , 'think' : 17 , 'did' : 6 , 'know' : 14 , 'about' : 0 , 'counts' : 5 , 'let' : 15 , 'see' : 16 , 'if' : 12 , 'this' : 18 , 'works' : 20 , 'yes' : 21 } \u6211\u4eec\u770b\u5230\uff0c\u7d22\u5f15 22 \u5c5e\u4e8e \"you\"\uff0c\u800c\u5728\u7b2c\u4e8c\u53e5\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u4e24\u6b21 \"you\"\u3002\u6211\u5e0c\u671b\u5927\u5bb6\u73b0\u5728\u5df2\u7ecf\u6e05\u695a\u4ec0\u4e48\u662f\u8bcd\u888b\u4e86\u3002\u4f46\u662f\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4e9b\u7279\u6b8a\u5b57\u7b26\u3002\u6709\u65f6\uff0c\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u4e5f\u5f88\u6709\u7528\u3002\u4f8b\u5982\uff0c\"? \"\u5728\u5927\u591a\u6570\u53e5\u5b50\u4e2d\u8868\u793a\u7591\u95ee\u53e5\u3002\u8ba9\u6211\u4eec\u628a scikit-learn \u7684 word_tokenize \u6574\u5408\u5230 CountVectorizer \u4e2d\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002 from sklearn.feature_extraction.text import CountVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) print ( ctv . vocabulary_ ) \u8fd9\u6837\uff0c\u6211\u4eec\u7684\u8bcd\u888b\u5c31\u53d8\u6210\u4e86\uff1a { 'hello' : 14 , ',' : 2 , 'how' : 16 , 'are' : 7 , 'you' : 27 , '?' : 4 , 'im' : 18 , 'getting' : 13 , 'bored' : 9 , 'at' : 8 , 'home' : 15 , '.' : 3 , 'and' : 6 , 'what' : 24 , 'do' : 12 , 'think' : 22 , 'did' : 11 , 'know' : 19 , 'about' : 5 , 'counts' : 10 , 'let' : 20 , \"'s\" : 1 , 'see' : 21 , 'if' : 17 , 'this' : 23 , 'works' : 25 , '!' : 0 , 'yes' : 26 } \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u5229\u7528 IMDB \u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u53e5\u5b50\u521b\u5efa\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5e76\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002\u8be5\u6570\u636e\u96c6\u4e2d\u6b63\u8d1f\u6837\u672c\u7684\u6bd4\u4f8b\u4e3a 1:1\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u5e76\u521b\u5efa\u4e00\u4e2a\u811a\u672c\u6765\u8bad\u7ec35\u4e2a\u6298\u53e0\u3002\u4f60\u4f1a\u95ee\u4f7f\u7528\u54ea\u4e2a\u6a21\u578b\uff1f\u5bf9\u4e8e\u9ad8\u7ef4\u7a00\u758f\u6570\u636e\uff0c\u54ea\u4e2a\u6a21\u578b\u6700\u5feb\uff1f\u903b\u8f91\u56de\u5f52\u3002\u6211\u4eec\u5c06\u9996\u5148\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u6765\u5904\u7406\u8fd9\u4e2a\u6570\u636e\u96c6\uff0c\u5e76\u521b\u5efa\u7b2c\u4e00\u4e2a\u57fa\u51c6\u6a21\u578b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) count_vec = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) count_vec . fit ( train_df . review ) xtrain = count_vec . transform ( train_df . review ) xtest = count_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u6bb5\u4ee3\u7801\u7684\u8fd0\u884c\u9700\u8981\u4e00\u5b9a\u7684\u65f6\u95f4\uff0c\u4f46\u53ef\u4ee5\u5f97\u5230\u4ee5\u4e0b\u8f93\u51fa\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8903 Fold : 1 Accuracy = 0.897 Fold : 2 Accuracy = 0.891 Fold : 3 Accuracy = 0.8914 Fold : 4 Accuracy = 0.8931 \u54c7\uff0c\u51c6\u786e\u7387\u5df2\u7ecf\u8fbe\u5230 89%\uff0c\u800c\u6211\u4eec\u6240\u505a\u7684\u53ea\u662f\u4f7f\u7528\u8bcd\u888b\u548c\u903b\u8f91\u56de\u5f52\uff01\u8fd9\u771f\u662f\u592a\u68d2\u4e86\uff01\u4e0d\u8fc7\uff0c\u8fd9\u4e2a\u6a21\u578b\u7684\u8bad\u7ec3\u82b1\u8d39\u4e86\u5f88\u591a\u65f6\u95f4\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u80fd\u5426\u901a\u8fc7\u4f7f\u7528\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u6765\u7f29\u77ed\u8bad\u7ec3\u65f6\u95f4\u3002\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u5728 NLP \u4efb\u52a1\u4e2d\u76f8\u5f53\u6d41\u884c\uff0c\u56e0\u4e3a\u7a00\u758f\u77e9\u9635\u975e\u5e38\u5e9e\u5927\uff0c\u800c\u6734\u7d20\u8d1d\u53f6\u65af\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\u3002\u8981\u4f7f\u7528\u8fd9\u4e2a\u6a21\u578b\uff0c\u9700\u8981\u66f4\u6539\u4e00\u4e2a\u5bfc\u5165\u548c\u6a21\u578b\u7684\u884c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e2a\u6a21\u578b\u7684\u6027\u80fd\u5982\u4f55\u3002\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 MultinomialNB\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import naive_bayes from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer model = naive_bayes . MultinomialNB () model . fit ( xtrain , train_df . sentiment ) \u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8444 Fold : 1 Accuracy = 0.8499 Fold : 2 Accuracy = 0.8422 Fold : 3 Accuracy = 0.8443 Fold : 4 Accuracy = 0.8455 \u6211\u4eec\u770b\u5230\u8fd9\u4e2a\u5206\u6570\u5f88\u4f4e\u3002\u4f46\u6734\u7d20\u8d1d\u53f6\u65af\u6a21\u578b\u7684\u901f\u5ea6\u975e\u5e38\u5feb\u3002 NLP \u4e2d\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f TF-IDF\uff0c\u5982\u4eca\u5927\u591a\u6570\u4eba\u90fd\u503e\u5411\u4e8e\u5ffd\u7565\u6216\u4e0d\u5c51\u4e8e\u4e86\u89e3\u8fd9\u79cd\u65b9\u6cd5\u3002TF \u662f\u672f\u8bed\u9891\u7387\uff0cIDF \u662f\u53cd\u5411\u6587\u6863\u9891\u7387\u3002\u4ece\u8fd9\u4e9b\u672f\u8bed\u6765\u770b\uff0c\u8fd9\u4f3c\u4e4e\u6709\u4e9b\u56f0\u96be\uff0c\u4f46\u901a\u8fc7 TF \u548c IDF \u7684\u8ba1\u7b97\u516c\u5f0f\uff0c\u4e8b\u60c5\u5c31\u4f1a\u53d8\u5f97\u5f88\u660e\u663e\u3002 $$ TF(t) = \\frac{Number\\ of\\ times\\ a\\ term\\ t\\ appears\\ in\\ a\\ document}{Total\\ number\\ of\\ terms\\ in \\ the\\ document} $$ \\[ IDF(t) = LOG\\left(\\frac{Total\\ number\\ of\\ documents}{Number\\ of\\ documents with\\ term\\ t\\ in\\ it}\\right) \\] \u672f\u8bed t \u7684 TF-IDF \u5b9a\u4e49\u4e3a\uff1a $$ TF-IDF(t) = TF(t) \\times IDF(t) $$ \u4e0e scikit-learn \u4e2d\u7684 CountVectorizer \u7c7b\u4f3c\uff0c\u6211\u4eec\u4e5f\u6709 TfidfVectorizer\u3002\u8ba9\u6211\u4eec\u8bd5\u7740\u50cf\u4f7f\u7528 CountVectorizer \u4e00\u6837\u4f7f\u7528\u5b83\u3002 from sklearn.feature_extraction.text import TfidfVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) print ( corpus_transformed ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a ( 0 , 27 ) 0.2965698850220162 ( 0 , 16 ) 0.4428321995085722 ( 0 , 14 ) 0.4428321995085722 ( 0 , 7 ) 0.4428321995085722 ( 0 , 4 ) 0.35727423026525224 ( 0 , 2 ) 0.4428321995085722 ( 1 , 27 ) 0.35299699146792735 ( 1 , 24 ) 0.2635440111190765 ( 1 , 22 ) 0.2635440111190765 ( 1 , 18 ) 0.2635440111190765 ( 1 , 15 ) 0.2635440111190765 ( 1 , 13 ) 0.2635440111190765 ( 1 , 12 ) 0.2635440111190765 ( 1 , 9 ) 0.2635440111190765 ( 1 , 8 ) 0.2635440111190765 ( 1 , 6 ) 0.2635440111190765 ( 1 , 4 ) 0.42525129752567803 ( 1 , 3 ) 0.2635440111190765 ( 2 , 27 ) 0.31752680284846835 ( 2 , 19 ) 0.4741246485558491 ( 2 , 11 ) 0.4741246485558491 ( 2 , 10 ) 0.4741246485558491 ( 2 , 5 ) 0.4741246485558491 ( 3 , 25 ) 0.38775666010579296 ( 3 , 23 ) 0.38775666010579296 ( 3 , 21 ) 0.38775666010579296 ( 3 , 20 ) 0.38775666010579296 ( 3 , 17 ) 0.38775666010579296 ( 3 , 1 ) 0.38775666010579296 ( 3 , 0 ) 0.3128396318588854 ( 4 , 26 ) 0.2959842226518677 ( 4 , 0 ) 0.9551928286692534 \u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6b21\u6211\u4eec\u5f97\u5230\u7684\u4e0d\u662f\u6574\u6570\u503c\uff0c\u800c\u662f\u6d6e\u70b9\u6570\u3002 \u7528 TfidfVectorizer \u4ee3\u66ff CountVectorizer \u4e5f\u662f\u5c0f\u83dc\u4e00\u789f\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 TfidfTransformer\u3002\u5982\u679c\u4f60\u4f7f\u7528\u7684\u662f\u8ba1\u6570\u503c\uff0c\u53ef\u4ee5\u4f7f\u7528 TfidfTransformer \u5e76\u83b7\u5f97\u4e0e TfidfVectorizer \u76f8\u540c\u7684\u6548\u679c\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfidf_vec . fit ( train_df . review ) xtrain = tfidf_vec . transform ( train_df . review ) xtest = tfidf_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u6211\u4eec\u53ef\u4ee5\u770b\u770b TF-IDF \u5728\u903b\u8f91\u56de\u5f52\u6a21\u578b\u4e0a\u7684\u8868\u73b0\u5982\u4f55\u3002 Fold : 0 Accuracy = 0.8976 Fold : 1 Accuracy = 0.8998 Fold : 2 Accuracy = 0.8948 Fold : 3 Accuracy = 0.8912 Fold : 4 Accuracy = 0.8995 \u6211\u4eec\u770b\u5230\uff0c\u8fd9\u4e9b\u5206\u6570\u90fd\u6bd4 CountVectorizer \u9ad8\u4e00\u4e9b\uff0c\u56e0\u6b64\u5b83\u6210\u4e3a\u4e86\u6211\u4eec\u60f3\u8981\u51fb\u8d25\u7684\u65b0\u57fa\u51c6\u3002 NLP \u4e2d\u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u6982\u5ff5\u662f N-gram\u3002N-grams \u662f\u6309\u987a\u5e8f\u6392\u5217\u7684\u5355\u8bcd\u7ec4\u5408\u3002N-grams \u5f88\u5bb9\u6613\u521b\u5efa\u3002\u60a8\u53ea\u9700\u6ce8\u610f\u987a\u5e8f\u5373\u53ef\u3002\u4e3a\u4e86\u8ba9\u4e8b\u60c5\u53d8\u5f97\u66f4\u7b80\u5355\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 NLTK \u7684 N-gram \u5b9e\u73b0\u3002 from nltk import ngrams from nltk.tokenize import word_tokenize N = 3 sentence = \"hi, how are you?\" tokenized_sentence = word_tokenize ( sentence ) n_grams = list ( ngrams ( tokenized_sentence , N )) print ( n_grams ) \u7531\u6b64\u5f97\u5230\uff1a [( 'hi' , ',' , 'how' ), ( ',' , 'how' , 'are' ), ( 'how' , 'are' , 'you' ), ( 'are' , 'you' , '?' )] \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u521b\u5efa 2-gram \u6216 4-gram \u7b49\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b n-gram \u5c06\u6210\u4e3a\u6211\u4eec\u8bcd\u6c47\u8868\u7684\u4e00\u90e8\u5206\uff0c\u5f53\u6211\u4eec\u8ba1\u7b97\u8ba1\u6570\u6216 tf-idf \u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u4e00\u4e2a n-gram \u89c6\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u6807\u8bb0\u3002\u56e0\u6b64\uff0c\u5728\u67d0\u79cd\u7a0b\u5ea6\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u7ed3\u5408\u4e0a\u4e0b\u6587\u3002scikit-learn \u7684 CountVectorizer \u548c TfidfVectorizer \u5b9e\u73b0\u90fd\u901a\u8fc7 ngram_range \u53c2\u6570\u63d0\u4f9b n-gram\uff0c\u8be5\u53c2\u6570\u6709\u6700\u5c0f\u548c\u6700\u5927\u9650\u5236\u3002\u9ed8\u8ba4\u60c5\u51b5\u4e0b\uff0c\u8be5\u53c2\u6570\u4e3a\uff081, 1\uff09\u3002\u5f53\u6211\u4eec\u5c06\u5176\u6539\u4e3a (1, 3) \u65f6\uff0c\u6211\u4eec\u5c06\u770b\u5230\u5355\u5b57\u5143\u3001\u53cc\u5b57\u5143\u548c\u4e09\u5b57\u5143\u3002\u4ee3\u7801\u6539\u52a8\u5f88\u5c0f\u3002 \u7531\u4e8e\u5230\u76ee\u524d\u4e3a\u6b62\u6211\u4eec\u4f7f\u7528 tf-idf \u5f97\u5230\u4e86\u6700\u597d\u7684\u7ed3\u679c\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5305\u542b n-grams \u76f4\u81f3 trigrams \u662f\u5426\u80fd\u6539\u8fdb\u6a21\u578b\u3002\u552f\u4e00\u9700\u8981\u4fee\u6539\u7684\u662f TfidfVectorizer \u7684\u521d\u59cb\u5316\u3002 tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None , ngram_range = ( 1 , 3 ) ) \u8ba9\u6211\u4eec\u770b\u770b\u662f\u5426\u4f1a\u6709\u6539\u8fdb\u3002 Fold : 0 Accuracy = 0.8931 Fold : 1 Accuracy = 0.8941 Fold : 2 Accuracy = 0.897 Fold : 3 Accuracy = 0.8922 Fold : 4 Accuracy = 0.8847 \u770b\u8d77\u6765\u8fd8\u884c\uff0c\u4f46\u6211\u4eec\u770b\u4e0d\u5230\u4efb\u4f55\u6539\u8fdb\u3002 \u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u591a\u4f7f\u7528 bigrams \u6765\u83b7\u5f97\u6539\u8fdb\u3002 \u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u4e00\u90e8\u5206\u3002\u4e5f\u8bb8\u4f60\u53ef\u4ee5\u81ea\u5df1\u8bd5\u7740\u505a\u3002 NLP \u7684\u57fa\u7840\u77e5\u8bc6\u8fd8\u6709\u5f88\u591a\u3002\u4f60\u5fc5\u987b\u77e5\u9053\u7684\u4e00\u4e2a\u672f\u8bed\u662f\u8bcd\u5e72\u63d0\u53d6\uff08strmming\uff09\u3002\u53e6\u4e00\u4e2a\u662f\u8bcd\u5f62\u8fd8\u539f\uff08lemmatization\uff09\u3002 \u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f \u53ef\u4ee5\u5c06\u4e00\u4e2a\u8bcd\u51cf\u5c11\u5230\u6700\u5c0f\u5f62\u5f0f\u3002\u5728\u8bcd\u5e72\u63d0\u53d6\u7684\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5e72\u5355\u8bcd\uff0c\u800c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5f62\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u8bcd\u5f62\u8fd8\u539f\u6bd4\u8bcd\u5e72\u63d0\u53d6\u66f4\u6fc0\u8fdb\uff0c\u800c\u8bcd\u5e72\u63d0\u53d6\u66f4\u6d41\u884c\u548c\u5e7f\u6cdb\u3002\u8bcd\u5e72\u548c\u8bcd\u5f62\u90fd\u6765\u81ea\u8bed\u8a00\u5b66\u3002\u5982\u679c\u4f60\u6253\u7b97\u4e3a\u67d0\u79cd\u8bed\u8a00\u5236\u4f5c\u8bcd\u5e72\u6216\u8bcd\u578b\uff0c\u9700\u8981\u5bf9\u8be5\u8bed\u8a00\u6709\u6df1\u5165\u7684\u4e86\u89e3\u3002\u5982\u679c\u8981\u8fc7\u591a\u5730\u4ecb\u7ecd\u8fd9\u4e9b\u77e5\u8bc6\uff0c\u5c31\u610f\u5473\u7740\u8981\u5728\u672c\u4e66\u4e2d\u589e\u52a0\u4e00\u7ae0\u3002\u4f7f\u7528 NLTK \u8f6f\u4ef6\u5305\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e24\u79cd\u65b9\u6cd5\u7684\u4e00\u4e9b\u793a\u4f8b\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u5668\u3002\u6211\u5c06\u7528\u6700\u5e38\u89c1\u7684 Snowball Stemmer \u548c WordNet Lemmatizer \u6765\u4e3e\u4f8b\u8bf4\u660e\u3002 from nltk.stem import WordNetLemmatizer from nltk.stem.snowball import SnowballStemmer lemmatizer = WordNetLemmatizer () stemmer = SnowballStemmer ( \"english\" ) words = [ \"fishing\" , \"fishes\" , \"fished\" ] for word in words : print ( f \"word= { word } \" ) print ( f \"stemmed_word= { stemmer . stem ( word ) } \" ) print ( f \"lemma= { lemmatizer . lemmatize ( word ) } \" ) print ( \"\" ) \u8fd9\u5c06\u6253\u5370\uff1a word = fishing stemmed_word = fish lemma = fishing word = fishes stemmed_word = fish lemma = fish word = fished stemmed_word = fish lemma = fished \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u662f\u622a\u7136\u4e0d\u540c\u7684\u3002\u5f53\u6211\u4eec\u8fdb\u884c\u8bcd\u5e72\u63d0\u53d6\u65f6\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u4e00\u4e2a\u8bcd\u7684\u6700\u5c0f\u5f62\u5f0f\uff0c\u5b83\u53ef\u80fd\u662f\u4e5f\u53ef\u80fd\u4e0d\u662f\u8be5\u8bcd\u6240\u5c5e\u8bed\u8a00\u8bcd\u5178\u4e2d\u7684\u4e00\u4e2a\u8bcd\u3002\u4f46\u662f\uff0c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u8fd9\u5c06\u662f\u4e00\u4e2a\u8bcd\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5c1d\u8bd5\u6dfb\u52a0\u8bcd\u5e72\u548c\u8bcd\u7d20\u5316\uff0c\u770b\u770b\u662f\u5426\u80fd\u6539\u5584\u7ed3\u679c\u3002 \u60a8\u8fd8\u5e94\u8be5\u4e86\u89e3\u7684\u4e00\u4e2a\u4e3b\u9898\u662f\u4e3b\u9898\u63d0\u53d6\u3002 \u4e3b\u9898\u63d0\u53d6 \u53ef\u4ee5\u4f7f\u7528\u975e\u8d1f\u77e9\u9635\u56e0\u5f0f\u5206\u89e3\uff08NMF\uff09\u6216\u6f5c\u5728\u8bed\u4e49\u5206\u6790\uff08LSA\uff09\u6765\u5b8c\u6210\uff0c\u540e\u8005\u4e5f\u88ab\u79f0\u4e3a\u5947\u5f02\u503c\u5206\u89e3\u6216 SVD\u3002\u8fd9\u4e9b\u5206\u89e3\u6280\u672f\u53ef\u5c06\u6570\u636e\u7b80\u5316\u4e3a\u7ed9\u5b9a\u6570\u91cf\u7684\u6210\u5206\u3002 \u60a8\u53ef\u4ee5\u5728\u4ece CountVectorizer \u6216 TfidfVectorizer \u4e2d\u83b7\u5f97\u7684\u7a00\u758f\u77e9\u9635\u4e0a\u5e94\u7528\u5176\u4e2d\u4efb\u4f55\u4e00\u79cd\u6280\u672f\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u5e94\u7528\u5230\u4e4b\u524d\u4f7f\u7528\u8fc7\u7684 TfidfVetorizer \u4e0a\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import decomposition from sklearn.feature_extraction.text import TfidfVectorizer corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus = corpus . review . values tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) svd = decomposition . TruncatedSVD ( n_components = 10 ) corpus_svd = svd . fit ( corpus_transformed ) sample_index = 0 feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) N = 5 print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ]) \u60a8\u53ef\u4ee5\u4f7f\u7528\u5faa\u73af\u6765\u8fd0\u884c\u591a\u4e2a\u6837\u672c\u3002 N = 5 for sample_index in range ( 5 ): feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ] ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a [ 'the' , ',' , '.' , 'a' , 'and' ] [ 'br' , '<' , '>' , '/' , '-' ] [ 'i' , 'movie' , '!' , 'it' , 'was' ] [ ',' , '!' , \"''\" , '``' , 'you' ] [ '!' , 'the' , '...' , \"''\" , '``' ] \u4f60\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6839\u672c\u8bf4\u4e0d\u901a\u3002\u600e\u4e48\u529e\u5462\uff1f\u8ba9\u6211\u4eec\u8bd5\u7740\u6e05\u7406\u4e00\u4e0b\uff0c\u770b\u770b\u662f\u5426\u6709\u610f\u4e49\u3002\u8981\u6e05\u7406\u4efb\u4f55\u6587\u672c\u6570\u636e\uff0c\u5c24\u5176\u662f pandas \u6570\u636e\u5e27\u4e2d\u7684\u6587\u672c\u6570\u636e\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u51fd\u6570\u3002 import re import string def clean_text ( s ): s = s . split () s = \" \" . join ( s ) s = re . sub ( f '[ { re . escape ( string . punctuation ) } ]' , '' , s ) return s \u8be5\u51fd\u6570\u4f1a\u5c06 \"hi, how are you????\" \u8fd9\u6837\u7684\u5b57\u7b26\u4e32\u8f6c\u6362\u4e3a \"hi how are you\"\u3002\u8ba9\u6211\u4eec\u628a\u8fd9\u4e2a\u51fd\u6570\u5e94\u7528\u5230\u65e7\u7684 SVD \u4ee3\u7801\u4e2d\uff0c\u770b\u770b\u5b83\u662f\u5426\u80fd\u7ed9\u63d0\u53d6\u7684\u4e3b\u9898\u5e26\u6765\u63d0\u5347\u3002\u4f7f\u7528 pandas\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 apply \u51fd\u6570\u5c06\u6e05\u7406\u4ee3\u7801 \"\u5e94\u7528 \"\u5230\u4efb\u610f\u7ed9\u5b9a\u7684\u5217\u4e2d\u3002 import pandas as pd corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus . loc [:, \"review\" ] = corpus . review . apply ( clean_text ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u5728\u4e3b SVD \u811a\u672c\u4e2d\u6dfb\u52a0\u4e86\u4e00\u884c\u4ee3\u7801\uff0c\u8fd9\u5c31\u662f\u4f7f\u7528\u51fd\u6570\u548c pandas \u5e94\u7528\u7684\u597d\u5904\u3002\u8fd9\u6b21\u751f\u6210\u7684\u4e3b\u9898\u5982\u4e0b\u3002 [ 'the' , 'a' , 'and' , 'of' , 'to' ] [ 'i' , 'movie' , 'it' , 'was' , 'this' ] [ 'the' , 'was' , 'i' , 'were' , 'of' ] [ 'her' , 'was' , 'she' , 'i' , 'he' ] [ 'br' , 'to' , 'they' , 'he' , 'show' ] \u547c\uff01\u81f3\u5c11\u8fd9\u6bd4\u6211\u4eec\u4e4b\u524d\u597d\u591a\u4e86\u3002\u4f46\u4f60\u77e5\u9053\u5417\uff1f\u4f60\u53ef\u4ee5\u901a\u8fc7\u5728\u6e05\u7406\u529f\u80fd\u4e2d\u5220\u9664\u505c\u6b62\u8bcd\uff08stopwords\uff09\u6765\u4f7f\u5b83\u53d8\u5f97\u66f4\u597d\u3002\u4ec0\u4e48\u662fstopwords\uff1f\u5b83\u4eec\u662f\u5b58\u5728\u4e8e\u6bcf\u79cd\u8bed\u8a00\u4e2d\u7684\u9ad8\u9891\u8bcd\u3002\u4f8b\u5982\uff0c\u5728\u82f1\u8bed\u4e2d\uff0c\u8fd9\u4e9b\u8bcd\u5305\u62ec \"a\"\u3001\"an\"\u3001\"the\"\u3001\"for \"\u7b49\u3002\u5220\u9664\u505c\u6b62\u8bcd\u5e76\u975e\u603b\u662f\u660e\u667a\u7684\u9009\u62e9\uff0c\u8fd9\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u4e1a\u52a1\u95ee\u9898\u3002\u50cf \"I need a new dog\"\u8fd9\u6837\u7684\u53e5\u5b50\uff0c\u53bb\u6389\u505c\u6b62\u8bcd\u540e\u4f1a\u53d8\u6210 \"need new dog\"\uff0c\u6b64\u65f6\u6211\u4eec\u4e0d\u77e5\u9053\u8c01\u9700\u8981new dog\u3002 \u5982\u679c\u6211\u4eec\u603b\u662f\u5220\u9664\u505c\u6b62\u8bcd\uff0c\u5c31\u4f1a\u4e22\u5931\u5f88\u591a\u4e0a\u4e0b\u6587\u4fe1\u606f\u3002\u4f60\u53ef\u4ee5\u5728 NLTK \u4e2d\u627e\u5230\u8bb8\u591a\u8bed\u8a00\u7684\u505c\u6b62\u8bcd\uff0c\u5982\u679c\u6ca1\u6709\uff0c\u4f60\u4e5f\u53ef\u4ee5\u5728\u81ea\u5df1\u559c\u6b22\u7684\u641c\u7d22\u5f15\u64ce\u4e0a\u5feb\u901f\u641c\u7d22\u4e00\u4e0b\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u8f6c\u5230\u5927\u591a\u6570\u4eba\u90fd\u559c\u6b22\u4f7f\u7528\u7684\u65b9\u6cd5\uff1a\u6df1\u5ea6\u5b66\u4e60\u3002\u4f46\u9996\u5148\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u8bcd\u5d4c\u5165\uff08embedings for words\uff09\u3002\u4f60\u5df2\u7ecf\u770b\u5230\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u5c06\u6807\u8bb0\u8f6c\u6362\u6210\u4e86\u6570\u5b57\u3002\u56e0\u6b64\uff0c\u5982\u679c\u67d0\u4e2a\u8bed\u6599\u5e93\u4e2d\u6709 N \u4e2a\u552f\u4e00\u7684\u8bcd\u5757\uff0c\u5b83\u4eec\u53ef\u4ee5\u7528 0 \u5230 N-1 \u4e4b\u95f4\u7684\u6574\u6570\u6765\u8868\u793a\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u7528\u5411\u91cf\u6765\u8868\u793a\u8fd9\u4e9b\u6574\u6570\u8bcd\u5757\u3002\u8fd9\u79cd\u5c06\u5355\u8bcd\u8868\u793a\u6210\u5411\u91cf\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u5355\u8bcd\u5d4c\u5165\u6216\u5355\u8bcd\u5411\u91cf\u3002\u8c37\u6b4c\u7684 Word2Vec \u662f\u5c06\u5355\u8bcd\u8f6c\u6362\u4e3a\u5411\u91cf\u7684\u6700\u53e4\u8001\u65b9\u6cd5\u4e4b\u4e00\u3002\u6b64\u5916\uff0c\u8fd8\u6709 Facebook \u7684 FastText \u548c\u65af\u5766\u798f\u5927\u5b66\u7684 GloVe\uff08\u7528\u4e8e\u5355\u8bcd\u8868\u793a\u7684\u5168\u5c40\u5411\u91cf\uff09\u3002\u8fd9\u4e9b\u65b9\u6cd5\u5f7c\u6b64\u5927\u76f8\u5f84\u5ead\u3002 \u5176\u57fa\u672c\u601d\u60f3\u662f\u5efa\u7acb\u4e00\u4e2a\u6d45\u5c42\u7f51\u7edc\uff0c\u901a\u8fc7\u91cd\u6784\u8f93\u5165\u53e5\u5b50\u6765\u5b66\u4e60\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u5468\u56f4\u7684\u6240\u6709\u5355\u8bcd\u6765\u8bad\u7ec3\u7f51\u7edc\u9884\u6d4b\u4e00\u4e2a\u7f3a\u5931\u7684\u5355\u8bcd\uff0c\u5728\u6b64\u8fc7\u7a0b\u4e2d\uff0c\u7f51\u7edc\u5c06\u5b66\u4e60\u5e76\u66f4\u65b0\u6240\u6709\u76f8\u5173\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u8fd9\u79cd\u65b9\u6cd5\u4e5f\u88ab\u79f0\u4e3a\u8fde\u7eed\u8bcd\u888b\u6216 CBoW \u6a21\u578b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e2a\u5355\u8bcd\u6765\u9884\u6d4b\u4e0a\u4e0b\u6587\u4e2d\u7684\u5355\u8bcd\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8df3\u683c\u6a21\u578b\u3002Word2Vec \u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\u5b66\u4e60\u5d4c\u5165\u3002 FastText \u53ef\u4ee5\u5b66\u4e60\u5b57\u7b26 n-gram \u7684\u5d4c\u5165\u3002\u548c\u5355\u8bcd n-gram \u4e00\u6837\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5b57\u7b26\uff0c\u5219\u79f0\u4e3a\u5b57\u7b26 n-gram\uff0c\u6700\u540e\uff0cGloVe \u901a\u8fc7\u5171\u73b0\u77e9\u9635\u6765\u5b66\u4e60\u8fd9\u4e9b\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0c\u6240\u6709\u8fd9\u4e9b\u4e0d\u540c\u7c7b\u578b\u7684\u5d4c\u5165\u6700\u7ec8\u90fd\u4f1a\u8fd4\u56de\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u8bed\u6599\u5e93\uff08\u4f8b\u5982\u82f1\u8bed\u7ef4\u57fa\u767e\u79d1\uff09\u4e2d\u7684\u5355\u8bcd\uff0c\u503c\u662f\u5927\u5c0f\u4e3a N\uff08\u901a\u5e38\u4e3a 300\uff09\u7684\u5411\u91cf\u3002 \u56fe 1\uff1a\u53ef\u89c6\u5316\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u3002 \u56fe 1 \u663e\u793a\u4e86\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u7684\u53ef\u89c6\u5316\u6548\u679c\u3002\u5047\u8bbe\u6211\u4eec\u4ee5\u67d0\u79cd\u65b9\u5f0f\u5b8c\u6210\u4e86\u8bcd\u8bed\u7684\u4e8c\u7ef4\u8868\u793a\u3002\u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u4eceBerlin\uff08\u5fb7\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u4e2d\u51cf\u53bb\u5fb7\u56fd\uff08Germany\uff09\u7684\u5411\u91cf\uff0c\u518d\u52a0\u4e0a\u6cd5\u56fd\uff08france\uff09\u7684\u5411\u91cf\uff0c\u5c31\u4f1a\u5f97\u5230\u4e00\u4e2a\u63a5\u8fd1Paris\uff08\u6cd5\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u3002\u7531\u6b64\u53ef\u89c1\uff0c\u5d4c\u5165\u5f0f\u4e5f\u80fd\u8fdb\u884c\u7c7b\u6bd4\u3002 \u8fd9\u5e76\u4e0d\u603b\u662f\u6b63\u786e\u7684\uff0c\u4f46\u8fd9\u6837\u7684\u4f8b\u5b50\u6709\u52a9\u4e8e\u7406\u89e3\u5355\u8bcd\u5d4c\u5165\u7684\u4f5c\u7528\u3002\u50cf \"\u55e8\uff0c\u4f60\u597d\u5417 \"\u8fd9\u6837\u7684\u53e5\u5b50\u53ef\u4ee5\u7528\u4e0b\u9762\u7684\u4e00\u5806\u5411\u91cf\u6765\u8868\u793a\u3002 hi \u2500> [vector (v1) of size 300] , \u2500> [vector (v2) of size 300] how \u2500> [vector (v3) of size 300] are \u2500> [vector (v4) of size 300] you \u2500> [vector (v5) of size 300] ? \u2500> [vector (v6) of size 300] \u4f7f\u7528\u8fd9\u4e9b\u4fe1\u606f\u6709\u591a\u79cd\u65b9\u6cd5\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u4e4b\u4e00\u5c31\u662f\u4f7f\u7528\u5d4c\u5165\u5411\u91cf\u3002\u5982\u4e0a\u4f8b\u6240\u793a\uff0c\u6bcf\u4e2a\u5355\u8bcd\u90fd\u6709\u4e00\u4e2a 1x300 \u7684\u5d4c\u5165\u5411\u91cf\u3002\u5229\u7528\u8fd9\u4e9b\u4fe1\u606f\uff0c\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u51fa\u6574\u4e2a\u53e5\u5b50\u7684\u5d4c\u5165\u3002\u8ba1\u7b97\u65b9\u6cd5\u6709\u591a\u79cd\u3002\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\u5982\u4e0b\u6240\u793a\u3002\u5728\u8fd9\u4e2a\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u5c06\u7ed9\u5b9a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u5411\u91cf\u63d0\u53d6\u51fa\u6765\uff0c\u7136\u540e\u4ece\u6240\u6709\u6807\u8bb0\u8bcd\u7684\u5355\u8bcd\u5411\u91cf\u4e2d\u521b\u5efa\u4e00\u4e2a\u5f52\u4e00\u5316\u7684\u5355\u8bcd\u5411\u91cf\u3002\u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u53e5\u5b50\u5411\u91cf\u3002 import numpy as np def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): words = str ( s ) . lower () words = tokenizer ( words ) words = [ w for w in words if not w in stop_words ] words = [ w for w in words if w . isalpha ()] M = [] for w in words : if w in embedding_dict : M . append ( embedding_dict [ w ]) if len ( M ) == 0 : return np . zeros ( 300 ) M = np . array ( M ) v = M . sum ( axis = 0 ) return v / np . sqrt (( v ** 2 ) . sum ()) \u6211\u4eec\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5c06\u6240\u6709\u793a\u4f8b\u8f6c\u6362\u6210\u4e00\u4e2a\u5411\u91cf\u3002\u6211\u4eec\u80fd\u5426\u4f7f\u7528 fastText \u5411\u91cf\u6765\u6539\u8fdb\u4e4b\u524d\u7684\u7ed3\u679c\uff1f\u6bcf\u7bc7\u8bc4\u8bba\u90fd\u6709 300 \u4e2a\u7279\u5f81\u3002 import io import numpy as np import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df = df . sample ( frac = 1 ) . reset_index ( drop = True ) print ( \"Loading embeddings\" ) embeddings = load_vectors ( \"../input/crawl-300d-2M.vec\" ) print ( \"Creating sentence vectors\" ) vectors = [] for review in df . review . values : vectors . append ( sentence_to_vec ( s = review , embedding_dict = embeddings , stop_words = [], tokenizer = word_tokenize ) ) vectors = np . array ( vectors ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for fold_ , ( t_ , v_ ) in enumerate ( kf . split ( X = vectors , y = y )): print ( f \"Training fold: { fold_ } \" ) xtrain = vectors [ t_ , :] ytrain = y [ t_ ] xtest = vectors [ v_ , :] ytest = y [ v_ ] model = linear_model . LogisticRegression () model . fit ( xtrain , ytrain ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( ytest , preds ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u5c06\u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Loading embeddings Creating sentence vectors Training fold : 0 Accuracy = 0.8619 Training fold : 1 Accuracy = 0.8661 Training fold : 2 Accuracy = 0.8544 Training fold : 3 Accuracy = 0.8624 Training fold : 4 Accuracy = 0.8595 Wow\uff01\u771f\u662f\u51fa\u4e4e\u610f\u6599\u3002\u6211\u4eec\u6240\u505a\u7684\u4e00\u5207\u90fd\u662f\u4e3a\u4e86\u4f7f\u7528 FastText \u5d4c\u5165\u3002\u8bd5\u7740\u628a\u5d4c\u5165\u5f0f\u6362\u6210 GloVe\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u3002 \u5f53\u6211\u4eec\u8c08\u8bba\u6587\u672c\u6570\u636e\u65f6\uff0c\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\u4e00\u4ef6\u4e8b\u3002\u6587\u672c\u6570\u636e\u4e0e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u975e\u5e38\u76f8\u4f3c\u3002\u5982\u56fe 2 \u6240\u793a\uff0c\u6211\u4eec\u8bc4\u8bba\u4e2d\u7684\u4efb\u4f55\u6837\u672c\u90fd\u662f\u5728\u4e0d\u540c\u65f6\u95f4\u6233\u4e0a\u6309\u9012\u589e\u987a\u5e8f\u6392\u5217\u7684\u6807\u8bb0\u5e8f\u5217\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u53ef\u4ee5\u8868\u793a\u4e3a\u4e00\u4e2a\u5411\u91cf/\u5d4c\u5165\u3002 \u56fe 2\uff1a\u5c06\u6807\u8bb0\u8868\u793a\u4e3a\u5d4c\u5165\uff0c\u5e76\u5c06\u5176\u89c6\u4e3a\u65f6\u95f4\u5e8f\u5217 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e7f\u6cdb\u7528\u4e8e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u7684\u6a21\u578b\uff0c\u4f8b\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\uff08LSTM\uff09\u6216\u95e8\u63a7\u9012\u5f52\u5355\u5143\uff08GRU\uff09\uff0c\u751a\u81f3\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff08CNN\uff09\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u53cc\u5411 LSTM \u6a21\u578b\u3002 \u9996\u5148\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u9879\u76ee\u3002\u4f60\u53ef\u4ee5\u968f\u610f\u7ed9\u5b83\u547d\u540d\u3002\u7136\u540e\uff0c\u6211\u4eec\u7684\u7b2c\u4e00\u6b65\u5c06\u662f\u5206\u5272\u6570\u636e\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f df . to_csv ( \"../input/imdb_folds.csv\" , index = False ) \u5c06\u6570\u636e\u96c6\u5212\u5206\u4e3a\u591a\u4e2a\u6298\u53e0\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5728 dataset.py \u4e2d\u521b\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u4f1a\u8fd4\u56de\u4e00\u4e2a\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\u6837\u672c\u3002 import torch class IMDBDataset : def __init__ ( self , reviews , targets ): self . reviews = reviews self . target = targets def __len__ ( self ): return len ( self . reviews ) def __getitem__ ( self , item ): review = self . reviews [ item , :] target = self . target [ item ] return { \"review\" : torch . tensor ( review , dtype = torch . long ), \"target\" : torch . tensor ( target , dtype = torch . float ) } \u5b8c\u6210\u6570\u636e\u96c6\u5206\u7c7b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa lstm.py\uff0c\u5176\u4e2d\u5305\u542b\u6211\u4eec\u7684 LSTM \u6a21\u578b import torch import torch.nn as nn class LSTM ( nn . Module ): def __init__ ( self , embedding_matrix ): super ( LSTM , self ) . __init__ () num_words = embedding_matrix . shape [ 0 ] embed_dim = embedding_matrix . shape [ 1 ] self . embedding = nn . Embedding ( num_embeddings = num_words , embedding_dim = embed_dim ) self . embedding . weight = nn . Parameter ( torch . tensor ( embedding_matrix , dtype = torch . float32 ) ) self . embedding . weight . requires_grad = False self . lstm = nn . LSTM ( embed_dim , 128 , bidirectional = True , batch_first = True , ) self . out = nn . Linear ( 512 , 1 ) def forward ( self , x ): x = self . embedding ( x ) x , _ = self . lstm ( x ) avg_pool = torch . mean ( x , 1 ) max_pool , _ = torch . max ( x , 1 ) out = torch . cat (( avg_pool , max_pool ), 1 ) out = self . out ( out ) return out \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa engine.py\uff0c\u5176\u4e2d\u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u51fd\u6570\u3002 import torch import torch.nn as nn def train ( data_loader , model , optimizer , device ): model . train () for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () predictions = model ( reviews ) loss = nn . BCEWithLogitsLoss ()( predictions , targets . view ( - 1 , 1 ) ) loss . backward () optimizer . step () def evaluate ( data_loader , model , device ): final_predictions = [] final_targets = [] model . eval () with torch . no_grad (): for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) predictions = model ( reviews ) predictions = predictions . cpu () . numpy () . tolist () targets = data [ \"target\" ] . cpu () . numpy () . tolist () final_predictions . extend ( predictions ) final_targets . extend ( targets ) return final_predictions , final_targets \u8fd9\u4e9b\u51fd\u6570\u5c06\u5728 train.py \u4e2d\u4e3a\u6211\u4eec\u63d0\u4f9b\u5e2e\u52a9\uff0c\u8be5\u51fd\u6570\u7528\u4e8e\u8bad\u7ec3\u591a\u4e2a\u6298\u53e0\u3002 import io import torch import numpy as np import pandas as pd import tensorflow as tf from sklearn import metrics import config import dataset import engine import lstm def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def create_embedding_matrix ( word_index , embedding_dict ): embedding_matrix = np . zeros (( len ( word_index ) + 1 , 300 )) for word , i in word_index . items (): if word in embedding_dict : embedding_matrix [ i ] = embedding_dict [ word ] return embedding_matrix def run ( df , fold ): train_df = df [ df . kfold != fold ] . reset_index ( drop = True ) valid_df = df [ df . kfold == fold ] . reset_index ( drop = True ) print ( \"Fitting tokenizer\" ) tokenizer = tf . keras . preprocessing . text . Tokenizer () tokenizer . fit_on_texts ( df . review . values . tolist ()) xtrain = tokenizer . texts_to_sequences ( train_df . review . values ) xtest = tokenizer . texts_to_sequences ( valid_df . review . values ) xtrain = tf . keras . preprocessing . sequence . pad_sequences ( xtrain , maxlen = config . MAX_LEN ) xtest = tf . keras . preprocessing . sequence . pad_sequences ( xtest , maxlen = config . MAX_LEN ) train_dataset = dataset . IMDBDataset ( reviews = xtrain , targets = train_df . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 2 ) valid_dataset = dataset . IMDBDataset ( reviews = xtest , targets = valid_df . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) print ( \"Loading embeddings\" ) embedding_dict = load_vectors ( \"../input/crawl-300d-2M.vec\" ) embedding_matrix = create_embedding_matrix ( tokenizer . word_index , embedding_dict ) device = torch . device ( \"cuda\" ) model = lstm . LSTM ( embedding_matrix ) model . to ( device ) optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) print ( \"Training Model\" ) best_accuracy = 0 early_stopping_counter = 0 for epoch in range ( config . EPOCHS ): engine . train ( train_data_loader , model , optimizer , device ) outputs , targets = engine . evaluate ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"FOLD: { fold } , Epoch: { epoch } , Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : best_accuracy = accuracy else : early_stopping_counter += 1 if early_stopping_counter > 2 : break if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb_folds.csv\" ) run ( df , fold = 0 ) run ( df , fold = 1 ) run ( df , fold = 2 ) run ( df , fold = 3 ) run ( df , fold = 4 ) \u6700\u540e\u662f config.py\u3002 MAX_LEN = 128 TRAIN_BATCH_SIZE = 16 VALID_BATCH_SIZE = 8 EPOCHS = 10 \u8ba9\u6211\u4eec\u770b\u770b\u8f93\u51fa\uff1a FOLD : 0 , Epoch : 3 , Accuracy Score = 0.9015 FOLD : 1 , Epoch : 4 , Accuracy Score = 0.9007 FOLD : 2 , Epoch : 3 , Accuracy Score = 0.8924 FOLD : 3 , Epoch : 2 , Accuracy Score = 0.9 FOLD : 4 , Epoch : 1 , Accuracy Score = 0.878 \u8fd9\u662f\u8fc4\u4eca\u4e3a\u6b62\u6211\u4eec\u83b7\u5f97\u7684\u6700\u597d\u6210\u7ee9\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u53ea\u663e\u793a\u4e86\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7cbe\u5ea6\u6700\u9ad8\u7684Epoch\u3002 \u4f60\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u9884\u5148\u8bad\u7ec3\u7684\u5d4c\u5165\u548c\u7b80\u5355\u7684\u53cc\u5411 LSTM\u3002 \u5982\u679c\u4f60\u60f3\u6539\u53d8\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u53ea\u6539\u53d8 lstm.py \u4e2d\u7684\u6a21\u578b\u5e76\u4fdd\u6301\u4e00\u5207\u4e0d\u53d8\u3002 \u8fd9\u79cd\u4ee3\u7801\u53ea\u9700\u8981\u5f88\u5c11\u7684\u5b9e\u9a8c\u6539\u52a8\uff0c\u5e76\u4e14\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \u4f8b\u5982\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5b66\u4e60\u5d4c\u5165\u800c\u4e0d\u662f\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6\u4e00\u4e9b\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u7ec4\u5408\u591a\u4e2a\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528GRU\uff0c\u60a8\u53ef\u4ee5\u5728\u5d4c\u5165\u540e\u4f7f\u7528\u7a7a\u95f4dropout\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0GRU LSTM \u5c42\u4e4b\u540e\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0\u4e24\u4e2a LSTM \u5c42\uff0c\u60a8\u53ef\u4ee5\u8fdb\u884c LSTM-GRU-LSTM \u914d\u7f6e\uff0c\u60a8\u53ef\u4ee5\u7528\u5377\u79ef\u5c42\u66ff\u6362 LSTM \u7b49\uff0c\u800c\u65e0\u9700\u5bf9\u4ee3\u7801\u8fdb\u884c\u592a\u591a\u66f4\u6539\u3002 \u6211\u63d0\u5230\u7684\u5927\u90e8\u5206\u5185\u5bb9\u53ea\u9700\u8981\u66f4\u6539\u6a21\u578b\u7c7b\u3002 \u5f53\u60a8\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\u65f6\uff0c\u5c1d\u8bd5\u67e5\u770b\u6709\u591a\u5c11\u5355\u8bcd\u65e0\u6cd5\u627e\u5230\u5d4c\u5165\u4ee5\u53ca\u539f\u56e0\u3002 \u9884\u8bad\u7ec3\u5d4c\u5165\u7684\u5355\u8bcd\u8d8a\u591a\uff0c\u7ed3\u679c\u5c31\u8d8a\u597d\u3002 \u6211\u5411\u60a8\u5c55\u793a\u4ee5\u4e0b\u672a\u6ce8\u91ca\u7684 (!) \u51fd\u6570\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5b83\u4e3a\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u8bad\u7ec3\u5d4c\u5165\u521b\u5efa\u5d4c\u5165\u77e9\u9635\uff0c\u5176\u683c\u5f0f\u4e0e glove \u6216 fastText \u76f8\u540c\uff08\u53ef\u80fd\u9700\u8981\u8fdb\u884c\u4e00\u4e9b\u66f4\u6539\uff09\u3002 def load_embeddings ( word_index , embedding_file , vector_length = 300 ): max_features = len ( word_index ) + 1 words_to_find = list ( word_index . keys ()) more_words_to_find = [] for wtf in words_to_find : more_words_to_find . append ( wtf ) more_words_to_find . append ( str ( wtf ) . capitalize ()) more_words_to_find = set ( more_words_to_find ) def get_coefs ( word , * arr ): return word , np . asarray ( arr , dtype = 'float32' ) embeddings_index = dict ( get_coefs ( * o . strip () . split ( \" \" )) for o in open ( embedding_file ) if o . split ( \" \" )[ 0 ] in more_words_to_find and len ( o ) > 100 ) embedding_matrix = np . zeros (( max_features , vector_length )) for word , i in word_index . items (): if i >= max_features : continue embedding_vector = embeddings_index . get ( word ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . capitalize () ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . upper () ) if ( embedding_vector is not None and len ( embedding_vector ) == vector_length ): embedding_matrix [ i ] = embedding_vector return embedding_matrix \u9605\u8bfb\u5e76\u8fd0\u884c\u4e0a\u9762\u7684\u51fd\u6570\uff0c\u770b\u770b\u53d1\u751f\u4e86\u4ec0\u4e48\u3002 \u8be5\u51fd\u6570\u8fd8\u53ef\u4ee5\u4fee\u6539\u4e3a\u4f7f\u7528\u8bcd\u5e72\u8bcd\u6216\u8bcd\u5f62\u8fd8\u539f\u8bcd\u3002 \u6700\u540e\uff0c\u60a8\u5e0c\u671b\u8bad\u7ec3\u8bed\u6599\u5e93\u4e2d\u7684\u672a\u77e5\u5355\u8bcd\u6570\u91cf\u6700\u5c11\u3002 \u53e6\u4e00\u4e2a\u6280\u5de7\u662f\u5b66\u4e60\u5d4c\u5165\u5c42\uff0c\u5373\u4f7f\u5176\u53ef\u8bad\u7ec3\uff0c\u7136\u540e\u8bad\u7ec3\u7f51\u7edc\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5206\u7c7b\u95ee\u9898\u6784\u5efa\u4e86\u5f88\u591a\u6a21\u578b\u3002 \u7136\u800c\uff0c\u73b0\u5728\u662f\u5e03\u5076\u65f6\u4ee3\uff0c\u8d8a\u6765\u8d8a\u591a\u7684\u4eba\u8f6c\u5411\u57fa\u4e8e\u53d8\u5f62\u91d1\u521a\u7684\u6a21\u578b\u3002 \u57fa\u4e8e Transformer \u7684\u7f51\u7edc\u80fd\u591f\u5904\u7406\u672c\u8d28\u4e0a\u957f\u671f\u7684\u4f9d\u8d56\u5173\u7cfb\u3002 LSTM \u4ec5\u5f53\u5b83\u770b\u5230\u524d\u4e00\u4e2a\u5355\u8bcd\u65f6\u624d\u67e5\u770b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u3002 \u53d8\u538b\u5668\u7684\u60c5\u51b5\u5e76\u975e\u5982\u6b64\u3002 \u5b83\u53ef\u4ee5\u540c\u65f6\u67e5\u770b\u6574\u4e2a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u3002 \u56e0\u6b64\uff0c\u53e6\u4e00\u4e2a\u4f18\u70b9\u662f\u5b83\u53ef\u4ee5\u8f7b\u677e\u5e76\u884c\u5316\u5e76\u66f4\u6709\u6548\u5730\u4f7f\u7528 GPU\u3002 Transformers \u662f\u4e00\u4e2a\u975e\u5e38\u5e7f\u6cdb\u7684\u8bdd\u9898\uff0c\u6709\u592a\u591a\u7684\u6a21\u578b\uff1a BERT\u3001RoBERTa\u3001XLNet\u3001XLM-RoBERTa\u3001T5 \u7b49\u3002\u6211\u5c06\u5411\u60a8\u5c55\u793a\u4e00\u79cd\u53ef\u7528\u4e8e\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\uff08T5 \u9664\u5916\uff09\u8fdb\u884c\u5206\u7c7b\u7684\u901a\u7528\u65b9\u6cd5 \u6211\u4eec\u4e00\u76f4\u5728\u8ba8\u8bba\u7684\u95ee\u9898\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u4e9b\u53d8\u538b\u5668\u9700\u8981\u8bad\u7ec3\u5b83\u4eec\u6240\u9700\u7684\u8ba1\u7b97\u80fd\u529b\u3002 \u56e0\u6b64\uff0c\u5982\u679c\u60a8\u6ca1\u6709\u9ad8\u7aef\u7cfb\u7edf\uff0c\u4e0e\u57fa\u4e8e LSTM \u6216 TF-IDF \u7684\u6a21\u578b\u76f8\u6bd4\uff0c\u8bad\u7ec3\u6a21\u578b\u53ef\u80fd\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u3002 \u6211\u4eec\u8981\u505a\u7684\u7b2c\u4e00\u4ef6\u4e8b\u662f\u521b\u5efa\u4e00\u4e2a\u914d\u7f6e\u6587\u4ef6\u3002 import transformers MAX_LEN = 512 TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 EPOCHS = 10 BERT_PATH = \"../input/bert_base_uncased/\" MODEL_PATH = \"model.bin\" TRAINING_FILE = \"../input/imdb.csv\" TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u8fd9\u91cc\u7684\u914d\u7f6e\u6587\u4ef6\u662f\u6211\u4eec\u5b9a\u4e49\u5206\u8bcd\u5668\u548c\u5176\u4ed6\u6211\u4eec\u60f3\u8981\u7ecf\u5e38\u66f4\u6539\u7684\u53c2\u6570\u7684\u552f\u4e00\u5730\u65b9 \u2014\u2014 \u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u505a\u5f88\u591a\u5b9e\u9a8c\u800c\u4e0d\u9700\u8981\u8fdb\u884c\u5927\u91cf\u66f4\u6539\u3002 \u4e0b\u4e00\u6b65\u662f\u6784\u5efa\u6570\u636e\u96c6\u7c7b\u3002 import config import torch class BERTDataset : def __init__ ( self , review , target ): self . review = review self . target = target self . tokenizer = config . TOKENIZER self . max_len = config . MAX_LEN def __len__ ( self ): return len ( self . review ) def __getitem__ ( self , item ): review = str ( self . review [ item ]) review = \" \" . join ( review . split ()) inputs = self . tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = self . max_len , pad_to_max_length = True , ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] return { \"ids\" : torch . tensor ( ids , dtype = torch . long ), \"mask\" : torch . tensor ( mask , dtype = torch . long ), \"token_type_ids\" : torch . tensor ( token_type_ids , dtype = torch . long ), \"targets\" : torch . tensor ( self . target [ item ], dtype = torch . float ) } \u73b0\u5728\u6211\u4eec\u6765\u5230\u4e86\u8be5\u9879\u76ee\u7684\u6838\u5fc3\uff0c\u5373\u6a21\u578b\u3002 import config import transformers import torch.nn as nn class BERTBaseUncased ( nn . Module ): def __init__ ( self ): super ( BERTBaseUncased , self ) . __init__ () self . bert = transformers . BertModel . from_pretrained ( config . BERT_PATH ) self . bert_drop = nn . Dropout ( 0.3 ) self . out = nn . Linear ( 768 , 1 ) def forward ( self , ids , mask , token_type_ids ): hidden state _ , o2 = self . bert ( ids , attention_mask = mask , token_type_ids = token_type_ids ) bo = self . bert_drop ( o2 ) output = self . out ( bo ) return output \u8be5\u6a21\u578b\u8fd4\u56de\u5355\u4e2a\u8f93\u51fa\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e26\u6709 logits \u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\uff0c\u5b83\u9996\u5148\u5e94\u7528 sigmoid\uff0c\u7136\u540e\u8ba1\u7b97\u635f\u5931\u3002 \u8fd9\u662f\u5728engine.py \u4e2d\u5b8c\u6210\u7684\u3002 import torch import torch.nn as nn def loss_fn ( outputs , targets ): return nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) def train_fn ( data_loader , model , optimizer , device , scheduler ): model . train () for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) loss = loss_fn ( outputs , targets ) loss . backward () optimizer . step () scheduler . step () def eval_fn ( data_loader , model , device ): model . eval () fin_targets = [] fin_outputs = [] with torch . no_grad (): for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) targets = targets . cpu () . detach () fin_targets . extend ( targets . numpy () . tolist ()) outputs = torch . sigmoid ( outputs ) . cpu () . detach () fin_outputs . extend ( outputs . numpy () . tolist ()) return fin_outputs , fin_targets \u6700\u540e\uff0c\u6211\u4eec\u51c6\u5907\u597d\u8bad\u7ec3\u4e86\u3002 \u6211\u4eec\u6765\u770b\u770b\u8bad\u7ec3\u811a\u672c\u5427\uff01 import config import dataset import engine import torch import pandas as pd import torch.nn as nn import numpy as np from model import BERTBaseUncased from sklearn import model_selection from sklearn import metrics from transformers import AdamW from transformers import get_linear_schedule_with_warmup def train (): dfx = pd . read_csv ( config . TRAINING_FILE ) . fillna ( \"none\" ) dfx . sentiment = dfx . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df_train , df_valid = model_selection . train_test_split ( dfx , test_size = 0.1 , random_state = 42 , stratify = dfx . sentiment . values ) df_train = df_train . reset_index ( drop = True ) df_valid = df_valid . reset_index ( drop = True ) train_dataset = dataset . BERTDataset ( review = df_train . review . values , target = df_train . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 4 ) valid_dataset = dataset . BERTDataset ( review = df_valid . review . values , target = df_valid . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) device = torch . device ( \"cuda\" ) model = BERTBaseUncased () model . to ( device ) param_optimizer = list ( model . named_parameters ()) no_decay = [ \"bias\" , \"LayerNorm.bias\" , \"LayerNorm.weight\" ] optimizer_parameters = [ { \"params\" : [ p for n , p in param_optimizer if not any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.001 , } \uff0c { \"params\" : [ p for n , p in param_optimizer if any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.0 , }] num_train_steps = int ( len ( df_train ) / config . TRAIN_BATCH_SIZE * config . EPOCHS ) optimizer = AdamW ( optimizer_parameters , lr = 3e-5 ) scheduler = get_linear_schedule_with_warmup ( optimizer , num_warmup_steps = 0 , num_training_steps = num_train_steps ) model = nn . DataParallel ( model ) best_accuracy = 0 for epoch in range ( config . EPOCHS ): engine . train_fn ( train_data_loader , model , optimizer , device , scheduler ) outputs , targets = engine . eval_fn ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : torch . save ( model . state_dict (), config . MODEL_PATH ) best_accuracy = accuracy if __name__ == \"__main__\" : train () \u4e4d\u4e00\u770b\u53ef\u80fd\u770b\u8d77\u6765\u5f88\u591a\uff0c\u4f46\u4e00\u65e6\u60a8\u4e86\u89e3\u4e86\u5404\u4e2a\u7ec4\u4ef6\uff0c\u5c31\u4e0d\u518d\u90a3\u4e48\u7b80\u5355\u4e86\u3002 \u60a8\u53ea\u9700\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u5373\u53ef\u8f7b\u677e\u5c06\u5176\u66f4\u6539\u4e3a\u60a8\u60f3\u8981\u4f7f\u7528\u7684\u4efb\u4f55\u5176\u4ed6\u53d8\u538b\u5668\u6a21\u578b\u3002 \u8be5\u6a21\u578b\u7684\u51c6\u786e\u7387\u4e3a 93%\uff01 \u54c7\uff01 \u8fd9\u6bd4\u4efb\u4f55\u5176\u4ed6\u6a21\u578b\u90fd\u8981\u597d\u5f97\u591a\u3002 \u4f46\u662f\u8fd9\u503c\u5f97\u5417\uff1f \u6211\u4eec\u4f7f\u7528 LSTM \u80fd\u591f\u5b9e\u73b0 90% \u7684\u76ee\u6807\uff0c\u800c\u4e14\u5b83\u4eec\u66f4\u7b80\u5355\u3001\u66f4\u5bb9\u6613\u8bad\u7ec3\u5e76\u4e14\u63a8\u7406\u901f\u5ea6\u66f4\u5feb\u3002 \u901a\u8fc7\u4f7f\u7528\u4e0d\u540c\u7684\u6570\u636e\u5904\u7406\u6216\u8c03\u6574\u5c42\u3001\u8282\u70b9\u3001dropout\u3001\u5b66\u4e60\u7387\u3001\u66f4\u6539\u4f18\u5316\u5668\u7b49\u53c2\u6570\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6a21\u578b\u6539\u8fdb\u4e00\u4e2a\u767e\u5206\u70b9\u3002\u7136\u540e\u6211\u4eec\u5c06\u4ece BERT \u4e2d\u83b7\u5f97\u7ea6 2% \u7684\u6536\u76ca\u3002 \u53e6\u4e00\u65b9\u9762\uff0cBERT \u7684\u8bad\u7ec3\u65f6\u95f4\u8981\u957f\u5f97\u591a\uff0c\u53c2\u6570\u5f88\u591a\uff0c\u800c\u4e14\u63a8\u7406\u901f\u5ea6\u4e5f\u5f88\u6162\u3002 \u6700\u540e\uff0c\u60a8\u5e94\u8be5\u5ba1\u89c6\u81ea\u5df1\u7684\u4e1a\u52a1\u5e76\u505a\u51fa\u660e\u667a\u7684\u9009\u62e9\u3002 \u4e0d\u8981\u4ec5\u4ec5\u56e0\u4e3a BERT\u201c\u9177\u201d\u800c\u9009\u62e9\u5b83\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6211\u4eec\u5728\u8fd9\u91cc\u8ba8\u8bba\u7684\u552f\u4e00\u4efb\u52a1\u662f\u5206\u7c7b\uff0c\u4f46\u5c06\u5176\u66f4\u6539\u4e3a\u56de\u5f52\u3001\u591a\u6807\u7b7e\u6216\u591a\u7c7b\u53ea\u9700\u8981\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u3002 \u4f8b\u5982\uff0c\u591a\u7c7b\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u7684\u540c\u4e00\u95ee\u9898\u5c06\u6709\u591a\u4e2a\u8f93\u51fa\u548c\u4ea4\u53c9\u71b5\u635f\u5931\u3002 \u5176\u4ed6\u4e00\u5207\u90fd\u5e94\u8be5\u4fdd\u6301\u4e0d\u53d8\u3002 \u81ea\u7136\u8bed\u8a00\u5904\u7406\u975e\u5e38\u5e9e\u5927\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u5176\u4e2d\u7684\u4e00\u5c0f\u90e8\u5206\u3002 \u663e\u7136\uff0c\u8fd9\u662f\u4e00\u4e2a\u5f88\u5927\u7684\u6bd4\u4f8b\uff0c\u56e0\u4e3a\u5927\u591a\u6570\u5de5\u4e1a\u6a21\u578b\u90fd\u662f\u5206\u7c7b\u6216\u56de\u5f52\u6a21\u578b\u3002 \u5982\u679c\u6211\u5f00\u59cb\u8be6\u7ec6\u5199\u6240\u6709\u5185\u5bb9\uff0c\u6211\u6700\u7ec8\u53ef\u80fd\u4f1a\u5199\u51e0\u767e\u9875\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u51b3\u5b9a\u5c06\u6240\u6709\u5185\u5bb9\u5305\u542b\u5728\u4e00\u672c\u5355\u72ec\u7684\u4e66\u4e2d\uff1a\u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55 NLP \u95ee\u9898\uff01","title":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5"},{"location":"%E6%97%A0%E7%9B%91%E7%9D%A3%E5%92%8C%E6%9C%89%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0/","text":"\u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60 \u5728\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u65f6\uff0c\u901a\u5e38\u6709\u4e24\u7c7b\u6570\u636e\uff08\u548c\u673a\u5668\u5b66\u4e60\u6a21\u578b\uff09\uff1a \u76d1\u7763\u6570\u636e\uff1a\u603b\u662f\u6709\u4e00\u4e2a\u6216\u591a\u4e2a\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807 \u65e0\u76d1\u7763\u6570\u636e\uff1a\u6ca1\u6709\u4efb\u4f55\u76ee\u6807\u53d8\u91cf\u3002 \u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\u3002\u6211\u4eec\u9700\u8981\u9884\u6d4b\u4e00\u4e2a\u503c\u7684\u95ee\u9898\u88ab\u79f0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u3002\u4f8b\u5982\uff0c\u5982\u679c\u95ee\u9898\u662f\u6839\u636e\u5386\u53f2\u623f\u4ef7\u9884\u6d4b\u623f\u4ef7\uff0c\u90a3\u4e48\u533b\u9662\u3001\u5b66\u6821\u6216\u8d85\u5e02\u7684\u5b58\u5728\uff0c\u4e0e\u6700\u8fd1\u516c\u5171\u4ea4\u901a\u7684\u8ddd\u79bb\u7b49\u7279\u5f81\u5c31\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u540c\u6837\uff0c\u5f53\u6211\u4eec\u5f97\u5230\u732b\u548c\u72d7\u7684\u56fe\u50cf\u65f6\uff0c\u6211\u4eec\u4e8b\u5148\u77e5\u9053\u54ea\u4e9b\u662f\u732b\uff0c\u54ea\u4e9b\u662f\u72d7\uff0c\u5982\u679c\u4efb\u52a1\u662f\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u6240\u63d0\u4f9b\u7684\u56fe\u50cf\u662f\u732b\u8fd8\u662f\u72d7\uff0c\u90a3\u4e48\u8fd9\u4e2a\u95ee\u9898\u5c31\u88ab\u8ba4\u4e3a\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002 \u56fe 1\uff1a\u6709\u76d1\u7763\u5b66\u4e60\u6570\u636e \u5982\u56fe 1 \u6240\u793a\uff0c\u6570\u636e\u7684\u6bcf\u4e00\u884c\u90fd\u4e0e\u4e00\u4e2a\u76ee\u6807\u6216\u6807\u7b7e\u76f8\u5173\u8054\u3002\u5217\u662f\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u884c\u4ee3\u8868\u4e0d\u540c\u7684\u6570\u636e\u70b9\uff0c\u901a\u5e38\u79f0\u4e3a\u6837\u672c\u3002\u793a\u4f8b\u4e2d\u7684\u5341\u4e2a\u6837\u672c\u6709\u5341\u4e2a\u7279\u5f81\u548c\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\uff0c\u76ee\u6807\u53d8\u91cf\u53ef\u4ee5\u662f\u6570\u5b57\u6216\u7c7b\u522b\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5206\u7c7b\u53d8\u91cf\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5b9e\u6570\uff0c\u95ee\u9898\u5c31\u88ab\u5b9a\u4e49\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6709\u76d1\u7763\u95ee\u9898\u53ef\u5206\u4e3a\u4e24\u4e2a\u5b50\u7c7b\uff1a \u5206\u7c7b\uff1a\u9884\u6d4b\u7c7b\u522b\uff0c\u5982\u732b\u6216\u72d7 \u56de\u5f52\uff1a\u9884\u6d4b\u503c\uff0c\u5982\u623f\u4ef7 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6709\u65f6\u6211\u4eec\u53ef\u80fd\u4f1a\u5728\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u4f7f\u7528\u56de\u5f52\uff0c\u8fd9\u53d6\u51b3\u4e8e\u7528\u4e8e\u8bc4\u4f30\u7684\u6307\u6807\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u7a0d\u540e\u4f1a\u8ba8\u8bba\u8fd9\u4e2a\u95ee\u9898\u3002 \u53e6\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u95ee\u9898\u662f\u65e0\u76d1\u7763\u7c7b\u578b\u3002 \u65e0\u76d1\u7763 \u6570\u636e\u96c6\u6ca1\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\uff0c\u4e00\u822c\u6765\u8bf4\uff0c\u4e0e\u6709\u76d1\u7763\u95ee\u9898\u76f8\u6bd4\uff0c\u5904\u7406\u65e0\u76d1\u7763\u6570\u636e\u96c6\u66f4\u5177\u6311\u6218\u6027\u3002 \u5047\u8bbe\u4f60\u5728\u4e00\u5bb6\u5904\u7406\u4fe1\u7528\u5361\u4ea4\u6613\u7684\u91d1\u878d\u516c\u53f8\u5de5\u4f5c\u3002\u6bcf\u79d2\u949f\u90fd\u6709\u5927\u91cf\u6570\u636e\u6d8c\u5165\u3002\u552f\u4e00\u7684\u95ee\u9898\u662f\uff0c\u5f88\u96be\u627e\u5230\u4e00\u4e2a\u4eba\u6765\u5c06\u6bcf\u7b14\u4ea4\u6613\u6807\u8bb0\u4e3a\u6709\u6548\u4ea4\u6613\u3001\u771f\u5b9e\u4ea4\u6613\u6216\u6b3a\u8bc8\u4ea4\u6613\u3002\u5f53\u6211\u4eec\u6ca1\u6709\u4efb\u4f55\u5173\u4e8e\u4ea4\u6613\u662f\u6b3a\u8bc8\u8fd8\u662f\u771f\u5b9e\u7684\u4fe1\u606f\u65f6\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u65e0\u76d1\u7763\u95ee\u9898\u3002\u8981\u89e3\u51b3\u8fd9\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5fc5\u987b\u8003\u8651\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a\u591a\u5c11\u4e2a \u805a\u7c7b \u3002\u805a\u7c7b\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd8\u6709\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5e94\u7528\u4e8e\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5bf9\u4e8e\u6b3a\u8bc8\u68c0\u6d4b\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\u6570\u636e\u53ef\u4ee5\u5206\u4e3a\u4e24\u7c7b\uff08\u6b3a\u8bc8\u6216\u771f\u5b9e\uff09\u3002 \u5f53\u6211\u4eec\u77e5\u9053\u805a\u7c7b\u7684\u6570\u91cf\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u805a\u7c7b\u7b97\u6cd5\u6765\u89e3\u51b3\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5728\u56fe 2 \u4e2d\uff0c\u5047\u8bbe\u6570\u636e\u5206\u4e3a\u4e24\u7c7b\uff0c\u6df1\u8272\u4ee3\u8868\u6b3a\u8bc8\uff0c\u6d45\u8272\u4ee3\u8868\u771f\u5b9e\u4ea4\u6613\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528\u805a\u7c7b\u65b9\u6cd5\u4e4b\u524d\uff0c\u6211\u4eec\u5e76\u4e0d\u77e5\u9053\u8fd9\u4e9b\u7c7b\u522b\u3002\u5e94\u7528\u805a\u7c7b\u7b97\u6cd5\u540e\uff0c\u6211\u4eec\u5e94\u8be5\u80fd\u591f\u533a\u5206\u8fd9\u4e24\u4e2a\u5047\u5b9a\u76ee\u6807\u3002 \u4e3a\u4e86\u7406\u89e3\u65e0\u76d1\u7763\u95ee\u9898\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u8bb8\u591a\u5206\u89e3\u6280\u672f\uff0c\u5982 \u4e3b\u6210\u5206\u5206\u6790\uff08PCA\uff09\u3001t-\u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09 \u7b49\u3002 \u6709\u76d1\u7763\u7684\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\uff0c\u56e0\u4e3a\u5b83\u4eec\u5f88\u5bb9\u6613\u8bc4\u4f30\u3002\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u8bc4\u4f30\u6280\u672f\u3002\u7136\u800c\uff0c\u5bf9\u65e0\u76d1\u7763\u7b97\u6cd5\u7684\u7ed3\u679c\u8fdb\u884c\u8bc4\u4f30\u5177\u6709\u6311\u6218\u6027\uff0c\u9700\u8981\u5927\u91cf\u7684\u4eba\u4e3a\u5e72\u9884\u6216\u542f\u53d1\u5f0f\u65b9\u6cd5\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3b\u8981\u5173\u6ce8\u6709\u76d1\u7763\u6570\u636e\u548c\u6a21\u578b\uff0c\u4f46\u8fd9\u5e76\u4e0d\u610f\u5473\u7740\u6211\u4eec\u4f1a\u5ffd\u7565\u65e0\u76d1\u7763\u6570\u636e\u95ee\u9898\u3002 \u56fe 2\uff1a\u65e0\u76d1\u7763\u5b66\u4e60\u6570\u636e\u96c6 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u5f53\u4eba\u4eec\u5f00\u59cb\u5b66\u4e60\u6570\u636e\u79d1\u5b66\u6216\u673a\u5668\u5b66\u4e60\u65f6\uff0c\u90fd\u4f1a\u4ece\u975e\u5e38\u8457\u540d\u7684\u6570\u636e\u96c6\u5f00\u59cb\uff0c\u4f8b\u5982\u6cf0\u5766\u5c3c\u514b\u6570\u636e\u96c6\u6216\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\uff0c\u8fd9\u4e9b\u90fd\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u5728\u6cf0\u5766\u5c3c\u514b\u53f7\u6570\u636e\u96c6\u4e2d\uff0c\u4f60\u5fc5\u987b\u6839\u636e\u8239\u7968\u7b49\u7ea7\u3001\u6027\u522b\u3001\u5e74\u9f84\u7b49\u56e0\u7d20\u9884\u6d4b\u6cf0\u5766\u5c3c\u514b\u53f7\u4e0a\u4e58\u5ba2\u7684\u5b58\u6d3b\u7387\u3002\u540c\u6837\uff0c\u5728\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\u4e2d\uff0c\u60a8\u5fc5\u987b\u6839\u636e\u843c\u7247\u5bbd\u5ea6\u3001\u82b1\u74e3\u957f\u5ea6\u3001\u843c\u7247\u957f\u5ea6\u548c\u82b1\u74e3\u5bbd\u5ea6\u7b49\u56e0\u7d20\u9884\u6d4b\u82b1\u7684\u79cd\u7c7b\u3002 \u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u80fd\u5305\u62ec\u7528\u4e8e\u5ba2\u6237\u7ec6\u5206\u7684\u6570\u636e\u96c6\u3002 \u4f8b\u5982\uff0c\u60a8\u62e5\u6709\u8bbf\u95ee\u60a8\u7684\u7535\u5b50\u5546\u52a1\u7f51\u7ad9\u7684\u5ba2\u6237\u6570\u636e\uff0c\u6216\u8005\u8bbf\u95ee\u5546\u5e97\u6216\u5546\u573a\u7684\u5ba2\u6237\u6570\u636e\uff0c\u800c\u60a8\u5e0c\u671b\u5c06\u5b83\u4eec\u7ec6\u5206\u6216\u805a\u7c7b\u4e3a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u65e0\u76d1\u7763\u6570\u636e\u96c6\u7684\u53e6\u4e00\u4e2a\u4f8b\u5b50\u53ef\u80fd\u5305\u62ec\u4fe1\u7528\u5361\u6b3a\u8bc8\u68c0\u6d4b\u6216\u5bf9\u51e0\u5f20\u56fe\u7247\u8fdb\u884c\u805a\u7c7b\u7b49\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd8\u53ef\u4ee5\u5c06\u6709\u76d1\u7763\u6570\u636e\u96c6\u8f6c\u6362\u4e3a\u65e0\u76d1\u7763\u6570\u636e\u96c6\uff0c\u4ee5\u67e5\u770b\u5b83\u4eec\u5728\u7ed8\u5236\u65f6\u7684\u6548\u679c\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u56fe 3 \u4e2d\u7684\u6570\u636e\u96c6\u3002\u56fe 3 \u663e\u793a\u7684\u662f MNIST \u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u6d41\u884c\u7684\u624b\u5199\u6570\u5b57\u6570\u636e\u96c6\uff0c\u5b83\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\uff0c\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u4f60\u4f1a\u5f97\u5230\u6570\u5b57\u56fe\u50cf\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u6b63\u786e\u6807\u7b7e\u3002\u4f60\u5fc5\u987b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u53ea\u63d0\u4f9b\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\u8bc6\u522b\u51fa\u54ea\u4e2a\u6570\u5b57\u662f\u5b83\u3002 \u56fe 3\uff1aMNIST\u6570\u636e\u96c6 \u5982\u679c\u6211\u4eec\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c t \u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09\u5206\u89e3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u53ea\u9700\u5728\u56fe\u50cf\u50cf\u7d20\u4e0a\u964d\u7ef4\u81f3 2 \u4e2a\u7ef4\u5ea6\uff0c\u5c31\u80fd\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4e0a\u5206\u79bb\u56fe\u50cf\u3002\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1aMNIST \u6570\u636e\u96c6\u7684 t-SNE \u53ef\u89c6\u5316\u3002\u4f7f\u7528\u4e86 3000 \u5e45\u56fe\u50cf\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002\u9996\u5148\u662f\u5bfc\u5165\u6240\u6709\u9700\u8981\u7684\u5e93\u3002 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from sklearn import datasets from sklearn import manifold % matplotlib inline \u6211\u4eec\u4f7f\u7528 matplotlib \u548c seaborn \u8fdb\u884c\u7ed8\u56fe\uff0c\u4f7f\u7528 numpy \u5904\u7406\u6570\u503c\u6570\u7ec4\uff0c\u4f7f\u7528 pandas \u4ece\u6570\u503c\u6570\u7ec4\u521b\u5efa\u6570\u636e\u5e27\uff0c\u4f7f\u7528 scikit-learn (sklearn) \u83b7\u53d6\u6570\u636e\u5e76\u6267\u884c t-SNE\u3002 \u5bfc\u5165\u540e\uff0c\u6211\u4eec\u9700\u8981\u4e0b\u8f7d\u6570\u636e\u5e76\u5355\u72ec\u8bfb\u53d6\uff0c\u6216\u8005\u4f7f\u7528 sklearn \u7684\u5185\u7f6e\u51fd\u6570\u6765\u63d0\u4f9b MNIST \u6570\u636e\u96c6\u3002 data = datasets . fetch_openml ( 'mnist_784' , version = 1 , return_X_y = True ) pixel_values , targets = data targets = targets . astype ( int ) \u5728\u8fd9\u90e8\u5206\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 sklearn \u6570\u636e\u96c6\u83b7\u53d6\u4e86\u6570\u636e\uff0c\u5e76\u83b7\u5f97\u4e86\u4e00\u4e2a\u50cf\u7d20\u503c\u6570\u7ec4\u548c\u53e6\u4e00\u4e2a\u76ee\u6807\u6570\u7ec4\u3002\u7531\u4e8e\u76ee\u6807\u662f\u5b57\u7b26\u4e32\u7c7b\u578b\uff0c\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u3002 pixel_values \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a 70000x784 \u7684\u4e8c\u7ef4\u6570\u7ec4\u3002 \u5171\u6709 70000 \u5f20\u4e0d\u540c\u7684\u56fe\u50cf\uff0c\u6bcf\u5f20\u56fe\u50cf\u5927\u5c0f\u4e3a 28x28 \u50cf\u7d20\u3002\u5e73\u94fa 28x28 \u540e\u5f97\u5230 784 \u4e2a\u6570\u636e\u70b9\u3002 \u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u91cd\u5851\u4e3a\u539f\u6765\u7684\u5f62\u72b6\uff0c\u7136\u540e\u4f7f\u7528 matplotlib \u7ed8\u5236\u6210\u56fe\u8868\uff0c\u4ece\u800c\u5c06\u5176\u53ef\u89c6\u5316\u3002 single_image = pixel_values [ 1 , :] . reshape ( 28 , 28 ) plt . imshow ( single_image , cmap = 'gray' ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u7ed8\u5236\u5982\u4e0b\u56fe\u50cf\uff1a \u56fe 5\uff1a\u7ed8\u5236MNIST\u6570\u636e\u96c6\u5355\u5f20\u56fe\u7247 \u6700\u91cd\u8981\u7684\u4e00\u6b65\u662f\u5728\u6211\u4eec\u83b7\u53d6\u6570\u636e\u4e4b\u540e\u3002 tsne = manifold . TSNE ( n_components = 2 , random_state = 42 ) transformed_data = tsne . fit_transform ( pixel_values [: 3000 , :]) \u8fd9\u4e00\u6b65\u521b\u5efa\u4e86\u6570\u636e\u7684 t-SNE \u53d8\u6362\u3002\u6211\u4eec\u53ea\u4f7f\u7528 2 \u4e2a\u7ef4\u5ea6\uff0c\u56e0\u4e3a\u5728\u4e8c\u7ef4\u73af\u5883\u4e2d\u53ef\u4ee5\u5f88\u597d\u5730\u5c06\u5b83\u4eec\u53ef\u89c6\u5316\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u8f6c\u6362\u540e\u7684\u6570\u636e\u662f\u4e00\u4e2a 3000x2 \u5f62\u72b6\u7684\u6570\u7ec4\uff083000 \u884c 2 \u5217\uff09\u3002\u5728\u6570\u7ec4\u4e0a\u8c03\u7528 pd.DataFrame \u53ef\u4ee5\u5c06\u8fd9\u6837\u7684\u6570\u636e\u8f6c\u6362\u4e3a pandas \u6570\u636e\u5e27\u3002 tsne_df = pd . DataFrame ( np . column_stack (( transformed_data , targets [: 3000 ])), columns = [ \"x\" , \"y\" , \"targets\" ]) tsne_df . loc [:, \"targets\" ] = tsne_df . targets . astype ( int ) \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a numpy \u6570\u7ec4\u521b\u5efa\u4e00\u4e2a pandas \u6570\u636e\u5e27\u3002x \u548c y \u662f t-SNE \u5206\u89e3\u7684\u4e24\u4e2a\u7ef4\u5ea6\uff0ctarget \u662f\u5b9e\u9645\u6570\u5b57\u3002\u8fd9\u6837\u6211\u4eec\u5c31\u5f97\u5230\u4e86\u5982\u56fe 6 \u6240\u793a\u7684\u6570\u636e\u5e27\u3002 \u56fe 6\uff1at-SNE\u540e\u6570\u636e\u524d10\u884c \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 seaborn \u548c matplotlib \u7ed8\u5236\u5b83\u3002 grid = sns . FacetGrid ( tsne_df , hue = \"targets\" , size = 8 ) grid . map ( plt . scatter , \"x\" , \"y\" ) . add_legend () \u8fd9\u662f\u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u89c6\u5316\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5728\u540c\u4e00\u6570\u636e\u96c6\u4e0a\u8fdb\u884c k-means \u805a\u7c7b\uff0c\u770b\u770b\u5b83\u5728\u65e0\u76d1\u7763\u73af\u5883\u4e0b\u7684\u8868\u73b0\u5982\u4f55\u3002\u4e00\u4e2a\u7ecf\u5e38\u51fa\u73b0\u7684\u95ee\u9898\u662f\uff0c\u5982\u4f55\u5728 k-means \u805a\u7c7b\u4e2d\u627e\u5230\u6700\u4f73\u7684\u7c07\u6570\u3002\u8fd9\u4e2a\u95ee\u9898\u6ca1\u6709\u6b63\u786e\u7b54\u6848\u3002\u4f60\u5fc5\u987b\u901a\u8fc7\u4ea4\u53c9\u9a8c\u8bc1\u6765\u627e\u5230\u6700\u4f73\u7c07\u6570\u3002\u672c\u4e66\u7a0d\u540e\u5c06\u8ba8\u8bba\u4ea4\u53c9\u9a8c\u8bc1\u3002\u8bf7\u6ce8\u610f\uff0c\u4e0a\u8ff0\u4ee3\u7801\u662f\u5728 jupyter \u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\u7684\u3002 \u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 jupyter \u505a\u4e00\u4e9b\u7b80\u5355\u7684\u4e8b\u60c5\uff0c\u6bd4\u5982\u4e0a\u9762\u7684\u4f8b\u5b50\u548c \u7ed8\u56fe\u3002\u5bf9\u4e8e\u672c\u4e66\u4e2d\u7684\u5927\u90e8\u5206\u5185\u5bb9\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 python \u811a\u672c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6 IDE \u56e0\u4e3a\u7ed3\u679c\u90fd\u662f\u4e00\u6837\u7684\u3002 MNIST \u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u628a\u5b83\u8f6c\u6362\u6210\u4e00\u4e2a\u65e0\u76d1\u7763\u7684\u95ee\u9898\uff0c\u53ea\u662f\u4e3a\u4e86\u68c0\u67e5\u5b83\u662f\u5426\u80fd\u5e26\u6765\u4efb\u4f55\u597d\u7684\u7ed3\u679c\u3002\u5982\u679c\u6211\u4eec\u4f7f\u7528\u5206\u7c7b\u7b97\u6cd5\uff0c\u6548\u679c\u4f1a\u66f4\u597d\u3002\u8ba9\u6211\u4eec\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u4e00\u63a2\u7a76\u7adf\u3002","title":"\u6709\u76d1\u7763\u548c\u65e0\u76d1\u7763\u5b66\u4e60"},{"location":"%E6%97%A0%E7%9B%91%E7%9D%A3%E5%92%8C%E6%9C%89%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0/#_1","text":"\u5728\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u65f6\uff0c\u901a\u5e38\u6709\u4e24\u7c7b\u6570\u636e\uff08\u548c\u673a\u5668\u5b66\u4e60\u6a21\u578b\uff09\uff1a \u76d1\u7763\u6570\u636e\uff1a\u603b\u662f\u6709\u4e00\u4e2a\u6216\u591a\u4e2a\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807 \u65e0\u76d1\u7763\u6570\u636e\uff1a\u6ca1\u6709\u4efb\u4f55\u76ee\u6807\u53d8\u91cf\u3002 \u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\u3002\u6211\u4eec\u9700\u8981\u9884\u6d4b\u4e00\u4e2a\u503c\u7684\u95ee\u9898\u88ab\u79f0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u3002\u4f8b\u5982\uff0c\u5982\u679c\u95ee\u9898\u662f\u6839\u636e\u5386\u53f2\u623f\u4ef7\u9884\u6d4b\u623f\u4ef7\uff0c\u90a3\u4e48\u533b\u9662\u3001\u5b66\u6821\u6216\u8d85\u5e02\u7684\u5b58\u5728\uff0c\u4e0e\u6700\u8fd1\u516c\u5171\u4ea4\u901a\u7684\u8ddd\u79bb\u7b49\u7279\u5f81\u5c31\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u540c\u6837\uff0c\u5f53\u6211\u4eec\u5f97\u5230\u732b\u548c\u72d7\u7684\u56fe\u50cf\u65f6\uff0c\u6211\u4eec\u4e8b\u5148\u77e5\u9053\u54ea\u4e9b\u662f\u732b\uff0c\u54ea\u4e9b\u662f\u72d7\uff0c\u5982\u679c\u4efb\u52a1\u662f\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u6240\u63d0\u4f9b\u7684\u56fe\u50cf\u662f\u732b\u8fd8\u662f\u72d7\uff0c\u90a3\u4e48\u8fd9\u4e2a\u95ee\u9898\u5c31\u88ab\u8ba4\u4e3a\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002 \u56fe 1\uff1a\u6709\u76d1\u7763\u5b66\u4e60\u6570\u636e \u5982\u56fe 1 \u6240\u793a\uff0c\u6570\u636e\u7684\u6bcf\u4e00\u884c\u90fd\u4e0e\u4e00\u4e2a\u76ee\u6807\u6216\u6807\u7b7e\u76f8\u5173\u8054\u3002\u5217\u662f\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u884c\u4ee3\u8868\u4e0d\u540c\u7684\u6570\u636e\u70b9\uff0c\u901a\u5e38\u79f0\u4e3a\u6837\u672c\u3002\u793a\u4f8b\u4e2d\u7684\u5341\u4e2a\u6837\u672c\u6709\u5341\u4e2a\u7279\u5f81\u548c\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\uff0c\u76ee\u6807\u53d8\u91cf\u53ef\u4ee5\u662f\u6570\u5b57\u6216\u7c7b\u522b\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5206\u7c7b\u53d8\u91cf\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5b9e\u6570\uff0c\u95ee\u9898\u5c31\u88ab\u5b9a\u4e49\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6709\u76d1\u7763\u95ee\u9898\u53ef\u5206\u4e3a\u4e24\u4e2a\u5b50\u7c7b\uff1a \u5206\u7c7b\uff1a\u9884\u6d4b\u7c7b\u522b\uff0c\u5982\u732b\u6216\u72d7 \u56de\u5f52\uff1a\u9884\u6d4b\u503c\uff0c\u5982\u623f\u4ef7 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6709\u65f6\u6211\u4eec\u53ef\u80fd\u4f1a\u5728\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u4f7f\u7528\u56de\u5f52\uff0c\u8fd9\u53d6\u51b3\u4e8e\u7528\u4e8e\u8bc4\u4f30\u7684\u6307\u6807\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u7a0d\u540e\u4f1a\u8ba8\u8bba\u8fd9\u4e2a\u95ee\u9898\u3002 \u53e6\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u95ee\u9898\u662f\u65e0\u76d1\u7763\u7c7b\u578b\u3002 \u65e0\u76d1\u7763 \u6570\u636e\u96c6\u6ca1\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\uff0c\u4e00\u822c\u6765\u8bf4\uff0c\u4e0e\u6709\u76d1\u7763\u95ee\u9898\u76f8\u6bd4\uff0c\u5904\u7406\u65e0\u76d1\u7763\u6570\u636e\u96c6\u66f4\u5177\u6311\u6218\u6027\u3002 \u5047\u8bbe\u4f60\u5728\u4e00\u5bb6\u5904\u7406\u4fe1\u7528\u5361\u4ea4\u6613\u7684\u91d1\u878d\u516c\u53f8\u5de5\u4f5c\u3002\u6bcf\u79d2\u949f\u90fd\u6709\u5927\u91cf\u6570\u636e\u6d8c\u5165\u3002\u552f\u4e00\u7684\u95ee\u9898\u662f\uff0c\u5f88\u96be\u627e\u5230\u4e00\u4e2a\u4eba\u6765\u5c06\u6bcf\u7b14\u4ea4\u6613\u6807\u8bb0\u4e3a\u6709\u6548\u4ea4\u6613\u3001\u771f\u5b9e\u4ea4\u6613\u6216\u6b3a\u8bc8\u4ea4\u6613\u3002\u5f53\u6211\u4eec\u6ca1\u6709\u4efb\u4f55\u5173\u4e8e\u4ea4\u6613\u662f\u6b3a\u8bc8\u8fd8\u662f\u771f\u5b9e\u7684\u4fe1\u606f\u65f6\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u65e0\u76d1\u7763\u95ee\u9898\u3002\u8981\u89e3\u51b3\u8fd9\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5fc5\u987b\u8003\u8651\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a\u591a\u5c11\u4e2a \u805a\u7c7b \u3002\u805a\u7c7b\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd8\u6709\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5e94\u7528\u4e8e\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5bf9\u4e8e\u6b3a\u8bc8\u68c0\u6d4b\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\u6570\u636e\u53ef\u4ee5\u5206\u4e3a\u4e24\u7c7b\uff08\u6b3a\u8bc8\u6216\u771f\u5b9e\uff09\u3002 \u5f53\u6211\u4eec\u77e5\u9053\u805a\u7c7b\u7684\u6570\u91cf\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u805a\u7c7b\u7b97\u6cd5\u6765\u89e3\u51b3\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5728\u56fe 2 \u4e2d\uff0c\u5047\u8bbe\u6570\u636e\u5206\u4e3a\u4e24\u7c7b\uff0c\u6df1\u8272\u4ee3\u8868\u6b3a\u8bc8\uff0c\u6d45\u8272\u4ee3\u8868\u771f\u5b9e\u4ea4\u6613\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528\u805a\u7c7b\u65b9\u6cd5\u4e4b\u524d\uff0c\u6211\u4eec\u5e76\u4e0d\u77e5\u9053\u8fd9\u4e9b\u7c7b\u522b\u3002\u5e94\u7528\u805a\u7c7b\u7b97\u6cd5\u540e\uff0c\u6211\u4eec\u5e94\u8be5\u80fd\u591f\u533a\u5206\u8fd9\u4e24\u4e2a\u5047\u5b9a\u76ee\u6807\u3002 \u4e3a\u4e86\u7406\u89e3\u65e0\u76d1\u7763\u95ee\u9898\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u8bb8\u591a\u5206\u89e3\u6280\u672f\uff0c\u5982 \u4e3b\u6210\u5206\u5206\u6790\uff08PCA\uff09\u3001t-\u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09 \u7b49\u3002 \u6709\u76d1\u7763\u7684\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\uff0c\u56e0\u4e3a\u5b83\u4eec\u5f88\u5bb9\u6613\u8bc4\u4f30\u3002\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u8bc4\u4f30\u6280\u672f\u3002\u7136\u800c\uff0c\u5bf9\u65e0\u76d1\u7763\u7b97\u6cd5\u7684\u7ed3\u679c\u8fdb\u884c\u8bc4\u4f30\u5177\u6709\u6311\u6218\u6027\uff0c\u9700\u8981\u5927\u91cf\u7684\u4eba\u4e3a\u5e72\u9884\u6216\u542f\u53d1\u5f0f\u65b9\u6cd5\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3b\u8981\u5173\u6ce8\u6709\u76d1\u7763\u6570\u636e\u548c\u6a21\u578b\uff0c\u4f46\u8fd9\u5e76\u4e0d\u610f\u5473\u7740\u6211\u4eec\u4f1a\u5ffd\u7565\u65e0\u76d1\u7763\u6570\u636e\u95ee\u9898\u3002 \u56fe 2\uff1a\u65e0\u76d1\u7763\u5b66\u4e60\u6570\u636e\u96c6 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u5f53\u4eba\u4eec\u5f00\u59cb\u5b66\u4e60\u6570\u636e\u79d1\u5b66\u6216\u673a\u5668\u5b66\u4e60\u65f6\uff0c\u90fd\u4f1a\u4ece\u975e\u5e38\u8457\u540d\u7684\u6570\u636e\u96c6\u5f00\u59cb\uff0c\u4f8b\u5982\u6cf0\u5766\u5c3c\u514b\u6570\u636e\u96c6\u6216\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\uff0c\u8fd9\u4e9b\u90fd\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u5728\u6cf0\u5766\u5c3c\u514b\u53f7\u6570\u636e\u96c6\u4e2d\uff0c\u4f60\u5fc5\u987b\u6839\u636e\u8239\u7968\u7b49\u7ea7\u3001\u6027\u522b\u3001\u5e74\u9f84\u7b49\u56e0\u7d20\u9884\u6d4b\u6cf0\u5766\u5c3c\u514b\u53f7\u4e0a\u4e58\u5ba2\u7684\u5b58\u6d3b\u7387\u3002\u540c\u6837\uff0c\u5728\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\u4e2d\uff0c\u60a8\u5fc5\u987b\u6839\u636e\u843c\u7247\u5bbd\u5ea6\u3001\u82b1\u74e3\u957f\u5ea6\u3001\u843c\u7247\u957f\u5ea6\u548c\u82b1\u74e3\u5bbd\u5ea6\u7b49\u56e0\u7d20\u9884\u6d4b\u82b1\u7684\u79cd\u7c7b\u3002 \u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u80fd\u5305\u62ec\u7528\u4e8e\u5ba2\u6237\u7ec6\u5206\u7684\u6570\u636e\u96c6\u3002 \u4f8b\u5982\uff0c\u60a8\u62e5\u6709\u8bbf\u95ee\u60a8\u7684\u7535\u5b50\u5546\u52a1\u7f51\u7ad9\u7684\u5ba2\u6237\u6570\u636e\uff0c\u6216\u8005\u8bbf\u95ee\u5546\u5e97\u6216\u5546\u573a\u7684\u5ba2\u6237\u6570\u636e\uff0c\u800c\u60a8\u5e0c\u671b\u5c06\u5b83\u4eec\u7ec6\u5206\u6216\u805a\u7c7b\u4e3a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u65e0\u76d1\u7763\u6570\u636e\u96c6\u7684\u53e6\u4e00\u4e2a\u4f8b\u5b50\u53ef\u80fd\u5305\u62ec\u4fe1\u7528\u5361\u6b3a\u8bc8\u68c0\u6d4b\u6216\u5bf9\u51e0\u5f20\u56fe\u7247\u8fdb\u884c\u805a\u7c7b\u7b49\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd8\u53ef\u4ee5\u5c06\u6709\u76d1\u7763\u6570\u636e\u96c6\u8f6c\u6362\u4e3a\u65e0\u76d1\u7763\u6570\u636e\u96c6\uff0c\u4ee5\u67e5\u770b\u5b83\u4eec\u5728\u7ed8\u5236\u65f6\u7684\u6548\u679c\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u56fe 3 \u4e2d\u7684\u6570\u636e\u96c6\u3002\u56fe 3 \u663e\u793a\u7684\u662f MNIST \u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u6d41\u884c\u7684\u624b\u5199\u6570\u5b57\u6570\u636e\u96c6\uff0c\u5b83\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\uff0c\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u4f60\u4f1a\u5f97\u5230\u6570\u5b57\u56fe\u50cf\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u6b63\u786e\u6807\u7b7e\u3002\u4f60\u5fc5\u987b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u53ea\u63d0\u4f9b\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\u8bc6\u522b\u51fa\u54ea\u4e2a\u6570\u5b57\u662f\u5b83\u3002 \u56fe 3\uff1aMNIST\u6570\u636e\u96c6 \u5982\u679c\u6211\u4eec\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c t \u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09\u5206\u89e3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u53ea\u9700\u5728\u56fe\u50cf\u50cf\u7d20\u4e0a\u964d\u7ef4\u81f3 2 \u4e2a\u7ef4\u5ea6\uff0c\u5c31\u80fd\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4e0a\u5206\u79bb\u56fe\u50cf\u3002\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1aMNIST \u6570\u636e\u96c6\u7684 t-SNE \u53ef\u89c6\u5316\u3002\u4f7f\u7528\u4e86 3000 \u5e45\u56fe\u50cf\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002\u9996\u5148\u662f\u5bfc\u5165\u6240\u6709\u9700\u8981\u7684\u5e93\u3002 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from sklearn import datasets from sklearn import manifold % matplotlib inline \u6211\u4eec\u4f7f\u7528 matplotlib \u548c seaborn \u8fdb\u884c\u7ed8\u56fe\uff0c\u4f7f\u7528 numpy \u5904\u7406\u6570\u503c\u6570\u7ec4\uff0c\u4f7f\u7528 pandas \u4ece\u6570\u503c\u6570\u7ec4\u521b\u5efa\u6570\u636e\u5e27\uff0c\u4f7f\u7528 scikit-learn (sklearn) \u83b7\u53d6\u6570\u636e\u5e76\u6267\u884c t-SNE\u3002 \u5bfc\u5165\u540e\uff0c\u6211\u4eec\u9700\u8981\u4e0b\u8f7d\u6570\u636e\u5e76\u5355\u72ec\u8bfb\u53d6\uff0c\u6216\u8005\u4f7f\u7528 sklearn \u7684\u5185\u7f6e\u51fd\u6570\u6765\u63d0\u4f9b MNIST \u6570\u636e\u96c6\u3002 data = datasets . fetch_openml ( 'mnist_784' , version = 1 , return_X_y = True ) pixel_values , targets = data targets = targets . astype ( int ) \u5728\u8fd9\u90e8\u5206\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 sklearn \u6570\u636e\u96c6\u83b7\u53d6\u4e86\u6570\u636e\uff0c\u5e76\u83b7\u5f97\u4e86\u4e00\u4e2a\u50cf\u7d20\u503c\u6570\u7ec4\u548c\u53e6\u4e00\u4e2a\u76ee\u6807\u6570\u7ec4\u3002\u7531\u4e8e\u76ee\u6807\u662f\u5b57\u7b26\u4e32\u7c7b\u578b\uff0c\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u3002 pixel_values \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a 70000x784 \u7684\u4e8c\u7ef4\u6570\u7ec4\u3002 \u5171\u6709 70000 \u5f20\u4e0d\u540c\u7684\u56fe\u50cf\uff0c\u6bcf\u5f20\u56fe\u50cf\u5927\u5c0f\u4e3a 28x28 \u50cf\u7d20\u3002\u5e73\u94fa 28x28 \u540e\u5f97\u5230 784 \u4e2a\u6570\u636e\u70b9\u3002 \u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u91cd\u5851\u4e3a\u539f\u6765\u7684\u5f62\u72b6\uff0c\u7136\u540e\u4f7f\u7528 matplotlib \u7ed8\u5236\u6210\u56fe\u8868\uff0c\u4ece\u800c\u5c06\u5176\u53ef\u89c6\u5316\u3002 single_image = pixel_values [ 1 , :] . reshape ( 28 , 28 ) plt . imshow ( single_image , cmap = 'gray' ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u7ed8\u5236\u5982\u4e0b\u56fe\u50cf\uff1a \u56fe 5\uff1a\u7ed8\u5236MNIST\u6570\u636e\u96c6\u5355\u5f20\u56fe\u7247 \u6700\u91cd\u8981\u7684\u4e00\u6b65\u662f\u5728\u6211\u4eec\u83b7\u53d6\u6570\u636e\u4e4b\u540e\u3002 tsne = manifold . TSNE ( n_components = 2 , random_state = 42 ) transformed_data = tsne . fit_transform ( pixel_values [: 3000 , :]) \u8fd9\u4e00\u6b65\u521b\u5efa\u4e86\u6570\u636e\u7684 t-SNE \u53d8\u6362\u3002\u6211\u4eec\u53ea\u4f7f\u7528 2 \u4e2a\u7ef4\u5ea6\uff0c\u56e0\u4e3a\u5728\u4e8c\u7ef4\u73af\u5883\u4e2d\u53ef\u4ee5\u5f88\u597d\u5730\u5c06\u5b83\u4eec\u53ef\u89c6\u5316\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u8f6c\u6362\u540e\u7684\u6570\u636e\u662f\u4e00\u4e2a 3000x2 \u5f62\u72b6\u7684\u6570\u7ec4\uff083000 \u884c 2 \u5217\uff09\u3002\u5728\u6570\u7ec4\u4e0a\u8c03\u7528 pd.DataFrame \u53ef\u4ee5\u5c06\u8fd9\u6837\u7684\u6570\u636e\u8f6c\u6362\u4e3a pandas \u6570\u636e\u5e27\u3002 tsne_df = pd . DataFrame ( np . column_stack (( transformed_data , targets [: 3000 ])), columns = [ \"x\" , \"y\" , \"targets\" ]) tsne_df . loc [:, \"targets\" ] = tsne_df . targets . astype ( int ) \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a numpy \u6570\u7ec4\u521b\u5efa\u4e00\u4e2a pandas \u6570\u636e\u5e27\u3002x \u548c y \u662f t-SNE \u5206\u89e3\u7684\u4e24\u4e2a\u7ef4\u5ea6\uff0ctarget \u662f\u5b9e\u9645\u6570\u5b57\u3002\u8fd9\u6837\u6211\u4eec\u5c31\u5f97\u5230\u4e86\u5982\u56fe 6 \u6240\u793a\u7684\u6570\u636e\u5e27\u3002 \u56fe 6\uff1at-SNE\u540e\u6570\u636e\u524d10\u884c \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 seaborn \u548c matplotlib \u7ed8\u5236\u5b83\u3002 grid = sns . FacetGrid ( tsne_df , hue = \"targets\" , size = 8 ) grid . map ( plt . scatter , \"x\" , \"y\" ) . add_legend () \u8fd9\u662f\u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u89c6\u5316\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5728\u540c\u4e00\u6570\u636e\u96c6\u4e0a\u8fdb\u884c k-means \u805a\u7c7b\uff0c\u770b\u770b\u5b83\u5728\u65e0\u76d1\u7763\u73af\u5883\u4e0b\u7684\u8868\u73b0\u5982\u4f55\u3002\u4e00\u4e2a\u7ecf\u5e38\u51fa\u73b0\u7684\u95ee\u9898\u662f\uff0c\u5982\u4f55\u5728 k-means \u805a\u7c7b\u4e2d\u627e\u5230\u6700\u4f73\u7684\u7c07\u6570\u3002\u8fd9\u4e2a\u95ee\u9898\u6ca1\u6709\u6b63\u786e\u7b54\u6848\u3002\u4f60\u5fc5\u987b\u901a\u8fc7\u4ea4\u53c9\u9a8c\u8bc1\u6765\u627e\u5230\u6700\u4f73\u7c07\u6570\u3002\u672c\u4e66\u7a0d\u540e\u5c06\u8ba8\u8bba\u4ea4\u53c9\u9a8c\u8bc1\u3002\u8bf7\u6ce8\u610f\uff0c\u4e0a\u8ff0\u4ee3\u7801\u662f\u5728 jupyter \u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\u7684\u3002 \u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 jupyter \u505a\u4e00\u4e9b\u7b80\u5355\u7684\u4e8b\u60c5\uff0c\u6bd4\u5982\u4e0a\u9762\u7684\u4f8b\u5b50\u548c \u7ed8\u56fe\u3002\u5bf9\u4e8e\u672c\u4e66\u4e2d\u7684\u5927\u90e8\u5206\u5185\u5bb9\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 python \u811a\u672c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6 IDE \u56e0\u4e3a\u7ed3\u679c\u90fd\u662f\u4e00\u6837\u7684\u3002 MNIST \u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u628a\u5b83\u8f6c\u6362\u6210\u4e00\u4e2a\u65e0\u76d1\u7763\u7684\u95ee\u9898\uff0c\u53ea\u662f\u4e3a\u4e86\u68c0\u67e5\u5b83\u662f\u5426\u80fd\u5e26\u6765\u4efb\u4f55\u597d\u7684\u7ed3\u679c\u3002\u5982\u679c\u6211\u4eec\u4f7f\u7528\u5206\u7c7b\u7b97\u6cd5\uff0c\u6548\u679c\u4f1a\u66f4\u597d\u3002\u8ba9\u6211\u4eec\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u4e00\u63a2\u7a76\u7adf\u3002","title":"\u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60"},{"location":"%E7%89%B9%E5%BE%81%E5%B7%A5%E7%A8%8B/","text":"\u7279\u5f81\u5de5\u7a0b \u7279\u5f81\u5de5\u7a0b\u662f\u6784\u5efa\u826f\u597d\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u90e8\u5206\u4e4b\u4e00\u3002\u5982\u679c\u6211\u4eec\u62e5\u6709\u6709\u7528\u7684\u7279\u5f81\uff0c\u6a21\u578b\u5c31\u4f1a\u8868\u73b0\u5f97\u66f4\u597d\u3002\u5728\u8bb8\u591a\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u907f\u514d\u4f7f\u7528\u5927\u578b\u590d\u6742\u6a21\u578b\uff0c\u800c\u4f7f\u7528\u5177\u6709\u5173\u952e\u5de5\u7a0b\u7279\u5f81\u7684\u7b80\u5355\u6a21\u578b\u3002\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\uff0c\u53ea\u6709\u5f53\u4f60\u5bf9\u95ee\u9898\u7684\u9886\u57df\u6709\u4e00\u5b9a\u7684\u4e86\u89e3\uff0c\u5e76\u4e14\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u76f8\u5173\u6570\u636e\u65f6\uff0c\u624d\u80fd\u4ee5\u6700\u4f73\u65b9\u5f0f\u5b8c\u6210\u7279\u5f81\u5de5\u7a0b\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e9b\u901a\u7528\u6280\u672f\uff0c\u4ece\u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u7279\u5f81\u3002\u7279\u5f81\u5de5\u7a0b\u4e0d\u4ec5\u4ec5\u662f\u4ece\u6570\u636e\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\uff0c\u8fd8\u5305\u62ec\u4e0d\u540c\u7c7b\u578b\u7684\u5f52\u4e00\u5316\u548c\u8f6c\u6362\u3002 \u5728\u6709\u5173\u5206\u7c7b\u7279\u5f81\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u4e86\u89e3\u4e86\u7ed3\u5408\u4e0d\u540c\u5206\u7c7b\u53d8\u91cf\u7684\u65b9\u6cd5\u3001\u5982\u4f55\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u8ba1\u6570\u3001\u6807\u7b7e\u7f16\u7801\u548c\u4f7f\u7528\u5d4c\u5165\u3002\u8fd9\u4e9b\u51e0\u4e4e\u90fd\u662f\u5229\u7528\u5206\u7c7b\u53d8\u91cf\u8bbe\u8ba1\u7279\u5f81\u7684\u65b9\u6cd5\u3002\u56e0\u6b64\uff0c\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u7684\u91cd\u70b9\u5c06\u4ec5\u9650\u4e8e\u6570\u503c\u53d8\u91cf\u4ee5\u53ca\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u7684\u7ec4\u5408\u3002 \u8ba9\u6211\u4eec\u4ece\u6700\u7b80\u5355\u4f46\u5e94\u7528\u6700\u5e7f\u6cdb\u7684\u7279\u5f81\u5de5\u7a0b\u6280\u672f\u5f00\u59cb\u3002\u5047\u8bbe\u4f60\u6b63\u5728\u5904\u7406\u65e5\u671f\u548c\u65f6\u95f4\u6570\u636e\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709\u4e00\u4e2a\u5e26\u6709\u65e5\u671f\u7c7b\u578b\u5217\u7684 pandas \u6570\u636e\u5e27\u3002\u5229\u7528\u8fd9\u4e00\u5217\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4ee5\u4e0b\u7279\u5f81\uff1a \u5e74 \u5e74\u4e2d\u7684\u5468 \u6708 \u661f\u671f \u5468\u672b \u5c0f\u65f6 \u8fd8\u6709\u66f4\u591a \u800c\u4f7f\u7528pandas\u5c31\u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u6dfb\u52a0'year'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5e74\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'year' ] = df [ 'datetime_column' ] . dt . year # \u6dfb\u52a0'weekofyear'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5468\u6570\u63d0\u53d6\u51fa\u6765 df . loc [:, 'weekofyear' ] = df [ 'datetime_column' ] . dt . weekofyear # \u6dfb\u52a0'month'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u6708\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'month' ] = df [ 'datetime_column' ] . dt . month # \u6dfb\u52a0'dayofweek'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u661f\u671f\u51e0\u63d0\u53d6\u51fa\u6765 df . loc [:, 'dayofweek' ] = df [ 'datetime_column' ] . dt . dayofweek # \u6dfb\u52a0'weekend'\u5217\uff0c\u5224\u65ad\u5f53\u5929\u662f\u5426\u4e3a\u5468\u672b df . loc [:, 'weekend' ] = ( df . datetime_column . dt . weekday >= 5 ) . astype ( int ) # \u6dfb\u52a0 'hour' \u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5c0f\u65f6\u63d0\u53d6\u51fa\u6765 df . loc [:, 'hour' ] = df [ 'datetime_column' ] . dt . hour \u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u65e5\u671f\u65f6\u95f4\u5217\u521b\u5efa\u4e00\u7cfb\u5217\u65b0\u5217\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u53ef\u4ee5\u521b\u5efa\u7684\u4e00\u4e9b\u793a\u4f8b\u529f\u80fd\u3002 import pandas as pd # \u521b\u5efa\u65e5\u671f\u65f6\u95f4\u5e8f\u5217\uff0c\u5305\u542b\u4e86\u4ece '2020-01-06' \u5230 '2020-01-10' \u7684\u65e5\u671f\u65f6\u95f4\u70b9\uff0c\u65f6\u95f4\u95f4\u9694\u4e3a10\u5c0f\u65f6 s = pd . date_range ( '2020-01-06' , '2020-01-10' , freq = '10H' ) . to_series () # \u63d0\u53d6\u5bf9\u5e94\u65f6\u95f4\u7279\u5f81 features = { \"dayofweek\" : s . dt . dayofweek . values , \"dayofyear\" : s . dt . dayofyear . values , \"hour\" : s . dt . hour . values , \"is_leap_year\" : s . dt . is_leap_year . values , \"quarter\" : s . dt . quarter . values , \"weekofyear\" : s . dt . weekofyear . values } \u8fd9\u5c06\u4ece\u7ed9\u5b9a\u7cfb\u5217\u4e2d\u751f\u6210\u4e00\u4e2a\u7279\u5f81\u5b57\u5178\u3002\u60a8\u53ef\u4ee5\u5c06\u6b64\u5e94\u7528\u4e8e pandas \u6570\u636e\u4e2d\u7684\u4efb\u4f55\u65e5\u671f\u65f6\u95f4\u5217\u3002\u8fd9\u4e9b\u662f pandas \u63d0\u4f9b\u7684\u4f17\u591a\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u4e2d\u7684\u4e00\u90e8\u5206\u3002\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u975e\u5e38\u91cd\u8981\uff0c\u4f8b\u5982\uff0c\u5728\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97\u7684\u9500\u552e\u989d\u65f6\uff0c\u5982\u679c\u60f3\u5728\u805a\u5408\u7279\u5f81\u4e0a\u4f7f\u7528 xgboost \u7b49\u6a21\u578b\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u5c31\u975e\u5e38\u91cd\u8981\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u4e0b\u6240\u793a\u7684\u6570\u636e\uff1a \u56fe 1\uff1a\u5305\u542b\u5206\u7c7b\u548c\u65e5\u671f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u6709\u4e00\u4e2a\u65e5\u671f\u5217\uff0c\u4ece\u4e2d\u53ef\u4ee5\u8f7b\u677e\u63d0\u53d6\u5e74\u3001\u6708\u3001\u5b63\u5ea6\u7b49\u7279\u5f81\u3002\u7136\u540e\uff0c\u6211\u4eec\u6709\u4e00\u4e2a customer_id \u5217\uff0c\u8be5\u5217\u6709\u591a\u4e2a\u6761\u76ee\uff0c\u56e0\u6b64\u4e00\u4e2a\u5ba2\u6237\u4f1a\u88ab\u770b\u5230\u5f88\u591a\u6b21\uff08\u622a\u56fe\u4e2d\u770b\u4e0d\u5230\uff09\u3002\u6bcf\u4e2a\u65e5\u671f\u548c\u5ba2\u6237 ID \u90fd\u6709\u4e09\u4e2a\u5206\u7c7b\u7279\u5f81\u548c\u4e00\u4e2a\u6570\u5b57\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u4ece\u4e2d\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff1a - \u5ba2\u6237\u6700\u6d3b\u8dc3\u7684\u6708\u4efd\u662f\u51e0\u6708 - \u67d0\u4e2a\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u7684\u8ba1\u6570\u662f\u591a\u5c11 - \u67d0\u5e74\u67d0\u6708\u67d0\u5468\u67d0\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u6570\u91cf\u662f\u591a\u5c11\uff1f - \u67d0\u4e2a\u5ba2\u6237\u7684 num1 \u5e73\u5747\u503c\u662f\u591a\u5c11\uff1f - \u7b49\u7b49\u3002 \u4f7f\u7528 pandas \u4e2d\u7684\u805a\u5408\uff0c\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u521b\u5efa\u7c7b\u4f3c\u7684\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u5b9e\u73b0\u3002 def generate_features ( df ): df . loc [:, 'year' ] = df [ 'date' ] . dt . year df . loc [:, 'weekofyear' ] = df [ 'date' ] . dt . weekofyear df . loc [:, 'month' ] = df [ 'date' ] . dt . month df . loc [:, 'dayofweek' ] = df [ 'date' ] . dt . dayofweek df . loc [:, 'weekend' ] = ( df [ 'date' ] . dt . weekday >= 5 ) . astype ( int ) aggs = {} # \u5bf9 'month' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'month' ] = [ 'nunique' , 'mean' ] # \u5bf9 'weekofyear' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'weekofyear' ] = [ 'nunique' , 'mean' ] # \u5bf9 'num1' \u5217\u8fdb\u884c sum\u3001max\u3001min\u3001mean \u805a\u5408 aggs [ 'num1' ] = [ 'sum' , 'max' , 'min' , 'mean' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c size \u805a\u5408 aggs [ 'customer_id' ] = [ 'size' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c nunique \u805a\u5408 aggs [ 'customer_id' ] = [ 'nunique' ] # \u5bf9\u6570\u636e\u5e94\u7528\u4e0d\u540c\u7684\u805a\u5408\u51fd\u6570 agg_df = df . groupby ( 'customer_id' ) . agg ( aggs ) # \u91cd\u7f6e\u7d22\u5f15 agg_df = agg_df . reset_index () return agg_df \u8bf7\u6ce8\u610f\uff0c\u5728\u4e0a\u8ff0\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u8df3\u8fc7\u4e86\u5206\u7c7b\u53d8\u91cf\uff0c\u4f46\u60a8\u53ef\u4ee5\u50cf\u4f7f\u7528\u5176\u4ed6\u805a\u5408\u53d8\u91cf\u4e00\u6837\u4f7f\u7528\u5b83\u4eec\u3002 \u56fe 2\uff1a\u603b\u4f53\u7279\u5f81\u548c\u5176\u4ed6\u7279\u5f81 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u56fe 2 \u4e2d\u7684\u6570\u636e\u4e0e\u5e26\u6709 customer_id \u5217\u7684\u539f\u59cb\u6570\u636e\u5e27\u8fde\u63a5\u8d77\u6765\uff0c\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5e76\u4e0d\u662f\u8981\u9884\u6d4b\u4ec0\u4e48\uff1b\u6211\u4eec\u53ea\u662f\u5728\u521b\u5efa\u901a\u7528\u7279\u5f81\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u8bd5\u56fe\u5728\u8fd9\u91cc\u9884\u6d4b\u4ec0\u4e48\uff0c\u521b\u5efa\u7279\u5f81\u4f1a\u66f4\u5bb9\u6613\u3002 \u4f8b\u5982\uff0c\u6709\u65f6\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u95ee\u9898\u65f6\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u7684\u7279\u5f81\u4e0d\u662f\u5355\u4e2a\u503c\uff0c\u800c\u662f\u4e00\u7cfb\u5217\u503c\u3002 \u4f8b\u5982\uff0c\u5ba2\u6237\u5728\u7279\u5b9a\u65f6\u95f4\u6bb5\u5185\u7684\u4ea4\u6613\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4f1a\u521b\u5efa\u4e0d\u540c\u7c7b\u578b\u7684\u7279\u5f81\uff0c\u4f8b\u5982\uff1a\u4f7f\u7528\u6570\u503c\u7279\u5f81\u65f6\uff0c\u5728\u5bf9\u5206\u7c7b\u5217\u8fdb\u884c\u5206\u7ec4\u65f6\uff0c\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e8e\u65f6\u95f4\u5206\u5e03\u503c\u5217\u8868\u7684\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u7cfb\u5217\u7edf\u8ba1\u7279\u5f81\uff0c\u4f8b\u5982 \u5e73\u5747\u503c \u6700\u5927\u503c \u6700\u5c0f\u503c \u72ec\u7279\u6027 \u504f\u659c \u5cf0\u5ea6 Kstat \u767e\u5206\u4f4d\u6570 \u5b9a\u91cf \u5cf0\u503c\u5230\u5cf0\u503c \u4ee5\u53ca\u66f4\u591a \u8fd9\u4e9b\u53ef\u4ee5\u4f7f\u7528\u7b80\u5355\u7684 numpy \u51fd\u6570\u521b\u5efa\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 import numpy as np # \u521b\u5efa\u5b57\u5178\uff0c\u7528\u4e8e\u5b58\u50a8\u4e0d\u540c\u7684\u7edf\u8ba1\u7279\u5f81 feature_dict = {} # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5e73\u5747\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'mean' \u952e\u4e0b feature_dict [ 'mean' ] = np . mean ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5927\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'max' \u952e\u4e0b feature_dict [ 'max' ] = np . max ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5c0f\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'min' \u952e\u4e0b feature_dict [ 'min' ] = np . min ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6807\u51c6\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'std' \u952e\u4e0b feature_dict [ 'std' ] = np . std ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u65b9\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'var' \u952e\u4e0b feature_dict [ 'var' ] = np . var ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5dee\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'ptp' \u952e\u4e0b feature_dict [ 'ptp' ] = np . ptp ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c10\u767e\u5206\u4f4d\u6570\uff08\u5373\u767e\u5206\u4e4b10\u5206\u4f4d\u6570\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_10' \u952e\u4e0b feature_dict [ 'percentile_10' ] = np . percentile ( x , 10 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c60\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_60' \u952e\u4e0b feature_dict [ 'percentile_60' ] = np . percentile ( x , 60 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c90\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_90' \u952e\u4e0b feature_dict [ 'percentile_90' ] = np . percentile ( x , 90 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u76845%\u5206\u4f4d\u6570\uff08\u53730.05\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_5' \u952e\u4e0b feature_dict [ 'quantile_5' ] = np . quantile ( x , 0.05 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768495%\u5206\u4f4d\u6570\uff08\u53730.95\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_95' \u952e\u4e0b feature_dict [ 'quantile_95' ] = np . quantile ( x , 0.95 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768499%\u5206\u4f4d\u6570\uff08\u53730.99\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_99' \u952e\u4e0b feature_dict [ 'quantile_99' ] = np . quantile ( x , 0.99 ) \u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff08\u6570\u503c\u5217\u8868\uff09\u53ef\u4ee5\u8f6c\u6362\u6210\u8bb8\u591a\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4e00\u4e2a\u540d\u4e3a tsfresh \u7684 python \u5e93\u975e\u5e38\u6709\u7528\u3002 from tsfresh.feature_extraction import feature_calculators as fc # \u8ba1\u7b97 x \u6570\u5217\u7684\u7edd\u5bf9\u80fd\u91cf\uff08abs_energy\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'abs_energy' \u952e\u4e0b feature_dict [ 'abs_energy' ] = fc . abs_energy ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u9ad8\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_above_mean' \u952e\u4e0b feature_dict [ 'count_above_mean' ] = fc . count_above_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u4f4e\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_below_mean' \u952e\u4e0b feature_dict [ 'count_below_mean' ] = fc . count_below_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u7edd\u5bf9\u53d8\u5316\uff08mean_abs_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_abs_change' \u952e\u4e0b feature_dict [ 'mean_abs_change' ] = fc . mean_abs_change ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u53d8\u5316\u7387\uff08mean_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_change' \u952e\u4e0b feature_dict [ 'mean_change' ] = fc . mean_change ( x ) \u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\uff1btsfresh \u63d0\u4f9b\u4e86\u6570\u767e\u79cd\u7279\u5f81\u548c\u6570\u5341\u79cd\u4e0d\u540c\u7279\u5f81\u7684\u53d8\u4f53\uff0c\u4f60\u53ef\u4ee5\u5c06\u5b83\u4eec\u7528\u4e8e\u57fa\u4e8e\u65f6\u95f4\u5e8f\u5217\uff08\u503c\u5217\u8868\uff09\u7684\u7279\u5f81\u3002\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0cx \u662f\u4e00\u4e2a\u503c\u5217\u8868\u3002\u4f46\u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\u3002\u60a8\u8fd8\u53ef\u4ee5\u4e3a\u5305\u542b\u6216\u4e0d\u5305\u542b\u5206\u7c7b\u6570\u636e\u7684\u6570\u503c\u6570\u636e\u521b\u5efa\u8bb8\u591a\u5176\u4ed6\u7279\u5f81\u3002\u751f\u6210\u8bb8\u591a\u7279\u5f81\u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u521b\u5efa\u4e00\u5806\u591a\u9879\u5f0f\u7279\u5f81\u3002\u4f8b\u5982\uff0c\u4ece\u4e24\u4e2a\u7279\u5f81 \"a \"\u548c \"b \"\u751f\u6210\u7684\u4e8c\u7ea7\u591a\u9879\u5f0f\u7279\u5f81\u5305\u62ec \"a\"\u3001\"b\"\u3001\"ab\"\u3001\"a^2 \"\u548c \"b^2\"\u3002 import numpy as np df = pd . DataFrame ( np . random . rand ( 100 , 2 ), columns = [ f \"f_ { i } \" for i in range ( 1 , 3 )]) \u5982\u56fe 3 \u6240\u793a\uff0c\u5b83\u7ed9\u51fa\u4e86\u4e00\u4e2a\u6570\u636e\u8868\u3002 \u56fe 3\uff1a\u5305\u542b\u4e24\u4e2a\u6570\u5b57\u7279\u5f81\u7684\u968f\u673a\u6570\u636e\u8868 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u7684 PolynomialFeatures \u521b\u5efa\u4e24\u6b21\u591a\u9879\u5f0f\u7279\u5f81\u3002 from sklearn import preprocessing # \u6307\u5b9a\u591a\u9879\u5f0f\u7684\u6b21\u6570\u4e3a 2\uff0c\u4e0d\u4ec5\u8003\u8651\u4ea4\u4e92\u9879\uff0c\u4e0d\u5305\u62ec\u504f\u5dee\uff08include_bias=False\uff09 pf = preprocessing . PolynomialFeatures ( degree = 2 , interaction_only = False , include_bias = False ) # \u62df\u5408\uff0c\u521b\u5efa\u591a\u9879\u5f0f\u7279\u5f81 pf . fit ( df ) # \u8f6c\u6362\u6570\u636e poly_feats = pf . transform ( df ) # \u83b7\u53d6\u751f\u6210\u7684\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf num_feats = poly_feats . shape [ 1 ] # \u4e3a\u65b0\u751f\u6210\u7684\u7279\u5f81\u547d\u540d df_transformed = pd . DataFrame ( poly_feats , columns = [ f \"f_ { i } \" for i in range ( 1 , num_feats + 1 )] ) \u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u6570\u636e\u8868\uff0c\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1a\u5e26\u6709\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e\u8868 \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e9b\u591a\u9879\u5f0f\u7279\u5f81\u3002\u5982\u679c\u521b\u5efa\u7684\u662f\u4e09\u6b21\u591a\u9879\u5f0f\u7279\u5f81\uff0c\u6700\u7ec8\u603b\u5171\u4f1a\u6709\u4e5d\u4e2a\u7279\u5f81\u3002\u7279\u5f81\u7684\u6570\u91cf\u8d8a\u591a\uff0c\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf\u4e5f\u5c31\u8d8a\u591a\uff0c\u800c\u4e14\u4f60\u8fd8\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5982\u679c\u6570\u636e\u96c6\u4e2d\u6709\u5f88\u591a\u6837\u672c\uff0c\u90a3\u4e48\u521b\u5efa\u8fd9\u7c7b\u7279\u5f81\u5c31\u9700\u8981\u82b1\u8d39\u4e00\u4e9b\u65f6\u95f4\u3002 \u56fe 5\uff1a\u6570\u5b57\u7279\u5f81\u5217\u7684\u76f4\u65b9\u56fe \u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u529f\u80fd\u662f\u5c06\u6570\u5b57\u8f6c\u6362\u4e3a\u7c7b\u522b\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \u5206\u7bb1 \u3002\u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u56fe 5\uff0c\u5b83\u663e\u793a\u4e86\u4e00\u4e2a\u968f\u673a\u6570\u5b57\u7279\u5f81\u7684\u6837\u672c\u76f4\u65b9\u56fe\u3002\u6211\u4eec\u5728\u8be5\u56fe\u4e2d\u4f7f\u7528\u4e8610\u4e2a\u5206\u7bb1\uff0c\u53ef\u4ee5\u770b\u5230\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a10\u4e2a\u90e8\u5206\u3002\u8fd9\u53ef\u4ee5\u4f7f\u7528 pandas \u7684cat\u51fd\u6570\u6765\u5b9e\u73b0\u3002 # \u521b\u5efa10\u4e2a\u5206\u7bb1 df [ \"f_bin_10\" ] = pd . cut ( df [ \"f_1\" ], bins = 10 , labels = False ) # \u521b\u5efa100\u4e2a\u5206\u7bb1 df [ \"f_bin_100\" ] = pd . cut ( df [ \"f_1\" ], bins = 100 , labels = False ) \u5982\u56fe 6 \u6240\u793a\uff0c\u8fd9\u5c06\u5728\u6570\u636e\u5e27\u4e2d\u751f\u6210\u4e24\u4e2a\u65b0\u7279\u5f81\u3002 \u56fe 6\uff1a\u6570\u503c\u7279\u5f81\u5206\u7bb1 \u5f53\u4f60\u8fdb\u884c\u5206\u7c7b\u65f6\uff0c\u53ef\u4ee5\u540c\u65f6\u4f7f\u7528\u5206\u7bb1\u548c\u539f\u59cb\u7279\u5f81\u3002\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u540e\u534a\u90e8\u5206\u5b66\u4e60\u66f4\u591a\u5173\u4e8e\u9009\u62e9\u7279\u5f81\u7684\u77e5\u8bc6\u3002\u5206\u7bb1\u8fd8\u53ef\u4ee5\u5c06\u6570\u5b57\u7279\u5f81\u89c6\u4e3a\u5206\u7c7b\u7279\u5f81\u3002 \u53e6\u4e00\u79cd\u53ef\u4ee5\u4ece\u6570\u503c\u7279\u5f81\u4e2d\u521b\u5efa\u7684\u6709\u8da3\u7279\u5f81\u7c7b\u578b\u662f\u5bf9\u6570\u53d8\u6362\u3002\u8bf7\u770b\u56fe 7 \u4e2d\u7684\u7279\u5f81 f_3\u3002 \u4e0e\u5176\u4ed6\u65b9\u5dee\u8f83\u5c0f\u7684\u7279\u5f81\u76f8\u6bd4\uff08\u5047\u8bbe\u5982\u6b64\uff09\uff0cf_3 \u662f\u4e00\u79cd\u65b9\u5dee\u975e\u5e38\u5927\u7684\u7279\u6b8a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5e0c\u671b\u964d\u4f4e\u8fd9\u4e00\u5217\u7684\u65b9\u5dee\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6570\u53d8\u6362\u6765\u5b9e\u73b0\u3002 f_3 \u5217\u7684\u503c\u8303\u56f4\u4e3a 0 \u5230 10000\uff0c\u76f4\u65b9\u56fe\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u7279\u5f81 f_3 \u7684\u76f4\u65b9\u56fe \u6211\u4eec\u53ef\u4ee5\u5bf9\u8fd9\u4e00\u5217\u5e94\u7528 log(1 + x) \u6765\u51cf\u5c11\u5176\u65b9\u5dee\u3002\u56fe 9 \u663e\u793a\u4e86\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u76f4\u65b9\u56fe\u7684\u53d8\u5316\u3002 \u56fe 9\uff1a\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u7684 f_3 \u76f4\u65b9\u56fe \u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e0d\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u548c\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u7684\u65b9\u5dee\u3002 In [ X ]: df . f_3 . var () Out [ X ]: 8077265.875858586 In [ X ]: df . f_3 . apply ( lambda x : np . log ( 1 + x )) . var () Out [ X ]: 0.6058771732119975 \u6709\u65f6\uff0c\u4e5f\u53ef\u4ee5\u7528\u6307\u6570\u6765\u4ee3\u66ff\u5bf9\u6570\u3002\u4e00\u79cd\u975e\u5e38\u6709\u8da3\u7684\u60c5\u51b5\u662f\uff0c\u60a8\u4f7f\u7528\u57fa\u4e8e\u5bf9\u6570\u7684\u8bc4\u4f30\u6307\u6807\uff0c\u4f8b\u5982 RMSLE\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u5728\u5bf9\u6570\u53d8\u6362\u7684\u76ee\u6807\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u7136\u540e\u5728\u9884\u6d4b\u65f6\u4f7f\u7528\u6307\u6570\u503c\u8f6c\u6362\u56de\u539f\u59cb\u503c\u3002\u8fd9\u5c06\u6709\u52a9\u4e8e\u9488\u5bf9\u6307\u6807\u4f18\u5316\u6a21\u578b\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u7c7b\u6570\u5b57\u7279\u5f81\u90fd\u662f\u57fa\u4e8e\u76f4\u89c9\u521b\u5efa\u7684\u3002\u6ca1\u6709\u516c\u5f0f\u53ef\u5faa\u3002\u5982\u679c\u60a8\u4ece\u4e8b\u7684\u662f\u67d0\u4e00\u884c\u4e1a\uff0c\u60a8\u5c06\u521b\u5efa\u7279\u5b9a\u884c\u4e1a\u7684\u7279\u5f81\u3002 \u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u548c\u6570\u503c\u53d8\u91cf\u65f6\uff0c\u53ef\u80fd\u4f1a\u9047\u5230\u7f3a\u5931\u503c\u3002\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e9b\u5904\u7406\u5206\u7c7b\u7279\u5f81\u4e2d\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\uff0c\u4f46\u8fd8\u6709\u66f4\u591a\u65b9\u6cd5\u53ef\u4ee5\u5904\u7406\u7f3a\u5931\u503c/NaN \u503c\u3002\u8fd9\u4e5f\u88ab\u89c6\u4e3a\u7279\u5f81\u5de5\u7a0b\u3002 \u5982\u679c\u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u9047\u5230\u7f3a\u5931\u503c\uff0c\u5c31\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\uff01\u8fd9\u6837\u505a\u867d\u7136\u7b80\u5355\uff0c\u4f46\uff08\u51e0\u4e4e\uff09\u603b\u662f\u6709\u6548\u7684\uff01 \u5728\u6570\u503c\u6570\u636e\u4e2d\u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u9009\u62e9\u4e00\u4e2a\u5728\u7279\u5b9a\u7279\u5f81\u4e2d\u6ca1\u6709\u51fa\u73b0\u7684\u503c\uff0c\u7136\u540e\u7528\u5b83\u6765\u586b\u8865\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u7279\u5f81\u4e2d\u6ca1\u6709 0\u3002\u8fd9\u662f\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\uff0c\u4f46\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\u3002\u5bf9\u4e8e\u6570\u503c\u6570\u636e\u6765\u8bf4\uff0c\u6bd4\u586b\u5145 0 \u66f4\u6709\u6548\u7684\u65b9\u6cd5\u4e4b\u4e00\u662f\u4f7f\u7528\u5e73\u5747\u503c\u8fdb\u884c\u586b\u5145\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8be5\u7279\u5f81\u6240\u6709\u503c\u7684\u4e2d\u4f4d\u6570\u6765\u586b\u5145\uff0c\u6216\u8005\u4f7f\u7528\u6700\u5e38\u89c1\u7684\u503c\u6765\u586b\u5145\u7f3a\u5931\u503c\u3002\u8fd9\u6837\u505a\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002 \u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u9ad8\u7ea7\u65b9\u6cd5\u662f\u4f7f\u7528 K \u8fd1\u90bb\u6cd5 \u3002 \u60a8\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\uff0c\u7136\u540e\u5229\u7528\u67d0\u79cd\u8ddd\u79bb\u5ea6\u91cf\uff08\u4f8b\u5982\u6b27\u6c0f\u8ddd\u79bb\uff09\u627e\u5230\u6700\u8fd1\u7684\u90bb\u5c45\u3002\u7136\u540e\u53d6\u6240\u6709\u8fd1\u90bb\u7684\u5e73\u5747\u503c\u6765\u586b\u8865\u7f3a\u5931\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 KNN \u6765\u586b\u8865\u8fd9\u6837\u7684\u7f3a\u5931\u503c\u3002 \u56fe 10\uff1a\u6709\u7f3a\u5931\u503c\u7684\u4e8c\u7ef4\u6570\u7ec4 \u8ba9\u6211\u4eec\u770b\u770b KNN \u662f\u5982\u4f55\u5904\u7406\u56fe 10 \u6240\u793a\u7684\u7f3a\u5931\u503c\u77e9\u9635\u7684\u3002 import numpy as np from sklearn import impute # \u751f\u6210\u7ef4\u5ea6\u4e3a (10, 6) \u7684\u968f\u673a\u6574\u6570\u77e9\u9635 X\uff0c\u6570\u503c\u8303\u56f4\u5728 1 \u5230 14 \u4e4b\u95f4 X = np . random . randint ( 1 , 15 , ( 10 , 6 )) # \u6570\u636e\u7c7b\u578b\u8f6c\u6362\u4e3a float X = X . astype ( float ) # \u5728\u77e9\u9635 X \u4e2d\u968f\u673a\u9009\u62e9 10 \u4e2a\u4f4d\u7f6e\uff0c\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u5143\u7d20\u8bbe\u7f6e\u4e3a NaN\uff08\u7f3a\u5931\u503c\uff09 X . ravel ()[ np . random . choice ( X . size , 10 , replace = False )] = np . nan # \u521b\u5efa\u4e00\u4e2a KNNImputer \u5bf9\u8c61 knn_imputer\uff0c\u6307\u5b9a\u90bb\u5c45\u6570\u91cf\u4e3a 2 knn_imputer = impute . KNNImputer ( n_neighbors = 2 ) # # \u4f7f\u7528 knn_imputer \u5bf9\u77e9\u9635 X \u8fdb\u884c\u62df\u5408\u548c\u8f6c\u6362\uff0c\u7528 K-\u6700\u8fd1\u90bb\u65b9\u6cd5\u586b\u8865\u7f3a\u5931\u503c knn_imputer . fit_transform ( X ) \u5982\u56fe 11 \u6240\u793a\uff0c\u5b83\u586b\u5145\u4e86\u4e0a\u8ff0\u77e9\u9635\u3002 \u56fe 11\uff1aKNN\u4f30\u7b97\u7684\u6570\u503c \u53e6\u4e00\u79cd\u5f25\u8865\u5217\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\u662f\u8bad\u7ec3\u56de\u5f52\u6a21\u578b\uff0c\u8bd5\u56fe\u6839\u636e\u5176\u4ed6\u5217\u9884\u6d4b\u67d0\u5217\u7684\u7f3a\u5931\u503c\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u4ece\u6709\u7f3a\u5931\u503c\u7684\u4e00\u5217\u5f00\u59cb\uff0c\u5c06\u8fd9\u4e00\u5217\u4f5c\u4e3a\u65e0\u7f3a\u5931\u503c\u56de\u5f52\u6a21\u578b\u7684\u76ee\u6807\u5217\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6240\u6709\u5176\u4ed6\u5217\uff0c\u5bf9\u76f8\u5173\u5217\u4e2d\u6ca1\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\uff0c\u7136\u540e\u5c1d\u8bd5\u9884\u6d4b\u4e4b\u524d\u5220\u9664\u7684\u6837\u672c\u7684\u76ee\u6807\u5217\uff08\u540c\u4e00\u5217\uff09\u3002\u8fd9\u6837\uff0c\u57fa\u4e8e\u6a21\u578b\u7684\u4f30\u7b97\u5c31\u4f1a\u66f4\u52a0\u7a33\u5065\u3002 \u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u5bf9\u4e8e\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c\u6570\u503c\u5f52\u4e00\u5316\uff0c\u56e0\u4e3a\u5b83\u4eec\u53ef\u4ee5\u81ea\u884c\u5904\u7406\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u6240\u5c55\u793a\u7684\u53ea\u662f\u521b\u5efa\u4e00\u822c\u7279\u5f81\u7684\u4e00\u4e9b\u65b9\u6cd5\u3002\u73b0\u5728\uff0c\u5047\u8bbe\u60a8\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u9884\u6d4b\u4e0d\u540c\u5546\u54c1\uff08\u6bcf\u5468\u6216\u6bcf\u6708\uff09\u5546\u5e97\u9500\u552e\u989d\u7684\u95ee\u9898\u3002\u60a8\u6709\u5546\u54c1\uff0c\u4e5f\u6709\u5546\u5e97 ID\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u6bcf\u4e2a\u5546\u5e97\u7684\u5546\u54c1\u7b49\u7279\u5f81\u3002\u73b0\u5728\uff0c\u8fd9\u662f\u4e0a\u6587\u6ca1\u6709\u8ba8\u8bba\u7684\u7279\u5f81\u4e4b\u4e00\u3002\u8fd9\u7c7b\u7279\u5f81\u4e0d\u80fd\u4e00\u6982\u800c\u8bba\uff0c\u5b8c\u5168\u6765\u81ea\u4e8e\u9886\u57df\u3001\u6570\u636e\u548c\u4e1a\u52a1\u77e5\u8bc6\u3002\u67e5\u770b\u6570\u636e\uff0c\u627e\u51fa\u9002\u5408\u7684\u7279\u5f81\uff0c\u7136\u540e\u521b\u5efa\u76f8\u5e94\u7684\u7279\u5f81\u3002\u5982\u679c\u60a8\u4f7f\u7528\u7684\u662f\u903b\u8f91\u56de\u5f52\u7b49\u7ebf\u6027\u6a21\u578b\u6216 SVM \u7b49\u6a21\u578b\uff0c\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\u5bf9\u7279\u5f81\u8fdb\u884c\u7f29\u653e\u6216\u5f52\u4e00\u5316\u5904\u7406\u3002\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65e0\u9700\u5bf9\u7279\u5f81\u8fdb\u884c\u4efb\u4f55\u5f52\u4e00\u5316\u5904\u7406\u5373\u53ef\u6b63\u5e38\u5de5\u4f5c\u3002","title":"\u7279\u5f81\u5de5\u7a0b"},{"location":"%E7%89%B9%E5%BE%81%E5%B7%A5%E7%A8%8B/#_1","text":"\u7279\u5f81\u5de5\u7a0b\u662f\u6784\u5efa\u826f\u597d\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u90e8\u5206\u4e4b\u4e00\u3002\u5982\u679c\u6211\u4eec\u62e5\u6709\u6709\u7528\u7684\u7279\u5f81\uff0c\u6a21\u578b\u5c31\u4f1a\u8868\u73b0\u5f97\u66f4\u597d\u3002\u5728\u8bb8\u591a\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u907f\u514d\u4f7f\u7528\u5927\u578b\u590d\u6742\u6a21\u578b\uff0c\u800c\u4f7f\u7528\u5177\u6709\u5173\u952e\u5de5\u7a0b\u7279\u5f81\u7684\u7b80\u5355\u6a21\u578b\u3002\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\uff0c\u53ea\u6709\u5f53\u4f60\u5bf9\u95ee\u9898\u7684\u9886\u57df\u6709\u4e00\u5b9a\u7684\u4e86\u89e3\uff0c\u5e76\u4e14\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u76f8\u5173\u6570\u636e\u65f6\uff0c\u624d\u80fd\u4ee5\u6700\u4f73\u65b9\u5f0f\u5b8c\u6210\u7279\u5f81\u5de5\u7a0b\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e9b\u901a\u7528\u6280\u672f\uff0c\u4ece\u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u7279\u5f81\u3002\u7279\u5f81\u5de5\u7a0b\u4e0d\u4ec5\u4ec5\u662f\u4ece\u6570\u636e\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\uff0c\u8fd8\u5305\u62ec\u4e0d\u540c\u7c7b\u578b\u7684\u5f52\u4e00\u5316\u548c\u8f6c\u6362\u3002 \u5728\u6709\u5173\u5206\u7c7b\u7279\u5f81\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u4e86\u89e3\u4e86\u7ed3\u5408\u4e0d\u540c\u5206\u7c7b\u53d8\u91cf\u7684\u65b9\u6cd5\u3001\u5982\u4f55\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u8ba1\u6570\u3001\u6807\u7b7e\u7f16\u7801\u548c\u4f7f\u7528\u5d4c\u5165\u3002\u8fd9\u4e9b\u51e0\u4e4e\u90fd\u662f\u5229\u7528\u5206\u7c7b\u53d8\u91cf\u8bbe\u8ba1\u7279\u5f81\u7684\u65b9\u6cd5\u3002\u56e0\u6b64\uff0c\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u7684\u91cd\u70b9\u5c06\u4ec5\u9650\u4e8e\u6570\u503c\u53d8\u91cf\u4ee5\u53ca\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u7684\u7ec4\u5408\u3002 \u8ba9\u6211\u4eec\u4ece\u6700\u7b80\u5355\u4f46\u5e94\u7528\u6700\u5e7f\u6cdb\u7684\u7279\u5f81\u5de5\u7a0b\u6280\u672f\u5f00\u59cb\u3002\u5047\u8bbe\u4f60\u6b63\u5728\u5904\u7406\u65e5\u671f\u548c\u65f6\u95f4\u6570\u636e\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709\u4e00\u4e2a\u5e26\u6709\u65e5\u671f\u7c7b\u578b\u5217\u7684 pandas \u6570\u636e\u5e27\u3002\u5229\u7528\u8fd9\u4e00\u5217\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4ee5\u4e0b\u7279\u5f81\uff1a \u5e74 \u5e74\u4e2d\u7684\u5468 \u6708 \u661f\u671f \u5468\u672b \u5c0f\u65f6 \u8fd8\u6709\u66f4\u591a \u800c\u4f7f\u7528pandas\u5c31\u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u6dfb\u52a0'year'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5e74\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'year' ] = df [ 'datetime_column' ] . dt . year # \u6dfb\u52a0'weekofyear'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5468\u6570\u63d0\u53d6\u51fa\u6765 df . loc [:, 'weekofyear' ] = df [ 'datetime_column' ] . dt . weekofyear # \u6dfb\u52a0'month'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u6708\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'month' ] = df [ 'datetime_column' ] . dt . month # \u6dfb\u52a0'dayofweek'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u661f\u671f\u51e0\u63d0\u53d6\u51fa\u6765 df . loc [:, 'dayofweek' ] = df [ 'datetime_column' ] . dt . dayofweek # \u6dfb\u52a0'weekend'\u5217\uff0c\u5224\u65ad\u5f53\u5929\u662f\u5426\u4e3a\u5468\u672b df . loc [:, 'weekend' ] = ( df . datetime_column . dt . weekday >= 5 ) . astype ( int ) # \u6dfb\u52a0 'hour' \u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5c0f\u65f6\u63d0\u53d6\u51fa\u6765 df . loc [:, 'hour' ] = df [ 'datetime_column' ] . dt . hour \u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u65e5\u671f\u65f6\u95f4\u5217\u521b\u5efa\u4e00\u7cfb\u5217\u65b0\u5217\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u53ef\u4ee5\u521b\u5efa\u7684\u4e00\u4e9b\u793a\u4f8b\u529f\u80fd\u3002 import pandas as pd # \u521b\u5efa\u65e5\u671f\u65f6\u95f4\u5e8f\u5217\uff0c\u5305\u542b\u4e86\u4ece '2020-01-06' \u5230 '2020-01-10' \u7684\u65e5\u671f\u65f6\u95f4\u70b9\uff0c\u65f6\u95f4\u95f4\u9694\u4e3a10\u5c0f\u65f6 s = pd . date_range ( '2020-01-06' , '2020-01-10' , freq = '10H' ) . to_series () # \u63d0\u53d6\u5bf9\u5e94\u65f6\u95f4\u7279\u5f81 features = { \"dayofweek\" : s . dt . dayofweek . values , \"dayofyear\" : s . dt . dayofyear . values , \"hour\" : s . dt . hour . values , \"is_leap_year\" : s . dt . is_leap_year . values , \"quarter\" : s . dt . quarter . values , \"weekofyear\" : s . dt . weekofyear . values } \u8fd9\u5c06\u4ece\u7ed9\u5b9a\u7cfb\u5217\u4e2d\u751f\u6210\u4e00\u4e2a\u7279\u5f81\u5b57\u5178\u3002\u60a8\u53ef\u4ee5\u5c06\u6b64\u5e94\u7528\u4e8e pandas \u6570\u636e\u4e2d\u7684\u4efb\u4f55\u65e5\u671f\u65f6\u95f4\u5217\u3002\u8fd9\u4e9b\u662f pandas \u63d0\u4f9b\u7684\u4f17\u591a\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u4e2d\u7684\u4e00\u90e8\u5206\u3002\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u975e\u5e38\u91cd\u8981\uff0c\u4f8b\u5982\uff0c\u5728\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97\u7684\u9500\u552e\u989d\u65f6\uff0c\u5982\u679c\u60f3\u5728\u805a\u5408\u7279\u5f81\u4e0a\u4f7f\u7528 xgboost \u7b49\u6a21\u578b\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u5c31\u975e\u5e38\u91cd\u8981\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u4e0b\u6240\u793a\u7684\u6570\u636e\uff1a \u56fe 1\uff1a\u5305\u542b\u5206\u7c7b\u548c\u65e5\u671f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u6709\u4e00\u4e2a\u65e5\u671f\u5217\uff0c\u4ece\u4e2d\u53ef\u4ee5\u8f7b\u677e\u63d0\u53d6\u5e74\u3001\u6708\u3001\u5b63\u5ea6\u7b49\u7279\u5f81\u3002\u7136\u540e\uff0c\u6211\u4eec\u6709\u4e00\u4e2a customer_id \u5217\uff0c\u8be5\u5217\u6709\u591a\u4e2a\u6761\u76ee\uff0c\u56e0\u6b64\u4e00\u4e2a\u5ba2\u6237\u4f1a\u88ab\u770b\u5230\u5f88\u591a\u6b21\uff08\u622a\u56fe\u4e2d\u770b\u4e0d\u5230\uff09\u3002\u6bcf\u4e2a\u65e5\u671f\u548c\u5ba2\u6237 ID \u90fd\u6709\u4e09\u4e2a\u5206\u7c7b\u7279\u5f81\u548c\u4e00\u4e2a\u6570\u5b57\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u4ece\u4e2d\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff1a - \u5ba2\u6237\u6700\u6d3b\u8dc3\u7684\u6708\u4efd\u662f\u51e0\u6708 - \u67d0\u4e2a\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u7684\u8ba1\u6570\u662f\u591a\u5c11 - \u67d0\u5e74\u67d0\u6708\u67d0\u5468\u67d0\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u6570\u91cf\u662f\u591a\u5c11\uff1f - \u67d0\u4e2a\u5ba2\u6237\u7684 num1 \u5e73\u5747\u503c\u662f\u591a\u5c11\uff1f - \u7b49\u7b49\u3002 \u4f7f\u7528 pandas \u4e2d\u7684\u805a\u5408\uff0c\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u521b\u5efa\u7c7b\u4f3c\u7684\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u5b9e\u73b0\u3002 def generate_features ( df ): df . loc [:, 'year' ] = df [ 'date' ] . dt . year df . loc [:, 'weekofyear' ] = df [ 'date' ] . dt . weekofyear df . loc [:, 'month' ] = df [ 'date' ] . dt . month df . loc [:, 'dayofweek' ] = df [ 'date' ] . dt . dayofweek df . loc [:, 'weekend' ] = ( df [ 'date' ] . dt . weekday >= 5 ) . astype ( int ) aggs = {} # \u5bf9 'month' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'month' ] = [ 'nunique' , 'mean' ] # \u5bf9 'weekofyear' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'weekofyear' ] = [ 'nunique' , 'mean' ] # \u5bf9 'num1' \u5217\u8fdb\u884c sum\u3001max\u3001min\u3001mean \u805a\u5408 aggs [ 'num1' ] = [ 'sum' , 'max' , 'min' , 'mean' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c size \u805a\u5408 aggs [ 'customer_id' ] = [ 'size' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c nunique \u805a\u5408 aggs [ 'customer_id' ] = [ 'nunique' ] # \u5bf9\u6570\u636e\u5e94\u7528\u4e0d\u540c\u7684\u805a\u5408\u51fd\u6570 agg_df = df . groupby ( 'customer_id' ) . agg ( aggs ) # \u91cd\u7f6e\u7d22\u5f15 agg_df = agg_df . reset_index () return agg_df \u8bf7\u6ce8\u610f\uff0c\u5728\u4e0a\u8ff0\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u8df3\u8fc7\u4e86\u5206\u7c7b\u53d8\u91cf\uff0c\u4f46\u60a8\u53ef\u4ee5\u50cf\u4f7f\u7528\u5176\u4ed6\u805a\u5408\u53d8\u91cf\u4e00\u6837\u4f7f\u7528\u5b83\u4eec\u3002 \u56fe 2\uff1a\u603b\u4f53\u7279\u5f81\u548c\u5176\u4ed6\u7279\u5f81 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u56fe 2 \u4e2d\u7684\u6570\u636e\u4e0e\u5e26\u6709 customer_id \u5217\u7684\u539f\u59cb\u6570\u636e\u5e27\u8fde\u63a5\u8d77\u6765\uff0c\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5e76\u4e0d\u662f\u8981\u9884\u6d4b\u4ec0\u4e48\uff1b\u6211\u4eec\u53ea\u662f\u5728\u521b\u5efa\u901a\u7528\u7279\u5f81\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u8bd5\u56fe\u5728\u8fd9\u91cc\u9884\u6d4b\u4ec0\u4e48\uff0c\u521b\u5efa\u7279\u5f81\u4f1a\u66f4\u5bb9\u6613\u3002 \u4f8b\u5982\uff0c\u6709\u65f6\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u95ee\u9898\u65f6\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u7684\u7279\u5f81\u4e0d\u662f\u5355\u4e2a\u503c\uff0c\u800c\u662f\u4e00\u7cfb\u5217\u503c\u3002 \u4f8b\u5982\uff0c\u5ba2\u6237\u5728\u7279\u5b9a\u65f6\u95f4\u6bb5\u5185\u7684\u4ea4\u6613\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4f1a\u521b\u5efa\u4e0d\u540c\u7c7b\u578b\u7684\u7279\u5f81\uff0c\u4f8b\u5982\uff1a\u4f7f\u7528\u6570\u503c\u7279\u5f81\u65f6\uff0c\u5728\u5bf9\u5206\u7c7b\u5217\u8fdb\u884c\u5206\u7ec4\u65f6\uff0c\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e8e\u65f6\u95f4\u5206\u5e03\u503c\u5217\u8868\u7684\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u7cfb\u5217\u7edf\u8ba1\u7279\u5f81\uff0c\u4f8b\u5982 \u5e73\u5747\u503c \u6700\u5927\u503c \u6700\u5c0f\u503c \u72ec\u7279\u6027 \u504f\u659c \u5cf0\u5ea6 Kstat \u767e\u5206\u4f4d\u6570 \u5b9a\u91cf \u5cf0\u503c\u5230\u5cf0\u503c \u4ee5\u53ca\u66f4\u591a \u8fd9\u4e9b\u53ef\u4ee5\u4f7f\u7528\u7b80\u5355\u7684 numpy \u51fd\u6570\u521b\u5efa\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 import numpy as np # \u521b\u5efa\u5b57\u5178\uff0c\u7528\u4e8e\u5b58\u50a8\u4e0d\u540c\u7684\u7edf\u8ba1\u7279\u5f81 feature_dict = {} # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5e73\u5747\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'mean' \u952e\u4e0b feature_dict [ 'mean' ] = np . mean ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5927\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'max' \u952e\u4e0b feature_dict [ 'max' ] = np . max ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5c0f\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'min' \u952e\u4e0b feature_dict [ 'min' ] = np . min ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6807\u51c6\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'std' \u952e\u4e0b feature_dict [ 'std' ] = np . std ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u65b9\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'var' \u952e\u4e0b feature_dict [ 'var' ] = np . var ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5dee\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'ptp' \u952e\u4e0b feature_dict [ 'ptp' ] = np . ptp ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c10\u767e\u5206\u4f4d\u6570\uff08\u5373\u767e\u5206\u4e4b10\u5206\u4f4d\u6570\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_10' \u952e\u4e0b feature_dict [ 'percentile_10' ] = np . percentile ( x , 10 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c60\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_60' \u952e\u4e0b feature_dict [ 'percentile_60' ] = np . percentile ( x , 60 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c90\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_90' \u952e\u4e0b feature_dict [ 'percentile_90' ] = np . percentile ( x , 90 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u76845%\u5206\u4f4d\u6570\uff08\u53730.05\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_5' \u952e\u4e0b feature_dict [ 'quantile_5' ] = np . quantile ( x , 0.05 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768495%\u5206\u4f4d\u6570\uff08\u53730.95\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_95' \u952e\u4e0b feature_dict [ 'quantile_95' ] = np . quantile ( x , 0.95 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768499%\u5206\u4f4d\u6570\uff08\u53730.99\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_99' \u952e\u4e0b feature_dict [ 'quantile_99' ] = np . quantile ( x , 0.99 ) \u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff08\u6570\u503c\u5217\u8868\uff09\u53ef\u4ee5\u8f6c\u6362\u6210\u8bb8\u591a\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4e00\u4e2a\u540d\u4e3a tsfresh \u7684 python \u5e93\u975e\u5e38\u6709\u7528\u3002 from tsfresh.feature_extraction import feature_calculators as fc # \u8ba1\u7b97 x \u6570\u5217\u7684\u7edd\u5bf9\u80fd\u91cf\uff08abs_energy\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'abs_energy' \u952e\u4e0b feature_dict [ 'abs_energy' ] = fc . abs_energy ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u9ad8\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_above_mean' \u952e\u4e0b feature_dict [ 'count_above_mean' ] = fc . count_above_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u4f4e\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_below_mean' \u952e\u4e0b feature_dict [ 'count_below_mean' ] = fc . count_below_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u7edd\u5bf9\u53d8\u5316\uff08mean_abs_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_abs_change' \u952e\u4e0b feature_dict [ 'mean_abs_change' ] = fc . mean_abs_change ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u53d8\u5316\u7387\uff08mean_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_change' \u952e\u4e0b feature_dict [ 'mean_change' ] = fc . mean_change ( x ) \u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\uff1btsfresh \u63d0\u4f9b\u4e86\u6570\u767e\u79cd\u7279\u5f81\u548c\u6570\u5341\u79cd\u4e0d\u540c\u7279\u5f81\u7684\u53d8\u4f53\uff0c\u4f60\u53ef\u4ee5\u5c06\u5b83\u4eec\u7528\u4e8e\u57fa\u4e8e\u65f6\u95f4\u5e8f\u5217\uff08\u503c\u5217\u8868\uff09\u7684\u7279\u5f81\u3002\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0cx \u662f\u4e00\u4e2a\u503c\u5217\u8868\u3002\u4f46\u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\u3002\u60a8\u8fd8\u53ef\u4ee5\u4e3a\u5305\u542b\u6216\u4e0d\u5305\u542b\u5206\u7c7b\u6570\u636e\u7684\u6570\u503c\u6570\u636e\u521b\u5efa\u8bb8\u591a\u5176\u4ed6\u7279\u5f81\u3002\u751f\u6210\u8bb8\u591a\u7279\u5f81\u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u521b\u5efa\u4e00\u5806\u591a\u9879\u5f0f\u7279\u5f81\u3002\u4f8b\u5982\uff0c\u4ece\u4e24\u4e2a\u7279\u5f81 \"a \"\u548c \"b \"\u751f\u6210\u7684\u4e8c\u7ea7\u591a\u9879\u5f0f\u7279\u5f81\u5305\u62ec \"a\"\u3001\"b\"\u3001\"ab\"\u3001\"a^2 \"\u548c \"b^2\"\u3002 import numpy as np df = pd . DataFrame ( np . random . rand ( 100 , 2 ), columns = [ f \"f_ { i } \" for i in range ( 1 , 3 )]) \u5982\u56fe 3 \u6240\u793a\uff0c\u5b83\u7ed9\u51fa\u4e86\u4e00\u4e2a\u6570\u636e\u8868\u3002 \u56fe 3\uff1a\u5305\u542b\u4e24\u4e2a\u6570\u5b57\u7279\u5f81\u7684\u968f\u673a\u6570\u636e\u8868 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u7684 PolynomialFeatures \u521b\u5efa\u4e24\u6b21\u591a\u9879\u5f0f\u7279\u5f81\u3002 from sklearn import preprocessing # \u6307\u5b9a\u591a\u9879\u5f0f\u7684\u6b21\u6570\u4e3a 2\uff0c\u4e0d\u4ec5\u8003\u8651\u4ea4\u4e92\u9879\uff0c\u4e0d\u5305\u62ec\u504f\u5dee\uff08include_bias=False\uff09 pf = preprocessing . PolynomialFeatures ( degree = 2 , interaction_only = False , include_bias = False ) # \u62df\u5408\uff0c\u521b\u5efa\u591a\u9879\u5f0f\u7279\u5f81 pf . fit ( df ) # \u8f6c\u6362\u6570\u636e poly_feats = pf . transform ( df ) # \u83b7\u53d6\u751f\u6210\u7684\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf num_feats = poly_feats . shape [ 1 ] # \u4e3a\u65b0\u751f\u6210\u7684\u7279\u5f81\u547d\u540d df_transformed = pd . DataFrame ( poly_feats , columns = [ f \"f_ { i } \" for i in range ( 1 , num_feats + 1 )] ) \u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u6570\u636e\u8868\uff0c\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1a\u5e26\u6709\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e\u8868 \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e9b\u591a\u9879\u5f0f\u7279\u5f81\u3002\u5982\u679c\u521b\u5efa\u7684\u662f\u4e09\u6b21\u591a\u9879\u5f0f\u7279\u5f81\uff0c\u6700\u7ec8\u603b\u5171\u4f1a\u6709\u4e5d\u4e2a\u7279\u5f81\u3002\u7279\u5f81\u7684\u6570\u91cf\u8d8a\u591a\uff0c\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf\u4e5f\u5c31\u8d8a\u591a\uff0c\u800c\u4e14\u4f60\u8fd8\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5982\u679c\u6570\u636e\u96c6\u4e2d\u6709\u5f88\u591a\u6837\u672c\uff0c\u90a3\u4e48\u521b\u5efa\u8fd9\u7c7b\u7279\u5f81\u5c31\u9700\u8981\u82b1\u8d39\u4e00\u4e9b\u65f6\u95f4\u3002 \u56fe 5\uff1a\u6570\u5b57\u7279\u5f81\u5217\u7684\u76f4\u65b9\u56fe \u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u529f\u80fd\u662f\u5c06\u6570\u5b57\u8f6c\u6362\u4e3a\u7c7b\u522b\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \u5206\u7bb1 \u3002\u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u56fe 5\uff0c\u5b83\u663e\u793a\u4e86\u4e00\u4e2a\u968f\u673a\u6570\u5b57\u7279\u5f81\u7684\u6837\u672c\u76f4\u65b9\u56fe\u3002\u6211\u4eec\u5728\u8be5\u56fe\u4e2d\u4f7f\u7528\u4e8610\u4e2a\u5206\u7bb1\uff0c\u53ef\u4ee5\u770b\u5230\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a10\u4e2a\u90e8\u5206\u3002\u8fd9\u53ef\u4ee5\u4f7f\u7528 pandas \u7684cat\u51fd\u6570\u6765\u5b9e\u73b0\u3002 # \u521b\u5efa10\u4e2a\u5206\u7bb1 df [ \"f_bin_10\" ] = pd . cut ( df [ \"f_1\" ], bins = 10 , labels = False ) # \u521b\u5efa100\u4e2a\u5206\u7bb1 df [ \"f_bin_100\" ] = pd . cut ( df [ \"f_1\" ], bins = 100 , labels = False ) \u5982\u56fe 6 \u6240\u793a\uff0c\u8fd9\u5c06\u5728\u6570\u636e\u5e27\u4e2d\u751f\u6210\u4e24\u4e2a\u65b0\u7279\u5f81\u3002 \u56fe 6\uff1a\u6570\u503c\u7279\u5f81\u5206\u7bb1 \u5f53\u4f60\u8fdb\u884c\u5206\u7c7b\u65f6\uff0c\u53ef\u4ee5\u540c\u65f6\u4f7f\u7528\u5206\u7bb1\u548c\u539f\u59cb\u7279\u5f81\u3002\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u540e\u534a\u90e8\u5206\u5b66\u4e60\u66f4\u591a\u5173\u4e8e\u9009\u62e9\u7279\u5f81\u7684\u77e5\u8bc6\u3002\u5206\u7bb1\u8fd8\u53ef\u4ee5\u5c06\u6570\u5b57\u7279\u5f81\u89c6\u4e3a\u5206\u7c7b\u7279\u5f81\u3002 \u53e6\u4e00\u79cd\u53ef\u4ee5\u4ece\u6570\u503c\u7279\u5f81\u4e2d\u521b\u5efa\u7684\u6709\u8da3\u7279\u5f81\u7c7b\u578b\u662f\u5bf9\u6570\u53d8\u6362\u3002\u8bf7\u770b\u56fe 7 \u4e2d\u7684\u7279\u5f81 f_3\u3002 \u4e0e\u5176\u4ed6\u65b9\u5dee\u8f83\u5c0f\u7684\u7279\u5f81\u76f8\u6bd4\uff08\u5047\u8bbe\u5982\u6b64\uff09\uff0cf_3 \u662f\u4e00\u79cd\u65b9\u5dee\u975e\u5e38\u5927\u7684\u7279\u6b8a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5e0c\u671b\u964d\u4f4e\u8fd9\u4e00\u5217\u7684\u65b9\u5dee\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6570\u53d8\u6362\u6765\u5b9e\u73b0\u3002 f_3 \u5217\u7684\u503c\u8303\u56f4\u4e3a 0 \u5230 10000\uff0c\u76f4\u65b9\u56fe\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u7279\u5f81 f_3 \u7684\u76f4\u65b9\u56fe \u6211\u4eec\u53ef\u4ee5\u5bf9\u8fd9\u4e00\u5217\u5e94\u7528 log(1 + x) \u6765\u51cf\u5c11\u5176\u65b9\u5dee\u3002\u56fe 9 \u663e\u793a\u4e86\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u76f4\u65b9\u56fe\u7684\u53d8\u5316\u3002 \u56fe 9\uff1a\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u7684 f_3 \u76f4\u65b9\u56fe \u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e0d\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u548c\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u7684\u65b9\u5dee\u3002 In [ X ]: df . f_3 . var () Out [ X ]: 8077265.875858586 In [ X ]: df . f_3 . apply ( lambda x : np . log ( 1 + x )) . var () Out [ X ]: 0.6058771732119975 \u6709\u65f6\uff0c\u4e5f\u53ef\u4ee5\u7528\u6307\u6570\u6765\u4ee3\u66ff\u5bf9\u6570\u3002\u4e00\u79cd\u975e\u5e38\u6709\u8da3\u7684\u60c5\u51b5\u662f\uff0c\u60a8\u4f7f\u7528\u57fa\u4e8e\u5bf9\u6570\u7684\u8bc4\u4f30\u6307\u6807\uff0c\u4f8b\u5982 RMSLE\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u5728\u5bf9\u6570\u53d8\u6362\u7684\u76ee\u6807\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u7136\u540e\u5728\u9884\u6d4b\u65f6\u4f7f\u7528\u6307\u6570\u503c\u8f6c\u6362\u56de\u539f\u59cb\u503c\u3002\u8fd9\u5c06\u6709\u52a9\u4e8e\u9488\u5bf9\u6307\u6807\u4f18\u5316\u6a21\u578b\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u7c7b\u6570\u5b57\u7279\u5f81\u90fd\u662f\u57fa\u4e8e\u76f4\u89c9\u521b\u5efa\u7684\u3002\u6ca1\u6709\u516c\u5f0f\u53ef\u5faa\u3002\u5982\u679c\u60a8\u4ece\u4e8b\u7684\u662f\u67d0\u4e00\u884c\u4e1a\uff0c\u60a8\u5c06\u521b\u5efa\u7279\u5b9a\u884c\u4e1a\u7684\u7279\u5f81\u3002 \u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u548c\u6570\u503c\u53d8\u91cf\u65f6\uff0c\u53ef\u80fd\u4f1a\u9047\u5230\u7f3a\u5931\u503c\u3002\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e9b\u5904\u7406\u5206\u7c7b\u7279\u5f81\u4e2d\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\uff0c\u4f46\u8fd8\u6709\u66f4\u591a\u65b9\u6cd5\u53ef\u4ee5\u5904\u7406\u7f3a\u5931\u503c/NaN \u503c\u3002\u8fd9\u4e5f\u88ab\u89c6\u4e3a\u7279\u5f81\u5de5\u7a0b\u3002 \u5982\u679c\u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u9047\u5230\u7f3a\u5931\u503c\uff0c\u5c31\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\uff01\u8fd9\u6837\u505a\u867d\u7136\u7b80\u5355\uff0c\u4f46\uff08\u51e0\u4e4e\uff09\u603b\u662f\u6709\u6548\u7684\uff01 \u5728\u6570\u503c\u6570\u636e\u4e2d\u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u9009\u62e9\u4e00\u4e2a\u5728\u7279\u5b9a\u7279\u5f81\u4e2d\u6ca1\u6709\u51fa\u73b0\u7684\u503c\uff0c\u7136\u540e\u7528\u5b83\u6765\u586b\u8865\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u7279\u5f81\u4e2d\u6ca1\u6709 0\u3002\u8fd9\u662f\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\uff0c\u4f46\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\u3002\u5bf9\u4e8e\u6570\u503c\u6570\u636e\u6765\u8bf4\uff0c\u6bd4\u586b\u5145 0 \u66f4\u6709\u6548\u7684\u65b9\u6cd5\u4e4b\u4e00\u662f\u4f7f\u7528\u5e73\u5747\u503c\u8fdb\u884c\u586b\u5145\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8be5\u7279\u5f81\u6240\u6709\u503c\u7684\u4e2d\u4f4d\u6570\u6765\u586b\u5145\uff0c\u6216\u8005\u4f7f\u7528\u6700\u5e38\u89c1\u7684\u503c\u6765\u586b\u5145\u7f3a\u5931\u503c\u3002\u8fd9\u6837\u505a\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002 \u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u9ad8\u7ea7\u65b9\u6cd5\u662f\u4f7f\u7528 K \u8fd1\u90bb\u6cd5 \u3002 \u60a8\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\uff0c\u7136\u540e\u5229\u7528\u67d0\u79cd\u8ddd\u79bb\u5ea6\u91cf\uff08\u4f8b\u5982\u6b27\u6c0f\u8ddd\u79bb\uff09\u627e\u5230\u6700\u8fd1\u7684\u90bb\u5c45\u3002\u7136\u540e\u53d6\u6240\u6709\u8fd1\u90bb\u7684\u5e73\u5747\u503c\u6765\u586b\u8865\u7f3a\u5931\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 KNN \u6765\u586b\u8865\u8fd9\u6837\u7684\u7f3a\u5931\u503c\u3002 \u56fe 10\uff1a\u6709\u7f3a\u5931\u503c\u7684\u4e8c\u7ef4\u6570\u7ec4 \u8ba9\u6211\u4eec\u770b\u770b KNN \u662f\u5982\u4f55\u5904\u7406\u56fe 10 \u6240\u793a\u7684\u7f3a\u5931\u503c\u77e9\u9635\u7684\u3002 import numpy as np from sklearn import impute # \u751f\u6210\u7ef4\u5ea6\u4e3a (10, 6) \u7684\u968f\u673a\u6574\u6570\u77e9\u9635 X\uff0c\u6570\u503c\u8303\u56f4\u5728 1 \u5230 14 \u4e4b\u95f4 X = np . random . randint ( 1 , 15 , ( 10 , 6 )) # \u6570\u636e\u7c7b\u578b\u8f6c\u6362\u4e3a float X = X . astype ( float ) # \u5728\u77e9\u9635 X \u4e2d\u968f\u673a\u9009\u62e9 10 \u4e2a\u4f4d\u7f6e\uff0c\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u5143\u7d20\u8bbe\u7f6e\u4e3a NaN\uff08\u7f3a\u5931\u503c\uff09 X . ravel ()[ np . random . choice ( X . size , 10 , replace = False )] = np . nan # \u521b\u5efa\u4e00\u4e2a KNNImputer \u5bf9\u8c61 knn_imputer\uff0c\u6307\u5b9a\u90bb\u5c45\u6570\u91cf\u4e3a 2 knn_imputer = impute . KNNImputer ( n_neighbors = 2 ) # # \u4f7f\u7528 knn_imputer \u5bf9\u77e9\u9635 X \u8fdb\u884c\u62df\u5408\u548c\u8f6c\u6362\uff0c\u7528 K-\u6700\u8fd1\u90bb\u65b9\u6cd5\u586b\u8865\u7f3a\u5931\u503c knn_imputer . fit_transform ( X ) \u5982\u56fe 11 \u6240\u793a\uff0c\u5b83\u586b\u5145\u4e86\u4e0a\u8ff0\u77e9\u9635\u3002 \u56fe 11\uff1aKNN\u4f30\u7b97\u7684\u6570\u503c \u53e6\u4e00\u79cd\u5f25\u8865\u5217\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\u662f\u8bad\u7ec3\u56de\u5f52\u6a21\u578b\uff0c\u8bd5\u56fe\u6839\u636e\u5176\u4ed6\u5217\u9884\u6d4b\u67d0\u5217\u7684\u7f3a\u5931\u503c\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u4ece\u6709\u7f3a\u5931\u503c\u7684\u4e00\u5217\u5f00\u59cb\uff0c\u5c06\u8fd9\u4e00\u5217\u4f5c\u4e3a\u65e0\u7f3a\u5931\u503c\u56de\u5f52\u6a21\u578b\u7684\u76ee\u6807\u5217\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6240\u6709\u5176\u4ed6\u5217\uff0c\u5bf9\u76f8\u5173\u5217\u4e2d\u6ca1\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\uff0c\u7136\u540e\u5c1d\u8bd5\u9884\u6d4b\u4e4b\u524d\u5220\u9664\u7684\u6837\u672c\u7684\u76ee\u6807\u5217\uff08\u540c\u4e00\u5217\uff09\u3002\u8fd9\u6837\uff0c\u57fa\u4e8e\u6a21\u578b\u7684\u4f30\u7b97\u5c31\u4f1a\u66f4\u52a0\u7a33\u5065\u3002 \u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u5bf9\u4e8e\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c\u6570\u503c\u5f52\u4e00\u5316\uff0c\u56e0\u4e3a\u5b83\u4eec\u53ef\u4ee5\u81ea\u884c\u5904\u7406\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u6240\u5c55\u793a\u7684\u53ea\u662f\u521b\u5efa\u4e00\u822c\u7279\u5f81\u7684\u4e00\u4e9b\u65b9\u6cd5\u3002\u73b0\u5728\uff0c\u5047\u8bbe\u60a8\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u9884\u6d4b\u4e0d\u540c\u5546\u54c1\uff08\u6bcf\u5468\u6216\u6bcf\u6708\uff09\u5546\u5e97\u9500\u552e\u989d\u7684\u95ee\u9898\u3002\u60a8\u6709\u5546\u54c1\uff0c\u4e5f\u6709\u5546\u5e97 ID\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u6bcf\u4e2a\u5546\u5e97\u7684\u5546\u54c1\u7b49\u7279\u5f81\u3002\u73b0\u5728\uff0c\u8fd9\u662f\u4e0a\u6587\u6ca1\u6709\u8ba8\u8bba\u7684\u7279\u5f81\u4e4b\u4e00\u3002\u8fd9\u7c7b\u7279\u5f81\u4e0d\u80fd\u4e00\u6982\u800c\u8bba\uff0c\u5b8c\u5168\u6765\u81ea\u4e8e\u9886\u57df\u3001\u6570\u636e\u548c\u4e1a\u52a1\u77e5\u8bc6\u3002\u67e5\u770b\u6570\u636e\uff0c\u627e\u51fa\u9002\u5408\u7684\u7279\u5f81\uff0c\u7136\u540e\u521b\u5efa\u76f8\u5e94\u7684\u7279\u5f81\u3002\u5982\u679c\u60a8\u4f7f\u7528\u7684\u662f\u903b\u8f91\u56de\u5f52\u7b49\u7ebf\u6027\u6a21\u578b\u6216 SVM \u7b49\u6a21\u578b\uff0c\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\u5bf9\u7279\u5f81\u8fdb\u884c\u7f29\u653e\u6216\u5f52\u4e00\u5316\u5904\u7406\u3002\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65e0\u9700\u5bf9\u7279\u5f81\u8fdb\u884c\u4efb\u4f55\u5f52\u4e00\u5316\u5904\u7406\u5373\u53ef\u6b63\u5e38\u5de5\u4f5c\u3002","title":"\u7279\u5f81\u5de5\u7a0b"},{"location":"%E7%89%B9%E5%BE%81%E9%80%89%E6%8B%A9/","text":"\u7279\u5f81\u9009\u62e9 \u5f53\u4f60\u521b\u5efa\u4e86\u6210\u5343\u4e0a\u4e07\u4e2a\u7279\u5f81\u540e\uff0c\u5c31\u8be5\u4ece\u4e2d\u6311\u9009\u51fa\u51e0\u4e2a\u4e86\u3002\u4f46\u662f\uff0c\u6211\u4eec\u7edd\u4e0d\u5e94\u8be5\u521b\u5efa\u6210\u767e\u4e0a\u5343\u4e2a\u65e0\u7528\u7684\u7279\u5f81\u3002\u7279\u5f81\u8fc7\u591a\u4f1a\u5e26\u6765\u4e00\u4e2a\u4f17\u6240\u5468\u77e5\u7684\u95ee\u9898\uff0c\u5373 \"\u7ef4\u5ea6\u8bc5\u5492\"\u3002\u5982\u679c\u4f60\u6709\u5f88\u591a\u7279\u5f81\uff0c\u4f60\u4e5f\u5fc5\u987b\u6709\u5f88\u591a\u8bad\u7ec3\u6837\u672c\u6765\u6355\u6349\u6240\u6709\u7279\u5f81\u3002\u4ec0\u4e48\u662f \"\u5927\u91cf \"\u5e76\u6ca1\u6709\u6b63\u786e\u7684\u5b9a\u4e49\uff0c\u8fd9\u9700\u8981\u60a8\u901a\u8fc7\u6b63\u786e\u9a8c\u8bc1\u60a8\u7684\u6a21\u578b\u548c\u68c0\u67e5\u8bad\u7ec3\u6a21\u578b\u6240\u9700\u7684\u65f6\u95f4\u6765\u786e\u5b9a\u3002 \u9009\u62e9\u7279\u5f81\u7684\u6700\u7b80\u5355\u65b9\u6cd5\u662f \u5220\u9664\u65b9\u5dee\u975e\u5e38\u5c0f\u7684\u7279\u5f81 \u3002\u5982\u679c\u7279\u5f81\u7684\u65b9\u5dee\u975e\u5e38\u5c0f\uff08\u5373\u975e\u5e38\u63a5\u8fd1\u4e8e 0\uff09\uff0c\u5b83\u4eec\u5c31\u63a5\u8fd1\u4e8e\u5e38\u91cf\uff0c\u56e0\u6b64\u6839\u672c\u4e0d\u4f1a\u7ed9\u4efb\u4f55\u6a21\u578b\u589e\u52a0\u4efb\u4f55\u4ef7\u503c\u3002\u6700\u597d\u7684\u529e\u6cd5\u5c31\u662f\u53bb\u6389\u5b83\u4eec\uff0c\u4ece\u800c\u964d\u4f4e\u590d\u6742\u5ea6\u3002\u8bf7\u6ce8\u610f\uff0c\u65b9\u5dee\u4e5f\u53d6\u51b3\u4e8e\u6570\u636e\u7684\u7f29\u653e\u3002 Scikit-learn \u7684 VarianceThreshold \u5b9e\u73b0\u4e86\u8fd9\u4e00\u70b9\u3002 from sklearn.feature_selection import VarianceThreshold data = ... # \u521b\u5efa VarianceThreshold \u5bf9\u8c61 var_thresh\uff0c\u6307\u5b9a\u65b9\u5dee\u9608\u503c\u4e3a 0.1 var_thresh = VarianceThreshold ( threshold = 0.1 ) # \u4f7f\u7528 var_thresh \u5bf9\u6570\u636e data \u8fdb\u884c\u62df\u5408\u548c\u53d8\u6362\uff0c\u5c06\u65b9\u5dee\u4f4e\u4e8e\u9608\u503c\u7684\u7279\u5f81\u79fb\u9664 transformed_data = var_thresh . fit_transform ( data ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u5220\u9664\u76f8\u5173\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u3002\u8981\u8ba1\u7b97\u4e0d\u540c\u6570\u5b57\u7279\u5f81\u4e4b\u95f4\u7684\u76f8\u5173\u6027\uff0c\u53ef\u4ee5\u4f7f\u7528\u76ae\u5c14\u900a\u76f8\u5173\u6027\u3002 import pandas as pd from sklearn.datasets import fetch_california_housing # \u52a0\u8f7d\u6570\u636e data = fetch_california_housing () # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u77e9\u9635 X X = data [ \"data\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u7684\u5217\u540d col_names = data [ \"feature_names\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf y y = data [ \"target\" ] df = pd . DataFrame ( X , columns = col_names ) # \u6dfb\u52a0 MedInc_Sqrt \u5217\uff0c\u662f MedInc \u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u5e73\u65b9\u6839\u8fd0\u7b97\u7684\u7ed3\u679c df . loc [:, \"MedInc_Sqrt\" ] = df . MedInc . apply ( np . sqrt ) # \u8ba1\u7b97\u76ae\u5c14\u900a\u76f8\u5173\u6027\u77e9\u9635 df . corr () \u5f97\u51fa\u76f8\u5173\u77e9\u9635\uff0c\u5982\u56fe 1 \u6240\u793a\u3002 \u56fe 1\uff1a\u76ae\u5c14\u900a\u76f8\u5173\u77e9\u9635\u6837\u672c \u6211\u4eec\u770b\u5230\uff0cMedInc_Sqrt \u4e0e MedInc \u7684\u76f8\u5173\u6027\u975e\u5e38\u9ad8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5220\u9664\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u4e00\u4e9b \u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5 \u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u53ea\u4e0d\u8fc7\u662f\u9488\u5bf9\u7ed9\u5b9a\u76ee\u6807\u5bf9\u6bcf\u4e2a\u7279\u5f81\u8fdb\u884c\u8bc4\u5206\u3002 \u4e92\u4fe1\u606f \u3001 \u65b9\u5dee\u5206\u6790 F \u68c0\u9a8c\u548c chi2 \u662f\u4e00\u4e9b\u6700\u5e38\u7528\u7684\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u3002\u5728 scikit- learn \u4e2d\uff0c\u6709\u4e24\u79cd\u65b9\u6cd5\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e9b\u65b9\u6cd5\u3002 - SelectKBest\uff1a\u4fdd\u7559\u5f97\u5206\u6700\u9ad8\u7684 k \u4e2a\u7279\u5f81 - SelectPercentile\uff1a\u4fdd\u7559\u7528\u6237\u6307\u5b9a\u767e\u5206\u6bd4\u5185\u7684\u9876\u7ea7\u7279\u5f81\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u53ea\u6709\u975e\u8d1f\u6570\u636e\u624d\u80fd\u4f7f\u7528 chi2\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u5f53\u6211\u4eec\u6709\u4e00\u4e9b\u5355\u8bcd\u6216\u57fa\u4e8e tf-idf \u7684\u7279\u5f81\u65f6\uff0c\u8fd9\u662f\u4e00\u79cd\u7279\u522b\u6709\u7528\u7684\u7279\u5f81\u9009\u62e9\u6280\u672f\u3002\u6700\u597d\u4e3a\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u521b\u5efa\u4e00\u4e2a\u5305\u88c5\u5668\uff0c\u51e0\u4e4e\u53ef\u4ee5\u7528\u4e8e\u4efb\u4f55\u65b0\u95ee\u9898\u3002 from sklearn.feature_selection import chi2 from sklearn.feature_selection import f_classif from sklearn.feature_selection import f_regression from sklearn.feature_selection import mutual_info_classif from sklearn.feature_selection import mutual_info_regression from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectPercentile class UnivariateFeatureSelction : def __init__ ( self , n_features , problem_type , scoring ): # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u5206\u7c7b\u95ee\u9898 if problem_type == \"classification\" : # \u521b\u5efa\u5b57\u5178 valid_scoring \uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_classif\" : f_classif , \"chi2\" : chi2 , \"mutual_info_classif\" : mutual_info_classif } # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u56de\u5f52\u95ee\u9898 else : # \u521b\u5efa\u5b57\u5178 valid_scoring\uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_regression\" : f_regression , \"mutual_info_regression\" : mutual_info_regression } # \u68c0\u67e5\u7279\u5f81\u91cd\u8981\u6027\u65b9\u5f0f\u662f\u5426\u5728\u5b57\u5178\u4e2d if scoring not in valid_scoring : raise Exception ( \"Invalid scoring function\" ) # \u68c0\u67e5 n_features \u7684\u7c7b\u578b\uff0c\u5982\u679c\u662f\u6574\u6570\uff0c\u5219\u4f7f\u7528 SelectKBest \u8fdb\u884c\u7279\u5f81\u9009\u62e9 if isinstance ( n_features , int ): self . selection = SelectKBest ( valid_scoring [ scoring ], k = n_features ) # \u5982\u679c n_features \u662f\u6d6e\u70b9\u6570\uff0c\u5219\u4f7f\u7528 SelectPercentile \u8fdb\u884c\u7279\u5f81\u9009\u62e9 elif isinstance ( n_features , float ): self . selection = SelectPercentile ( valid_scoring [ scoring ], percentile = int ( n_features * 100 ) ) # \u5982\u679c n_features \u7c7b\u578b\u65e0\u6548\uff0c\u5f15\u53d1\u5f02\u5e38 else : raise Exception ( \"Invalid type of feature\" ) # \u5b9a\u4e49 fit \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 def fit ( self , X , y ): return self . selection . fit ( X , y ) # \u5b9a\u4e49 transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def transform ( self , X ): return self . selection . transform ( X ) # \u5b9a\u4e49 fit_transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668\u5e76\u540c\u65f6\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def fit_transform ( self , X , y ): return self . selection . fit_transform ( X , y ) \u4f7f\u7528\u8be5\u7c7b\u975e\u5e38\u7b80\u5355\u3002 # \u5b9e\u4f8b\u5316\u7279\u5f81\u9009\u62e9\u5668\uff0c\u4fdd\u7559\u524d10%\u7684\u7279\u5f81\uff0c\u56de\u5f52\u95ee\u9898\uff0c\u4f7f\u7528f_regression\u8861\u91cf\u7279\u5f81\u91cd\u8981\u6027 ufs = UnivariateFeatureSelction ( n_features = 0.1 , problem_type = \"regression\" , scoring = \"f_regression\" ) # \u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 ufs . fit ( X , y ) # \u7279\u5f81\u8f6c\u6362 X_transformed = ufs . transform ( X ) \u8fd9\u6837\u5c31\u80fd\u6ee1\u8db3\u5927\u90e8\u5206\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u7684\u9700\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u521b\u5efa\u8f83\u5c11\u800c\u91cd\u8981\u7684\u7279\u5f81\u901a\u5e38\u6bd4\u521b\u5efa\u6570\u4ee5\u767e\u8ba1\u7684\u7279\u5f81\u8981\u597d\u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u4e0d\u4e00\u5b9a\u603b\u662f\u8868\u73b0\u826f\u597d\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u4eba\u4eec\u66f4\u559c\u6b22\u4f7f\u7528\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 \u4f7f\u7528\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u7684\u6700\u7b80\u5355\u5f62\u5f0f\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u3002\u5728\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u4e2d\uff0c\u7b2c\u4e00\u6b65\u662f\u9009\u62e9\u4e00\u4e2a\u6a21\u578b\u3002\u7b2c\u4e8c\u6b65\u662f\u9009\u62e9\u635f\u5931/\u8bc4\u5206\u51fd\u6570\u3002\u7b2c\u4e09\u6b65\u4e5f\u662f\u6700\u540e\u4e00\u6b65\u662f\u53cd\u590d\u8bc4\u4f30\u6bcf\u4e2a\u7279\u5f81\uff0c\u5982\u679c\u80fd\u63d0\u9ad8\u635f\u5931/\u8bc4\u5206\uff0c\u5c31\u5c06\u5176\u6dfb\u52a0\u5230 \"\u597d \"\u7279\u5f81\u5217\u8868\u4e2d\u3002\u6ca1\u6709\u6bd4\u8fd9\u66f4\u7b80\u5355\u7684\u4e86\u3002\u4f46\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u8fd9\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u662f\u6709\u539f\u56e0\u7684\u3002\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u8fc7\u7a0b\u5728\u6bcf\u6b21\u8bc4\u4f30\u7279\u5f81\u65f6\u90fd\u4f1a\u9002\u5408\u7ed9\u5b9a\u7684\u6a21\u578b\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u8ba1\u7b97\u6210\u672c\u975e\u5e38\u9ad8\u3002\u5b8c\u6210\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u4e5f\u9700\u8981\u5927\u91cf\u65f6\u95f4\u3002\u5982\u679c\u4e0d\u6b63\u786e\u4f7f\u7528\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\uff0c\u751a\u81f3\u4f1a\u5bfc\u81f4\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn.datasets import make_classification class GreedyFeatureSelection : # \u5b9a\u4e49\u8bc4\u4f30\u5206\u6570\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd def evaluate_score ( self , X , y ): # \u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u8bad\u7ec3\u6a21\u578b model . fit ( X , y ) # \u9884\u6d4b\u6982\u7387\u503c predictions = model . predict_proba ( X )[:, 1 ] # \u8ba1\u7b97 AUC \u5206\u6570 auc = metrics . roc_auc_score ( y , predictions ) return auc # \u7279\u5f81\u9009\u62e9\u51fd\u6570 def _feature_selection ( self , X , y ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6700\u4f73\u7279\u5f81\u548c\u6700\u4f73\u5206\u6570 good_features = [] best_scores = [] # \u83b7\u53d6\u7279\u5f81\u6570\u91cf num_features = X . shape [ 1 ] # \u5f00\u59cb\u7279\u5f81\u9009\u62e9\u7684\u5faa\u73af while True : this_feature = None best_score = 0 # \u904d\u5386\u6bcf\u4e2a\u7279\u5f81 for feature in range ( num_features ): if feature in good_features : continue selected_features = good_features + [ feature ] xtrain = X [:, selected_features ] score = self . evaluate_score ( xtrain , y ) # \u5982\u679c\u5f53\u524d\u7279\u5f81\u7684\u5f97\u5206\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u5f97\u5206\uff0c\u5219\u66f4\u65b0 if score > best_score : this_feature = feature best_score = score # \u82e5\u627e\u5230\u4e86\u65b0\u7684\u6700\u4f73\u7279\u5f81 if this_feature != None : # \u7279\u5f81\u6dfb\u52a0\u5230 good_features \u5217\u8868 good_features . append ( this_feature ) # \u5f97\u5206\u6dfb\u52a0\u5230 best_scores \u5217\u8868 best_scores . append ( best_score ) # \u5982\u679c best_scores \u5217\u8868\u957f\u5ea6\u5927\u4e8e2\uff0c\u5e76\u4e14\u6700\u540e\u4e24\u4e2a\u5f97\u5206\u76f8\u6bd4\u8f83\u5dee\uff0c\u5219\u7ed3\u675f\u5faa\u73af if len ( best_scores ) > 2 : if best_scores [ - 1 ] < best_scores [ - 2 ]: break # \u8fd4\u56de\u6700\u4f73\u7279\u5f81\u7684\u5f97\u5206\u5217\u8868\u548c\u6700\u4f73\u7279\u5f81\u5217\u8868 return best_scores [: - 1 ], good_features [: - 1 ] # \u5b9a\u4e49\u7c7b\u7684\u8c03\u7528\u65b9\u6cd5\uff0c\u7528\u4e8e\u6267\u884c\u7279\u5f81\u9009\u62e9 def __call__ ( self , X , y ): scores , features = self . _feature_selection ( X , y ) return X [:, features ], scores if __name__ == \"__main__\" : # \u751f\u6210\u4e00\u4e2a\u793a\u4f8b\u7684\u5206\u7c7b\u6570\u636e\u96c6 X \u548c\u6807\u7b7e y X , y = make_classification ( n_samples = 1000 , n_features = 100 ) # \u5b9e\u4f8b\u5316 GreedyFeatureSelection \u7c7b\uff0c\u5e76\u4f7f\u7528 __call__ \u65b9\u6cd5\u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed , scores = GreedyFeatureSelection ()( X , y ) \u8fd9\u79cd\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u4f1a\u8fd4\u56de\u5206\u6570\u548c\u7279\u5f81\u7d22\u5f15\u5217\u8868\u3002\u56fe 2 \u663e\u793a\u4e86\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u589e\u52a0\u4e00\u4e2a\u65b0\u7279\u5f81\u540e\uff0c\u5206\u6570\u662f\u5982\u4f55\u63d0\u9ad8\u7684\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u67d0\u4e00\u70b9\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u65e0\u6cd5\u63d0\u9ad8\u5206\u6570\u4e86\uff0c\u8fd9\u5c31\u662f\u6211\u4eec\u505c\u6b62\u7684\u5730\u65b9\u3002 \u53e6\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u9012\u5f52\u7279\u5f81\u6d88\u9664\u6cd5\uff08RFE\uff09\u3002\u5728\u524d\u4e00\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a\u7279\u5f81\u5f00\u59cb\uff0c\u7136\u540e\u4e0d\u65ad\u6dfb\u52a0\u65b0\u7684\u7279\u5f81\uff0c\u4f46\u5728 RFE \u4e2d\uff0c\u6211\u4eec\u4ece\u6240\u6709\u7279\u5f81\u5f00\u59cb\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u4e0d\u65ad\u53bb\u9664\u4e00\u4e2a\u5bf9\u7ed9\u5b9a\u6a21\u578b\u63d0\u4f9b\u6700\u5c0f\u503c\u7684\u7279\u5f81\u3002\u4f46\u6211\u4eec\u5982\u4f55\u77e5\u9053\u54ea\u4e2a\u7279\u5f81\u7684\u4ef7\u503c\u6700\u5c0f\u5462\uff1f\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7ebf\u6027\u652f\u6301\u5411\u91cf\u673a\uff08SVM\uff09\u6216\u903b\u8f91\u56de\u5f52\u7b49\u6a21\u578b\uff0c\u6211\u4eec\u4f1a\u4e3a\u6bcf\u4e2a\u7279\u5f81\u5f97\u5230\u4e00\u4e2a\u7cfb\u6570\uff0c\u8be5\u7cfb\u6570\u51b3\u5b9a\u4e86\u7279\u5f81\u7684\u91cd\u8981\u6027\u3002\u800c\u5bf9\u4e8e\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u7279\u5f81\u91cd\u8981\u6027\uff0c\u800c\u4e0d\u662f\u7cfb\u6570\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u53ef\u4ee5\u5254\u9664\u6700\u4e0d\u91cd\u8981\u7684\u7279\u5f81\uff0c\u76f4\u5230\u8fbe\u5230\u6240\u9700\u7684\u7279\u5f81\u6570\u91cf\u4e3a\u6b62\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u51b3\u5b9a\u8981\u4fdd\u7559\u591a\u5c11\u7279\u5f81\u3002 \u56fe 2\uff1a\u589e\u52a0\u65b0\u7279\u5f81\u540e\uff0c\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7684 AUC \u5206\u6570\u5982\u4f55\u53d8\u5316 \u5f53\u6211\u4eec\u8fdb\u884c\u9012\u5f52\u7279\u5f81\u5254\u9664\u65f6\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u4f1a\u5254\u9664\u7279\u5f81\u91cd\u8981\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u6216\u7cfb\u6570\u63a5\u8fd1 0 \u7684\u7279\u5f81\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5f53\u4f60\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u8fd9\u6837\u7684\u6a21\u578b\u8fdb\u884c\u4e8c\u5143\u5206\u7c7b\u65f6\uff0c\u5982\u679c\u7279\u5f81\u5bf9\u6b63\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u6b63\uff0c\u800c\u5982\u679c\u7279\u5f81\u5bf9\u8d1f\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u8d1f\u3002\u4fee\u6539\u6211\u4eec\u7684\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7c7b\uff0c\u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u9012\u5f52\u7279\u5f81\u6d88\u9664\u7c7b\u975e\u5e38\u5bb9\u6613\uff0c\u4f46 scikit-learn \u4e5f\u63d0\u4f9b\u4e86 RFE\u3002\u4e0b\u9762\u7684\u793a\u4f8b\u5c55\u793a\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u7528\u6cd5\u3002 import pandas as pd from sklearn.feature_selection import RFE from sklearn.linear_model import LinearRegression from sklearn.datasets import fetch_california_housing data = fetch_california_housing () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] model = LinearRegression () # \u521b\u5efa RFE\uff08\u9012\u5f52\u7279\u5f81\u6d88\u9664\uff09\uff0c\u6307\u5b9a\u6a21\u578b\u4e3a\u7ebf\u6027\u56de\u5f52\u6a21\u578b\uff0c\u8981\u9009\u62e9\u7684\u7279\u5f81\u6570\u91cf\u4e3a 3 rfe = RFE ( estimator = model , n_features_to_select = 3 ) # \u8bad\u7ec3\u6a21\u578b rfe . fit ( X , y ) # \u4f7f\u7528 RFE \u9009\u62e9\u7684\u7279\u5f81\u8fdb\u884c\u6570\u636e\u8f6c\u6362 X_transformed = rfe . transform ( X ) \u6211\u4eec\u770b\u5230\u4e86\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u7684\u4e24\u79cd\u4e0d\u540c\u7684\u8d2a\u5a6a\u65b9\u6cd5\u3002\u4f46\u4e5f\u53ef\u4ee5\u6839\u636e\u6570\u636e\u62df\u5408\u6a21\u578b\uff0c\u7136\u540e\u901a\u8fc7\u7279\u5f81\u7cfb\u6570\u6216\u7279\u5f81\u7684\u91cd\u8981\u6027\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u5982\u679c\u4f7f\u7528\u7cfb\u6570\uff0c\u5219\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u9608\u503c\uff0c\u5982\u679c\u7cfb\u6570\u9ad8\u4e8e\u8be5\u9608\u503c\uff0c\u5219\u53ef\u4ee5\u4fdd\u7559\u8be5\u7279\u5f81\uff0c\u5426\u5219\u5c06\u5176\u5254\u9664\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4ece\u968f\u673a\u68ee\u6797\u8fd9\u6837\u7684\u6a21\u578b\u4e2d\u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u5b9e\u4f8b\u5316\u968f\u673a\u68ee\u6797\u6a21\u578b model = RandomForestRegressor () # \u62df\u5408\u6a21\u578b model . fit ( X , y ) \u968f\u673a\u68ee\u6797\uff08\u6216\u4efb\u4f55\u6a21\u578b\uff09\u7684\u7279\u5f81\u91cd\u8981\u6027\u53ef\u6309\u5982\u4e0b\u65b9\u5f0f\u7ed8\u5236\u3002 # \u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027 importances = model . feature_importances_ # \u964d\u5e8f\u6392\u5217 idxs = np . argsort ( importances ) # \u8bbe\u5b9a\u6807\u9898 plt . title ( 'Feature Importances' ) # \u521b\u5efa\u76f4\u65b9\u56fe plt . barh ( range ( len ( idxs )), importances [ idxs ], align = 'center' ) # y\u8f74\u6807\u7b7e plt . yticks ( range ( len ( idxs )), [ col_names [ i ] for i in idxs ]) # x\u8f74\u6807\u7b7e plt . xlabel ( 'Random Forest Feature Importance' ) plt . show () \u7ed3\u679c\u5982\u56fe 3 \u6240\u793a\u3002 \u56fe 3\uff1a\u7279\u5f81\u91cd\u8981\u6027\u56fe \u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u6700\u4f73\u7279\u5f81\u5e76\u4e0d\u662f\u4ec0\u4e48\u65b0\u9c9c\u4e8b\u3002\u60a8\u53ef\u4ee5\u4ece\u4e00\u4e2a\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u53e6\u4e00\u4e2a\u6a21\u578b\u8fdb\u884c\u8bad\u7ec3\u3002\u4f8b\u5982\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u7cfb\u6570\u6765\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u968f\u673a\u68ee\u6797\uff08Random Forest\uff09\u5bf9\u6240\u9009\u7279\u5f81\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 SelectFromModel \u7c7b\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u76f4\u63a5\u4ece\u7ed9\u5b9a\u7684\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u6839\u636e\u9700\u8981\u6307\u5b9a\u7cfb\u6570\u6216\u7279\u5f81\u91cd\u8981\u6027\u7684\u9608\u503c\uff0c\u4ee5\u53ca\u8981\u9009\u62e9\u7684\u7279\u5f81\u7684\u6700\u5927\u6570\u91cf\u3002 \u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\uff0c\u6211\u4eec\u4f7f\u7528 SelectFromModel \u4e2d\u7684\u9ed8\u8ba4\u53c2\u6570\u6765\u9009\u62e9\u7279\u5f81\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor from sklearn.feature_selection import SelectFromModel data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u521b\u5efa\u968f\u673a\u68ee\u6797\u6a21\u578b\u56de\u5f52\u6a21\u578b model = RandomForestRegressor () # \u521b\u5efa SelectFromModel \u5bf9\u8c61 sfm\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u6a21\u578b\u4f5c\u4e3a\u4f30\u7b97\u5668 sfm = SelectFromModel ( estimator = model ) # \u4f7f\u7528 sfm \u5bf9\u7279\u5f81\u77e9\u9635 X \u548c\u76ee\u6807\u53d8\u91cf y \u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed = sfm . fit_transform ( X , y ) # \u83b7\u53d6\u7ecf\u8fc7\u7279\u5f81\u9009\u62e9\u540e\u7684\u7279\u5f81\u63a9\u7801\uff08True \u8868\u793a\u7279\u5f81\u88ab\u9009\u62e9\uff0cFalse \u8868\u793a\u7279\u5f81\u672a\u88ab\u9009\u62e9\uff09 support = sfm . get_support () # \u6253\u5370\u88ab\u9009\u62e9\u7684\u7279\u5f81\u5217\u540d print ([ x for x , y in zip ( col_names , support ) if y == True ]) \u4e0a\u9762\u7a0b\u5e8f\u6253\u5370\u7ed3\u679c\uff1a ['bmi'\uff0c's5']\u3002\u6211\u4eec\u518d\u770b\u56fe 3\uff0c\u5c31\u4f1a\u53d1\u73b0\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u4e2a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u76f4\u63a5\u4ece\u968f\u673a\u68ee\u6797\u63d0\u4f9b\u7684\u7279\u5f81\u91cd\u8981\u6027\u4e2d\u8fdb\u884c\u9009\u62e9\u3002\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4ef6\u4e8b\uff0c\u90a3\u5c31\u662f\u4f7f\u7528 L1\uff08Lasso\uff09\u60e9\u7f5a\u6a21\u578b \u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u5f53\u6211\u4eec\u4f7f\u7528 L1 \u60e9\u7f5a\u8fdb\u884c\u6b63\u5219\u5316\u65f6\uff0c\u5927\u90e8\u5206\u7cfb\u6570\u90fd\u5c06\u4e3a 0\uff08\u6216\u63a5\u8fd1 0\uff09\uff0c\u56e0\u6b64\u6211\u4eec\u8981\u9009\u62e9\u7cfb\u6570\u4e0d\u4e3a 0 \u7684\u7279\u5f81\u3002\u53ea\u9700\u5c06\u6a21\u578b\u9009\u62e9\u7247\u6bb5\u4e2d\u7684\u968f\u673a\u68ee\u6797\u66ff\u6362\u4e3a\u652f\u6301 L1 \u60e9\u7f5a\u7684\u6a21\u578b\uff08\u5982 lasso \u56de\u5f52\uff09\u5373\u53ef\u3002\u6240\u6709\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u90fd\u63d0\u4f9b\u7279\u5f81\u91cd\u8981\u6027\uff0c\u56e0\u6b64\u672c\u7ae0\u4e2d\u5c55\u793a\u7684\u6240\u6709\u57fa\u4e8e\u6a21\u578b\u7684\u7247\u6bb5\u90fd\u53ef\u7528\u4e8e XGBoost\u3001LightGBM \u6216 CatBoost\u3002\u7279\u5f81\u91cd\u8981\u6027\u51fd\u6570\u7684\u540d\u79f0\u53ef\u80fd\u4e0d\u540c\uff0c\u4ea7\u751f\u7ed3\u679c\u7684\u683c\u5f0f\u4e5f\u53ef\u80fd\u4e0d\u540c\uff0c\u4f46\u7528\u6cd5\u662f\u4e00\u6837\u7684\u3002\u6700\u540e\uff0c\u5728\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u65f6\u5fc5\u987b\u5c0f\u5fc3\u8c28\u614e\u3002\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u9009\u62e9\u7279\u5f81\uff0c\u5e76\u5728\u9a8c\u8bc1\u6570\u636e\u4e0a\u9a8c\u8bc1\u6a21\u578b\uff0c\u4ee5\u4fbf\u5728\u4e0d\u8fc7\u5ea6\u62df\u5408\u6a21\u578b\u7684\u60c5\u51b5\u4e0b\u6b63\u786e\u9009\u62e9\u7279\u5f81\u3002","title":"\u7279\u5f81\u9009\u62e9"},{"location":"%E7%89%B9%E5%BE%81%E9%80%89%E6%8B%A9/#_1","text":"\u5f53\u4f60\u521b\u5efa\u4e86\u6210\u5343\u4e0a\u4e07\u4e2a\u7279\u5f81\u540e\uff0c\u5c31\u8be5\u4ece\u4e2d\u6311\u9009\u51fa\u51e0\u4e2a\u4e86\u3002\u4f46\u662f\uff0c\u6211\u4eec\u7edd\u4e0d\u5e94\u8be5\u521b\u5efa\u6210\u767e\u4e0a\u5343\u4e2a\u65e0\u7528\u7684\u7279\u5f81\u3002\u7279\u5f81\u8fc7\u591a\u4f1a\u5e26\u6765\u4e00\u4e2a\u4f17\u6240\u5468\u77e5\u7684\u95ee\u9898\uff0c\u5373 \"\u7ef4\u5ea6\u8bc5\u5492\"\u3002\u5982\u679c\u4f60\u6709\u5f88\u591a\u7279\u5f81\uff0c\u4f60\u4e5f\u5fc5\u987b\u6709\u5f88\u591a\u8bad\u7ec3\u6837\u672c\u6765\u6355\u6349\u6240\u6709\u7279\u5f81\u3002\u4ec0\u4e48\u662f \"\u5927\u91cf \"\u5e76\u6ca1\u6709\u6b63\u786e\u7684\u5b9a\u4e49\uff0c\u8fd9\u9700\u8981\u60a8\u901a\u8fc7\u6b63\u786e\u9a8c\u8bc1\u60a8\u7684\u6a21\u578b\u548c\u68c0\u67e5\u8bad\u7ec3\u6a21\u578b\u6240\u9700\u7684\u65f6\u95f4\u6765\u786e\u5b9a\u3002 \u9009\u62e9\u7279\u5f81\u7684\u6700\u7b80\u5355\u65b9\u6cd5\u662f \u5220\u9664\u65b9\u5dee\u975e\u5e38\u5c0f\u7684\u7279\u5f81 \u3002\u5982\u679c\u7279\u5f81\u7684\u65b9\u5dee\u975e\u5e38\u5c0f\uff08\u5373\u975e\u5e38\u63a5\u8fd1\u4e8e 0\uff09\uff0c\u5b83\u4eec\u5c31\u63a5\u8fd1\u4e8e\u5e38\u91cf\uff0c\u56e0\u6b64\u6839\u672c\u4e0d\u4f1a\u7ed9\u4efb\u4f55\u6a21\u578b\u589e\u52a0\u4efb\u4f55\u4ef7\u503c\u3002\u6700\u597d\u7684\u529e\u6cd5\u5c31\u662f\u53bb\u6389\u5b83\u4eec\uff0c\u4ece\u800c\u964d\u4f4e\u590d\u6742\u5ea6\u3002\u8bf7\u6ce8\u610f\uff0c\u65b9\u5dee\u4e5f\u53d6\u51b3\u4e8e\u6570\u636e\u7684\u7f29\u653e\u3002 Scikit-learn \u7684 VarianceThreshold \u5b9e\u73b0\u4e86\u8fd9\u4e00\u70b9\u3002 from sklearn.feature_selection import VarianceThreshold data = ... # \u521b\u5efa VarianceThreshold \u5bf9\u8c61 var_thresh\uff0c\u6307\u5b9a\u65b9\u5dee\u9608\u503c\u4e3a 0.1 var_thresh = VarianceThreshold ( threshold = 0.1 ) # \u4f7f\u7528 var_thresh \u5bf9\u6570\u636e data \u8fdb\u884c\u62df\u5408\u548c\u53d8\u6362\uff0c\u5c06\u65b9\u5dee\u4f4e\u4e8e\u9608\u503c\u7684\u7279\u5f81\u79fb\u9664 transformed_data = var_thresh . fit_transform ( data ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u5220\u9664\u76f8\u5173\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u3002\u8981\u8ba1\u7b97\u4e0d\u540c\u6570\u5b57\u7279\u5f81\u4e4b\u95f4\u7684\u76f8\u5173\u6027\uff0c\u53ef\u4ee5\u4f7f\u7528\u76ae\u5c14\u900a\u76f8\u5173\u6027\u3002 import pandas as pd from sklearn.datasets import fetch_california_housing # \u52a0\u8f7d\u6570\u636e data = fetch_california_housing () # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u77e9\u9635 X X = data [ \"data\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u7684\u5217\u540d col_names = data [ \"feature_names\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf y y = data [ \"target\" ] df = pd . DataFrame ( X , columns = col_names ) # \u6dfb\u52a0 MedInc_Sqrt \u5217\uff0c\u662f MedInc \u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u5e73\u65b9\u6839\u8fd0\u7b97\u7684\u7ed3\u679c df . loc [:, \"MedInc_Sqrt\" ] = df . MedInc . apply ( np . sqrt ) # \u8ba1\u7b97\u76ae\u5c14\u900a\u76f8\u5173\u6027\u77e9\u9635 df . corr () \u5f97\u51fa\u76f8\u5173\u77e9\u9635\uff0c\u5982\u56fe 1 \u6240\u793a\u3002 \u56fe 1\uff1a\u76ae\u5c14\u900a\u76f8\u5173\u77e9\u9635\u6837\u672c \u6211\u4eec\u770b\u5230\uff0cMedInc_Sqrt \u4e0e MedInc \u7684\u76f8\u5173\u6027\u975e\u5e38\u9ad8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5220\u9664\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u4e00\u4e9b \u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5 \u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u53ea\u4e0d\u8fc7\u662f\u9488\u5bf9\u7ed9\u5b9a\u76ee\u6807\u5bf9\u6bcf\u4e2a\u7279\u5f81\u8fdb\u884c\u8bc4\u5206\u3002 \u4e92\u4fe1\u606f \u3001 \u65b9\u5dee\u5206\u6790 F \u68c0\u9a8c\u548c chi2 \u662f\u4e00\u4e9b\u6700\u5e38\u7528\u7684\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u3002\u5728 scikit- learn \u4e2d\uff0c\u6709\u4e24\u79cd\u65b9\u6cd5\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e9b\u65b9\u6cd5\u3002 - SelectKBest\uff1a\u4fdd\u7559\u5f97\u5206\u6700\u9ad8\u7684 k \u4e2a\u7279\u5f81 - SelectPercentile\uff1a\u4fdd\u7559\u7528\u6237\u6307\u5b9a\u767e\u5206\u6bd4\u5185\u7684\u9876\u7ea7\u7279\u5f81\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u53ea\u6709\u975e\u8d1f\u6570\u636e\u624d\u80fd\u4f7f\u7528 chi2\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u5f53\u6211\u4eec\u6709\u4e00\u4e9b\u5355\u8bcd\u6216\u57fa\u4e8e tf-idf \u7684\u7279\u5f81\u65f6\uff0c\u8fd9\u662f\u4e00\u79cd\u7279\u522b\u6709\u7528\u7684\u7279\u5f81\u9009\u62e9\u6280\u672f\u3002\u6700\u597d\u4e3a\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u521b\u5efa\u4e00\u4e2a\u5305\u88c5\u5668\uff0c\u51e0\u4e4e\u53ef\u4ee5\u7528\u4e8e\u4efb\u4f55\u65b0\u95ee\u9898\u3002 from sklearn.feature_selection import chi2 from sklearn.feature_selection import f_classif from sklearn.feature_selection import f_regression from sklearn.feature_selection import mutual_info_classif from sklearn.feature_selection import mutual_info_regression from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectPercentile class UnivariateFeatureSelction : def __init__ ( self , n_features , problem_type , scoring ): # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u5206\u7c7b\u95ee\u9898 if problem_type == \"classification\" : # \u521b\u5efa\u5b57\u5178 valid_scoring \uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_classif\" : f_classif , \"chi2\" : chi2 , \"mutual_info_classif\" : mutual_info_classif } # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u56de\u5f52\u95ee\u9898 else : # \u521b\u5efa\u5b57\u5178 valid_scoring\uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_regression\" : f_regression , \"mutual_info_regression\" : mutual_info_regression } # \u68c0\u67e5\u7279\u5f81\u91cd\u8981\u6027\u65b9\u5f0f\u662f\u5426\u5728\u5b57\u5178\u4e2d if scoring not in valid_scoring : raise Exception ( \"Invalid scoring function\" ) # \u68c0\u67e5 n_features \u7684\u7c7b\u578b\uff0c\u5982\u679c\u662f\u6574\u6570\uff0c\u5219\u4f7f\u7528 SelectKBest \u8fdb\u884c\u7279\u5f81\u9009\u62e9 if isinstance ( n_features , int ): self . selection = SelectKBest ( valid_scoring [ scoring ], k = n_features ) # \u5982\u679c n_features \u662f\u6d6e\u70b9\u6570\uff0c\u5219\u4f7f\u7528 SelectPercentile \u8fdb\u884c\u7279\u5f81\u9009\u62e9 elif isinstance ( n_features , float ): self . selection = SelectPercentile ( valid_scoring [ scoring ], percentile = int ( n_features * 100 ) ) # \u5982\u679c n_features \u7c7b\u578b\u65e0\u6548\uff0c\u5f15\u53d1\u5f02\u5e38 else : raise Exception ( \"Invalid type of feature\" ) # \u5b9a\u4e49 fit \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 def fit ( self , X , y ): return self . selection . fit ( X , y ) # \u5b9a\u4e49 transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def transform ( self , X ): return self . selection . transform ( X ) # \u5b9a\u4e49 fit_transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668\u5e76\u540c\u65f6\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def fit_transform ( self , X , y ): return self . selection . fit_transform ( X , y ) \u4f7f\u7528\u8be5\u7c7b\u975e\u5e38\u7b80\u5355\u3002 # \u5b9e\u4f8b\u5316\u7279\u5f81\u9009\u62e9\u5668\uff0c\u4fdd\u7559\u524d10%\u7684\u7279\u5f81\uff0c\u56de\u5f52\u95ee\u9898\uff0c\u4f7f\u7528f_regression\u8861\u91cf\u7279\u5f81\u91cd\u8981\u6027 ufs = UnivariateFeatureSelction ( n_features = 0.1 , problem_type = \"regression\" , scoring = \"f_regression\" ) # \u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 ufs . fit ( X , y ) # \u7279\u5f81\u8f6c\u6362 X_transformed = ufs . transform ( X ) \u8fd9\u6837\u5c31\u80fd\u6ee1\u8db3\u5927\u90e8\u5206\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u7684\u9700\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u521b\u5efa\u8f83\u5c11\u800c\u91cd\u8981\u7684\u7279\u5f81\u901a\u5e38\u6bd4\u521b\u5efa\u6570\u4ee5\u767e\u8ba1\u7684\u7279\u5f81\u8981\u597d\u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u4e0d\u4e00\u5b9a\u603b\u662f\u8868\u73b0\u826f\u597d\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u4eba\u4eec\u66f4\u559c\u6b22\u4f7f\u7528\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 \u4f7f\u7528\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u7684\u6700\u7b80\u5355\u5f62\u5f0f\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u3002\u5728\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u4e2d\uff0c\u7b2c\u4e00\u6b65\u662f\u9009\u62e9\u4e00\u4e2a\u6a21\u578b\u3002\u7b2c\u4e8c\u6b65\u662f\u9009\u62e9\u635f\u5931/\u8bc4\u5206\u51fd\u6570\u3002\u7b2c\u4e09\u6b65\u4e5f\u662f\u6700\u540e\u4e00\u6b65\u662f\u53cd\u590d\u8bc4\u4f30\u6bcf\u4e2a\u7279\u5f81\uff0c\u5982\u679c\u80fd\u63d0\u9ad8\u635f\u5931/\u8bc4\u5206\uff0c\u5c31\u5c06\u5176\u6dfb\u52a0\u5230 \"\u597d \"\u7279\u5f81\u5217\u8868\u4e2d\u3002\u6ca1\u6709\u6bd4\u8fd9\u66f4\u7b80\u5355\u7684\u4e86\u3002\u4f46\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u8fd9\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u662f\u6709\u539f\u56e0\u7684\u3002\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u8fc7\u7a0b\u5728\u6bcf\u6b21\u8bc4\u4f30\u7279\u5f81\u65f6\u90fd\u4f1a\u9002\u5408\u7ed9\u5b9a\u7684\u6a21\u578b\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u8ba1\u7b97\u6210\u672c\u975e\u5e38\u9ad8\u3002\u5b8c\u6210\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u4e5f\u9700\u8981\u5927\u91cf\u65f6\u95f4\u3002\u5982\u679c\u4e0d\u6b63\u786e\u4f7f\u7528\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\uff0c\u751a\u81f3\u4f1a\u5bfc\u81f4\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn.datasets import make_classification class GreedyFeatureSelection : # \u5b9a\u4e49\u8bc4\u4f30\u5206\u6570\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd def evaluate_score ( self , X , y ): # \u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u8bad\u7ec3\u6a21\u578b model . fit ( X , y ) # \u9884\u6d4b\u6982\u7387\u503c predictions = model . predict_proba ( X )[:, 1 ] # \u8ba1\u7b97 AUC \u5206\u6570 auc = metrics . roc_auc_score ( y , predictions ) return auc # \u7279\u5f81\u9009\u62e9\u51fd\u6570 def _feature_selection ( self , X , y ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6700\u4f73\u7279\u5f81\u548c\u6700\u4f73\u5206\u6570 good_features = [] best_scores = [] # \u83b7\u53d6\u7279\u5f81\u6570\u91cf num_features = X . shape [ 1 ] # \u5f00\u59cb\u7279\u5f81\u9009\u62e9\u7684\u5faa\u73af while True : this_feature = None best_score = 0 # \u904d\u5386\u6bcf\u4e2a\u7279\u5f81 for feature in range ( num_features ): if feature in good_features : continue selected_features = good_features + [ feature ] xtrain = X [:, selected_features ] score = self . evaluate_score ( xtrain , y ) # \u5982\u679c\u5f53\u524d\u7279\u5f81\u7684\u5f97\u5206\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u5f97\u5206\uff0c\u5219\u66f4\u65b0 if score > best_score : this_feature = feature best_score = score # \u82e5\u627e\u5230\u4e86\u65b0\u7684\u6700\u4f73\u7279\u5f81 if this_feature != None : # \u7279\u5f81\u6dfb\u52a0\u5230 good_features \u5217\u8868 good_features . append ( this_feature ) # \u5f97\u5206\u6dfb\u52a0\u5230 best_scores \u5217\u8868 best_scores . append ( best_score ) # \u5982\u679c best_scores \u5217\u8868\u957f\u5ea6\u5927\u4e8e2\uff0c\u5e76\u4e14\u6700\u540e\u4e24\u4e2a\u5f97\u5206\u76f8\u6bd4\u8f83\u5dee\uff0c\u5219\u7ed3\u675f\u5faa\u73af if len ( best_scores ) > 2 : if best_scores [ - 1 ] < best_scores [ - 2 ]: break # \u8fd4\u56de\u6700\u4f73\u7279\u5f81\u7684\u5f97\u5206\u5217\u8868\u548c\u6700\u4f73\u7279\u5f81\u5217\u8868 return best_scores [: - 1 ], good_features [: - 1 ] # \u5b9a\u4e49\u7c7b\u7684\u8c03\u7528\u65b9\u6cd5\uff0c\u7528\u4e8e\u6267\u884c\u7279\u5f81\u9009\u62e9 def __call__ ( self , X , y ): scores , features = self . _feature_selection ( X , y ) return X [:, features ], scores if __name__ == \"__main__\" : # \u751f\u6210\u4e00\u4e2a\u793a\u4f8b\u7684\u5206\u7c7b\u6570\u636e\u96c6 X \u548c\u6807\u7b7e y X , y = make_classification ( n_samples = 1000 , n_features = 100 ) # \u5b9e\u4f8b\u5316 GreedyFeatureSelection \u7c7b\uff0c\u5e76\u4f7f\u7528 __call__ \u65b9\u6cd5\u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed , scores = GreedyFeatureSelection ()( X , y ) \u8fd9\u79cd\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u4f1a\u8fd4\u56de\u5206\u6570\u548c\u7279\u5f81\u7d22\u5f15\u5217\u8868\u3002\u56fe 2 \u663e\u793a\u4e86\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u589e\u52a0\u4e00\u4e2a\u65b0\u7279\u5f81\u540e\uff0c\u5206\u6570\u662f\u5982\u4f55\u63d0\u9ad8\u7684\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u67d0\u4e00\u70b9\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u65e0\u6cd5\u63d0\u9ad8\u5206\u6570\u4e86\uff0c\u8fd9\u5c31\u662f\u6211\u4eec\u505c\u6b62\u7684\u5730\u65b9\u3002 \u53e6\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u9012\u5f52\u7279\u5f81\u6d88\u9664\u6cd5\uff08RFE\uff09\u3002\u5728\u524d\u4e00\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a\u7279\u5f81\u5f00\u59cb\uff0c\u7136\u540e\u4e0d\u65ad\u6dfb\u52a0\u65b0\u7684\u7279\u5f81\uff0c\u4f46\u5728 RFE \u4e2d\uff0c\u6211\u4eec\u4ece\u6240\u6709\u7279\u5f81\u5f00\u59cb\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u4e0d\u65ad\u53bb\u9664\u4e00\u4e2a\u5bf9\u7ed9\u5b9a\u6a21\u578b\u63d0\u4f9b\u6700\u5c0f\u503c\u7684\u7279\u5f81\u3002\u4f46\u6211\u4eec\u5982\u4f55\u77e5\u9053\u54ea\u4e2a\u7279\u5f81\u7684\u4ef7\u503c\u6700\u5c0f\u5462\uff1f\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7ebf\u6027\u652f\u6301\u5411\u91cf\u673a\uff08SVM\uff09\u6216\u903b\u8f91\u56de\u5f52\u7b49\u6a21\u578b\uff0c\u6211\u4eec\u4f1a\u4e3a\u6bcf\u4e2a\u7279\u5f81\u5f97\u5230\u4e00\u4e2a\u7cfb\u6570\uff0c\u8be5\u7cfb\u6570\u51b3\u5b9a\u4e86\u7279\u5f81\u7684\u91cd\u8981\u6027\u3002\u800c\u5bf9\u4e8e\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u7279\u5f81\u91cd\u8981\u6027\uff0c\u800c\u4e0d\u662f\u7cfb\u6570\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u53ef\u4ee5\u5254\u9664\u6700\u4e0d\u91cd\u8981\u7684\u7279\u5f81\uff0c\u76f4\u5230\u8fbe\u5230\u6240\u9700\u7684\u7279\u5f81\u6570\u91cf\u4e3a\u6b62\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u51b3\u5b9a\u8981\u4fdd\u7559\u591a\u5c11\u7279\u5f81\u3002 \u56fe 2\uff1a\u589e\u52a0\u65b0\u7279\u5f81\u540e\uff0c\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7684 AUC \u5206\u6570\u5982\u4f55\u53d8\u5316 \u5f53\u6211\u4eec\u8fdb\u884c\u9012\u5f52\u7279\u5f81\u5254\u9664\u65f6\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u4f1a\u5254\u9664\u7279\u5f81\u91cd\u8981\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u6216\u7cfb\u6570\u63a5\u8fd1 0 \u7684\u7279\u5f81\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5f53\u4f60\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u8fd9\u6837\u7684\u6a21\u578b\u8fdb\u884c\u4e8c\u5143\u5206\u7c7b\u65f6\uff0c\u5982\u679c\u7279\u5f81\u5bf9\u6b63\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u6b63\uff0c\u800c\u5982\u679c\u7279\u5f81\u5bf9\u8d1f\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u8d1f\u3002\u4fee\u6539\u6211\u4eec\u7684\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7c7b\uff0c\u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u9012\u5f52\u7279\u5f81\u6d88\u9664\u7c7b\u975e\u5e38\u5bb9\u6613\uff0c\u4f46 scikit-learn \u4e5f\u63d0\u4f9b\u4e86 RFE\u3002\u4e0b\u9762\u7684\u793a\u4f8b\u5c55\u793a\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u7528\u6cd5\u3002 import pandas as pd from sklearn.feature_selection import RFE from sklearn.linear_model import LinearRegression from sklearn.datasets import fetch_california_housing data = fetch_california_housing () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] model = LinearRegression () # \u521b\u5efa RFE\uff08\u9012\u5f52\u7279\u5f81\u6d88\u9664\uff09\uff0c\u6307\u5b9a\u6a21\u578b\u4e3a\u7ebf\u6027\u56de\u5f52\u6a21\u578b\uff0c\u8981\u9009\u62e9\u7684\u7279\u5f81\u6570\u91cf\u4e3a 3 rfe = RFE ( estimator = model , n_features_to_select = 3 ) # \u8bad\u7ec3\u6a21\u578b rfe . fit ( X , y ) # \u4f7f\u7528 RFE \u9009\u62e9\u7684\u7279\u5f81\u8fdb\u884c\u6570\u636e\u8f6c\u6362 X_transformed = rfe . transform ( X ) \u6211\u4eec\u770b\u5230\u4e86\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u7684\u4e24\u79cd\u4e0d\u540c\u7684\u8d2a\u5a6a\u65b9\u6cd5\u3002\u4f46\u4e5f\u53ef\u4ee5\u6839\u636e\u6570\u636e\u62df\u5408\u6a21\u578b\uff0c\u7136\u540e\u901a\u8fc7\u7279\u5f81\u7cfb\u6570\u6216\u7279\u5f81\u7684\u91cd\u8981\u6027\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u5982\u679c\u4f7f\u7528\u7cfb\u6570\uff0c\u5219\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u9608\u503c\uff0c\u5982\u679c\u7cfb\u6570\u9ad8\u4e8e\u8be5\u9608\u503c\uff0c\u5219\u53ef\u4ee5\u4fdd\u7559\u8be5\u7279\u5f81\uff0c\u5426\u5219\u5c06\u5176\u5254\u9664\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4ece\u968f\u673a\u68ee\u6797\u8fd9\u6837\u7684\u6a21\u578b\u4e2d\u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u5b9e\u4f8b\u5316\u968f\u673a\u68ee\u6797\u6a21\u578b model = RandomForestRegressor () # \u62df\u5408\u6a21\u578b model . fit ( X , y ) \u968f\u673a\u68ee\u6797\uff08\u6216\u4efb\u4f55\u6a21\u578b\uff09\u7684\u7279\u5f81\u91cd\u8981\u6027\u53ef\u6309\u5982\u4e0b\u65b9\u5f0f\u7ed8\u5236\u3002 # \u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027 importances = model . feature_importances_ # \u964d\u5e8f\u6392\u5217 idxs = np . argsort ( importances ) # \u8bbe\u5b9a\u6807\u9898 plt . title ( 'Feature Importances' ) # \u521b\u5efa\u76f4\u65b9\u56fe plt . barh ( range ( len ( idxs )), importances [ idxs ], align = 'center' ) # y\u8f74\u6807\u7b7e plt . yticks ( range ( len ( idxs )), [ col_names [ i ] for i in idxs ]) # x\u8f74\u6807\u7b7e plt . xlabel ( 'Random Forest Feature Importance' ) plt . show () \u7ed3\u679c\u5982\u56fe 3 \u6240\u793a\u3002 \u56fe 3\uff1a\u7279\u5f81\u91cd\u8981\u6027\u56fe \u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u6700\u4f73\u7279\u5f81\u5e76\u4e0d\u662f\u4ec0\u4e48\u65b0\u9c9c\u4e8b\u3002\u60a8\u53ef\u4ee5\u4ece\u4e00\u4e2a\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u53e6\u4e00\u4e2a\u6a21\u578b\u8fdb\u884c\u8bad\u7ec3\u3002\u4f8b\u5982\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u7cfb\u6570\u6765\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u968f\u673a\u68ee\u6797\uff08Random Forest\uff09\u5bf9\u6240\u9009\u7279\u5f81\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 SelectFromModel \u7c7b\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u76f4\u63a5\u4ece\u7ed9\u5b9a\u7684\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u6839\u636e\u9700\u8981\u6307\u5b9a\u7cfb\u6570\u6216\u7279\u5f81\u91cd\u8981\u6027\u7684\u9608\u503c\uff0c\u4ee5\u53ca\u8981\u9009\u62e9\u7684\u7279\u5f81\u7684\u6700\u5927\u6570\u91cf\u3002 \u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\uff0c\u6211\u4eec\u4f7f\u7528 SelectFromModel \u4e2d\u7684\u9ed8\u8ba4\u53c2\u6570\u6765\u9009\u62e9\u7279\u5f81\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor from sklearn.feature_selection import SelectFromModel data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u521b\u5efa\u968f\u673a\u68ee\u6797\u6a21\u578b\u56de\u5f52\u6a21\u578b model = RandomForestRegressor () # \u521b\u5efa SelectFromModel \u5bf9\u8c61 sfm\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u6a21\u578b\u4f5c\u4e3a\u4f30\u7b97\u5668 sfm = SelectFromModel ( estimator = model ) # \u4f7f\u7528 sfm \u5bf9\u7279\u5f81\u77e9\u9635 X \u548c\u76ee\u6807\u53d8\u91cf y \u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed = sfm . fit_transform ( X , y ) # \u83b7\u53d6\u7ecf\u8fc7\u7279\u5f81\u9009\u62e9\u540e\u7684\u7279\u5f81\u63a9\u7801\uff08True \u8868\u793a\u7279\u5f81\u88ab\u9009\u62e9\uff0cFalse \u8868\u793a\u7279\u5f81\u672a\u88ab\u9009\u62e9\uff09 support = sfm . get_support () # \u6253\u5370\u88ab\u9009\u62e9\u7684\u7279\u5f81\u5217\u540d print ([ x for x , y in zip ( col_names , support ) if y == True ]) \u4e0a\u9762\u7a0b\u5e8f\u6253\u5370\u7ed3\u679c\uff1a ['bmi'\uff0c's5']\u3002\u6211\u4eec\u518d\u770b\u56fe 3\uff0c\u5c31\u4f1a\u53d1\u73b0\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u4e2a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u76f4\u63a5\u4ece\u968f\u673a\u68ee\u6797\u63d0\u4f9b\u7684\u7279\u5f81\u91cd\u8981\u6027\u4e2d\u8fdb\u884c\u9009\u62e9\u3002\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4ef6\u4e8b\uff0c\u90a3\u5c31\u662f\u4f7f\u7528 L1\uff08Lasso\uff09\u60e9\u7f5a\u6a21\u578b \u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u5f53\u6211\u4eec\u4f7f\u7528 L1 \u60e9\u7f5a\u8fdb\u884c\u6b63\u5219\u5316\u65f6\uff0c\u5927\u90e8\u5206\u7cfb\u6570\u90fd\u5c06\u4e3a 0\uff08\u6216\u63a5\u8fd1 0\uff09\uff0c\u56e0\u6b64\u6211\u4eec\u8981\u9009\u62e9\u7cfb\u6570\u4e0d\u4e3a 0 \u7684\u7279\u5f81\u3002\u53ea\u9700\u5c06\u6a21\u578b\u9009\u62e9\u7247\u6bb5\u4e2d\u7684\u968f\u673a\u68ee\u6797\u66ff\u6362\u4e3a\u652f\u6301 L1 \u60e9\u7f5a\u7684\u6a21\u578b\uff08\u5982 lasso \u56de\u5f52\uff09\u5373\u53ef\u3002\u6240\u6709\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u90fd\u63d0\u4f9b\u7279\u5f81\u91cd\u8981\u6027\uff0c\u56e0\u6b64\u672c\u7ae0\u4e2d\u5c55\u793a\u7684\u6240\u6709\u57fa\u4e8e\u6a21\u578b\u7684\u7247\u6bb5\u90fd\u53ef\u7528\u4e8e XGBoost\u3001LightGBM \u6216 CatBoost\u3002\u7279\u5f81\u91cd\u8981\u6027\u51fd\u6570\u7684\u540d\u79f0\u53ef\u80fd\u4e0d\u540c\uff0c\u4ea7\u751f\u7ed3\u679c\u7684\u683c\u5f0f\u4e5f\u53ef\u80fd\u4e0d\u540c\uff0c\u4f46\u7528\u6cd5\u662f\u4e00\u6837\u7684\u3002\u6700\u540e\uff0c\u5728\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u65f6\u5fc5\u987b\u5c0f\u5fc3\u8c28\u614e\u3002\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u9009\u62e9\u7279\u5f81\uff0c\u5e76\u5728\u9a8c\u8bc1\u6570\u636e\u4e0a\u9a8c\u8bc1\u6a21\u578b\uff0c\u4ee5\u4fbf\u5728\u4e0d\u8fc7\u5ea6\u62df\u5408\u6a21\u578b\u7684\u60c5\u51b5\u4e0b\u6b63\u786e\u9009\u62e9\u7279\u5f81\u3002","title":"\u7279\u5f81\u9009\u62e9"},{"location":"%E7%BB%84%E5%90%88%E5%92%8C%E5%A0%86%E5%8F%A0%E6%96%B9%E6%B3%95/","text":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u542c\u5230\u4e0a\u9762\u4e24\u4e2a\u8bcd\uff0c\u6211\u4eec\u9996\u5148\u60f3\u5230\u7684\u5c31\u662f\u5728\u7ebf\uff08online\uff09/\u79bb\u7ebf\uff08offline\uff09\u673a\u5668\u5b66\u4e60\u7ade\u8d5b\u3002\u51e0\u5e74\u524d\u662f\u8fd9\u6837\uff0c\u4f46\u73b0\u5728\u968f\u7740\u8ba1\u7b97\u80fd\u529b\u7684\u8fdb\u6b65\u548c\u865a\u62df\u5b9e\u4f8b\u7684\u5ec9\u4ef7\uff0c\u4eba\u4eec\u751a\u81f3\u5f00\u59cb\u5728\u884c\u4e1a\u4e2d\u4f7f\u7528\u7ec4\u5408\u6a21\u578b\uff08ensemble models\uff09\u3002\u4f8b\u5982\uff0c\u90e8\u7f72\u591a\u4e2a\u795e\u7ecf\u7f51\u7edc\u5e76\u5b9e\u65f6\u4e3a\u5b83\u4eec\u63d0\u4f9b\u670d\u52a1\u975e\u5e38\u5bb9\u6613\uff0c\u54cd\u5e94\u65f6\u95f4\u5c0f\u4e8e 500 \u6beb\u79d2\u3002\u6709\u65f6\uff0c\u4e00\u4e2a\u5e9e\u5927\u7684\u795e\u7ecf\u7f51\u7edc\u6216\u5927\u578b\u6a21\u578b\u4e5f\u53ef\u4ee5\u88ab\u5176\u4ed6\u51e0\u4e2a\u6a21\u578b\u53d6\u4ee3\uff0c\u8fd9\u4e9b\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u6027\u80fd\u4e0e\u5927\u578b\u6a21\u578b\u76f8\u4f3c\uff0c\u901f\u5ea6\u5374\u5feb\u4e00\u500d\u3002\u5982\u679c\u662f\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff08\u4e9b\uff09\u6a21\u578b\u5462\uff1f\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u4e8e\u9009\u62e9\u591a\u4e2a\u5c0f\u673a\u578b\uff0c\u5b83\u4eec\u901f\u5ea6\u66f4\u5feb\uff0c\u6027\u80fd\u4e0e\u5927\u673a\u578b\u548c\u6162\u673a\u578b\u76f8\u540c\u3002\u8bf7\u8bb0\u4f4f\uff0c\u8f83\u5c0f\u7684\u578b\u53f7\u4e5f\u66f4\u5bb9\u6613\u548c\u66f4\u5feb\u5730\u8fdb\u884c\u8c03\u6574\u3002 \u7ec4\u5408\uff08ensembling\uff09\u4e0d\u8fc7\u662f\u4e0d\u540c\u6a21\u578b\u7684\u7ec4\u5408\u3002\u6a21\u578b\u53ef\u4ee5\u901a\u8fc7\u9884\u6d4b/\u6982\u7387\u8fdb\u884c\u7ec4\u5408\u3002\u7ec4\u5408\u6a21\u578b\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6c42\u5e73\u5747\u503c\u3002 $$ Ensemble Probabilities = (M1_proba + M2_proba + ... + Mn_Proba)/n $$ \u8fd9\u662f\u6700\u7b80\u5355\u4e5f\u662f\u6700\u6709\u6548\u7684\u7ec4\u5408\u6a21\u578b\u7684\u65b9\u6cd5\u3002\u5728\u7b80\u5355\u5e73\u5747\u6cd5\u4e2d\uff0c\u6240\u6709\u6a21\u578b\u7684\u6743\u91cd\u90fd\u662f\u76f8\u7b49\u7684\u3002\u65e0\u8bba\u91c7\u7528\u54ea\u79cd\u7ec4\u5408\u65b9\u6cd5\uff0c\u60a8\u90fd\u5e94\u8be5\u7262\u8bb0\u4e00\u70b9\uff0c\u90a3\u5c31\u662f\u60a8\u5e94\u8be5\u59cb\u7ec8\u5c06\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b/\u6982\u7387\u7ec4\u5408\u5728\u4e00\u8d77\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u7ec4\u5408\u76f8\u5173\u6027\u4e0d\u9ad8\u7684\u6a21\u578b\u6bd4\u7ec4\u5408\u76f8\u5173\u6027\u5f88\u9ad8\u7684\u6a21\u578b\u6548\u679c\u66f4\u597d\u3002 \u5982\u679c\u6ca1\u6709\u6982\u7387\uff0c\u4e5f\u53ef\u4ee5\u7ec4\u5408\u9884\u6d4b\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6295\u7968\u3002\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u8fdb\u884c\u591a\u7c7b\u5206\u7c7b\uff0c\u6709\u4e09\u4e2a\u7c7b\u522b\uff1a 0\u30011 \u548c 2\u3002 [0, 0, 1] : \u6700\u9ad8\u7968\u6570\uff1a 0 [0, 1, 2] : \u6700\u9ad8\u7968\u7ea7\uff1a \u65e0\uff08\u968f\u673a\u9009\u62e9\u4e00\u4e2a\uff09 [2, 2, 2] : \u6700\u9ad8\u7968\u6570\uff1a 2 \u4ee5\u4e0b\u7b80\u5355\u51fd\u6570\u53ef\u4ee5\u5b8c\u6210\u8fd9\u4e9b\u7b80\u5355\u64cd\u4f5c\u3002 import numpy as np def mean_predictions ( probas ): # \u8ba1\u7b97\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u6bcf\u884c\u5e73\u5747\u503c return np . mean ( probas , axis = 1 ) def max_voting ( preds ): # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u67e5\u627e\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u7684\u7d22\u5f15 idxs = np . argmax ( preds , axis = 1 ) # \u6839\u636e\u7d22\u5f15\u53d6\u51fa\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u5143\u7d20 return np . take_along_axis ( preds , idxs [:, None ], axis = 1 ) \u8bf7\u6ce8\u610f\uff0cprobas \u7684\u6bcf\u4e00\u5217\u90fd\u53ea\u6709\u4e00\u4e2a\u6982\u7387\uff08\u5373\u4e8c\u5143\u5206\u7c7b\uff0c\u901a\u5e38\u4e3a\u7c7b\u522b 1\uff09\u3002\u56e0\u6b64\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u4e00\u4e2a\u65b0\u6a21\u578b\u3002\u540c\u6837\uff0c\u5bf9\u4e8e preds\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u6765\u81ea\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b\u503c\u3002\u8fd9\u4e24\u4e2a\u51fd\u6570\u90fd\u5047\u8bbe\u4e86\u4e00\u4e2a 2 \u7ef4 numpy \u6570\u7ec4\u3002\u60a8\u53ef\u4ee5\u6839\u636e\u81ea\u5df1\u7684\u9700\u6c42\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u6709\u4e00\u4e2a 2 \u7ef4\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u6bcf\u4e2a\u6a21\u578b\u7684\u6982\u7387\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u51fd\u6570\u4f1a\u6709\u4e00\u4e9b\u53d8\u5316\u3002 \u53e6\u4e00\u79cd\u7ec4\u5408\u591a\u4e2a\u6a21\u578b\u7684\u65b9\u6cd5\u662f\u901a\u8fc7\u5b83\u4eec\u7684 \u6982\u7387\u6392\u5e8f \u3002\u5f53\u76f8\u5173\u6307\u6807\u662f\u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u65f6\uff0c\u8fd9\u79cd\u7ec4\u5408\u65b9\u5f0f\u975e\u5e38\u6709\u6548\uff0c\u56e0\u4e3a AUC \u5c31\u662f\u5bf9\u6837\u672c\u8fdb\u884c\u6392\u5e8f\u3002 def rank_mean ( probas ): # \u521b\u5efa\u7a7a\u5217\u8868ranked\u5b58\u50a8\u6bcf\u4e2a\u7c7b\u522b\u6982\u7387\u503c\u6392\u540d ranked = [] # \u904d\u5386\u6982\u7387\u503c\u6bcf\u4e00\u5217\uff08\u6bcf\u4e2a\u7c7b\u522b\u7684\u6982\u7387\u503c\uff09 for i in range ( probas . shape [ 1 ]): # \u5f53\u524d\u5217\u6982\u7387\u503c\u6392\u540d\uff0crank_data\u662f\u6392\u540d\u7ed3\u679c rank_data = stats . rankdata ( probas [:, i ]) # \u5c06\u5f53\u524d\u5217\u6392\u540d\u7ed3\u679c\u6dfb\u52a0\u5230ranked\u5217\u8868\u4e2d ranked . append ( rank_data ) # \u5c06ranked\u5217\u8868\u4e2d\u6392\u540d\u7ed3\u679c\u6309\u5217\u5806\u53e0\uff0c\u5f62\u6210\u4e8c\u7ef4\u6570\u7ec4 ranked = np . column_stack ( ranked ) # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u8ba1\u7b97\u6837\u672c\u6392\u540d\u5e73\u5747\u503c return np . mean ( ranked , axis = 1 ) \u8bf7\u6ce8\u610f\uff0c\u5728 scipy \u7684 rankdata \u4e2d\uff0c\u7b49\u7ea7\u4ece 1 \u5f00\u59cb\u3002 \u4e3a\u4ec0\u4e48\u8fd9\u7c7b\u96c6\u5408\u6709\u6548\uff1f\u8ba9\u6211\u4eec\u770b\u770b\u56fe 1\u3002 \u56fe 1\uff1a\u4e09\u4eba\u731c\u5927\u8c61\u7684\u8eab\u9ad8 \u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u6709\u4e09\u4e2a\u4eba\u5728\u731c\u5927\u8c61\u7684\u9ad8\u5ea6\uff0c\u90a3\u4e48\u539f\u59cb\u9ad8\u5ea6\u5c06\u975e\u5e38\u63a5\u8fd1\u4e09\u4e2a\u4eba\u731c\u6d4b\u7684\u5e73\u5747\u503c\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u4eba\u90fd\u80fd\u731c\u5230\u975e\u5e38\u63a5\u8fd1\u5927\u8c61\u539f\u6765\u7684\u9ad8\u5ea6\u3002\u63a5\u8fd1\u4f30\u8ba1\u503c\u610f\u5473\u7740\u8bef\u5dee\uff0c\u4f46\u5982\u679c\u6211\u4eec\u5c06\u4e09\u4e2a\u9884\u6d4b\u503c\u5e73\u5747\uff0c\u5c31\u80fd\u5c06\u8bef\u5dee\u964d\u5230\u6700\u4f4e\u3002\u8fd9\u5c31\u662f\u591a\u4e2a\u6a21\u578b\u5e73\u5747\u7684\u4e3b\u8981\u601d\u60f3\u3002 $$ Final\\ Probabilities = w_1 \\times M1_proba + w_2 \\times M2_proba + \\cdots + w_n \\times Mn_proba $$ \u5176\u4e2d \\((w_1 + w_2 + w_3 + \\cdots + w_n)=1.0\\) \u4f8b\u5982\uff0c\u5982\u679c\u4f60\u6709\u4e00\u4e2a AUC \u975e\u5e38\u9ad8\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u548c\u4e00\u4e2a AUC \u7a0d\u4f4e\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0c\u968f\u673a\u68ee\u6797\u6a21\u578b\u5360 70%\uff0c\u903b\u8f91\u56de\u5f52\u6a21\u578b\u5360 30%\u3002\u90a3\u4e48\uff0c\u6211\u662f\u5982\u4f55\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u7684\u5462\uff1f\u8ba9\u6211\u4eec\u518d\u6dfb\u52a0\u4e00\u4e2a\u6a21\u578b\uff0c\u5047\u8bbe\u73b0\u5728\u6211\u4eec\u4e5f\u6709\u4e00\u4e2a xgboost \u6a21\u578b\uff0c\u5b83\u7684 AUC \u6bd4\u968f\u673a\u68ee\u6797\u9ad8\u3002\u73b0\u5728\uff0c\u6211\u5c06\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0cxgboost\uff1a\u968f\u673a\u68ee\u6797\uff1a\u903b\u8f91\u56de\u5f52\u7684\u6bd4\u4f8b\u4e3a 3:2:1\u3002\u5f88\u7b80\u5355\u5427\uff1f\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u6613\u5982\u53cd\u638c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u662f\u5982\u4f55\u505a\u5230\u7684\u3002 \u5047\u5b9a\u6211\u4eec\u6709\u4e09\u53ea\u7334\u5b50\uff0c\u4e09\u53ea\u65cb\u94ae\u7684\u6570\u503c\u5728 0 \u548c 1 \u4e4b\u95f4\u3002\u8fd9\u4e9b\u7334\u5b50\u8f6c\u52a8\u65cb\u94ae\uff0c\u6211\u4eec\u8ba1\u7b97\u5b83\u4eec\u6bcf\u8f6c\u5230\u4e00\u4e2a\u6570\u503c\u65f6\u7684 AUC \u5206\u6570\u3002\u6700\u7ec8\uff0c\u7334\u5b50\u4eec\u4f1a\u627e\u5230\u4e00\u4e2a\u80fd\u7ed9\u51fa\u6700\u4f73 AUC \u7684\u7ec4\u5408\u3002\u6ca1\u9519\uff0c\u8fd9\u5c31\u662f\u968f\u673a\u641c\u7d22\uff01\u5728\u8fdb\u884c\u8fd9\u7c7b\u641c\u7d22\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\u4e24\u4e2a\u6700\u91cd\u8981\u7684\u7ec4\u5408\u89c4\u5219\u3002 \u7ec4\u5408\u7684\u7b2c\u4e00\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u7ec4\u5408\u7684\u7b2c\u4e8c\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u662f\u7684\u3002\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u6761\u89c4\u5219\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u6298\u53e0\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u5047\u8bbe\u6211\u4eec\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\uff1a\u6298\u53e0 1 \u548c\u6298\u53e0 2\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u6837\u505a\u53ea\u662f\u4e3a\u4e86\u7b80\u5316\u89e3\u91ca\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\uff0c\u60a8\u5e94\u8be5\u521b\u5efa\u66f4\u591a\u7684\u6298\u53e0\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u968f\u673a\u68ee\u6797\u6a21\u578b\u3001\u903b\u8f91\u56de\u5f52\u6a21\u578b\u548c xgboost \u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 2 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u4e4b\u540e\uff0c\u6211\u4eec\u5728\u6298\u53e0 2 \u4e0a\u4ece\u5934\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 1 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u4e3a\u6240\u6709\u8bad\u7ec3\u6570\u636e\u521b\u5efa\u4e86\u9884\u6d4b\u7ed3\u679c\u3002\u73b0\u5728\uff0c\u4e3a\u4e86\u5408\u5e76\u8fd9\u4e9b\u6a21\u578b\uff0c\u6211\u4eec\u5c06\u6298\u53e0 1 \u548c\u6298\u53e0 1 \u7684\u6240\u6709\u9884\u6d4b\u6570\u636e\u5408\u5e76\u5728\u4e00\u8d77\uff0c\u7136\u540e\u521b\u5efa\u4e00\u4e2a\u4f18\u5316\u51fd\u6570\uff0c\u8bd5\u56fe\u627e\u5230\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4fbf\u9488\u5bf9\u6298\u53e0 2 \u7684\u76ee\u6807\u6700\u5c0f\u5316\u8bef\u5dee\u6216\u6700\u5927\u5316 AUC\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u662f\u7528\u4e09\u4e2a\u6a21\u578b\u7684\u9884\u6d4b\u6982\u7387\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u4e00\u4e2a\u4f18\u5316\u6a21\u578b\uff0c\u7136\u540e\u5728\u6298\u53e0 2 \u4e0a\u5bf9\u5176\u8fdb\u884c\u8bc4\u4f30\u3002\u8ba9\u6211\u4eec\u5148\u6765\u770b\u770b\u6211\u4eec\u53ef\u4ee5\u7528\u6765\u627e\u5230\u591a\u4e2a\u6a21\u578b\u7684\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4f18\u5316 AUC\uff08\u6216\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u6d4b\u6307\u6807\u7ec4\u5408\uff09\u7684\u7c7b\u3002 import numpy as np from functools import partial from scipy.optimize import fmin from sklearn import metrics class OptimizeAUC : def __init__ ( self ): # \u521d\u59cb\u5316\u7cfb\u6570 self . coef_ = 0 def _auc ( self , coef , X , y ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u7cfb\u6570 x_coef = X * coef # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8ba1\u7b97AUC\u5206\u6570 auc_score = metrics . roc_auc_score ( y , predictions ) # \u8fd4\u56de\u8d1fAUC\u4ee5\u4fbf\u6700\u5c0f\u5316 return - 1.0 * auc_score def fit ( self , X , y ): # \u521b\u5efa\u5e26\u6709\u90e8\u5206\u53c2\u6570\u7684\u76ee\u6807\u51fd\u6570 loss_partial = partial ( self . _auc , X = X , y = y ) # \u521d\u59cb\u5316\u7cfb\u6570 initial_coef = np . random . dirichlet ( np . ones ( X . shape [ 1 ]), size = 1 ) # \u4f7f\u7528fmin\u51fd\u6570\u4f18\u5316AUC\u76ee\u6807\u51fd\u6570\uff0c\u627e\u5230\u6700\u4f18\u7cfb\u6570 self . coef_ = fmin ( loss_partial , initial_coef , disp = True ) def predict ( self , X ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u8bad\u7ec3\u597d\u7684\u7cfb\u6570 x_coef = X * self . coef_ # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8fd4\u56de\u9884\u6d4b\u7ed3\u679c return predictions \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528\u5b83\uff0c\u5e76\u5c06\u5176\u4e0e\u7b80\u5355\u5e73\u5747\u6cd5\u8fdb\u884c\u6bd4\u8f83\u3002 import xgboost as xgb from sklearn.datasets import make_classification from sklearn import ensemble from sklearn import linear_model from sklearn import metrics from sklearn import model_selection # \u751f\u6210\u4e00\u4e2a\u5206\u7c7b\u6570\u636e\u96c6 X , y = make_classification ( n_samples = 10000 , n_features = 25 ) # \u5212\u5206\u6570\u636e\u96c6\u4e3a\u4e24\u4e2a\u4ea4\u53c9\u9a8c\u8bc1\u6298\u53e0 xfold1 , xfold2 , yfold1 , yfold2 = model_selection . train_test_split ( X , y , test_size = 0.5 , stratify = y ) # \u521d\u59cb\u5316\u4e09\u4e2a\u4e0d\u540c\u7684\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold1 , yfold1 ) rf . fit ( xfold1 , yfold1 ) xgbc . fit ( xfold1 , yfold1 ) # \u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold2 )[:, 1 ] pred_rf = rf . predict_proba ( xfold2 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold2 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold2_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold2 = [] for i in range ( fold2_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold2 , fold2_preds [:, i ]) aucs_fold2 . append ( auc ) print ( f \"Fold-2: LR AUC = { aucs_fold2 [ 0 ] } \" ) print ( f \"Fold-2: RF AUC = { aucs_fold2 [ 1 ] } \" ) print ( f \"Fold-2: XGB AUC = { aucs_fold2 [ 2 ] } \" ) print ( f \"Fold-2: Average Pred AUC = { aucs_fold2 [ 3 ] } \" ) # \u91cd\u65b0\u521d\u59cb\u5316\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold2 , yfold2 ) rf . fit ( xfold2 , yfold2 ) xgbc . fit ( xfold2 , yfold2 ) # \u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold1 )[:, 1 ] pred_rf = rf . predict_proba ( xfold1 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold1 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold1_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold1 = [] for i in range ( fold1_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold1 , fold1_preds [:, i ]) aucs_fold1 . append ( auc ) print ( f \"Fold-1: LR AUC = { aucs_fold1 [ 0 ] } \" ) print ( f \"Fold-1: RF AUC = { aucs_fold1 [ 1 ] } \" ) print ( f \"Fold-1: XGB AUC = { aucs_fold1 [ 2 ] } \" ) print ( f \"Fold-1: Average prediction AUC = { aucs_fold1 [ 3 ] } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765\u8bad\u7ec3\u4f18\u5316\u5668 opt . fit ( fold1_preds [:, : - 1 ], yfold1 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold2 = opt . predict ( fold2_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold2 , opt_preds_fold2 ) print ( f \"Optimized AUC, Fold 2 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765 opt . fit ( fold2_preds [:, : - 1 ], yfold2 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold1 = opt . predict ( fold1_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold1 , opt_preds_fold1 ) print ( f \"Optimized AUC, Fold 1 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) \u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u8f93\u51fa\uff1a \u276f python auc_opt . py Fold - 2 : LR AUC = 0.9145446769443348 Fold - 2 : RF AUC = 0.9269918948683287 Fold - 2 : XGB AUC = 0.9302436595508696 Fold - 2 : Average Pred AUC = 0.927701495890154 Fold - 1 : LR AUC = 0.9050872233256017 Fold - 1 : RF AUC = 0.9179382818311258 Fold - 1 : XGB AUC = 0.9195837242005629 Fold - 1 : Average prediction AUC = 0.9189669233123695 Optimization terminated successfully . Current function value : - 0.920643 Iterations : 50 Function evaluations : 109 Optimized AUC , Fold 2 = 0.9305386199756128 Coefficients = [ - 0.00188194 0.19328336 0.35891836 ] Optimization terminated successfully . Current function value : - 0.931232 Iterations : 56 Function evaluations : 113 Optimized AUC , Fold 1 = 0.9192523637234037 Coefficients = [ - 0.15655124 0.22393151 0.58711366 ] \u6211\u4eec\u770b\u5230\uff0c\u5e73\u5747\u503c\u66f4\u597d\uff0c\u4f46\u4f7f\u7528\u4f18\u5316\u5668\u627e\u5230\u9608\u503c\u66f4\u597d\uff01\u6709\u65f6\uff0c\u5e73\u5747\u503c\u662f\u6700\u597d\u7684\u9009\u62e9\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u7cfb\u6570\u52a0\u8d77\u6765\u5e76\u6ca1\u6709\u8fbe\u5230 1.0\uff0c\u4f46\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u6211\u4eec\u8981\u5904\u7406\u7684\u662f AUC\uff0c\u800c AUC \u53ea\u5173\u5fc3\u7b49\u7ea7\u3002 \u5373\u4f7f\u968f\u673a\u68ee\u6797\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u968f\u673a\u68ee\u6797\u53ea\u662f\u8bb8\u591a\u7b80\u5355\u51b3\u7b56\u6811\u7684\u7ec4\u5408\u3002\u968f\u673a\u68ee\u6797\u5c5e\u4e8e\u96c6\u5408\u6a21\u578b\u7684\u4e00\u79cd\uff0c\u4e5f\u5c31\u662f\u4fd7\u79f0\u7684 \"bagging\" \u3002\u5728\u888b\u96c6\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u5c0f\u6570\u636e\u5b50\u96c6\u5e76\u8bad\u7ec3\u591a\u4e2a\u7b80\u5355\u6a21\u578b\u3002\u6700\u7ec8\u7ed3\u679c\u7531\u6240\u6709\u8fd9\u4e9b\u5c0f\u6a21\u578b\u7684\u9884\u6d4b\u7ed3\u679c\uff08\u5982\u5e73\u5747\u503c\uff09\u7ec4\u5408\u800c\u6210\u3002 \u6211\u4eec\u4f7f\u7528\u7684 xgboost \u6a21\u578b\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u6240\u6709\u68af\u5ea6\u63d0\u5347\u6a21\u578b\u90fd\u662f\u96c6\u5408\u6a21\u578b\uff0c\u7edf\u79f0\u4e3a \u63d0\u5347\u6a21\u578b\uff08boosting models\uff09 \u3002\u63d0\u5347\u6a21\u578b\u7684\u5de5\u4f5c\u539f\u7406\u4e0e\u88c5\u888b\u6a21\u578b\u7c7b\u4f3c\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u63d0\u5347\u6a21\u578b\u4e2d\u7684\u8fde\u7eed\u6a21\u578b\u662f\u6839\u636e\u8bef\u5dee\u6b8b\u5dee\u8bad\u7ec3\u7684\uff0c\u5e76\u503e\u5411\u4e8e\u6700\u5c0f\u5316\u524d\u9762\u6a21\u578b\u7684\u8bef\u5dee\u3002\u8fd9\u6837\uff0c\u63d0\u5347\u6a21\u578b\u5c31\u80fd\u5b8c\u7f8e\u5730\u5b66\u4e60\u6570\u636e\uff0c\u56e0\u6b64\u5bb9\u6613\u51fa\u73b0\u8fc7\u62df\u5408\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u770b\u5230\u7684\u4ee3\u7801\u7247\u6bb5\u53ea\u8003\u8651\u4e86\u4e00\u5217\u3002\u4f46\u60c5\u51b5\u5e76\u975e\u603b\u662f\u5982\u6b64\uff0c\u5f88\u591a\u65f6\u5019\u60a8\u9700\u8981\u5904\u7406\u591a\u5217\u9884\u6d4b\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u4f1a\u9047\u5230\u4ece\u591a\u4e2a\u7c7b\u522b\u4e2d\u9884\u6d4b\u4e00\u4e2a\u7c7b\u522b\u7684\u95ee\u9898\uff0c\u5373\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5bf9\u4e8e\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u9009\u62e9\u6295\u7968\u65b9\u6cd5\u3002\u4f46\u6295\u7968\u6cd5\u5e76\u4e0d\u603b\u662f\u6700\u4f73\u65b9\u6cd5\u3002\u5982\u679c\u8981\u7ec4\u5408\u6982\u7387\uff0c\u5c31\u4f1a\u6709\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4\uff0c\u800c\u4e0d\u662f\u50cf\u6211\u4eec\u4e4b\u524d\u4f18\u5316 AUC \u65f6\u7684\u5411\u91cf\u3002\u5982\u679c\u6709\u591a\u4e2a\u7c7b\u522b\uff0c\u53ef\u4ee5\u5c1d\u8bd5\u4f18\u5316\u5bf9\u6570\u635f\u5931\uff08\u6216\u5176\u4ed6\u4e0e\u4e1a\u52a1\u76f8\u5173\u7684\u6307\u6807\uff09\u3002 \u8981\u8fdb\u884c\u7ec4\u5408\uff0c\u53ef\u4ee5\u5728\u62df\u5408\u51fd\u6570 (X) \u4e2d\u4f7f\u7528 numpy \u6570\u7ec4\u5217\u8868\u800c\u4e0d\u662f numpy \u6570\u7ec4\uff0c\u968f\u540e\u8fd8\u9700\u8981\u66f4\u6539\u4f18\u5316\u5668\u548c\u9884\u6d4b\u51fd\u6570\u3002\u6211\u5c31\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u5427\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8fdb\u5165\u4e0b\u4e00\u4e2a\u6709\u8da3\u7684\u8bdd\u9898\uff0c\u8fd9\u4e2a\u8bdd\u9898\u76f8\u5f53\u6d41\u884c\uff0c\u88ab\u79f0\u4e3a \u5806\u53e0 \u3002\u56fe 2 \u5c55\u793a\u4e86\u5982\u4f55\u5806\u53e0\u6a21\u578b\u3002 \u56fe2 : Stacking \u5806\u53e0\u4e0d\u50cf\u5236\u9020\u706b\u7bad\u3002\u5b83\u7b80\u5355\u660e\u4e86\u3002\u5982\u679c\u60a8\u8fdb\u884c\u4e86\u6b63\u786e\u7684\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5e76\u5728\u6574\u4e2a\u5efa\u6a21\u8fc7\u7a0b\u4e2d\u4fdd\u6301\u6298\u53e0\u4e0d\u53d8\uff0c\u90a3\u4e48\u5c31\u4e0d\u4f1a\u51fa\u73b0\u4efb\u4f55\u8fc7\u5ea6\u8d34\u5408\u7684\u60c5\u51b5\u3002 \u8ba9\u6211\u7528\u7b80\u5355\u7684\u8981\u70b9\u5411\u4f60\u63cf\u8ff0\u4e00\u4e0b\u8fd9\u4e2a\u60f3\u6cd5\u3002 - \u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u82e5\u5e72\u6298\u53e0\u3002 - \u8bad\u7ec3\u4e00\u5806\u6a21\u578b\uff1a M1\u3001M2.....Mn\u3002 - \u521b\u5efa\u5b8c\u6574\u7684\u8bad\u7ec3\u9884\u6d4b\uff08\u4f7f\u7528\u975e\u6298\u53e0\u8bad\u7ec3\uff09\uff0c\u5e76\u4f7f\u7528\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\u9884\u6d4b\u3002 - \u76f4\u5230\u8fd9\u91cc\u662f\u7b2c 1 \u5c42 (L1)\u3002 - \u5c06\u8fd9\u4e9b\u6a21\u578b\u7684\u6298\u53e0\u9884\u6d4b\u4f5c\u4e3a\u53e6\u4e00\u4e2a\u6a21\u578b\u7684\u7279\u5f81\u3002\u8fd9\u5c31\u662f\u4e8c\u7ea7\u6a21\u578b\uff08L2\uff09\u3002 - \u4f7f\u7528\u4e0e\u4e4b\u524d\u76f8\u540c\u7684\u6298\u53e0\u6765\u8bad\u7ec3\u8fd9\u4e2a L2 \u6a21\u578b\u3002 - \u73b0\u5728\uff0c\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u521b\u5efa OOF\uff08\u6298\u53e0\u5916\uff09\u9884\u6d4b\u3002 - \u73b0\u5728\u60a8\u5c31\u6709\u4e86\u8bad\u7ec3\u6570\u636e\u7684 L2 \u9884\u6d4b\u548c\u6700\u7ec8\u6d4b\u8bd5\u96c6\u9884\u6d4b\u3002 \u60a8\u53ef\u4ee5\u4e0d\u65ad\u91cd\u590d L1 \u90e8\u5206\uff0c\u4e5f\u53ef\u4ee5\u521b\u5efa\u4efb\u610f\u591a\u7684\u5c42\u6b21\u3002 \u6709\u65f6\uff0c\u4f60\u8fd8\u4f1a\u9047\u5230\u4e00\u4e2a\u53eb\u6df7\u5408\u7684\u672f\u8bed blending \u3002\u5982\u679c\u4f60\u9047\u5230\u4e86\uff0c\u4e0d\u7528\u592a\u62c5\u5fc3\u3002\u5b83\u53ea\u4e0d\u8fc7\u662f\u7528\u4e00\u4e2a\u4fdd\u7559\u7ec4\u6765\u5806\u53e0\uff0c\u800c\u4e0d\u662f\u591a\u91cd\u6298\u53e0\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u6211\u5728\u672c\u7ae0\u4e2d\u6240\u63cf\u8ff0\u7684\u5185\u5bb9\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u7c7b\u578b\u7684\u95ee\u9898\uff1a\u5206\u7c7b\u3001\u56de\u5f52\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u7b49\u3002","title":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5"},{"location":"%E7%BB%84%E5%90%88%E5%92%8C%E5%A0%86%E5%8F%A0%E6%96%B9%E6%B3%95/#_1","text":"\u542c\u5230\u4e0a\u9762\u4e24\u4e2a\u8bcd\uff0c\u6211\u4eec\u9996\u5148\u60f3\u5230\u7684\u5c31\u662f\u5728\u7ebf\uff08online\uff09/\u79bb\u7ebf\uff08offline\uff09\u673a\u5668\u5b66\u4e60\u7ade\u8d5b\u3002\u51e0\u5e74\u524d\u662f\u8fd9\u6837\uff0c\u4f46\u73b0\u5728\u968f\u7740\u8ba1\u7b97\u80fd\u529b\u7684\u8fdb\u6b65\u548c\u865a\u62df\u5b9e\u4f8b\u7684\u5ec9\u4ef7\uff0c\u4eba\u4eec\u751a\u81f3\u5f00\u59cb\u5728\u884c\u4e1a\u4e2d\u4f7f\u7528\u7ec4\u5408\u6a21\u578b\uff08ensemble models\uff09\u3002\u4f8b\u5982\uff0c\u90e8\u7f72\u591a\u4e2a\u795e\u7ecf\u7f51\u7edc\u5e76\u5b9e\u65f6\u4e3a\u5b83\u4eec\u63d0\u4f9b\u670d\u52a1\u975e\u5e38\u5bb9\u6613\uff0c\u54cd\u5e94\u65f6\u95f4\u5c0f\u4e8e 500 \u6beb\u79d2\u3002\u6709\u65f6\uff0c\u4e00\u4e2a\u5e9e\u5927\u7684\u795e\u7ecf\u7f51\u7edc\u6216\u5927\u578b\u6a21\u578b\u4e5f\u53ef\u4ee5\u88ab\u5176\u4ed6\u51e0\u4e2a\u6a21\u578b\u53d6\u4ee3\uff0c\u8fd9\u4e9b\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u6027\u80fd\u4e0e\u5927\u578b\u6a21\u578b\u76f8\u4f3c\uff0c\u901f\u5ea6\u5374\u5feb\u4e00\u500d\u3002\u5982\u679c\u662f\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff08\u4e9b\uff09\u6a21\u578b\u5462\uff1f\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u4e8e\u9009\u62e9\u591a\u4e2a\u5c0f\u673a\u578b\uff0c\u5b83\u4eec\u901f\u5ea6\u66f4\u5feb\uff0c\u6027\u80fd\u4e0e\u5927\u673a\u578b\u548c\u6162\u673a\u578b\u76f8\u540c\u3002\u8bf7\u8bb0\u4f4f\uff0c\u8f83\u5c0f\u7684\u578b\u53f7\u4e5f\u66f4\u5bb9\u6613\u548c\u66f4\u5feb\u5730\u8fdb\u884c\u8c03\u6574\u3002 \u7ec4\u5408\uff08ensembling\uff09\u4e0d\u8fc7\u662f\u4e0d\u540c\u6a21\u578b\u7684\u7ec4\u5408\u3002\u6a21\u578b\u53ef\u4ee5\u901a\u8fc7\u9884\u6d4b/\u6982\u7387\u8fdb\u884c\u7ec4\u5408\u3002\u7ec4\u5408\u6a21\u578b\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6c42\u5e73\u5747\u503c\u3002 $$ Ensemble Probabilities = (M1_proba + M2_proba + ... + Mn_Proba)/n $$ \u8fd9\u662f\u6700\u7b80\u5355\u4e5f\u662f\u6700\u6709\u6548\u7684\u7ec4\u5408\u6a21\u578b\u7684\u65b9\u6cd5\u3002\u5728\u7b80\u5355\u5e73\u5747\u6cd5\u4e2d\uff0c\u6240\u6709\u6a21\u578b\u7684\u6743\u91cd\u90fd\u662f\u76f8\u7b49\u7684\u3002\u65e0\u8bba\u91c7\u7528\u54ea\u79cd\u7ec4\u5408\u65b9\u6cd5\uff0c\u60a8\u90fd\u5e94\u8be5\u7262\u8bb0\u4e00\u70b9\uff0c\u90a3\u5c31\u662f\u60a8\u5e94\u8be5\u59cb\u7ec8\u5c06\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b/\u6982\u7387\u7ec4\u5408\u5728\u4e00\u8d77\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u7ec4\u5408\u76f8\u5173\u6027\u4e0d\u9ad8\u7684\u6a21\u578b\u6bd4\u7ec4\u5408\u76f8\u5173\u6027\u5f88\u9ad8\u7684\u6a21\u578b\u6548\u679c\u66f4\u597d\u3002 \u5982\u679c\u6ca1\u6709\u6982\u7387\uff0c\u4e5f\u53ef\u4ee5\u7ec4\u5408\u9884\u6d4b\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6295\u7968\u3002\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u8fdb\u884c\u591a\u7c7b\u5206\u7c7b\uff0c\u6709\u4e09\u4e2a\u7c7b\u522b\uff1a 0\u30011 \u548c 2\u3002 [0, 0, 1] : \u6700\u9ad8\u7968\u6570\uff1a 0 [0, 1, 2] : \u6700\u9ad8\u7968\u7ea7\uff1a \u65e0\uff08\u968f\u673a\u9009\u62e9\u4e00\u4e2a\uff09 [2, 2, 2] : \u6700\u9ad8\u7968\u6570\uff1a 2 \u4ee5\u4e0b\u7b80\u5355\u51fd\u6570\u53ef\u4ee5\u5b8c\u6210\u8fd9\u4e9b\u7b80\u5355\u64cd\u4f5c\u3002 import numpy as np def mean_predictions ( probas ): # \u8ba1\u7b97\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u6bcf\u884c\u5e73\u5747\u503c return np . mean ( probas , axis = 1 ) def max_voting ( preds ): # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u67e5\u627e\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u7684\u7d22\u5f15 idxs = np . argmax ( preds , axis = 1 ) # \u6839\u636e\u7d22\u5f15\u53d6\u51fa\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u5143\u7d20 return np . take_along_axis ( preds , idxs [:, None ], axis = 1 ) \u8bf7\u6ce8\u610f\uff0cprobas \u7684\u6bcf\u4e00\u5217\u90fd\u53ea\u6709\u4e00\u4e2a\u6982\u7387\uff08\u5373\u4e8c\u5143\u5206\u7c7b\uff0c\u901a\u5e38\u4e3a\u7c7b\u522b 1\uff09\u3002\u56e0\u6b64\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u4e00\u4e2a\u65b0\u6a21\u578b\u3002\u540c\u6837\uff0c\u5bf9\u4e8e preds\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u6765\u81ea\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b\u503c\u3002\u8fd9\u4e24\u4e2a\u51fd\u6570\u90fd\u5047\u8bbe\u4e86\u4e00\u4e2a 2 \u7ef4 numpy \u6570\u7ec4\u3002\u60a8\u53ef\u4ee5\u6839\u636e\u81ea\u5df1\u7684\u9700\u6c42\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u6709\u4e00\u4e2a 2 \u7ef4\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u6bcf\u4e2a\u6a21\u578b\u7684\u6982\u7387\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u51fd\u6570\u4f1a\u6709\u4e00\u4e9b\u53d8\u5316\u3002 \u53e6\u4e00\u79cd\u7ec4\u5408\u591a\u4e2a\u6a21\u578b\u7684\u65b9\u6cd5\u662f\u901a\u8fc7\u5b83\u4eec\u7684 \u6982\u7387\u6392\u5e8f \u3002\u5f53\u76f8\u5173\u6307\u6807\u662f\u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u65f6\uff0c\u8fd9\u79cd\u7ec4\u5408\u65b9\u5f0f\u975e\u5e38\u6709\u6548\uff0c\u56e0\u4e3a AUC \u5c31\u662f\u5bf9\u6837\u672c\u8fdb\u884c\u6392\u5e8f\u3002 def rank_mean ( probas ): # \u521b\u5efa\u7a7a\u5217\u8868ranked\u5b58\u50a8\u6bcf\u4e2a\u7c7b\u522b\u6982\u7387\u503c\u6392\u540d ranked = [] # \u904d\u5386\u6982\u7387\u503c\u6bcf\u4e00\u5217\uff08\u6bcf\u4e2a\u7c7b\u522b\u7684\u6982\u7387\u503c\uff09 for i in range ( probas . shape [ 1 ]): # \u5f53\u524d\u5217\u6982\u7387\u503c\u6392\u540d\uff0crank_data\u662f\u6392\u540d\u7ed3\u679c rank_data = stats . rankdata ( probas [:, i ]) # \u5c06\u5f53\u524d\u5217\u6392\u540d\u7ed3\u679c\u6dfb\u52a0\u5230ranked\u5217\u8868\u4e2d ranked . append ( rank_data ) # \u5c06ranked\u5217\u8868\u4e2d\u6392\u540d\u7ed3\u679c\u6309\u5217\u5806\u53e0\uff0c\u5f62\u6210\u4e8c\u7ef4\u6570\u7ec4 ranked = np . column_stack ( ranked ) # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u8ba1\u7b97\u6837\u672c\u6392\u540d\u5e73\u5747\u503c return np . mean ( ranked , axis = 1 ) \u8bf7\u6ce8\u610f\uff0c\u5728 scipy \u7684 rankdata \u4e2d\uff0c\u7b49\u7ea7\u4ece 1 \u5f00\u59cb\u3002 \u4e3a\u4ec0\u4e48\u8fd9\u7c7b\u96c6\u5408\u6709\u6548\uff1f\u8ba9\u6211\u4eec\u770b\u770b\u56fe 1\u3002 \u56fe 1\uff1a\u4e09\u4eba\u731c\u5927\u8c61\u7684\u8eab\u9ad8 \u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u6709\u4e09\u4e2a\u4eba\u5728\u731c\u5927\u8c61\u7684\u9ad8\u5ea6\uff0c\u90a3\u4e48\u539f\u59cb\u9ad8\u5ea6\u5c06\u975e\u5e38\u63a5\u8fd1\u4e09\u4e2a\u4eba\u731c\u6d4b\u7684\u5e73\u5747\u503c\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u4eba\u90fd\u80fd\u731c\u5230\u975e\u5e38\u63a5\u8fd1\u5927\u8c61\u539f\u6765\u7684\u9ad8\u5ea6\u3002\u63a5\u8fd1\u4f30\u8ba1\u503c\u610f\u5473\u7740\u8bef\u5dee\uff0c\u4f46\u5982\u679c\u6211\u4eec\u5c06\u4e09\u4e2a\u9884\u6d4b\u503c\u5e73\u5747\uff0c\u5c31\u80fd\u5c06\u8bef\u5dee\u964d\u5230\u6700\u4f4e\u3002\u8fd9\u5c31\u662f\u591a\u4e2a\u6a21\u578b\u5e73\u5747\u7684\u4e3b\u8981\u601d\u60f3\u3002 $$ Final\\ Probabilities = w_1 \\times M1_proba + w_2 \\times M2_proba + \\cdots + w_n \\times Mn_proba $$ \u5176\u4e2d \\((w_1 + w_2 + w_3 + \\cdots + w_n)=1.0\\) \u4f8b\u5982\uff0c\u5982\u679c\u4f60\u6709\u4e00\u4e2a AUC \u975e\u5e38\u9ad8\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u548c\u4e00\u4e2a AUC \u7a0d\u4f4e\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0c\u968f\u673a\u68ee\u6797\u6a21\u578b\u5360 70%\uff0c\u903b\u8f91\u56de\u5f52\u6a21\u578b\u5360 30%\u3002\u90a3\u4e48\uff0c\u6211\u662f\u5982\u4f55\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u7684\u5462\uff1f\u8ba9\u6211\u4eec\u518d\u6dfb\u52a0\u4e00\u4e2a\u6a21\u578b\uff0c\u5047\u8bbe\u73b0\u5728\u6211\u4eec\u4e5f\u6709\u4e00\u4e2a xgboost \u6a21\u578b\uff0c\u5b83\u7684 AUC \u6bd4\u968f\u673a\u68ee\u6797\u9ad8\u3002\u73b0\u5728\uff0c\u6211\u5c06\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0cxgboost\uff1a\u968f\u673a\u68ee\u6797\uff1a\u903b\u8f91\u56de\u5f52\u7684\u6bd4\u4f8b\u4e3a 3:2:1\u3002\u5f88\u7b80\u5355\u5427\uff1f\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u6613\u5982\u53cd\u638c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u662f\u5982\u4f55\u505a\u5230\u7684\u3002 \u5047\u5b9a\u6211\u4eec\u6709\u4e09\u53ea\u7334\u5b50\uff0c\u4e09\u53ea\u65cb\u94ae\u7684\u6570\u503c\u5728 0 \u548c 1 \u4e4b\u95f4\u3002\u8fd9\u4e9b\u7334\u5b50\u8f6c\u52a8\u65cb\u94ae\uff0c\u6211\u4eec\u8ba1\u7b97\u5b83\u4eec\u6bcf\u8f6c\u5230\u4e00\u4e2a\u6570\u503c\u65f6\u7684 AUC \u5206\u6570\u3002\u6700\u7ec8\uff0c\u7334\u5b50\u4eec\u4f1a\u627e\u5230\u4e00\u4e2a\u80fd\u7ed9\u51fa\u6700\u4f73 AUC \u7684\u7ec4\u5408\u3002\u6ca1\u9519\uff0c\u8fd9\u5c31\u662f\u968f\u673a\u641c\u7d22\uff01\u5728\u8fdb\u884c\u8fd9\u7c7b\u641c\u7d22\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\u4e24\u4e2a\u6700\u91cd\u8981\u7684\u7ec4\u5408\u89c4\u5219\u3002 \u7ec4\u5408\u7684\u7b2c\u4e00\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u7ec4\u5408\u7684\u7b2c\u4e8c\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u662f\u7684\u3002\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u6761\u89c4\u5219\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u6298\u53e0\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u5047\u8bbe\u6211\u4eec\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\uff1a\u6298\u53e0 1 \u548c\u6298\u53e0 2\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u6837\u505a\u53ea\u662f\u4e3a\u4e86\u7b80\u5316\u89e3\u91ca\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\uff0c\u60a8\u5e94\u8be5\u521b\u5efa\u66f4\u591a\u7684\u6298\u53e0\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u968f\u673a\u68ee\u6797\u6a21\u578b\u3001\u903b\u8f91\u56de\u5f52\u6a21\u578b\u548c xgboost \u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 2 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u4e4b\u540e\uff0c\u6211\u4eec\u5728\u6298\u53e0 2 \u4e0a\u4ece\u5934\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 1 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u4e3a\u6240\u6709\u8bad\u7ec3\u6570\u636e\u521b\u5efa\u4e86\u9884\u6d4b\u7ed3\u679c\u3002\u73b0\u5728\uff0c\u4e3a\u4e86\u5408\u5e76\u8fd9\u4e9b\u6a21\u578b\uff0c\u6211\u4eec\u5c06\u6298\u53e0 1 \u548c\u6298\u53e0 1 \u7684\u6240\u6709\u9884\u6d4b\u6570\u636e\u5408\u5e76\u5728\u4e00\u8d77\uff0c\u7136\u540e\u521b\u5efa\u4e00\u4e2a\u4f18\u5316\u51fd\u6570\uff0c\u8bd5\u56fe\u627e\u5230\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4fbf\u9488\u5bf9\u6298\u53e0 2 \u7684\u76ee\u6807\u6700\u5c0f\u5316\u8bef\u5dee\u6216\u6700\u5927\u5316 AUC\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u662f\u7528\u4e09\u4e2a\u6a21\u578b\u7684\u9884\u6d4b\u6982\u7387\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u4e00\u4e2a\u4f18\u5316\u6a21\u578b\uff0c\u7136\u540e\u5728\u6298\u53e0 2 \u4e0a\u5bf9\u5176\u8fdb\u884c\u8bc4\u4f30\u3002\u8ba9\u6211\u4eec\u5148\u6765\u770b\u770b\u6211\u4eec\u53ef\u4ee5\u7528\u6765\u627e\u5230\u591a\u4e2a\u6a21\u578b\u7684\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4f18\u5316 AUC\uff08\u6216\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u6d4b\u6307\u6807\u7ec4\u5408\uff09\u7684\u7c7b\u3002 import numpy as np from functools import partial from scipy.optimize import fmin from sklearn import metrics class OptimizeAUC : def __init__ ( self ): # \u521d\u59cb\u5316\u7cfb\u6570 self . coef_ = 0 def _auc ( self , coef , X , y ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u7cfb\u6570 x_coef = X * coef # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8ba1\u7b97AUC\u5206\u6570 auc_score = metrics . roc_auc_score ( y , predictions ) # \u8fd4\u56de\u8d1fAUC\u4ee5\u4fbf\u6700\u5c0f\u5316 return - 1.0 * auc_score def fit ( self , X , y ): # \u521b\u5efa\u5e26\u6709\u90e8\u5206\u53c2\u6570\u7684\u76ee\u6807\u51fd\u6570 loss_partial = partial ( self . _auc , X = X , y = y ) # \u521d\u59cb\u5316\u7cfb\u6570 initial_coef = np . random . dirichlet ( np . ones ( X . shape [ 1 ]), size = 1 ) # \u4f7f\u7528fmin\u51fd\u6570\u4f18\u5316AUC\u76ee\u6807\u51fd\u6570\uff0c\u627e\u5230\u6700\u4f18\u7cfb\u6570 self . coef_ = fmin ( loss_partial , initial_coef , disp = True ) def predict ( self , X ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u8bad\u7ec3\u597d\u7684\u7cfb\u6570 x_coef = X * self . coef_ # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8fd4\u56de\u9884\u6d4b\u7ed3\u679c return predictions \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528\u5b83\uff0c\u5e76\u5c06\u5176\u4e0e\u7b80\u5355\u5e73\u5747\u6cd5\u8fdb\u884c\u6bd4\u8f83\u3002 import xgboost as xgb from sklearn.datasets import make_classification from sklearn import ensemble from sklearn import linear_model from sklearn import metrics from sklearn import model_selection # \u751f\u6210\u4e00\u4e2a\u5206\u7c7b\u6570\u636e\u96c6 X , y = make_classification ( n_samples = 10000 , n_features = 25 ) # \u5212\u5206\u6570\u636e\u96c6\u4e3a\u4e24\u4e2a\u4ea4\u53c9\u9a8c\u8bc1\u6298\u53e0 xfold1 , xfold2 , yfold1 , yfold2 = model_selection . train_test_split ( X , y , test_size = 0.5 , stratify = y ) # \u521d\u59cb\u5316\u4e09\u4e2a\u4e0d\u540c\u7684\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold1 , yfold1 ) rf . fit ( xfold1 , yfold1 ) xgbc . fit ( xfold1 , yfold1 ) # \u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold2 )[:, 1 ] pred_rf = rf . predict_proba ( xfold2 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold2 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold2_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold2 = [] for i in range ( fold2_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold2 , fold2_preds [:, i ]) aucs_fold2 . append ( auc ) print ( f \"Fold-2: LR AUC = { aucs_fold2 [ 0 ] } \" ) print ( f \"Fold-2: RF AUC = { aucs_fold2 [ 1 ] } \" ) print ( f \"Fold-2: XGB AUC = { aucs_fold2 [ 2 ] } \" ) print ( f \"Fold-2: Average Pred AUC = { aucs_fold2 [ 3 ] } \" ) # \u91cd\u65b0\u521d\u59cb\u5316\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold2 , yfold2 ) rf . fit ( xfold2 , yfold2 ) xgbc . fit ( xfold2 , yfold2 ) # \u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold1 )[:, 1 ] pred_rf = rf . predict_proba ( xfold1 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold1 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold1_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold1 = [] for i in range ( fold1_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold1 , fold1_preds [:, i ]) aucs_fold1 . append ( auc ) print ( f \"Fold-1: LR AUC = { aucs_fold1 [ 0 ] } \" ) print ( f \"Fold-1: RF AUC = { aucs_fold1 [ 1 ] } \" ) print ( f \"Fold-1: XGB AUC = { aucs_fold1 [ 2 ] } \" ) print ( f \"Fold-1: Average prediction AUC = { aucs_fold1 [ 3 ] } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765\u8bad\u7ec3\u4f18\u5316\u5668 opt . fit ( fold1_preds [:, : - 1 ], yfold1 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold2 = opt . predict ( fold2_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold2 , opt_preds_fold2 ) print ( f \"Optimized AUC, Fold 2 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765 opt . fit ( fold2_preds [:, : - 1 ], yfold2 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold1 = opt . predict ( fold1_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold1 , opt_preds_fold1 ) print ( f \"Optimized AUC, Fold 1 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) \u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u8f93\u51fa\uff1a \u276f python auc_opt . py Fold - 2 : LR AUC = 0.9145446769443348 Fold - 2 : RF AUC = 0.9269918948683287 Fold - 2 : XGB AUC = 0.9302436595508696 Fold - 2 : Average Pred AUC = 0.927701495890154 Fold - 1 : LR AUC = 0.9050872233256017 Fold - 1 : RF AUC = 0.9179382818311258 Fold - 1 : XGB AUC = 0.9195837242005629 Fold - 1 : Average prediction AUC = 0.9189669233123695 Optimization terminated successfully . Current function value : - 0.920643 Iterations : 50 Function evaluations : 109 Optimized AUC , Fold 2 = 0.9305386199756128 Coefficients = [ - 0.00188194 0.19328336 0.35891836 ] Optimization terminated successfully . Current function value : - 0.931232 Iterations : 56 Function evaluations : 113 Optimized AUC , Fold 1 = 0.9192523637234037 Coefficients = [ - 0.15655124 0.22393151 0.58711366 ] \u6211\u4eec\u770b\u5230\uff0c\u5e73\u5747\u503c\u66f4\u597d\uff0c\u4f46\u4f7f\u7528\u4f18\u5316\u5668\u627e\u5230\u9608\u503c\u66f4\u597d\uff01\u6709\u65f6\uff0c\u5e73\u5747\u503c\u662f\u6700\u597d\u7684\u9009\u62e9\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u7cfb\u6570\u52a0\u8d77\u6765\u5e76\u6ca1\u6709\u8fbe\u5230 1.0\uff0c\u4f46\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u6211\u4eec\u8981\u5904\u7406\u7684\u662f AUC\uff0c\u800c AUC \u53ea\u5173\u5fc3\u7b49\u7ea7\u3002 \u5373\u4f7f\u968f\u673a\u68ee\u6797\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u968f\u673a\u68ee\u6797\u53ea\u662f\u8bb8\u591a\u7b80\u5355\u51b3\u7b56\u6811\u7684\u7ec4\u5408\u3002\u968f\u673a\u68ee\u6797\u5c5e\u4e8e\u96c6\u5408\u6a21\u578b\u7684\u4e00\u79cd\uff0c\u4e5f\u5c31\u662f\u4fd7\u79f0\u7684 \"bagging\" \u3002\u5728\u888b\u96c6\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u5c0f\u6570\u636e\u5b50\u96c6\u5e76\u8bad\u7ec3\u591a\u4e2a\u7b80\u5355\u6a21\u578b\u3002\u6700\u7ec8\u7ed3\u679c\u7531\u6240\u6709\u8fd9\u4e9b\u5c0f\u6a21\u578b\u7684\u9884\u6d4b\u7ed3\u679c\uff08\u5982\u5e73\u5747\u503c\uff09\u7ec4\u5408\u800c\u6210\u3002 \u6211\u4eec\u4f7f\u7528\u7684 xgboost \u6a21\u578b\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u6240\u6709\u68af\u5ea6\u63d0\u5347\u6a21\u578b\u90fd\u662f\u96c6\u5408\u6a21\u578b\uff0c\u7edf\u79f0\u4e3a \u63d0\u5347\u6a21\u578b\uff08boosting models\uff09 \u3002\u63d0\u5347\u6a21\u578b\u7684\u5de5\u4f5c\u539f\u7406\u4e0e\u88c5\u888b\u6a21\u578b\u7c7b\u4f3c\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u63d0\u5347\u6a21\u578b\u4e2d\u7684\u8fde\u7eed\u6a21\u578b\u662f\u6839\u636e\u8bef\u5dee\u6b8b\u5dee\u8bad\u7ec3\u7684\uff0c\u5e76\u503e\u5411\u4e8e\u6700\u5c0f\u5316\u524d\u9762\u6a21\u578b\u7684\u8bef\u5dee\u3002\u8fd9\u6837\uff0c\u63d0\u5347\u6a21\u578b\u5c31\u80fd\u5b8c\u7f8e\u5730\u5b66\u4e60\u6570\u636e\uff0c\u56e0\u6b64\u5bb9\u6613\u51fa\u73b0\u8fc7\u62df\u5408\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u770b\u5230\u7684\u4ee3\u7801\u7247\u6bb5\u53ea\u8003\u8651\u4e86\u4e00\u5217\u3002\u4f46\u60c5\u51b5\u5e76\u975e\u603b\u662f\u5982\u6b64\uff0c\u5f88\u591a\u65f6\u5019\u60a8\u9700\u8981\u5904\u7406\u591a\u5217\u9884\u6d4b\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u4f1a\u9047\u5230\u4ece\u591a\u4e2a\u7c7b\u522b\u4e2d\u9884\u6d4b\u4e00\u4e2a\u7c7b\u522b\u7684\u95ee\u9898\uff0c\u5373\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5bf9\u4e8e\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u9009\u62e9\u6295\u7968\u65b9\u6cd5\u3002\u4f46\u6295\u7968\u6cd5\u5e76\u4e0d\u603b\u662f\u6700\u4f73\u65b9\u6cd5\u3002\u5982\u679c\u8981\u7ec4\u5408\u6982\u7387\uff0c\u5c31\u4f1a\u6709\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4\uff0c\u800c\u4e0d\u662f\u50cf\u6211\u4eec\u4e4b\u524d\u4f18\u5316 AUC \u65f6\u7684\u5411\u91cf\u3002\u5982\u679c\u6709\u591a\u4e2a\u7c7b\u522b\uff0c\u53ef\u4ee5\u5c1d\u8bd5\u4f18\u5316\u5bf9\u6570\u635f\u5931\uff08\u6216\u5176\u4ed6\u4e0e\u4e1a\u52a1\u76f8\u5173\u7684\u6307\u6807\uff09\u3002 \u8981\u8fdb\u884c\u7ec4\u5408\uff0c\u53ef\u4ee5\u5728\u62df\u5408\u51fd\u6570 (X) \u4e2d\u4f7f\u7528 numpy \u6570\u7ec4\u5217\u8868\u800c\u4e0d\u662f numpy \u6570\u7ec4\uff0c\u968f\u540e\u8fd8\u9700\u8981\u66f4\u6539\u4f18\u5316\u5668\u548c\u9884\u6d4b\u51fd\u6570\u3002\u6211\u5c31\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u5427\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8fdb\u5165\u4e0b\u4e00\u4e2a\u6709\u8da3\u7684\u8bdd\u9898\uff0c\u8fd9\u4e2a\u8bdd\u9898\u76f8\u5f53\u6d41\u884c\uff0c\u88ab\u79f0\u4e3a \u5806\u53e0 \u3002\u56fe 2 \u5c55\u793a\u4e86\u5982\u4f55\u5806\u53e0\u6a21\u578b\u3002 \u56fe2 : Stacking \u5806\u53e0\u4e0d\u50cf\u5236\u9020\u706b\u7bad\u3002\u5b83\u7b80\u5355\u660e\u4e86\u3002\u5982\u679c\u60a8\u8fdb\u884c\u4e86\u6b63\u786e\u7684\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5e76\u5728\u6574\u4e2a\u5efa\u6a21\u8fc7\u7a0b\u4e2d\u4fdd\u6301\u6298\u53e0\u4e0d\u53d8\uff0c\u90a3\u4e48\u5c31\u4e0d\u4f1a\u51fa\u73b0\u4efb\u4f55\u8fc7\u5ea6\u8d34\u5408\u7684\u60c5\u51b5\u3002 \u8ba9\u6211\u7528\u7b80\u5355\u7684\u8981\u70b9\u5411\u4f60\u63cf\u8ff0\u4e00\u4e0b\u8fd9\u4e2a\u60f3\u6cd5\u3002 - \u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u82e5\u5e72\u6298\u53e0\u3002 - \u8bad\u7ec3\u4e00\u5806\u6a21\u578b\uff1a M1\u3001M2.....Mn\u3002 - \u521b\u5efa\u5b8c\u6574\u7684\u8bad\u7ec3\u9884\u6d4b\uff08\u4f7f\u7528\u975e\u6298\u53e0\u8bad\u7ec3\uff09\uff0c\u5e76\u4f7f\u7528\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\u9884\u6d4b\u3002 - \u76f4\u5230\u8fd9\u91cc\u662f\u7b2c 1 \u5c42 (L1)\u3002 - \u5c06\u8fd9\u4e9b\u6a21\u578b\u7684\u6298\u53e0\u9884\u6d4b\u4f5c\u4e3a\u53e6\u4e00\u4e2a\u6a21\u578b\u7684\u7279\u5f81\u3002\u8fd9\u5c31\u662f\u4e8c\u7ea7\u6a21\u578b\uff08L2\uff09\u3002 - \u4f7f\u7528\u4e0e\u4e4b\u524d\u76f8\u540c\u7684\u6298\u53e0\u6765\u8bad\u7ec3\u8fd9\u4e2a L2 \u6a21\u578b\u3002 - \u73b0\u5728\uff0c\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u521b\u5efa OOF\uff08\u6298\u53e0\u5916\uff09\u9884\u6d4b\u3002 - \u73b0\u5728\u60a8\u5c31\u6709\u4e86\u8bad\u7ec3\u6570\u636e\u7684 L2 \u9884\u6d4b\u548c\u6700\u7ec8\u6d4b\u8bd5\u96c6\u9884\u6d4b\u3002 \u60a8\u53ef\u4ee5\u4e0d\u65ad\u91cd\u590d L1 \u90e8\u5206\uff0c\u4e5f\u53ef\u4ee5\u521b\u5efa\u4efb\u610f\u591a\u7684\u5c42\u6b21\u3002 \u6709\u65f6\uff0c\u4f60\u8fd8\u4f1a\u9047\u5230\u4e00\u4e2a\u53eb\u6df7\u5408\u7684\u672f\u8bed blending \u3002\u5982\u679c\u4f60\u9047\u5230\u4e86\uff0c\u4e0d\u7528\u592a\u62c5\u5fc3\u3002\u5b83\u53ea\u4e0d\u8fc7\u662f\u7528\u4e00\u4e2a\u4fdd\u7559\u7ec4\u6765\u5806\u53e0\uff0c\u800c\u4e0d\u662f\u591a\u91cd\u6298\u53e0\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u6211\u5728\u672c\u7ae0\u4e2d\u6240\u63cf\u8ff0\u7684\u5185\u5bb9\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u7c7b\u578b\u7684\u95ee\u9898\uff1a\u5206\u7c7b\u3001\u56de\u5f52\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u7b49\u3002","title":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5"},{"location":"%E7%BB%84%E7%BB%87%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%A1%B9%E7%9B%AE/","text":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee \u7ec8\u4e8e\uff0c\u6211\u4eec\u53ef\u4ee5\u5f00\u59cb\u6784\u5efa\u7b2c\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e86\u3002 \u662f\u8fd9\u6837\u5417\uff1f \u5728\u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u6ce8\u610f\u51e0\u4ef6\u4e8b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u6211\u4eec\u5c06\u5728\u96c6\u6210\u5f00\u53d1\u73af\u5883/\u6587\u672c\u7f16\u8f91\u5668\u4e2d\u5de5\u4f5c\uff0c\u800c\u4e0d\u662f\u5728 jupyter notebook\u4e2d\u3002\u4f60\u4e5f\u53ef\u4ee5\u5728 jupyter notebook\u4e2d\u5de5\u4f5c\uff0c\u8fd9\u5b8c\u5168\u53d6\u51b3\u4e8e\u4f60\u3002\u4e0d\u8fc7\uff0c\u6211\u5c06\u53ea\u4f7f\u7528 jupyter notebook\u6765\u63a2\u7d22\u6570\u636e\u3001\u7ed8\u5236\u56fe\u8868\u548c\u56fe\u5f62\u3002\u6211\u4eec\u5c06\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u6784\u5efa\u5206\u7c7b\u6846\u67b6\uff0c\u5373\u63d2\u5373\u7528\u3002\u60a8\u65e0\u9700\u5bf9\u4ee3\u7801\u505a\u592a\u591a\u6539\u52a8\u5c31\u80fd\u8bad\u7ec3\u6a21\u578b\uff0c\u800c\u4e14\u5f53\u60a8\u6539\u8fdb\u6a21\u578b\u65f6\uff0c\u8fd8\u80fd\u4f7f\u7528 git \u5bf9\u5176\u8fdb\u884c\u8ddf\u8e2a\u3002 \u6211\u4eec\u9996\u5148\u6765\u770b\u770b\u6587\u4ef6\u7684\u7ed3\u6784\u3002\u5bf9\u4e8e\u4f60\u6b63\u5728\u505a\u7684\u4efb\u4f55\u9879\u76ee\uff0c\u90fd\u8981\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\u5939\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u5c06\u9879\u76ee\u547d\u540d\u4e3a \"project\"\u3002 \u9879\u76ee\u6587\u4ef6\u5939\u5185\u90e8\u5e94\u8be5\u5982\u4e0b\u6240\u793a\u3002 input train.csv test.csv src create_folds.py train.py inference.py models.py config.py model_dispatcher.py models model_rf.bin model_et.bin notebooks exploration.ipynb check_data.ipynb README.md LICENSE \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e9b\u6587\u4ef6\u5939\u548c\u6587\u4ef6\u7684\u5185\u5bb9\u3002 input/ \uff1a\u8be5\u6587\u4ef6\u5939\u5305\u542b\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7684\u6240\u6709\u8f93\u5165\u6587\u4ef6\u548c\u6570\u636e\u3002\u5982\u679c\u60a8\u6b63\u5728\u5f00\u53d1 NLP \u9879\u76ee\uff0c\u60a8\u53ef\u4ee5\u5c06embeddings\u653e\u5728\u8fd9\u91cc\u3002\u5982\u679c\u662f\u56fe\u50cf\u9879\u76ee\uff0c\u6240\u6709\u56fe\u50cf\u90fd\u653e\u5728\u8be5\u6587\u4ef6\u5939\u4e0b\u7684\u5b50\u6587\u4ef6\u5939\u4e2d\u3002 src/ \uff1a\u6211\u4eec\u5c06\u5728\u8fd9\u91cc\u4fdd\u5b58\u4e0e\u9879\u76ee\u76f8\u5173\u7684\u6240\u6709 python \u811a\u672c\u3002\u5982\u679c\u6211\u8bf4\u7684\u662f\u4e00\u4e2a python \u811a\u672c\uff0c\u5373\u4efb\u4f55 *.py \u6587\u4ef6\uff0c\u5b83\u90fd\u5b58\u50a8\u5728 src \u6587\u4ef6\u5939\u4e2d\u3002 models/ \uff1a\u8be5\u6587\u4ef6\u5939\u4fdd\u5b58\u6240\u6709\u8bad\u7ec3\u8fc7\u7684\u6a21\u578b\u3002 notebook/ \uff1a\u6240\u6709 jupyter notebook\uff08\u5373\u4efb\u4f55 *.ipynb \u6587\u4ef6\uff09\u90fd\u5b58\u50a8\u5728\u7b14\u8bb0\u672c \u6587\u4ef6\u5939\u4e2d\u3002 README.md \uff1a\u8fd9\u662f\u4e00\u4e2a\u6807\u8bb0\u7b26\u6587\u4ef6\uff0c\u60a8\u53ef\u4ee5\u5728\u5176\u4e2d\u63cf\u8ff0\u60a8\u7684\u9879\u76ee\uff0c\u5e76\u5199\u660e\u5982\u4f55\u8bad\u7ec3\u6a21\u578b\u6216\u5728\u751f\u4ea7\u73af\u5883\u4e2d\u4f7f\u7528\u3002 LICENSE \uff1a\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u6587\u4ef6\uff0c\u5305\u542b\u9879\u76ee\u7684\u8bb8\u53ef\u8bc1\uff0c\u5982 MIT\u3001Apache \u7b49\u3002\u5173\u4e8e\u8bb8\u53ef\u8bc1\u7684\u8be6\u7ec6\u4ecb\u7ecd\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 \u5047\u8bbe\u4f60\u6b63\u5728\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u5bf9 MNIST \u6570\u636e\u96c6\uff08\u51e0\u4e4e\u6bcf\u672c\u673a\u5668\u5b66\u4e60\u4e66\u7c4d\u90fd\u4f1a\u7528\u5230\u7684\u6570\u636e\u96c6\uff09\u8fdb\u884c\u5206\u7c7b\u3002\u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\uff0c\u6211\u4eec\u5728\u4ea4\u53c9\u68c0\u9a8c\u4e00\u7ae0\u4e2d\u4e5f\u63d0\u5230\u8fc7 MNIST \u6570\u636e\u96c6\u3002\u6240\u4ee5\uff0c\u6211\u5c31\u4e0d\u89e3\u91ca\u8fd9\u4e2a\u6570\u636e\u96c6\u662f\u4ec0\u4e48\u6837\u5b50\u4e86\u3002\u7f51\u4e0a\u6709\u8bb8\u591a\u4e0d\u540c\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6\uff0c\u4f46\u6211\u4eec\u5c06\u4f7f\u7528 CSV \u683c\u5f0f\u7684\u6570\u636e\u96c6\u3002 \u5728\u8fd9\u79cd\u683c\u5f0f\u7684\u6570\u636e\u96c6\u4e2d\uff0cCSV \u7684\u6bcf\u4e00\u884c\u90fd\u5305\u542b\u56fe\u50cf\u7684\u6807\u7b7e\u548c 784 \u4e2a\u50cf\u7d20\u503c\uff0c\u50cf\u7d20\u503c\u8303\u56f4\u4ece 0 \u5230 255\u3002\u6570\u636e\u96c6\u5305\u542b 60000 \u5f20\u8fd9\u79cd\u683c\u5f0f\u7684\u56fe\u50cf\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u8f7b\u677e\u8bfb\u53d6\u8fd9\u79cd\u6570\u636e\u683c\u5f0f\u3002 \u8bf7\u6ce8\u610f\uff0c\u5c3d\u7ba1\u56fe 1 \u663e\u793a\u6240\u6709\u50cf\u7d20\u503c\u5747\u4e3a\u96f6\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002 \u56fe 1\uff1aCSV\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u6807\u7b7e\u5217\u7684\u8ba1\u6570\u3002 \u56fe 2\uff1aMNIST \u6570\u636e\u96c6\u4e2d\u7684\u6807\u7b7e\u8ba1\u6570 \u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c\u66f4\u591a\u7684\u63a2\u7d22\u3002\u6211\u4eec\u5df2\u7ecf\u77e5\u9053\u4e86\u6211\u4eec\u6240\u62e5\u6709\u7684\u6570\u636e\uff0c\u6ca1\u6709\u5fc5\u8981\u518d\u5bf9\u4e0d\u540c\u7684\u50cf\u7d20\u503c\u8fdb\u884c\u7ed8\u56fe\u3002\u4ece\u56fe 2 \u4e2d\u53ef\u4ee5\u6e05\u695a\u5730\u770b\u51fa\uff0c\u6807\u7b7e\u7684\u5206\u5e03\u76f8\u5f53\u5747\u5300\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387/F1 \u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u8fd9\u5c31\u662f\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u7684\u7b2c\u4e00\u6b65\uff1a\u786e\u5b9a\u8861\u91cf\u6807\u51c6\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7f16\u5199\u4e00\u4e9b\u4ee3\u7801\u4e86\u3002\u6211\u4eec\u9700\u8981\u521b\u5efa src/ \u6587\u4ef6\u5939\u548c\u4e00\u4e9b python \u811a\u672c\u3002 \u8bf7\u6ce8\u610f\uff0c\u8bad\u7ec3 CSV \u6587\u4ef6\u4f4d\u4e8e input/ \u6587\u4ef6\u5939\u4e2d\uff0c\u540d\u4e3a mnist_train.csv \u3002 \u5bf9\u4e8e\u8fd9\u6837\u4e00\u4e2a\u9879\u76ee\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5e94\u8be5\u662f\u4ec0\u4e48\u6837\u7684\u5462\uff1f \u9996\u5148\u8981\u521b\u5efa\u7684\u811a\u672c\u662f create_folds.py \u3002 \u8fd9\u5c06\u5728 input/ \u6587\u4ef6\u5939\u4e2d\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a mnist_train_folds.csv \u7684\u65b0\u6587\u4ef6\uff0c\u4e0e mnist_train.csv \u76f8\u540c\u3002\u552f\u4e00\u4e0d\u540c\u7684\u662f\uff0c\u8fd9\u4e2a CSV \u6587\u4ef6\u7ecf\u8fc7\u4e86\u968f\u673a\u6392\u5e8f\uff0c\u5e76\u65b0\u589e\u4e86\u4e00\u5217\u540d\u4e3a kfold \u7684\u5185\u5bb9\u3002 \u4e00\u65e6\u6211\u4eec\u51b3\u5b9a\u4e86\u8981\u4f7f\u7528\u54ea\u79cd\u8bc4\u4f30\u6307\u6807\u5e76\u521b\u5efa\u4e86\u6298\u53e0\uff0c\u5c31\u53ef\u4ee5\u5f00\u59cb\u521b\u5efa\u57fa\u672c\u6a21\u578b\u4e86\u3002\u8fd9\u53ef\u4ee5\u5728 train.py \u4e2d\u5b8c\u6210\u3002 import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/mnist_train_folds.csv\" ) # \u9009\u53d6df\u4e2dkfold\u5217\u4e0d\u7b49\u4e8efold df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u9009\u53d6df\u4e2dkfold\u5217\u7b49\u4e8efold df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u8bad\u7ec3\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_train = df_train . drop ( \"label\" , axis = 1 ) . values # \u8bad\u7ec3\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_train = df_train . label . values # \u9a8c\u8bc1\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values # \u9a8c\u8bc1\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_valid = df_valid . label . values # \u5b9e\u4f8b\u5316\u51b3\u7b56\u6811\u6a21\u578b clf = tree . DecisionTreeClassifier () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b clf . fit ( x_train , y_train ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u8f93\u5165\u5f97\u5230\u9884\u6d4b\u7ed3\u679c preds = clf . predict ( x_valid ) # \u8ba1\u7b97\u9a8c\u8bc1\u96c6\u51c6\u786e\u7387 accuracy = metrics . accuracy_score ( y_valid , preds ) # \u6253\u5370fold\u4fe1\u606f\u548c\u51c6\u786e\u7387 print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) # \u4fdd\u5b58\u6a21\u578b joblib . dump ( clf , f \"../models/dt_ { fold } .bin\" ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u60a8\u53ef\u4ee5\u5728\u63a7\u5236\u53f0\u8c03\u7528 python train.py \u8fd0\u884c\u8be5\u811a\u672c\u3002 \u276f python train . py Fold = 0 , Accuracy = 0.8680833333333333 Fold = 1 , Accuracy = 0.8685 Fold = 2 , Accuracy = 0.8674166666666666 Fold = 3 , Accuracy = 0.8703333333333333 Fold = 4 , Accuracy = 0.8699166666666667 \u67e5\u770b\u8bad\u7ec3\u811a\u672c\u65f6\uff0c\u60a8\u4f1a\u53d1\u73b0\u8fd8\u6709\u4e00\u4e9b\u5185\u5bb9\u662f\u786c\u7f16\u7801\u7684\uff0c\u4f8b\u5982\u6298\u53e0\u6570\u3001\u8bad\u7ec3\u6587\u4ef6\u548c\u8f93\u51fa\u6587\u4ef6\u5939\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u6240\u6709\u8fd9\u4e9b\u4fe1\u606f\u7684\u914d\u7f6e\u6587\u4ef6\uff1a config.py \u3002 TRAINING_FILE = \"../input/mnist_train_folds.csv\" MODEL_OUTPUT = \"../models/\" \u6211\u4eec\u8fd8\u5bf9\u8bad\u7ec3\u811a\u672c\u8fdb\u884c\u4e86\u4e00\u4e9b\u4fee\u6539\u3002\u8bad\u7ec3\u6587\u4ef6\u73b0\u5728\u4f7f\u7528\u914d\u7f6e\u6587\u4ef6\u3002\u8fd9\u6837\uff0c\u66f4\u6539\u6570\u636e\u6216\u6a21\u578b\u8f93\u51fa\u5c31\u66f4\u5bb9\u6613\u4e86\u3002 import os import config import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u4f7f\u7528config\u4e2d\u7684\u8def\u5f84\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values clf = tree . DecisionTreeClassifier () clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" ) ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5e76\u6ca1\u6709\u5c55\u793a\u8fd9\u4e2a\u57f9\u8bad\u811a\u672c\u4e0e\u4e4b\u524d\u811a\u672c\u7684\u533a\u522b\u3002\u8bf7\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e24\u4e2a\u811a\u672c\uff0c\u81ea\u5df1\u627e\u51fa\u4e0d\u540c\u4e4b\u5904\u3002\u533a\u522b\u5e76\u4e0d\u591a\u3002 \u4e0e\u8bad\u7ec3\u811a\u672c\u76f8\u5173\u7684\u8fd8\u6709\u4e00\u70b9\u53ef\u4ee5\u6539\u8fdb\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u4e3a\u6bcf\u4e2a\u6298\u53e0\u591a\u6b21\u8c03\u7528\u8fd0\u884c\u51fd\u6570\u3002\u6709\u65f6\uff0c\u5728\u540c\u4e00\u4e2a\u811a\u672c\u4e2d\u8fd0\u884c\u591a\u4e2a\u6298\u53e0\u5e76\u4e0d\u53ef\u53d6\uff0c\u56e0\u4e3a\u5185\u5b58\u6d88\u8017\u53ef\u80fd\u4f1a\u4e0d\u65ad\u589e\u52a0\uff0c\u7a0b\u5e8f\u53ef\u80fd\u4f1a\u5d29\u6e83\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5411\u8bad\u7ec3\u811a\u672c\u4f20\u9012\u53c2\u6570\u3002\u6211\u559c\u6b22\u4f7f\u7528 argparse\u3002 import argparse if __name__ == \"__main__\" : # \u5b9e\u4f8b\u5316\u53c2\u6570\u73af\u5883 parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # \u8bfb\u53d6\u53c2\u6570 args = parser . parse_args () run ( fold = args . fold ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u8fd0\u884c python \u811a\u672c\uff0c\u4f46\u4ec5\u9650\u4e8e\u7ed9\u5b9a\u7684\u6298\u53e0\u3002 \u276f python train . py -- fold 0 Fold = 0 , Accuracy = 0.8656666666666667 \u4ed4\u7ec6\u89c2\u5bdf\uff0c\u6211\u4eec\u7684\u7b2c 0 \u6298\u5f97\u5206\u4e0e\u4e4b\u524d\u6709\u4e9b\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6a21\u578b\u4e2d\u5b58\u5728\u968f\u673a\u6027\u3002\u6211\u4eec\u5c06\u5728\u540e\u9762\u7684\u7ae0\u8282\u4e2d\u8ba8\u8bba\u5982\u4f55\u5904\u7406\u968f\u673a\u6027\u3002 \u73b0\u5728\uff0c\u5982\u679c\u4f60\u613f\u610f\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a shell \u811a\u672c \uff0c\u9488\u5bf9\u4e0d\u540c\u7684\u6298\u53e0\u4f7f\u7528\u4e0d\u540c\u7684\u547d\u4ee4\uff0c\u7136\u540e\u4e00\u8d77\u8fd0\u884c\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 python train . py -- fold 0 python train . py -- fold 1 python train . py -- fold 2 python train . py -- fold 3 python train . py -- fold 4 \u60a8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u5b83\u3002 \u276f sh run . sh Fold = 0 , Accuracy = 0.8675 Fold = 1 , Accuracy = 0.8693333333333333 Fold = 2 , Accuracy = 0.8683333333333333 Fold = 3 , Accuracy = 0.8704166666666666 Fold = 4 , Accuracy = 0.8685 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u53d6\u5f97\u4e86\u4e00\u4e9b\u8fdb\u5c55\uff0c\u4f46\u5982\u679c\u6211\u4eec\u770b\u4e00\u4e0b\u6211\u4eec\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u6211\u4eec\u4ecd\u7136\u53d7\u5230\u4e00\u4e9b\u4e1c\u897f\u7684\u9650\u5236\uff0c\u4f8b\u5982\u6a21\u578b\u3002\u6a21\u578b\u662f\u786c\u7f16\u7801\u5728\u8bad\u7ec3\u811a\u672c\u4e2d\u7684\uff0c\u53ea\u6709\u4fee\u6539\u811a\u672c\u624d\u80fd\u6539\u53d8\u5b83\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u65b0\u7684 python \u811a\u672c\uff0c\u540d\u4e3a model_dispatcher.py \u3002model_dispatcher.py\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5c06\u8c03\u5ea6\u6211\u4eec\u7684\u6a21\u578b\u5230\u8bad\u7ec3\u811a\u672c\u4e2d\u3002 from sklearn import tree models = { # \u4ee5gini\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), # \u4ee5entropy\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), } model_dispatcher.py \u4ece scikit-learn \u4e2d\u5bfc\u5165\u4e86 tree\uff0c\u5e76\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u6a21\u578b\u7684\u540d\u79f0\uff0c\u503c\u662f\u6a21\u578b\u672c\u8eab\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5b9a\u4e49\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u51b3\u7b56\u6811\uff0c\u4e00\u79cd\u4f7f\u7528\u57fa\u5c3c\u6807\u51c6\uff0c\u53e6\u4e00\u79cd\u4f7f\u7528\u71b5\u6807\u51c6\u3002\u8981\u4f7f\u7528 py\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u8bad\u7ec3\u811a\u672c\u505a\u4e00\u4e9b\u4fee\u6539\u3002 import argparse import os import joblib import pandas as pd from sklearn import metrics import config import model_dispatcher def run ( fold , model ): df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values # \u6839\u636emodel\u53c2\u6570\u9009\u62e9\u6a21\u578b clf = model_dispatcher . models [ model ] clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" )) if __name__ == \"__main__\" : parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # model\u53c2\u6570 parser . add_argument ( \"--model\" , type = str ) args = parser . parse_args () run ( fold = args . fold , model = args . model ) train.py \u6709\u51e0\u5904\u91cd\u5927\u6539\u52a8\uff1a - \u5bfc\u5165 model_dispatcher - \u4e3a ArgumentParser \u6dfb\u52a0 --model \u53c2\u6570 - \u4e3a run() \u51fd\u6570\u6dfb\u52a0model\u53c2\u6570 - \u4f7f\u7528\u8c03\u5ea6\u7a0b\u5e8f\u83b7\u53d6\u6307\u5b9a\u540d\u79f0\u7684\u6a21\u578b \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u811a\u672c\uff1a \u276f python train . py -- fold 0 -- model decision_tree_gini Fold = 0 , Accuracy = 0.8665833333333334 \u6216\u6267\u884c\u4ee5\u4e0b\u547d\u4ee4 \u276f python train . py -- fold 0 -- model decision_tree_entropy Fold = 0 , Accuracy = 0.8705833333333334 \u73b0\u5728\uff0c\u5982\u679c\u8981\u6dfb\u52a0\u65b0\u6a21\u578b\uff0c\u53ea\u9700\u4fee\u6539 model_dispatcher.py \u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u6dfb\u52a0\u968f\u673a\u68ee\u6797\uff0c\u770b\u770b\u51c6\u786e\u7387\u4f1a\u6709\u4ec0\u4e48\u53d8\u5316\u3002 from sklearn import ensemble from sklearn import tree models = { \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), # \u968f\u673a\u68ee\u6797\u6a21\u578b \"rf\" : ensemble . RandomForestClassifier (), } \u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u3002 \u276f python train . py -- fold 0 -- model rf Fold = 0 , Accuracy = 0.9670833333333333 \u54c7\uff0c\u4e00\u4e2a\u7b80\u5355\u7684\u6539\u52a8\u5c31\u80fd\u8ba9\u5206\u6570\u6709\u5982\u6b64\u5927\u7684\u63d0\u5347\uff01\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528 run.sh \u811a\u672c\u8fd0\u884c 5 \u4e2a\u6298\u53e0\uff01 python train . py -- fold 0 -- model rf python train . py -- fold 1 -- model rf python train . py -- fold 2 -- model rf python train . py -- fold 3 -- model rf python train . py -- fold 4 -- model rf \u5f97\u5206\u60c5\u51b5\u5982\u4e0b \u276f sh run . sh Fold = 0 , Accuracy = 0.9674166666666667 Fold = 1 , Accuracy = 0.9698333333333333 Fold = 2 , Accuracy = 0.96575 Fold = 3 , Accuracy = 0.9684166666666667 Fold = 4 , Accuracy = 0.9666666666666667 MNIST \u51e0\u4e4e\u662f\u6bcf\u672c\u4e66\u548c\u6bcf\u7bc7\u535a\u5ba2\u90fd\u4f1a\u8ba8\u8bba\u7684\u95ee\u9898\u3002\u4f46\u6211\u8bd5\u56fe\u5c06\u8fd9\u4e2a\u95ee\u9898\u8f6c\u6362\u5f97\u66f4\u6709\u8da3\uff0c\u5e76\u5411\u4f60\u5c55\u793a\u5982\u4f55\u4e3a\u4f60\u6b63\u5728\u505a\u7684\u6216\u8ba1\u5212\u5728\u4e0d\u4e45\u7684\u5c06\u6765\u505a\u7684\u51e0\u4e4e\u6240\u6709\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7f16\u5199\u4e00\u4e2a\u57fa\u672c\u6846\u67b6\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4ee5\u6539\u8fdb\u8fd9\u4e2a MNIST \u6a21\u578b\u548c\u8fd9\u4e2a\u6846\u67b6\uff0c\u6211\u4eec\u5c06\u5728\u4ee5\u540e\u7684\u7ae0\u8282\u4e2d\u770b\u5230\u3002 \u6211\u4f7f\u7528\u4e86\u4e00\u4e9b\u811a\u672c\uff0c\u5982 model_dispatcher.py \u548c config.py \uff0c\u5e76\u5c06\u5b83\u4eec\u5bfc\u5165\u5230\u6211\u7684\u8bad\u7ec3\u811a\u672c\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u6ca1\u6709\u5bfc\u5165 \uff0c\u4f60\u4e5f\u4e0d\u5e94\u8be5\u5bfc\u5165\u3002\u5982\u679c\u6211\u5bfc\u5165\u4e86 \uff0c\u4f60\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u77e5\u9053\u6a21\u578b\u5b57\u5178\u662f\u4ece\u54ea\u91cc\u6765\u7684\u3002\u7f16\u5199\u4f18\u79c0\u3001\u6613\u61c2\u7684\u4ee3\u7801\u662f\u4e00\u4e2a\u4eba\u5fc5\u987b\u5177\u5907\u7684\u57fa\u672c\u7d20\u8d28\uff0c\u4f46\u8bb8\u591a\u6570\u636e\u79d1\u5b66\u5bb6\u5374\u5ffd\u89c6\u4e86\u8fd9\u4e00\u70b9\u3002\u5982\u679c\u4f60\u6240\u505a\u7684\u9879\u76ee\u80fd\u8ba9\u5176\u4ed6\u4eba\u7406\u89e3\u5e76\u4f7f\u7528\uff0c\u800c\u65e0\u9700\u54a8\u8be2\u4f60\u7684\u610f\u89c1\uff0c\u90a3\u4e48\u4f60\u5c31\u8282\u7701\u4e86\u4ed6\u4eec\u7684\u65f6\u95f4\u548c\u81ea\u5df1\u7684\u65f6\u95f4\uff0c\u53ef\u4ee5\u5c06\u8fd9\u4e9b\u65f6\u95f4\u6295\u5165\u5230\u6539\u8fdb\u4f60\u7684\u9879\u76ee\u6216\u5f00\u53d1\u65b0\u9879\u76ee\u4e2d\u53bb\u3002","title":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee"},{"location":"%E7%BB%84%E7%BB%87%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%A1%B9%E7%9B%AE/#_1","text":"\u7ec8\u4e8e\uff0c\u6211\u4eec\u53ef\u4ee5\u5f00\u59cb\u6784\u5efa\u7b2c\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e86\u3002 \u662f\u8fd9\u6837\u5417\uff1f \u5728\u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u6ce8\u610f\u51e0\u4ef6\u4e8b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u6211\u4eec\u5c06\u5728\u96c6\u6210\u5f00\u53d1\u73af\u5883/\u6587\u672c\u7f16\u8f91\u5668\u4e2d\u5de5\u4f5c\uff0c\u800c\u4e0d\u662f\u5728 jupyter notebook\u4e2d\u3002\u4f60\u4e5f\u53ef\u4ee5\u5728 jupyter notebook\u4e2d\u5de5\u4f5c\uff0c\u8fd9\u5b8c\u5168\u53d6\u51b3\u4e8e\u4f60\u3002\u4e0d\u8fc7\uff0c\u6211\u5c06\u53ea\u4f7f\u7528 jupyter notebook\u6765\u63a2\u7d22\u6570\u636e\u3001\u7ed8\u5236\u56fe\u8868\u548c\u56fe\u5f62\u3002\u6211\u4eec\u5c06\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u6784\u5efa\u5206\u7c7b\u6846\u67b6\uff0c\u5373\u63d2\u5373\u7528\u3002\u60a8\u65e0\u9700\u5bf9\u4ee3\u7801\u505a\u592a\u591a\u6539\u52a8\u5c31\u80fd\u8bad\u7ec3\u6a21\u578b\uff0c\u800c\u4e14\u5f53\u60a8\u6539\u8fdb\u6a21\u578b\u65f6\uff0c\u8fd8\u80fd\u4f7f\u7528 git \u5bf9\u5176\u8fdb\u884c\u8ddf\u8e2a\u3002 \u6211\u4eec\u9996\u5148\u6765\u770b\u770b\u6587\u4ef6\u7684\u7ed3\u6784\u3002\u5bf9\u4e8e\u4f60\u6b63\u5728\u505a\u7684\u4efb\u4f55\u9879\u76ee\uff0c\u90fd\u8981\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\u5939\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u5c06\u9879\u76ee\u547d\u540d\u4e3a \"project\"\u3002 \u9879\u76ee\u6587\u4ef6\u5939\u5185\u90e8\u5e94\u8be5\u5982\u4e0b\u6240\u793a\u3002 input train.csv test.csv src create_folds.py train.py inference.py models.py config.py model_dispatcher.py models model_rf.bin model_et.bin notebooks exploration.ipynb check_data.ipynb README.md LICENSE \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e9b\u6587\u4ef6\u5939\u548c\u6587\u4ef6\u7684\u5185\u5bb9\u3002 input/ \uff1a\u8be5\u6587\u4ef6\u5939\u5305\u542b\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7684\u6240\u6709\u8f93\u5165\u6587\u4ef6\u548c\u6570\u636e\u3002\u5982\u679c\u60a8\u6b63\u5728\u5f00\u53d1 NLP \u9879\u76ee\uff0c\u60a8\u53ef\u4ee5\u5c06embeddings\u653e\u5728\u8fd9\u91cc\u3002\u5982\u679c\u662f\u56fe\u50cf\u9879\u76ee\uff0c\u6240\u6709\u56fe\u50cf\u90fd\u653e\u5728\u8be5\u6587\u4ef6\u5939\u4e0b\u7684\u5b50\u6587\u4ef6\u5939\u4e2d\u3002 src/ \uff1a\u6211\u4eec\u5c06\u5728\u8fd9\u91cc\u4fdd\u5b58\u4e0e\u9879\u76ee\u76f8\u5173\u7684\u6240\u6709 python \u811a\u672c\u3002\u5982\u679c\u6211\u8bf4\u7684\u662f\u4e00\u4e2a python \u811a\u672c\uff0c\u5373\u4efb\u4f55 *.py \u6587\u4ef6\uff0c\u5b83\u90fd\u5b58\u50a8\u5728 src \u6587\u4ef6\u5939\u4e2d\u3002 models/ \uff1a\u8be5\u6587\u4ef6\u5939\u4fdd\u5b58\u6240\u6709\u8bad\u7ec3\u8fc7\u7684\u6a21\u578b\u3002 notebook/ \uff1a\u6240\u6709 jupyter notebook\uff08\u5373\u4efb\u4f55 *.ipynb \u6587\u4ef6\uff09\u90fd\u5b58\u50a8\u5728\u7b14\u8bb0\u672c \u6587\u4ef6\u5939\u4e2d\u3002 README.md \uff1a\u8fd9\u662f\u4e00\u4e2a\u6807\u8bb0\u7b26\u6587\u4ef6\uff0c\u60a8\u53ef\u4ee5\u5728\u5176\u4e2d\u63cf\u8ff0\u60a8\u7684\u9879\u76ee\uff0c\u5e76\u5199\u660e\u5982\u4f55\u8bad\u7ec3\u6a21\u578b\u6216\u5728\u751f\u4ea7\u73af\u5883\u4e2d\u4f7f\u7528\u3002 LICENSE \uff1a\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u6587\u4ef6\uff0c\u5305\u542b\u9879\u76ee\u7684\u8bb8\u53ef\u8bc1\uff0c\u5982 MIT\u3001Apache \u7b49\u3002\u5173\u4e8e\u8bb8\u53ef\u8bc1\u7684\u8be6\u7ec6\u4ecb\u7ecd\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 \u5047\u8bbe\u4f60\u6b63\u5728\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u5bf9 MNIST \u6570\u636e\u96c6\uff08\u51e0\u4e4e\u6bcf\u672c\u673a\u5668\u5b66\u4e60\u4e66\u7c4d\u90fd\u4f1a\u7528\u5230\u7684\u6570\u636e\u96c6\uff09\u8fdb\u884c\u5206\u7c7b\u3002\u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\uff0c\u6211\u4eec\u5728\u4ea4\u53c9\u68c0\u9a8c\u4e00\u7ae0\u4e2d\u4e5f\u63d0\u5230\u8fc7 MNIST \u6570\u636e\u96c6\u3002\u6240\u4ee5\uff0c\u6211\u5c31\u4e0d\u89e3\u91ca\u8fd9\u4e2a\u6570\u636e\u96c6\u662f\u4ec0\u4e48\u6837\u5b50\u4e86\u3002\u7f51\u4e0a\u6709\u8bb8\u591a\u4e0d\u540c\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6\uff0c\u4f46\u6211\u4eec\u5c06\u4f7f\u7528 CSV \u683c\u5f0f\u7684\u6570\u636e\u96c6\u3002 \u5728\u8fd9\u79cd\u683c\u5f0f\u7684\u6570\u636e\u96c6\u4e2d\uff0cCSV \u7684\u6bcf\u4e00\u884c\u90fd\u5305\u542b\u56fe\u50cf\u7684\u6807\u7b7e\u548c 784 \u4e2a\u50cf\u7d20\u503c\uff0c\u50cf\u7d20\u503c\u8303\u56f4\u4ece 0 \u5230 255\u3002\u6570\u636e\u96c6\u5305\u542b 60000 \u5f20\u8fd9\u79cd\u683c\u5f0f\u7684\u56fe\u50cf\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u8f7b\u677e\u8bfb\u53d6\u8fd9\u79cd\u6570\u636e\u683c\u5f0f\u3002 \u8bf7\u6ce8\u610f\uff0c\u5c3d\u7ba1\u56fe 1 \u663e\u793a\u6240\u6709\u50cf\u7d20\u503c\u5747\u4e3a\u96f6\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002 \u56fe 1\uff1aCSV\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u6807\u7b7e\u5217\u7684\u8ba1\u6570\u3002 \u56fe 2\uff1aMNIST \u6570\u636e\u96c6\u4e2d\u7684\u6807\u7b7e\u8ba1\u6570 \u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c\u66f4\u591a\u7684\u63a2\u7d22\u3002\u6211\u4eec\u5df2\u7ecf\u77e5\u9053\u4e86\u6211\u4eec\u6240\u62e5\u6709\u7684\u6570\u636e\uff0c\u6ca1\u6709\u5fc5\u8981\u518d\u5bf9\u4e0d\u540c\u7684\u50cf\u7d20\u503c\u8fdb\u884c\u7ed8\u56fe\u3002\u4ece\u56fe 2 \u4e2d\u53ef\u4ee5\u6e05\u695a\u5730\u770b\u51fa\uff0c\u6807\u7b7e\u7684\u5206\u5e03\u76f8\u5f53\u5747\u5300\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387/F1 \u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u8fd9\u5c31\u662f\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u7684\u7b2c\u4e00\u6b65\uff1a\u786e\u5b9a\u8861\u91cf\u6807\u51c6\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7f16\u5199\u4e00\u4e9b\u4ee3\u7801\u4e86\u3002\u6211\u4eec\u9700\u8981\u521b\u5efa src/ \u6587\u4ef6\u5939\u548c\u4e00\u4e9b python \u811a\u672c\u3002 \u8bf7\u6ce8\u610f\uff0c\u8bad\u7ec3 CSV \u6587\u4ef6\u4f4d\u4e8e input/ \u6587\u4ef6\u5939\u4e2d\uff0c\u540d\u4e3a mnist_train.csv \u3002 \u5bf9\u4e8e\u8fd9\u6837\u4e00\u4e2a\u9879\u76ee\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5e94\u8be5\u662f\u4ec0\u4e48\u6837\u7684\u5462\uff1f \u9996\u5148\u8981\u521b\u5efa\u7684\u811a\u672c\u662f create_folds.py \u3002 \u8fd9\u5c06\u5728 input/ \u6587\u4ef6\u5939\u4e2d\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a mnist_train_folds.csv \u7684\u65b0\u6587\u4ef6\uff0c\u4e0e mnist_train.csv \u76f8\u540c\u3002\u552f\u4e00\u4e0d\u540c\u7684\u662f\uff0c\u8fd9\u4e2a CSV \u6587\u4ef6\u7ecf\u8fc7\u4e86\u968f\u673a\u6392\u5e8f\uff0c\u5e76\u65b0\u589e\u4e86\u4e00\u5217\u540d\u4e3a kfold \u7684\u5185\u5bb9\u3002 \u4e00\u65e6\u6211\u4eec\u51b3\u5b9a\u4e86\u8981\u4f7f\u7528\u54ea\u79cd\u8bc4\u4f30\u6307\u6807\u5e76\u521b\u5efa\u4e86\u6298\u53e0\uff0c\u5c31\u53ef\u4ee5\u5f00\u59cb\u521b\u5efa\u57fa\u672c\u6a21\u578b\u4e86\u3002\u8fd9\u53ef\u4ee5\u5728 train.py \u4e2d\u5b8c\u6210\u3002 import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/mnist_train_folds.csv\" ) # \u9009\u53d6df\u4e2dkfold\u5217\u4e0d\u7b49\u4e8efold df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u9009\u53d6df\u4e2dkfold\u5217\u7b49\u4e8efold df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u8bad\u7ec3\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_train = df_train . drop ( \"label\" , axis = 1 ) . values # \u8bad\u7ec3\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_train = df_train . label . values # \u9a8c\u8bc1\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values # \u9a8c\u8bc1\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_valid = df_valid . label . values # \u5b9e\u4f8b\u5316\u51b3\u7b56\u6811\u6a21\u578b clf = tree . DecisionTreeClassifier () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b clf . fit ( x_train , y_train ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u8f93\u5165\u5f97\u5230\u9884\u6d4b\u7ed3\u679c preds = clf . predict ( x_valid ) # \u8ba1\u7b97\u9a8c\u8bc1\u96c6\u51c6\u786e\u7387 accuracy = metrics . accuracy_score ( y_valid , preds ) # \u6253\u5370fold\u4fe1\u606f\u548c\u51c6\u786e\u7387 print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) # \u4fdd\u5b58\u6a21\u578b joblib . dump ( clf , f \"../models/dt_ { fold } .bin\" ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u60a8\u53ef\u4ee5\u5728\u63a7\u5236\u53f0\u8c03\u7528 python train.py \u8fd0\u884c\u8be5\u811a\u672c\u3002 \u276f python train . py Fold = 0 , Accuracy = 0.8680833333333333 Fold = 1 , Accuracy = 0.8685 Fold = 2 , Accuracy = 0.8674166666666666 Fold = 3 , Accuracy = 0.8703333333333333 Fold = 4 , Accuracy = 0.8699166666666667 \u67e5\u770b\u8bad\u7ec3\u811a\u672c\u65f6\uff0c\u60a8\u4f1a\u53d1\u73b0\u8fd8\u6709\u4e00\u4e9b\u5185\u5bb9\u662f\u786c\u7f16\u7801\u7684\uff0c\u4f8b\u5982\u6298\u53e0\u6570\u3001\u8bad\u7ec3\u6587\u4ef6\u548c\u8f93\u51fa\u6587\u4ef6\u5939\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u6240\u6709\u8fd9\u4e9b\u4fe1\u606f\u7684\u914d\u7f6e\u6587\u4ef6\uff1a config.py \u3002 TRAINING_FILE = \"../input/mnist_train_folds.csv\" MODEL_OUTPUT = \"../models/\" \u6211\u4eec\u8fd8\u5bf9\u8bad\u7ec3\u811a\u672c\u8fdb\u884c\u4e86\u4e00\u4e9b\u4fee\u6539\u3002\u8bad\u7ec3\u6587\u4ef6\u73b0\u5728\u4f7f\u7528\u914d\u7f6e\u6587\u4ef6\u3002\u8fd9\u6837\uff0c\u66f4\u6539\u6570\u636e\u6216\u6a21\u578b\u8f93\u51fa\u5c31\u66f4\u5bb9\u6613\u4e86\u3002 import os import config import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u4f7f\u7528config\u4e2d\u7684\u8def\u5f84\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values clf = tree . DecisionTreeClassifier () clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" ) ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5e76\u6ca1\u6709\u5c55\u793a\u8fd9\u4e2a\u57f9\u8bad\u811a\u672c\u4e0e\u4e4b\u524d\u811a\u672c\u7684\u533a\u522b\u3002\u8bf7\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e24\u4e2a\u811a\u672c\uff0c\u81ea\u5df1\u627e\u51fa\u4e0d\u540c\u4e4b\u5904\u3002\u533a\u522b\u5e76\u4e0d\u591a\u3002 \u4e0e\u8bad\u7ec3\u811a\u672c\u76f8\u5173\u7684\u8fd8\u6709\u4e00\u70b9\u53ef\u4ee5\u6539\u8fdb\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u4e3a\u6bcf\u4e2a\u6298\u53e0\u591a\u6b21\u8c03\u7528\u8fd0\u884c\u51fd\u6570\u3002\u6709\u65f6\uff0c\u5728\u540c\u4e00\u4e2a\u811a\u672c\u4e2d\u8fd0\u884c\u591a\u4e2a\u6298\u53e0\u5e76\u4e0d\u53ef\u53d6\uff0c\u56e0\u4e3a\u5185\u5b58\u6d88\u8017\u53ef\u80fd\u4f1a\u4e0d\u65ad\u589e\u52a0\uff0c\u7a0b\u5e8f\u53ef\u80fd\u4f1a\u5d29\u6e83\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5411\u8bad\u7ec3\u811a\u672c\u4f20\u9012\u53c2\u6570\u3002\u6211\u559c\u6b22\u4f7f\u7528 argparse\u3002 import argparse if __name__ == \"__main__\" : # \u5b9e\u4f8b\u5316\u53c2\u6570\u73af\u5883 parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # \u8bfb\u53d6\u53c2\u6570 args = parser . parse_args () run ( fold = args . fold ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u8fd0\u884c python \u811a\u672c\uff0c\u4f46\u4ec5\u9650\u4e8e\u7ed9\u5b9a\u7684\u6298\u53e0\u3002 \u276f python train . py -- fold 0 Fold = 0 , Accuracy = 0.8656666666666667 \u4ed4\u7ec6\u89c2\u5bdf\uff0c\u6211\u4eec\u7684\u7b2c 0 \u6298\u5f97\u5206\u4e0e\u4e4b\u524d\u6709\u4e9b\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6a21\u578b\u4e2d\u5b58\u5728\u968f\u673a\u6027\u3002\u6211\u4eec\u5c06\u5728\u540e\u9762\u7684\u7ae0\u8282\u4e2d\u8ba8\u8bba\u5982\u4f55\u5904\u7406\u968f\u673a\u6027\u3002 \u73b0\u5728\uff0c\u5982\u679c\u4f60\u613f\u610f\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a shell \u811a\u672c \uff0c\u9488\u5bf9\u4e0d\u540c\u7684\u6298\u53e0\u4f7f\u7528\u4e0d\u540c\u7684\u547d\u4ee4\uff0c\u7136\u540e\u4e00\u8d77\u8fd0\u884c\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 python train . py -- fold 0 python train . py -- fold 1 python train . py -- fold 2 python train . py -- fold 3 python train . py -- fold 4 \u60a8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u5b83\u3002 \u276f sh run . sh Fold = 0 , Accuracy = 0.8675 Fold = 1 , Accuracy = 0.8693333333333333 Fold = 2 , Accuracy = 0.8683333333333333 Fold = 3 , Accuracy = 0.8704166666666666 Fold = 4 , Accuracy = 0.8685 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u53d6\u5f97\u4e86\u4e00\u4e9b\u8fdb\u5c55\uff0c\u4f46\u5982\u679c\u6211\u4eec\u770b\u4e00\u4e0b\u6211\u4eec\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u6211\u4eec\u4ecd\u7136\u53d7\u5230\u4e00\u4e9b\u4e1c\u897f\u7684\u9650\u5236\uff0c\u4f8b\u5982\u6a21\u578b\u3002\u6a21\u578b\u662f\u786c\u7f16\u7801\u5728\u8bad\u7ec3\u811a\u672c\u4e2d\u7684\uff0c\u53ea\u6709\u4fee\u6539\u811a\u672c\u624d\u80fd\u6539\u53d8\u5b83\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u65b0\u7684 python \u811a\u672c\uff0c\u540d\u4e3a model_dispatcher.py \u3002model_dispatcher.py\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5c06\u8c03\u5ea6\u6211\u4eec\u7684\u6a21\u578b\u5230\u8bad\u7ec3\u811a\u672c\u4e2d\u3002 from sklearn import tree models = { # \u4ee5gini\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), # \u4ee5entropy\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), } model_dispatcher.py \u4ece scikit-learn \u4e2d\u5bfc\u5165\u4e86 tree\uff0c\u5e76\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u6a21\u578b\u7684\u540d\u79f0\uff0c\u503c\u662f\u6a21\u578b\u672c\u8eab\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5b9a\u4e49\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u51b3\u7b56\u6811\uff0c\u4e00\u79cd\u4f7f\u7528\u57fa\u5c3c\u6807\u51c6\uff0c\u53e6\u4e00\u79cd\u4f7f\u7528\u71b5\u6807\u51c6\u3002\u8981\u4f7f\u7528 py\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u8bad\u7ec3\u811a\u672c\u505a\u4e00\u4e9b\u4fee\u6539\u3002 import argparse import os import joblib import pandas as pd from sklearn import metrics import config import model_dispatcher def run ( fold , model ): df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values # \u6839\u636emodel\u53c2\u6570\u9009\u62e9\u6a21\u578b clf = model_dispatcher . models [ model ] clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" )) if __name__ == \"__main__\" : parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # model\u53c2\u6570 parser . add_argument ( \"--model\" , type = str ) args = parser . parse_args () run ( fold = args . fold , model = args . model ) train.py \u6709\u51e0\u5904\u91cd\u5927\u6539\u52a8\uff1a - \u5bfc\u5165 model_dispatcher - \u4e3a ArgumentParser \u6dfb\u52a0 --model \u53c2\u6570 - \u4e3a run() \u51fd\u6570\u6dfb\u52a0model\u53c2\u6570 - \u4f7f\u7528\u8c03\u5ea6\u7a0b\u5e8f\u83b7\u53d6\u6307\u5b9a\u540d\u79f0\u7684\u6a21\u578b \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u811a\u672c\uff1a \u276f python train . py -- fold 0 -- model decision_tree_gini Fold = 0 , Accuracy = 0.8665833333333334 \u6216\u6267\u884c\u4ee5\u4e0b\u547d\u4ee4 \u276f python train . py -- fold 0 -- model decision_tree_entropy Fold = 0 , Accuracy = 0.8705833333333334 \u73b0\u5728\uff0c\u5982\u679c\u8981\u6dfb\u52a0\u65b0\u6a21\u578b\uff0c\u53ea\u9700\u4fee\u6539 model_dispatcher.py \u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u6dfb\u52a0\u968f\u673a\u68ee\u6797\uff0c\u770b\u770b\u51c6\u786e\u7387\u4f1a\u6709\u4ec0\u4e48\u53d8\u5316\u3002 from sklearn import ensemble from sklearn import tree models = { \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), # \u968f\u673a\u68ee\u6797\u6a21\u578b \"rf\" : ensemble . RandomForestClassifier (), } \u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u3002 \u276f python train . py -- fold 0 -- model rf Fold = 0 , Accuracy = 0.9670833333333333 \u54c7\uff0c\u4e00\u4e2a\u7b80\u5355\u7684\u6539\u52a8\u5c31\u80fd\u8ba9\u5206\u6570\u6709\u5982\u6b64\u5927\u7684\u63d0\u5347\uff01\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528 run.sh \u811a\u672c\u8fd0\u884c 5 \u4e2a\u6298\u53e0\uff01 python train . py -- fold 0 -- model rf python train . py -- fold 1 -- model rf python train . py -- fold 2 -- model rf python train . py -- fold 3 -- model rf python train . py -- fold 4 -- model rf \u5f97\u5206\u60c5\u51b5\u5982\u4e0b \u276f sh run . sh Fold = 0 , Accuracy = 0.9674166666666667 Fold = 1 , Accuracy = 0.9698333333333333 Fold = 2 , Accuracy = 0.96575 Fold = 3 , Accuracy = 0.9684166666666667 Fold = 4 , Accuracy = 0.9666666666666667 MNIST \u51e0\u4e4e\u662f\u6bcf\u672c\u4e66\u548c\u6bcf\u7bc7\u535a\u5ba2\u90fd\u4f1a\u8ba8\u8bba\u7684\u95ee\u9898\u3002\u4f46\u6211\u8bd5\u56fe\u5c06\u8fd9\u4e2a\u95ee\u9898\u8f6c\u6362\u5f97\u66f4\u6709\u8da3\uff0c\u5e76\u5411\u4f60\u5c55\u793a\u5982\u4f55\u4e3a\u4f60\u6b63\u5728\u505a\u7684\u6216\u8ba1\u5212\u5728\u4e0d\u4e45\u7684\u5c06\u6765\u505a\u7684\u51e0\u4e4e\u6240\u6709\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7f16\u5199\u4e00\u4e2a\u57fa\u672c\u6846\u67b6\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4ee5\u6539\u8fdb\u8fd9\u4e2a MNIST \u6a21\u578b\u548c\u8fd9\u4e2a\u6846\u67b6\uff0c\u6211\u4eec\u5c06\u5728\u4ee5\u540e\u7684\u7ae0\u8282\u4e2d\u770b\u5230\u3002 \u6211\u4f7f\u7528\u4e86\u4e00\u4e9b\u811a\u672c\uff0c\u5982 model_dispatcher.py \u548c config.py \uff0c\u5e76\u5c06\u5b83\u4eec\u5bfc\u5165\u5230\u6211\u7684\u8bad\u7ec3\u811a\u672c\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u6ca1\u6709\u5bfc\u5165 \uff0c\u4f60\u4e5f\u4e0d\u5e94\u8be5\u5bfc\u5165\u3002\u5982\u679c\u6211\u5bfc\u5165\u4e86 \uff0c\u4f60\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u77e5\u9053\u6a21\u578b\u5b57\u5178\u662f\u4ece\u54ea\u91cc\u6765\u7684\u3002\u7f16\u5199\u4f18\u79c0\u3001\u6613\u61c2\u7684\u4ee3\u7801\u662f\u4e00\u4e2a\u4eba\u5fc5\u987b\u5177\u5907\u7684\u57fa\u672c\u7d20\u8d28\uff0c\u4f46\u8bb8\u591a\u6570\u636e\u79d1\u5b66\u5bb6\u5374\u5ffd\u89c6\u4e86\u8fd9\u4e00\u70b9\u3002\u5982\u679c\u4f60\u6240\u505a\u7684\u9879\u76ee\u80fd\u8ba9\u5176\u4ed6\u4eba\u7406\u89e3\u5e76\u4f7f\u7528\uff0c\u800c\u65e0\u9700\u54a8\u8be2\u4f60\u7684\u610f\u89c1\uff0c\u90a3\u4e48\u4f60\u5c31\u8282\u7701\u4e86\u4ed6\u4eec\u7684\u65f6\u95f4\u548c\u81ea\u5df1\u7684\u65f6\u95f4\uff0c\u53ef\u4ee5\u5c06\u8fd9\u4e9b\u65f6\u95f4\u6295\u5165\u5230\u6539\u8fdb\u4f60\u7684\u9879\u76ee\u6216\u5f00\u53d1\u65b0\u9879\u76ee\u4e2d\u53bb\u3002","title":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee"},{"location":"%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87/","text":"\u8bc4\u4f30\u6307\u6807 \u8bf4\u5230\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u4f60\u4f1a\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u9047\u5230\u5f88\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u6307\u6807\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u751a\u81f3\u4f1a\u6839\u636e\u4e1a\u52a1\u95ee\u9898\u521b\u5efa\u5ea6\u91cf\u6807\u51c6\u3002\u9010\u4e00\u4ecb\u7ecd\u548c\u89e3\u91ca\u6bcf\u4e00\u79cd\u5ea6\u91cf\u7c7b\u578b\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u5ea6\u91cf\u6807\u51c6\uff0c\u4f9b\u4f60\u5728\u6700\u521d\u7684\u51e0\u4e2a\u9879\u76ee\u4e2d\u4f7f\u7528\u3002 \u5728\u672c\u4e66\u7684\u5f00\u5934\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u76d1\u7763\u5b66\u4e60\u548c\u975e\u76d1\u7763\u5b66\u4e60\u3002\u867d\u7136\u65e0\u76d1\u7763\u5b66\u4e60\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e9b\u6307\u6807\uff0c\u4f46\u6211\u4eec\u5c06\u53ea\u5173\u6ce8\u6709\u76d1\u7763\u5b66\u4e60\u3002\u8fd9\u662f\u56e0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u591a\uff0c\u800c\u4e14\u5bf9\u65e0\u76d1\u7763\u65b9\u6cd5\u7684\u8bc4\u4f30\u76f8\u5f53\u4e3b\u89c2\u3002 \u5982\u679c\u6211\u4eec\u8c08\u8bba\u5206\u7c7b\u95ee\u9898\uff0c\u6700\u5e38\u7528\u7684\u6307\u6807\u662f\uff1a \u51c6\u786e\u7387\uff08Accuracy\uff09 \u7cbe\u786e\u7387\uff08P\uff09 \u53ec\u56de\u7387\uff08R\uff09 F1 \u5206\u6570\uff08F1\uff09 AUC\uff08AUC\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u8bf4\u5230\u56de\u5f52\uff0c\u6700\u5e38\u7528\u7684\u8bc4\u4ef7\u6307\u6807\u662f \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee \uff08MAE\uff09 \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u5747\u65b9\u6839\u8bef\u5dee \uff08RMSE\uff09 \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \uff08RMSLE\uff09 \u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee \uff08MPE\uff09 \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee \uff08MAPE\uff09 R2 \u4e86\u89e3\u4e0a\u8ff0\u6307\u6807\u7684\u5de5\u4f5c\u539f\u7406\u5e76\u4e0d\u662f\u6211\u4eec\u5fc5\u987b\u4e86\u89e3\u7684\u552f\u4e00\u4e8b\u60c5\u3002\u6211\u4eec\u8fd8\u5fc5\u987b\u77e5\u9053\u4f55\u65f6\u4f7f\u7528\u54ea\u4e9b\u6307\u6807\uff0c\u800c\u8fd9\u53d6\u51b3\u4e8e\u4f60\u6709\u4ec0\u4e48\u6837\u7684\u6570\u636e\u548c\u76ee\u6807\u3002\u6211\u8ba4\u4e3a\u8fd9\u4e0e\u76ee\u6807\u6709\u5173\uff0c\u800c\u4e0e\u6570\u636e\u65e0\u5173\u3002 \u8981\u8fdb\u4e00\u6b65\u4e86\u89e3\u8fd9\u4e9b\u6307\u6807\uff0c\u8ba9\u6211\u4eec\u4ece\u4e00\u4e2a\u7b80\u5355\u7684\u95ee\u9898\u5f00\u59cb\u3002\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a \u4e8c\u5143\u5206\u7c7b \u95ee\u9898\uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u76ee\u6807\u7684\u95ee\u9898\uff0c\u5047\u8bbe\u8fd9\u662f\u4e00\u4e2a\u80f8\u90e8 X \u5149\u56fe\u50cf\u5206\u7c7b\u95ee\u9898\u3002\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6ca1\u6709\u95ee\u9898\uff0c\u800c\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6709\u80ba\u584c\u9677\uff0c\u4e5f\u5c31\u662f\u6240\u8c13\u7684\u6c14\u80f8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u5206\u7c7b\u5668\uff0c\u5728\u7ed9\u5b9a\u80f8\u90e8 X \u5149\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\uff0c\u5b83\u80fd\u68c0\u6d4b\u51fa\u56fe\u50cf\u662f\u5426\u6709\u6c14\u80f8\u3002 \u56fe 1\uff1a\u6c14\u80f8\u80ba\u90e8\u56fe\u50cf \u6211\u4eec\u8fd8\u5047\u8bbe\u6709\u76f8\u540c\u6570\u91cf\u7684\u6c14\u80f8\u548c\u975e\u6c14\u80f8\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u6bd4\u5982\u5404 100 \u5f20\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709 100 \u5f20\u9633\u6027\u6837\u672c\u548c 100 \u5f20\u9634\u6027\u6837\u672c\uff0c\u5171\u8ba1 200 \u5f20\u56fe\u50cf\u3002 \u7b2c\u4e00\u6b65\u662f\u5c06\u4e0a\u8ff0\u6570\u636e\u5206\u4e3a\u4e24\u7ec4\uff0c\u6bcf\u7ec4 100 \u5f20\u56fe\u50cf\uff0c\u5373\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u8fd9\u4e24\u4e2a\u96c6\u5408\u4e2d\uff0c\u6211\u4eec\u90fd\u6709 50 \u4e2a\u6b63\u6837\u672c\u548c 50 \u4e2a\u8d1f\u6837\u672c\u3002 \u5728\u4e8c\u5143\u5206\u7c7b\u6307\u6807\u4e2d\uff0c\u5f53\u6b63\u8d1f\u6837\u672c\u6570\u91cf\u76f8\u7b49\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f7f\u7528\u51c6\u786e\u7387\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002 \u51c6\u786e\u7387 \uff1a\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6700\u76f4\u63a5\u7684\u6307\u6807\u4e4b\u4e00\u3002\u5b83\u5b9a\u4e49\u4e86\u6a21\u578b\u7684\u51c6\u786e\u5ea6\u3002\u5bf9\u4e8e\u4e0a\u8ff0\u95ee\u9898\uff0c\u5982\u679c\u4f60\u5efa\u7acb\u7684\u6a21\u578b\u80fd\u51c6\u786e\u5206\u7c7b 90 \u5f20\u56fe\u7247\uff0c\u90a3\u4e48\u4f60\u7684\u51c6\u786e\u7387\u5c31\u662f 90% \u6216 0.90\u3002\u5982\u679c\u53ea\u6709 83 \u5e45\u56fe\u50cf\u88ab\u6b63\u786e\u5206\u7c7b\uff0c\u90a3\u4e48\u6a21\u578b\u7684\u51c6\u786e\u7387\u5c31\u662f 83% \u6216 0.83\u3002 \u8ba1\u7b97\u51c6\u786e\u7387\u7684 Python \u4ee3\u7801\u4e5f\u975e\u5e38\u7b80\u5355\u3002 def accuracy ( y_true , y_pred ): # \u4e3a\u6b63\u786e\u9884\u6d4b\u6570\u521d\u59cb\u5316\u4e00\u4e2a\u7b80\u5355\u8ba1\u6570\u5668 correct_counter = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): if yt == yp : # \u5982\u679c\u9884\u6d4b\u6807\u7b7e\u4e0e\u771f\u5b9e\u6807\u7b7e\u76f8\u540c\uff0c\u5219\u589e\u52a0\u8ba1\u6570\u5668 correct_counter += 1 # \u8fd4\u56de\u6b63\u786e\u7387\uff0c\u6b63\u786e\u6807\u7b7e\u6570/\u603b\u6807\u7b7e\u6570 return correct_counter / len ( y_true ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u8ba1\u7b97\u51c6\u786e\u7387\u3002 In [ X ]: from sklearn import metrics ... : l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] ... : metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u73b0\u5728\uff0c\u5047\u8bbe\u6211\u4eec\u628a\u6570\u636e\u96c6\u7a0d\u5fae\u6539\u52a8\u4e00\u4e0b\uff0c\u6709 180 \u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u53ea\u6709 20 \u5f20\u6709\u6c14\u80f8\u3002\u5373\u4f7f\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4e5f\u8981\u521b\u5efa\u6b63\u8d1f\uff08\u6c14\u80f8\u4e0e\u975e\u6c14\u80f8\uff09\u76ee\u6807\u6bd4\u4f8b\u76f8\u540c\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u6bcf\u4e00\u7ec4\u4e2d\uff0c\u6211\u4eec\u6709 90 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u5982\u679c\u8bf4\u9a8c\u8bc1\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u90fd\u662f\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u90a3\u4e48\u60a8\u7684\u51c6\u786e\u7387\u4f1a\u662f\u591a\u5c11\u5462\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\uff1b\u60a8\u5bf9 90% \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u6b63\u786e\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u7684\u51c6\u786e\u7387\u662f 90%\u3002 \u4f46\u8bf7\u518d\u770b\u4e00\u904d\u3002 \u4f60\u751a\u81f3\u6ca1\u6709\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5c31\u5f97\u5230\u4e86 90% \u7684\u51c6\u786e\u7387\u3002\u8fd9\u4f3c\u4e4e\u6709\u70b9\u6ca1\u7528\u3002\u5982\u679c\u6211\u4eec\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u6570\u636e\u96c6\u662f\u504f\u659c\u7684\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u6bd4\u53e6\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u591a\u5f88\u591a\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u56e0\u4e3a\u5b83\u4e0d\u80fd\u4ee3\u8868\u6570\u636e\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u80fd\u4f1a\u83b7\u5f97\u5f88\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u4f46\u60a8\u7684\u6a21\u578b\u5728\u5b9e\u9645\u6837\u672c\u4e2d\u7684\u8868\u73b0\u53ef\u80fd\u5e76\u4e0d\u7406\u60f3\uff0c\u800c\u4e14\u60a8\u4e5f\u65e0\u6cd5\u5411\u7ecf\u7406\u89e3\u91ca\u539f\u56e0\u3002 \u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b \u7cbe\u786e\u7387 \u7b49\u5176\u4ed6\u6307\u6807\u3002 \u5728\u5b66\u4e60\u7cbe\u786e\u7387\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e86\u89e3\u4e00\u4e9b\u672f\u8bed\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5047\u8bbe\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e3a\u6b63\u7c7b (1)\uff0c\u6ca1\u6709\u6c14\u80f8\u7684\u4e3a\u8d1f\u7c7b (0)\u3002 \u771f\u9633\u6027 \uff08TP\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6709\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9633\u6027\u3002 \u771f\u9634\u6027 \uff08TN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u800c\u5b9e\u9645\u76ee\u6807\u663e\u793a\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u6b63\u786e\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9633\u6027\uff1b\u5982\u679c\u60a8\u7684\u6a21\u578b\u51c6\u786e\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9634\u6027\u3002 \u5047\u9633\u6027 \uff08FP\uff09 \uff1a\u7ed9\u5b9a\u4e00\u5f20\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u975e\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9633\u6027\u3002 \u5047\u9634\u6027 \uff08FN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u975e\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\uff0c\u90a3\u4e48\u5b83\u5c31\u662f\u5047\u9633\u6027\u3002\u5982\u679c\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5219\u662f\u5047\u9634\u6027\u3002 \u8ba9\u6211\u4eec\u9010\u4e00\u770b\u770b\u8fd9\u4e9b\u5b9e\u73b0\u3002 def true_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u8ba1\u6570\u5668 tp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 1 : tp += 1 # \u8fd4\u56de\u771f\u9633\u6027\u6837\u672c\u6570 return tp def true_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9634\u6027\u6837\u672c\u8ba1\u6570\u5668 tn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 0 : tn += 1 # \u8fd4\u56de\u771f\u9634\u6027\u6837\u672c\u6570 return tn def false_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9633\u6027\u8ba1\u6570\u5668 fp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 1 : fp += 1 # \u8fd4\u56de\u5047\u9633\u6027\u6837\u672c\u6570 return fp def false_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9634\u6027\u8ba1\u6570\u5668 fn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 0 : fn += 1 # \u8fd4\u56de\u5047\u9634\u6027\u6570 return fn \u6211\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e9b\u529f\u80fd\u7684\u65b9\u6cd5\u975e\u5e38\u7b80\u5355\uff0c\u800c\u4e14\u53ea\u9002\u7528\u4e8e\u4e8c\u5143\u5206\u7c7b\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u8fd9\u4e9b\u51fd\u6570\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: true_positive ( l1 , l2 ) Out [ X ]: 2 In [ X ]: false_positive ( l1 , l2 ) Out [ X ]: 1 In [ X ]: false_negative ( l1 , l2 ) Out [ X ]: 2 In [ X ]: true_negative ( l1 , l2 ) Out [ X ]: 3 \u5982\u679c\u6211\u4eec\u5fc5\u987b\u7528\u4e0a\u8ff0\u672f\u8bed\u6765\u5b9a\u4e49\u7cbe\u786e\u7387\uff0c\u6211\u4eec\u53ef\u4ee5\u5199\u4e3a\uff1a \\[ Accuracy Score = (TP + TN)/(TP + TN + FP +FN) \\] \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5728 python \u4e2d\u4f7f\u7528 TP\u3001TN\u3001FP \u548c FN \u5feb\u901f\u5b9e\u73b0\u51c6\u786e\u5ea6\u5f97\u5206\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a accuracy_v2\u3002 def accuracy_v2 ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u51c6\u786e\u7387 accuracy_score = ( tp + tn ) / ( tp + tn + fp + fn ) return accuracy_score \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0e\u4e4b\u524d\u7684\u5b9e\u73b0\u548c scikit-learn \u7248\u672c\u8fdb\u884c\u6bd4\u8f83\uff0c\u5feb\u901f\u68c0\u67e5\u8be5\u51fd\u6570\u7684\u6b63\u786e\u6027\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: accuracy ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: accuracy_v2 ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u8bf7\u6ce8\u610f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0cmetrics.accuracy_score \u6765\u81ea scikit-learn\u3002 \u5f88\u597d\u3002\u6240\u6709\u503c\u90fd\u5339\u914d\u3002\u8fd9\u8bf4\u660e\u6211\u4eec\u5728\u5b9e\u73b0\u8fc7\u7a0b\u4e2d\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u5176\u4ed6\u91cd\u8981\u6307\u6807\u3002 \u9996\u5148\u662f\u7cbe\u786e\u7387\u3002\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u662f \\[ Precision = TP/(TP + FP) \\] \u5047\u8bbe\u6211\u4eec\u5728\u65b0\u7684\u504f\u659c\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u4e86\u4e00\u4e2a\u65b0\u6a21\u578b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 90 \u5f20\u56fe\u50cf\u4e2d\u7684 80 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u56fe\u50cf\u4e2d\u7684 8 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6210\u529f\u8bc6\u522b\u4e86 100 \u5f20\u56fe\u50cf\u4e2d\u7684 88 \u5f20\u3002\u56e0\u6b64\uff0c\u51c6\u786e\u7387\u4e3a 0.88 \u6216 88%\u3002 \u4f46\u662f\uff0c\u5728\u8fd9 100 \u5f20\u6837\u672c\u4e2d\uff0c\u6709 10 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u6c14\u80f8\uff0c2 \u5f20\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u975e\u6c14\u80f8\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 TP : 8 TN: 80 FP: 10 FN: 2 \u7cbe\u786e\u7387\u4e3a 8 / (8 + 10) = 0.444\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u5728\u8bc6\u522b\u9633\u6027\u6837\u672c\uff08\u6c14\u80f8\uff09\u65f6\u6709 44.4% \u7684\u6b63\u786e\u7387\u3002 \u73b0\u5728\uff0c\u65e2\u7136\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86 TP\u3001TN\u3001FP \u548c FN\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5728 python \u4e2d\u5b9e\u73b0\u7cbe\u786e\u7387\u4e86\u3002 def precision ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8ba9\u6211\u4eec\u8bd5\u8bd5\u8fd9\u79cd\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u65b9\u5f0f\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: precision ( l1 , l2 ) Out [ X ]: 0.6666666666666666 \u8fd9\u4f3c\u4e4e\u6ca1\u6709\u95ee\u9898\u3002 \u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u770b \u53ec\u56de\u7387 \u3002\u53ec\u56de\u7387\u7684\u5b9a\u4e49\u662f\uff1a \\[ Recall = TP/(TP + FN) \\] \u5728\u4e0a\u8ff0\u60c5\u51b5\u4e0b\uff0c\u53ec\u56de\u7387\u4e3a 8 / (8 + 2) = 0.80\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 80% \u7684\u9633\u6027\u6837\u672c\u3002 def recall ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u53ec\u56de\u7387 recall = tp / ( tp + fn ) return recall \u5c31\u6211\u4eec\u7684\u4e24\u4e2a\u5c0f\u5217\u8868\u800c\u8a00\uff0c\u53ec\u56de\u7387\u5e94\u8be5\u662f 0.5\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: recall ( l1 , l2 ) Out [ X ]: 0.5 \u8fd9\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u503c\u76f8\u7b26\uff01 \u5bf9\u4e8e\u4e00\u4e2a \"\u597d \"\u6a21\u578b\u6765\u8bf4\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u503c\u90fd\u5e94\u8be5\u5f88\u9ad8\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0c\u53ec\u56de\u503c\u76f8\u5f53\u9ad8\u3002\u4f46\u662f\uff0c\u7cbe\u786e\u7387\u5374\u5f88\u4f4e\uff01\u6211\u4eec\u7684\u6a21\u578b\u4ea7\u751f\u4e86\u5927\u91cf\u7684\u8bef\u62a5\uff0c\u4f46\u6f0f\u62a5\u8f83\u5c11\u3002\u5728\u8fd9\u7c7b\u95ee\u9898\u4e2d\uff0c\u5047\u9634\u6027\u8f83\u5c11\u662f\u597d\u4e8b\uff0c\u56e0\u4e3a\u4f60\u4e0d\u60f3\u5728\u75c5\u4eba\u6709\u6c14\u80f8\u7684\u60c5\u51b5\u4e0b\u5374\u8bf4\u4ed6\u4eec\u6ca1\u6709\u6c14\u80f8\u3002\u8fd9\u6837\u505a\u4f1a\u9020\u6210\u66f4\u5927\u7684\u4f24\u5bb3\u3002\u4f46\u6211\u4eec\u4e5f\u6709\u5f88\u591a\u5047\u9633\u6027\u7ed3\u679c\uff0c\u8fd9\u4e5f\u4e0d\u662f\u597d\u4e8b\u3002 \u5927\u591a\u6570\u6a21\u578b\u90fd\u4f1a\u9884\u6d4b\u4e00\u4e2a\u6982\u7387\uff0c\u5f53\u6211\u4eec\u9884\u6d4b\u65f6\uff0c\u901a\u5e38\u4f1a\u5c06\u8fd9\u4e2a\u9608\u503c\u9009\u4e3a 0.5\u3002\u8fd9\u4e2a\u9608\u503c\u5e76\u4e0d\u603b\u662f\u7406\u60f3\u7684\uff0c\u6839\u636e\u8fd9\u4e2a\u9608\u503c\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u503c\u53ef\u80fd\u4f1a\u53d1\u751f\u5f88\u5927\u7684\u53d8\u5316\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u7684\u6bcf\u4e2a\u9608\u503c\u90fd\u80fd\u8ba1\u7b97\u51fa\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u53ef\u4ee5\u5728\u8fd9\u4e9b\u503c\u4e4b\u95f4\u7ed8\u5236\u51fa\u66f2\u7ebf\u56fe\u3002\u8fd9\u5e45\u56fe\u6216\u66f2\u7ebf\u88ab\u79f0\u4e3a \"\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\"\u3002 \u5728\u7814\u7a76\u7cbe\u786e\u7387-\u8c03\u7528\u66f2\u7ebf\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5047\u8bbe\u6709\u4e24\u4e2a\u5217\u8868\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0.02638412 , 0.11114267 , 0.31620708 , ... : 0.0490937 , 0.0191491 , 0.17554844 , ... : 0.15952202 , 0.03819563 , 0.11639273 , ... : 0.079377 , 0.08584789 , 0.39095342 , ... : 0.27259048 , 0.03447096 , 0.04644807 , ... : 0.03543574 , 0.18521942 , 0.05934905 , ... : 0.61977213 , 0.33056815 ] \u56e0\u6b64\uff0cy_true \u662f\u6211\u4eec\u7684\u76ee\u6807\u503c\uff0c\u800c y_pred \u662f\u6837\u672c\u88ab\u8d4b\u503c\u4e3a 1 \u7684\u6982\u7387\u503c\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u6211\u4eec\u8981\u770b\u7684\u662f\u9884\u6d4b\u4e2d\u7684\u6982\u7387\uff0c\u800c\u4e0d\u662f\u9884\u6d4b\u503c\uff08\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u9884\u6d4b\u503c\u7684\u8ba1\u7b97\u9608\u503c\u4e3a 0.5\uff09\u3002 precisions = [] recalls = [] thresholds = [ 0.0490937 , 0.05934905 , 0.079377 , 0.08584789 , 0.11114267 , 0.11639273 , 0.15952202 , 0.17554844 , 0.18521942 , 0.27259048 , 0.31620708 , 0.33056815 , 0.39095342 , 0.61977213 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for i in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_prediction = [ 1 if x >= i else 0 for x in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , temp_prediction ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , temp_prediction ) # \u52a0\u5165\u7cbe\u786e\u7387\u5217\u8868 precisions . append ( p ) # \u52a0\u5165\u53ec\u56de\u7387\u5217\u8868 recalls . append ( r ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7ed8\u5236\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 7 , 7 )) # x\u8f74\u4e3a\u53ec\u56de\u7387\uff0cy\u8f74\u4e3a\u7cbe\u786e\u7387 plt . plot ( recalls , precisions ) # \u6dfb\u52a0x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a15 plt . xlabel ( 'Recall' , fontsize = 15 ) # \u6dfb\u52a0y\u8f74\u6807\u7b7e\uff0c\u5b57\u6761\u5927\u5c0f\u4e3a15 plt . ylabel ( 'Precision' , fontsize = 15 ) \u56fe 2 \u663e\u793a\u4e86\u6211\u4eec\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\u5f97\u5230\u7684\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 \u56fe 2\uff1a\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u8fd9\u6761 \u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u4e0e\u60a8\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230\u7684\u66f2\u7ebf\u622a\u7136\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6211\u4eec\u53ea\u6709 20 \u4e2a\u6837\u672c\uff0c\u5176\u4e2d\u53ea\u6709 3 \u4e2a\u662f\u9633\u6027\u6837\u672c\u3002\u4f46\u8fd9\u6ca1\u4ec0\u4e48\u597d\u62c5\u5fc3\u7684\u3002\u8fd9\u8fd8\u662f\u90a3\u6761\u7cbe\u786e\u7387-\u53ec\u56de\u66f2\u7ebf\u3002 \u4f60\u4f1a\u53d1\u73b0\uff0c\u9009\u62e9\u4e00\u4e2a\u65e2\u80fd\u63d0\u4f9b\u826f\u597d\u7cbe\u786e\u7387\u53c8\u80fd\u63d0\u4f9b\u53ec\u56de\u503c\u7684\u9608\u503c\u662f\u5f88\u6709\u6311\u6218\u6027\u7684\u3002\u5982\u679c\u9608\u503c\u8fc7\u9ad8\uff0c\u771f\u9633\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u51cf\u5c11\uff0c\u800c\u5047\u9634\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u589e\u52a0\u3002\u8fd9\u4f1a\u964d\u4f4e\u53ec\u56de\u7387\uff0c\u4f46\u7cbe\u786e\u7387\u5f97\u5206\u4f1a\u5f88\u9ad8\u3002\u5982\u679c\u5c06\u9608\u503c\u964d\u5f97\u592a\u4f4e\uff0c\u5219\u8bef\u62a5\u4f1a\u5927\u91cf\u589e\u52a0\uff0c\u7cbe\u786e\u7387\u4e5f\u4f1a\u964d\u4f4e\u3002 \u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u8d8a\u63a5\u8fd1 1 \u8d8a\u597d\u3002 F1 \u5206\u6570\u662f\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7efc\u5408\u6307\u6807\u3002\u5b83\u88ab\u5b9a\u4e49\u4e3a\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7b80\u5355\u52a0\u6743\u5e73\u5747\u503c\uff08\u8c03\u548c\u5e73\u5747\u503c\uff09\u3002\u5982\u679c\u6211\u4eec\u7528 P \u8868\u793a\u7cbe\u786e\u7387\uff0c\u7528 R \u8868\u793a\u53ec\u56de\u7387\uff0c\u90a3\u4e48 F1 \u5206\u6570\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a \\[ F1 = 2PR/(P + R) \\] \u6839\u636e TP\u3001FP \u548c FN\uff0c\u7a0d\u52a0\u6570\u5b66\u8ba1\u7b97\u5c31\u80fd\u5f97\u51fa\u4ee5\u4e0b F1 \u7b49\u5f0f\uff1a \\[ F1 = 2TP/(2TP + FP + FN) \\] Python \u5b9e\u73b0\u5f88\u7b80\u5355\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86\u8fd9\u4e9b def f1 ( y_true , y_pred ): # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , y_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , y_pred ) # \u8ba1\u7b97f1\u503c score = 2 * p * r / ( p + r ) return score \u8ba9\u6211\u4eec\u770b\u770b\u5176\u7ed3\u679c\uff0c\u5e76\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: f1 ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u901a\u8fc7 scikit learn\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u76f8\u540c\u7684\u5217\u8868\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . f1_score ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u4e0e\u5176\u5355\u72ec\u770b\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u60a8\u8fd8\u53ef\u4ee5\u53ea\u770b F1 \u5206\u6570\u3002\u4e0e\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c\u51c6\u786e\u5ea6\u4e00\u6837\uff0cF1 \u5206\u6570\u7684\u8303\u56f4\u4e5f\u662f\u4ece 0 \u5230 1\uff0c\u5b8c\u7f8e\u9884\u6d4b\u6a21\u578b\u7684 F1 \u5206\u6570\u4e3a 1\u3002 \u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5e94\u8be5\u4e86\u89e3\u5176\u4ed6\u4e00\u4e9b\u5173\u952e\u672f\u8bed\u3002 \u7b2c\u4e00\u4e2a\u672f\u8bed\u662f TPR \u6216\u771f\u9633\u6027\u7387\uff08True Positive Rate\uff09\uff0c\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\u3002 \\[ TPR = TP/(TP + FN) \\] \u5c3d\u7ba1\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\uff0c\u4f46\u6211\u4eec\u5c06\u4e3a\u5b83\u521b\u5efa\u4e00\u4e2a python \u51fd\u6570\uff0c\u4ee5\u4fbf\u4eca\u540e\u4f7f\u7528\u8fd9\u4e2a\u540d\u79f0\u3002 def tpr ( y_true , y_pred ): # \u771f\u9633\u6027\u7387\uff08TPR\uff09\uff0c\u4e0e\u53ec\u56de\u7387\u8ba1\u7b97\u516c\u5f0f\u4e00\u81f4 return recall ( y_true , y_pred ) TPR \u6216\u53ec\u56de\u7387\u4e5f\u88ab\u79f0\u4e3a\u7075\u654f\u5ea6\u3002 \u800c FPR \u6216\u5047\u9633\u6027\u7387\uff08False Positive Rate\uff09\u7684\u5b9a\u4e49\u662f\uff1a \\[ FPR = FP / (TN + FP) \\] def fpr ( y_true , y_pred ): # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u8fd4\u56de\u5047\u9633\u6027\u7387\uff08FPR\uff09 return fp / ( tn + fp ) 1 - FPR \u88ab\u79f0\u4e3a\u7279\u5f02\u6027\u6216\u771f\u9634\u6027\u7387\u6216 TNR\u3002\u8fd9\u4e9b\u672f\u8bed\u5f88\u591a\uff0c\u4f46\u5176\u4e2d\u6700\u91cd\u8981\u7684\u53ea\u6709 TPR \u548c FPR\u3002\u5047\u8bbe\u6211\u4eec\u53ea\u6709 15 \u4e2a\u6837\u672c\uff0c\u5176\u76ee\u6807\u503c\u4e3a\u4e8c\u5143\uff1a Actual targets : [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1] \u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u7c7b\u4f3c\u968f\u673a\u68ee\u6797\u7684\u6a21\u578b\uff0c\u5c31\u80fd\u5f97\u5230\u6837\u672c\u5448\u9633\u6027\u7684\u6982\u7387\u3002 Predicted probabilities for 1: [0.1, 0.3, 0.2, 0.6, 0.8, 0.05, 0.9, 0.5, 0.3, 0.66, 0.3, 0.2, 0.85, 0.15, 0.99] \u5bf9\u4e8e >= 0.5 \u7684\u5178\u578b\u9608\u503c\uff0c\u6211\u4eec\u53ef\u4ee5\u8bc4\u4f30\u4e0a\u8ff0\u6240\u6709\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387/TPR\u3001F1 \u548c FPR \u503c\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u5c06\u9608\u503c\u9009\u4e3a 0.4 \u6216 0.6\uff0c\u4e5f\u53ef\u4ee5\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0 \u5230 1 \u4e4b\u95f4\u7684\u4efb\u4f55\u503c\uff0c\u5e76\u8ba1\u7b97\u4e0a\u8ff0\u6240\u6709\u6307\u6807\u3002 \u4e0d\u8fc7\uff0c\u6211\u4eec\u53ea\u8ba1\u7b97\u4e24\u4e2a\u503c\uff1a TPR \u548c FPR\u3002 # \u521d\u59cb\u5316\u771f\u9633\u6027\u7387\u5217\u8868 tpr_list = [] # \u521d\u59cb\u5316\u5047\u9633\u6027\u7387\u5217\u8868 fpr_list = [] # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u7387 temp_tpr = tpr ( y_true , temp_pred ) # \u5047\u9633\u6027\u7387 temp_fpr = fpr ( y_true , temp_pred ) # \u5c06\u771f\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 tpr_list . append ( temp_tpr ) # \u5c06\u5047\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 fpr_list . append ( temp_fpr ) \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u6bcf\u4e2a\u9608\u503c\u7684 TPR \u503c\u548c FPR \u503c\u3002 \u56fe 3\uff1a\u9608\u503c\u3001TPR \u548c FPR \u503c\u8868 \u5982\u679c\u6211\u4eec\u7ed8\u5236\u5982\u56fe 3 \u6240\u793a\u7684\u8868\u683c\uff0c\u5373\u4ee5 TPR \u4e3a Y \u8f74\uff0cFPR \u4e3a X \u8f74\uff0c\u5c31\u4f1a\u5f97\u5230\u5982\u56fe 4 \u6240\u793a\u7684\u66f2\u7ebf\u3002 \u56fe 4\uff1aROC\u66f2\u7ebf \u8fd9\u6761\u66f2\u7ebf\u4e5f\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u3002\u5982\u679c\u6211\u4eec\u8ba1\u7b97\u8fd9\u6761 ROC \u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff0c\u5c31\u662f\u5728\u8ba1\u7b97\u53e6\u4e00\u4e2a\u6307\u6807\uff0c\u5f53\u6570\u636e\u96c6\u7684\u4e8c\u5143\u76ee\u6807\u504f\u659c\u65f6\uff0c\u8fd9\u4e2a\u6307\u6807\u5c31\u4f1a\u975e\u5e38\u5e38\u7528\u3002 \u8fd9\u4e2a\u6307\u6807\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u6216\u66f2\u7ebf\u4e0b\u9762\u79ef\uff0c\u7b80\u79f0 AUC\u3002\u8ba1\u7b97 ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002\u5728\u6b64\uff0c\u6211\u4eec\u5c06\u91c7\u7528 scikit- learn \u7684\u5947\u5999\u5b9e\u73b0\u65b9\u6cd5\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: metrics . roc_auc_score ( y_true , y_pred ) Out [ X ]: 0.8300000000000001 AUC \u503c\u4ece 0 \u5230 1 \u4e0d\u7b49\u3002 AUC = 1 \u610f\u5473\u7740\u60a8\u62e5\u6709\u4e00\u4e2a\u5b8c\u7f8e\u7684\u6a21\u578b\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u610f\u5473\u7740\u4f60\u5728\u9a8c\u8bc1\u65f6\u72af\u4e86\u4e00\u4e9b\u9519\u8bef\uff0c\u5e94\u8be5\u91cd\u65b0\u5ba1\u89c6\u6570\u636e\u5904\u7406\u548c\u9a8c\u8bc1\u6d41\u7a0b\u3002\u5982\u679c\u4f60\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\uff0c\u90a3\u4e48\u606d\u559c\u4f60\uff0c\u4f60\u5df2\u7ecf\u62e5\u6709\u4e86\u9488\u5bf9\u6570\u636e\u96c6\u5efa\u7acb\u7684\u6700\u4f73\u6a21\u578b\u3002 AUC = 0 \u610f\u5473\u7740\u60a8\u7684\u6a21\u578b\u975e\u5e38\u7cdf\u7cd5\uff08\u6216\u975e\u5e38\u597d\uff01\uff09\u3002\u8bd5\u7740\u53cd\u8f6c\u9884\u6d4b\u7684\u6982\u7387\uff0c\u4f8b\u5982\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u6b63\u7c7b\u7684\u6982\u7387\u662f p\uff0c\u8bd5\u7740\u7528 1-p \u4ee3\u66ff\u5b83\u3002\u8fd9\u79cd AUC \u4e5f\u53ef\u80fd\u610f\u5473\u7740\u60a8\u7684\u9a8c\u8bc1\u6216\u6570\u636e\u5904\u7406\u5b58\u5728\u95ee\u9898\u3002 AUC = 0.5 \u610f\u5473\u7740\u4f60\u7684\u9884\u6d4b\u662f\u968f\u673a\u7684\u3002\u56e0\u6b64\uff0c\u5bf9\u4e8e\u4efb\u4f55\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u5982\u679c\u6211\u5c06\u6240\u6709\u76ee\u6807\u90fd\u9884\u6d4b\u4e3a 0.5\uff0c\u6211\u5c06\u5f97\u5230 0.5 \u7684 AUC\u3002 AUC \u503c\u4ecb\u4e8e 0 \u548c 0.5 \u4e4b\u95f4\uff0c\u610f\u5473\u7740\u4f60\u7684\u6a21\u578b\u6bd4\u968f\u673a\u6a21\u578b\u66f4\u5dee\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u662f\u56e0\u4e3a\u4f60\u98a0\u5012\u4e86\u7c7b\u522b\u3002 \u5982\u679c\u60a8\u5c1d\u8bd5\u53cd\u8f6c\u9884\u6d4b\uff0c\u60a8\u7684 AUC \u503c\u53ef\u80fd\u4f1a\u8d85\u8fc7 0.5\u3002\u63a5\u8fd1 1 \u7684 AUC \u503c\u88ab\u8ba4\u4e3a\u662f\u597d\u503c\u3002 \u4f46 AUC \u5bf9\u6211\u4eec\u7684\u6a21\u578b\u6709\u4ec0\u4e48\u5f71\u54cd\u5462\uff1f \u5047\u8bbe\u60a8\u5efa\u7acb\u4e86\u4e00\u4e2a\u4ece\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e2d\u68c0\u6d4b\u6c14\u80f8\u7684\u6a21\u578b\uff0c\u5176 AUC \u503c\u4e3a 0.85\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5982\u679c\u60a8\u4ece\u6570\u636e\u96c6\u4e2d\u968f\u673a\u9009\u62e9\u4e00\u5f20\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9633\u6027\u6837\u672c\uff09\u548c\u53e6\u4e00\u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9634\u6027\u6837\u672c\uff09\uff0c\u90a3\u4e48\u6c14\u80f8\u56fe\u50cf\u7684\u6392\u540d\u5c06\u9ad8\u4e8e\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u6982\u7387\u4e3a 0.85\u3002 \u8ba1\u7b97\u6982\u7387\u548c AUC \u540e\uff0c\u60a8\u9700\u8981\u5bf9\u6d4b\u8bd5\u96c6\u8fdb\u884c\u9884\u6d4b\u3002\u6839\u636e\u95ee\u9898\u548c\u4f7f\u7528\u60c5\u51b5\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u6982\u7387\u6216\u5b9e\u9645\u7c7b\u522b\u3002\u5982\u679c\u4f60\u60f3\u8981\u6982\u7387\uff0c\u8fd9\u5e76\u4e0d\u96be\u3002\u5982\u679c\u60a8\u60f3\u8981\u7c7b\u522b\uff0c\u5219\u9700\u8981\u9009\u62e9\u4e00\u4e2a\u9608\u503c\u3002\u5728\u4e8c\u5143\u5206\u7c7b\u7684\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u91c7\u7528\u7c7b\u4f3c\u4e0b\u9762\u7684\u65b9\u6cd5\u3002 \\[ Prediction = Probability >= Threshold \\] \u4e5f\u5c31\u662f\u8bf4\uff0c\u9884\u6d4b\u662f\u4e00\u4e2a\u53ea\u5305\u542b\u4e8c\u5143\u53d8\u91cf\u7684\u65b0\u5217\u8868\u3002\u5982\u679c\u6982\u7387\u5927\u4e8e\u6216\u7b49\u4e8e\u7ed9\u5b9a\u7684\u9608\u503c\uff0c\u5219\u9884\u6d4b\u4e2d\u7684\u4e00\u9879\u4e3a 1\uff0c\u5426\u5219\u4e3a 0\u3002 \u4f60\u731c\u600e\u4e48\u7740\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 ROC \u66f2\u7ebf\u6765\u9009\u62e9\u8fd9\u4e2a\u9608\u503c\uff01ROC \u66f2\u7ebf\u4f1a\u544a\u8bc9\u60a8\u9608\u503c\u5bf9\u5047\u9633\u6027\u7387\u548c\u771f\u9633\u6027\u7387\u7684\u5f71\u54cd\uff0c\u8fdb\u800c\u5f71\u54cd\u5047\u9633\u6027\u548c\u771f\u9633\u6027\u3002\u60a8\u5e94\u8be5\u9009\u62e9\u6700\u9002\u5408\u60a8\u7684\u95ee\u9898\u548c\u6570\u636e\u96c6\u7684\u9608\u503c\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u4e0d\u5e0c\u671b\u6709\u592a\u591a\u7684\u8bef\u62a5\uff0c\u90a3\u4e48\u9608\u503c\u5c31\u5e94\u8be5\u9ad8\u4e00\u4e9b\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e5f\u4f1a\u5e26\u6765\u66f4\u591a\u7684\u8bef\u62a5\u3002\u6ce8\u610f\u6743\u8861\u5229\u5f0a\uff0c\u9009\u62e9\u6700\u4f73\u9608\u503c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u9608\u503c\u5982\u4f55\u5f71\u54cd\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u503c\u3002 # \u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list = [] # \u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list = [] # \u771f\u5b9e\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 temp_tp = true_positive ( y_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 temp_fp = false_positive ( y_true , temp_pred ) # \u52a0\u5165\u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list . append ( temp_tp ) # \u52a0\u5165\u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list . append ( temp_fp ) \u5229\u7528\u8fd9\u4e00\u70b9\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8868\u683c\uff0c\u5982\u56fe 5 \u6240\u793a\u3002 \u56fe 5\uff1a\u4e0d\u540c\u9608\u503c\u7684 TP \u503c\u548c FP \u503c \u5982\u56fe 6 \u6240\u793a\uff0c\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0cROC \u66f2\u7ebf\u5de6\u4e0a\u89d2\u7684\u503c\u5e94\u8be5\u662f\u4e00\u4e2a\u76f8\u5f53\u4e0d\u9519\u7684\u9608\u503c\u3002 \u5bf9\u6bd4\u8868\u683c\u548c ROC \u66f2\u7ebf\uff0c\u6211\u4eec\u53ef\u4ee5\u53d1\u73b0\uff0c0.6 \u5de6\u53f3\u7684\u9608\u503c\u76f8\u5f53\u4e0d\u9519\uff0c\u65e2\u4e0d\u4f1a\u4e22\u5931\u5927\u91cf\u7684\u771f\u9633\u6027\u7ed3\u679c\uff0c\u4e5f\u4e0d\u4f1a\u51fa\u73b0\u5927\u91cf\u7684\u5047\u9633\u6027\u7ed3\u679c\u3002 \u56fe 6\uff1a\u4ece ROC \u66f2\u7ebf\u6700\u5de6\u4fa7\u7684\u9876\u70b9\u9009\u62e9\u6700\u4f73\u9608\u503c AUC \u662f\u4e1a\u5185\u5e7f\u6cdb\u5e94\u7528\u4e8e\u504f\u659c\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6307\u6807\uff0c\u4e5f\u662f\u6bcf\u4e2a\u4eba\u90fd\u5e94\u8be5\u4e86\u89e3\u7684\u6307\u6807\u3002\u4e00\u65e6\u7406\u89e3\u4e86 AUC \u80cc\u540e\u7684\u7406\u5ff5\uff08\u5982\u4e0a\u6587\u6240\u8ff0\uff09\uff0c\u4e5f\u5c31\u5f88\u5bb9\u6613\u5411\u4e1a\u754c\u53ef\u80fd\u4f1a\u8bc4\u4f30\u60a8\u7684\u6a21\u578b\u7684\u975e\u6280\u672f\u4eba\u5458\u89e3\u91ca\u5b83\u4e86\u3002 \u5b66\u4e60 AUC \u540e\uff0c\u4f60\u5e94\u8be5\u5b66\u4e60\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u6307\u6807\u662f\u5bf9\u6570\u635f\u5931\u3002\u5bf9\u4e8e\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5c06\u5bf9\u6570\u635f\u5931\u5b9a\u4e49\u4e3a\uff1a \\[ LogLoss = -1.0 \\times (target \\times log(prediction) + (1-target) \\times log(1-prediction)) \\] \u5176\u4e2d\uff0c\u76ee\u6807\u503c\u4e3a 0 \u6216 1\uff0c\u9884\u6d4b\u503c\u4e3a\u6837\u672c\u5c5e\u4e8e\u7c7b\u522b 1 \u7684\u6982\u7387\u3002 \u5bf9\u4e8e\u6570\u636e\u96c6\u4e2d\u7684\u591a\u4e2a\u6837\u672c\uff0c\u6240\u6709\u6837\u672c\u7684\u5bf9\u6570\u635f\u5931\u53ea\u662f\u6240\u6709\u5355\u4e2a\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u9700\u8981\u8bb0\u4f4f\u7684\u4e00\u70b9\u662f\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u4e0d\u6b63\u786e\u6216\u504f\u5dee\u8f83\u5927\u7684\u9884\u6d4b\u8fdb\u884c\u76f8\u5f53\u9ad8\u7684\u60e9\u7f5a\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u975e\u5e38\u786e\u5b9a\u548c\u975e\u5e38\u9519\u8bef\u7684\u9884\u6d4b\u8fdb\u884c\u60e9\u7f5a\u3002 import numpy as np def log_loss ( y_true , y_proba ): # \u6781\u5c0f\u503c\uff0c\u9632\u6b620\u505a\u5206\u6bcd epsilon = 1e-15 # \u5bf9\u6570\u635f\u5931\u5217\u8868 loss = [] # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_proba ): # \u9650\u5236yp\u8303\u56f4\uff0c\u6700\u5c0f\u4e3aepsilon\uff0c\u6700\u5927\u4e3a1-epsilon yp = np . clip ( yp , epsilon , 1 - epsilon ) # \u8ba1\u7b97\u5bf9\u6570\u635f\u5931 temp_loss = - 1.0 * ( yt * np . log ( yp ) + ( 1 - yt ) * np . log ( 1 - yp )) # \u52a0\u5165\u5bf9\u6570\u635f\u5931\u5217\u8868 loss . append ( temp_loss ) return np . mean ( loss ) \u8ba9\u6211\u4eec\u6d4b\u8bd5\u4e00\u4e0b\u51fd\u6570\u6267\u884c\u60c5\u51b5\uff1a In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_proba = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u6211\u4eec\u53ef\u4ee5\u5c06\u5176\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u5b9e\u73b0\u662f\u6b63\u786e\u7684\u3002 \u5bf9\u6570\u635f\u5931\u7684\u5b9e\u73b0\u5f88\u5bb9\u6613\u3002\u89e3\u91ca\u8d77\u6765\u4f3c\u4e4e\u6709\u70b9\u56f0\u96be\u3002\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5bf9\u6570\u635f\u5931\u7684\u60e9\u7f5a\u8981\u6bd4\u5176\u4ed6\u6307\u6807\u5927\u5f97\u591a\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u6709 51% \u7684\u628a\u63e1\u8ba4\u4e3a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\uff0c\u90a3\u4e48\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.51) + (1 - 1) \\times log(1 - 0.51))=0.67 \\] \u5982\u679c\u4f60\u5bf9\u5c5e\u4e8e 0 \u7c7b\u7684\u6837\u672c\u6709 49% \u7684\u628a\u63e1\uff0c\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.49) + (1 - 1) \\times log(1 - 0.49))=0.67 \\] \u56e0\u6b64\uff0c\u5373\u4f7f\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0.5 \u7684\u622a\u65ad\u503c\u5e76\u5f97\u5230\u5b8c\u7f8e\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u6211\u4eec\u4ecd\u7136\u4f1a\u6709\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002\u56e0\u6b64\uff0c\u5728\u5904\u7406\u5bf9\u6570\u635f\u5931\u65f6\uff0c\u4f60\u9700\u8981\u975e\u5e38\u5c0f\u5fc3\uff1b\u4efb\u4f55\u4e0d\u786e\u5b9a\u7684\u9884\u6d4b\u90fd\u4f1a\u4ea7\u751f\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002 \u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u5927\u591a\u6570\u6307\u6807\u90fd\u53ef\u4ee5\u8f6c\u6362\u6210\u591a\u7c7b\u7248\u672c\u3002\u8fd9\u4e2a\u60f3\u6cd5\u5f88\u7b80\u5355\u3002\u4ee5\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u4e3a\u4f8b\u3002\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u6bcf\u4e00\u7c7b\u7684\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u6709\u4e09\u79cd\u4e0d\u540c\u7684\u8ba1\u7b97\u65b9\u6cd5\uff0c\u6709\u65f6\u53ef\u80fd\u4f1a\u4ee4\u4eba\u56f0\u60d1\u3002\u5047\u8bbe\u6211\u4eec\u9996\u5148\u5bf9\u7cbe\u786e\u7387\u611f\u5174\u8da3\u3002\u6211\u4eec\u77e5\u9053\uff0c\u7cbe\u786e\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u3002 \u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Macro averaged precision\uff09\uff1a\u5206\u522b\u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u7cbe\u786e\u7387\u7136\u540e\u6c42\u5e73\u5747\u503c \u5fae\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Micro averaged precision\uff09\uff1a\u6309\u7c7b\u8ba1\u7b97\u771f\u9633\u6027\u548c\u5047\u9633\u6027\uff0c\u7136\u540e\u7528\u5176\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387\u3002\u7136\u540e\u4ee5\u6b64\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387 \u52a0\u6743\u7cbe\u786e\u7387 \uff08Weighted precision\uff09\uff1a\u4e0e\u5b8f\u89c2\u7cbe\u786e\u7387\u76f8\u540c\uff0c\u4f46\u8fd9\u91cc\u662f\u52a0\u6743\u5e73\u5747\u7cbe\u786e\u7387 \u53d6\u51b3\u4e8e\u6bcf\u4e2a\u7c7b\u522b\u4e2d\u7684\u9879\u76ee\u6570 \u8fd9\u770b\u4f3c\u590d\u6742\uff0c\u4f46\u5728 python \u5b9e\u73b0\u4e2d\u5f88\u5bb9\u6613\u7406\u89e3\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import numpy as np def macro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u5982\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u8ba1\u7b97\u7cbe\u786e\u5ea6 temp_precision = tp / ( tp + fp ) # \u5404\u7c7b\u7cbe\u786e\u7387\u76f8\u52a0 precision += temp_precision # \u8ba1\u7b97\u5e73\u5747\u503c precision /= num_classes return precision \u4f60\u4f1a\u53d1\u73b0\u8fd9\u5e76\u4e0d\u96be\u3002\u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5fae\u5e73\u5747\u7cbe\u786e\u7387\u5206\u6570\u3002 import numpy as np def micro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u6570 tp = 0 # \u521d\u59cb\u5316\u5047\u9633\u6027\u6837\u672c\u6570 fp = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 tp += true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 fp += false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8fd9\u4e5f\u4e0d\u96be\u3002\u90a3\u4ec0\u4e48\u96be\uff1f\u4ec0\u4e48\u90fd\u4e0d\u96be\u3002\u673a\u5668\u5b66\u4e60\u5f88\u7b80\u5355\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u52a0\u6743\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u3002 from collections import Counter import numpy as np def weighted_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 temp_precision = tp / ( tp + fp ) # \u6839\u636e\u8be5\u79cd\u7c7b\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_precision = class_counts [ class_ ] * temp_precision # \u52a0\u6743\u7cbe\u786e\u7387\u6c42\u548c precision += weighted_precision # \u8ba1\u7b97\u5e73\u5747\u7cbe\u786e\u7387 overall_precision = precision / len ( y_true ) return overall_precision \u5c06\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff0c\u4ee5\u4e86\u89e3\u5b9e\u73b0\u662f\u5426\u6b63\u786e\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: macro_precision ( y_true , y_pred ) Out [ X ]: 0.3611111111111111 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"macro\" ) Out [ X ]: 0.3611111111111111 In [ X ]: micro_precision ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"micro\" ) Out [ X ]: 0.4444444444444444 In [ X ]: weighted_precision ( y_true , y_pred ) Out [ X ]: 0.39814814814814814 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.39814814814814814 \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6b63\u786e\u5730\u5b9e\u73b0\u4e86\u4e00\u5207\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u91cc\u5c55\u793a\u7684\u5b9e\u73b0\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\uff0c\u4f46\u5374\u662f\u6700\u5bb9\u6613\u7406\u89e3\u7684\u3002 \u540c\u6837\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5b9e\u73b0 \u591a\u7c7b\u522b\u7684\u53ec\u56de\u7387\u6307\u6807 \u3002\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u3001\u5047\u9633\u6027\u548c\u5047\u9634\u6027\uff0c\u800c F1 \u5219\u53d6\u51b3\u4e8e\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u53ec\u56de\u7387\u7684\u5b9e\u73b0\u65b9\u6cd5\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\uff0c\u8fd9\u91cc\u5b9e\u73b0\u7684\u662f\u591a\u7c7b F1 \u7684\u4e00\u4e2a\u7248\u672c\uff0c\u5373\u52a0\u6743\u5e73\u5747\u503c\u3002 from collections import Counter import numpy as np def weighted_f1 ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316F1\u503c f1 = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( temp_true , temp_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( temp_true , temp_pred ) # \u82e5\u7cbe\u786e\u7387+\u53ec\u56de\u7387\u4e0d\u4e3a0\uff0c\u5219\u4f7f\u7528\u516c\u5f0f\u8ba1\u7b97F1\u503c if p + r != 0 : temp_f1 = 2 * p * r / ( p + r ) # \u5426\u5219\u76f4\u63a5\u4e3a0 else : temp_f1 = 0 # \u6839\u636e\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_f1 = class_counts [ class_ ] * temp_f1 # \u52a0\u6743F1\u503c\u76f8\u52a0 f1 += weighted_f1 # \u8ba1\u7b97\u52a0\u6743\u5e73\u5747F1\u503c overall_f1 = f1 / len ( y_true ) return overall_f1 \u8bf7\u6ce8\u610f\uff0c\u4e0a\u9762\u6709\u51e0\u884c\u4ee3\u7801\u662f\u65b0\u5199\u7684\u3002\u56e0\u6b64\uff0c\u4f60\u5e94\u8be5\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e9b\u4ee3\u7801\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: weighted_f1 ( y_true , y_pred ) Out [ X ]: 0.41269841269841273 In [ X ]: metrics . f1_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.41269841269841273 \u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u591a\u7c7b\u95ee\u9898\u5b9e\u73b0\u4e86\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002\u540c\u6837\uff0c\u60a8\u4e5f\u53ef\u4ee5\u5c06 AUC \u548c\u5bf9\u6570\u635f\u5931\u8f6c\u6362\u4e3a\u591a\u7c7b\u683c\u5f0f\u3002\u8fd9\u79cd\u8f6c\u6362\u683c\u5f0f\u88ab\u79f0\u4e3a one-vs-all \u3002\u8fd9\u91cc\u6211\u4e0d\u6253\u7b97\u5b9e\u73b0\u5b83\u4eec\uff0c\u56e0\u4e3a\u5b9e\u73b0\u65b9\u6cd5\u4e0e\u6211\u4eec\u5df2\u7ecf\u8ba8\u8bba\u8fc7\u7684\u5f88\u76f8\u4f3c\u3002 \u5728\u4e8c\u5143\u6216\u591a\u7c7b\u5206\u7c7b\u4e2d\uff0c\u770b\u4e00\u4e0b \u6df7\u6dc6\u77e9\u9635 \u4e5f\u5f88\u6d41\u884c\u3002\u4e0d\u8981\u56f0\u60d1\uff0c\u8fd9\u5f88\u7b80\u5355\u3002\u6df7\u6dc6\u77e9\u9635\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u5305\u542b TP\u3001FP\u3001TN \u548c FN \u7684\u8868\u683c\u3002\u4f7f\u7528\u6df7\u6dc6\u77e9\u9635\uff0c\u60a8\u53ef\u4ee5\u5feb\u901f\u67e5\u770b\u6709\u591a\u5c11\u6837\u672c\u88ab\u9519\u8bef\u5206\u7c7b\uff0c\u6709\u591a\u5c11\u6837\u672c\u88ab\u6b63\u786e\u5206\u7c7b\u3002\u4e5f\u8bb8\u6709\u4eba\u4f1a\u8bf4\uff0c\u6df7\u6dc6\u77e9\u9635\u5e94\u8be5\u5728\u672c\u7ae0\u5f88\u65e9\u5c31\u8bb2\u5230\uff0c\u4f46\u6211\u6ca1\u6709\u8fd9\u4e48\u505a\u3002\u5982\u679c\u4e86\u89e3\u4e86 TP\u3001FP\u3001TN\u3001FN\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c AUC\uff0c\u5c31\u5f88\u5bb9\u6613\u7406\u89e3\u548c\u89e3\u91ca\u6df7\u6dc6\u77e9\u9635\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u56fe 7 \u4e2d\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6df7\u6dc6\u77e9\u9635\u7531 TP\u3001FP\u3001FN \u548c TN \u7ec4\u6210\u3002\u6211\u4eec\u53ea\u9700\u8981\u8fd9\u4e9b\u503c\u6765\u8ba1\u7b97\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u3001F1 \u5206\u6570\u548c AUC\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u4e5f\u559c\u6b22\u628a FP \u79f0\u4e3a \u7b2c\u4e00\u7c7b\u9519\u8bef \uff0c\u628a FN \u79f0\u4e3a \u7b2c\u4e8c\u7c7b\u9519\u8bef \u3002 \u56fe 7\uff1a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6df7\u6dc6\u77e9\u9635 \u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u4e8c\u5143\u6df7\u6dc6\u77e9\u9635\u6269\u5c55\u4e3a\u591a\u7c7b\u6df7\u6dc6\u77e9\u9635\u3002\u5b83\u4f1a\u662f\u4ec0\u4e48\u6837\u5b50\u5462\uff1f\u5982\u679c\u6211\u4eec\u6709 N \u4e2a\u7c7b\u522b\uff0c\u5b83\u5c06\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a NxN \u7684\u77e9\u9635\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b\uff0c\u6211\u4eec\u90fd\u8981\u8ba1\u7b97\u76f8\u5173\u7c7b\u522b\u548c\u5176\u4ed6\u7c7b\u522b\u7684\u6837\u672c\u603b\u6570\u3002\u4e3e\u4e2a\u4f8b\u5b50\u53ef\u4ee5\u8ba9\u6211\u4eec\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4ee5\u4e0b\u771f\u5b9e\u6807\u7b7e\uff1a \\[ [0, 1, 2, 0, 1, 2, 0, 2, 2] \\] \u6211\u4eec\u7684\u9884\u6d4b\u6807\u7b7e\u662f\uff1a \\[ [0, 2, 1, 0, 2, 1, 0, 0, 2] \\] \u90a3\u4e48\uff0c\u6211\u4eec\u7684\u6df7\u6dc6\u77e9\u9635\u5c06\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u591a\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635 \u56fe 8 \u8bf4\u660e\u4e86\u4ec0\u4e48\uff1f \u8ba9\u6211\u4eec\u6765\u770b\u770b 0 \u7c7b\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e 0 \u7c7b\u3002\u7136\u800c\uff0c\u5728\u9884\u6d4b\u4e2d\uff0c\u6211\u4eec\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 0 \u7c7b\uff0c1 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\u3002\u7406\u60f3\u60c5\u51b5\u4e0b\uff0c\u5bf9\u4e8e\u771f\u5b9e\u6807\u7b7e\u4e2d\u7684\u7c7b\u522b 0\uff0c\u9884\u6d4b\u6807\u7b7e 1 \u548c 2 \u5e94\u8be5\u6ca1\u6709\u4efb\u4f55\u6837\u672c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u7c7b\u522b 2\u3002\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 4\uff0c\u800c\u5728\u9884\u6d4b\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 3\u3002 \u4e00\u4e2a\u5b8c\u7f8e\u7684\u6df7\u6dc6\u77e9\u9635\u53ea\u80fd\u4ece\u5de6\u5230\u53f3\u659c\u5411\u586b\u5145\u3002 \u6df7\u6dc6\u77e9\u9635 \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u8ba1\u7b97\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u4e0d\u540c\u6307\u6807\u3002Scikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u76f4\u63a5\u7684\u65b9\u6cd5\u6765\u751f\u6210\u6df7\u6dc6\u77e9\u9635\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5728\u56fe 8 \u4e2d\u663e\u793a\u7684\u6df7\u6dc6\u77e9\u9635\u662f scikit-learn \u6df7\u6dc6\u77e9\u9635\u7684\u8f6c\u7f6e\uff0c\u539f\u59cb\u7248\u672c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u7ed8\u5236\u3002 import matplotlib.pyplot as plt import seaborn as sns from sklearn import metrics # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] # \u9884\u6d4b\u6837\u672c\u6807\u7b7e y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] # \u8ba1\u7b97\u6df7\u6dc6\u77e9\u9635 cm = metrics . confusion_matrix ( y_true , y_pred ) # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 10 , 10 )) # \u521b\u5efa\u65b9\u683c cmap = sns . cubehelix_palette ( 50 , hue = 0.05 , rot = 0 , light = 0.9 , dark = 0 , as_cmap = True ) # \u89c4\u5b9a\u5b57\u4f53\u5927\u5c0f sns . set ( font_scale = 2.5 ) # \u7ed8\u5236\u70ed\u56fe sns . heatmap ( cm , annot = True , cmap = cmap , cbar = False ) # y\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . ylabel ( 'Actual Labels' , fontsize = 20 ) # x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . xlabel ( 'Predicted Labels' , fontsize = 20 ) \u56e0\u6b64\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u89e3\u51b3\u4e86\u4e8c\u5143\u5206\u7c7b\u548c\u591a\u7c7b\u5206\u7c7b\u7684\u5ea6\u91cf\u95ee\u9898\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u8ba8\u8bba\u53e6\u4e00\u79cd\u7c7b\u578b\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u5373\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u5728\u591a\u6807\u7b7e\u5206\u7c7b\u4e2d\uff0c\u6bcf\u4e2a\u6837\u672c\u90fd\u53ef\u80fd\u4e0e\u4e00\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u76f8\u5173\u8054\u3002\u8fd9\u7c7b\u95ee\u9898\u7684\u4e00\u4e2a\u7b80\u5355\u4f8b\u5b50\u5c31\u662f\u8981\u6c42\u4f60\u9884\u6d4b\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53\u3002 \u56fe 9 \u663e\u793a\u4e86\u4e00\u4e2a\u8457\u540d\u6570\u636e\u96c6\u7684\u56fe\u50cf\u793a\u4f8b\u3002\u8bf7\u6ce8\u610f\uff0c\u8be5\u6570\u636e\u96c6\u7684\u76ee\u6807\u6709\u6240\u4e0d\u540c\uff0c\u4f46\u6211\u4eec\u6682\u4e14\u4e0d\u53bb\u8ba8\u8bba\u5b83\u3002\u6211\u4eec\u5047\u8bbe\u5176\u76ee\u7684\u53ea\u662f\u9884\u6d4b\u56fe\u50cf\u4e2d\u662f\u5426\u5b58\u5728\u67d0\u4e2a\u7269\u4f53\u3002\u5728\u56fe 9 \u4e2d\uff0c\u6211\u4eec\u6709\u6905\u5b50\u3001\u82b1\u76c6\u3001\u7a97\u6237\uff0c\u4f46\u6ca1\u6709\u5176\u4ed6\u7269\u4f53\uff0c\u5982\u7535\u8111\u3001\u5e8a\u3001\u7535\u89c6\u7b49\u3002\u56e0\u6b64\uff0c\u4e00\u5e45\u56fe\u50cf\u53ef\u80fd\u6709\u591a\u4e2a\u76f8\u5173\u76ee\u6807\u3002\u8fd9\u7c7b\u95ee\u9898\u5c31\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u95ee\u9898\u3002 \u56fe 9\uff1a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53 \u8fd9\u7c7b\u5206\u7c7b\u95ee\u9898\u7684\u8861\u91cf\u6807\u51c6\u6709\u4e9b\u4e0d\u540c\u3002\u4e00\u4e9b\u5408\u9002\u7684 \u6700\u5e38\u89c1\u7684\u6307\u6807\u6709\uff1a k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u786e\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 \u8ba9\u6211\u4eec\u4ece k \u7cbe\u786e\u7387\u6216\u8005 P@k \u6211\u4eec\u4e0d\u80fd\u5c06\u8fd9\u4e00\u7cbe\u786e\u7387\u4e0e\u524d\u9762\u8ba8\u8bba\u7684\u7cbe\u786e\u7387\u6df7\u6dc6\u3002\u5982\u679c\u60a8\u6709\u4e00\u4e2a\u7ed9\u5b9a\u6837\u672c\u7684\u539f\u59cb\u7c7b\u522b\u5217\u8868\u548c\u540c\u4e00\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u7c7b\u522b\u5217\u8868\uff0c\u90a3\u4e48\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u5c31\u662f\u9884\u6d4b\u5217\u8868\u4e2d\u4ec5\u8003\u8651\u524d k \u4e2a\u9884\u6d4b\u7ed3\u679c\u7684\u547d\u4e2d\u6570\u9664\u4ee5 k\u3002 \u5982\u679c\u60a8\u5bf9\u6b64\u611f\u5230\u56f0\u60d1\uff0c\u4f7f\u7528 python \u4ee3\u7801\u540e\u5c31\u4f1a\u660e\u767d\u3002 def pk ( y_true , y_pred , k ): # \u5982\u679ck\u4e3a0 if k == 0 : # \u8fd4\u56de0 return 0 # \u53d6\u9884\u6d4b\u6807\u7b7e\u524dk\u4e2a y_pred = y_pred [: k ] # \u5c06\u9884\u6d4b\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 pred_set = set ( y_pred ) # \u5c06\u771f\u5b9e\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 true_set = set ( y_true ) # \u9884\u6d4b\u6807\u7b7e\u96c6\u5408\u4e0e\u771f\u5b9e\u6807\u7b7e\u96c6\u5408\u4ea4\u96c6 common_values = pred_set . intersection ( true_set ) # \u8ba1\u7b97\u7cbe\u786e\u7387 return len ( common_values ) / len ( y_pred [: k ]) \u6709\u4e86\u4ee3\u7801\uff0c\u4e00\u5207\u90fd\u53d8\u5f97\u66f4\u5bb9\u6613\u7406\u89e3\u4e86\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e86 k \u5e73\u5747\u7cbe\u786e\u7387\u6216 AP@k \u3002AP@k \u662f\u901a\u8fc7 P@k \u8ba1\u7b97\u5f97\u51fa\u7684\u3002\u4f8b\u5982\uff0c\u5982\u679c\u8981\u8ba1\u7b97 AP@3\uff0c\u6211\u4eec\u8981\u5148\u8ba1\u7b97 P@1\u3001P@2 \u548c P@3\uff0c\u7136\u540e\u5c06\u603b\u548c\u9664\u4ee5 3\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u7684\u5b9e\u73b0\u3002 def apk ( y_true , y_pred , k ): # \u521d\u59cb\u5316P@k\u5217\u8868 pk_values = [] # \u904d\u53861~k for i in range ( 1 , k + 1 ): # \u5c06P@k\u52a0\u5165\u5217\u8868 pk_values . append ( pk ( y_true , y_pred , i )) # \u82e5\u957f\u5ea6\u4e3a0 if len ( pk_values ) == 0 : # \u8fd4\u56de0 return 0 # \u5426\u5219\u8ba1\u7b97AP@K return sum ( pk_values ) / len ( pk_values ) \u8fd9\u4e24\u4e2a\u51fd\u6570\u53ef\u4ee5\u7528\u6765\u8ba1\u7b97\u4e24\u4e2a\u7ed9\u5b9a\u5217\u8868\u7684 k \u5e73\u5747\u7cbe\u786e\u7387 (AP@k)\uff1b\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u8ba1\u7b97\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: for i in range ( len ( y_true )): ... : for j in range ( 1 , 4 ): ... : print ( ... : f \"\"\" ...: y_true= { y_true [ i ] } , ...: y_pred= { y_pred [ i ] } , ...: AP@ { j } = { apk ( y_true [ i ], y_pred [ i ], k = j ) } ...: \"\"\" ... : ) ... : y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 1 = 0.0 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 2 = 0.25 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 3 = 0.38888888888888884 \u8bf7\u6ce8\u610f\uff0c\u6211\u7701\u7565\u4e86\u8f93\u51fa\u7ed3\u679c\u4e2d\u7684\u8bb8\u591a\u6570\u503c\uff0c\u4f46\u4f60\u4f1a\u660e\u767d\u5176\u4e2d\u7684\u610f\u601d\u3002\u8fd9\u5c31\u662f\u6211\u4eec\u5982\u4f55\u8ba1\u7b97 AP@k \u7684\u65b9\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u7684 AP@k\u3002\u5728\u673a\u5668\u5b66\u4e60\u4e2d\uff0c\u6211\u4eec\u5bf9\u6240\u6709\u6837\u672c\u90fd\u611f\u5174\u8da3\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u6709 \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387 k \u6216 MAP@k \u3002MAP@k \u53ea\u662f AP@k \u7684\u5e73\u5747\u503c\uff0c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b python \u4ee3\u7801\u8f7b\u677e\u8ba1\u7b97\u3002 def mapk ( y_true , y_pred , k ): # \u521d\u59cb\u5316AP@k\u5217\u8868 apk_values = [] # \u904d\u53860~\uff08\u771f\u5b9e\u6807\u7b7e\u6570-1\uff09 for i in range ( len ( y_true )): # \u5c06AP@K\u52a0\u5165\u5217\u8868 apk_values . append ( apk ( y_true [ i ], y_pred [ i ], k = k ) ) # \u8ba1\u7b97\u5e73\u5747AP@k return sum ( apk_values ) / len ( apk_values ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u9488\u5bf9\u76f8\u540c\u7684\u5217\u8868\u8ba1\u7b97 k=1\u30012\u30013 \u548c 4 \u65f6\u7684 MAP@k\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: mapk ( y_true , y_pred , k = 1 ) Out [ X ]: 0.3333333333333333 In [ X ]: mapk ( y_true , y_pred , k = 2 ) Out [ X ]: 0.375 In [ X ]: mapk ( y_true , y_pred , k = 3 ) Out [ X ]: 0.3611111111111111 In [ X ]: mapk ( y_true , y_pred , k = 4 ) Out [ X ]: 0.34722222222222215 P@k\u3001AP@k \u548c MAP@k \u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u5176\u4e2d 1 \u4e3a\u6700\u4f73\u3002 \u8bf7\u6ce8\u610f\uff0c\u6709\u65f6\u60a8\u53ef\u80fd\u4f1a\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230 P@k \u548c AP@k \u7684\u4e0d\u540c\u5b9e\u73b0\u65b9\u5f0f\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5176\u4e2d\u4e00\u79cd\u5b9e\u73b0\u65b9\u5f0f\u3002 import numpy as np def apk ( actual , predicted , k = 10 ): # \u82e5\u9884\u6d4b\u6807\u7b7e\u957f\u5ea6\u5927\u4e8ek if len ( predicted ) > k : # \u53d6\u524dk\u4e2a\u6807\u7b7e predicted = predicted [: k ] score = 0.0 num_hits = 0.0 for i , p in enumerate ( predicted ): if p in actual and p not in predicted [: i ]: num_hits += 1.0 score += num_hits / ( i + 1.0 ) if not actual : return 0.0 return score / min ( len ( actual ), k ) \u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u662f AP@k \u7684\u53e6\u4e00\u4e2a\u7248\u672c\uff0c\u5176\u4e2d\u987a\u5e8f\u5f88\u91cd\u8981\uff0c\u6211\u4eec\u8981\u6743\u8861\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u7684\u7ed3\u679c\u4e0e\u6211\u7684\u4ecb\u7ecd\u7565\u6709\u4e0d\u540c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6765\u770b\u770b \u591a\u6807\u7b7e\u5206\u7c7b\u7684\u5bf9\u6570\u635f\u5931 \u3002\u8fd9\u5f88\u5bb9\u6613\u3002\u60a8\u53ef\u4ee5\u5c06\u76ee\u6807\u8f6c\u6362\u4e3a\u4e8c\u5143\u5206\u7c7b\uff0c\u7136\u540e\u5bf9\u6bcf\u4e00\u5217\u4f7f\u7528\u5bf9\u6570\u635f\u5931\u3002\u6700\u540e\uff0c\u4f60\u53ef\u4ee5\u6c42\u51fa\u6bcf\u5217\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u8fd9\u4e5f\u88ab\u79f0\u4e3a\u5e73\u5747\u5217\u5bf9\u6570\u635f\u5931\u3002\u5f53\u7136\uff0c\u8fd8\u6709\u5176\u4ed6\u65b9\u6cd5\u53ef\u4ee5\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff0c\u4f60\u5e94\u8be5\u5728\u9047\u5230\u65f6\u52a0\u4ee5\u63a2\u7d22\u3002 \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u8bf4\u5df2\u7ecf\u638c\u63e1\u4e86\u6240\u6709\u4e8c\u5143\u5206\u7c7b\u3001\u591a\u7c7b\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u6307\u6807\uff0c\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u6307\u6807\u3002 \u56de\u5f52\u4e2d\u6700\u5e38\u89c1\u7684\u6307\u6807\u662f \u8bef\u5dee\uff08Error\uff09 \u3002\u8bef\u5dee\u5f88\u7b80\u5355\uff0c\u4e5f\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \\[ Error = True\\ Value - Predicted\\ Value \\] \u7edd\u5bf9\u8bef\u5dee\uff08Absolute error\uff09 \u53ea\u662f\u4e0a\u8ff0\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\u3002 \\[ Absolute\\ Error = Abs(True\\ Value - Predicted\\ Value) \\] \u63a5\u4e0b\u6765\u6211\u4eec\u8ba8\u8bba \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff08MAE\uff09 \u3002\u5b83\u53ea\u662f\u6240\u6709\u7edd\u5bf9\u8bef\u5dee\u7684\u5e73\u5747\u503c\u3002 import numpy as np def mean_absolute_error ( y_true , y_pred ): #\u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u7edd\u5bf9\u8bef\u5dee error += np . abs ( yt - yp ) # \u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee return error / len ( y_true ) \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5e73\u65b9\u8bef\u5dee\u548c \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u3002 \\[ Squared\\ Error = (True Value - Predicted\\ Value)^2 \\] \u5747\u65b9\u8bef\u5dee\uff08MSE\uff09\u7684\u8ba1\u7b97\u65b9\u5f0f\u5982\u4e0b def mean_squared_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u8bef\u5dee\u5e73\u65b9\u548c error += ( yt - yp ) ** 2 # \u8ba1\u7b97\u5747\u65b9\u8bef\u5dee return error / len ( y_true ) MSE \u548c RMSE\uff08\u5747\u65b9\u6839\u8bef\u5dee\uff09 \u662f\u8bc4\u4f30\u56de\u5f52\u6a21\u578b\u6700\u5e38\u7528\u7684\u6307\u6807\u3002 \\[ RMSE = SQRT(MSE) \\] \u540c\u4e00\u7c7b\u8bef\u5dee\u7684\u53e6\u4e00\u79cd\u7c7b\u578b\u662f \u5e73\u65b9\u5bf9\u6570\u8bef\u5dee \u3002\u6709\u4eba\u79f0\u5176\u4e3a SLE \uff0c\u5f53\u6211\u4eec\u53d6\u6240\u6709\u6837\u672c\u4e2d\u8fd9\u4e00\u8bef\u5dee\u7684\u5e73\u5747\u503c\u65f6\uff0c\u5b83\u88ab\u79f0\u4e3a MSLE\uff08\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee\uff09\uff0c\u5b9e\u73b0\u65b9\u6cd5\u5982\u4e0b\u3002 import numpy as np def mean_squared_log_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee error += ( np . log ( 1 + yt ) - np . log ( 1 + yp )) ** 2 # \u8ba1\u7b97\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee return error / len ( y_true ) \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \u53ea\u662f\u5176\u5e73\u65b9\u6839\u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a RMSLE \u3002 \u7136\u540e\u662f\u767e\u5206\u6bd4\u8bef\u5dee\uff1a \\[ Percentage\\ Error = (( True\\ Value \u2013 Predicted\\ Value ) / True\\ Value ) \\times 100 \\] \u540c\u6837\u53ef\u4ee5\u8f6c\u6362\u4e3a\u6240\u6709\u6837\u672c\u7684\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee\u3002 def mean_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u767e\u5206\u6bd4\u8bef\u5dee error += ( yt - yp ) / yt # \u8fd4\u56de\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u7edd\u5bf9\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\uff08\u4e5f\u662f\u66f4\u5e38\u89c1\u7684\u7248\u672c\uff09\u88ab\u79f0\u4e3a \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee\u6216 MAPE \u3002 import numpy as np def mean_abs_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee error += np . abs ( yt - yp ) / yt #\u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u56de\u5f52\u7684\u6700\u5927\u4f18\u70b9\u662f\uff0c\u53ea\u6709\u51e0\u4e2a\u6700\u5e38\u7528\u7684\u6307\u6807\uff0c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u6240\u6709\u56de\u5f52\u95ee\u9898\u3002\u4e0e\u5206\u7c7b\u6307\u6807\u76f8\u6bd4\uff0c\u56de\u5f52\u6307\u6807\u66f4\u5bb9\u6613\u7406\u89e3\u3002 \u8ba9\u6211\u4eec\u6765\u8c08\u8c08\u53e6\u4e00\u4e2a\u56de\u5f52\u6307\u6807 \\(R^2\\) \uff08R \u65b9\uff09\uff0c\u4e5f\u79f0\u4e3a \u5224\u5b9a\u7cfb\u6570 \u3002 \u7b80\u5355\u5730\u8bf4\uff0cR \u65b9\u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u3002R \u65b9\u63a5\u8fd1 1.0 \u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u76f8\u5f53\u597d\uff0c\u800c\u63a5\u8fd1 0 \u5219\u8868\u793a\u6a21\u578b\u4e0d\u662f\u90a3\u4e48\u597d\u3002\u5f53\u6a21\u578b\u53ea\u662f\u505a\u51fa\u8352\u8c2c\u7684\u9884\u6d4b\u65f6\uff0cR \u65b9\u4e5f\u53ef\u80fd\u662f\u8d1f\u503c\u3002 R \u65b9\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\u6240\u793a\uff0c\u4f46 Python \u7684\u5b9e\u73b0\u603b\u662f\u80fd\u8ba9\u4e00\u5207\u66f4\u52a0\u6e05\u6670\u3002 \\[ R^2 = \\frac{\\sum^{N}_{i=1}(y_{t_i}-y_{p_i})^2}{\\sum^{N}_{i=1}(y_{t_i} - y_{t_{mean}})} \\] import numpy as np def r2 ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u5747\u771f\u5b9e\u503c mean_true_value = np . mean ( y_true ) # \u521d\u59cb\u5316\u5e73\u65b9\u8bef\u5dee numerator = 0 denominator = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): numerator += ( yt - yp ) ** 2 denominator += ( yt - mean_true_value ) ** 2 ratio = numerator / denominator # \u8ba1\u7b97R\u65b9 return 1 \u2013 ratio \u8fd8\u6709\u66f4\u591a\u7684\u8bc4\u4ef7\u6307\u6807\uff0c\u8fd9\u4e2a\u6e05\u5355\u6c38\u8fdc\u4e5f\u5217\u4e0d\u5b8c\u3002\u6211\u53ef\u4ee5\u5199\u4e00\u672c\u4e66\uff0c\u53ea\u4ecb\u7ecd\u4e0d\u540c\u7684\u8bc4\u4ef7\u6307\u6807\u3002\u4e5f\u8bb8\u6211\u4f1a\u7684\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b\u8bc4\u4f30\u6307\u6807\u51e0\u4e4e\u53ef\u4ee5\u6ee1\u8db3\u4f60\u60f3\u5c1d\u8bd5\u89e3\u51b3\u7684\u6240\u6709\u95ee\u9898\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ee5\u6700\u76f4\u63a5\u7684\u65b9\u5f0f\u5b9e\u73b0\u4e86\u8fd9\u4e9b\u6307\u6807\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u4eec\u4e0d\u591f\u9ad8\u6548\u3002\u4f60\u53ef\u4ee5\u901a\u8fc7\u6b63\u786e\u4f7f\u7528 numpy \u4ee5\u975e\u5e38\u9ad8\u6548\u7684\u65b9\u5f0f\u5b9e\u73b0\u5176\u4e2d\u5927\u90e8\u5206\u6307\u6807\u3002\u4f8b\u5982\uff0c\u770b\u770b\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\u7684\u5b9e\u73b0\uff0c\u4e0d\u9700\u8981\u4efb\u4f55\u5faa\u73af\u3002 import numpy as np def mae_np ( y_true , y_pred ): return np . mean ( np . abs ( y_true - y_pred )) \u6211\u672c\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5b9e\u73b0\u6240\u6709\u6307\u6807\uff0c\u4f46\u4e3a\u4e86\u5b66\u4e60\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b\u5e95\u5c42\u5b9e\u73b0\u3002\u4e00\u65e6\u4f60\u5b66\u4f1a\u4e86\u7eaf python \u7684\u5e95\u5c42\u5b9e\u73b0\uff0c\u5e76\u4e14\u4e0d\u4f7f\u7528\u5927\u91cf numpy\uff0c\u4f60\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u5176\u8f6c\u6362\u4e3a numpy\uff0c\u5e76\u4f7f\u5176\u53d8\u5f97\u66f4\u5feb\u3002 \u7136\u540e\u662f\u4e00\u4e9b\u9ad8\u7ea7\u5ea6\u91cf\u3002 \u5176\u4e2d\u4e00\u4e2a\u5e94\u7528\u76f8\u5f53\u5e7f\u6cdb\u7684\u6307\u6807\u662f \u4e8c\u6b21\u52a0\u6743\u5361\u5e15 \uff0c\u4e5f\u79f0\u4e3a QWK \u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a\u79d1\u6069\u5361\u5e15\u3002 QWK \u8861\u91cf\u4e24\u4e2a \"\u8bc4\u5206 \"\u4e4b\u95f4\u7684 \"\u4e00\u81f4\u6027\"\u3002\u8bc4\u5206\u53ef\u4ee5\u662f 0 \u5230 N \u4e4b\u95f4\u7684\u4efb\u4f55\u5b9e\u6570\uff0c\u9884\u6d4b\u4e5f\u5728\u540c\u4e00\u8303\u56f4\u5185\u3002\u4e00\u81f4\u6027\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u8fd9\u4e9b\u8bc4\u7ea7\u4e4b\u95f4\u7684\u63a5\u8fd1\u7a0b\u5ea6\u3002\u56e0\u6b64\uff0c\u5b83\u9002\u7528\u4e8e\u6709 N \u4e2a\u4e0d\u540c\u7c7b\u522b\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u4e00\u81f4\u5ea6\u9ad8\uff0c\u5206\u6570\u5c31\u66f4\u63a5\u8fd1 1.0\u3002Cohen's kappa \u5728 scikit-learn \u4e2d\u6709\u5f88\u597d\u7684\u5b9e\u73b0\uff0c\u5173\u4e8e\u8be5\u6307\u6807\u7684\u8be6\u7ec6\u8ba8\u8bba\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 ] In [ X ]: y_pred = [ 2 , 1 , 3 , 1 , 2 , 3 , 3 , 1 , 2 ] In [ X ]: metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) Out [ X ]: 0.33333333333333337 In [ X ]: metrics . accuracy_score ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 \u60a8\u53ef\u4ee5\u770b\u5230\uff0c\u5c3d\u7ba1\u51c6\u786e\u5ea6\u5f88\u9ad8\uff0c\u4f46 QWK \u5374\u5f88\u4f4e\u3002QWK \u5927\u4e8e 0.85 \u5373\u4e3a\u975e\u5e38\u597d\uff01 \u4e00\u4e2a\u91cd\u8981\u7684\u6307\u6807\u662f \u9a6c\u4fee\u76f8\u5173\u7cfb\u6570\uff08MCC\uff09 \u30021 \u4ee3\u8868\u5b8c\u7f8e\u9884\u6d4b\uff0c-1 \u4ee3\u8868\u4e0d\u5b8c\u7f8e\u9884\u6d4b\uff0c0 \u4ee3\u8868\u968f\u673a\u9884\u6d4b\u3002MCC \u7684\u8ba1\u7b97\u516c\u5f0f\u975e\u5e38\u7b80\u5355\u3002 \\[ MCC = \\frac{TP \\times TN - FP \\times FN}{\\sqrt{(TP + FP) \\times (FN + TN) \\times (FP + TN) \\times (TP + FN)}} \\] \u6211\u4eec\u770b\u5230\uff0cMCC \u8003\u8651\u4e86 TP\u3001FP\u3001TN \u548c FN\uff0c\u56e0\u6b64\u53ef\u7528\u4e8e\u5904\u7406\u7c7b\u504f\u659c\u7684\u95ee\u9898\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u7684\u65b9\u6cd5\u5728 python \u4e2d\u5feb\u901f\u5b9e\u73b0\u5b83\u3002 def mcc ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) numerator = ( tp * tn ) - ( fp * fn ) denominator = ( ( tp + fp ) * ( fn + tn ) * ( fp + tn ) * ( tp + fn ) ) denominator = denominator ** 0.5 return numerator / denominator \u8fd9\u4e9b\u6307\u6807\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002 \u9700\u8981\u6ce8\u610f\u7684\u4e00\u70b9\u662f\uff0c\u5728\u8bc4\u4f30\u975e\u76d1\u7763\u65b9\u6cd5\uff08\u4f8b\u5982\u67d0\u79cd\u805a\u7c7b\uff09\u65f6\uff0c\u6700\u597d\u521b\u5efa\u6216\u624b\u52a8\u6807\u8bb0\u6d4b\u8bd5\u96c6\uff0c\u5e76\u5c06\u5176\u4e0e\u5efa\u6a21\u90e8\u5206\u7684\u6240\u6709\u5185\u5bb9\u5206\u5f00\u3002\u5b8c\u6210\u805a\u7c7b\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u4efb\u4f55\u4e00\u79cd\u76d1\u7763\u5b66\u4e60\u6307\u6807\u6765\u8bc4\u4f30\u6d4b\u8bd5\u96c6\u7684\u6027\u80fd\u4e86\u3002 \u4e00\u65e6\u6211\u4eec\u4e86\u89e3\u4e86\u7279\u5b9a\u95ee\u9898\u5e94\u8be5\u4f7f\u7528\u4ec0\u4e48\u6307\u6807\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f00\u59cb\u66f4\u6df1\u5165\u5730\u7814\u7a76\u6211\u4eec\u7684\u6a21\u578b\uff0c\u4ee5\u6c42\u6539\u8fdb\u3002","title":"\u8bc4\u4f30\u6307\u6807"},{"location":"%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87/#_1","text":"\u8bf4\u5230\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u4f60\u4f1a\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u9047\u5230\u5f88\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u6307\u6807\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u751a\u81f3\u4f1a\u6839\u636e\u4e1a\u52a1\u95ee\u9898\u521b\u5efa\u5ea6\u91cf\u6807\u51c6\u3002\u9010\u4e00\u4ecb\u7ecd\u548c\u89e3\u91ca\u6bcf\u4e00\u79cd\u5ea6\u91cf\u7c7b\u578b\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u5ea6\u91cf\u6807\u51c6\uff0c\u4f9b\u4f60\u5728\u6700\u521d\u7684\u51e0\u4e2a\u9879\u76ee\u4e2d\u4f7f\u7528\u3002 \u5728\u672c\u4e66\u7684\u5f00\u5934\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u76d1\u7763\u5b66\u4e60\u548c\u975e\u76d1\u7763\u5b66\u4e60\u3002\u867d\u7136\u65e0\u76d1\u7763\u5b66\u4e60\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e9b\u6307\u6807\uff0c\u4f46\u6211\u4eec\u5c06\u53ea\u5173\u6ce8\u6709\u76d1\u7763\u5b66\u4e60\u3002\u8fd9\u662f\u56e0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u591a\uff0c\u800c\u4e14\u5bf9\u65e0\u76d1\u7763\u65b9\u6cd5\u7684\u8bc4\u4f30\u76f8\u5f53\u4e3b\u89c2\u3002 \u5982\u679c\u6211\u4eec\u8c08\u8bba\u5206\u7c7b\u95ee\u9898\uff0c\u6700\u5e38\u7528\u7684\u6307\u6807\u662f\uff1a \u51c6\u786e\u7387\uff08Accuracy\uff09 \u7cbe\u786e\u7387\uff08P\uff09 \u53ec\u56de\u7387\uff08R\uff09 F1 \u5206\u6570\uff08F1\uff09 AUC\uff08AUC\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u8bf4\u5230\u56de\u5f52\uff0c\u6700\u5e38\u7528\u7684\u8bc4\u4ef7\u6307\u6807\u662f \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee \uff08MAE\uff09 \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u5747\u65b9\u6839\u8bef\u5dee \uff08RMSE\uff09 \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \uff08RMSLE\uff09 \u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee \uff08MPE\uff09 \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee \uff08MAPE\uff09 R2 \u4e86\u89e3\u4e0a\u8ff0\u6307\u6807\u7684\u5de5\u4f5c\u539f\u7406\u5e76\u4e0d\u662f\u6211\u4eec\u5fc5\u987b\u4e86\u89e3\u7684\u552f\u4e00\u4e8b\u60c5\u3002\u6211\u4eec\u8fd8\u5fc5\u987b\u77e5\u9053\u4f55\u65f6\u4f7f\u7528\u54ea\u4e9b\u6307\u6807\uff0c\u800c\u8fd9\u53d6\u51b3\u4e8e\u4f60\u6709\u4ec0\u4e48\u6837\u7684\u6570\u636e\u548c\u76ee\u6807\u3002\u6211\u8ba4\u4e3a\u8fd9\u4e0e\u76ee\u6807\u6709\u5173\uff0c\u800c\u4e0e\u6570\u636e\u65e0\u5173\u3002 \u8981\u8fdb\u4e00\u6b65\u4e86\u89e3\u8fd9\u4e9b\u6307\u6807\uff0c\u8ba9\u6211\u4eec\u4ece\u4e00\u4e2a\u7b80\u5355\u7684\u95ee\u9898\u5f00\u59cb\u3002\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a \u4e8c\u5143\u5206\u7c7b \u95ee\u9898\uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u76ee\u6807\u7684\u95ee\u9898\uff0c\u5047\u8bbe\u8fd9\u662f\u4e00\u4e2a\u80f8\u90e8 X \u5149\u56fe\u50cf\u5206\u7c7b\u95ee\u9898\u3002\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6ca1\u6709\u95ee\u9898\uff0c\u800c\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6709\u80ba\u584c\u9677\uff0c\u4e5f\u5c31\u662f\u6240\u8c13\u7684\u6c14\u80f8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u5206\u7c7b\u5668\uff0c\u5728\u7ed9\u5b9a\u80f8\u90e8 X \u5149\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\uff0c\u5b83\u80fd\u68c0\u6d4b\u51fa\u56fe\u50cf\u662f\u5426\u6709\u6c14\u80f8\u3002 \u56fe 1\uff1a\u6c14\u80f8\u80ba\u90e8\u56fe\u50cf \u6211\u4eec\u8fd8\u5047\u8bbe\u6709\u76f8\u540c\u6570\u91cf\u7684\u6c14\u80f8\u548c\u975e\u6c14\u80f8\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u6bd4\u5982\u5404 100 \u5f20\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709 100 \u5f20\u9633\u6027\u6837\u672c\u548c 100 \u5f20\u9634\u6027\u6837\u672c\uff0c\u5171\u8ba1 200 \u5f20\u56fe\u50cf\u3002 \u7b2c\u4e00\u6b65\u662f\u5c06\u4e0a\u8ff0\u6570\u636e\u5206\u4e3a\u4e24\u7ec4\uff0c\u6bcf\u7ec4 100 \u5f20\u56fe\u50cf\uff0c\u5373\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u8fd9\u4e24\u4e2a\u96c6\u5408\u4e2d\uff0c\u6211\u4eec\u90fd\u6709 50 \u4e2a\u6b63\u6837\u672c\u548c 50 \u4e2a\u8d1f\u6837\u672c\u3002 \u5728\u4e8c\u5143\u5206\u7c7b\u6307\u6807\u4e2d\uff0c\u5f53\u6b63\u8d1f\u6837\u672c\u6570\u91cf\u76f8\u7b49\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f7f\u7528\u51c6\u786e\u7387\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002 \u51c6\u786e\u7387 \uff1a\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6700\u76f4\u63a5\u7684\u6307\u6807\u4e4b\u4e00\u3002\u5b83\u5b9a\u4e49\u4e86\u6a21\u578b\u7684\u51c6\u786e\u5ea6\u3002\u5bf9\u4e8e\u4e0a\u8ff0\u95ee\u9898\uff0c\u5982\u679c\u4f60\u5efa\u7acb\u7684\u6a21\u578b\u80fd\u51c6\u786e\u5206\u7c7b 90 \u5f20\u56fe\u7247\uff0c\u90a3\u4e48\u4f60\u7684\u51c6\u786e\u7387\u5c31\u662f 90% \u6216 0.90\u3002\u5982\u679c\u53ea\u6709 83 \u5e45\u56fe\u50cf\u88ab\u6b63\u786e\u5206\u7c7b\uff0c\u90a3\u4e48\u6a21\u578b\u7684\u51c6\u786e\u7387\u5c31\u662f 83% \u6216 0.83\u3002 \u8ba1\u7b97\u51c6\u786e\u7387\u7684 Python \u4ee3\u7801\u4e5f\u975e\u5e38\u7b80\u5355\u3002 def accuracy ( y_true , y_pred ): # \u4e3a\u6b63\u786e\u9884\u6d4b\u6570\u521d\u59cb\u5316\u4e00\u4e2a\u7b80\u5355\u8ba1\u6570\u5668 correct_counter = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): if yt == yp : # \u5982\u679c\u9884\u6d4b\u6807\u7b7e\u4e0e\u771f\u5b9e\u6807\u7b7e\u76f8\u540c\uff0c\u5219\u589e\u52a0\u8ba1\u6570\u5668 correct_counter += 1 # \u8fd4\u56de\u6b63\u786e\u7387\uff0c\u6b63\u786e\u6807\u7b7e\u6570/\u603b\u6807\u7b7e\u6570 return correct_counter / len ( y_true ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u8ba1\u7b97\u51c6\u786e\u7387\u3002 In [ X ]: from sklearn import metrics ... : l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] ... : metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u73b0\u5728\uff0c\u5047\u8bbe\u6211\u4eec\u628a\u6570\u636e\u96c6\u7a0d\u5fae\u6539\u52a8\u4e00\u4e0b\uff0c\u6709 180 \u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u53ea\u6709 20 \u5f20\u6709\u6c14\u80f8\u3002\u5373\u4f7f\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4e5f\u8981\u521b\u5efa\u6b63\u8d1f\uff08\u6c14\u80f8\u4e0e\u975e\u6c14\u80f8\uff09\u76ee\u6807\u6bd4\u4f8b\u76f8\u540c\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u6bcf\u4e00\u7ec4\u4e2d\uff0c\u6211\u4eec\u6709 90 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u5982\u679c\u8bf4\u9a8c\u8bc1\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u90fd\u662f\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u90a3\u4e48\u60a8\u7684\u51c6\u786e\u7387\u4f1a\u662f\u591a\u5c11\u5462\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\uff1b\u60a8\u5bf9 90% \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u6b63\u786e\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u7684\u51c6\u786e\u7387\u662f 90%\u3002 \u4f46\u8bf7\u518d\u770b\u4e00\u904d\u3002 \u4f60\u751a\u81f3\u6ca1\u6709\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5c31\u5f97\u5230\u4e86 90% \u7684\u51c6\u786e\u7387\u3002\u8fd9\u4f3c\u4e4e\u6709\u70b9\u6ca1\u7528\u3002\u5982\u679c\u6211\u4eec\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u6570\u636e\u96c6\u662f\u504f\u659c\u7684\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u6bd4\u53e6\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u591a\u5f88\u591a\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u56e0\u4e3a\u5b83\u4e0d\u80fd\u4ee3\u8868\u6570\u636e\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u80fd\u4f1a\u83b7\u5f97\u5f88\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u4f46\u60a8\u7684\u6a21\u578b\u5728\u5b9e\u9645\u6837\u672c\u4e2d\u7684\u8868\u73b0\u53ef\u80fd\u5e76\u4e0d\u7406\u60f3\uff0c\u800c\u4e14\u60a8\u4e5f\u65e0\u6cd5\u5411\u7ecf\u7406\u89e3\u91ca\u539f\u56e0\u3002 \u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b \u7cbe\u786e\u7387 \u7b49\u5176\u4ed6\u6307\u6807\u3002 \u5728\u5b66\u4e60\u7cbe\u786e\u7387\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e86\u89e3\u4e00\u4e9b\u672f\u8bed\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5047\u8bbe\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e3a\u6b63\u7c7b (1)\uff0c\u6ca1\u6709\u6c14\u80f8\u7684\u4e3a\u8d1f\u7c7b (0)\u3002 \u771f\u9633\u6027 \uff08TP\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6709\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9633\u6027\u3002 \u771f\u9634\u6027 \uff08TN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u800c\u5b9e\u9645\u76ee\u6807\u663e\u793a\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u6b63\u786e\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9633\u6027\uff1b\u5982\u679c\u60a8\u7684\u6a21\u578b\u51c6\u786e\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9634\u6027\u3002 \u5047\u9633\u6027 \uff08FP\uff09 \uff1a\u7ed9\u5b9a\u4e00\u5f20\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u975e\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9633\u6027\u3002 \u5047\u9634\u6027 \uff08FN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u975e\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\uff0c\u90a3\u4e48\u5b83\u5c31\u662f\u5047\u9633\u6027\u3002\u5982\u679c\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5219\u662f\u5047\u9634\u6027\u3002 \u8ba9\u6211\u4eec\u9010\u4e00\u770b\u770b\u8fd9\u4e9b\u5b9e\u73b0\u3002 def true_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u8ba1\u6570\u5668 tp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 1 : tp += 1 # \u8fd4\u56de\u771f\u9633\u6027\u6837\u672c\u6570 return tp def true_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9634\u6027\u6837\u672c\u8ba1\u6570\u5668 tn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 0 : tn += 1 # \u8fd4\u56de\u771f\u9634\u6027\u6837\u672c\u6570 return tn def false_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9633\u6027\u8ba1\u6570\u5668 fp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 1 : fp += 1 # \u8fd4\u56de\u5047\u9633\u6027\u6837\u672c\u6570 return fp def false_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9634\u6027\u8ba1\u6570\u5668 fn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 0 : fn += 1 # \u8fd4\u56de\u5047\u9634\u6027\u6570 return fn \u6211\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e9b\u529f\u80fd\u7684\u65b9\u6cd5\u975e\u5e38\u7b80\u5355\uff0c\u800c\u4e14\u53ea\u9002\u7528\u4e8e\u4e8c\u5143\u5206\u7c7b\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u8fd9\u4e9b\u51fd\u6570\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: true_positive ( l1 , l2 ) Out [ X ]: 2 In [ X ]: false_positive ( l1 , l2 ) Out [ X ]: 1 In [ X ]: false_negative ( l1 , l2 ) Out [ X ]: 2 In [ X ]: true_negative ( l1 , l2 ) Out [ X ]: 3 \u5982\u679c\u6211\u4eec\u5fc5\u987b\u7528\u4e0a\u8ff0\u672f\u8bed\u6765\u5b9a\u4e49\u7cbe\u786e\u7387\uff0c\u6211\u4eec\u53ef\u4ee5\u5199\u4e3a\uff1a \\[ Accuracy Score = (TP + TN)/(TP + TN + FP +FN) \\] \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5728 python \u4e2d\u4f7f\u7528 TP\u3001TN\u3001FP \u548c FN \u5feb\u901f\u5b9e\u73b0\u51c6\u786e\u5ea6\u5f97\u5206\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a accuracy_v2\u3002 def accuracy_v2 ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u51c6\u786e\u7387 accuracy_score = ( tp + tn ) / ( tp + tn + fp + fn ) return accuracy_score \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0e\u4e4b\u524d\u7684\u5b9e\u73b0\u548c scikit-learn \u7248\u672c\u8fdb\u884c\u6bd4\u8f83\uff0c\u5feb\u901f\u68c0\u67e5\u8be5\u51fd\u6570\u7684\u6b63\u786e\u6027\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: accuracy ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: accuracy_v2 ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u8bf7\u6ce8\u610f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0cmetrics.accuracy_score \u6765\u81ea scikit-learn\u3002 \u5f88\u597d\u3002\u6240\u6709\u503c\u90fd\u5339\u914d\u3002\u8fd9\u8bf4\u660e\u6211\u4eec\u5728\u5b9e\u73b0\u8fc7\u7a0b\u4e2d\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u5176\u4ed6\u91cd\u8981\u6307\u6807\u3002 \u9996\u5148\u662f\u7cbe\u786e\u7387\u3002\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u662f \\[ Precision = TP/(TP + FP) \\] \u5047\u8bbe\u6211\u4eec\u5728\u65b0\u7684\u504f\u659c\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u4e86\u4e00\u4e2a\u65b0\u6a21\u578b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 90 \u5f20\u56fe\u50cf\u4e2d\u7684 80 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u56fe\u50cf\u4e2d\u7684 8 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6210\u529f\u8bc6\u522b\u4e86 100 \u5f20\u56fe\u50cf\u4e2d\u7684 88 \u5f20\u3002\u56e0\u6b64\uff0c\u51c6\u786e\u7387\u4e3a 0.88 \u6216 88%\u3002 \u4f46\u662f\uff0c\u5728\u8fd9 100 \u5f20\u6837\u672c\u4e2d\uff0c\u6709 10 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u6c14\u80f8\uff0c2 \u5f20\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u975e\u6c14\u80f8\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 TP : 8 TN: 80 FP: 10 FN: 2 \u7cbe\u786e\u7387\u4e3a 8 / (8 + 10) = 0.444\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u5728\u8bc6\u522b\u9633\u6027\u6837\u672c\uff08\u6c14\u80f8\uff09\u65f6\u6709 44.4% \u7684\u6b63\u786e\u7387\u3002 \u73b0\u5728\uff0c\u65e2\u7136\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86 TP\u3001TN\u3001FP \u548c FN\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5728 python \u4e2d\u5b9e\u73b0\u7cbe\u786e\u7387\u4e86\u3002 def precision ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8ba9\u6211\u4eec\u8bd5\u8bd5\u8fd9\u79cd\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u65b9\u5f0f\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: precision ( l1 , l2 ) Out [ X ]: 0.6666666666666666 \u8fd9\u4f3c\u4e4e\u6ca1\u6709\u95ee\u9898\u3002 \u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u770b \u53ec\u56de\u7387 \u3002\u53ec\u56de\u7387\u7684\u5b9a\u4e49\u662f\uff1a \\[ Recall = TP/(TP + FN) \\] \u5728\u4e0a\u8ff0\u60c5\u51b5\u4e0b\uff0c\u53ec\u56de\u7387\u4e3a 8 / (8 + 2) = 0.80\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 80% \u7684\u9633\u6027\u6837\u672c\u3002 def recall ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u53ec\u56de\u7387 recall = tp / ( tp + fn ) return recall \u5c31\u6211\u4eec\u7684\u4e24\u4e2a\u5c0f\u5217\u8868\u800c\u8a00\uff0c\u53ec\u56de\u7387\u5e94\u8be5\u662f 0.5\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: recall ( l1 , l2 ) Out [ X ]: 0.5 \u8fd9\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u503c\u76f8\u7b26\uff01 \u5bf9\u4e8e\u4e00\u4e2a \"\u597d \"\u6a21\u578b\u6765\u8bf4\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u503c\u90fd\u5e94\u8be5\u5f88\u9ad8\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0c\u53ec\u56de\u503c\u76f8\u5f53\u9ad8\u3002\u4f46\u662f\uff0c\u7cbe\u786e\u7387\u5374\u5f88\u4f4e\uff01\u6211\u4eec\u7684\u6a21\u578b\u4ea7\u751f\u4e86\u5927\u91cf\u7684\u8bef\u62a5\uff0c\u4f46\u6f0f\u62a5\u8f83\u5c11\u3002\u5728\u8fd9\u7c7b\u95ee\u9898\u4e2d\uff0c\u5047\u9634\u6027\u8f83\u5c11\u662f\u597d\u4e8b\uff0c\u56e0\u4e3a\u4f60\u4e0d\u60f3\u5728\u75c5\u4eba\u6709\u6c14\u80f8\u7684\u60c5\u51b5\u4e0b\u5374\u8bf4\u4ed6\u4eec\u6ca1\u6709\u6c14\u80f8\u3002\u8fd9\u6837\u505a\u4f1a\u9020\u6210\u66f4\u5927\u7684\u4f24\u5bb3\u3002\u4f46\u6211\u4eec\u4e5f\u6709\u5f88\u591a\u5047\u9633\u6027\u7ed3\u679c\uff0c\u8fd9\u4e5f\u4e0d\u662f\u597d\u4e8b\u3002 \u5927\u591a\u6570\u6a21\u578b\u90fd\u4f1a\u9884\u6d4b\u4e00\u4e2a\u6982\u7387\uff0c\u5f53\u6211\u4eec\u9884\u6d4b\u65f6\uff0c\u901a\u5e38\u4f1a\u5c06\u8fd9\u4e2a\u9608\u503c\u9009\u4e3a 0.5\u3002\u8fd9\u4e2a\u9608\u503c\u5e76\u4e0d\u603b\u662f\u7406\u60f3\u7684\uff0c\u6839\u636e\u8fd9\u4e2a\u9608\u503c\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u503c\u53ef\u80fd\u4f1a\u53d1\u751f\u5f88\u5927\u7684\u53d8\u5316\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u7684\u6bcf\u4e2a\u9608\u503c\u90fd\u80fd\u8ba1\u7b97\u51fa\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u53ef\u4ee5\u5728\u8fd9\u4e9b\u503c\u4e4b\u95f4\u7ed8\u5236\u51fa\u66f2\u7ebf\u56fe\u3002\u8fd9\u5e45\u56fe\u6216\u66f2\u7ebf\u88ab\u79f0\u4e3a \"\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\"\u3002 \u5728\u7814\u7a76\u7cbe\u786e\u7387-\u8c03\u7528\u66f2\u7ebf\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5047\u8bbe\u6709\u4e24\u4e2a\u5217\u8868\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0.02638412 , 0.11114267 , 0.31620708 , ... : 0.0490937 , 0.0191491 , 0.17554844 , ... : 0.15952202 , 0.03819563 , 0.11639273 , ... : 0.079377 , 0.08584789 , 0.39095342 , ... : 0.27259048 , 0.03447096 , 0.04644807 , ... : 0.03543574 , 0.18521942 , 0.05934905 , ... : 0.61977213 , 0.33056815 ] \u56e0\u6b64\uff0cy_true \u662f\u6211\u4eec\u7684\u76ee\u6807\u503c\uff0c\u800c y_pred \u662f\u6837\u672c\u88ab\u8d4b\u503c\u4e3a 1 \u7684\u6982\u7387\u503c\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u6211\u4eec\u8981\u770b\u7684\u662f\u9884\u6d4b\u4e2d\u7684\u6982\u7387\uff0c\u800c\u4e0d\u662f\u9884\u6d4b\u503c\uff08\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u9884\u6d4b\u503c\u7684\u8ba1\u7b97\u9608\u503c\u4e3a 0.5\uff09\u3002 precisions = [] recalls = [] thresholds = [ 0.0490937 , 0.05934905 , 0.079377 , 0.08584789 , 0.11114267 , 0.11639273 , 0.15952202 , 0.17554844 , 0.18521942 , 0.27259048 , 0.31620708 , 0.33056815 , 0.39095342 , 0.61977213 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for i in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_prediction = [ 1 if x >= i else 0 for x in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , temp_prediction ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , temp_prediction ) # \u52a0\u5165\u7cbe\u786e\u7387\u5217\u8868 precisions . append ( p ) # \u52a0\u5165\u53ec\u56de\u7387\u5217\u8868 recalls . append ( r ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7ed8\u5236\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 7 , 7 )) # x\u8f74\u4e3a\u53ec\u56de\u7387\uff0cy\u8f74\u4e3a\u7cbe\u786e\u7387 plt . plot ( recalls , precisions ) # \u6dfb\u52a0x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a15 plt . xlabel ( 'Recall' , fontsize = 15 ) # \u6dfb\u52a0y\u8f74\u6807\u7b7e\uff0c\u5b57\u6761\u5927\u5c0f\u4e3a15 plt . ylabel ( 'Precision' , fontsize = 15 ) \u56fe 2 \u663e\u793a\u4e86\u6211\u4eec\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\u5f97\u5230\u7684\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 \u56fe 2\uff1a\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u8fd9\u6761 \u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u4e0e\u60a8\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230\u7684\u66f2\u7ebf\u622a\u7136\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6211\u4eec\u53ea\u6709 20 \u4e2a\u6837\u672c\uff0c\u5176\u4e2d\u53ea\u6709 3 \u4e2a\u662f\u9633\u6027\u6837\u672c\u3002\u4f46\u8fd9\u6ca1\u4ec0\u4e48\u597d\u62c5\u5fc3\u7684\u3002\u8fd9\u8fd8\u662f\u90a3\u6761\u7cbe\u786e\u7387-\u53ec\u56de\u66f2\u7ebf\u3002 \u4f60\u4f1a\u53d1\u73b0\uff0c\u9009\u62e9\u4e00\u4e2a\u65e2\u80fd\u63d0\u4f9b\u826f\u597d\u7cbe\u786e\u7387\u53c8\u80fd\u63d0\u4f9b\u53ec\u56de\u503c\u7684\u9608\u503c\u662f\u5f88\u6709\u6311\u6218\u6027\u7684\u3002\u5982\u679c\u9608\u503c\u8fc7\u9ad8\uff0c\u771f\u9633\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u51cf\u5c11\uff0c\u800c\u5047\u9634\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u589e\u52a0\u3002\u8fd9\u4f1a\u964d\u4f4e\u53ec\u56de\u7387\uff0c\u4f46\u7cbe\u786e\u7387\u5f97\u5206\u4f1a\u5f88\u9ad8\u3002\u5982\u679c\u5c06\u9608\u503c\u964d\u5f97\u592a\u4f4e\uff0c\u5219\u8bef\u62a5\u4f1a\u5927\u91cf\u589e\u52a0\uff0c\u7cbe\u786e\u7387\u4e5f\u4f1a\u964d\u4f4e\u3002 \u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u8d8a\u63a5\u8fd1 1 \u8d8a\u597d\u3002 F1 \u5206\u6570\u662f\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7efc\u5408\u6307\u6807\u3002\u5b83\u88ab\u5b9a\u4e49\u4e3a\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7b80\u5355\u52a0\u6743\u5e73\u5747\u503c\uff08\u8c03\u548c\u5e73\u5747\u503c\uff09\u3002\u5982\u679c\u6211\u4eec\u7528 P \u8868\u793a\u7cbe\u786e\u7387\uff0c\u7528 R \u8868\u793a\u53ec\u56de\u7387\uff0c\u90a3\u4e48 F1 \u5206\u6570\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a \\[ F1 = 2PR/(P + R) \\] \u6839\u636e TP\u3001FP \u548c FN\uff0c\u7a0d\u52a0\u6570\u5b66\u8ba1\u7b97\u5c31\u80fd\u5f97\u51fa\u4ee5\u4e0b F1 \u7b49\u5f0f\uff1a \\[ F1 = 2TP/(2TP + FP + FN) \\] Python \u5b9e\u73b0\u5f88\u7b80\u5355\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86\u8fd9\u4e9b def f1 ( y_true , y_pred ): # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , y_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , y_pred ) # \u8ba1\u7b97f1\u503c score = 2 * p * r / ( p + r ) return score \u8ba9\u6211\u4eec\u770b\u770b\u5176\u7ed3\u679c\uff0c\u5e76\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: f1 ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u901a\u8fc7 scikit learn\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u76f8\u540c\u7684\u5217\u8868\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . f1_score ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u4e0e\u5176\u5355\u72ec\u770b\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u60a8\u8fd8\u53ef\u4ee5\u53ea\u770b F1 \u5206\u6570\u3002\u4e0e\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c\u51c6\u786e\u5ea6\u4e00\u6837\uff0cF1 \u5206\u6570\u7684\u8303\u56f4\u4e5f\u662f\u4ece 0 \u5230 1\uff0c\u5b8c\u7f8e\u9884\u6d4b\u6a21\u578b\u7684 F1 \u5206\u6570\u4e3a 1\u3002 \u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5e94\u8be5\u4e86\u89e3\u5176\u4ed6\u4e00\u4e9b\u5173\u952e\u672f\u8bed\u3002 \u7b2c\u4e00\u4e2a\u672f\u8bed\u662f TPR \u6216\u771f\u9633\u6027\u7387\uff08True Positive Rate\uff09\uff0c\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\u3002 \\[ TPR = TP/(TP + FN) \\] \u5c3d\u7ba1\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\uff0c\u4f46\u6211\u4eec\u5c06\u4e3a\u5b83\u521b\u5efa\u4e00\u4e2a python \u51fd\u6570\uff0c\u4ee5\u4fbf\u4eca\u540e\u4f7f\u7528\u8fd9\u4e2a\u540d\u79f0\u3002 def tpr ( y_true , y_pred ): # \u771f\u9633\u6027\u7387\uff08TPR\uff09\uff0c\u4e0e\u53ec\u56de\u7387\u8ba1\u7b97\u516c\u5f0f\u4e00\u81f4 return recall ( y_true , y_pred ) TPR \u6216\u53ec\u56de\u7387\u4e5f\u88ab\u79f0\u4e3a\u7075\u654f\u5ea6\u3002 \u800c FPR \u6216\u5047\u9633\u6027\u7387\uff08False Positive Rate\uff09\u7684\u5b9a\u4e49\u662f\uff1a \\[ FPR = FP / (TN + FP) \\] def fpr ( y_true , y_pred ): # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u8fd4\u56de\u5047\u9633\u6027\u7387\uff08FPR\uff09 return fp / ( tn + fp ) 1 - FPR \u88ab\u79f0\u4e3a\u7279\u5f02\u6027\u6216\u771f\u9634\u6027\u7387\u6216 TNR\u3002\u8fd9\u4e9b\u672f\u8bed\u5f88\u591a\uff0c\u4f46\u5176\u4e2d\u6700\u91cd\u8981\u7684\u53ea\u6709 TPR \u548c FPR\u3002\u5047\u8bbe\u6211\u4eec\u53ea\u6709 15 \u4e2a\u6837\u672c\uff0c\u5176\u76ee\u6807\u503c\u4e3a\u4e8c\u5143\uff1a Actual targets : [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1] \u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u7c7b\u4f3c\u968f\u673a\u68ee\u6797\u7684\u6a21\u578b\uff0c\u5c31\u80fd\u5f97\u5230\u6837\u672c\u5448\u9633\u6027\u7684\u6982\u7387\u3002 Predicted probabilities for 1: [0.1, 0.3, 0.2, 0.6, 0.8, 0.05, 0.9, 0.5, 0.3, 0.66, 0.3, 0.2, 0.85, 0.15, 0.99] \u5bf9\u4e8e >= 0.5 \u7684\u5178\u578b\u9608\u503c\uff0c\u6211\u4eec\u53ef\u4ee5\u8bc4\u4f30\u4e0a\u8ff0\u6240\u6709\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387/TPR\u3001F1 \u548c FPR \u503c\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u5c06\u9608\u503c\u9009\u4e3a 0.4 \u6216 0.6\uff0c\u4e5f\u53ef\u4ee5\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0 \u5230 1 \u4e4b\u95f4\u7684\u4efb\u4f55\u503c\uff0c\u5e76\u8ba1\u7b97\u4e0a\u8ff0\u6240\u6709\u6307\u6807\u3002 \u4e0d\u8fc7\uff0c\u6211\u4eec\u53ea\u8ba1\u7b97\u4e24\u4e2a\u503c\uff1a TPR \u548c FPR\u3002 # \u521d\u59cb\u5316\u771f\u9633\u6027\u7387\u5217\u8868 tpr_list = [] # \u521d\u59cb\u5316\u5047\u9633\u6027\u7387\u5217\u8868 fpr_list = [] # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u7387 temp_tpr = tpr ( y_true , temp_pred ) # \u5047\u9633\u6027\u7387 temp_fpr = fpr ( y_true , temp_pred ) # \u5c06\u771f\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 tpr_list . append ( temp_tpr ) # \u5c06\u5047\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 fpr_list . append ( temp_fpr ) \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u6bcf\u4e2a\u9608\u503c\u7684 TPR \u503c\u548c FPR \u503c\u3002 \u56fe 3\uff1a\u9608\u503c\u3001TPR \u548c FPR \u503c\u8868 \u5982\u679c\u6211\u4eec\u7ed8\u5236\u5982\u56fe 3 \u6240\u793a\u7684\u8868\u683c\uff0c\u5373\u4ee5 TPR \u4e3a Y \u8f74\uff0cFPR \u4e3a X \u8f74\uff0c\u5c31\u4f1a\u5f97\u5230\u5982\u56fe 4 \u6240\u793a\u7684\u66f2\u7ebf\u3002 \u56fe 4\uff1aROC\u66f2\u7ebf \u8fd9\u6761\u66f2\u7ebf\u4e5f\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u3002\u5982\u679c\u6211\u4eec\u8ba1\u7b97\u8fd9\u6761 ROC \u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff0c\u5c31\u662f\u5728\u8ba1\u7b97\u53e6\u4e00\u4e2a\u6307\u6807\uff0c\u5f53\u6570\u636e\u96c6\u7684\u4e8c\u5143\u76ee\u6807\u504f\u659c\u65f6\uff0c\u8fd9\u4e2a\u6307\u6807\u5c31\u4f1a\u975e\u5e38\u5e38\u7528\u3002 \u8fd9\u4e2a\u6307\u6807\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u6216\u66f2\u7ebf\u4e0b\u9762\u79ef\uff0c\u7b80\u79f0 AUC\u3002\u8ba1\u7b97 ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002\u5728\u6b64\uff0c\u6211\u4eec\u5c06\u91c7\u7528 scikit- learn \u7684\u5947\u5999\u5b9e\u73b0\u65b9\u6cd5\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: metrics . roc_auc_score ( y_true , y_pred ) Out [ X ]: 0.8300000000000001 AUC \u503c\u4ece 0 \u5230 1 \u4e0d\u7b49\u3002 AUC = 1 \u610f\u5473\u7740\u60a8\u62e5\u6709\u4e00\u4e2a\u5b8c\u7f8e\u7684\u6a21\u578b\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u610f\u5473\u7740\u4f60\u5728\u9a8c\u8bc1\u65f6\u72af\u4e86\u4e00\u4e9b\u9519\u8bef\uff0c\u5e94\u8be5\u91cd\u65b0\u5ba1\u89c6\u6570\u636e\u5904\u7406\u548c\u9a8c\u8bc1\u6d41\u7a0b\u3002\u5982\u679c\u4f60\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\uff0c\u90a3\u4e48\u606d\u559c\u4f60\uff0c\u4f60\u5df2\u7ecf\u62e5\u6709\u4e86\u9488\u5bf9\u6570\u636e\u96c6\u5efa\u7acb\u7684\u6700\u4f73\u6a21\u578b\u3002 AUC = 0 \u610f\u5473\u7740\u60a8\u7684\u6a21\u578b\u975e\u5e38\u7cdf\u7cd5\uff08\u6216\u975e\u5e38\u597d\uff01\uff09\u3002\u8bd5\u7740\u53cd\u8f6c\u9884\u6d4b\u7684\u6982\u7387\uff0c\u4f8b\u5982\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u6b63\u7c7b\u7684\u6982\u7387\u662f p\uff0c\u8bd5\u7740\u7528 1-p \u4ee3\u66ff\u5b83\u3002\u8fd9\u79cd AUC \u4e5f\u53ef\u80fd\u610f\u5473\u7740\u60a8\u7684\u9a8c\u8bc1\u6216\u6570\u636e\u5904\u7406\u5b58\u5728\u95ee\u9898\u3002 AUC = 0.5 \u610f\u5473\u7740\u4f60\u7684\u9884\u6d4b\u662f\u968f\u673a\u7684\u3002\u56e0\u6b64\uff0c\u5bf9\u4e8e\u4efb\u4f55\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u5982\u679c\u6211\u5c06\u6240\u6709\u76ee\u6807\u90fd\u9884\u6d4b\u4e3a 0.5\uff0c\u6211\u5c06\u5f97\u5230 0.5 \u7684 AUC\u3002 AUC \u503c\u4ecb\u4e8e 0 \u548c 0.5 \u4e4b\u95f4\uff0c\u610f\u5473\u7740\u4f60\u7684\u6a21\u578b\u6bd4\u968f\u673a\u6a21\u578b\u66f4\u5dee\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u662f\u56e0\u4e3a\u4f60\u98a0\u5012\u4e86\u7c7b\u522b\u3002 \u5982\u679c\u60a8\u5c1d\u8bd5\u53cd\u8f6c\u9884\u6d4b\uff0c\u60a8\u7684 AUC \u503c\u53ef\u80fd\u4f1a\u8d85\u8fc7 0.5\u3002\u63a5\u8fd1 1 \u7684 AUC \u503c\u88ab\u8ba4\u4e3a\u662f\u597d\u503c\u3002 \u4f46 AUC \u5bf9\u6211\u4eec\u7684\u6a21\u578b\u6709\u4ec0\u4e48\u5f71\u54cd\u5462\uff1f \u5047\u8bbe\u60a8\u5efa\u7acb\u4e86\u4e00\u4e2a\u4ece\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e2d\u68c0\u6d4b\u6c14\u80f8\u7684\u6a21\u578b\uff0c\u5176 AUC \u503c\u4e3a 0.85\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5982\u679c\u60a8\u4ece\u6570\u636e\u96c6\u4e2d\u968f\u673a\u9009\u62e9\u4e00\u5f20\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9633\u6027\u6837\u672c\uff09\u548c\u53e6\u4e00\u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9634\u6027\u6837\u672c\uff09\uff0c\u90a3\u4e48\u6c14\u80f8\u56fe\u50cf\u7684\u6392\u540d\u5c06\u9ad8\u4e8e\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u6982\u7387\u4e3a 0.85\u3002 \u8ba1\u7b97\u6982\u7387\u548c AUC \u540e\uff0c\u60a8\u9700\u8981\u5bf9\u6d4b\u8bd5\u96c6\u8fdb\u884c\u9884\u6d4b\u3002\u6839\u636e\u95ee\u9898\u548c\u4f7f\u7528\u60c5\u51b5\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u6982\u7387\u6216\u5b9e\u9645\u7c7b\u522b\u3002\u5982\u679c\u4f60\u60f3\u8981\u6982\u7387\uff0c\u8fd9\u5e76\u4e0d\u96be\u3002\u5982\u679c\u60a8\u60f3\u8981\u7c7b\u522b\uff0c\u5219\u9700\u8981\u9009\u62e9\u4e00\u4e2a\u9608\u503c\u3002\u5728\u4e8c\u5143\u5206\u7c7b\u7684\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u91c7\u7528\u7c7b\u4f3c\u4e0b\u9762\u7684\u65b9\u6cd5\u3002 \\[ Prediction = Probability >= Threshold \\] \u4e5f\u5c31\u662f\u8bf4\uff0c\u9884\u6d4b\u662f\u4e00\u4e2a\u53ea\u5305\u542b\u4e8c\u5143\u53d8\u91cf\u7684\u65b0\u5217\u8868\u3002\u5982\u679c\u6982\u7387\u5927\u4e8e\u6216\u7b49\u4e8e\u7ed9\u5b9a\u7684\u9608\u503c\uff0c\u5219\u9884\u6d4b\u4e2d\u7684\u4e00\u9879\u4e3a 1\uff0c\u5426\u5219\u4e3a 0\u3002 \u4f60\u731c\u600e\u4e48\u7740\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 ROC \u66f2\u7ebf\u6765\u9009\u62e9\u8fd9\u4e2a\u9608\u503c\uff01ROC \u66f2\u7ebf\u4f1a\u544a\u8bc9\u60a8\u9608\u503c\u5bf9\u5047\u9633\u6027\u7387\u548c\u771f\u9633\u6027\u7387\u7684\u5f71\u54cd\uff0c\u8fdb\u800c\u5f71\u54cd\u5047\u9633\u6027\u548c\u771f\u9633\u6027\u3002\u60a8\u5e94\u8be5\u9009\u62e9\u6700\u9002\u5408\u60a8\u7684\u95ee\u9898\u548c\u6570\u636e\u96c6\u7684\u9608\u503c\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u4e0d\u5e0c\u671b\u6709\u592a\u591a\u7684\u8bef\u62a5\uff0c\u90a3\u4e48\u9608\u503c\u5c31\u5e94\u8be5\u9ad8\u4e00\u4e9b\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e5f\u4f1a\u5e26\u6765\u66f4\u591a\u7684\u8bef\u62a5\u3002\u6ce8\u610f\u6743\u8861\u5229\u5f0a\uff0c\u9009\u62e9\u6700\u4f73\u9608\u503c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u9608\u503c\u5982\u4f55\u5f71\u54cd\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u503c\u3002 # \u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list = [] # \u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list = [] # \u771f\u5b9e\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 temp_tp = true_positive ( y_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 temp_fp = false_positive ( y_true , temp_pred ) # \u52a0\u5165\u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list . append ( temp_tp ) # \u52a0\u5165\u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list . append ( temp_fp ) \u5229\u7528\u8fd9\u4e00\u70b9\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8868\u683c\uff0c\u5982\u56fe 5 \u6240\u793a\u3002 \u56fe 5\uff1a\u4e0d\u540c\u9608\u503c\u7684 TP \u503c\u548c FP \u503c \u5982\u56fe 6 \u6240\u793a\uff0c\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0cROC \u66f2\u7ebf\u5de6\u4e0a\u89d2\u7684\u503c\u5e94\u8be5\u662f\u4e00\u4e2a\u76f8\u5f53\u4e0d\u9519\u7684\u9608\u503c\u3002 \u5bf9\u6bd4\u8868\u683c\u548c ROC \u66f2\u7ebf\uff0c\u6211\u4eec\u53ef\u4ee5\u53d1\u73b0\uff0c0.6 \u5de6\u53f3\u7684\u9608\u503c\u76f8\u5f53\u4e0d\u9519\uff0c\u65e2\u4e0d\u4f1a\u4e22\u5931\u5927\u91cf\u7684\u771f\u9633\u6027\u7ed3\u679c\uff0c\u4e5f\u4e0d\u4f1a\u51fa\u73b0\u5927\u91cf\u7684\u5047\u9633\u6027\u7ed3\u679c\u3002 \u56fe 6\uff1a\u4ece ROC \u66f2\u7ebf\u6700\u5de6\u4fa7\u7684\u9876\u70b9\u9009\u62e9\u6700\u4f73\u9608\u503c AUC \u662f\u4e1a\u5185\u5e7f\u6cdb\u5e94\u7528\u4e8e\u504f\u659c\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6307\u6807\uff0c\u4e5f\u662f\u6bcf\u4e2a\u4eba\u90fd\u5e94\u8be5\u4e86\u89e3\u7684\u6307\u6807\u3002\u4e00\u65e6\u7406\u89e3\u4e86 AUC \u80cc\u540e\u7684\u7406\u5ff5\uff08\u5982\u4e0a\u6587\u6240\u8ff0\uff09\uff0c\u4e5f\u5c31\u5f88\u5bb9\u6613\u5411\u4e1a\u754c\u53ef\u80fd\u4f1a\u8bc4\u4f30\u60a8\u7684\u6a21\u578b\u7684\u975e\u6280\u672f\u4eba\u5458\u89e3\u91ca\u5b83\u4e86\u3002 \u5b66\u4e60 AUC \u540e\uff0c\u4f60\u5e94\u8be5\u5b66\u4e60\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u6307\u6807\u662f\u5bf9\u6570\u635f\u5931\u3002\u5bf9\u4e8e\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5c06\u5bf9\u6570\u635f\u5931\u5b9a\u4e49\u4e3a\uff1a \\[ LogLoss = -1.0 \\times (target \\times log(prediction) + (1-target) \\times log(1-prediction)) \\] \u5176\u4e2d\uff0c\u76ee\u6807\u503c\u4e3a 0 \u6216 1\uff0c\u9884\u6d4b\u503c\u4e3a\u6837\u672c\u5c5e\u4e8e\u7c7b\u522b 1 \u7684\u6982\u7387\u3002 \u5bf9\u4e8e\u6570\u636e\u96c6\u4e2d\u7684\u591a\u4e2a\u6837\u672c\uff0c\u6240\u6709\u6837\u672c\u7684\u5bf9\u6570\u635f\u5931\u53ea\u662f\u6240\u6709\u5355\u4e2a\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u9700\u8981\u8bb0\u4f4f\u7684\u4e00\u70b9\u662f\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u4e0d\u6b63\u786e\u6216\u504f\u5dee\u8f83\u5927\u7684\u9884\u6d4b\u8fdb\u884c\u76f8\u5f53\u9ad8\u7684\u60e9\u7f5a\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u975e\u5e38\u786e\u5b9a\u548c\u975e\u5e38\u9519\u8bef\u7684\u9884\u6d4b\u8fdb\u884c\u60e9\u7f5a\u3002 import numpy as np def log_loss ( y_true , y_proba ): # \u6781\u5c0f\u503c\uff0c\u9632\u6b620\u505a\u5206\u6bcd epsilon = 1e-15 # \u5bf9\u6570\u635f\u5931\u5217\u8868 loss = [] # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_proba ): # \u9650\u5236yp\u8303\u56f4\uff0c\u6700\u5c0f\u4e3aepsilon\uff0c\u6700\u5927\u4e3a1-epsilon yp = np . clip ( yp , epsilon , 1 - epsilon ) # \u8ba1\u7b97\u5bf9\u6570\u635f\u5931 temp_loss = - 1.0 * ( yt * np . log ( yp ) + ( 1 - yt ) * np . log ( 1 - yp )) # \u52a0\u5165\u5bf9\u6570\u635f\u5931\u5217\u8868 loss . append ( temp_loss ) return np . mean ( loss ) \u8ba9\u6211\u4eec\u6d4b\u8bd5\u4e00\u4e0b\u51fd\u6570\u6267\u884c\u60c5\u51b5\uff1a In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_proba = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u6211\u4eec\u53ef\u4ee5\u5c06\u5176\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u5b9e\u73b0\u662f\u6b63\u786e\u7684\u3002 \u5bf9\u6570\u635f\u5931\u7684\u5b9e\u73b0\u5f88\u5bb9\u6613\u3002\u89e3\u91ca\u8d77\u6765\u4f3c\u4e4e\u6709\u70b9\u56f0\u96be\u3002\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5bf9\u6570\u635f\u5931\u7684\u60e9\u7f5a\u8981\u6bd4\u5176\u4ed6\u6307\u6807\u5927\u5f97\u591a\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u6709 51% \u7684\u628a\u63e1\u8ba4\u4e3a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\uff0c\u90a3\u4e48\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.51) + (1 - 1) \\times log(1 - 0.51))=0.67 \\] \u5982\u679c\u4f60\u5bf9\u5c5e\u4e8e 0 \u7c7b\u7684\u6837\u672c\u6709 49% \u7684\u628a\u63e1\uff0c\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.49) + (1 - 1) \\times log(1 - 0.49))=0.67 \\] \u56e0\u6b64\uff0c\u5373\u4f7f\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0.5 \u7684\u622a\u65ad\u503c\u5e76\u5f97\u5230\u5b8c\u7f8e\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u6211\u4eec\u4ecd\u7136\u4f1a\u6709\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002\u56e0\u6b64\uff0c\u5728\u5904\u7406\u5bf9\u6570\u635f\u5931\u65f6\uff0c\u4f60\u9700\u8981\u975e\u5e38\u5c0f\u5fc3\uff1b\u4efb\u4f55\u4e0d\u786e\u5b9a\u7684\u9884\u6d4b\u90fd\u4f1a\u4ea7\u751f\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002 \u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u5927\u591a\u6570\u6307\u6807\u90fd\u53ef\u4ee5\u8f6c\u6362\u6210\u591a\u7c7b\u7248\u672c\u3002\u8fd9\u4e2a\u60f3\u6cd5\u5f88\u7b80\u5355\u3002\u4ee5\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u4e3a\u4f8b\u3002\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u6bcf\u4e00\u7c7b\u7684\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u6709\u4e09\u79cd\u4e0d\u540c\u7684\u8ba1\u7b97\u65b9\u6cd5\uff0c\u6709\u65f6\u53ef\u80fd\u4f1a\u4ee4\u4eba\u56f0\u60d1\u3002\u5047\u8bbe\u6211\u4eec\u9996\u5148\u5bf9\u7cbe\u786e\u7387\u611f\u5174\u8da3\u3002\u6211\u4eec\u77e5\u9053\uff0c\u7cbe\u786e\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u3002 \u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Macro averaged precision\uff09\uff1a\u5206\u522b\u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u7cbe\u786e\u7387\u7136\u540e\u6c42\u5e73\u5747\u503c \u5fae\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Micro averaged precision\uff09\uff1a\u6309\u7c7b\u8ba1\u7b97\u771f\u9633\u6027\u548c\u5047\u9633\u6027\uff0c\u7136\u540e\u7528\u5176\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387\u3002\u7136\u540e\u4ee5\u6b64\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387 \u52a0\u6743\u7cbe\u786e\u7387 \uff08Weighted precision\uff09\uff1a\u4e0e\u5b8f\u89c2\u7cbe\u786e\u7387\u76f8\u540c\uff0c\u4f46\u8fd9\u91cc\u662f\u52a0\u6743\u5e73\u5747\u7cbe\u786e\u7387 \u53d6\u51b3\u4e8e\u6bcf\u4e2a\u7c7b\u522b\u4e2d\u7684\u9879\u76ee\u6570 \u8fd9\u770b\u4f3c\u590d\u6742\uff0c\u4f46\u5728 python \u5b9e\u73b0\u4e2d\u5f88\u5bb9\u6613\u7406\u89e3\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import numpy as np def macro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u5982\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u8ba1\u7b97\u7cbe\u786e\u5ea6 temp_precision = tp / ( tp + fp ) # \u5404\u7c7b\u7cbe\u786e\u7387\u76f8\u52a0 precision += temp_precision # \u8ba1\u7b97\u5e73\u5747\u503c precision /= num_classes return precision \u4f60\u4f1a\u53d1\u73b0\u8fd9\u5e76\u4e0d\u96be\u3002\u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5fae\u5e73\u5747\u7cbe\u786e\u7387\u5206\u6570\u3002 import numpy as np def micro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u6570 tp = 0 # \u521d\u59cb\u5316\u5047\u9633\u6027\u6837\u672c\u6570 fp = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 tp += true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 fp += false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8fd9\u4e5f\u4e0d\u96be\u3002\u90a3\u4ec0\u4e48\u96be\uff1f\u4ec0\u4e48\u90fd\u4e0d\u96be\u3002\u673a\u5668\u5b66\u4e60\u5f88\u7b80\u5355\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u52a0\u6743\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u3002 from collections import Counter import numpy as np def weighted_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 temp_precision = tp / ( tp + fp ) # \u6839\u636e\u8be5\u79cd\u7c7b\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_precision = class_counts [ class_ ] * temp_precision # \u52a0\u6743\u7cbe\u786e\u7387\u6c42\u548c precision += weighted_precision # \u8ba1\u7b97\u5e73\u5747\u7cbe\u786e\u7387 overall_precision = precision / len ( y_true ) return overall_precision \u5c06\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff0c\u4ee5\u4e86\u89e3\u5b9e\u73b0\u662f\u5426\u6b63\u786e\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: macro_precision ( y_true , y_pred ) Out [ X ]: 0.3611111111111111 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"macro\" ) Out [ X ]: 0.3611111111111111 In [ X ]: micro_precision ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"micro\" ) Out [ X ]: 0.4444444444444444 In [ X ]: weighted_precision ( y_true , y_pred ) Out [ X ]: 0.39814814814814814 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.39814814814814814 \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6b63\u786e\u5730\u5b9e\u73b0\u4e86\u4e00\u5207\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u91cc\u5c55\u793a\u7684\u5b9e\u73b0\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\uff0c\u4f46\u5374\u662f\u6700\u5bb9\u6613\u7406\u89e3\u7684\u3002 \u540c\u6837\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5b9e\u73b0 \u591a\u7c7b\u522b\u7684\u53ec\u56de\u7387\u6307\u6807 \u3002\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u3001\u5047\u9633\u6027\u548c\u5047\u9634\u6027\uff0c\u800c F1 \u5219\u53d6\u51b3\u4e8e\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u53ec\u56de\u7387\u7684\u5b9e\u73b0\u65b9\u6cd5\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\uff0c\u8fd9\u91cc\u5b9e\u73b0\u7684\u662f\u591a\u7c7b F1 \u7684\u4e00\u4e2a\u7248\u672c\uff0c\u5373\u52a0\u6743\u5e73\u5747\u503c\u3002 from collections import Counter import numpy as np def weighted_f1 ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316F1\u503c f1 = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( temp_true , temp_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( temp_true , temp_pred ) # \u82e5\u7cbe\u786e\u7387+\u53ec\u56de\u7387\u4e0d\u4e3a0\uff0c\u5219\u4f7f\u7528\u516c\u5f0f\u8ba1\u7b97F1\u503c if p + r != 0 : temp_f1 = 2 * p * r / ( p + r ) # \u5426\u5219\u76f4\u63a5\u4e3a0 else : temp_f1 = 0 # \u6839\u636e\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_f1 = class_counts [ class_ ] * temp_f1 # \u52a0\u6743F1\u503c\u76f8\u52a0 f1 += weighted_f1 # \u8ba1\u7b97\u52a0\u6743\u5e73\u5747F1\u503c overall_f1 = f1 / len ( y_true ) return overall_f1 \u8bf7\u6ce8\u610f\uff0c\u4e0a\u9762\u6709\u51e0\u884c\u4ee3\u7801\u662f\u65b0\u5199\u7684\u3002\u56e0\u6b64\uff0c\u4f60\u5e94\u8be5\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e9b\u4ee3\u7801\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: weighted_f1 ( y_true , y_pred ) Out [ X ]: 0.41269841269841273 In [ X ]: metrics . f1_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.41269841269841273 \u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u591a\u7c7b\u95ee\u9898\u5b9e\u73b0\u4e86\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002\u540c\u6837\uff0c\u60a8\u4e5f\u53ef\u4ee5\u5c06 AUC \u548c\u5bf9\u6570\u635f\u5931\u8f6c\u6362\u4e3a\u591a\u7c7b\u683c\u5f0f\u3002\u8fd9\u79cd\u8f6c\u6362\u683c\u5f0f\u88ab\u79f0\u4e3a one-vs-all \u3002\u8fd9\u91cc\u6211\u4e0d\u6253\u7b97\u5b9e\u73b0\u5b83\u4eec\uff0c\u56e0\u4e3a\u5b9e\u73b0\u65b9\u6cd5\u4e0e\u6211\u4eec\u5df2\u7ecf\u8ba8\u8bba\u8fc7\u7684\u5f88\u76f8\u4f3c\u3002 \u5728\u4e8c\u5143\u6216\u591a\u7c7b\u5206\u7c7b\u4e2d\uff0c\u770b\u4e00\u4e0b \u6df7\u6dc6\u77e9\u9635 \u4e5f\u5f88\u6d41\u884c\u3002\u4e0d\u8981\u56f0\u60d1\uff0c\u8fd9\u5f88\u7b80\u5355\u3002\u6df7\u6dc6\u77e9\u9635\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u5305\u542b TP\u3001FP\u3001TN \u548c FN \u7684\u8868\u683c\u3002\u4f7f\u7528\u6df7\u6dc6\u77e9\u9635\uff0c\u60a8\u53ef\u4ee5\u5feb\u901f\u67e5\u770b\u6709\u591a\u5c11\u6837\u672c\u88ab\u9519\u8bef\u5206\u7c7b\uff0c\u6709\u591a\u5c11\u6837\u672c\u88ab\u6b63\u786e\u5206\u7c7b\u3002\u4e5f\u8bb8\u6709\u4eba\u4f1a\u8bf4\uff0c\u6df7\u6dc6\u77e9\u9635\u5e94\u8be5\u5728\u672c\u7ae0\u5f88\u65e9\u5c31\u8bb2\u5230\uff0c\u4f46\u6211\u6ca1\u6709\u8fd9\u4e48\u505a\u3002\u5982\u679c\u4e86\u89e3\u4e86 TP\u3001FP\u3001TN\u3001FN\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c AUC\uff0c\u5c31\u5f88\u5bb9\u6613\u7406\u89e3\u548c\u89e3\u91ca\u6df7\u6dc6\u77e9\u9635\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u56fe 7 \u4e2d\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6df7\u6dc6\u77e9\u9635\u7531 TP\u3001FP\u3001FN \u548c TN \u7ec4\u6210\u3002\u6211\u4eec\u53ea\u9700\u8981\u8fd9\u4e9b\u503c\u6765\u8ba1\u7b97\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u3001F1 \u5206\u6570\u548c AUC\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u4e5f\u559c\u6b22\u628a FP \u79f0\u4e3a \u7b2c\u4e00\u7c7b\u9519\u8bef \uff0c\u628a FN \u79f0\u4e3a \u7b2c\u4e8c\u7c7b\u9519\u8bef \u3002 \u56fe 7\uff1a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6df7\u6dc6\u77e9\u9635 \u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u4e8c\u5143\u6df7\u6dc6\u77e9\u9635\u6269\u5c55\u4e3a\u591a\u7c7b\u6df7\u6dc6\u77e9\u9635\u3002\u5b83\u4f1a\u662f\u4ec0\u4e48\u6837\u5b50\u5462\uff1f\u5982\u679c\u6211\u4eec\u6709 N \u4e2a\u7c7b\u522b\uff0c\u5b83\u5c06\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a NxN \u7684\u77e9\u9635\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b\uff0c\u6211\u4eec\u90fd\u8981\u8ba1\u7b97\u76f8\u5173\u7c7b\u522b\u548c\u5176\u4ed6\u7c7b\u522b\u7684\u6837\u672c\u603b\u6570\u3002\u4e3e\u4e2a\u4f8b\u5b50\u53ef\u4ee5\u8ba9\u6211\u4eec\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4ee5\u4e0b\u771f\u5b9e\u6807\u7b7e\uff1a \\[ [0, 1, 2, 0, 1, 2, 0, 2, 2] \\] \u6211\u4eec\u7684\u9884\u6d4b\u6807\u7b7e\u662f\uff1a \\[ [0, 2, 1, 0, 2, 1, 0, 0, 2] \\] \u90a3\u4e48\uff0c\u6211\u4eec\u7684\u6df7\u6dc6\u77e9\u9635\u5c06\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u591a\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635 \u56fe 8 \u8bf4\u660e\u4e86\u4ec0\u4e48\uff1f \u8ba9\u6211\u4eec\u6765\u770b\u770b 0 \u7c7b\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e 0 \u7c7b\u3002\u7136\u800c\uff0c\u5728\u9884\u6d4b\u4e2d\uff0c\u6211\u4eec\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 0 \u7c7b\uff0c1 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\u3002\u7406\u60f3\u60c5\u51b5\u4e0b\uff0c\u5bf9\u4e8e\u771f\u5b9e\u6807\u7b7e\u4e2d\u7684\u7c7b\u522b 0\uff0c\u9884\u6d4b\u6807\u7b7e 1 \u548c 2 \u5e94\u8be5\u6ca1\u6709\u4efb\u4f55\u6837\u672c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u7c7b\u522b 2\u3002\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 4\uff0c\u800c\u5728\u9884\u6d4b\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 3\u3002 \u4e00\u4e2a\u5b8c\u7f8e\u7684\u6df7\u6dc6\u77e9\u9635\u53ea\u80fd\u4ece\u5de6\u5230\u53f3\u659c\u5411\u586b\u5145\u3002 \u6df7\u6dc6\u77e9\u9635 \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u8ba1\u7b97\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u4e0d\u540c\u6307\u6807\u3002Scikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u76f4\u63a5\u7684\u65b9\u6cd5\u6765\u751f\u6210\u6df7\u6dc6\u77e9\u9635\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5728\u56fe 8 \u4e2d\u663e\u793a\u7684\u6df7\u6dc6\u77e9\u9635\u662f scikit-learn \u6df7\u6dc6\u77e9\u9635\u7684\u8f6c\u7f6e\uff0c\u539f\u59cb\u7248\u672c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u7ed8\u5236\u3002 import matplotlib.pyplot as plt import seaborn as sns from sklearn import metrics # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] # \u9884\u6d4b\u6837\u672c\u6807\u7b7e y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] # \u8ba1\u7b97\u6df7\u6dc6\u77e9\u9635 cm = metrics . confusion_matrix ( y_true , y_pred ) # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 10 , 10 )) # \u521b\u5efa\u65b9\u683c cmap = sns . cubehelix_palette ( 50 , hue = 0.05 , rot = 0 , light = 0.9 , dark = 0 , as_cmap = True ) # \u89c4\u5b9a\u5b57\u4f53\u5927\u5c0f sns . set ( font_scale = 2.5 ) # \u7ed8\u5236\u70ed\u56fe sns . heatmap ( cm , annot = True , cmap = cmap , cbar = False ) # y\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . ylabel ( 'Actual Labels' , fontsize = 20 ) # x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . xlabel ( 'Predicted Labels' , fontsize = 20 ) \u56e0\u6b64\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u89e3\u51b3\u4e86\u4e8c\u5143\u5206\u7c7b\u548c\u591a\u7c7b\u5206\u7c7b\u7684\u5ea6\u91cf\u95ee\u9898\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u8ba8\u8bba\u53e6\u4e00\u79cd\u7c7b\u578b\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u5373\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u5728\u591a\u6807\u7b7e\u5206\u7c7b\u4e2d\uff0c\u6bcf\u4e2a\u6837\u672c\u90fd\u53ef\u80fd\u4e0e\u4e00\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u76f8\u5173\u8054\u3002\u8fd9\u7c7b\u95ee\u9898\u7684\u4e00\u4e2a\u7b80\u5355\u4f8b\u5b50\u5c31\u662f\u8981\u6c42\u4f60\u9884\u6d4b\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53\u3002 \u56fe 9 \u663e\u793a\u4e86\u4e00\u4e2a\u8457\u540d\u6570\u636e\u96c6\u7684\u56fe\u50cf\u793a\u4f8b\u3002\u8bf7\u6ce8\u610f\uff0c\u8be5\u6570\u636e\u96c6\u7684\u76ee\u6807\u6709\u6240\u4e0d\u540c\uff0c\u4f46\u6211\u4eec\u6682\u4e14\u4e0d\u53bb\u8ba8\u8bba\u5b83\u3002\u6211\u4eec\u5047\u8bbe\u5176\u76ee\u7684\u53ea\u662f\u9884\u6d4b\u56fe\u50cf\u4e2d\u662f\u5426\u5b58\u5728\u67d0\u4e2a\u7269\u4f53\u3002\u5728\u56fe 9 \u4e2d\uff0c\u6211\u4eec\u6709\u6905\u5b50\u3001\u82b1\u76c6\u3001\u7a97\u6237\uff0c\u4f46\u6ca1\u6709\u5176\u4ed6\u7269\u4f53\uff0c\u5982\u7535\u8111\u3001\u5e8a\u3001\u7535\u89c6\u7b49\u3002\u56e0\u6b64\uff0c\u4e00\u5e45\u56fe\u50cf\u53ef\u80fd\u6709\u591a\u4e2a\u76f8\u5173\u76ee\u6807\u3002\u8fd9\u7c7b\u95ee\u9898\u5c31\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u95ee\u9898\u3002 \u56fe 9\uff1a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53 \u8fd9\u7c7b\u5206\u7c7b\u95ee\u9898\u7684\u8861\u91cf\u6807\u51c6\u6709\u4e9b\u4e0d\u540c\u3002\u4e00\u4e9b\u5408\u9002\u7684 \u6700\u5e38\u89c1\u7684\u6307\u6807\u6709\uff1a k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u786e\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 \u8ba9\u6211\u4eec\u4ece k \u7cbe\u786e\u7387\u6216\u8005 P@k \u6211\u4eec\u4e0d\u80fd\u5c06\u8fd9\u4e00\u7cbe\u786e\u7387\u4e0e\u524d\u9762\u8ba8\u8bba\u7684\u7cbe\u786e\u7387\u6df7\u6dc6\u3002\u5982\u679c\u60a8\u6709\u4e00\u4e2a\u7ed9\u5b9a\u6837\u672c\u7684\u539f\u59cb\u7c7b\u522b\u5217\u8868\u548c\u540c\u4e00\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u7c7b\u522b\u5217\u8868\uff0c\u90a3\u4e48\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u5c31\u662f\u9884\u6d4b\u5217\u8868\u4e2d\u4ec5\u8003\u8651\u524d k \u4e2a\u9884\u6d4b\u7ed3\u679c\u7684\u547d\u4e2d\u6570\u9664\u4ee5 k\u3002 \u5982\u679c\u60a8\u5bf9\u6b64\u611f\u5230\u56f0\u60d1\uff0c\u4f7f\u7528 python \u4ee3\u7801\u540e\u5c31\u4f1a\u660e\u767d\u3002 def pk ( y_true , y_pred , k ): # \u5982\u679ck\u4e3a0 if k == 0 : # \u8fd4\u56de0 return 0 # \u53d6\u9884\u6d4b\u6807\u7b7e\u524dk\u4e2a y_pred = y_pred [: k ] # \u5c06\u9884\u6d4b\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 pred_set = set ( y_pred ) # \u5c06\u771f\u5b9e\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 true_set = set ( y_true ) # \u9884\u6d4b\u6807\u7b7e\u96c6\u5408\u4e0e\u771f\u5b9e\u6807\u7b7e\u96c6\u5408\u4ea4\u96c6 common_values = pred_set . intersection ( true_set ) # \u8ba1\u7b97\u7cbe\u786e\u7387 return len ( common_values ) / len ( y_pred [: k ]) \u6709\u4e86\u4ee3\u7801\uff0c\u4e00\u5207\u90fd\u53d8\u5f97\u66f4\u5bb9\u6613\u7406\u89e3\u4e86\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e86 k \u5e73\u5747\u7cbe\u786e\u7387\u6216 AP@k \u3002AP@k \u662f\u901a\u8fc7 P@k \u8ba1\u7b97\u5f97\u51fa\u7684\u3002\u4f8b\u5982\uff0c\u5982\u679c\u8981\u8ba1\u7b97 AP@3\uff0c\u6211\u4eec\u8981\u5148\u8ba1\u7b97 P@1\u3001P@2 \u548c P@3\uff0c\u7136\u540e\u5c06\u603b\u548c\u9664\u4ee5 3\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u7684\u5b9e\u73b0\u3002 def apk ( y_true , y_pred , k ): # \u521d\u59cb\u5316P@k\u5217\u8868 pk_values = [] # \u904d\u53861~k for i in range ( 1 , k + 1 ): # \u5c06P@k\u52a0\u5165\u5217\u8868 pk_values . append ( pk ( y_true , y_pred , i )) # \u82e5\u957f\u5ea6\u4e3a0 if len ( pk_values ) == 0 : # \u8fd4\u56de0 return 0 # \u5426\u5219\u8ba1\u7b97AP@K return sum ( pk_values ) / len ( pk_values ) \u8fd9\u4e24\u4e2a\u51fd\u6570\u53ef\u4ee5\u7528\u6765\u8ba1\u7b97\u4e24\u4e2a\u7ed9\u5b9a\u5217\u8868\u7684 k \u5e73\u5747\u7cbe\u786e\u7387 (AP@k)\uff1b\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u8ba1\u7b97\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: for i in range ( len ( y_true )): ... : for j in range ( 1 , 4 ): ... : print ( ... : f \"\"\" ...: y_true= { y_true [ i ] } , ...: y_pred= { y_pred [ i ] } , ...: AP@ { j } = { apk ( y_true [ i ], y_pred [ i ], k = j ) } ...: \"\"\" ... : ) ... : y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 1 = 0.0 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 2 = 0.25 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 3 = 0.38888888888888884 \u8bf7\u6ce8\u610f\uff0c\u6211\u7701\u7565\u4e86\u8f93\u51fa\u7ed3\u679c\u4e2d\u7684\u8bb8\u591a\u6570\u503c\uff0c\u4f46\u4f60\u4f1a\u660e\u767d\u5176\u4e2d\u7684\u610f\u601d\u3002\u8fd9\u5c31\u662f\u6211\u4eec\u5982\u4f55\u8ba1\u7b97 AP@k \u7684\u65b9\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u7684 AP@k\u3002\u5728\u673a\u5668\u5b66\u4e60\u4e2d\uff0c\u6211\u4eec\u5bf9\u6240\u6709\u6837\u672c\u90fd\u611f\u5174\u8da3\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u6709 \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387 k \u6216 MAP@k \u3002MAP@k \u53ea\u662f AP@k \u7684\u5e73\u5747\u503c\uff0c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b python \u4ee3\u7801\u8f7b\u677e\u8ba1\u7b97\u3002 def mapk ( y_true , y_pred , k ): # \u521d\u59cb\u5316AP@k\u5217\u8868 apk_values = [] # \u904d\u53860~\uff08\u771f\u5b9e\u6807\u7b7e\u6570-1\uff09 for i in range ( len ( y_true )): # \u5c06AP@K\u52a0\u5165\u5217\u8868 apk_values . append ( apk ( y_true [ i ], y_pred [ i ], k = k ) ) # \u8ba1\u7b97\u5e73\u5747AP@k return sum ( apk_values ) / len ( apk_values ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u9488\u5bf9\u76f8\u540c\u7684\u5217\u8868\u8ba1\u7b97 k=1\u30012\u30013 \u548c 4 \u65f6\u7684 MAP@k\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: mapk ( y_true , y_pred , k = 1 ) Out [ X ]: 0.3333333333333333 In [ X ]: mapk ( y_true , y_pred , k = 2 ) Out [ X ]: 0.375 In [ X ]: mapk ( y_true , y_pred , k = 3 ) Out [ X ]: 0.3611111111111111 In [ X ]: mapk ( y_true , y_pred , k = 4 ) Out [ X ]: 0.34722222222222215 P@k\u3001AP@k \u548c MAP@k \u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u5176\u4e2d 1 \u4e3a\u6700\u4f73\u3002 \u8bf7\u6ce8\u610f\uff0c\u6709\u65f6\u60a8\u53ef\u80fd\u4f1a\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230 P@k \u548c AP@k \u7684\u4e0d\u540c\u5b9e\u73b0\u65b9\u5f0f\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5176\u4e2d\u4e00\u79cd\u5b9e\u73b0\u65b9\u5f0f\u3002 import numpy as np def apk ( actual , predicted , k = 10 ): # \u82e5\u9884\u6d4b\u6807\u7b7e\u957f\u5ea6\u5927\u4e8ek if len ( predicted ) > k : # \u53d6\u524dk\u4e2a\u6807\u7b7e predicted = predicted [: k ] score = 0.0 num_hits = 0.0 for i , p in enumerate ( predicted ): if p in actual and p not in predicted [: i ]: num_hits += 1.0 score += num_hits / ( i + 1.0 ) if not actual : return 0.0 return score / min ( len ( actual ), k ) \u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u662f AP@k \u7684\u53e6\u4e00\u4e2a\u7248\u672c\uff0c\u5176\u4e2d\u987a\u5e8f\u5f88\u91cd\u8981\uff0c\u6211\u4eec\u8981\u6743\u8861\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u7684\u7ed3\u679c\u4e0e\u6211\u7684\u4ecb\u7ecd\u7565\u6709\u4e0d\u540c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6765\u770b\u770b \u591a\u6807\u7b7e\u5206\u7c7b\u7684\u5bf9\u6570\u635f\u5931 \u3002\u8fd9\u5f88\u5bb9\u6613\u3002\u60a8\u53ef\u4ee5\u5c06\u76ee\u6807\u8f6c\u6362\u4e3a\u4e8c\u5143\u5206\u7c7b\uff0c\u7136\u540e\u5bf9\u6bcf\u4e00\u5217\u4f7f\u7528\u5bf9\u6570\u635f\u5931\u3002\u6700\u540e\uff0c\u4f60\u53ef\u4ee5\u6c42\u51fa\u6bcf\u5217\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u8fd9\u4e5f\u88ab\u79f0\u4e3a\u5e73\u5747\u5217\u5bf9\u6570\u635f\u5931\u3002\u5f53\u7136\uff0c\u8fd8\u6709\u5176\u4ed6\u65b9\u6cd5\u53ef\u4ee5\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff0c\u4f60\u5e94\u8be5\u5728\u9047\u5230\u65f6\u52a0\u4ee5\u63a2\u7d22\u3002 \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u8bf4\u5df2\u7ecf\u638c\u63e1\u4e86\u6240\u6709\u4e8c\u5143\u5206\u7c7b\u3001\u591a\u7c7b\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u6307\u6807\uff0c\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u6307\u6807\u3002 \u56de\u5f52\u4e2d\u6700\u5e38\u89c1\u7684\u6307\u6807\u662f \u8bef\u5dee\uff08Error\uff09 \u3002\u8bef\u5dee\u5f88\u7b80\u5355\uff0c\u4e5f\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \\[ Error = True\\ Value - Predicted\\ Value \\] \u7edd\u5bf9\u8bef\u5dee\uff08Absolute error\uff09 \u53ea\u662f\u4e0a\u8ff0\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\u3002 \\[ Absolute\\ Error = Abs(True\\ Value - Predicted\\ Value) \\] \u63a5\u4e0b\u6765\u6211\u4eec\u8ba8\u8bba \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff08MAE\uff09 \u3002\u5b83\u53ea\u662f\u6240\u6709\u7edd\u5bf9\u8bef\u5dee\u7684\u5e73\u5747\u503c\u3002 import numpy as np def mean_absolute_error ( y_true , y_pred ): #\u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u7edd\u5bf9\u8bef\u5dee error += np . abs ( yt - yp ) # \u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee return error / len ( y_true ) \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5e73\u65b9\u8bef\u5dee\u548c \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u3002 \\[ Squared\\ Error = (True Value - Predicted\\ Value)^2 \\] \u5747\u65b9\u8bef\u5dee\uff08MSE\uff09\u7684\u8ba1\u7b97\u65b9\u5f0f\u5982\u4e0b def mean_squared_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u8bef\u5dee\u5e73\u65b9\u548c error += ( yt - yp ) ** 2 # \u8ba1\u7b97\u5747\u65b9\u8bef\u5dee return error / len ( y_true ) MSE \u548c RMSE\uff08\u5747\u65b9\u6839\u8bef\u5dee\uff09 \u662f\u8bc4\u4f30\u56de\u5f52\u6a21\u578b\u6700\u5e38\u7528\u7684\u6307\u6807\u3002 \\[ RMSE = SQRT(MSE) \\] \u540c\u4e00\u7c7b\u8bef\u5dee\u7684\u53e6\u4e00\u79cd\u7c7b\u578b\u662f \u5e73\u65b9\u5bf9\u6570\u8bef\u5dee \u3002\u6709\u4eba\u79f0\u5176\u4e3a SLE \uff0c\u5f53\u6211\u4eec\u53d6\u6240\u6709\u6837\u672c\u4e2d\u8fd9\u4e00\u8bef\u5dee\u7684\u5e73\u5747\u503c\u65f6\uff0c\u5b83\u88ab\u79f0\u4e3a MSLE\uff08\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee\uff09\uff0c\u5b9e\u73b0\u65b9\u6cd5\u5982\u4e0b\u3002 import numpy as np def mean_squared_log_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee error += ( np . log ( 1 + yt ) - np . log ( 1 + yp )) ** 2 # \u8ba1\u7b97\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee return error / len ( y_true ) \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \u53ea\u662f\u5176\u5e73\u65b9\u6839\u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a RMSLE \u3002 \u7136\u540e\u662f\u767e\u5206\u6bd4\u8bef\u5dee\uff1a \\[ Percentage\\ Error = (( True\\ Value \u2013 Predicted\\ Value ) / True\\ Value ) \\times 100 \\] \u540c\u6837\u53ef\u4ee5\u8f6c\u6362\u4e3a\u6240\u6709\u6837\u672c\u7684\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee\u3002 def mean_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u767e\u5206\u6bd4\u8bef\u5dee error += ( yt - yp ) / yt # \u8fd4\u56de\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u7edd\u5bf9\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\uff08\u4e5f\u662f\u66f4\u5e38\u89c1\u7684\u7248\u672c\uff09\u88ab\u79f0\u4e3a \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee\u6216 MAPE \u3002 import numpy as np def mean_abs_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee error += np . abs ( yt - yp ) / yt #\u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u56de\u5f52\u7684\u6700\u5927\u4f18\u70b9\u662f\uff0c\u53ea\u6709\u51e0\u4e2a\u6700\u5e38\u7528\u7684\u6307\u6807\uff0c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u6240\u6709\u56de\u5f52\u95ee\u9898\u3002\u4e0e\u5206\u7c7b\u6307\u6807\u76f8\u6bd4\uff0c\u56de\u5f52\u6307\u6807\u66f4\u5bb9\u6613\u7406\u89e3\u3002 \u8ba9\u6211\u4eec\u6765\u8c08\u8c08\u53e6\u4e00\u4e2a\u56de\u5f52\u6307\u6807 \\(R^2\\) \uff08R \u65b9\uff09\uff0c\u4e5f\u79f0\u4e3a \u5224\u5b9a\u7cfb\u6570 \u3002 \u7b80\u5355\u5730\u8bf4\uff0cR \u65b9\u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u3002R \u65b9\u63a5\u8fd1 1.0 \u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u76f8\u5f53\u597d\uff0c\u800c\u63a5\u8fd1 0 \u5219\u8868\u793a\u6a21\u578b\u4e0d\u662f\u90a3\u4e48\u597d\u3002\u5f53\u6a21\u578b\u53ea\u662f\u505a\u51fa\u8352\u8c2c\u7684\u9884\u6d4b\u65f6\uff0cR \u65b9\u4e5f\u53ef\u80fd\u662f\u8d1f\u503c\u3002 R \u65b9\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\u6240\u793a\uff0c\u4f46 Python \u7684\u5b9e\u73b0\u603b\u662f\u80fd\u8ba9\u4e00\u5207\u66f4\u52a0\u6e05\u6670\u3002 \\[ R^2 = \\frac{\\sum^{N}_{i=1}(y_{t_i}-y_{p_i})^2}{\\sum^{N}_{i=1}(y_{t_i} - y_{t_{mean}})} \\] import numpy as np def r2 ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u5747\u771f\u5b9e\u503c mean_true_value = np . mean ( y_true ) # \u521d\u59cb\u5316\u5e73\u65b9\u8bef\u5dee numerator = 0 denominator = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): numerator += ( yt - yp ) ** 2 denominator += ( yt - mean_true_value ) ** 2 ratio = numerator / denominator # \u8ba1\u7b97R\u65b9 return 1 \u2013 ratio \u8fd8\u6709\u66f4\u591a\u7684\u8bc4\u4ef7\u6307\u6807\uff0c\u8fd9\u4e2a\u6e05\u5355\u6c38\u8fdc\u4e5f\u5217\u4e0d\u5b8c\u3002\u6211\u53ef\u4ee5\u5199\u4e00\u672c\u4e66\uff0c\u53ea\u4ecb\u7ecd\u4e0d\u540c\u7684\u8bc4\u4ef7\u6307\u6807\u3002\u4e5f\u8bb8\u6211\u4f1a\u7684\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b\u8bc4\u4f30\u6307\u6807\u51e0\u4e4e\u53ef\u4ee5\u6ee1\u8db3\u4f60\u60f3\u5c1d\u8bd5\u89e3\u51b3\u7684\u6240\u6709\u95ee\u9898\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ee5\u6700\u76f4\u63a5\u7684\u65b9\u5f0f\u5b9e\u73b0\u4e86\u8fd9\u4e9b\u6307\u6807\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u4eec\u4e0d\u591f\u9ad8\u6548\u3002\u4f60\u53ef\u4ee5\u901a\u8fc7\u6b63\u786e\u4f7f\u7528 numpy \u4ee5\u975e\u5e38\u9ad8\u6548\u7684\u65b9\u5f0f\u5b9e\u73b0\u5176\u4e2d\u5927\u90e8\u5206\u6307\u6807\u3002\u4f8b\u5982\uff0c\u770b\u770b\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\u7684\u5b9e\u73b0\uff0c\u4e0d\u9700\u8981\u4efb\u4f55\u5faa\u73af\u3002 import numpy as np def mae_np ( y_true , y_pred ): return np . mean ( np . abs ( y_true - y_pred )) \u6211\u672c\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5b9e\u73b0\u6240\u6709\u6307\u6807\uff0c\u4f46\u4e3a\u4e86\u5b66\u4e60\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b\u5e95\u5c42\u5b9e\u73b0\u3002\u4e00\u65e6\u4f60\u5b66\u4f1a\u4e86\u7eaf python \u7684\u5e95\u5c42\u5b9e\u73b0\uff0c\u5e76\u4e14\u4e0d\u4f7f\u7528\u5927\u91cf numpy\uff0c\u4f60\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u5176\u8f6c\u6362\u4e3a numpy\uff0c\u5e76\u4f7f\u5176\u53d8\u5f97\u66f4\u5feb\u3002 \u7136\u540e\u662f\u4e00\u4e9b\u9ad8\u7ea7\u5ea6\u91cf\u3002 \u5176\u4e2d\u4e00\u4e2a\u5e94\u7528\u76f8\u5f53\u5e7f\u6cdb\u7684\u6307\u6807\u662f \u4e8c\u6b21\u52a0\u6743\u5361\u5e15 \uff0c\u4e5f\u79f0\u4e3a QWK \u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a\u79d1\u6069\u5361\u5e15\u3002 QWK \u8861\u91cf\u4e24\u4e2a \"\u8bc4\u5206 \"\u4e4b\u95f4\u7684 \"\u4e00\u81f4\u6027\"\u3002\u8bc4\u5206\u53ef\u4ee5\u662f 0 \u5230 N \u4e4b\u95f4\u7684\u4efb\u4f55\u5b9e\u6570\uff0c\u9884\u6d4b\u4e5f\u5728\u540c\u4e00\u8303\u56f4\u5185\u3002\u4e00\u81f4\u6027\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u8fd9\u4e9b\u8bc4\u7ea7\u4e4b\u95f4\u7684\u63a5\u8fd1\u7a0b\u5ea6\u3002\u56e0\u6b64\uff0c\u5b83\u9002\u7528\u4e8e\u6709 N \u4e2a\u4e0d\u540c\u7c7b\u522b\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u4e00\u81f4\u5ea6\u9ad8\uff0c\u5206\u6570\u5c31\u66f4\u63a5\u8fd1 1.0\u3002Cohen's kappa \u5728 scikit-learn \u4e2d\u6709\u5f88\u597d\u7684\u5b9e\u73b0\uff0c\u5173\u4e8e\u8be5\u6307\u6807\u7684\u8be6\u7ec6\u8ba8\u8bba\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 ] In [ X ]: y_pred = [ 2 , 1 , 3 , 1 , 2 , 3 , 3 , 1 , 2 ] In [ X ]: metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) Out [ X ]: 0.33333333333333337 In [ X ]: metrics . accuracy_score ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 \u60a8\u53ef\u4ee5\u770b\u5230\uff0c\u5c3d\u7ba1\u51c6\u786e\u5ea6\u5f88\u9ad8\uff0c\u4f46 QWK \u5374\u5f88\u4f4e\u3002QWK \u5927\u4e8e 0.85 \u5373\u4e3a\u975e\u5e38\u597d\uff01 \u4e00\u4e2a\u91cd\u8981\u7684\u6307\u6807\u662f \u9a6c\u4fee\u76f8\u5173\u7cfb\u6570\uff08MCC\uff09 \u30021 \u4ee3\u8868\u5b8c\u7f8e\u9884\u6d4b\uff0c-1 \u4ee3\u8868\u4e0d\u5b8c\u7f8e\u9884\u6d4b\uff0c0 \u4ee3\u8868\u968f\u673a\u9884\u6d4b\u3002MCC \u7684\u8ba1\u7b97\u516c\u5f0f\u975e\u5e38\u7b80\u5355\u3002 \\[ MCC = \\frac{TP \\times TN - FP \\times FN}{\\sqrt{(TP + FP) \\times (FN + TN) \\times (FP + TN) \\times (TP + FN)}} \\] \u6211\u4eec\u770b\u5230\uff0cMCC \u8003\u8651\u4e86 TP\u3001FP\u3001TN \u548c FN\uff0c\u56e0\u6b64\u53ef\u7528\u4e8e\u5904\u7406\u7c7b\u504f\u659c\u7684\u95ee\u9898\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u7684\u65b9\u6cd5\u5728 python \u4e2d\u5feb\u901f\u5b9e\u73b0\u5b83\u3002 def mcc ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) numerator = ( tp * tn ) - ( fp * fn ) denominator = ( ( tp + fp ) * ( fn + tn ) * ( fp + tn ) * ( tp + fn ) ) denominator = denominator ** 0.5 return numerator / denominator \u8fd9\u4e9b\u6307\u6807\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002 \u9700\u8981\u6ce8\u610f\u7684\u4e00\u70b9\u662f\uff0c\u5728\u8bc4\u4f30\u975e\u76d1\u7763\u65b9\u6cd5\uff08\u4f8b\u5982\u67d0\u79cd\u805a\u7c7b\uff09\u65f6\uff0c\u6700\u597d\u521b\u5efa\u6216\u624b\u52a8\u6807\u8bb0\u6d4b\u8bd5\u96c6\uff0c\u5e76\u5c06\u5176\u4e0e\u5efa\u6a21\u90e8\u5206\u7684\u6240\u6709\u5185\u5bb9\u5206\u5f00\u3002\u5b8c\u6210\u805a\u7c7b\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u4efb\u4f55\u4e00\u79cd\u76d1\u7763\u5b66\u4e60\u6307\u6807\u6765\u8bc4\u4f30\u6d4b\u8bd5\u96c6\u7684\u6027\u80fd\u4e86\u3002 \u4e00\u65e6\u6211\u4eec\u4e86\u89e3\u4e86\u7279\u5b9a\u95ee\u9898\u5e94\u8be5\u4f7f\u7528\u4ec0\u4e48\u6307\u6807\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f00\u59cb\u66f4\u6df1\u5165\u5730\u7814\u7a76\u6211\u4eec\u7684\u6a21\u578b\uff0c\u4ee5\u6c42\u6539\u8fdb\u3002","title":"\u8bc4\u4f30\u6307\u6807"},{"location":"%E8%B6%85%E5%8F%82%E6%95%B0%E4%BC%98%E5%8C%96/","text":"\u8d85\u53c2\u6570\u4f18\u5316 \u6709\u4e86\u4f18\u79c0\u7684\u6a21\u578b\uff0c\u5c31\u6709\u4e86\u4f18\u5316\u8d85\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u5f97\u5206\u6a21\u578b\u7684\u96be\u9898\u3002\u90a3\u4e48\uff0c\u4ec0\u4e48\u662f\u8d85\u53c2\u6570\u4f18\u5316\u5462\uff1f\u5047\u8bbe\u60a8\u7684\u673a\u5668\u5b66\u4e60\u9879\u76ee\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u6d41\u7a0b\u3002\u6709\u4e00\u4e2a\u6570\u636e\u96c6\uff0c\u4f60\u76f4\u63a5\u5e94\u7528\u4e00\u4e2a\u6a21\u578b\uff0c\u7136\u540e\u5f97\u5230\u7ed3\u679c\u3002\u6a21\u578b\u5728\u8fd9\u91cc\u7684\u53c2\u6570\u88ab\u79f0\u4e3a\u8d85\u53c2\u6570\uff0c\u5373\u63a7\u5236\u6a21\u578b\u8bad\u7ec3/\u62df\u5408\u8fc7\u7a0b\u7684\u53c2\u6570\u3002\u5982\u679c\u6211\u4eec\u7528 SGD \u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u6a21\u578b\u7684\u53c2\u6570\u662f\u659c\u7387\u548c\u504f\u5dee\uff0c\u8d85\u53c2\u6570\u662f\u5b66\u4e60\u7387\u3002\u4f60\u4f1a\u53d1\u73b0\u6211\u5728\u672c\u7ae0\u548c\u672c\u4e66\u4e2d\u4ea4\u66ff\u4f7f\u7528\u8fd9\u4e9b\u672f\u8bed\u3002\u5047\u8bbe\u6a21\u578b\u4e2d\u6709\u4e09\u4e2a\u53c2\u6570 a\u3001b\u3001c\uff0c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u90fd\u53ef\u4ee5\u662f 1 \u5230 10 \u4e4b\u95f4\u7684\u6574\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u7684 \"\u6b63\u786e \"\u7ec4\u5408\u5c06\u4e3a\u60a8\u63d0\u4f9b\u6700\u4f73\u7ed3\u679c\u3002\u56e0\u6b64\uff0c\u8fd9\u5c31\u6709\u70b9\u50cf\u4e00\u4e2a\u88c5\u6709\u4e09\u62e8\u5bc6\u7801\u9501\u7684\u624b\u63d0\u7bb1\u3002\u4e0d\u8fc7\uff0c\u4e09\u62e8\u5bc6\u7801\u9501\u53ea\u6709\u4e00\u4e2a\u6b63\u786e\u7b54\u6848\u3002\u800c\u6a21\u578b\u6709\u5f88\u591a\u6b63\u786e\u7b54\u6848\u3002\u90a3\u4e48\uff0c\u5982\u4f55\u627e\u5230\u6700\u4f73\u53c2\u6570\u5462\uff1f\u4e00\u79cd\u65b9\u6cd5\u662f\u5bf9\u6240\u6709\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\uff0c\u770b\u54ea\u79cd\u7ec4\u5408\u80fd\u63d0\u9ad8\u6307\u6807\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u521d\u59cb\u5316\u6700\u4f73\u51c6\u786e\u5ea6 best_accuracy = 0 # \u521d\u59cb\u5316\u6700\u4f73\u53c2\u6570\u7684\u5b57\u5178 best_parameters = { \"a\" : 0 , \"b\" : 0 , \"c\" : 0 } # \u5faa\u73af\u904d\u5386 a \u7684\u53d6\u503c\u8303\u56f4 1~10 for a in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 b \u7684\u53d6\u503c\u8303\u56f4 1~10 for b in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 c \u7684\u53d6\u503c\u8303\u56f4 1~10 for c in range ( 1 , 11 ): # \u521b\u5efa\u6a21\u578b\uff0c\u4f7f\u7528 a\u3001b\u3001c \u53c2\u6570 model = MODEL ( a , b , c ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u6a21\u578b model . fit ( training_data ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u9a8c\u8bc1\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( validation_data ) # \u8ba1\u7b97\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 accuracy = metrics . accuracy_score ( targets , preds ) # \u5982\u679c\u5f53\u524d\u51c6\u786e\u5ea6\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u5219\u66f4\u65b0\u6700\u4f73\u51c6\u786e\u5ea6\u548c\u6700\u4f73\u53c2\u6570 if accuracy > best_accuracy : best_accuracy = accuracy best_parameters [ \"a\" ] = a best_parameters [ \"b\" ] = b best_parameters [ \"c\" ] = c \u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4ece 1 \u5230 10 \u5bf9\u6240\u6709\u53c2\u6570\u8fdb\u884c\u4e86\u62df\u5408\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u603b\u5171\u8981\u5bf9\u6a21\u578b\u8fdb\u884c 1000 \u6b21\uff0810 x 10 x 10\uff09\u62df\u5408\u3002\u8fd9\u53ef\u80fd\u4f1a\u5f88\u6602\u8d35\uff0c\u56e0\u4e3a\u6a21\u578b\u7684\u8bad\u7ec3\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u8fc7\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\u5e94\u8be5\u6ca1\u95ee\u9898\uff0c\u4f46\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\uff0c\u5e76\u4e0d\u662f\u53ea\u6709\u4e09\u4e2a\u53c2\u6570\uff0c\u6bcf\u4e2a\u53c2\u6570\u4e5f\u4e0d\u662f\u53ea\u6709\u5341\u4e2a\u503c\u3002 \u5927\u591a\u6570\u6a21\u578b\u53c2\u6570\u90fd\u662f\u5b9e\u6570\uff0c\u4e0d\u540c\u53c2\u6570\u7684\u7ec4\u5408\u53ef\u4ee5\u662f\u65e0\u9650\u7684\u3002 \u8ba9\u6211\u4eec\u770b\u770b scikit-learn \u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002 RandomForestClassifier ( n_estimators = 100 , criterion = 'gini' , max_depth = None , min_samples_split = 2 , min_samples_leaf = 1 , min_weight_fraction_leaf = 0.0 , max_features = 'auto' , max_leaf_nodes = None , min_impurity_decrease = 0.0 , min_impurity_split = None , bootstrap = True , oob_score = False , n_jobs = None , random_state = None , verbose = 0 , warm_start = False , class_weight = None , ccp_alpha = 0.0 , max_samples = None , ) \u6709 19 \u4e2a\u53c2\u6570\uff0c\u800c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u7684\u6240\u6709\u7ec4\u5408\uff0c\u4ee5\u53ca\u5b83\u4eec\u53ef\u4ee5\u627f\u62c5\u7684\u6240\u6709\u503c\uff0c\u90fd\u5c06\u662f\u65e0\u7a77\u65e0\u5c3d\u7684\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u6ca1\u6709\u8db3\u591f\u7684\u8d44\u6e90\u548c\u65f6\u95f4\u6765\u505a\u8fd9\u4ef6\u4e8b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6307\u5b9a\u4e86\u4e00\u4e2a\u53c2\u6570\u7f51\u683c\u3002\u5728\u8fd9\u4e2a\u7f51\u683c\u4e0a\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408\u7684\u641c\u7d22\u79f0\u4e3a\u7f51\u683c\u641c\u7d22\u3002\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0cn_estimators \u53ef\u4ee5\u662f 100\u3001200\u3001250\u3001300\u3001400\u3001500\uff1bmax_depth \u53ef\u4ee5\u662f 1\u30012\u30015\u30017\u300111\u300115\uff1bcriterion \u53ef\u4ee5\u662f gini \u6216 entropy\u3002\u8fd9\u4e9b\u53c2\u6570\u770b\u8d77\u6765\u5e76\u4e0d\u591a\uff0c\u4f46\u5982\u679c\u6570\u636e\u96c6\u8fc7\u5927\uff0c\u8ba1\u7b97\u8d77\u6765\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u53ef\u4ee5\u50cf\u4e4b\u524d\u4e00\u6837\u521b\u5efa\u4e09\u4e2a for \u5faa\u73af\uff0c\u5e76\u5728\u9a8c\u8bc1\u96c6\u4e0a\u8ba1\u7b97\u5f97\u5206\uff0c\u8fd9\u6837\u5c31\u80fd\u5b9e\u73b0\u7f51\u683c\u641c\u7d22\u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u8981\u8fdb\u884c k \u6298\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5219\u9700\u8981\u66f4\u591a\u7684\u5faa\u73af\uff0c\u8fd9\u610f\u5473\u7740\u9700\u8981\u66f4\u591a\u7684\u65f6\u95f4\u6765\u627e\u5230\u5b8c\u7f8e\u7684\u53c2\u6570\u3002\u56e0\u6b64\uff0c\u7f51\u683c\u641c\u7d22\u5e76\u4e0d\u6d41\u884c\u3002\u8ba9\u6211\u4eec\u4ee5\u6839\u636e \u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4 \u6570\u636e\u96c6\u4e3a\u4f8b\uff0c\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 \u56fe 1\uff1a\u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4\u6570\u636e\u96c6\u5c55\u793a \u8bad\u7ec3\u96c6\u4e2d\u53ea\u6709 2000 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u53ef\u4ee5\u8f7b\u677e\u5730\u4f7f\u7528\u5206\u5c42 kfold \u548c\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5177\u6709\u4e0a\u8ff0\u53c2\u6570\u8303\u56f4\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\uff0c\u5e76\u5728\u4e0b\u9762\u7684\u793a\u4f8b\u4e2d\u4e86\u89e3\u5982\u4f55\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u3002 # rf_grid_search.py import numpy as np import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u5220\u9664 price_range \u5217 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u53d6\u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\uff0c\u4f7f\u7528\u6240\u6709\u53ef\u7528\u7684 CPU \u6838\u5fc3\u8fdb\u884c\u8bad\u7ec3 classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid = { \"n_estimators\" : [ 100 , 200 , 250 , 300 , 400 , 500 ], \"max_depth\" : [ 1 , 2 , 5 , 7 , 11 , 15 ], \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22 model = model_selection . GridSearchCV ( estimator = classifier , param_grid = param_grid , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( f \"Best score: { model . best_score_ } \" ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u8fd9\u91cc\u6253\u5370\u4e86\u5f88\u591a\u5185\u5bb9\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u6700\u540e\u51e0\u884c\u3002 [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.895 , total = 1.0 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.890 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.910 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.880 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.870 , total = 1.1 s [ Parallel ( n_jobs = 1 )]: Done 360 out of 360 | elapsed : 3.7 min finished Best score : 0.889 Best parameters set : criterion : 'entropy' max_depth : 15 n_estimators : 500 \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c5\u6298\u4ea4\u53c9\u68c0\u9a8c\u6700\u4f73\u5f97\u5206\u662f 0.889\uff0c\u6211\u4eec\u7684\u7f51\u683c\u641c\u7d22\u5f97\u5230\u4e86\u6700\u4f73\u53c2\u6570\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u7684\u4e0b\u4e00\u4e2a\u6700\u4f73\u65b9\u6cd5\u662f \u968f\u673a\u641c\u7d22 \u3002\u5728\u968f\u673a\u641c\u7d22\u4e2d\uff0c\u6211\u4eec\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u53c2\u6570\u7ec4\u5408\uff0c\u7136\u540e\u8ba1\u7b97\u4ea4\u53c9\u9a8c\u8bc1\u5f97\u5206\u3002\u8fd9\u91cc\u6d88\u8017\u7684\u65f6\u95f4\u6bd4\u7f51\u683c\u641c\u7d22\u5c11\uff0c\u56e0\u4e3a\u6211\u4eec\u4e0d\u5bf9\u6240\u6709\u4e0d\u540c\u7684\u53c2\u6570\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\u3002\u6211\u4eec\u9009\u62e9\u8981\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u5c11\u6b21\u8bc4\u4f30\uff0c\u8fd9\u5c31\u51b3\u5b9a\u4e86\u641c\u7d22\u6240\u9700\u7684\u65f6\u95f4\u3002\u4ee3\u7801\u4e0e\u4e0a\u9762\u7684\u5dee\u522b\u4e0d\u5927\u3002\u9664 GridSearchCV \u5916\uff0c\u6211\u4eec\u4f7f\u7528 RandomizedSearchCV\u3002 if __name__ == \"__main__\" : classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u66f4\u6539\u641c\u7d22\u7a7a\u95f4 param_grid = { \"n_estimators\" : np . arange ( 100 , 1500 , 100 ), \"max_depth\" : np . arange ( 1 , 31 ), \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u968f\u673a\u53c2\u6570\u641c\u7d22 model = model_selection . RandomizedSearchCV ( estimator = classifier , param_distributions = param_grid , n_iter = 20 , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) print ( f \"Best score: { model . best_score_ } \" ) print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u6211\u4eec\u66f4\u6539\u4e86\u968f\u673a\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c\uff0c\u7ed3\u679c\u4f3c\u4e4e\u6709\u4e86\u4e9b\u8bb8\u6539\u8fdb\u3002 Best score : 0.8905 Best parameters set : criterion : entropy max_depth : 25 n_estimators : 300 \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8f83\u5c11\uff0c\u968f\u673a\u641c\u7d22\u6bd4\u7f51\u683c\u641c\u7d22\u66f4\u5feb\u3002\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\uff0c\u4f60\u53ef\u4ee5\u4e3a\u5404\u79cd\u6a21\u578b\u627e\u5230\u6700\u4f18\u53c2\u6570\uff0c\u53ea\u8981\u5b83\u4eec\u6709\u62df\u5408\u548c\u9884\u6d4b\u529f\u80fd\uff0c\u8fd9\u4e5f\u662f scikit-learn \u7684\u6807\u51c6\u3002\u6709\u65f6\uff0c\u4f60\u53ef\u80fd\u60f3\u4f7f\u7528\u7ba1\u9053\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u7531\u4e24\u5217\u6587\u672c\u7ec4\u6210\uff0c\u4f60\u9700\u8981\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u7c7b\u522b\u3002\u8ba9\u6211\u4eec\u5047\u8bbe\u4f60\u9009\u62e9\u7684\u7ba1\u9053\u662f\u9996\u5148\u4ee5\u534a\u76d1\u7763\u7684\u65b9\u5f0f\u5e94\u7528 tf-idf\uff0c\u7136\u540e\u4f7f\u7528 SVD \u548c SVM \u5206\u7c7b\u5668\u3002\u73b0\u5728\u7684\u95ee\u9898\u662f\uff0c\u6211\u4eec\u5fc5\u987b\u9009\u62e9 SVD \u7684\u6210\u5206\uff0c\u8fd8\u9700\u8981\u8c03\u6574 SVM \u7684\u53c2\u6570\u3002\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import numpy as np import pandas as pd from sklearn import metrics from sklearn import model_selection from sklearn import pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # \u8ba1\u7b97\u52a0\u6743\u4e8c\u6b21 Kappa \u5206\u6570 def quadratic_weighted_kappa ( y_true , y_pred ): return metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) if __name__ == '__main__' : # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( '../input/train.csv' ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u63d0\u53d6 id \u5217\u7684\u503c\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u7c7b\u578b\uff0c\u5b58\u50a8\u5728\u53d8\u91cf idx \u4e2d idx = test . id . values . astype ( int ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 train = train . drop ( 'id' , axis = 1 ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 test = test . drop ( 'id' , axis = 1 ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf 'relevance' \uff0c\u5b58\u50a8\u5728\u53d8\u91cf y \u4e2d y = train . relevance . values # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 traindata \u4e2d traindata = list ( train . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 testdata \u4e2d testdata = list ( test . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u521b\u5efa\u4e00\u4e2a TfidfVectorizer \u5bf9\u8c61 tfv\uff0c\u7528\u4e8e\u5c06\u6587\u672c\u6570\u636e\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv = TfidfVectorizer ( min_df = 3 , max_features = None , strip_accents = 'unicode' , analyzer = 'word' , token_pattern = r '\\w{1,}' , ngram_range = ( 1 , 3 ), use_idf = 1 , smooth_idf = 1 , sublinear_tf = 1 , stop_words = 'english' ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408 TfidfVectorizer\uff0c\u5c06\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv . fit ( traindata ) # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X X = tfv . transform ( traindata ) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X_test X_test = tfv . transform ( testdata ) # \u521b\u5efa TruncatedSVD \u5bf9\u8c61 svd\uff0c\u7528\u4e8e\u8fdb\u884c\u5947\u5f02\u503c\u5206\u89e3 svd = TruncatedSVD () # \u521b\u5efa StandardScaler \u5bf9\u8c61 scl\uff0c\u7528\u4e8e\u8fdb\u884c\u7279\u5f81\u7f29\u653e scl = StandardScaler () # \u521b\u5efa\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668\u5bf9\u8c61 svm_model svm_model = SVC () # \u521b\u5efa\u673a\u5668\u5b66\u4e60\u7ba1\u9053 clf\uff0c\u5305\u542b\u5947\u5f02\u503c\u5206\u89e3\u3001\u7279\u5f81\u7f29\u653e\u548c\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668 clf = pipeline . Pipeline ( [ ( 'svd' , svd ), ( 'scl' , scl ), ( 'svm' , svm_model ) ] ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid param_grid = { 'svd__n_components' : [ 200 , 300 ], 'svm__C' : [ 10 , 12 ] } # \u521b\u5efa\u81ea\u5b9a\u4e49\u7684\u8bc4\u5206\u51fd\u6570 kappa_scorer\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd kappa_scorer = metrics . make_scorer ( quadratic_weighted_kappa , greater_is_better = True ) # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model = model_selection . GridSearchCV ( estimator = clf , param_grid = param_grid , scoring = kappa_scorer , verbose = 10 , n_jobs =- 1 , refit = True , cv = 5 ) # \u4f7f\u7528 GridSearchCV \u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( \"Best score: %0.3f \" % model . best_score_ ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( \" \\t %s : %r \" % ( param_name , best_parameters [ param_name ])) # \u83b7\u53d6\u6700\u4f73\u6a21\u578b best_model = model . best_estimator_ best_model . fit ( X , y ) # \u4f7f\u7528\u6700\u4f73\u6a21\u578b\u8fdb\u884c\u9884\u6d4b preds = best_model . predict ( ... ) \u8fd9\u91cc\u663e\u793a\u7684\u7ba1\u9053\u5305\u62ec SVD\uff08\u5947\u5f02\u503c\u5206\u89e3\uff09\u3001\u6807\u51c6\u7f29\u653e\u548c SVM\uff08\u652f\u6301\u5411\u91cf\u673a\uff09\u6a21\u578b\u3002\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e\u6ca1\u6709\u8bad\u7ec3\u6570\u636e\uff0c\u60a8\u65e0\u6cd5\u6309\u539f\u6837\u8fd0\u884c\u4e0a\u8ff0\u4ee3\u7801\u3002\u5f53\u6211\u4eec\u8fdb\u5165\u9ad8\u7ea7\u8d85\u53c2\u6570\u4f18\u5316\u6280\u672f\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4e0d\u540c\u7c7b\u578b\u7684 \u6700\u5c0f\u5316\u7b97\u6cd5 \u6765\u7814\u7a76\u51fd\u6570\u7684\u6700\u5c0f\u5316\u3002\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u591a\u79cd\u6700\u5c0f\u5316\u51fd\u6570\u6765\u5b9e\u73b0\uff0c\u5982\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u3001\u5185\u5c14\u5fb7-\u6885\u5fb7\u4f18\u5316\u7b97\u6cd5\u3001\u4f7f\u7528\u8d1d\u53f6\u65af\u6280\u672f\u548c\u9ad8\u65af\u8fc7\u7a0b\u5bfb\u627e\u6700\u4f18\u53c2\u6570\u6216\u4f7f\u7528\u9057\u4f20\u7b97\u6cd5\u3002\u6211\u5c06\u5728 \"\u96c6\u5408\u4e0e\u5806\u53e0\uff08ensembling and stacking\uff09 \"\u4e00\u7ae0\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u548c Nelder-Mead \u7b97\u6cd5\u7684\u5e94\u7528\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u9ad8\u65af\u8fc7\u7a0b\u5982\u4f55\u7528\u4e8e\u8d85\u53c2\u6570\u4f18\u5316\u3002\u8fd9\u7c7b\u7b97\u6cd5\u9700\u8981\u4e00\u4e2a\u53ef\u4ee5\u4f18\u5316\u7684\u51fd\u6570\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u90fd\u662f\u6700\u5c0f\u5316\u8fd9\u4e2a\u51fd\u6570\uff0c\u5c31\u50cf\u6211\u4eec\u6700\u5c0f\u5316\u635f\u5931\u4e00\u6837\u3002 \u56e0\u6b64\uff0c\u6bd4\u65b9\u8bf4\uff0c\u4f60\u60f3\u627e\u5230\u6700\u4f73\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u663e\u7136\uff0c\u51c6\u786e\u5ea6\u8d8a\u9ad8\u8d8a\u597d\u3002\u73b0\u5728\uff0c\u6211\u4eec\u4e0d\u80fd\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\uff0c\u4f46\u6211\u4eec\u53ef\u4ee5\u5c06\u7cbe\u786e\u5ea6\u4e58\u4ee5-1\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u662f\u5728\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\u7684\u8d1f\u503c\uff0c\u4f46\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u6700\u5927\u5316\u7cbe\u786e\u5ea6\u3002 \u5728\u9ad8\u65af\u8fc7\u7a0b\u4e2d\u4f7f\u7528\u8d1d\u53f6\u65af\u4f18\u5316\uff0c\u53ef\u4ee5\u4f7f\u7528 scikit-optimize (skopt) \u5e93\u4e2d\u7684 gp_minimize \u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4f7f\u7528\u8be5\u51fd\u6570\u8c03\u6574\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u53c2\u6570\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from skopt import gp_minimize from skopt import space def optimize ( params , param_names , x , y ): # \u5c06\u53c2\u6570\u540d\u79f0\u548c\u5bf9\u5e94\u7684\u503c\u6253\u5305\u6210\u5b57\u5178 params = dict ( zip ( param_names , params )) # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u6a21\u578b\uff0c\u4f7f\u7528\u4f20\u5165\u7684\u53c2\u6570\u914d\u7f6e model = ensemble . RandomForestClassifier ( ** params ) # \u521b\u5efa StratifiedKFold \u4ea4\u53c9\u9a8c\u8bc1\u5bf9\u8c61\uff0c\u5c06\u6570\u636e\u5206\u4e3a 5 \u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6298\u53e0\u7684\u51c6\u786e\u5ea6\u7684\u5217\u8868 accuracies = [] # \u5faa\u73af\u904d\u5386\u6bcf\u4e2a\u6298\u53e0\u7684\u8bad\u7ec3\u548c\u6d4b\u8bd5\u6570\u636e for idx in kf . split ( X = x , y = y ): train_idx , test_idx = idx [ 0 ], idx [ 1 ] xtrain = x [ train_idx ] ytrain = y [ train_idx ] xtest = x [ test_idx ] ytest = y [ test_idx ] # \u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u62df\u5408\u6a21\u578b model . fit ( xtrain , ytrain ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( xtest ) # \u8ba1\u7b97\u6298\u53e0\u7684\u51c6\u786e\u5ea6 fold_accuracy = metrics . accuracy_score ( ytest , preds ) accuracies . append ( fold_accuracy ) # \u8fd4\u56de\u5e73\u5747\u51c6\u786e\u5ea6\u7684\u8d1f\u6570\uff08\u56e0\u4e3a skopt \u4f7f\u7528\u8d1f\u6570\u6765\u6700\u5c0f\u5316\u76ee\u6807\u51fd\u6570\uff09 return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u53d6\u7279\u5f81\u77e9\u9635 X\uff08\u53bb\u6389\"price_range\"\u5217\uff09 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u5b9a\u4e49\u8d85\u53c2\u6570\u641c\u7d22\u7a7a\u95f4 param_space param_space = [ space . Integer ( 3 , 15 , name = \"max_depth\" ), space . Integer ( 100 , 1500 , name = \"n_estimators\" ), space . Categorical ([ \"gini\" , \"entropy\" ], name = \"criterion\" ), space . Real ( 0.01 , 1 , prior = \"uniform\" , name = \"max_features\" ) ] # \u5b9a\u4e49\u8d85\u53c2\u6570\u7684\u540d\u79f0\u5217\u8868 param_names param_names = [ \"max_depth\" , \"n_estimators\" , \"criterion\" , \"max_features\" ] # \u521b\u5efa\u51fd\u6570 optimization_function\uff0c\u7528\u4e8e\u4f20\u9012\u7ed9 gp_minimize optimization_function = partial ( optimize , param_names = param_names , x = X , y = y ) # \u4f7f\u7528 Bayesian Optimization\uff08\u57fa\u4e8e\u8d1d\u53f6\u65af\u4f18\u5316\uff09\u6765\u641c\u7d22\u6700\u4f73\u8d85\u53c2\u6570 result = gp_minimize ( optimization_function , dimensions = param_space , n_calls = 15 , n_random_starts = 10 , verbose = 10 ) # \u83b7\u53d6\u6700\u4f73\u8d85\u53c2\u6570\u7684\u5b57\u5178 best_params = dict ( zip ( param_names , result . x ) ) # \u6253\u5370\u51fa\u627e\u5230\u7684\u6700\u4f73\u8d85\u53c2\u6570 print ( best_params ) \u8fd9\u540c\u6837\u4f1a\u4ea7\u751f\u5927\u91cf\u8f93\u51fa\uff0c\u6700\u540e\u4e00\u90e8\u5206\u5982\u4e0b\u6240\u793a\u3002 Iteration No : 14 started . Searching for the next optimal point . Iteration No : 14 ended . Search finished for the next optimal point . Time taken : 4.7793 Function value obtained : - 0.9075 Current minimum : - 0.9075 Iteration No : 15 started . Searching for the next optimal point . Iteration No : 15 ended . Search finished for the next optimal point . Time taken : 49.4186 Function value obtained : - 0.9075 Current minimum : - 0.9075 { 'max_depth' : 12 , 'n_estimators' : 100 , 'criterion' : 'entropy' , 'max_features' : 1.0 } \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6210\u529f\u7a81\u7834\u4e86 0.90 \u7684\u51c6\u786e\u7387\u3002\u8fd9\u771f\u662f\u592a\u795e\u5947\u4e86\uff01 \u6211\u4eec\u8fd8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6bb5\u67e5\u770b\uff08\u7ed8\u5236\uff09\u6211\u4eec\u662f\u5982\u4f55\u5b9e\u73b0\u6536\u655b\u7684\u3002 from skopt.plots import plot_convergence plot_convergence ( result ) \u6536\u655b\u56fe\u5982\u56fe 2 \u6240\u793a\u3002 \u56fe 2\uff1a\u968f\u673a\u68ee\u6797\u53c2\u6570\u4f18\u5316\u7684\u6536\u655b\u56fe Scikit- optimize \u5c31\u662f\u8fd9\u6837\u4e00\u4e2a\u5e93\u3002 hyperopt \u4f7f\u7528\u6811\u72b6\u7ed3\u6784\u8d1d\u53f6\u65af\u4f30\u8ba1\u5668\uff08TPE\uff09\u6765\u627e\u5230\u6700\u4f18\u53c2\u6570\u3002\u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u7247\u6bb5\uff0c\u6211\u5728\u4f7f\u7528 hyperopt \u65f6\u5bf9\u4e4b\u524d\u7684\u4ee3\u7801\u505a\u4e86\u6700\u5c0f\u7684\u6539\u52a8\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from hyperopt import hp , fmin , tpe , Trials from hyperopt.pyll.base import scope def optimize ( params , x , y ): model = ensemble . RandomForestClassifier ( ** params ) kf = model_selection . StratifiedKFold ( n_splits = 5 ) ... return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/mobile_train.csv\" ) X = df . drop ( \"price_range\" , axis = 1 ) . values y = df . price_range . values # \u5b9a\u4e49\u641c\u7d22\u7a7a\u95f4\uff08\u6574\u578b\u3001\u6d6e\u70b9\u6570\u578b\u3001\u9009\u62e9\u578b\uff09 param_space = { \"max_depth\" : scope . int ( hp . quniform ( \"max_depth\" , 1 , 15 , 1 )), \"n_estimators\" : scope . int ( hp . quniform ( \"n_estimators\" , 100 , 1500 , 1 ) ), \"criterion\" : hp . choice ( \"criterion\" , [ \"gini\" , \"entropy\" ]), \"max_features\" : hp . uniform ( \"max_features\" , 0 , 1 ) } # \u5305\u88c5\u51fd\u6570 optimization_function = partial ( optimize , x = X , y = y ) # \u5f00\u59cb\u8bad\u7ec3 trials = Trials () # \u6700\u5c0f\u5316\u76ee\u6807\u503c hopt = fmin ( fn = optimization_function , space = param_space , algo = tpe . suggest , max_evals = 15 , trials = trials ) #\u6253\u5370\u6700\u4f73\u53c2\u6570 print ( hopt ) \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u4e0e\u4e4b\u524d\u7684\u4ee3\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u4f60\u5fc5\u987b\u4ee5\u4e0d\u540c\u7684\u683c\u5f0f\u5b9a\u4e49\u53c2\u6570\u7a7a\u95f4\uff0c\u8fd8\u9700\u8981\u6539\u53d8\u5b9e\u9645\u4f18\u5316\u90e8\u5206\uff0c\u7528 hyperopt \u4ee3\u66ff gp_minimize\u3002\u7ed3\u679c\u76f8\u5f53\u4e0d\u9519\uff01 \u276f python rf_hyperopt . py 100 %| \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 15 / 15 [ 04 : 38 < 00 : 00 , 18.57 s / trial , best loss : - 0.9095000000000001 ] { 'criterion' : 1 , 'max_depth' : 11.0 , 'max_features' : 0.821163568049807 , 'n_estimators' : 806.0 } \u6211\u4eec\u5f97\u5230\u4e86\u6bd4\u4ee5\u524d\u66f4\u597d\u7684\u51c6\u786e\u5ea6\u548c\u4e00\u7ec4\u53ef\u4ee5\u4f7f\u7528\u7684\u53c2\u6570\u3002\u8bf7\u6ce8\u610f\uff0c\u6700\u7ec8\u7ed3\u679c\u4e2d\u7684\u6807\u51c6\u662f 1\u3002\u8fd9\u610f\u5473\u7740\u9009\u62e9\u4e86 1\uff0c\u5373\u71b5\u3002 \u4e0a\u8ff0\u8c03\u6574\u8d85\u53c2\u6570\u7684\u65b9\u6cd5\u662f\u6700\u5e38\u89c1\u7684\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u6a21\u578b\uff1a\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001\u57fa\u4e8e\u6811\u7684\u65b9\u6cd5\u3001\u68af\u5ea6\u63d0\u5347\u6a21\u578b\uff08\u5982 xgboost\u3001lightgbm\uff09\uff0c\u751a\u81f3\u795e\u7ecf\u7f51\u7edc\uff01 \u867d\u7136\u8fd9\u4e9b\u65b9\u6cd5\u5df2\u7ecf\u5b58\u5728\uff0c\u4f46\u5b66\u4e60\u65f6\u5fc5\u987b\u4ece\u624b\u52a8\u8c03\u6574\u8d85\u53c2\u6570\u5f00\u59cb\uff0c\u5373\u624b\u5de5\u8c03\u6574\u3002\u624b\u52a8\u8c03\u6574\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5b66\u4e60\u57fa\u7840\u77e5\u8bc6\uff0c\u4f8b\u5982\uff0c\u5728\u68af\u5ea6\u63d0\u5347\u4e2d\uff0c\u5f53\u4f60\u589e\u52a0\u6df1\u5ea6\u65f6\uff0c\u4f60\u5e94\u8be5\u964d\u4f4e\u5b66\u4e60\u7387\u3002\u5982\u679c\u4f7f\u7528\u81ea\u52a8\u5de5\u5177\uff0c\u5c31\u65e0\u6cd5\u5b66\u4e60\u5230\u8fd9\u4e00\u70b9\u3002\u8bf7\u53c2\u8003\u4e0b\u8868\uff0c\u4e86\u89e3\u5e94\u5982\u4f55\u8c03\u6574\u3002RS* \u8868\u793a\u968f\u673a\u641c\u7d22\u5e94\u8be5\u66f4\u597d\u3002 \u4e00\u65e6\u4f60\u80fd\u66f4\u597d\u5730\u624b\u52a8\u8c03\u6574\u53c2\u6570\uff0c\u4f60\u751a\u81f3\u53ef\u80fd\u4e0d\u9700\u8981\u4efb\u4f55\u81ea\u52a8\u8d85\u53c2\u6570\u8c03\u6574\u3002\u521b\u5efa\u5927\u578b\u6a21\u578b\u6216\u5f15\u5165\u5927\u91cf\u7279\u5f81\u65f6\uff0c\u4e5f\u5bb9\u6613\u9020\u6210\u8bad\u7ec3\u6570\u636e\u7684\u8fc7\u5ea6\u62df\u5408\u3002\u4e3a\u907f\u514d\u8fc7\u5ea6\u62df\u5408\uff0c\u9700\u8981\u5728\u8bad\u7ec3\u6570\u636e\u7279\u5f81\u4e2d\u5f15\u5165\u566a\u58f0\u6216\u5bf9\u4ee3\u4ef7\u51fd\u6570\u8fdb\u884c\u60e9\u7f5a\u3002\u8fd9\u79cd\u60e9\u7f5a\u79f0\u4e3a \u6b63\u5219\u5316 \uff0c\u6709\u52a9\u4e8e\u6cdb\u5316\u6a21\u578b\u3002\u5728\u7ebf\u6027\u6a21\u578b\u4e2d\uff0c\u6700\u5e38\u89c1\u7684\u6b63\u5219\u5316\u7c7b\u578b\u662f L1 \u548c L2\u3002L1 \u4e5f\u79f0\u4e3a Lasso \u56de\u5f52\uff0cL2 \u79f0\u4e3a Ridge \u56de\u5f52\u3002\u8bf4\u5230\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528dropout\u3001\u6dfb\u52a0\u589e\u5f3a\u3001\u566a\u58f0\u7b49\u65b9\u6cd5\u5bf9\u6a21\u578b\u8fdb\u884c\u6b63\u5219\u5316\u3002\u5229\u7528\u8d85\u53c2\u6570\u4f18\u5316\uff0c\u8fd8\u53ef\u4ee5\u627e\u5230\u6b63\u786e\u7684\u60e9\u7f5a\u65b9\u6cd5\u3002 Model Optimize Range of values Linear Regression - fit_intercept - normalize - True/False - True/False Ridge - alpha - fit_intercept - normalize - 0.01, 0.1, 1.0, 10, 100 - True/False - True/False k-neighbors - n_neighbors - p - 2, 4, 8, 16, ... - 2, 3, ... SVM - C - gamma - class_weight - 0.001, 0.01, ...,10, 100, 1000 - 'auto', RS* - 'balanced', None Logistic Regression - Penalyt - C - L1 or L2 - 0.001, 0.01, ..., 10, ..., 100 Lasso - Alpha - Normalize - 0.1, 1.0, 10 - True/False Random Forest - n_estimators - max_depth - min_samples_split - min_samples_leaf - max features - 120, 300, 500, 800, 1200 - 5, 8, 15, 25, 30, None - 1, 2, 5, 10, 15, 100 - log2, sqrt, None XGBoost - eta - gamma - max_depth - min_child_weight - subsample - colsample_bytree - lambda - alpha - 0.01, 0.015, 0.025, 0.05, 0.1 - 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0 - 3, 5, 7, 9, 12, 15, 17, 25 - 1, 3, 5, 7 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.01, 0.1, 1.0, RS - 0, 0.1, 0.5, 1.0, RS","title":"\u8d85\u53c2\u6570\u4f18\u5316"},{"location":"%E8%B6%85%E5%8F%82%E6%95%B0%E4%BC%98%E5%8C%96/#_1","text":"\u6709\u4e86\u4f18\u79c0\u7684\u6a21\u578b\uff0c\u5c31\u6709\u4e86\u4f18\u5316\u8d85\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u5f97\u5206\u6a21\u578b\u7684\u96be\u9898\u3002\u90a3\u4e48\uff0c\u4ec0\u4e48\u662f\u8d85\u53c2\u6570\u4f18\u5316\u5462\uff1f\u5047\u8bbe\u60a8\u7684\u673a\u5668\u5b66\u4e60\u9879\u76ee\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u6d41\u7a0b\u3002\u6709\u4e00\u4e2a\u6570\u636e\u96c6\uff0c\u4f60\u76f4\u63a5\u5e94\u7528\u4e00\u4e2a\u6a21\u578b\uff0c\u7136\u540e\u5f97\u5230\u7ed3\u679c\u3002\u6a21\u578b\u5728\u8fd9\u91cc\u7684\u53c2\u6570\u88ab\u79f0\u4e3a\u8d85\u53c2\u6570\uff0c\u5373\u63a7\u5236\u6a21\u578b\u8bad\u7ec3/\u62df\u5408\u8fc7\u7a0b\u7684\u53c2\u6570\u3002\u5982\u679c\u6211\u4eec\u7528 SGD \u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u6a21\u578b\u7684\u53c2\u6570\u662f\u659c\u7387\u548c\u504f\u5dee\uff0c\u8d85\u53c2\u6570\u662f\u5b66\u4e60\u7387\u3002\u4f60\u4f1a\u53d1\u73b0\u6211\u5728\u672c\u7ae0\u548c\u672c\u4e66\u4e2d\u4ea4\u66ff\u4f7f\u7528\u8fd9\u4e9b\u672f\u8bed\u3002\u5047\u8bbe\u6a21\u578b\u4e2d\u6709\u4e09\u4e2a\u53c2\u6570 a\u3001b\u3001c\uff0c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u90fd\u53ef\u4ee5\u662f 1 \u5230 10 \u4e4b\u95f4\u7684\u6574\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u7684 \"\u6b63\u786e \"\u7ec4\u5408\u5c06\u4e3a\u60a8\u63d0\u4f9b\u6700\u4f73\u7ed3\u679c\u3002\u56e0\u6b64\uff0c\u8fd9\u5c31\u6709\u70b9\u50cf\u4e00\u4e2a\u88c5\u6709\u4e09\u62e8\u5bc6\u7801\u9501\u7684\u624b\u63d0\u7bb1\u3002\u4e0d\u8fc7\uff0c\u4e09\u62e8\u5bc6\u7801\u9501\u53ea\u6709\u4e00\u4e2a\u6b63\u786e\u7b54\u6848\u3002\u800c\u6a21\u578b\u6709\u5f88\u591a\u6b63\u786e\u7b54\u6848\u3002\u90a3\u4e48\uff0c\u5982\u4f55\u627e\u5230\u6700\u4f73\u53c2\u6570\u5462\uff1f\u4e00\u79cd\u65b9\u6cd5\u662f\u5bf9\u6240\u6709\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\uff0c\u770b\u54ea\u79cd\u7ec4\u5408\u80fd\u63d0\u9ad8\u6307\u6807\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u521d\u59cb\u5316\u6700\u4f73\u51c6\u786e\u5ea6 best_accuracy = 0 # \u521d\u59cb\u5316\u6700\u4f73\u53c2\u6570\u7684\u5b57\u5178 best_parameters = { \"a\" : 0 , \"b\" : 0 , \"c\" : 0 } # \u5faa\u73af\u904d\u5386 a \u7684\u53d6\u503c\u8303\u56f4 1~10 for a in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 b \u7684\u53d6\u503c\u8303\u56f4 1~10 for b in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 c \u7684\u53d6\u503c\u8303\u56f4 1~10 for c in range ( 1 , 11 ): # \u521b\u5efa\u6a21\u578b\uff0c\u4f7f\u7528 a\u3001b\u3001c \u53c2\u6570 model = MODEL ( a , b , c ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u6a21\u578b model . fit ( training_data ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u9a8c\u8bc1\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( validation_data ) # \u8ba1\u7b97\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 accuracy = metrics . accuracy_score ( targets , preds ) # \u5982\u679c\u5f53\u524d\u51c6\u786e\u5ea6\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u5219\u66f4\u65b0\u6700\u4f73\u51c6\u786e\u5ea6\u548c\u6700\u4f73\u53c2\u6570 if accuracy > best_accuracy : best_accuracy = accuracy best_parameters [ \"a\" ] = a best_parameters [ \"b\" ] = b best_parameters [ \"c\" ] = c \u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4ece 1 \u5230 10 \u5bf9\u6240\u6709\u53c2\u6570\u8fdb\u884c\u4e86\u62df\u5408\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u603b\u5171\u8981\u5bf9\u6a21\u578b\u8fdb\u884c 1000 \u6b21\uff0810 x 10 x 10\uff09\u62df\u5408\u3002\u8fd9\u53ef\u80fd\u4f1a\u5f88\u6602\u8d35\uff0c\u56e0\u4e3a\u6a21\u578b\u7684\u8bad\u7ec3\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u8fc7\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\u5e94\u8be5\u6ca1\u95ee\u9898\uff0c\u4f46\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\uff0c\u5e76\u4e0d\u662f\u53ea\u6709\u4e09\u4e2a\u53c2\u6570\uff0c\u6bcf\u4e2a\u53c2\u6570\u4e5f\u4e0d\u662f\u53ea\u6709\u5341\u4e2a\u503c\u3002 \u5927\u591a\u6570\u6a21\u578b\u53c2\u6570\u90fd\u662f\u5b9e\u6570\uff0c\u4e0d\u540c\u53c2\u6570\u7684\u7ec4\u5408\u53ef\u4ee5\u662f\u65e0\u9650\u7684\u3002 \u8ba9\u6211\u4eec\u770b\u770b scikit-learn \u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002 RandomForestClassifier ( n_estimators = 100 , criterion = 'gini' , max_depth = None , min_samples_split = 2 , min_samples_leaf = 1 , min_weight_fraction_leaf = 0.0 , max_features = 'auto' , max_leaf_nodes = None , min_impurity_decrease = 0.0 , min_impurity_split = None , bootstrap = True , oob_score = False , n_jobs = None , random_state = None , verbose = 0 , warm_start = False , class_weight = None , ccp_alpha = 0.0 , max_samples = None , ) \u6709 19 \u4e2a\u53c2\u6570\uff0c\u800c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u7684\u6240\u6709\u7ec4\u5408\uff0c\u4ee5\u53ca\u5b83\u4eec\u53ef\u4ee5\u627f\u62c5\u7684\u6240\u6709\u503c\uff0c\u90fd\u5c06\u662f\u65e0\u7a77\u65e0\u5c3d\u7684\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u6ca1\u6709\u8db3\u591f\u7684\u8d44\u6e90\u548c\u65f6\u95f4\u6765\u505a\u8fd9\u4ef6\u4e8b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6307\u5b9a\u4e86\u4e00\u4e2a\u53c2\u6570\u7f51\u683c\u3002\u5728\u8fd9\u4e2a\u7f51\u683c\u4e0a\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408\u7684\u641c\u7d22\u79f0\u4e3a\u7f51\u683c\u641c\u7d22\u3002\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0cn_estimators \u53ef\u4ee5\u662f 100\u3001200\u3001250\u3001300\u3001400\u3001500\uff1bmax_depth \u53ef\u4ee5\u662f 1\u30012\u30015\u30017\u300111\u300115\uff1bcriterion \u53ef\u4ee5\u662f gini \u6216 entropy\u3002\u8fd9\u4e9b\u53c2\u6570\u770b\u8d77\u6765\u5e76\u4e0d\u591a\uff0c\u4f46\u5982\u679c\u6570\u636e\u96c6\u8fc7\u5927\uff0c\u8ba1\u7b97\u8d77\u6765\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u53ef\u4ee5\u50cf\u4e4b\u524d\u4e00\u6837\u521b\u5efa\u4e09\u4e2a for \u5faa\u73af\uff0c\u5e76\u5728\u9a8c\u8bc1\u96c6\u4e0a\u8ba1\u7b97\u5f97\u5206\uff0c\u8fd9\u6837\u5c31\u80fd\u5b9e\u73b0\u7f51\u683c\u641c\u7d22\u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u8981\u8fdb\u884c k \u6298\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5219\u9700\u8981\u66f4\u591a\u7684\u5faa\u73af\uff0c\u8fd9\u610f\u5473\u7740\u9700\u8981\u66f4\u591a\u7684\u65f6\u95f4\u6765\u627e\u5230\u5b8c\u7f8e\u7684\u53c2\u6570\u3002\u56e0\u6b64\uff0c\u7f51\u683c\u641c\u7d22\u5e76\u4e0d\u6d41\u884c\u3002\u8ba9\u6211\u4eec\u4ee5\u6839\u636e \u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4 \u6570\u636e\u96c6\u4e3a\u4f8b\uff0c\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 \u56fe 1\uff1a\u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4\u6570\u636e\u96c6\u5c55\u793a \u8bad\u7ec3\u96c6\u4e2d\u53ea\u6709 2000 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u53ef\u4ee5\u8f7b\u677e\u5730\u4f7f\u7528\u5206\u5c42 kfold \u548c\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5177\u6709\u4e0a\u8ff0\u53c2\u6570\u8303\u56f4\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\uff0c\u5e76\u5728\u4e0b\u9762\u7684\u793a\u4f8b\u4e2d\u4e86\u89e3\u5982\u4f55\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u3002 # rf_grid_search.py import numpy as np import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u5220\u9664 price_range \u5217 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u53d6\u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\uff0c\u4f7f\u7528\u6240\u6709\u53ef\u7528\u7684 CPU \u6838\u5fc3\u8fdb\u884c\u8bad\u7ec3 classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid = { \"n_estimators\" : [ 100 , 200 , 250 , 300 , 400 , 500 ], \"max_depth\" : [ 1 , 2 , 5 , 7 , 11 , 15 ], \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22 model = model_selection . GridSearchCV ( estimator = classifier , param_grid = param_grid , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( f \"Best score: { model . best_score_ } \" ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u8fd9\u91cc\u6253\u5370\u4e86\u5f88\u591a\u5185\u5bb9\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u6700\u540e\u51e0\u884c\u3002 [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.895 , total = 1.0 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.890 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.910 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.880 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.870 , total = 1.1 s [ Parallel ( n_jobs = 1 )]: Done 360 out of 360 | elapsed : 3.7 min finished Best score : 0.889 Best parameters set : criterion : 'entropy' max_depth : 15 n_estimators : 500 \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c5\u6298\u4ea4\u53c9\u68c0\u9a8c\u6700\u4f73\u5f97\u5206\u662f 0.889\uff0c\u6211\u4eec\u7684\u7f51\u683c\u641c\u7d22\u5f97\u5230\u4e86\u6700\u4f73\u53c2\u6570\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u7684\u4e0b\u4e00\u4e2a\u6700\u4f73\u65b9\u6cd5\u662f \u968f\u673a\u641c\u7d22 \u3002\u5728\u968f\u673a\u641c\u7d22\u4e2d\uff0c\u6211\u4eec\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u53c2\u6570\u7ec4\u5408\uff0c\u7136\u540e\u8ba1\u7b97\u4ea4\u53c9\u9a8c\u8bc1\u5f97\u5206\u3002\u8fd9\u91cc\u6d88\u8017\u7684\u65f6\u95f4\u6bd4\u7f51\u683c\u641c\u7d22\u5c11\uff0c\u56e0\u4e3a\u6211\u4eec\u4e0d\u5bf9\u6240\u6709\u4e0d\u540c\u7684\u53c2\u6570\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\u3002\u6211\u4eec\u9009\u62e9\u8981\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u5c11\u6b21\u8bc4\u4f30\uff0c\u8fd9\u5c31\u51b3\u5b9a\u4e86\u641c\u7d22\u6240\u9700\u7684\u65f6\u95f4\u3002\u4ee3\u7801\u4e0e\u4e0a\u9762\u7684\u5dee\u522b\u4e0d\u5927\u3002\u9664 GridSearchCV \u5916\uff0c\u6211\u4eec\u4f7f\u7528 RandomizedSearchCV\u3002 if __name__ == \"__main__\" : classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u66f4\u6539\u641c\u7d22\u7a7a\u95f4 param_grid = { \"n_estimators\" : np . arange ( 100 , 1500 , 100 ), \"max_depth\" : np . arange ( 1 , 31 ), \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u968f\u673a\u53c2\u6570\u641c\u7d22 model = model_selection . RandomizedSearchCV ( estimator = classifier , param_distributions = param_grid , n_iter = 20 , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) print ( f \"Best score: { model . best_score_ } \" ) print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u6211\u4eec\u66f4\u6539\u4e86\u968f\u673a\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c\uff0c\u7ed3\u679c\u4f3c\u4e4e\u6709\u4e86\u4e9b\u8bb8\u6539\u8fdb\u3002 Best score : 0.8905 Best parameters set : criterion : entropy max_depth : 25 n_estimators : 300 \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8f83\u5c11\uff0c\u968f\u673a\u641c\u7d22\u6bd4\u7f51\u683c\u641c\u7d22\u66f4\u5feb\u3002\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\uff0c\u4f60\u53ef\u4ee5\u4e3a\u5404\u79cd\u6a21\u578b\u627e\u5230\u6700\u4f18\u53c2\u6570\uff0c\u53ea\u8981\u5b83\u4eec\u6709\u62df\u5408\u548c\u9884\u6d4b\u529f\u80fd\uff0c\u8fd9\u4e5f\u662f scikit-learn \u7684\u6807\u51c6\u3002\u6709\u65f6\uff0c\u4f60\u53ef\u80fd\u60f3\u4f7f\u7528\u7ba1\u9053\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u7531\u4e24\u5217\u6587\u672c\u7ec4\u6210\uff0c\u4f60\u9700\u8981\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u7c7b\u522b\u3002\u8ba9\u6211\u4eec\u5047\u8bbe\u4f60\u9009\u62e9\u7684\u7ba1\u9053\u662f\u9996\u5148\u4ee5\u534a\u76d1\u7763\u7684\u65b9\u5f0f\u5e94\u7528 tf-idf\uff0c\u7136\u540e\u4f7f\u7528 SVD \u548c SVM \u5206\u7c7b\u5668\u3002\u73b0\u5728\u7684\u95ee\u9898\u662f\uff0c\u6211\u4eec\u5fc5\u987b\u9009\u62e9 SVD \u7684\u6210\u5206\uff0c\u8fd8\u9700\u8981\u8c03\u6574 SVM \u7684\u53c2\u6570\u3002\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import numpy as np import pandas as pd from sklearn import metrics from sklearn import model_selection from sklearn import pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # \u8ba1\u7b97\u52a0\u6743\u4e8c\u6b21 Kappa \u5206\u6570 def quadratic_weighted_kappa ( y_true , y_pred ): return metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) if __name__ == '__main__' : # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( '../input/train.csv' ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u63d0\u53d6 id \u5217\u7684\u503c\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u7c7b\u578b\uff0c\u5b58\u50a8\u5728\u53d8\u91cf idx \u4e2d idx = test . id . values . astype ( int ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 train = train . drop ( 'id' , axis = 1 ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 test = test . drop ( 'id' , axis = 1 ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf 'relevance' \uff0c\u5b58\u50a8\u5728\u53d8\u91cf y \u4e2d y = train . relevance . values # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 traindata \u4e2d traindata = list ( train . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 testdata \u4e2d testdata = list ( test . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u521b\u5efa\u4e00\u4e2a TfidfVectorizer \u5bf9\u8c61 tfv\uff0c\u7528\u4e8e\u5c06\u6587\u672c\u6570\u636e\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv = TfidfVectorizer ( min_df = 3 , max_features = None , strip_accents = 'unicode' , analyzer = 'word' , token_pattern = r '\\w{1,}' , ngram_range = ( 1 , 3 ), use_idf = 1 , smooth_idf = 1 , sublinear_tf = 1 , stop_words = 'english' ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408 TfidfVectorizer\uff0c\u5c06\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv . fit ( traindata ) # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X X = tfv . transform ( traindata ) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X_test X_test = tfv . transform ( testdata ) # \u521b\u5efa TruncatedSVD \u5bf9\u8c61 svd\uff0c\u7528\u4e8e\u8fdb\u884c\u5947\u5f02\u503c\u5206\u89e3 svd = TruncatedSVD () # \u521b\u5efa StandardScaler \u5bf9\u8c61 scl\uff0c\u7528\u4e8e\u8fdb\u884c\u7279\u5f81\u7f29\u653e scl = StandardScaler () # \u521b\u5efa\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668\u5bf9\u8c61 svm_model svm_model = SVC () # \u521b\u5efa\u673a\u5668\u5b66\u4e60\u7ba1\u9053 clf\uff0c\u5305\u542b\u5947\u5f02\u503c\u5206\u89e3\u3001\u7279\u5f81\u7f29\u653e\u548c\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668 clf = pipeline . Pipeline ( [ ( 'svd' , svd ), ( 'scl' , scl ), ( 'svm' , svm_model ) ] ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid param_grid = { 'svd__n_components' : [ 200 , 300 ], 'svm__C' : [ 10 , 12 ] } # \u521b\u5efa\u81ea\u5b9a\u4e49\u7684\u8bc4\u5206\u51fd\u6570 kappa_scorer\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd kappa_scorer = metrics . make_scorer ( quadratic_weighted_kappa , greater_is_better = True ) # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model = model_selection . GridSearchCV ( estimator = clf , param_grid = param_grid , scoring = kappa_scorer , verbose = 10 , n_jobs =- 1 , refit = True , cv = 5 ) # \u4f7f\u7528 GridSearchCV \u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( \"Best score: %0.3f \" % model . best_score_ ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( \" \\t %s : %r \" % ( param_name , best_parameters [ param_name ])) # \u83b7\u53d6\u6700\u4f73\u6a21\u578b best_model = model . best_estimator_ best_model . fit ( X , y ) # \u4f7f\u7528\u6700\u4f73\u6a21\u578b\u8fdb\u884c\u9884\u6d4b preds = best_model . predict ( ... ) \u8fd9\u91cc\u663e\u793a\u7684\u7ba1\u9053\u5305\u62ec SVD\uff08\u5947\u5f02\u503c\u5206\u89e3\uff09\u3001\u6807\u51c6\u7f29\u653e\u548c SVM\uff08\u652f\u6301\u5411\u91cf\u673a\uff09\u6a21\u578b\u3002\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e\u6ca1\u6709\u8bad\u7ec3\u6570\u636e\uff0c\u60a8\u65e0\u6cd5\u6309\u539f\u6837\u8fd0\u884c\u4e0a\u8ff0\u4ee3\u7801\u3002\u5f53\u6211\u4eec\u8fdb\u5165\u9ad8\u7ea7\u8d85\u53c2\u6570\u4f18\u5316\u6280\u672f\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4e0d\u540c\u7c7b\u578b\u7684 \u6700\u5c0f\u5316\u7b97\u6cd5 \u6765\u7814\u7a76\u51fd\u6570\u7684\u6700\u5c0f\u5316\u3002\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u591a\u79cd\u6700\u5c0f\u5316\u51fd\u6570\u6765\u5b9e\u73b0\uff0c\u5982\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u3001\u5185\u5c14\u5fb7-\u6885\u5fb7\u4f18\u5316\u7b97\u6cd5\u3001\u4f7f\u7528\u8d1d\u53f6\u65af\u6280\u672f\u548c\u9ad8\u65af\u8fc7\u7a0b\u5bfb\u627e\u6700\u4f18\u53c2\u6570\u6216\u4f7f\u7528\u9057\u4f20\u7b97\u6cd5\u3002\u6211\u5c06\u5728 \"\u96c6\u5408\u4e0e\u5806\u53e0\uff08ensembling and stacking\uff09 \"\u4e00\u7ae0\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u548c Nelder-Mead \u7b97\u6cd5\u7684\u5e94\u7528\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u9ad8\u65af\u8fc7\u7a0b\u5982\u4f55\u7528\u4e8e\u8d85\u53c2\u6570\u4f18\u5316\u3002\u8fd9\u7c7b\u7b97\u6cd5\u9700\u8981\u4e00\u4e2a\u53ef\u4ee5\u4f18\u5316\u7684\u51fd\u6570\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u90fd\u662f\u6700\u5c0f\u5316\u8fd9\u4e2a\u51fd\u6570\uff0c\u5c31\u50cf\u6211\u4eec\u6700\u5c0f\u5316\u635f\u5931\u4e00\u6837\u3002 \u56e0\u6b64\uff0c\u6bd4\u65b9\u8bf4\uff0c\u4f60\u60f3\u627e\u5230\u6700\u4f73\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u663e\u7136\uff0c\u51c6\u786e\u5ea6\u8d8a\u9ad8\u8d8a\u597d\u3002\u73b0\u5728\uff0c\u6211\u4eec\u4e0d\u80fd\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\uff0c\u4f46\u6211\u4eec\u53ef\u4ee5\u5c06\u7cbe\u786e\u5ea6\u4e58\u4ee5-1\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u662f\u5728\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\u7684\u8d1f\u503c\uff0c\u4f46\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u6700\u5927\u5316\u7cbe\u786e\u5ea6\u3002 \u5728\u9ad8\u65af\u8fc7\u7a0b\u4e2d\u4f7f\u7528\u8d1d\u53f6\u65af\u4f18\u5316\uff0c\u53ef\u4ee5\u4f7f\u7528 scikit-optimize (skopt) \u5e93\u4e2d\u7684 gp_minimize \u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4f7f\u7528\u8be5\u51fd\u6570\u8c03\u6574\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u53c2\u6570\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from skopt import gp_minimize from skopt import space def optimize ( params , param_names , x , y ): # \u5c06\u53c2\u6570\u540d\u79f0\u548c\u5bf9\u5e94\u7684\u503c\u6253\u5305\u6210\u5b57\u5178 params = dict ( zip ( param_names , params )) # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u6a21\u578b\uff0c\u4f7f\u7528\u4f20\u5165\u7684\u53c2\u6570\u914d\u7f6e model = ensemble . RandomForestClassifier ( ** params ) # \u521b\u5efa StratifiedKFold \u4ea4\u53c9\u9a8c\u8bc1\u5bf9\u8c61\uff0c\u5c06\u6570\u636e\u5206\u4e3a 5 \u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6298\u53e0\u7684\u51c6\u786e\u5ea6\u7684\u5217\u8868 accuracies = [] # \u5faa\u73af\u904d\u5386\u6bcf\u4e2a\u6298\u53e0\u7684\u8bad\u7ec3\u548c\u6d4b\u8bd5\u6570\u636e for idx in kf . split ( X = x , y = y ): train_idx , test_idx = idx [ 0 ], idx [ 1 ] xtrain = x [ train_idx ] ytrain = y [ train_idx ] xtest = x [ test_idx ] ytest = y [ test_idx ] # \u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u62df\u5408\u6a21\u578b model . fit ( xtrain , ytrain ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( xtest ) # \u8ba1\u7b97\u6298\u53e0\u7684\u51c6\u786e\u5ea6 fold_accuracy = metrics . accuracy_score ( ytest , preds ) accuracies . append ( fold_accuracy ) # \u8fd4\u56de\u5e73\u5747\u51c6\u786e\u5ea6\u7684\u8d1f\u6570\uff08\u56e0\u4e3a skopt \u4f7f\u7528\u8d1f\u6570\u6765\u6700\u5c0f\u5316\u76ee\u6807\u51fd\u6570\uff09 return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u53d6\u7279\u5f81\u77e9\u9635 X\uff08\u53bb\u6389\"price_range\"\u5217\uff09 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u5b9a\u4e49\u8d85\u53c2\u6570\u641c\u7d22\u7a7a\u95f4 param_space param_space = [ space . Integer ( 3 , 15 , name = \"max_depth\" ), space . Integer ( 100 , 1500 , name = \"n_estimators\" ), space . Categorical ([ \"gini\" , \"entropy\" ], name = \"criterion\" ), space . Real ( 0.01 , 1 , prior = \"uniform\" , name = \"max_features\" ) ] # \u5b9a\u4e49\u8d85\u53c2\u6570\u7684\u540d\u79f0\u5217\u8868 param_names param_names = [ \"max_depth\" , \"n_estimators\" , \"criterion\" , \"max_features\" ] # \u521b\u5efa\u51fd\u6570 optimization_function\uff0c\u7528\u4e8e\u4f20\u9012\u7ed9 gp_minimize optimization_function = partial ( optimize , param_names = param_names , x = X , y = y ) # \u4f7f\u7528 Bayesian Optimization\uff08\u57fa\u4e8e\u8d1d\u53f6\u65af\u4f18\u5316\uff09\u6765\u641c\u7d22\u6700\u4f73\u8d85\u53c2\u6570 result = gp_minimize ( optimization_function , dimensions = param_space , n_calls = 15 , n_random_starts = 10 , verbose = 10 ) # \u83b7\u53d6\u6700\u4f73\u8d85\u53c2\u6570\u7684\u5b57\u5178 best_params = dict ( zip ( param_names , result . x ) ) # \u6253\u5370\u51fa\u627e\u5230\u7684\u6700\u4f73\u8d85\u53c2\u6570 print ( best_params ) \u8fd9\u540c\u6837\u4f1a\u4ea7\u751f\u5927\u91cf\u8f93\u51fa\uff0c\u6700\u540e\u4e00\u90e8\u5206\u5982\u4e0b\u6240\u793a\u3002 Iteration No : 14 started . Searching for the next optimal point . Iteration No : 14 ended . Search finished for the next optimal point . Time taken : 4.7793 Function value obtained : - 0.9075 Current minimum : - 0.9075 Iteration No : 15 started . Searching for the next optimal point . Iteration No : 15 ended . Search finished for the next optimal point . Time taken : 49.4186 Function value obtained : - 0.9075 Current minimum : - 0.9075 { 'max_depth' : 12 , 'n_estimators' : 100 , 'criterion' : 'entropy' , 'max_features' : 1.0 } \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6210\u529f\u7a81\u7834\u4e86 0.90 \u7684\u51c6\u786e\u7387\u3002\u8fd9\u771f\u662f\u592a\u795e\u5947\u4e86\uff01 \u6211\u4eec\u8fd8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6bb5\u67e5\u770b\uff08\u7ed8\u5236\uff09\u6211\u4eec\u662f\u5982\u4f55\u5b9e\u73b0\u6536\u655b\u7684\u3002 from skopt.plots import plot_convergence plot_convergence ( result ) \u6536\u655b\u56fe\u5982\u56fe 2 \u6240\u793a\u3002 \u56fe 2\uff1a\u968f\u673a\u68ee\u6797\u53c2\u6570\u4f18\u5316\u7684\u6536\u655b\u56fe Scikit- optimize \u5c31\u662f\u8fd9\u6837\u4e00\u4e2a\u5e93\u3002 hyperopt \u4f7f\u7528\u6811\u72b6\u7ed3\u6784\u8d1d\u53f6\u65af\u4f30\u8ba1\u5668\uff08TPE\uff09\u6765\u627e\u5230\u6700\u4f18\u53c2\u6570\u3002\u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u7247\u6bb5\uff0c\u6211\u5728\u4f7f\u7528 hyperopt \u65f6\u5bf9\u4e4b\u524d\u7684\u4ee3\u7801\u505a\u4e86\u6700\u5c0f\u7684\u6539\u52a8\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from hyperopt import hp , fmin , tpe , Trials from hyperopt.pyll.base import scope def optimize ( params , x , y ): model = ensemble . RandomForestClassifier ( ** params ) kf = model_selection . StratifiedKFold ( n_splits = 5 ) ... return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/mobile_train.csv\" ) X = df . drop ( \"price_range\" , axis = 1 ) . values y = df . price_range . values # \u5b9a\u4e49\u641c\u7d22\u7a7a\u95f4\uff08\u6574\u578b\u3001\u6d6e\u70b9\u6570\u578b\u3001\u9009\u62e9\u578b\uff09 param_space = { \"max_depth\" : scope . int ( hp . quniform ( \"max_depth\" , 1 , 15 , 1 )), \"n_estimators\" : scope . int ( hp . quniform ( \"n_estimators\" , 100 , 1500 , 1 ) ), \"criterion\" : hp . choice ( \"criterion\" , [ \"gini\" , \"entropy\" ]), \"max_features\" : hp . uniform ( \"max_features\" , 0 , 1 ) } # \u5305\u88c5\u51fd\u6570 optimization_function = partial ( optimize , x = X , y = y ) # \u5f00\u59cb\u8bad\u7ec3 trials = Trials () # \u6700\u5c0f\u5316\u76ee\u6807\u503c hopt = fmin ( fn = optimization_function , space = param_space , algo = tpe . suggest , max_evals = 15 , trials = trials ) #\u6253\u5370\u6700\u4f73\u53c2\u6570 print ( hopt ) \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u4e0e\u4e4b\u524d\u7684\u4ee3\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u4f60\u5fc5\u987b\u4ee5\u4e0d\u540c\u7684\u683c\u5f0f\u5b9a\u4e49\u53c2\u6570\u7a7a\u95f4\uff0c\u8fd8\u9700\u8981\u6539\u53d8\u5b9e\u9645\u4f18\u5316\u90e8\u5206\uff0c\u7528 hyperopt \u4ee3\u66ff gp_minimize\u3002\u7ed3\u679c\u76f8\u5f53\u4e0d\u9519\uff01 \u276f python rf_hyperopt . py 100 %| \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 15 / 15 [ 04 : 38 < 00 : 00 , 18.57 s / trial , best loss : - 0.9095000000000001 ] { 'criterion' : 1 , 'max_depth' : 11.0 , 'max_features' : 0.821163568049807 , 'n_estimators' : 806.0 } \u6211\u4eec\u5f97\u5230\u4e86\u6bd4\u4ee5\u524d\u66f4\u597d\u7684\u51c6\u786e\u5ea6\u548c\u4e00\u7ec4\u53ef\u4ee5\u4f7f\u7528\u7684\u53c2\u6570\u3002\u8bf7\u6ce8\u610f\uff0c\u6700\u7ec8\u7ed3\u679c\u4e2d\u7684\u6807\u51c6\u662f 1\u3002\u8fd9\u610f\u5473\u7740\u9009\u62e9\u4e86 1\uff0c\u5373\u71b5\u3002 \u4e0a\u8ff0\u8c03\u6574\u8d85\u53c2\u6570\u7684\u65b9\u6cd5\u662f\u6700\u5e38\u89c1\u7684\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u6a21\u578b\uff1a\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001\u57fa\u4e8e\u6811\u7684\u65b9\u6cd5\u3001\u68af\u5ea6\u63d0\u5347\u6a21\u578b\uff08\u5982 xgboost\u3001lightgbm\uff09\uff0c\u751a\u81f3\u795e\u7ecf\u7f51\u7edc\uff01 \u867d\u7136\u8fd9\u4e9b\u65b9\u6cd5\u5df2\u7ecf\u5b58\u5728\uff0c\u4f46\u5b66\u4e60\u65f6\u5fc5\u987b\u4ece\u624b\u52a8\u8c03\u6574\u8d85\u53c2\u6570\u5f00\u59cb\uff0c\u5373\u624b\u5de5\u8c03\u6574\u3002\u624b\u52a8\u8c03\u6574\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5b66\u4e60\u57fa\u7840\u77e5\u8bc6\uff0c\u4f8b\u5982\uff0c\u5728\u68af\u5ea6\u63d0\u5347\u4e2d\uff0c\u5f53\u4f60\u589e\u52a0\u6df1\u5ea6\u65f6\uff0c\u4f60\u5e94\u8be5\u964d\u4f4e\u5b66\u4e60\u7387\u3002\u5982\u679c\u4f7f\u7528\u81ea\u52a8\u5de5\u5177\uff0c\u5c31\u65e0\u6cd5\u5b66\u4e60\u5230\u8fd9\u4e00\u70b9\u3002\u8bf7\u53c2\u8003\u4e0b\u8868\uff0c\u4e86\u89e3\u5e94\u5982\u4f55\u8c03\u6574\u3002RS* \u8868\u793a\u968f\u673a\u641c\u7d22\u5e94\u8be5\u66f4\u597d\u3002 \u4e00\u65e6\u4f60\u80fd\u66f4\u597d\u5730\u624b\u52a8\u8c03\u6574\u53c2\u6570\uff0c\u4f60\u751a\u81f3\u53ef\u80fd\u4e0d\u9700\u8981\u4efb\u4f55\u81ea\u52a8\u8d85\u53c2\u6570\u8c03\u6574\u3002\u521b\u5efa\u5927\u578b\u6a21\u578b\u6216\u5f15\u5165\u5927\u91cf\u7279\u5f81\u65f6\uff0c\u4e5f\u5bb9\u6613\u9020\u6210\u8bad\u7ec3\u6570\u636e\u7684\u8fc7\u5ea6\u62df\u5408\u3002\u4e3a\u907f\u514d\u8fc7\u5ea6\u62df\u5408\uff0c\u9700\u8981\u5728\u8bad\u7ec3\u6570\u636e\u7279\u5f81\u4e2d\u5f15\u5165\u566a\u58f0\u6216\u5bf9\u4ee3\u4ef7\u51fd\u6570\u8fdb\u884c\u60e9\u7f5a\u3002\u8fd9\u79cd\u60e9\u7f5a\u79f0\u4e3a \u6b63\u5219\u5316 \uff0c\u6709\u52a9\u4e8e\u6cdb\u5316\u6a21\u578b\u3002\u5728\u7ebf\u6027\u6a21\u578b\u4e2d\uff0c\u6700\u5e38\u89c1\u7684\u6b63\u5219\u5316\u7c7b\u578b\u662f L1 \u548c L2\u3002L1 \u4e5f\u79f0\u4e3a Lasso \u56de\u5f52\uff0cL2 \u79f0\u4e3a Ridge \u56de\u5f52\u3002\u8bf4\u5230\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528dropout\u3001\u6dfb\u52a0\u589e\u5f3a\u3001\u566a\u58f0\u7b49\u65b9\u6cd5\u5bf9\u6a21\u578b\u8fdb\u884c\u6b63\u5219\u5316\u3002\u5229\u7528\u8d85\u53c2\u6570\u4f18\u5316\uff0c\u8fd8\u53ef\u4ee5\u627e\u5230\u6b63\u786e\u7684\u60e9\u7f5a\u65b9\u6cd5\u3002 Model Optimize Range of values Linear Regression - fit_intercept - normalize - True/False - True/False Ridge - alpha - fit_intercept - normalize - 0.01, 0.1, 1.0, 10, 100 - True/False - True/False k-neighbors - n_neighbors - p - 2, 4, 8, 16, ... - 2, 3, ... SVM - C - gamma - class_weight - 0.001, 0.01, ...,10, 100, 1000 - 'auto', RS* - 'balanced', None Logistic Regression - Penalyt - C - L1 or L2 - 0.001, 0.01, ..., 10, ..., 100 Lasso - Alpha - Normalize - 0.1, 1.0, 10 - True/False Random Forest - n_estimators - max_depth - min_samples_split - min_samples_leaf - max features - 120, 300, 500, 800, 1200 - 5, 8, 15, 25, 30, None - 1, 2, 5, 10, 15, 100 - log2, sqrt, None XGBoost - eta - gamma - max_depth - min_child_weight - subsample - colsample_bytree - lambda - alpha - 0.01, 0.015, 0.025, 0.05, 0.1 - 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0 - 3, 5, 7, 9, 12, 15, 17, 25 - 1, 3, 5, 7 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.01, 0.1, 1.0, RS - 0, 0.1, 0.5, 1.0, RS","title":"\u8d85\u53c2\u6570\u4f18\u5316"}]} \ No newline at end of file diff --git a/sitemap.xml.gz b/sitemap.xml.gz index 7ca9d5d..83c6602 100644 Binary files a/sitemap.xml.gz and b/sitemap.xml.gz differ diff --git "a/\350\257\204\344\274\260\346\214\207\346\240\207/index.html" "b/\350\257\204\344\274\260\346\214\207\346\240\207/index.html" index 92c9e5c..7cca71b 100644 --- "a/\350\257\204\344\274\260\346\214\207\346\240\207/index.html" +++ "b/\350\257\204\344\274\260\346\214\207\346\240\207/index.html" @@ -691,7 +691,7 @@

评估指标

Out[X]: 0.5

这与我们的计算值相符!

-

对于一个 "好 "模型来说,精确率和召回值都应该很高。我们看到,在上面的例子中,召回值相当高。但是,精确率却很低!我们的模型产生了大量的误报,但误报较少。在这类问题中,假阴性较少是好事,因为你不想在病人有气胸的情况下却说他们没有气胸。这样做会造成更大的伤害。但我们也有很多假阳性结果,这也不是好事。

+

对于一个 "好 "模型来说,精确率和召回值都应该很高。我们看到,在上面的例子中,召回值相当高。但是,精确率却很低!我们的模型产生了大量的误报,但漏报较少。在这类问题中,假阴性较少是好事,因为你不想在病人有气胸的情况下却说他们没有气胸。这样做会造成更大的伤害。但我们也有很多假阳性结果,这也不是好事。

大多数模型都会预测一个概率,当我们预测时,通常会将这个阈值选为 0.5。这个阈值并不总是理想的,根据这个阈值,精确率和召回率的值可能会发生很大的变化。如果我们选择的每个阈值都能计算出精确率和召回率,那么我们就可以在这些值之间绘制出曲线图。这幅图或曲线被称为 "精确率-召回率曲线"。

在研究精确率-调用曲线之前,我们先假设有两个列表。

In [X]: y_true = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0,