diff --git a/search/search_index.json b/search/search_index.json index db98732..a959b1f 100644 --- a/search/search_index.json +++ b/search/search_index.json @@ -1 +1 @@ -{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"AAAMLP-CN \u65b0\u7279\u6027 - 2023.09.07 \u26a1 \u4fee\u6b63\u90e8\u5206\u5df2\u77e5\u6587\u5b57\u9519\u8bef\u548c\u4ee3\u7801\u9519\u8bef \ud83e\udd17 \u6dfb\u52a0 \u5728\u7ebf\u6587\u6863 \u7ffb\u8bd1\u8fdb\u7a0b 2023.09.12 \u6dfb\u52a0\u7ae0\u8282\uff1a \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u3001 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5 \u7b80\u4ecb Abhishek Thakur\uff0c\u5f88\u591a kaggler \u5bf9\u4ed6\u90fd\u975e\u5e38\u719f\u6089\uff0c2017 \u5e74\uff0c\u4ed6\u5728 Linkedin \u53d1\u8868\u4e86\u4e00\u7bc7\u540d\u4e3a Approaching (Almost) Any Machine Learning Problem \u7684\u6587\u7ae0\uff0c\u4ecb\u7ecd\u4ed6\u5efa\u7acb\u7684\u4e00\u4e2a\u81ea\u52a8\u7684\u673a\u5668\u5b66\u4e60\u6846\u67b6\uff0c\u51e0\u4e4e\u53ef\u4ee5\u89e3\u51b3\u4efb\u4f55\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u8fd9\u7bc7\u6587\u7ae0\u66fe\u706b\u904d Kaggle\u3002 Abhishek \u5728 Kaggle \u4e0a\u7684\u6210\u5c31\uff1a Competitions Grandmaster\uff0817 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 3\uff09 Kernels Expert \uff08Kagglers \u6392\u540d\u524d 1\uff05\uff09 Discussion Grandmaster\uff0865 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 2\uff09 \u76ee\u524d\uff0cAbhishek \u5728\u632a\u5a01 boost \u516c\u53f8\u62c5\u4efb\u9996\u5e2d\u6570\u636e\u79d1\u5b66\u5bb6\u7684\u804c\u4f4d\uff0c\u8fd9\u662f\u4e00\u5bb6\u4e13\u95e8\u4ece\u4e8b\u4f1a\u8bdd\u4eba\u5de5\u667a\u80fd\u7684\u8f6f\u4ef6\u516c\u53f8\u3002 \u672c\u6587\u5bf9 Approaching (Almost) Any Machine Learning Problem \u8fdb\u884c\u4e86 \u4e2d\u6587\u7ffb\u8bd1 \uff0c\u7531\u4e8e\u672c\u4eba\u6c34\u5e73\u6709\u9650\uff0c\u4e14\u672a\u4f7f\u7528\u673a\u5668\u7ffb\u8bd1\uff0c\u53ef\u80fd\u6709\u90e8\u5206\u8a00\u8bed\u4e0d\u901a\u987a\u6216\u672c\u571f\u5316\u7a0b\u5ea6\u4e0d\u8db3\uff0c\u4e5f\u8bf7\u5927\u5bb6\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u591a\u63d0\u4f9b\u5b9d\u8d35\u610f\u89c1\u3002\u53e6\u9644\u4e0a\u4e66\u7c4d\u539f \u9879\u76ee\u5730\u5740 \uff0c \u8f6c\u8f7d\u8bf7\u4e00\u5b9a\u6807\u660e\u51fa\u5904\uff01 \u672c\u9879\u76ee \u652f\u6301\u5728\u7ebf\u9605\u8bfb \uff0c\u65b9\u4fbf\u60a8\u968f\u65f6\u968f\u5730\u8fdb\u884c\u67e5\u9605\u3002 \u56e0\u4e3a\u6709\u51e0\u7ae0\u5185\u5bb9\u592a\u8fc7\u57fa\u7840\uff0c\u6240\u4ee5\u672a\u8fdb\u884c\u7ffb\u8bd1\uff0c\u8be6\u7ec6\u60c5\u51b5\u8bf7\u53c2\u7167\u4e66\u7c4d\u76ee\u5f55\uff1a \u51c6\u5907\u73af\u5883\uff08\u672a\u7ffb\u8bd1\uff09 \u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60\uff08\u672a\u7ffb\u8bd1\uff09 \u4ea4\u53c9\u68c0\u9a8c\uff08\u5df2\u7ffb\u8bd1\uff09 \u8bc4\u4f30\u6307\u6807\uff08\u5df2\u7ffb\u8bd1\uff09 - \u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u5904\u7406\u5206\u7c7b\u53d8\u91cf\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u5de5\u7a0b\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u9009\u62e9\uff08\u5df2\u7ffb\u8bd1\uff09 \u8d85\u53c2\u6570\u4f18\u5316\uff08\u5df2\u7ffb\u8bd1\uff09 \u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5\uff08\u672a\u7ffb\u8bd1\uff09 \u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5\uff08\u672a\u7ffb\u8bd1\uff09 \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6211\u5c06\u4f1a\u628a\u5b8c\u6574\u7684\u7ffb\u8bd1\u7248 Markdown \u6587\u4ef6\u4e0a\u4f20\u5230 GitHub\uff0c\u4ee5\u4f9b\u5927\u5bb6\u514d\u8d39\u4e0b\u8f7d\u548c\u9605\u8bfb\u3002\u4e3a\u4e86\u6700\u4f73\u7684\u9605\u8bfb\u4f53\u9a8c\uff0c\u63a8\u8350\u4f7f\u7528 PDF \u683c\u5f0f\u6216\u662f\u5728\u7ebf\u9605\u8bfb\u8fdb\u884c\u67e5\u770b \u82e5\u60a8\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u53d1\u73b0\u4efb\u4f55\u9519\u8bef\u6216\u4e0d\u51c6\u786e\u4e4b\u5904\uff0c\u975e\u5e38\u6b22\u8fce\u901a\u8fc7\u63d0\u4ea4 Issue \u6216 Pull Request \u6765\u534f\u52a9\u6211\u8fdb\u884c\u4fee\u6b63\u3002 \u968f\u7740\u65f6\u95f4\u63a8\u79fb\uff0c\u6211\u53ef\u80fd\u4f1a \u7ee7\u7eed\u7ffb\u8bd1\u5c1a\u672a\u5b8c\u6210\u7684\u7ae0\u8282 \u3002\u5982\u679c\u60a8\u89c9\u5f97\u8fd9\u4e2a\u9879\u76ee\u5bf9\u60a8\u6709\u5e2e\u52a9\uff0c\u8bf7\u4e0d\u541d\u7ed9\u4e88 Star \u6216\u8005\u8fdb\u884c\u5173\u6ce8\u3002","title":"\u524d\u8a00"},{"location":"#aaamlp-cn","text":"","title":"AAAMLP-CN"},{"location":"#-20230907","text":"\u26a1 \u4fee\u6b63\u90e8\u5206\u5df2\u77e5\u6587\u5b57\u9519\u8bef\u548c\u4ee3\u7801\u9519\u8bef \ud83e\udd17 \u6dfb\u52a0 \u5728\u7ebf\u6587\u6863","title":"\u65b0\u7279\u6027 - 2023.09.07"},{"location":"#_1","text":"2023.09.12 \u6dfb\u52a0\u7ae0\u8282\uff1a \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u3001 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5","title":"\u7ffb\u8bd1\u8fdb\u7a0b"},{"location":"#_2","text":"Abhishek Thakur\uff0c\u5f88\u591a kaggler \u5bf9\u4ed6\u90fd\u975e\u5e38\u719f\u6089\uff0c2017 \u5e74\uff0c\u4ed6\u5728 Linkedin \u53d1\u8868\u4e86\u4e00\u7bc7\u540d\u4e3a Approaching (Almost) Any Machine Learning Problem \u7684\u6587\u7ae0\uff0c\u4ecb\u7ecd\u4ed6\u5efa\u7acb\u7684\u4e00\u4e2a\u81ea\u52a8\u7684\u673a\u5668\u5b66\u4e60\u6846\u67b6\uff0c\u51e0\u4e4e\u53ef\u4ee5\u89e3\u51b3\u4efb\u4f55\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u8fd9\u7bc7\u6587\u7ae0\u66fe\u706b\u904d Kaggle\u3002 Abhishek \u5728 Kaggle \u4e0a\u7684\u6210\u5c31\uff1a Competitions Grandmaster\uff0817 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 3\uff09 Kernels Expert \uff08Kagglers \u6392\u540d\u524d 1\uff05\uff09 Discussion Grandmaster\uff0865 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 2\uff09 \u76ee\u524d\uff0cAbhishek \u5728\u632a\u5a01 boost \u516c\u53f8\u62c5\u4efb\u9996\u5e2d\u6570\u636e\u79d1\u5b66\u5bb6\u7684\u804c\u4f4d\uff0c\u8fd9\u662f\u4e00\u5bb6\u4e13\u95e8\u4ece\u4e8b\u4f1a\u8bdd\u4eba\u5de5\u667a\u80fd\u7684\u8f6f\u4ef6\u516c\u53f8\u3002 \u672c\u6587\u5bf9 Approaching (Almost) Any Machine Learning Problem \u8fdb\u884c\u4e86 \u4e2d\u6587\u7ffb\u8bd1 \uff0c\u7531\u4e8e\u672c\u4eba\u6c34\u5e73\u6709\u9650\uff0c\u4e14\u672a\u4f7f\u7528\u673a\u5668\u7ffb\u8bd1\uff0c\u53ef\u80fd\u6709\u90e8\u5206\u8a00\u8bed\u4e0d\u901a\u987a\u6216\u672c\u571f\u5316\u7a0b\u5ea6\u4e0d\u8db3\uff0c\u4e5f\u8bf7\u5927\u5bb6\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u591a\u63d0\u4f9b\u5b9d\u8d35\u610f\u89c1\u3002\u53e6\u9644\u4e0a\u4e66\u7c4d\u539f \u9879\u76ee\u5730\u5740 \uff0c \u8f6c\u8f7d\u8bf7\u4e00\u5b9a\u6807\u660e\u51fa\u5904\uff01 \u672c\u9879\u76ee \u652f\u6301\u5728\u7ebf\u9605\u8bfb \uff0c\u65b9\u4fbf\u60a8\u968f\u65f6\u968f\u5730\u8fdb\u884c\u67e5\u9605\u3002 \u56e0\u4e3a\u6709\u51e0\u7ae0\u5185\u5bb9\u592a\u8fc7\u57fa\u7840\uff0c\u6240\u4ee5\u672a\u8fdb\u884c\u7ffb\u8bd1\uff0c\u8be6\u7ec6\u60c5\u51b5\u8bf7\u53c2\u7167\u4e66\u7c4d\u76ee\u5f55\uff1a \u51c6\u5907\u73af\u5883\uff08\u672a\u7ffb\u8bd1\uff09 \u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60\uff08\u672a\u7ffb\u8bd1\uff09 \u4ea4\u53c9\u68c0\u9a8c\uff08\u5df2\u7ffb\u8bd1\uff09 \u8bc4\u4f30\u6307\u6807\uff08\u5df2\u7ffb\u8bd1\uff09 - \u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u5904\u7406\u5206\u7c7b\u53d8\u91cf\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u5de5\u7a0b\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u9009\u62e9\uff08\u5df2\u7ffb\u8bd1\uff09 \u8d85\u53c2\u6570\u4f18\u5316\uff08\u5df2\u7ffb\u8bd1\uff09 \u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5\uff08\u672a\u7ffb\u8bd1\uff09 \u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5\uff08\u672a\u7ffb\u8bd1\uff09 \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6211\u5c06\u4f1a\u628a\u5b8c\u6574\u7684\u7ffb\u8bd1\u7248 Markdown \u6587\u4ef6\u4e0a\u4f20\u5230 GitHub\uff0c\u4ee5\u4f9b\u5927\u5bb6\u514d\u8d39\u4e0b\u8f7d\u548c\u9605\u8bfb\u3002\u4e3a\u4e86\u6700\u4f73\u7684\u9605\u8bfb\u4f53\u9a8c\uff0c\u63a8\u8350\u4f7f\u7528 PDF \u683c\u5f0f\u6216\u662f\u5728\u7ebf\u9605\u8bfb\u8fdb\u884c\u67e5\u770b \u82e5\u60a8\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u53d1\u73b0\u4efb\u4f55\u9519\u8bef\u6216\u4e0d\u51c6\u786e\u4e4b\u5904\uff0c\u975e\u5e38\u6b22\u8fce\u901a\u8fc7\u63d0\u4ea4 Issue \u6216 Pull Request \u6765\u534f\u52a9\u6211\u8fdb\u884c\u4fee\u6b63\u3002 \u968f\u7740\u65f6\u95f4\u63a8\u79fb\uff0c\u6211\u53ef\u80fd\u4f1a \u7ee7\u7eed\u7ffb\u8bd1\u5c1a\u672a\u5b8c\u6210\u7684\u7ae0\u8282 \u3002\u5982\u679c\u60a8\u89c9\u5f97\u8fd9\u4e2a\u9879\u76ee\u5bf9\u60a8\u6709\u5e2e\u52a9\uff0c\u8bf7\u4e0d\u541d\u7ed9\u4e88 Star \u6216\u8005\u8fdb\u884c\u5173\u6ce8\u3002","title":"\u7b80\u4ecb"},{"location":"%E4%BA%A4%E5%8F%89%E6%A3%80%E9%AA%8C/","text":"\u4ea4\u53c9\u68c0\u9a8c \u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u6ca1\u6709\u5efa\u7acb\u4efb\u4f55\u6a21\u578b\u3002\u539f\u56e0\u5f88\u7b80\u5355\uff0c\u5728\u521b\u5efa\u4efb\u4f55\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4ee5\u53ca\u5982\u4f55\u6839\u636e\u6570\u636e\u96c6\u9009\u62e9\u6700\u4f73\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\u3002 \u90a3\u4e48\uff0c\u4ec0\u4e48\u662f \u4ea4\u53c9\u68c0\u9a8c \uff0c\u6211\u4eec\u4e3a\u4ec0\u4e48\u8981\u5173\u6ce8\u5b83\uff1f \u5173\u4e8e\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u53ef\u4ee5\u627e\u5230\u591a\u79cd\u5b9a\u4e49\u3002\u6211\u7684\u5b9a\u4e49\u53ea\u6709\u4e00\u53e5\u8bdd\uff1a\u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fc7\u7a0b\u4e2d\u7684\u4e00\u4e2a\u6b65\u9aa4\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u786e\u4fdd\u6a21\u578b\u51c6\u786e\u62df\u5408\u6570\u636e\uff0c\u540c\u65f6\u786e\u4fdd\u6211\u4eec\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u4f46\u8fd9\u53c8\u5f15\u51fa\u4e86\u53e6\u4e00\u4e2a\u8bcd\uff1a \u8fc7\u62df\u5408 \u3002 \u8981\u89e3\u91ca\u8fc7\u62df\u5408\uff0c\u6211\u8ba4\u4e3a\u6700\u597d\u5148\u770b\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6709\u4e00\u4e2a\u76f8\u5f53\u6709\u540d\u7684\u7ea2\u9152\u8d28\u91cf\u6570\u636e\u96c6\uff08 red wine quality dataset \uff09\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 11 \u4e2a\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u8fd9\u4e9b\u7279\u5f81\u51b3\u5b9a\u4e86\u7ea2\u9152\u7684\u8d28\u91cf\u3002 \u8fd9\u4e9b\u5c5e\u6027\u5305\u62ec\uff1a \u56fa\u5b9a\u9178\u5ea6\uff08fixed acidity\uff09 \u6325\u53d1\u6027\u9178\u5ea6\uff08volatile acidity\uff09 \u67e0\u6aac\u9178\uff08citric acid\uff09 \u6b8b\u7559\u7cd6\uff08residual sugar\uff09 \u6c2f\u5316\u7269\uff08chlorides\uff09 \u6e38\u79bb\u4e8c\u6c27\u5316\u786b\uff08free sulfur dioxide\uff09 \u4e8c\u6c27\u5316\u786b\u603b\u91cf\uff08total sulfur dioxide\uff09 \u5bc6\u5ea6\uff08density\uff09 PH \u503c\uff08pH\uff09 \u786b\u9178\u76d0\uff08sulphates\uff09 \u9152\u7cbe\uff08alcohol\uff09 \u6839\u636e\u8fd9\u4e9b\u4e0d\u540c\u7279\u5f81\uff0c\u6211\u4eec\u9700\u8981\u9884\u6d4b\u7ea2\u8461\u8404\u9152\u7684\u8d28\u91cf\uff0c\u8d28\u91cf\u503c\u4ecb\u4e8e 0 \u5230 10 \u4e4b\u95f4\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u600e\u6837\u7684\u3002 import pandas as pd df = pd . read_csv ( \"winequality-red.csv\" ) \u56fe 1:\u7ea2\u8461\u8404\u9152\u8d28\u91cf\u6570\u636e\u96c6\u7b80\u5355\u5c55\u793a \u6211\u4eec\u53ef\u4ee5\u5c06\u8fd9\u4e2a\u95ee\u9898\u89c6\u4e3a\u5206\u7c7b\u95ee\u9898\uff0c\u4e5f\u53ef\u4ee5\u89c6\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u9009\u62e9\u5206\u7c7b\u3002\u7136\u800c\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u503c\u5305\u542b 6 \u79cd\u8d28\u91cf\u503c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u6240\u6709\u8d28\u91cf\u503c\u6620\u5c04\u5230 0 \u5230 5 \u4e4b\u95f4\u3002 # \u4e00\u4e2a\u6620\u5c04\u5b57\u5178\uff0c\u7528\u4e8e\u5c06\u8d28\u91cf\u503c\u4ece 0 \u5230 5 \u8fdb\u884c\u6620\u5c04 quality_mapping = { 3 : 0 , 4 : 1 , 5 : 2 , 6 : 3 , 7 : 4 , 8 : 5 } # \u4f60\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 map \u51fd\u6570\u4ee5\u53ca\u4efb\u4f55\u5b57\u5178\uff0c # \u6765\u8f6c\u6362\u7ed9\u5b9a\u5217\u4e2d\u7684\u503c\u4e3a\u5b57\u5178\u4e2d\u7684\u503c df . loc [:, \"quality\" ] = df . quality . map ( quality_mapping ) \u5f53\u6211\u4eec\u770b\u5927\u8fd9\u4e9b\u6570\u636e\u5e76\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u6211\u4eec\u8111\u6d77\u4e2d\u4f1a\u6d6e\u73b0\u51fa\u5f88\u591a\u53ef\u4ee5\u5e94\u7528\u7684\u7b97\u6cd5\uff0c\u4e5f\u8bb8\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u4ece\u4e00\u5f00\u59cb\u5c31\u6df1\u5165\u7814\u7a76\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u5c31\u6709\u70b9\u7275\u5f3a\u4e86\u3002\u6240\u4ee5\uff0c\u8ba9\u6211\u4eec\u4ece\u7b80\u5355\u7684\u3001\u6211\u4eec\u4e5f\u80fd\u53ef\u89c6\u5316\u7684\u4e1c\u897f\u5f00\u59cb\uff1a\u51b3\u7b56\u6811\u3002 \u5728\u5f00\u59cb\u4e86\u89e3\u4ec0\u4e48\u662f\u8fc7\u62df\u5408\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 1599 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u4fdd\u7559 1000 \u4e2a\u6837\u672c\u7528\u4e8e\u8bad\u7ec3\uff0c599 \u4e2a\u6837\u672c\u4f5c\u4e3a\u4e00\u4e2a\u5355\u72ec\u7684\u96c6\u5408\u3002 \u4ee5\u4e0b\u4ee3\u7801\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u5212\u5206\uff1a # \u4f7f\u7528 frac=1 \u7684 sample \u65b9\u6cd5\u6765\u6253\u4e71 dataframe # \u7531\u4e8e\u6253\u4e71\u540e\u7d22\u5f15\u4f1a\u6539\u53d8\uff0c\u6240\u4ee5\u6211\u4eec\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u9009\u53d6\u524d 1000 \u884c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e df_train = df . head ( 1000 ) # \u9009\u53d6\u6700\u540e\u7684 599 \u884c\u4f5c\u4e3a\u6d4b\u8bd5/\u9a8c\u8bc1\u6570\u636e df_test = df . tail ( 599 ) \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u5728\u8bad\u7ec3\u96c6\u4e0a\u4f7f\u7528 scikit-learn \u8bad\u7ec3\u4e00\u4e2a\u51b3\u7b56\u6811\u6a21\u578b\u3002 # \u4ece scikit-learn \u5bfc\u5165\u9700\u8981\u7684\u6a21\u5757 from sklearn import tree from sklearn import metrics # \u521d\u59cb\u5316\u4e00\u4e2a\u51b3\u7b56\u6811\u5206\u7c7b\u5668\uff0c\u8bbe\u7f6e\u6700\u5927\u6df1\u5ea6\u4e3a 3 clf = tree . DecisionTreeClassifier ( max_depth = 3 ) # \u9009\u62e9\u4f60\u60f3\u8981\u8bad\u7ec3\u6a21\u578b\u7684\u5217 # \u8fd9\u4e9b\u5217\u4f5c\u4e3a\u6a21\u578b\u7684\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u4f7f\u7528\u4e4b\u524d\u6620\u5c04\u7684\u8d28\u91cf\u4ee5\u53ca\u63d0\u4f9b\u7684\u7279\u5f81\u6765\u8bad\u7ec3\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5c06\u51b3\u7b56\u6811\u5206\u7c7b\u5668\u7684\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u8bbe\u4e3a 3\u3002\u8be5\u6a21\u578b\u7684\u6240\u6709\u5176\u4ed6\u53c2\u6570\u5747\u4fdd\u6301\u9ed8\u8ba4\u503c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u6d4b\u8bd5\u8be5\u6a21\u578b\u7684\u51c6\u786e\u6027\uff1a # \u5728\u8bad\u7ec3\u96c6\u4e0a\u751f\u6210\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) # \u5728\u6d4b\u8bd5\u96c6\u4e0a\u751f\u6210\u9884\u6d4b test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) # \u8ba1\u7b97\u6d4b\u8bd5\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) \u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u5206\u522b\u4e3a 58.9%\u548c 54.25%\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u589e\u52a0\u5230 7\uff0c\u5e76\u91cd\u590d\u4e0a\u8ff0\u8fc7\u7a0b\u3002\u8fd9\u6837\uff0c\u8bad\u7ec3\u51c6\u786e\u7387\u4e3a 76.6%\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4e3a 57.3%\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4f7f\u7528\u51c6\u786e\u7387\uff0c\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u662f\u6700\u76f4\u63a5\u7684\u6307\u6807\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u95ee\u9898\u6765\u8bf4\uff0c\u5b83\u53ef\u80fd\u4e0d\u662f\u6700\u597d\u7684\u6307\u6807\u3002\u6211\u4eec\u53ef\u4ee5\u6839\u636e\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u4e0d\u540c\u503c\u6765\u8ba1\u7b97\u8fd9\u4e9b\u51c6\u786e\u7387\uff0c\u5e76\u7ed8\u5236\u66f2\u7ebf\u56fe\u3002 # \u6ce8\u610f\uff1a\u8fd9\u6bb5\u4ee3\u7801\u5728 Jupyter \u7b14\u8bb0\u672c\u4e2d\u7f16\u5199 # \u5bfc\u5165 scikit-learn \u7684 tree \u548c metrics from sklearn import tree from sklearn import metrics # \u5bfc\u5165 matplotlib \u548c seaborn # \u7528\u4e8e\u7ed8\u56fe import matplotlib import matplotlib.pyplot as plt import seaborn as sns # \u8bbe\u7f6e\u5168\u5c40\u6807\u7b7e\u6587\u672c\u7684\u5927\u5c0f matplotlib . rc ( 'xtick' , labelsize = 20 ) matplotlib . rc ( 'ytick' , labelsize = 20 ) # \u786e\u4fdd\u56fe\u8868\u76f4\u63a5\u5728\u7b14\u8bb0\u672c\u5185\u663e\u793a % matplotlib inline # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6\u7684\u5217\u8868 # \u6211\u4eec\u4ece 50% \u7684\u51c6\u786e\u5ea6\u5f00\u59cb train_accuracies = [ 0.5 ] test_accuracies = [ 0.5 ] # \u904d\u5386\u51e0\u4e2a\u4e0d\u540c\u7684\u6811\u6df1\u5ea6\u503c for depth in range ( 1 , 25 ): # \u521d\u59cb\u5316\u6a21\u578b clf = tree . DecisionTreeClassifier ( max_depth = depth ) # \u9009\u62e9\u7528\u4e8e\u8bad\u7ec3\u7684\u5217/\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u5728\u7ed9\u5b9a\u7279\u5f81\u4e0a\u62df\u5408\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) # \u521b\u5efa\u8bad\u7ec3\u548c\u6d4b\u8bd5\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) # \u6dfb\u52a0\u51c6\u786e\u5ea6\u5230\u5217\u8868 train_accuracies . append ( train_accuracy ) test_accuracies . append ( test_accuracy ) # \u4f7f\u7528 matplotlib \u548c seaborn \u521b\u5efa\u4e24\u4e2a\u56fe plt . figure ( figsize = ( 10 , 5 )) sns . set_style ( \"whitegrid\" ) plt . plot ( train_accuracies , label = \"train accuracy\" ) plt . plot ( test_accuracies , label = \"test accuracy\" ) plt . legend ( loc = \"upper left\" , prop = { 'size' : 15 }) plt . xticks ( range ( 0 , 26 , 5 )) plt . xlabel ( \"max_depth\" , size = 20 ) plt . ylabel ( \"accuracy\" , size = 20 ) plt . show () \u8fd9\u5c06\u751f\u6210\u5982\u56fe 2 \u6240\u793a\u7684\u66f2\u7ebf\u56fe\u3002 \u56fe 2\uff1a\u4e0d\u540c max_depth \u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u7387\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5f53\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u503c\u4e3a 14 \u65f6\uff0c\u6d4b\u8bd5\u6570\u636e\u7684\u5f97\u5206\u6700\u9ad8\u3002\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u589e\u52a0\u8fd9\u4e2a\u53c2\u6570\u7684\u503c\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4f1a\u4fdd\u6301\u4e0d\u53d8\u6216\u53d8\u5dee\uff0c\u4f46\u8bad\u7ec3\u51c6\u786e\u7387\u4f1a\u4e0d\u65ad\u63d0\u9ad8\u3002\u8fd9\u8bf4\u660e\uff0c\u968f\u7740\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u589e\u52a0\uff0c\u51b3\u7b56\u6811\u6a21\u578b\u5bf9\u8bad\u7ec3\u6570\u636e\u7684\u5b66\u4e60\u6548\u679c\u8d8a\u6765\u8d8a\u597d\uff0c\u4f46\u6d4b\u8bd5\u6570\u636e\u7684\u6027\u80fd\u5374\u4e1d\u6beb\u6ca1\u6709\u63d0\u9ad8\u3002 \u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8fc7\u62df\u5408 \u3002 \u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u5b8c\u5168\u62df\u5408\uff0c\u800c\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5374\u8868\u73b0\u4e0d\u4f73\u3002\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5f88\u597d\u5730\u5b66\u4e60\u8bad\u7ec3\u6570\u636e\uff0c\u4f46\u65e0\u6cd5\u6cdb\u5316\u5230\u672a\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u3002\u5728\u4e0a\u9762\u7684\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u975e\u5e38\u9ad8\u7684\u6a21\u578b\uff0c\u5b83\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u4f1a\u6709\u51fa\u8272\u7684\u7ed3\u679c\uff0c\u4f46\u8fd9\u79cd\u6a21\u578b\u5e76\u4e0d\u5b9e\u7528\uff0c\u56e0\u4e3a\u5b83\u5728\u771f\u5b9e\u4e16\u754c\u7684\u6837\u672c\u6216\u5b9e\u65f6\u6570\u636e\u4e0a\u4e0d\u4f1a\u63d0\u4f9b\u7c7b\u4f3c\u7684\u7ed3\u679c\u3002 \u6709\u4eba\u53ef\u80fd\u4f1a\u8bf4\uff0c\u8fd9\u79cd\u65b9\u6cd5\u5e76\u6ca1\u6709\u8fc7\u62df\u5408\uff0c\u56e0\u4e3a\u6d4b\u8bd5\u96c6\u7684\u51c6\u786e\u7387\u57fa\u672c\u4fdd\u6301\u4e0d\u53d8\u3002\u8fc7\u62df\u5408\u7684\u53e6\u4e00\u4e2a\u5b9a\u4e49\u662f\uff0c\u5f53\u6211\u4eec\u4e0d\u65ad\u63d0\u9ad8\u8bad\u7ec3\u635f\u5931\u65f6\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u5728\u589e\u52a0\u3002\u8fd9\u79cd\u60c5\u51b5\u5728\u795e\u7ecf\u7f51\u7edc\u4e2d\u975e\u5e38\u5e38\u89c1\u3002 \u6bcf\u5f53\u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u90fd\u5fc5\u987b\u5728\u8bad\u7ec3\u671f\u95f4\u76d1\u63a7\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u3002\u5982\u679c\u6211\u4eec\u6709\u4e00\u4e2a\u975e\u5e38\u5927\u7684\u7f51\u7edc\u6765\u5904\u7406\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u6570\u636e\u96c6\uff08\u5373\u6837\u672c\u6570\u975e\u5e38\u5c11\uff09\uff0c\u6211\u4eec\u5c31\u4f1a\u89c2\u5bdf\u5230\uff0c\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u8bad\u7ec3\uff0c\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u90fd\u4f1a\u51cf\u5c11\u3002\u4f46\u662f\uff0c\u5728\u67d0\u4e2a\u65f6\u523b\uff0c\u6d4b\u8bd5\u635f\u5931\u4f1a\u8fbe\u5230\u6700\u5c0f\u503c\uff0c\u4e4b\u540e\uff0c\u5373\u4f7f\u8bad\u7ec3\u635f\u5931\u8fdb\u4e00\u6b65\u51cf\u5c11\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u4f1a\u5f00\u59cb\u589e\u52a0\u3002\u6211\u4eec\u5fc5\u987b\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u6700\u5c0f\u503c\u65f6\u505c\u6b62\u8bad\u7ec3\u3002 \u8fd9\u662f\u5bf9\u8fc7\u62df\u5408\u6700\u5e38\u89c1\u7684\u89e3\u91ca \u3002 \u5965\u5361\u59c6\u5243\u5200\u7528\u7b80\u5355\u7684\u8bdd\u8bf4\uff0c\u5c31\u662f\u4e0d\u8981\u8bd5\u56fe\u628a\u53ef\u4ee5\u7528\u7b80\u5355\u5f97\u591a\u7684\u65b9\u6cd5\u89e3\u51b3\u7684\u4e8b\u60c5\u590d\u6742\u5316\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6848\u5c31\u662f\u6700\u5177\u901a\u7528\u6027\u7684\u89e3\u51b3\u65b9\u6848\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u53ea\u8981\u4f60\u7684\u6a21\u578b\u4e0d\u7b26\u5408\u5965\u5361\u59c6\u5243\u5200\u539f\u5219\uff0c\u5c31\u5f88\u53ef\u80fd\u662f\u8fc7\u62df\u5408\u3002 \u56fe 3\uff1a\u8fc7\u62df\u5408\u7684\u6700\u4e00\u822c\u5b9a\u4e49 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u56de\u5230\u4ea4\u53c9\u68c0\u9a8c\u3002 \u5728\u89e3\u91ca\u8fc7\u62df\u5408\u65f6\uff0c\u6211\u51b3\u5b9a\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u6211\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u53e6\u4e00\u90e8\u5206\u4e0a\u68c0\u67e5\u5176\u6027\u80fd\u3002\u8fd9\u4e5f\u662f\u4ea4\u53c9\u68c0\u9a8c\u7684\u4e00\u79cd\uff0c\u901a\u5e38\u88ab\u79f0\u4e3a \"\u6682\u7559\u96c6\"\uff08 hold-out set \uff09\u3002\u5f53\u6211\u4eec\u62e5\u6709\u5927\u91cf\u6570\u636e\uff0c\u800c\u6a21\u578b\u63a8\u7406\u662f\u4e00\u4e2a\u8017\u65f6\u7684\u8fc7\u7a0b\u65f6\uff0c\u6211\u4eec\u5c31\u4f1a\u4f7f\u7528\u8fd9\u79cd\uff08\u4ea4\u53c9\uff09\u9a8c\u8bc1\u3002 \u4ea4\u53c9\u68c0\u9a8c\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\uff0c\u5b83\u662f\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u6b65\u9aa4\u3002 \u9009\u62e9\u6b63\u786e\u7684\u4ea4\u53c9\u68c0\u9a8c \u53d6\u51b3\u4e8e\u6240\u5904\u7406\u7684\u6570\u636e\u96c6\uff0c\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u9002\u7528\u7684\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u53ef\u80fd\u4e0d\u9002\u7528\u4e8e\u5176\u4ed6\u6570\u636e\u96c6\u3002\u4e0d\u8fc7\uff0c\u6709\u51e0\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u6700\u4e3a\u6d41\u884c\u548c\u5e7f\u6cdb\u4f7f\u7528\u3002 \u5176\u4e2d\u5305\u62ec\uff1a k \u6298\u4ea4\u53c9\u68c0\u9a8c \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c \u5206\u7ec4 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u4ea4\u53c9\u68c0\u9a8c\u662f\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u5c42\u51e0\u4e2a\u90e8\u5206\uff0c\u6211\u4eec\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u5176\u4f59\u90e8\u5206\u4e0a\u8fdb\u884c\u6d4b\u8bd5\u3002\u8bf7\u770b\u56fe 4\u3002 \u56fe 4\uff1a\u5c06\u6570\u636e\u96c6\u62c6\u5206\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u56fe 4 \u548c\u56fe 5 \u8bf4\u660e\uff0c\u5f53\u4f60\u5f97\u5230\u4e00\u4e2a\u6570\u636e\u96c6\u6765\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u65f6\uff0c\u4f60\u4f1a\u628a\u5b83\u4eec\u5206\u6210 \u4e24\u4e2a\u4e0d\u540c\u7684\u96c6\uff1a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u3002\u5f88\u591a\u4eba\u8fd8\u4f1a\u5c06\u5176\u5206\u6210\u7b2c\u4e09\u7ec4\uff0c\u79f0\u4e4b\u4e3a\u6d4b\u8bd5\u96c6\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u5c06\u53ea\u4f7f\u7528\u4e24\u4e2a\u96c6\u3002\u5982\u4f60\u6240\u89c1\uff0c\u6211\u4eec\u5c06\u6837\u672c\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\u8fdb\u884c\u4e86\u5212\u5206\u3002\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a k \u4e2a\u4e92\u4e0d\u5173\u8054\u7684\u4e0d\u540c\u96c6\u5408\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002 \u56fe 5\uff1aK \u6298\u4ea4\u53c9\u68c0\u9a8c \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 KFold \u5c06\u4efb\u4f55\u6570\u636e\u5206\u5272\u6210 k \u4e2a\u76f8\u7b49\u7684\u90e8\u5206\u3002\u6bcf\u4e2a\u6837\u672c\u5206\u914d\u4e00\u4e2a\u4ece 0 \u5230 k-1 \u7684\u503c\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u5b58\u50a8\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a kfold \u7684\u65b0\u5217\uff0c\u5e76\u7528 -1 \u586b\u5145 df [ \"kfold\" ] = - 1 # \u63a5\u4e0b\u6765\u7684\u6b65\u9aa4\u662f\u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4ece model_selection \u6a21\u5757\u521d\u59cb\u5316 kfold \u7c7b kf = model_selection . KFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684 kfold \u5217\uff08enumerate\u7684\u4f5c\u7528\u662f\u8fd4\u56de\u4e00\u4e2a\u8fed\u4ee3\u5668\uff09 for fold , ( trn_ , val_ ) in enumerate ( kf . split ( X = df )): df . loc [ val_ , 'kfold' ] = fold # \u4fdd\u5b58\u5e26\u6709 kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u636e\u96c6\u90fd\u53ef\u4ee5\u4f7f\u7528\u6b64\u6d41\u7a0b\u3002\u4f8b\u5982\uff0c\u5f53\u6570\u636e\u56fe\u50cf\u65f6\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u56fe\u50cf ID\u3001\u56fe\u50cf\u4f4d\u7f6e\u548c\u56fe\u50cf\u6807\u7b7e\u7684 CSV\uff0c\u7136\u540e\u4f7f\u7528\u4e0a\u8ff0\u6d41\u7a0b\u3002 \u53e6\u4e00\u79cd\u91cd\u8981\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u662f \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u6570\u636e\u96c6\uff0c\u5176\u4e2d\u6b63\u6837\u672c\u5360 90%\uff0c\u8d1f\u6837\u672c\u53ea\u5360 10%\uff0c\u90a3\u4e48\u4f60\u5c31\u4e0d\u5e94\u8be5\u4f7f\u7528\u968f\u673a k \u6298\u4ea4\u53c9\u3002\u5bf9\u8fd9\u6837\u7684\u6570\u636e\u96c6\u4f7f\u7528\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u6298\u53e0\u6837\u672c\u5168\u90e8\u4e3a\u8d1f\u6837\u672c\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u66f4\u503e\u5411\u4e8e\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u4ee5\u4fdd\u6301\u6bcf\u4e2a\u6298\u4e2d\u6807\u7b7e\u7684\u6bd4\u4f8b\u4e0d\u53d8\u3002\u56e0\u6b64\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u90fd\u4f1a\u6709\u76f8\u540c\u7684 90% \u6b63\u6837\u672c\u548c 10% \u8d1f\u6837\u672c\u3002\u56e0\u6b64\uff0c\u65e0\u8bba\u60a8\u9009\u62e9\u4ec0\u4e48\u6307\u6807\u8fdb\u884c\u8bc4\u4f30\uff0c\u90fd\u4f1a\u5728\u6240\u6709\u6298\u53e0\u4e2d\u5f97\u5230\u76f8\u4f3c\u7684\u7ed3\u679c\u3002 \u4fee\u6539\u521b\u5efa k \u6298\u4ea4\u53c9\u68c0\u9a8c\u7684\u4ee3\u7801\u4ee5\u521b\u5efa\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5f88\u5bb9\u6613\u3002\u6211\u4eec\u53ea\u9700\u5c06 model_selection.KFold \u66f4\u6539\u4e3a model_selection.StratifiedKFold \uff0c\u5e76\u5728 kf.split(...) \u51fd\u6570\u4e2d\u6307\u5b9a\u8981\u5206\u5c42\u7684\u76ee\u6807\u5217\u3002\u6211\u4eec\u5047\u8bbe CSV \u6570\u636e\u96c6\u6709\u4e00\u5217\u540d\u4e3a \"target\" \uff0c\u5e76\u4e14\u662f\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u4fdd\u5b58\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6dfb\u52a0\u4e00\u4e2a\u65b0\u5217 kfold\uff0c\u5e76\u7528 -1 \u521d\u59cb\u5316 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u53d8\u91cf y = df . target . values # \u521d\u59cb\u5316 StratifiedKFold \u7c7b\uff0c\u8bbe\u7f6e\u6298\u6570\uff08folds\uff09\u4e3a 5 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u4f7f\u7528 StratifiedKFold \u5bf9\u8c61\u7684 split \u65b9\u6cd5\u6765\u83b7\u53d6\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7d22\u5f15 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u5305\u542b kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u5bf9\u4e8e\u8461\u8404\u9152\u6570\u636e\u96c6\uff0c\u6211\u4eec\u6765\u770b\u770b\u6807\u7b7e\u7684\u5206\u5e03\u60c5\u51b5\u3002 b = sns . countplot ( x = 'quality' , data = df ) b . set_xlabel ( \"quality\" , fontsize = 20 ) b . set_ylabel ( \"count\" , fontsize = 20 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u7ee7\u7eed\u4e0a\u9762\u7684\u4ee3\u7801\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u8f6c\u6362\u4e86\u76ee\u6807\u503c\u3002\u4ece\u56fe 6 \u4e2d\u6211\u4eec\u53ef\u4ee5\u770b\u51fa\uff0c\u8d28\u91cf\u504f\u5dee\u5f88\u5927\u3002\u6709\u4e9b\u7c7b\u522b\u6709\u5f88\u591a\u6837\u672c\uff0c\u6709\u4e9b\u5219\u6ca1\u6709\u90a3\u4e48\u591a\u3002\u5982\u679c\u6211\u4eec\u8fdb\u884c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u90a3\u4e48\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u76ee\u6807\u503c\u5206\u5e03\u90fd\u4e0d\u4f1a\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u56fe 6\uff1a\u8461\u8404\u9152\u6570\u636e\u96c6\u4e2d \"\u8d28\u91cf\" \u5206\u5e03\u60c5\u51b5 \u89c4\u5219\u5f88\u7b80\u5355\uff0c\u5982\u679c\u662f\u6807\u51c6\u5206\u7c7b\u95ee\u9898\uff0c\u5c31\u76f2\u76ee\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u4f46\u5982\u679c\u6570\u636e\u91cf\u5f88\u5927\uff0c\u8be5\u600e\u4e48\u529e\u5462\uff1f\u5047\u8bbe\u6211\u4eec\u6709 100 \u4e07\u4e2a\u6837\u672c\u30025 \u500d\u4ea4\u53c9\u68c0\u9a8c\u610f\u5473\u7740\u5728 800k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u5728 200k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u9a8c\u8bc1\u3002\u6839\u636e\u6211\u4eec\u9009\u62e9\u7684\u7b97\u6cd5\uff0c\u5bf9\u4e8e\u8fd9\u6837\u89c4\u6a21\u7684\u6570\u636e\u96c6\u6765\u8bf4\uff0c\u8bad\u7ec3\u751a\u81f3\u9a8c\u8bc1\u90fd\u53ef\u80fd\u975e\u5e38\u6602\u8d35\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u3002 \u521b\u5efa\u4fdd\u6301\u7ed3\u679c\u7684\u8fc7\u7a0b\u4e0e\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u76f8\u540c\u3002\u5bf9\u4e8e\u62e5\u6709 100 \u4e07\u4e2a\u6837\u672c\u7684\u6570\u636e\u96c6\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa 10 \u4e2a\u6298\u53e0\u800c\u4e0d\u662f 5 \u4e2a\uff0c\u5e76\u4fdd\u7559\u5176\u4e2d\u4e00\u4e2a\u6298\u53e0\u4f5c\u4e3a\u4fdd\u7559\u6837\u672c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u6211\u4eec\u5c06\u6709 10 \u4e07\u4e2a\u6837\u672c\u88ab\u4fdd\u7559\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u59cb\u7ec8\u5728\u8fd9\u4e2a\u6837\u672c\u96c6\u4e0a\u8ba1\u7b97\u635f\u5931\u3001\u51c6\u786e\u7387\u548c\u5176\u4ed6\u6307\u6807\uff0c\u5e76\u5728 90 \u4e07\u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002 \u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u6682\u7559\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u975e\u5e38\u5e38\u7528\u3002\u5047\u8bbe\u6211\u4eec\u8981\u89e3\u51b3\u7684\u95ee\u9898\u662f\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97 2020 \u5e74\u7684\u9500\u552e\u989d\uff0c\u800c\u6211\u4eec\u5f97\u5230\u7684\u662f 2015-2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f60\u53ef\u4ee5\u9009\u62e9 2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u4f5c\u4e3a\u4fdd\u7559\u6570\u636e\uff0c\u7136\u540e\u5728 2015 \u5e74\u81f3 2018 \u5e74\u7684\u6240\u6709\u6570\u636e\u4e0a\u8bad\u7ec3\u4f60\u7684\u6a21\u578b\u3002 \u56fe 7\uff1a\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u793a\u4f8b \u5728\u56fe 7 \u6240\u793a\u7684\u793a\u4f8b\u4e2d\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u4ece\u65f6\u95f4\u6b65\u9aa4 31 \u5230 40 \u7684\u9500\u552e\u989d\u3002\u6211\u4eec\u53ef\u4ee5\u4fdd\u7559 21 \u81f3 30 \u6b65\u7684\u6570\u636e\uff0c\u7136\u540e\u4ece 0 \u6b65\u5230 20 \u6b65\u8bad\u7ec3\u6a21\u578b\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5728\u9884\u6d4b 31 \u6b65\u81f3 40 \u6b65\u65f6\uff0c\u5e94\u5c06 21 \u6b65\u81f3 30 \u6b65\u7684\u6570\u636e\u7eb3\u5165\u6a21\u578b\uff0c\u5426\u5219\uff0c\u6a21\u578b\u7684\u6027\u80fd\u5c06\u5927\u6253\u6298\u6263\u3002 \u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u5fc5\u987b\u5904\u7406\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u800c\u521b\u5efa\u5927\u578b\u9a8c\u8bc1\u96c6\u610f\u5473\u7740\u6a21\u578b\u5b66\u4e60\u4f1a\u4e22\u5931\u5927\u91cf\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9\u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c\uff0c\u76f8\u5f53\u4e8e\u7279\u6b8a\u7684 k \u5219\u4ea4\u53c9\u68c0\u9a8c\u5176\u4e2d k=N \uff0cN \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8fd9\u610f\u5473\u7740\u5728\u6240\u6709\u7684\u8bad\u7ec3\u6298\u53e0\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u9664 1 \u4e4b\u5916\u7684\u6240\u6709\u6570\u636e\u6837\u672c\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u7684\u6298\u53e0\u6570\u4e0e\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u76f8\u540c\u3002 \u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u6a21\u578b\u7684\u901f\u5ea6\u4e0d\u591f\u5feb\uff0c\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\uff0c\u4f46\u7531\u4e8e\u8fd9\u79cd\u4ea4\u53c9\u68c0\u9a8c\u53ea\u9002\u7528\u4e8e\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u56e0\u6b64\u5e76\u4e0d\u91cd\u8981\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u95ee\u9898\u4e86\u3002\u56de\u5f52\u95ee\u9898\u7684\u597d\u5904\u5728\u4e8e\uff0c\u9664\u4e86\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e4b\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u5728\u56de\u5f52\u95ee\u9898\u4e0a\u4f7f\u7528\u4e0a\u8ff0\u6240\u6709\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6211\u4eec\u4e0d\u80fd\u76f4\u63a5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u6709\u4e00\u4e9b\u65b9\u6cd5\u53ef\u4ee5\u7a0d\u7a0d\u6539\u53d8\u95ee\u9898\uff0c\u4ece\u800c\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u9002\u7528\u4e8e\u4efb\u4f55\u56de\u5f52\u95ee\u9898\u3002\u4f46\u662f\uff0c\u5982\u679c\u53d1\u73b0\u76ee\u6807\u5206\u5e03\u4e0d\u4e00\u81f4\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u8981\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u5fc5\u987b\u5148\u5c06\u76ee\u6807\u5212\u5206\u4e3a\u82e5\u5e72\u4e2a\u5206\u5c42\uff0c\u7136\u540e\u518d\u4ee5\u5904\u7406\u5206\u7c7b\u95ee\u9898\u7684\u76f8\u540c\u65b9\u5f0f\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u9009\u62e9\u5408\u9002\u7684\u5206\u5c42\u6570\u6709\u51e0\u79cd\u9009\u62e9\u3002\u5982\u679c\u6837\u672c\u91cf\u5f88\u5927\uff08> 10k\uff0c> 100k\uff09\uff0c\u90a3\u4e48\u5c31\u4e0d\u9700\u8981\u8003\u8651\u5206\u5c42\u7684\u6570\u91cf\u3002\u53ea\u9700\u5c06\u6570\u636e\u5206\u4e3a 10 \u6216 20 \u5c42\u5373\u53ef\u3002\u5982\u679c\u6837\u672c\u6570\u4e0d\u591a\uff0c\u5219\u53ef\u4ee5\u4f7f\u7528 Sturge's Rule \u8fd9\u6837\u7684\u7b80\u5355\u89c4\u5219\u6765\u8ba1\u7b97\u9002\u5f53\u7684\u5206\u5c42\u6570\u3002 Sturge's Rule\uff1a \\[ Number of Bins = 1 + log_2(N) \\] \u5176\u4e2d \\(N\\) \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8be5\u51fd\u6570\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u5229\u7528\u65af\u7279\u683c\u6cd5\u5219\u7ed8\u5236\u6837\u672c\u4e0e\u7bb1\u6570\u5bf9\u6bd4\u56fe \u8ba9\u6211\u4eec\u5236\u4f5c\u4e00\u4e2a\u56de\u5f52\u6570\u636e\u96c6\u6837\u672c\uff0c\u5e76\u5c1d\u8bd5\u5e94\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 # stratified-kfold for regression # \u4e3a\u56de\u5f52\u95ee\u9898\u8fdb\u884c\u5206\u5c42K-\u6298\u4ea4\u53c9\u9a8c\u8bc1 # \u5bfc\u5165\u9700\u8981\u7684\u5e93 import numpy as np import pandas as pd from sklearn import datasets from sklearn import model_selection # \u521b\u5efa\u5206\u6298\uff08folds\uff09\u7684\u51fd\u6570 def create_folds ( data ): # \u521b\u5efa\u4e00\u4e2a\u65b0\u5217\u53eb\u505akfold\uff0c\u5e76\u7528-1\u6765\u586b\u5145 data [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4f7f\u7528Sturge\u89c4\u5219\u8ba1\u7b97bin\u7684\u6570\u91cf num_bins = int ( np . floor ( 1 + np . log2 ( len ( data )))) # \u4f7f\u7528pandas\u7684cut\u51fd\u6570\u8fdb\u884c\u76ee\u6807\u53d8\u91cf\uff08target\uff09\u7684\u5206\u7bb1 data . loc [:, \"bins\" ] = pd . cut ( data [ \"target\" ], bins = num_bins , labels = False ) # \u521d\u59cb\u5316StratifiedKFold\u7c7b kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684kfold\u5217 # \u6ce8\u610f\uff1a\u6211\u4eec\u4f7f\u7528\u7684\u662fbins\u800c\u4e0d\u662f\u5b9e\u9645\u7684\u76ee\u6807\u53d8\u91cf\uff08target\uff09\uff01 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = data , y = data . bins . values )): data . loc [ v_ , 'kfold' ] = f # \u5220\u9664bins\u5217 data = data . drop ( \"bins\" , axis = 1 ) # \u8fd4\u56de\u5305\u542bfolds\u7684\u6570\u636e return data # \u4e3b\u7a0b\u5e8f\u5f00\u59cb if __name__ == \"__main__\" : # \u521b\u5efa\u4e00\u4e2a\u5e26\u670915000\u4e2a\u6837\u672c\u3001100\u4e2a\u7279\u5f81\u548c1\u4e2a\u76ee\u6807\u53d8\u91cf\u7684\u6837\u672c\u6570\u636e\u96c6 X , y = datasets . make_regression ( n_samples = 15000 , n_features = 100 , n_targets = 1 ) # \u4f7f\u7528numpy\u6570\u7ec4\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u6846 df = pd . DataFrame ( X , columns = [ f \"f_ { i } \" for i in range ( X . shape [ 1 ])] ) df . loc [:, \"target\" ] = y # \u521b\u5efafolds df = create_folds ( df ) \u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u7b2c\u4e00\u6b65\uff0c\u4e5f\u662f\u6700\u57fa\u672c\u7684\u4e00\u6b65\u3002\u5982\u679c\u8981\u505a\u7279\u5f81\u5de5\u7a0b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u8981\u5efa\u7acb\u6a21\u578b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u65b9\u6848\uff0c\u5176\u4e2d\u9a8c\u8bc1\u6570\u636e\u80fd\u591f\u4ee3\u8868\u8bad\u7ec3\u6570\u636e\u548c\u771f\u5b9e\u4e16\u754c\u7684\u6570\u636e\uff0c\u90a3\u4e48\u4f60\u5c31\u80fd\u5efa\u7acb\u4e00\u4e2a\u5177\u6709\u9ad8\u5ea6\u901a\u7528\u6027\u7684\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u3002 \u672c\u7ae0\u4ecb\u7ecd\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u6570\u636e\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u6839\u636e\u4f60\u7684\u95ee\u9898\u548c\u6570\u636e\u91c7\u7528\u65b0\u7684\u4ea4\u53c9\u68c0\u9a8c\u5f62\u5f0f\u3002 \u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u95ee\u9898\uff0c\u5e0c\u671b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u4ece\u60a3\u8005\u7684\u76ae\u80a4\u56fe\u50cf\u4e2d\u68c0\u6d4b\u51fa\u76ae\u80a4\u764c\u3002\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u5668\uff0c\u8be5\u5206\u7c7b\u5668\u63a5\u6536\u8f93\u5165\u56fe\u50cf\u5e76\u9884\u6d4b\u5176\u826f\u6027\u6216\u6076\u6027\u7684\u6982\u7387\u3002 \u5728\u8fd9\u7c7b\u6570\u636e\u96c6\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u53ef\u80fd\u6709\u540c\u4e00\u60a3\u8005\u7684\u591a\u5f20\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u8981\u5728\u8fd9\u91cc\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u7cfb\u7edf\uff0c\u5fc5\u987b\u6709\u5206\u5c42\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u4e5f\u5fc5\u987b\u786e\u4fdd\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u60a3\u8005\u4e0d\u4f1a\u51fa\u73b0\u5728\u9a8c\u8bc1\u6570\u636e\u4e2d\u3002\u5e78\u8fd0\u7684\u662f\uff0cscikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u79f0\u4e3a GroupKFold \u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u3002 \u5728\u8fd9\u91cc\uff0c\u60a3\u8005\u53ef\u4ee5\u88ab\u89c6\u4e3a\u7ec4\u3002 \u4f46\u9057\u61be\u7684\u662f\uff0cscikit-learn \u65e0\u6cd5\u5c06 GroupKFold \u4e0e StratifiedKFold \u7ed3\u5408\u8d77\u6765\u3002\u6240\u4ee5\u4f60\u9700\u8981\u81ea\u5df1\u52a8\u624b\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u8bfb\u8005\u7684\u7ec3\u4e60\u3002","title":"\u4ea4\u53c9\u68c0\u9a8c"},{"location":"%E4%BA%A4%E5%8F%89%E6%A3%80%E9%AA%8C/#_1","text":"\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u6ca1\u6709\u5efa\u7acb\u4efb\u4f55\u6a21\u578b\u3002\u539f\u56e0\u5f88\u7b80\u5355\uff0c\u5728\u521b\u5efa\u4efb\u4f55\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4ee5\u53ca\u5982\u4f55\u6839\u636e\u6570\u636e\u96c6\u9009\u62e9\u6700\u4f73\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\u3002 \u90a3\u4e48\uff0c\u4ec0\u4e48\u662f \u4ea4\u53c9\u68c0\u9a8c \uff0c\u6211\u4eec\u4e3a\u4ec0\u4e48\u8981\u5173\u6ce8\u5b83\uff1f \u5173\u4e8e\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u53ef\u4ee5\u627e\u5230\u591a\u79cd\u5b9a\u4e49\u3002\u6211\u7684\u5b9a\u4e49\u53ea\u6709\u4e00\u53e5\u8bdd\uff1a\u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fc7\u7a0b\u4e2d\u7684\u4e00\u4e2a\u6b65\u9aa4\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u786e\u4fdd\u6a21\u578b\u51c6\u786e\u62df\u5408\u6570\u636e\uff0c\u540c\u65f6\u786e\u4fdd\u6211\u4eec\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u4f46\u8fd9\u53c8\u5f15\u51fa\u4e86\u53e6\u4e00\u4e2a\u8bcd\uff1a \u8fc7\u62df\u5408 \u3002 \u8981\u89e3\u91ca\u8fc7\u62df\u5408\uff0c\u6211\u8ba4\u4e3a\u6700\u597d\u5148\u770b\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6709\u4e00\u4e2a\u76f8\u5f53\u6709\u540d\u7684\u7ea2\u9152\u8d28\u91cf\u6570\u636e\u96c6\uff08 red wine quality dataset \uff09\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 11 \u4e2a\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u8fd9\u4e9b\u7279\u5f81\u51b3\u5b9a\u4e86\u7ea2\u9152\u7684\u8d28\u91cf\u3002 \u8fd9\u4e9b\u5c5e\u6027\u5305\u62ec\uff1a \u56fa\u5b9a\u9178\u5ea6\uff08fixed acidity\uff09 \u6325\u53d1\u6027\u9178\u5ea6\uff08volatile acidity\uff09 \u67e0\u6aac\u9178\uff08citric acid\uff09 \u6b8b\u7559\u7cd6\uff08residual sugar\uff09 \u6c2f\u5316\u7269\uff08chlorides\uff09 \u6e38\u79bb\u4e8c\u6c27\u5316\u786b\uff08free sulfur dioxide\uff09 \u4e8c\u6c27\u5316\u786b\u603b\u91cf\uff08total sulfur dioxide\uff09 \u5bc6\u5ea6\uff08density\uff09 PH \u503c\uff08pH\uff09 \u786b\u9178\u76d0\uff08sulphates\uff09 \u9152\u7cbe\uff08alcohol\uff09 \u6839\u636e\u8fd9\u4e9b\u4e0d\u540c\u7279\u5f81\uff0c\u6211\u4eec\u9700\u8981\u9884\u6d4b\u7ea2\u8461\u8404\u9152\u7684\u8d28\u91cf\uff0c\u8d28\u91cf\u503c\u4ecb\u4e8e 0 \u5230 10 \u4e4b\u95f4\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u600e\u6837\u7684\u3002 import pandas as pd df = pd . read_csv ( \"winequality-red.csv\" ) \u56fe 1:\u7ea2\u8461\u8404\u9152\u8d28\u91cf\u6570\u636e\u96c6\u7b80\u5355\u5c55\u793a \u6211\u4eec\u53ef\u4ee5\u5c06\u8fd9\u4e2a\u95ee\u9898\u89c6\u4e3a\u5206\u7c7b\u95ee\u9898\uff0c\u4e5f\u53ef\u4ee5\u89c6\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u9009\u62e9\u5206\u7c7b\u3002\u7136\u800c\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u503c\u5305\u542b 6 \u79cd\u8d28\u91cf\u503c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u6240\u6709\u8d28\u91cf\u503c\u6620\u5c04\u5230 0 \u5230 5 \u4e4b\u95f4\u3002 # \u4e00\u4e2a\u6620\u5c04\u5b57\u5178\uff0c\u7528\u4e8e\u5c06\u8d28\u91cf\u503c\u4ece 0 \u5230 5 \u8fdb\u884c\u6620\u5c04 quality_mapping = { 3 : 0 , 4 : 1 , 5 : 2 , 6 : 3 , 7 : 4 , 8 : 5 } # \u4f60\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 map \u51fd\u6570\u4ee5\u53ca\u4efb\u4f55\u5b57\u5178\uff0c # \u6765\u8f6c\u6362\u7ed9\u5b9a\u5217\u4e2d\u7684\u503c\u4e3a\u5b57\u5178\u4e2d\u7684\u503c df . loc [:, \"quality\" ] = df . quality . map ( quality_mapping ) \u5f53\u6211\u4eec\u770b\u5927\u8fd9\u4e9b\u6570\u636e\u5e76\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u6211\u4eec\u8111\u6d77\u4e2d\u4f1a\u6d6e\u73b0\u51fa\u5f88\u591a\u53ef\u4ee5\u5e94\u7528\u7684\u7b97\u6cd5\uff0c\u4e5f\u8bb8\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u4ece\u4e00\u5f00\u59cb\u5c31\u6df1\u5165\u7814\u7a76\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u5c31\u6709\u70b9\u7275\u5f3a\u4e86\u3002\u6240\u4ee5\uff0c\u8ba9\u6211\u4eec\u4ece\u7b80\u5355\u7684\u3001\u6211\u4eec\u4e5f\u80fd\u53ef\u89c6\u5316\u7684\u4e1c\u897f\u5f00\u59cb\uff1a\u51b3\u7b56\u6811\u3002 \u5728\u5f00\u59cb\u4e86\u89e3\u4ec0\u4e48\u662f\u8fc7\u62df\u5408\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 1599 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u4fdd\u7559 1000 \u4e2a\u6837\u672c\u7528\u4e8e\u8bad\u7ec3\uff0c599 \u4e2a\u6837\u672c\u4f5c\u4e3a\u4e00\u4e2a\u5355\u72ec\u7684\u96c6\u5408\u3002 \u4ee5\u4e0b\u4ee3\u7801\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u5212\u5206\uff1a # \u4f7f\u7528 frac=1 \u7684 sample \u65b9\u6cd5\u6765\u6253\u4e71 dataframe # \u7531\u4e8e\u6253\u4e71\u540e\u7d22\u5f15\u4f1a\u6539\u53d8\uff0c\u6240\u4ee5\u6211\u4eec\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u9009\u53d6\u524d 1000 \u884c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e df_train = df . head ( 1000 ) # \u9009\u53d6\u6700\u540e\u7684 599 \u884c\u4f5c\u4e3a\u6d4b\u8bd5/\u9a8c\u8bc1\u6570\u636e df_test = df . tail ( 599 ) \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u5728\u8bad\u7ec3\u96c6\u4e0a\u4f7f\u7528 scikit-learn \u8bad\u7ec3\u4e00\u4e2a\u51b3\u7b56\u6811\u6a21\u578b\u3002 # \u4ece scikit-learn \u5bfc\u5165\u9700\u8981\u7684\u6a21\u5757 from sklearn import tree from sklearn import metrics # \u521d\u59cb\u5316\u4e00\u4e2a\u51b3\u7b56\u6811\u5206\u7c7b\u5668\uff0c\u8bbe\u7f6e\u6700\u5927\u6df1\u5ea6\u4e3a 3 clf = tree . DecisionTreeClassifier ( max_depth = 3 ) # \u9009\u62e9\u4f60\u60f3\u8981\u8bad\u7ec3\u6a21\u578b\u7684\u5217 # \u8fd9\u4e9b\u5217\u4f5c\u4e3a\u6a21\u578b\u7684\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u4f7f\u7528\u4e4b\u524d\u6620\u5c04\u7684\u8d28\u91cf\u4ee5\u53ca\u63d0\u4f9b\u7684\u7279\u5f81\u6765\u8bad\u7ec3\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5c06\u51b3\u7b56\u6811\u5206\u7c7b\u5668\u7684\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u8bbe\u4e3a 3\u3002\u8be5\u6a21\u578b\u7684\u6240\u6709\u5176\u4ed6\u53c2\u6570\u5747\u4fdd\u6301\u9ed8\u8ba4\u503c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u6d4b\u8bd5\u8be5\u6a21\u578b\u7684\u51c6\u786e\u6027\uff1a # \u5728\u8bad\u7ec3\u96c6\u4e0a\u751f\u6210\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) # \u5728\u6d4b\u8bd5\u96c6\u4e0a\u751f\u6210\u9884\u6d4b test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) # \u8ba1\u7b97\u6d4b\u8bd5\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) \u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u5206\u522b\u4e3a 58.9%\u548c 54.25%\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u589e\u52a0\u5230 7\uff0c\u5e76\u91cd\u590d\u4e0a\u8ff0\u8fc7\u7a0b\u3002\u8fd9\u6837\uff0c\u8bad\u7ec3\u51c6\u786e\u7387\u4e3a 76.6%\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4e3a 57.3%\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4f7f\u7528\u51c6\u786e\u7387\uff0c\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u662f\u6700\u76f4\u63a5\u7684\u6307\u6807\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u95ee\u9898\u6765\u8bf4\uff0c\u5b83\u53ef\u80fd\u4e0d\u662f\u6700\u597d\u7684\u6307\u6807\u3002\u6211\u4eec\u53ef\u4ee5\u6839\u636e\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u4e0d\u540c\u503c\u6765\u8ba1\u7b97\u8fd9\u4e9b\u51c6\u786e\u7387\uff0c\u5e76\u7ed8\u5236\u66f2\u7ebf\u56fe\u3002 # \u6ce8\u610f\uff1a\u8fd9\u6bb5\u4ee3\u7801\u5728 Jupyter \u7b14\u8bb0\u672c\u4e2d\u7f16\u5199 # \u5bfc\u5165 scikit-learn \u7684 tree \u548c metrics from sklearn import tree from sklearn import metrics # \u5bfc\u5165 matplotlib \u548c seaborn # \u7528\u4e8e\u7ed8\u56fe import matplotlib import matplotlib.pyplot as plt import seaborn as sns # \u8bbe\u7f6e\u5168\u5c40\u6807\u7b7e\u6587\u672c\u7684\u5927\u5c0f matplotlib . rc ( 'xtick' , labelsize = 20 ) matplotlib . rc ( 'ytick' , labelsize = 20 ) # \u786e\u4fdd\u56fe\u8868\u76f4\u63a5\u5728\u7b14\u8bb0\u672c\u5185\u663e\u793a % matplotlib inline # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6\u7684\u5217\u8868 # \u6211\u4eec\u4ece 50% \u7684\u51c6\u786e\u5ea6\u5f00\u59cb train_accuracies = [ 0.5 ] test_accuracies = [ 0.5 ] # \u904d\u5386\u51e0\u4e2a\u4e0d\u540c\u7684\u6811\u6df1\u5ea6\u503c for depth in range ( 1 , 25 ): # \u521d\u59cb\u5316\u6a21\u578b clf = tree . DecisionTreeClassifier ( max_depth = depth ) # \u9009\u62e9\u7528\u4e8e\u8bad\u7ec3\u7684\u5217/\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u5728\u7ed9\u5b9a\u7279\u5f81\u4e0a\u62df\u5408\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) # \u521b\u5efa\u8bad\u7ec3\u548c\u6d4b\u8bd5\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) # \u6dfb\u52a0\u51c6\u786e\u5ea6\u5230\u5217\u8868 train_accuracies . append ( train_accuracy ) test_accuracies . append ( test_accuracy ) # \u4f7f\u7528 matplotlib \u548c seaborn \u521b\u5efa\u4e24\u4e2a\u56fe plt . figure ( figsize = ( 10 , 5 )) sns . set_style ( \"whitegrid\" ) plt . plot ( train_accuracies , label = \"train accuracy\" ) plt . plot ( test_accuracies , label = \"test accuracy\" ) plt . legend ( loc = \"upper left\" , prop = { 'size' : 15 }) plt . xticks ( range ( 0 , 26 , 5 )) plt . xlabel ( \"max_depth\" , size = 20 ) plt . ylabel ( \"accuracy\" , size = 20 ) plt . show () \u8fd9\u5c06\u751f\u6210\u5982\u56fe 2 \u6240\u793a\u7684\u66f2\u7ebf\u56fe\u3002 \u56fe 2\uff1a\u4e0d\u540c max_depth \u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u7387\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5f53\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u503c\u4e3a 14 \u65f6\uff0c\u6d4b\u8bd5\u6570\u636e\u7684\u5f97\u5206\u6700\u9ad8\u3002\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u589e\u52a0\u8fd9\u4e2a\u53c2\u6570\u7684\u503c\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4f1a\u4fdd\u6301\u4e0d\u53d8\u6216\u53d8\u5dee\uff0c\u4f46\u8bad\u7ec3\u51c6\u786e\u7387\u4f1a\u4e0d\u65ad\u63d0\u9ad8\u3002\u8fd9\u8bf4\u660e\uff0c\u968f\u7740\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u589e\u52a0\uff0c\u51b3\u7b56\u6811\u6a21\u578b\u5bf9\u8bad\u7ec3\u6570\u636e\u7684\u5b66\u4e60\u6548\u679c\u8d8a\u6765\u8d8a\u597d\uff0c\u4f46\u6d4b\u8bd5\u6570\u636e\u7684\u6027\u80fd\u5374\u4e1d\u6beb\u6ca1\u6709\u63d0\u9ad8\u3002 \u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8fc7\u62df\u5408 \u3002 \u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u5b8c\u5168\u62df\u5408\uff0c\u800c\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5374\u8868\u73b0\u4e0d\u4f73\u3002\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5f88\u597d\u5730\u5b66\u4e60\u8bad\u7ec3\u6570\u636e\uff0c\u4f46\u65e0\u6cd5\u6cdb\u5316\u5230\u672a\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u3002\u5728\u4e0a\u9762\u7684\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u975e\u5e38\u9ad8\u7684\u6a21\u578b\uff0c\u5b83\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u4f1a\u6709\u51fa\u8272\u7684\u7ed3\u679c\uff0c\u4f46\u8fd9\u79cd\u6a21\u578b\u5e76\u4e0d\u5b9e\u7528\uff0c\u56e0\u4e3a\u5b83\u5728\u771f\u5b9e\u4e16\u754c\u7684\u6837\u672c\u6216\u5b9e\u65f6\u6570\u636e\u4e0a\u4e0d\u4f1a\u63d0\u4f9b\u7c7b\u4f3c\u7684\u7ed3\u679c\u3002 \u6709\u4eba\u53ef\u80fd\u4f1a\u8bf4\uff0c\u8fd9\u79cd\u65b9\u6cd5\u5e76\u6ca1\u6709\u8fc7\u62df\u5408\uff0c\u56e0\u4e3a\u6d4b\u8bd5\u96c6\u7684\u51c6\u786e\u7387\u57fa\u672c\u4fdd\u6301\u4e0d\u53d8\u3002\u8fc7\u62df\u5408\u7684\u53e6\u4e00\u4e2a\u5b9a\u4e49\u662f\uff0c\u5f53\u6211\u4eec\u4e0d\u65ad\u63d0\u9ad8\u8bad\u7ec3\u635f\u5931\u65f6\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u5728\u589e\u52a0\u3002\u8fd9\u79cd\u60c5\u51b5\u5728\u795e\u7ecf\u7f51\u7edc\u4e2d\u975e\u5e38\u5e38\u89c1\u3002 \u6bcf\u5f53\u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u90fd\u5fc5\u987b\u5728\u8bad\u7ec3\u671f\u95f4\u76d1\u63a7\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u3002\u5982\u679c\u6211\u4eec\u6709\u4e00\u4e2a\u975e\u5e38\u5927\u7684\u7f51\u7edc\u6765\u5904\u7406\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u6570\u636e\u96c6\uff08\u5373\u6837\u672c\u6570\u975e\u5e38\u5c11\uff09\uff0c\u6211\u4eec\u5c31\u4f1a\u89c2\u5bdf\u5230\uff0c\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u8bad\u7ec3\uff0c\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u90fd\u4f1a\u51cf\u5c11\u3002\u4f46\u662f\uff0c\u5728\u67d0\u4e2a\u65f6\u523b\uff0c\u6d4b\u8bd5\u635f\u5931\u4f1a\u8fbe\u5230\u6700\u5c0f\u503c\uff0c\u4e4b\u540e\uff0c\u5373\u4f7f\u8bad\u7ec3\u635f\u5931\u8fdb\u4e00\u6b65\u51cf\u5c11\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u4f1a\u5f00\u59cb\u589e\u52a0\u3002\u6211\u4eec\u5fc5\u987b\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u6700\u5c0f\u503c\u65f6\u505c\u6b62\u8bad\u7ec3\u3002 \u8fd9\u662f\u5bf9\u8fc7\u62df\u5408\u6700\u5e38\u89c1\u7684\u89e3\u91ca \u3002 \u5965\u5361\u59c6\u5243\u5200\u7528\u7b80\u5355\u7684\u8bdd\u8bf4\uff0c\u5c31\u662f\u4e0d\u8981\u8bd5\u56fe\u628a\u53ef\u4ee5\u7528\u7b80\u5355\u5f97\u591a\u7684\u65b9\u6cd5\u89e3\u51b3\u7684\u4e8b\u60c5\u590d\u6742\u5316\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6848\u5c31\u662f\u6700\u5177\u901a\u7528\u6027\u7684\u89e3\u51b3\u65b9\u6848\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u53ea\u8981\u4f60\u7684\u6a21\u578b\u4e0d\u7b26\u5408\u5965\u5361\u59c6\u5243\u5200\u539f\u5219\uff0c\u5c31\u5f88\u53ef\u80fd\u662f\u8fc7\u62df\u5408\u3002 \u56fe 3\uff1a\u8fc7\u62df\u5408\u7684\u6700\u4e00\u822c\u5b9a\u4e49 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u56de\u5230\u4ea4\u53c9\u68c0\u9a8c\u3002 \u5728\u89e3\u91ca\u8fc7\u62df\u5408\u65f6\uff0c\u6211\u51b3\u5b9a\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u6211\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u53e6\u4e00\u90e8\u5206\u4e0a\u68c0\u67e5\u5176\u6027\u80fd\u3002\u8fd9\u4e5f\u662f\u4ea4\u53c9\u68c0\u9a8c\u7684\u4e00\u79cd\uff0c\u901a\u5e38\u88ab\u79f0\u4e3a \"\u6682\u7559\u96c6\"\uff08 hold-out set \uff09\u3002\u5f53\u6211\u4eec\u62e5\u6709\u5927\u91cf\u6570\u636e\uff0c\u800c\u6a21\u578b\u63a8\u7406\u662f\u4e00\u4e2a\u8017\u65f6\u7684\u8fc7\u7a0b\u65f6\uff0c\u6211\u4eec\u5c31\u4f1a\u4f7f\u7528\u8fd9\u79cd\uff08\u4ea4\u53c9\uff09\u9a8c\u8bc1\u3002 \u4ea4\u53c9\u68c0\u9a8c\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\uff0c\u5b83\u662f\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u6b65\u9aa4\u3002 \u9009\u62e9\u6b63\u786e\u7684\u4ea4\u53c9\u68c0\u9a8c \u53d6\u51b3\u4e8e\u6240\u5904\u7406\u7684\u6570\u636e\u96c6\uff0c\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u9002\u7528\u7684\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u53ef\u80fd\u4e0d\u9002\u7528\u4e8e\u5176\u4ed6\u6570\u636e\u96c6\u3002\u4e0d\u8fc7\uff0c\u6709\u51e0\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u6700\u4e3a\u6d41\u884c\u548c\u5e7f\u6cdb\u4f7f\u7528\u3002 \u5176\u4e2d\u5305\u62ec\uff1a k \u6298\u4ea4\u53c9\u68c0\u9a8c \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c \u5206\u7ec4 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u4ea4\u53c9\u68c0\u9a8c\u662f\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u5c42\u51e0\u4e2a\u90e8\u5206\uff0c\u6211\u4eec\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u5176\u4f59\u90e8\u5206\u4e0a\u8fdb\u884c\u6d4b\u8bd5\u3002\u8bf7\u770b\u56fe 4\u3002 \u56fe 4\uff1a\u5c06\u6570\u636e\u96c6\u62c6\u5206\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u56fe 4 \u548c\u56fe 5 \u8bf4\u660e\uff0c\u5f53\u4f60\u5f97\u5230\u4e00\u4e2a\u6570\u636e\u96c6\u6765\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u65f6\uff0c\u4f60\u4f1a\u628a\u5b83\u4eec\u5206\u6210 \u4e24\u4e2a\u4e0d\u540c\u7684\u96c6\uff1a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u3002\u5f88\u591a\u4eba\u8fd8\u4f1a\u5c06\u5176\u5206\u6210\u7b2c\u4e09\u7ec4\uff0c\u79f0\u4e4b\u4e3a\u6d4b\u8bd5\u96c6\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u5c06\u53ea\u4f7f\u7528\u4e24\u4e2a\u96c6\u3002\u5982\u4f60\u6240\u89c1\uff0c\u6211\u4eec\u5c06\u6837\u672c\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\u8fdb\u884c\u4e86\u5212\u5206\u3002\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a k \u4e2a\u4e92\u4e0d\u5173\u8054\u7684\u4e0d\u540c\u96c6\u5408\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002 \u56fe 5\uff1aK \u6298\u4ea4\u53c9\u68c0\u9a8c \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 KFold \u5c06\u4efb\u4f55\u6570\u636e\u5206\u5272\u6210 k \u4e2a\u76f8\u7b49\u7684\u90e8\u5206\u3002\u6bcf\u4e2a\u6837\u672c\u5206\u914d\u4e00\u4e2a\u4ece 0 \u5230 k-1 \u7684\u503c\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u5b58\u50a8\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a kfold \u7684\u65b0\u5217\uff0c\u5e76\u7528 -1 \u586b\u5145 df [ \"kfold\" ] = - 1 # \u63a5\u4e0b\u6765\u7684\u6b65\u9aa4\u662f\u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4ece model_selection \u6a21\u5757\u521d\u59cb\u5316 kfold \u7c7b kf = model_selection . KFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684 kfold \u5217\uff08enumerate\u7684\u4f5c\u7528\u662f\u8fd4\u56de\u4e00\u4e2a\u8fed\u4ee3\u5668\uff09 for fold , ( trn_ , val_ ) in enumerate ( kf . split ( X = df )): df . loc [ val_ , 'kfold' ] = fold # \u4fdd\u5b58\u5e26\u6709 kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u636e\u96c6\u90fd\u53ef\u4ee5\u4f7f\u7528\u6b64\u6d41\u7a0b\u3002\u4f8b\u5982\uff0c\u5f53\u6570\u636e\u56fe\u50cf\u65f6\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u56fe\u50cf ID\u3001\u56fe\u50cf\u4f4d\u7f6e\u548c\u56fe\u50cf\u6807\u7b7e\u7684 CSV\uff0c\u7136\u540e\u4f7f\u7528\u4e0a\u8ff0\u6d41\u7a0b\u3002 \u53e6\u4e00\u79cd\u91cd\u8981\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u662f \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u6570\u636e\u96c6\uff0c\u5176\u4e2d\u6b63\u6837\u672c\u5360 90%\uff0c\u8d1f\u6837\u672c\u53ea\u5360 10%\uff0c\u90a3\u4e48\u4f60\u5c31\u4e0d\u5e94\u8be5\u4f7f\u7528\u968f\u673a k \u6298\u4ea4\u53c9\u3002\u5bf9\u8fd9\u6837\u7684\u6570\u636e\u96c6\u4f7f\u7528\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u6298\u53e0\u6837\u672c\u5168\u90e8\u4e3a\u8d1f\u6837\u672c\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u66f4\u503e\u5411\u4e8e\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u4ee5\u4fdd\u6301\u6bcf\u4e2a\u6298\u4e2d\u6807\u7b7e\u7684\u6bd4\u4f8b\u4e0d\u53d8\u3002\u56e0\u6b64\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u90fd\u4f1a\u6709\u76f8\u540c\u7684 90% \u6b63\u6837\u672c\u548c 10% \u8d1f\u6837\u672c\u3002\u56e0\u6b64\uff0c\u65e0\u8bba\u60a8\u9009\u62e9\u4ec0\u4e48\u6307\u6807\u8fdb\u884c\u8bc4\u4f30\uff0c\u90fd\u4f1a\u5728\u6240\u6709\u6298\u53e0\u4e2d\u5f97\u5230\u76f8\u4f3c\u7684\u7ed3\u679c\u3002 \u4fee\u6539\u521b\u5efa k \u6298\u4ea4\u53c9\u68c0\u9a8c\u7684\u4ee3\u7801\u4ee5\u521b\u5efa\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5f88\u5bb9\u6613\u3002\u6211\u4eec\u53ea\u9700\u5c06 model_selection.KFold \u66f4\u6539\u4e3a model_selection.StratifiedKFold \uff0c\u5e76\u5728 kf.split(...) \u51fd\u6570\u4e2d\u6307\u5b9a\u8981\u5206\u5c42\u7684\u76ee\u6807\u5217\u3002\u6211\u4eec\u5047\u8bbe CSV \u6570\u636e\u96c6\u6709\u4e00\u5217\u540d\u4e3a \"target\" \uff0c\u5e76\u4e14\u662f\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u4fdd\u5b58\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6dfb\u52a0\u4e00\u4e2a\u65b0\u5217 kfold\uff0c\u5e76\u7528 -1 \u521d\u59cb\u5316 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u53d8\u91cf y = df . target . values # \u521d\u59cb\u5316 StratifiedKFold \u7c7b\uff0c\u8bbe\u7f6e\u6298\u6570\uff08folds\uff09\u4e3a 5 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u4f7f\u7528 StratifiedKFold \u5bf9\u8c61\u7684 split \u65b9\u6cd5\u6765\u83b7\u53d6\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7d22\u5f15 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u5305\u542b kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u5bf9\u4e8e\u8461\u8404\u9152\u6570\u636e\u96c6\uff0c\u6211\u4eec\u6765\u770b\u770b\u6807\u7b7e\u7684\u5206\u5e03\u60c5\u51b5\u3002 b = sns . countplot ( x = 'quality' , data = df ) b . set_xlabel ( \"quality\" , fontsize = 20 ) b . set_ylabel ( \"count\" , fontsize = 20 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u7ee7\u7eed\u4e0a\u9762\u7684\u4ee3\u7801\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u8f6c\u6362\u4e86\u76ee\u6807\u503c\u3002\u4ece\u56fe 6 \u4e2d\u6211\u4eec\u53ef\u4ee5\u770b\u51fa\uff0c\u8d28\u91cf\u504f\u5dee\u5f88\u5927\u3002\u6709\u4e9b\u7c7b\u522b\u6709\u5f88\u591a\u6837\u672c\uff0c\u6709\u4e9b\u5219\u6ca1\u6709\u90a3\u4e48\u591a\u3002\u5982\u679c\u6211\u4eec\u8fdb\u884c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u90a3\u4e48\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u76ee\u6807\u503c\u5206\u5e03\u90fd\u4e0d\u4f1a\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u56fe 6\uff1a\u8461\u8404\u9152\u6570\u636e\u96c6\u4e2d \"\u8d28\u91cf\" \u5206\u5e03\u60c5\u51b5 \u89c4\u5219\u5f88\u7b80\u5355\uff0c\u5982\u679c\u662f\u6807\u51c6\u5206\u7c7b\u95ee\u9898\uff0c\u5c31\u76f2\u76ee\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u4f46\u5982\u679c\u6570\u636e\u91cf\u5f88\u5927\uff0c\u8be5\u600e\u4e48\u529e\u5462\uff1f\u5047\u8bbe\u6211\u4eec\u6709 100 \u4e07\u4e2a\u6837\u672c\u30025 \u500d\u4ea4\u53c9\u68c0\u9a8c\u610f\u5473\u7740\u5728 800k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u5728 200k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u9a8c\u8bc1\u3002\u6839\u636e\u6211\u4eec\u9009\u62e9\u7684\u7b97\u6cd5\uff0c\u5bf9\u4e8e\u8fd9\u6837\u89c4\u6a21\u7684\u6570\u636e\u96c6\u6765\u8bf4\uff0c\u8bad\u7ec3\u751a\u81f3\u9a8c\u8bc1\u90fd\u53ef\u80fd\u975e\u5e38\u6602\u8d35\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u3002 \u521b\u5efa\u4fdd\u6301\u7ed3\u679c\u7684\u8fc7\u7a0b\u4e0e\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u76f8\u540c\u3002\u5bf9\u4e8e\u62e5\u6709 100 \u4e07\u4e2a\u6837\u672c\u7684\u6570\u636e\u96c6\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa 10 \u4e2a\u6298\u53e0\u800c\u4e0d\u662f 5 \u4e2a\uff0c\u5e76\u4fdd\u7559\u5176\u4e2d\u4e00\u4e2a\u6298\u53e0\u4f5c\u4e3a\u4fdd\u7559\u6837\u672c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u6211\u4eec\u5c06\u6709 10 \u4e07\u4e2a\u6837\u672c\u88ab\u4fdd\u7559\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u59cb\u7ec8\u5728\u8fd9\u4e2a\u6837\u672c\u96c6\u4e0a\u8ba1\u7b97\u635f\u5931\u3001\u51c6\u786e\u7387\u548c\u5176\u4ed6\u6307\u6807\uff0c\u5e76\u5728 90 \u4e07\u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002 \u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u6682\u7559\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u975e\u5e38\u5e38\u7528\u3002\u5047\u8bbe\u6211\u4eec\u8981\u89e3\u51b3\u7684\u95ee\u9898\u662f\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97 2020 \u5e74\u7684\u9500\u552e\u989d\uff0c\u800c\u6211\u4eec\u5f97\u5230\u7684\u662f 2015-2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f60\u53ef\u4ee5\u9009\u62e9 2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u4f5c\u4e3a\u4fdd\u7559\u6570\u636e\uff0c\u7136\u540e\u5728 2015 \u5e74\u81f3 2018 \u5e74\u7684\u6240\u6709\u6570\u636e\u4e0a\u8bad\u7ec3\u4f60\u7684\u6a21\u578b\u3002 \u56fe 7\uff1a\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u793a\u4f8b \u5728\u56fe 7 \u6240\u793a\u7684\u793a\u4f8b\u4e2d\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u4ece\u65f6\u95f4\u6b65\u9aa4 31 \u5230 40 \u7684\u9500\u552e\u989d\u3002\u6211\u4eec\u53ef\u4ee5\u4fdd\u7559 21 \u81f3 30 \u6b65\u7684\u6570\u636e\uff0c\u7136\u540e\u4ece 0 \u6b65\u5230 20 \u6b65\u8bad\u7ec3\u6a21\u578b\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5728\u9884\u6d4b 31 \u6b65\u81f3 40 \u6b65\u65f6\uff0c\u5e94\u5c06 21 \u6b65\u81f3 30 \u6b65\u7684\u6570\u636e\u7eb3\u5165\u6a21\u578b\uff0c\u5426\u5219\uff0c\u6a21\u578b\u7684\u6027\u80fd\u5c06\u5927\u6253\u6298\u6263\u3002 \u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u5fc5\u987b\u5904\u7406\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u800c\u521b\u5efa\u5927\u578b\u9a8c\u8bc1\u96c6\u610f\u5473\u7740\u6a21\u578b\u5b66\u4e60\u4f1a\u4e22\u5931\u5927\u91cf\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9\u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c\uff0c\u76f8\u5f53\u4e8e\u7279\u6b8a\u7684 k \u5219\u4ea4\u53c9\u68c0\u9a8c\u5176\u4e2d k=N \uff0cN \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8fd9\u610f\u5473\u7740\u5728\u6240\u6709\u7684\u8bad\u7ec3\u6298\u53e0\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u9664 1 \u4e4b\u5916\u7684\u6240\u6709\u6570\u636e\u6837\u672c\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u7684\u6298\u53e0\u6570\u4e0e\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u76f8\u540c\u3002 \u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u6a21\u578b\u7684\u901f\u5ea6\u4e0d\u591f\u5feb\uff0c\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\uff0c\u4f46\u7531\u4e8e\u8fd9\u79cd\u4ea4\u53c9\u68c0\u9a8c\u53ea\u9002\u7528\u4e8e\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u56e0\u6b64\u5e76\u4e0d\u91cd\u8981\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u95ee\u9898\u4e86\u3002\u56de\u5f52\u95ee\u9898\u7684\u597d\u5904\u5728\u4e8e\uff0c\u9664\u4e86\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e4b\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u5728\u56de\u5f52\u95ee\u9898\u4e0a\u4f7f\u7528\u4e0a\u8ff0\u6240\u6709\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6211\u4eec\u4e0d\u80fd\u76f4\u63a5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u6709\u4e00\u4e9b\u65b9\u6cd5\u53ef\u4ee5\u7a0d\u7a0d\u6539\u53d8\u95ee\u9898\uff0c\u4ece\u800c\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u9002\u7528\u4e8e\u4efb\u4f55\u56de\u5f52\u95ee\u9898\u3002\u4f46\u662f\uff0c\u5982\u679c\u53d1\u73b0\u76ee\u6807\u5206\u5e03\u4e0d\u4e00\u81f4\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u8981\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u5fc5\u987b\u5148\u5c06\u76ee\u6807\u5212\u5206\u4e3a\u82e5\u5e72\u4e2a\u5206\u5c42\uff0c\u7136\u540e\u518d\u4ee5\u5904\u7406\u5206\u7c7b\u95ee\u9898\u7684\u76f8\u540c\u65b9\u5f0f\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u9009\u62e9\u5408\u9002\u7684\u5206\u5c42\u6570\u6709\u51e0\u79cd\u9009\u62e9\u3002\u5982\u679c\u6837\u672c\u91cf\u5f88\u5927\uff08> 10k\uff0c> 100k\uff09\uff0c\u90a3\u4e48\u5c31\u4e0d\u9700\u8981\u8003\u8651\u5206\u5c42\u7684\u6570\u91cf\u3002\u53ea\u9700\u5c06\u6570\u636e\u5206\u4e3a 10 \u6216 20 \u5c42\u5373\u53ef\u3002\u5982\u679c\u6837\u672c\u6570\u4e0d\u591a\uff0c\u5219\u53ef\u4ee5\u4f7f\u7528 Sturge's Rule \u8fd9\u6837\u7684\u7b80\u5355\u89c4\u5219\u6765\u8ba1\u7b97\u9002\u5f53\u7684\u5206\u5c42\u6570\u3002 Sturge's Rule\uff1a \\[ Number of Bins = 1 + log_2(N) \\] \u5176\u4e2d \\(N\\) \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8be5\u51fd\u6570\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u5229\u7528\u65af\u7279\u683c\u6cd5\u5219\u7ed8\u5236\u6837\u672c\u4e0e\u7bb1\u6570\u5bf9\u6bd4\u56fe \u8ba9\u6211\u4eec\u5236\u4f5c\u4e00\u4e2a\u56de\u5f52\u6570\u636e\u96c6\u6837\u672c\uff0c\u5e76\u5c1d\u8bd5\u5e94\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 # stratified-kfold for regression # \u4e3a\u56de\u5f52\u95ee\u9898\u8fdb\u884c\u5206\u5c42K-\u6298\u4ea4\u53c9\u9a8c\u8bc1 # \u5bfc\u5165\u9700\u8981\u7684\u5e93 import numpy as np import pandas as pd from sklearn import datasets from sklearn import model_selection # \u521b\u5efa\u5206\u6298\uff08folds\uff09\u7684\u51fd\u6570 def create_folds ( data ): # \u521b\u5efa\u4e00\u4e2a\u65b0\u5217\u53eb\u505akfold\uff0c\u5e76\u7528-1\u6765\u586b\u5145 data [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4f7f\u7528Sturge\u89c4\u5219\u8ba1\u7b97bin\u7684\u6570\u91cf num_bins = int ( np . floor ( 1 + np . log2 ( len ( data )))) # \u4f7f\u7528pandas\u7684cut\u51fd\u6570\u8fdb\u884c\u76ee\u6807\u53d8\u91cf\uff08target\uff09\u7684\u5206\u7bb1 data . loc [:, \"bins\" ] = pd . cut ( data [ \"target\" ], bins = num_bins , labels = False ) # \u521d\u59cb\u5316StratifiedKFold\u7c7b kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684kfold\u5217 # \u6ce8\u610f\uff1a\u6211\u4eec\u4f7f\u7528\u7684\u662fbins\u800c\u4e0d\u662f\u5b9e\u9645\u7684\u76ee\u6807\u53d8\u91cf\uff08target\uff09\uff01 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = data , y = data . bins . values )): data . loc [ v_ , 'kfold' ] = f # \u5220\u9664bins\u5217 data = data . drop ( \"bins\" , axis = 1 ) # \u8fd4\u56de\u5305\u542bfolds\u7684\u6570\u636e return data # \u4e3b\u7a0b\u5e8f\u5f00\u59cb if __name__ == \"__main__\" : # \u521b\u5efa\u4e00\u4e2a\u5e26\u670915000\u4e2a\u6837\u672c\u3001100\u4e2a\u7279\u5f81\u548c1\u4e2a\u76ee\u6807\u53d8\u91cf\u7684\u6837\u672c\u6570\u636e\u96c6 X , y = datasets . make_regression ( n_samples = 15000 , n_features = 100 , n_targets = 1 ) # \u4f7f\u7528numpy\u6570\u7ec4\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u6846 df = pd . DataFrame ( X , columns = [ f \"f_ { i } \" for i in range ( X . shape [ 1 ])] ) df . loc [:, \"target\" ] = y # \u521b\u5efafolds df = create_folds ( df ) \u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u7b2c\u4e00\u6b65\uff0c\u4e5f\u662f\u6700\u57fa\u672c\u7684\u4e00\u6b65\u3002\u5982\u679c\u8981\u505a\u7279\u5f81\u5de5\u7a0b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u8981\u5efa\u7acb\u6a21\u578b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u65b9\u6848\uff0c\u5176\u4e2d\u9a8c\u8bc1\u6570\u636e\u80fd\u591f\u4ee3\u8868\u8bad\u7ec3\u6570\u636e\u548c\u771f\u5b9e\u4e16\u754c\u7684\u6570\u636e\uff0c\u90a3\u4e48\u4f60\u5c31\u80fd\u5efa\u7acb\u4e00\u4e2a\u5177\u6709\u9ad8\u5ea6\u901a\u7528\u6027\u7684\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u3002 \u672c\u7ae0\u4ecb\u7ecd\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u6570\u636e\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u6839\u636e\u4f60\u7684\u95ee\u9898\u548c\u6570\u636e\u91c7\u7528\u65b0\u7684\u4ea4\u53c9\u68c0\u9a8c\u5f62\u5f0f\u3002 \u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u95ee\u9898\uff0c\u5e0c\u671b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u4ece\u60a3\u8005\u7684\u76ae\u80a4\u56fe\u50cf\u4e2d\u68c0\u6d4b\u51fa\u76ae\u80a4\u764c\u3002\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u5668\uff0c\u8be5\u5206\u7c7b\u5668\u63a5\u6536\u8f93\u5165\u56fe\u50cf\u5e76\u9884\u6d4b\u5176\u826f\u6027\u6216\u6076\u6027\u7684\u6982\u7387\u3002 \u5728\u8fd9\u7c7b\u6570\u636e\u96c6\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u53ef\u80fd\u6709\u540c\u4e00\u60a3\u8005\u7684\u591a\u5f20\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u8981\u5728\u8fd9\u91cc\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u7cfb\u7edf\uff0c\u5fc5\u987b\u6709\u5206\u5c42\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u4e5f\u5fc5\u987b\u786e\u4fdd\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u60a3\u8005\u4e0d\u4f1a\u51fa\u73b0\u5728\u9a8c\u8bc1\u6570\u636e\u4e2d\u3002\u5e78\u8fd0\u7684\u662f\uff0cscikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u79f0\u4e3a GroupKFold \u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u3002 \u5728\u8fd9\u91cc\uff0c\u60a3\u8005\u53ef\u4ee5\u88ab\u89c6\u4e3a\u7ec4\u3002 \u4f46\u9057\u61be\u7684\u662f\uff0cscikit-learn \u65e0\u6cd5\u5c06 GroupKFold \u4e0e StratifiedKFold \u7ed3\u5408\u8d77\u6765\u3002\u6240\u4ee5\u4f60\u9700\u8981\u81ea\u5df1\u52a8\u624b\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u8bfb\u8005\u7684\u7ec3\u4e60\u3002","title":"\u4ea4\u53c9\u68c0\u9a8c"},{"location":"%E5%87%86%E5%A4%87%E7%8E%AF%E5%A2%83/","text":"\u51c6\u5907\u73af\u5883 \u5728\u6211\u4eec\u5f00\u59cb\u7f16\u7a0b\u4e4b\u524d\uff0c\u5728\u4f60\u7684\u673a\u5668\u4e0a\u8bbe\u7f6e\u597d\u4e00\u5207\u662f\u975e\u5e38\u91cd\u8981\u7684\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Ubuntu 18.04 \u548c Python 3.7.6\u3002\u5982\u679c\u4f60\u662f Windows \u7528\u6237\uff0c\u53ef\u4ee5\u901a\u8fc7\u591a\u79cd\u65b9\u5f0f\u5b89\u88c5 Ubuntu\u3002\u4f8b\u5982\uff0c\u5728\u865a\u62df\u673a\u4e0a\u5b89\u88c5\u7531Oracle\u516c\u53f8\u63d0\u4f9b\u7684\u514d\u8d39\u8f6f\u4ef6 Virtual Box\u3002\u4e0eWindows\u4e00\u8d77\u4f5c\u4e3a\u53cc\u542f\u52a8\u7cfb\u7edf\u3002\u6211\u66f4\u559c\u6b22\u53cc\u542f\u52a8\uff0c\u56e0\u4e3a\u5b83\u662f\u539f\u751f\u7684\u3002\u5982\u679c\u4f60\u4e0d\u662fUbuntu\u7528\u6237\uff0c\u5728\u4f7f\u7528\u672c\u4e66\u4e2d\u7684\u67d0\u4e9bbash\u811a\u672c\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u95ee\u9898\u3002\u4e3a\u4e86\u907f\u514d\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u53ef\u4ee5\u5728\u865a\u62df\u673a\u4e2d\u5b89\u88c5Ubuntu\uff0c\u6216\u8005\u5728Windows\u4e0a\u5b89\u88c5Linux shell\u3002 \u7528 Anaconda \u5728\u4efb\u4f55\u673a\u5668\u4e0a\u5b89\u88c5 Python \u90fd\u5f88\u7b80\u5355\u3002\u6211\u7279\u522b\u559c\u6b22 Miniconda \uff0c\u5b83\u662f conda \u7684\u6700\u5c0f\u5b89\u88c5\u7a0b\u5e8f\u3002\u5b83\u9002\u7528\u4e8e Linux\u3001OSX \u548c Windows\u3002\u7531\u4e8e Python 2 \u652f\u6301\u5df2\u4e8e 2019 \u5e74\u5e95\u7ed3\u675f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Python 3 \u53d1\u884c\u7248\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0cminiconda \u5e76\u4e0d\u50cf\u666e\u901a Anaconda \u9644\u5e26\u6240\u6709\u8f6f\u4ef6\u5305\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u968f\u65f6\u5b89\u88c5\u65b0\u8f6f\u4ef6\u5305\u3002\u5b89\u88c5 miniconda \u975e\u5e38\u7b80\u5355\u3002 \u9996\u5148\u8981\u505a\u7684\u662f\u5c06 Miniconda3 \u4e0b\u8f7d\u5230\u7cfb\u7edf\u4e2d\u3002 cd ~/Downloads wget https://repo.anaconda.com/miniconda/... \u5176\u4e2d wget \u547d\u4ee4\u540e\u7684 URL \u662f miniconda3 \u7f51\u9875\u7684 URL\u3002\u5bf9\u4e8e 64 \u4f4d Linux \u7cfb\u7edf\uff0c\u7f16\u5199\u672c\u4e66\u65f6\u7684 URL \u662f https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \u4e0b\u8f7d miniconda3 \u540e\uff0c\u53ef\u4ee5\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\uff1a sh Miniconda3-latest-Linux-x86_64.sh \u63a5\u4e0b\u6765\uff0c\u8bf7\u9605\u8bfb\u5e76\u6309\u7167\u5c4f\u5e55\u4e0a\u7684\u8bf4\u660e\u64cd\u4f5c\u3002\u5982\u679c\u5b89\u88c5\u6b63\u786e\uff0c\u4f60\u5e94\u8be5\u53ef\u4ee5\u901a\u8fc7\u5728\u7ec8\u7aef\u8f93\u5165 conda init \u6765\u542f\u52a8 conda \u73af\u5883\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u5728\u672c\u4e66\u4e2d\u4e00\u76f4\u4f7f\u7528\u7684 conda \u73af\u5883\u3002\u8981\u521b\u5efa conda \u73af\u5883\uff0c\u53ef\u4ee5\u8f93\u5165\uff1a conda create -n environment_name python = 3 .7.6 \u6b64\u547d\u4ee4\u5c06\u521b\u5efa\u540d\u4e3a environment_name \u7684 conda \u73af\u5883\uff0c\u53ef\u4ee5\u4f7f\u7528\uff1a conda activate environment_name \u73b0\u5728\u6211\u4eec\u7684\u73af\u5883\u5df2\u7ecf\u642d\u5efa\u5b8c\u6bd5\u3002\u662f\u65f6\u5019\u5b89\u88c5\u4e00\u4e9b\u6211\u4eec\u4f1a\u7528\u5230\u7684\u8f6f\u4ef6\u5305\u4e86\u3002\u5728 conda \u73af\u5883\u4e2d\uff0c\u5b89\u88c5\u8f6f\u4ef6\u5305\u6709\u4e24\u79cd\u4e0d\u540c\u7684\u65b9\u5f0f\u3002 \u4f60\u53ef\u4ee5\u4ece conda \u4ed3\u5e93\u6216 PyPi \u5b98\u65b9\u4ed3\u5e93\u5b89\u88c5\u8f6f\u4ef6\u5305\u3002 conda/pip install package_name \u6ce8\u610f\uff1a\u67d0\u4e9b\u8f6f\u4ef6\u5305\u53ef\u80fd\u65e0\u6cd5\u5728 conda \u8f6f\u4ef6\u4ed3\u5e93\u4e2d\u627e\u5230\u3002\u56e0\u6b64\uff0c\u5728\u672c\u4e66\u4e2d\uff0c\u4f7f\u7528 pip \u5b89\u88c5\u662f\u6700\u53ef\u53d6\u7684\u65b9\u6cd5\u3002\u6211\u5df2\u7ecf\u521b\u5efa\u4e86\u4e00\u4e2a\u7f16\u5199\u672c\u4e66\u65f6\u4f7f\u7528\u7684\u8f6f\u4ef6\u5305\u5217\u8868\uff0c\u4fdd\u5b58\u5728 environment.yml \u4e2d\u3002 \u4f60\u53ef\u4ee5\u5728\u6211\u7684 GitHub \u4ed3\u5e93\u4e2d\u7684\u989d\u5916\u8d44\u6599\u4e2d\u627e\u5230\u5b83\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u521b\u5efa\u73af\u5883\uff1a conda env create -f environment.yml \u8be5\u547d\u4ee4\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a ml \u7684\u73af\u5883\u3002\u8981\u6fc0\u6d3b\u8be5\u73af\u5883\u5e76\u5f00\u59cb\u4f7f\u7528\uff0c\u5e94\u8fd0\u884c\uff1a conda activate ml \u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u51c6\u5907\u5c31\u7eea\uff0c\u53ef\u4ee5\u8fdb\u884c\u4e00\u4e9b\u5e94\u7528\u673a\u5668\u5b66\u4e60\u7684\u5de5\u4f5c\u4e86\uff01\u5728\u4f7f\u7528\u672c\u4e66\u8fdb\u884c\u7f16\u7801\u65f6\uff0c\u8bf7\u59cb\u7ec8\u8bb0\u4f4f\u8981\u5728 \"ml \"\u73af\u5883\u4e0b\u8fdb\u884c\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5f00\u59cb\u5b66\u4e60\u771f\u6b63\u7684\u7b2c\u4e00\u7ae0\u3002","title":"\u51c6\u5907\u73af\u5883"},{"location":"%E5%87%86%E5%A4%87%E7%8E%AF%E5%A2%83/#_1","text":"\u5728\u6211\u4eec\u5f00\u59cb\u7f16\u7a0b\u4e4b\u524d\uff0c\u5728\u4f60\u7684\u673a\u5668\u4e0a\u8bbe\u7f6e\u597d\u4e00\u5207\u662f\u975e\u5e38\u91cd\u8981\u7684\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Ubuntu 18.04 \u548c Python 3.7.6\u3002\u5982\u679c\u4f60\u662f Windows \u7528\u6237\uff0c\u53ef\u4ee5\u901a\u8fc7\u591a\u79cd\u65b9\u5f0f\u5b89\u88c5 Ubuntu\u3002\u4f8b\u5982\uff0c\u5728\u865a\u62df\u673a\u4e0a\u5b89\u88c5\u7531Oracle\u516c\u53f8\u63d0\u4f9b\u7684\u514d\u8d39\u8f6f\u4ef6 Virtual Box\u3002\u4e0eWindows\u4e00\u8d77\u4f5c\u4e3a\u53cc\u542f\u52a8\u7cfb\u7edf\u3002\u6211\u66f4\u559c\u6b22\u53cc\u542f\u52a8\uff0c\u56e0\u4e3a\u5b83\u662f\u539f\u751f\u7684\u3002\u5982\u679c\u4f60\u4e0d\u662fUbuntu\u7528\u6237\uff0c\u5728\u4f7f\u7528\u672c\u4e66\u4e2d\u7684\u67d0\u4e9bbash\u811a\u672c\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u95ee\u9898\u3002\u4e3a\u4e86\u907f\u514d\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u53ef\u4ee5\u5728\u865a\u62df\u673a\u4e2d\u5b89\u88c5Ubuntu\uff0c\u6216\u8005\u5728Windows\u4e0a\u5b89\u88c5Linux shell\u3002 \u7528 Anaconda \u5728\u4efb\u4f55\u673a\u5668\u4e0a\u5b89\u88c5 Python \u90fd\u5f88\u7b80\u5355\u3002\u6211\u7279\u522b\u559c\u6b22 Miniconda \uff0c\u5b83\u662f conda \u7684\u6700\u5c0f\u5b89\u88c5\u7a0b\u5e8f\u3002\u5b83\u9002\u7528\u4e8e Linux\u3001OSX \u548c Windows\u3002\u7531\u4e8e Python 2 \u652f\u6301\u5df2\u4e8e 2019 \u5e74\u5e95\u7ed3\u675f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Python 3 \u53d1\u884c\u7248\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0cminiconda \u5e76\u4e0d\u50cf\u666e\u901a Anaconda \u9644\u5e26\u6240\u6709\u8f6f\u4ef6\u5305\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u968f\u65f6\u5b89\u88c5\u65b0\u8f6f\u4ef6\u5305\u3002\u5b89\u88c5 miniconda \u975e\u5e38\u7b80\u5355\u3002 \u9996\u5148\u8981\u505a\u7684\u662f\u5c06 Miniconda3 \u4e0b\u8f7d\u5230\u7cfb\u7edf\u4e2d\u3002 cd ~/Downloads wget https://repo.anaconda.com/miniconda/... \u5176\u4e2d wget \u547d\u4ee4\u540e\u7684 URL \u662f miniconda3 \u7f51\u9875\u7684 URL\u3002\u5bf9\u4e8e 64 \u4f4d Linux \u7cfb\u7edf\uff0c\u7f16\u5199\u672c\u4e66\u65f6\u7684 URL \u662f https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \u4e0b\u8f7d miniconda3 \u540e\uff0c\u53ef\u4ee5\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\uff1a sh Miniconda3-latest-Linux-x86_64.sh \u63a5\u4e0b\u6765\uff0c\u8bf7\u9605\u8bfb\u5e76\u6309\u7167\u5c4f\u5e55\u4e0a\u7684\u8bf4\u660e\u64cd\u4f5c\u3002\u5982\u679c\u5b89\u88c5\u6b63\u786e\uff0c\u4f60\u5e94\u8be5\u53ef\u4ee5\u901a\u8fc7\u5728\u7ec8\u7aef\u8f93\u5165 conda init \u6765\u542f\u52a8 conda \u73af\u5883\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u5728\u672c\u4e66\u4e2d\u4e00\u76f4\u4f7f\u7528\u7684 conda \u73af\u5883\u3002\u8981\u521b\u5efa conda \u73af\u5883\uff0c\u53ef\u4ee5\u8f93\u5165\uff1a conda create -n environment_name python = 3 .7.6 \u6b64\u547d\u4ee4\u5c06\u521b\u5efa\u540d\u4e3a environment_name \u7684 conda \u73af\u5883\uff0c\u53ef\u4ee5\u4f7f\u7528\uff1a conda activate environment_name \u73b0\u5728\u6211\u4eec\u7684\u73af\u5883\u5df2\u7ecf\u642d\u5efa\u5b8c\u6bd5\u3002\u662f\u65f6\u5019\u5b89\u88c5\u4e00\u4e9b\u6211\u4eec\u4f1a\u7528\u5230\u7684\u8f6f\u4ef6\u5305\u4e86\u3002\u5728 conda \u73af\u5883\u4e2d\uff0c\u5b89\u88c5\u8f6f\u4ef6\u5305\u6709\u4e24\u79cd\u4e0d\u540c\u7684\u65b9\u5f0f\u3002 \u4f60\u53ef\u4ee5\u4ece conda \u4ed3\u5e93\u6216 PyPi \u5b98\u65b9\u4ed3\u5e93\u5b89\u88c5\u8f6f\u4ef6\u5305\u3002 conda/pip install package_name \u6ce8\u610f\uff1a\u67d0\u4e9b\u8f6f\u4ef6\u5305\u53ef\u80fd\u65e0\u6cd5\u5728 conda \u8f6f\u4ef6\u4ed3\u5e93\u4e2d\u627e\u5230\u3002\u56e0\u6b64\uff0c\u5728\u672c\u4e66\u4e2d\uff0c\u4f7f\u7528 pip \u5b89\u88c5\u662f\u6700\u53ef\u53d6\u7684\u65b9\u6cd5\u3002\u6211\u5df2\u7ecf\u521b\u5efa\u4e86\u4e00\u4e2a\u7f16\u5199\u672c\u4e66\u65f6\u4f7f\u7528\u7684\u8f6f\u4ef6\u5305\u5217\u8868\uff0c\u4fdd\u5b58\u5728 environment.yml \u4e2d\u3002 \u4f60\u53ef\u4ee5\u5728\u6211\u7684 GitHub \u4ed3\u5e93\u4e2d\u7684\u989d\u5916\u8d44\u6599\u4e2d\u627e\u5230\u5b83\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u521b\u5efa\u73af\u5883\uff1a conda env create -f environment.yml \u8be5\u547d\u4ee4\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a ml \u7684\u73af\u5883\u3002\u8981\u6fc0\u6d3b\u8be5\u73af\u5883\u5e76\u5f00\u59cb\u4f7f\u7528\uff0c\u5e94\u8fd0\u884c\uff1a conda activate ml \u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u51c6\u5907\u5c31\u7eea\uff0c\u53ef\u4ee5\u8fdb\u884c\u4e00\u4e9b\u5e94\u7528\u673a\u5668\u5b66\u4e60\u7684\u5de5\u4f5c\u4e86\uff01\u5728\u4f7f\u7528\u672c\u4e66\u8fdb\u884c\u7f16\u7801\u65f6\uff0c\u8bf7\u59cb\u7ec8\u8bb0\u4f4f\u8981\u5728 \"ml \"\u73af\u5883\u4e0b\u8fdb\u884c\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5f00\u59cb\u5b66\u4e60\u771f\u6b63\u7684\u7b2c\u4e00\u7ae0\u3002","title":"\u51c6\u5907\u73af\u5883"},{"location":"%E5%8F%AF%E9%87%8D%E5%A4%8D%E4%BB%A3%E7%A0%81%E5%92%8C%E6%A8%A1%E5%9E%8B%E6%96%B9%E6%B3%95/","text":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u5230\u4e86\u53ef\u4ee5\u5c06\u6a21\u578b/\u8bad\u7ec3\u4ee3\u7801\u5206\u53d1\u7ed9\u4ed6\u4eba\u4f7f\u7528\u7684\u9636\u6bb5\u3002\u60a8\u53ef\u4ee5\u7528\u8f6f\u76d8\u5206\u53d1\u6216\u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\uff0c\u4f46\u8fd9\u5e76\u4e0d\u7406\u60f3\u3002\u662f\u8fd9\u6837\u5417\uff1f\u4e5f\u8bb8\u5f88\u591a\u5e74\u524d\uff0c\u8fd9\u662f\u7406\u60f3\u7684\u505a\u6cd5\uff0c\u4f46\u73b0\u5728\u4e0d\u662f\u4e86\u3002 \u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\u548c\u534f\u4f5c\u7684\u9996\u9009\u65b9\u5f0f\u662f\u4f7f\u7528\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u3002Git \u662f\u6700\u6d41\u884c\u7684\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u4e4b\u4e00\u3002\u90a3\u4e48\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u5b66\u4f1a\u4e86 Git\uff0c\u5e76\u6b63\u786e\u5730\u683c\u5f0f\u5316\u4e86\u4ee3\u7801\uff0c\u7f16\u5199\u4e86\u9002\u5f53\u7684\u6587\u6863\uff0c\u8fd8\u5f00\u6e90\u4e86\u4f60\u7684\u9879\u76ee\u3002\u8fd9\u5c31\u591f\u4e86\u5417\uff1f\u4e0d\uff0c\u8fd8\u4e0d\u591f\u3002\u56e0\u4e3a\u4f60\u5728\u81ea\u5df1\u7684\u7535\u8111\u4e0a\u5199\u7684\u4ee3\u7801\uff0c\u5728\u522b\u4eba\u7684\u7535\u8111\u4e0a\u53ef\u80fd\u4f1a\u56e0\u4e3a\u5404\u79cd\u539f\u56e0\u800c\u65e0\u6cd5\u8fd0\u884c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u5728\u53d1\u5e03\u4ee3\u7801\u65f6\u80fd\u590d\u5236\u81ea\u5df1\u7684\u7535\u8111\uff0c\u800c\u5176\u4ed6\u4eba\u5728\u5b89\u88c5\u60a8\u7684\u8f6f\u4ef6\u6216\u8fd0\u884c\u60a8\u7684\u4ee3\u7801\u65f6\u4e5f\u80fd\u590d\u5236\u60a8\u7684\u7535\u8111\uff0c\u90a3\u5c31\u518d\u597d\u4e0d\u8fc7\u4e86\u3002\u4e3a\u6b64\uff0c\u5982\u4eca\u6700\u6d41\u884c\u7684\u65b9\u6cd5\u662f\u4f7f\u7528 Docker \u5bb9\u5668\uff08Docker Containers\uff09\u3002\u8981\u4f7f\u7528 Docker \u5bb9\u5668\uff0c\u4f60\u9700\u8981\u5b89\u88c5 Docker\u3002 \u8ba9\u6211\u4eec\u7528\u4e0b\u9762\u7684\u547d\u4ee4\u6765\u5b89\u88c5 Docker\u3002 sudo apt install docker.io sudo systemctl start docker sudo systemctl enable docker sudo groupadd docker sudo usermod -aG docker $USER \u8fd9\u4e9b\u547d\u4ee4\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e0a\u8fd0\u884c\u3002Docker \u6700\u68d2\u7684\u5730\u65b9\u5728\u4e8e\u5b83\u53ef\u4ee5\u5b89\u88c5\u5728\u4efb\u4f55\u673a\u5668\u4e0a\uff1a Linux\u3001Windows\u3001OSX\u3002\u56e0\u6b64\uff0c\u5982\u679c\u4f60\u4e00\u76f4\u5728 Docker \u5bb9\u5668\u4e2d\u5de5\u4f5c\uff0c\u54ea\u53f0\u673a\u5668\u90fd\u6ca1\u5173\u7cfb\uff01 Docker \u5bb9\u5668\u53ef\u4ee5\u88ab\u89c6\u4e3a\u5c0f\u578b\u865a\u62df\u673a\u3002\u4f60\u53ef\u4ee5\u4e3a\u4f60\u7684\u4ee3\u7801\u521b\u5efa\u4e00\u4e2a\u5bb9\u5668\uff0c\u7136\u540e\u6bcf\u4e2a\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u548c\u8bbf\u95ee\u5b83\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u521b\u5efa\u53ef\u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u5bb9\u5668\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e00\u7ae0\u4e2d\u8bad\u7ec3\u7684 BERT \u6a21\u578b\uff0c\u5e76\u5c1d\u8bd5\u5c06\u8bad\u7ec3\u4ee3\u7801\u5bb9\u5668\u5316\u3002 \u9996\u5148\uff0c\u4f60\u9700\u8981\u4e00\u4e2a\u5305\u542b python \u9879\u76ee\u9700\u6c42\u7684\u6587\u4ef6\u3002\u9700\u6c42\u5305\u542b\u5728\u540d\u4e3a requirements.txt \u7684\u6587\u4ef6\u4e2d\u3002\u6587\u4ef6\u540d\u662f thestandard\u3002\u8be5\u6587\u4ef6\u5305\u542b\u9879\u76ee\u4e2d\u4f7f\u7528\u7684\u6240\u6709 python \u5e93\u3002\u4e5f\u5c31\u662f\u53ef\u4ee5\u901a\u8fc7 PyPI (pip) \u4e0b\u8f7d\u7684 python \u5e93\u3002\u7528\u4e8e \u8bad\u7ec3 BERT \u6a21\u578b\u4ee5\u68c0\u6d4b\u6b63/\u8d1f\u60c5\u611f\uff0c\u6211\u4eec\u4f7f\u7528\u4e86 torch\u3001transformers\u3001tqdm\u3001scikit-learn\u3001pandas \u548c numpy\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u4eec\u5199\u5165 requirements.txt \u4e2d\u3002\u4f60\u53ef\u4ee5\u53ea\u5199\u540d\u79f0\uff0c\u4e5f\u53ef\u4ee5\u5305\u62ec\u7248\u672c\u3002\u5305\u542b\u7248\u672c\u603b\u662f\u6700\u597d\u7684\uff0c\u8fd9\u4e5f\u662f\u4f60\u5e94\u8be5\u505a\u7684\u3002\u5305\u542b\u7248\u672c\u540e\uff0c\u53ef\u4ee5\u786e\u4fdd\u5176\u4ed6\u4eba\u4f7f\u7528\u7684\u7248\u672c\u4e0e\u4f60\u7684\u7248\u672c\u76f8\u540c\uff0c\u800c\u4e0d\u662f\u6700\u65b0\u7248\u672c\uff0c\u56e0\u4e3a\u6700\u65b0\u7248\u672c\u53ef\u80fd\u4f1a\u66f4\u6539\u67d0\u4e9b\u5185\u5bb9\uff0c\u5982\u679c\u662f\u8fd9\u6837\u7684\u8bdd\uff0c\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u5c31\u4e0d\u4f1a\u4e0e\u4f60\u7684\u76f8\u540c\u4e86\u3002 \u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u663e\u793a\u4e86 requirements.txt\u3002 # requirements.txt pandas == 1.0.4 scikit - learn == 0.22.1 torch == 1.5.0 transformers == 2.11.0 \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a Dockerfile \u7684 Docker \u6587\u4ef6\u3002\u6ca1\u6709\u6269\u5c55\u540d\u3002Dockerfile \u6709\u51e0\u4e2a\u5143\u7d20\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 # Dockerfile # First of all, we include where we are getting the image # from. Image can be thought of as an operating system. # You can do \"FROM ubuntu:18.04\" # this will start from a clean ubuntu 18.04 image. # All images are downloaded from dockerhub # Here are we grabbing image from nvidia's repo # they created a docker image using ubuntu 18.04 # and installed cuda 10.1 and cudnn7 in it. Thus, we don't have to # install it. Makes our life easy. FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 # this is the same apt-get command that you are used to # except the fact that, we have -y argument. Its because # when we build this container, we cannot press Y when asked for RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* # We add a new user called \"abhishek\" # this can be anything. Anything you want it # to be. Usually, we don't use our own name, # you can use \"user\" or \"ubuntu\" RUN useradd -m abhishek # make our user own its own home directory RUN chown -R abhishek:abhishek /home/abhishek/ # copy all files from this direrctory to a # directory called app inside the home of abhishek # and abhishek owns it. COPY --chown = abhishek *.* /home/abhishek/app/ # change to user abhishek USER abhishek RUN mkdir /home/abhishek/data/ # Now we install all the requirements # after moving to the app directory # PLEASE NOTE that ubuntu 18.04 image # has python 3.6.9 and not python 3.7.6 # you can also install conda python here and use that # however, to simplify it, I will be using python 3.6.9 # inside the docker container!!!! RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt # install mkl. its needed for transformers RUN pip3 install mkl # when we log into the docker container, # we will go inside this directory automatically WORKDIR /home/abhishek/app \u521b\u5efa\u597d Docker \u6587\u4ef6\u540e\uff0c\u6211\u4eec\u5c31\u9700\u8981\u6784\u5efa\u5b83\u3002\u6784\u5efa Docker \u5bb9\u5668\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u547d\u4ee4\u3002 docker build -f Dockerfile -t bert:train . \u8be5\u547d\u4ee4\u6839\u636e\u63d0\u4f9b\u7684 Dockerfile \u6784\u5efa\u4e00\u4e2a\u5bb9\u5668\u3002Docker \u5bb9\u5668\u7684\u540d\u79f0\u662f bert:train\u3002\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a \u276f docker build -f Dockerfile -t bert:train . Sending build context to Docker daemon 19.97kB Step 1/7 : FROM nvidia/cuda:10.1-cudnn7-ubuntu18.04 ---> 3b55548ae91f Step 2/7 : RUN apt-get update && apt-get install -y git curl ca- certificates python3 python3-pip sudo && rm -rf /var/lib/apt/lists/* . . . . Removing intermediate container 8f6975dd08ba ---> d1802ac9f1b4 Step 7/7 : WORKDIR /home/abhishek/app ---> Running in 257ff09502ed Removing intermediate container 257ff09502ed ---> e5f6eb4cddd7 Successfully built e5f6eb4cddd7 Successfully tagged bert:train \u8bf7\u6ce8\u610f\uff0c\u6211\u5220\u9664\u4e86\u8f93\u51fa\u4e2d\u7684\u8bb8\u591a\u884c\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u767b\u5f55\u5bb9\u5668\u3002 docker run -ti bert:train /bin/bash \u4f60\u9700\u8981\u8bb0\u4f4f\uff0c\u4e00\u65e6\u9000\u51fa shell\uff0c\u4f60\u5728 shell \u4e2d\u6240\u505a\u7684\u4e00\u5207\u90fd\u5c06\u4e22\u5931\u3002\u4f60\u8fd8\u53ef\u4ee5\u5728 Docker \u5bb9\u5668\u4e2d\u4f7f\u7528\u3002 docker run -ti bert:train python3 train.py \u8f93\u51fa\u60c5\u51b5\uff1a Traceback (most recent call last): File \"train.py\", line 2, in import config File \"/home/abhishek/app/config.py\", line 28, in do_lower_case=True File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 393, in from_pretrained return cls._from_pretrained(*inputs, **kwargs) File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 496, in _from_pretrained list(cls.vocab_files_names.values()), OSError: Model name '../input/bert_base_uncased/' was not found in tokenizers model name list (bert-base-uncased, bert-large-uncased, bert- base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base- multilingual-cased, bert-base-chinese, bert-base-german-cased, bert- large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased- whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert- base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base- finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). We assumed '../input/bert_base_uncased/' was a path, a model identifier, or url to a directory containing vocabulary files named ['vocab.txt'] but couldn't find such vocabulary files at this path or url. \u54ce\u5440\uff0c\u51fa\u9519\u4e86\uff01 \u6211\u4e3a\u4ec0\u4e48\u8981\u628a\u9519\u8bef\u5370\u5728\u4e66\u4e0a\u5462\uff1f \u56e0\u4e3a\u7406\u89e3\u8fd9\u4e2a\u9519\u8bef\u975e\u5e38\u91cd\u8981\u3002\u8fd9\u4e2a\u9519\u8bef\u8bf4\u660e\u4ee3\u7801\u65e0\u6cd5\u627e\u5230\u76ee\u5f55\".../input/bert_base_cased\"\u3002\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u5728\u6ca1\u6709 Docker \u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u76ee\u5f55\u548c\u6240\u6709\u6587\u4ef6\u90fd\u5b58\u5728\u3002\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u662f\u56e0\u4e3a Docker \u5c31\u50cf\u4e00\u4e2a\u865a\u62df\u673a\uff01\u5b83\u6709\u81ea\u5df1\u7684\u6587\u4ef6\u7cfb\u7edf\uff0c\u672c\u5730\u673a\u5668\u4e0a\u7684\u6587\u4ef6\u4e0d\u4f1a\u5171\u4eab\u7ed9 Docker \u5bb9\u5668\u3002\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u672c\u5730\u673a\u5668\u4e0a\u7684\u8def\u5f84\u5e76\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\uff0c\u4f60\u9700\u8981\u5728\u8fd0\u884c Docker \u65f6\u5c06\u5176\u6302\u8f7d\u5230 Docker \u5bb9\u5668\u4e0a\u3002\u5f53\u6211\u4eec\u67e5\u770b\u8fd9\u4e2a\u6587\u4ef6\u5939\u7684\u8def\u5f84\u65f6\uff0c\u6211\u4eec\u77e5\u9053\u5b83\u4f4d\u4e8e\u540d\u4e3a input \u7684\u6587\u4ef6\u5939\u7684\u4e0a\u4e00\u7ea7\u3002\u8ba9\u6211\u4eec\u7a0d\u5fae\u4fee\u6539\u4e00\u4e0b config.py \u6587\u4ef6\uff01 # config.py import os import transformers # fetch home directory # in our docker container, it is # /home/abhishek HOME_DIR = os . path . expanduser ( \"~\" ) # this is the maximum number of tokens in the sentence MAX_LEN = 512 # batch sizes is low because model is huge! TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 # let's train for a maximum of 10 epochs EPOCHS = 10 # define path to BERT model files # Now we assume that all the data is stored inside # /home/abhishek/data BERT_PATH = os . path . join ( HOME_DIR , \"data\" , \"bert_base_uncased\" ) # this is where you want to save the model MODEL_PATH = os . path . join ( HOME_DIR , \"data\" , \"model.bin\" ) # training file TRAINING_FILE = os . path . join ( HOME_DIR , \"data\" , \"imdb.csv\" ) TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u73b0\u5728\uff0c\u4ee3\u7801\u5047\u5b9a\u6240\u6709\u5185\u5bb9\u90fd\u5728\u4e3b\u76ee\u5f55\u4e0b\u540d\u4e3a data \u7684\u6587\u4ef6\u5939\u4e2d\u3002 \u8bf7\u6ce8\u610f\uff0c\u5982\u679c Python \u811a\u672c\u6709\u4efb\u4f55\u6539\u52a8\uff0c\u90fd\u610f\u5473\u7740\u9700\u8981\u91cd\u5efa Docker \u5bb9\u5668\uff01\u56e0\u6b64\uff0c\u6211\u4eec\u91cd\u5efa\u5bb9\u5668\uff0c\u7136\u540e\u91cd\u65b0\u8fd0\u884c Docker \u547d\u4ee4\uff0c\u4f46\u8fd9\u6b21\u8981\u6709\u6240\u6539\u53d8\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u6ca1\u6709\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u8fd9\u4e5f\u662f\u884c\u4e0d\u901a\u7684\u3002\u522b\u62c5\u5fc3\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a Docker \u5bb9\u5668\u3002\u4f60\u53ea\u9700\u8981\u505a\u4e00\u6b21\u3002\u8981\u5b89\u88c5\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e2d\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\u3002 distribution = $( . /etc/os-release ; echo $ID$VERSION_ID ) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/ $distribution /nvidia- docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u6784\u5efa\u6211\u4eec\u7684\u5bb9\u5668\uff0c\u5e76\u5f00\u59cb\u8bad\u7ec3\u8fc7\u7a0b\uff1a docker run --gpus 1 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:train python3 train.py \u5176\u4e2d\uff0c-gpus 1 \u8868\u793a\u6211\u4eec\u5728 docker \u5bb9\u5668\u4e2d\u4f7f\u7528 1 \u4e2a GPU\uff0c-v \u8868\u793a\u6302\u8f7d\u5377\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u8981\u5c06\u672c\u5730\u76ee\u5f55 /home/abhishek/workspace/approaching_almost/input/ \u6302\u8f7d\u5230 docker \u5bb9\u5668\u4e2d\u7684 /home/abhishek/data/\u3002\u8fd9\u4e00\u6b65\u8981\u82b1\u70b9\u65f6\u95f4\uff0c\u4f46\u5b8c\u6210\u540e\uff0c\u672c\u5730\u6587\u4ef6\u5939\u4e2d\u5c31\u4f1a\u6709 model.bin\u3002 \u8fd9\u6837\uff0c\u53ea\u9700\u505a\u4e00\u4e9b\u7b80\u5355\u7684\u6539\u52a8\uff0c\u4f60\u7684\u8bad\u7ec3\u4ee3\u7801\u5c31\u5df2\u7ecf \"dockerized \"\u4e86\u3002\u73b0\u5728\uff0c\u4f60\u53ef\u4ee5\u5728\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u7cfb\u7edf\u4e0a\u4f7f\u7528\u8fd9\u4e9b\u4ee3\u7801\u8fdb\u884c\u8bad\u7ec3\u3002 \u4e0b\u4e00\u90e8\u5206\u662f\u5c06\u6211\u4eec\u8bad\u7ec3\u597d\u7684\u6a21\u578b \"\u63d0\u4f9b \"\u7ed9\u6700\u7ec8\u7528\u6237\u3002\u5047\u8bbe\u60a8\u60f3\u4ece\u63a5\u6536\u5230\u7684\u63a8\u6587\u6d41\u4e2d\u63d0\u53d6\u60c5\u611f\u4fe1\u606f\u3002\u8981\u5b8c\u6210\u8fd9\u9879\u4efb\u52a1\uff0c\u60a8\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a API\uff0c\u7528\u4e8e\u8f93\u5165\u53e5\u5b50\uff0c\u7136\u540e\u8fd4\u56de\u5e26\u6709\u60c5\u611f\u6982\u7387\u7684\u8f93\u51fa\u3002\u4f7f\u7528 Python \u6784\u5efa API \u7684\u6700\u5e38\u89c1\u65b9\u6cd5\u662f\u4f7f\u7528 Flask \uff0c\u5b83\u662f\u4e00\u4e2a\u5fae\u578b\u7f51\u7edc\u670d\u52a1\u6846\u67b6\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) MODEL = None DEVICE = \"cuda\" def sentence_prediction ( sentence ): tokenizer = config . TOKENIZER max_len = config . MAX_LEN review = str ( sentence ) review = \" \" . join ( review . split ()) inputs = tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = max_len ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] padding_length = max_len - len ( ids ) ids = ids + ([ 0 ] * padding_length ) mask = mask + ([ 0 ] * padding_length ) token_type_ids = token_type_ids + ([ 0 ] * padding_length ) ids = torch . tensor ( ids , dtype = torch . long ) . unsqueeze ( 0 ) mask = torch . tensor ( mask , dtype = torch . long ) . unsqueeze ( 0 ) token_type_ids = torch . tensor ( token_type_ids , dtype = torch . long ) . unsqueeze ( 0 ) ids = ids . to ( DEVICE , dtype = torch . long ) token_type_ids = token_type_ids . to ( DEVICE , dtype = torch . long ) mask = mask . to ( DEVICE , dtype = torch . long ) outputs = MODEL ( ids = ids , mask = mask , token_type_ids = token_type_ids ) outputs = torch . sigmoid ( outputs ) . cpu () . detach () . numpy () return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): sentence = request . args . get ( \"sentence\" ) start_time = time . time () positive_prediction = sentence_prediction ( sentence ) negative_prediction = 1 - positive_prediction response = {} response [ \"response\" ] = { \"positive\" : str ( positive_prediction ), \"negative\" : str ( negative_prediction ), \"sentence\" : str ( sentence ), \"time_taken\" : str ( time . time () - start_time ), } return flask . jsonify ( response ) if __name__ == \"__main__\" : MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ) )) MODEL . to ( DEVICE ) MODEL . eval () app . run ( host = \"0.0.0.0\" ) \u7136\u540e\u8fd0\u884c \"python api.py \"\u547d\u4ee4\u542f\u52a8 API\u3002API \u5c06\u5728\u7aef\u53e3 5000 \u7684 localhost \u4e0a\u542f\u52a8\u3002cURL \u8bf7\u6c42\u53ca\u5176\u54cd\u5e94\u793a\u4f8b\u5982\u4e0b\u3002 \u276f curl $'http://192.168.86.48:5000/predict?sentence=this%20is%20the%20best%20boo k%20ever' {\"response\":{\"negative\":\"0.0032927393913269043\",\"positive\":\"0.99670726\",\" sentence\":\"this is the best book ever\",\"time_taken\":\"0.029126882553100586\"}} \u53ef\u4ee5\u770b\u5230\uff0c\u6211\u4eec\u5f97\u5230\u7684\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002 \u60a8\u8fd8\u53ef\u4ee5\u8bbf\u95ee http://127.0.0.1:5000/predict?sentence=this%20book%20is%20too%20complicated%20for%20me\u3002\u8fd9\u5c06\u518d\u6b21\u8fd4\u56de\u4e00\u4e2a JSON \u6587\u4ef6\u3002 { response : { negative : \"0.8646619468927383\" , positive : \"0.13533805\" , sentence : \"this book is too complicated for me\" , time_taken : \"0.03852701187133789\" } } \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\uff0c\u53ef\u4ee5\u7528\u6765\u4e3a\u5c11\u91cf\u7528\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u4e3a\u4ec0\u4e48\u662f\u5c11\u91cf\uff1f\u56e0\u4e3a\u8fd9\u4e2a API \u4e00\u6b21\u53ea\u670d\u52a1\u4e00\u4e2a\u8bf7\u6c42\u3002gunicorn \u662f UNIX \u4e0a\u7684 Python WSGI HTTP \u670d\u52a1\u5668\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u5b83\u7684 CPU \u6765\u5904\u7406\u591a\u4e2a\u5e76\u884c\u8bf7\u6c42\u3002Gunicorn \u53ef\u4ee5\u4e3a API \u521b\u5efa\u591a\u4e2a\u8fdb\u7a0b\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u540c\u65f6\u4e3a\u591a\u4e2a\u5ba2\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 \"pip install gunicorn \"\u5b89\u88c5 gunicorn\u3002 \u4e3a\u4e86\u5c06\u4ee3\u7801\u8f6c\u6362\u4e3a\u4e0e gunicorn \u517c\u5bb9\uff0c\u6211\u4eec\u9700\u8981\u79fb\u9664 init main\uff0c\u5e76\u5c06\u5176\u4e2d\u7684\u6240\u6709\u5185\u5bb9\u79fb\u81f3\u5168\u5c40\u8303\u56f4\u3002\u6b64\u5916\uff0c\u6211\u4eec\u73b0\u5728\u4f7f\u7528\u7684\u662f CPU \u800c\u4e0d\u662f GPU\u3002\u4fee\u6539\u540e\u7684\u4ee3\u7801\u5982\u4e0b\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) DEVICE = \"cpu\" MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ))) MODEL . to ( DEVICE ) MODEL . eval () def sentence_prediction ( sentence ): return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): return flask . jsonify ( response ) \u6211\u4eec\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002 gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u5728\u63d0\u4f9b\u7684 IP \u5730\u5740\u548c\u7aef\u53e3\u4e0a\u4f7f\u7528 4 \u4e2a Worker \u8fd0\u884c\u6211\u4eec\u7684 flask \u5e94\u7528\u7a0b\u5e8f\u3002\u7531\u4e8e\u6709 4 \u4e2a Worker\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u540c\u65f6\u5904\u7406 4 \u4e2a\u8bf7\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u73b0\u5728\u6211\u4eec\u7684\u7ec8\u7aef\u4f7f\u7528\u7684\u662f CPU\uff0c\u56e0\u6b64\u4e0d\u9700\u8981 GPU \u673a\u5668\uff0c\u53ef\u4ee5\u5728\u4efb\u4f55\u6807\u51c6\u670d\u52a1\u5668/\u865a\u62df\u673a\u4e0a\u8fd0\u884c\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u8fd8\u6709\u4e00\u4e2a\u95ee\u9898\uff1a\u6211\u4eec\u5df2\u7ecf\u5728\u672c\u5730\u673a\u5668\u4e0a\u5b8c\u6210\u4e86\u6240\u6709\u5de5\u4f5c\uff0c\u56e0\u6b64\u5fc5\u987b\u5c06\u5176\u575e\u5316\u3002\u770b\u770b\u4e0b\u9762\u8fd9\u4e2a\u672a\u6ce8\u91ca\u7684 Dockerfile\uff0c\u5b83\u53ef\u4ee5\u7528\u6765\u90e8\u7f72\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002\u8bf7\u6ce8\u610f\u7528\u4e8e\u57f9\u8bad\u7684\u65e7 Dockerfile \u548c\u8fd9\u4e2a Dockerfile \u4e4b\u95f4\u7684\u533a\u522b\u3002\u533a\u522b\u4e0d\u5927\u3002 # CPU Dockerfile FROM ubuntu:18.04 RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* RUN useradd -m abhishek RUN chown -R abhishek:abhishek /home/abhishek/ COPY --chown = abhishek *.* /home/abhishek/app/ USER abhishek RUN mkdir /home/abhishek/data/ RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt RUN pip3 install mkl WORKDIR /home/abhishek/app \u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u65b0\u7684 Docker \u5bb9\u5668\u3002 docker build -f Dockerfile -t bert:api \u5f53 Docker \u5bb9\u5668\u6784\u5efa\u5b8c\u6210\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u76f4\u63a5\u8fd0\u884c API \u4e86\u3002 docker run -p 5000 :5000 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:api /home/abhishek/.local/bin/gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5c06\u5bb9\u5668\u5185\u7684 5000 \u7aef\u53e3\u66b4\u9732\u7ed9\u5bb9\u5668\u5916\u7684 5000 \u7aef\u53e3\u3002\u5982\u679c\u4f7f\u7528 docker-compose\uff0c\u4e5f\u53ef\u4ee5\u5f88\u597d\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002Dockercompose \u662f\u4e00\u4e2a\u53ef\u4ee5\u8ba9\u4f60\u540c\u65f6\u5728\u4e0d\u540c\u6216\u76f8\u540c\u5bb9\u5668\u4e2d\u8fd0\u884c\u4e0d\u540c\u670d\u52a1\u7684\u5de5\u5177\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 \"pip install docker-compose \"\u5b89\u88c5 docker-compose\uff0c\u7136\u540e\u5728\u6784\u5efa\u5bb9\u5668\u540e\u8fd0\u884c \"docker-compose up\"\u3002\u8981\u4f7f\u7528 docker-compose\uff0c\u4f60\u9700\u8981\u4e00\u4e2a docker-compose.yml \u6587\u4ef6\u3002 # docker-compose.yml # specify a version of the compose version : '3.7' # you can add multiple services services : # specify service name. we call our service: api api : # specify image name image : bert:api # the command that you would like to run inside the container command : /home/abhishek/.local/bin/gunicorn api:app --bind 0.0.0.0:5000 --workers 4 # mount the volume volumes : - /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ # this ensures that our ports from container will be # exposed as it is network_mode : host \u73b0\u5728\uff0c\u60a8\u53ea\u9700\u4f7f\u7528\u4e0a\u8ff0\u547d\u4ee4\u5373\u53ef\u91cd\u65b0\u8fd0\u884c API\uff0c\u5176\u8fd0\u884c\u65b9\u5f0f\u4e0e\u4e4b\u524d\u76f8\u540c\u3002\u606d\u559c\u4f60\uff0c\u73b0\u5728\uff0c\u4f60\u4e5f\u5df2\u7ecf\u6210\u529f\u5730\u5c06\u9884\u6d4b API \u8fdb\u884c\u4e86 Docker \u5316\uff0c\u53ef\u4ee5\u968f\u65f6\u968f\u5730\u90e8\u7f72\u4e86\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u5b66\u4e60\u4e86 Docker\u3001\u4f7f\u7528 flask \u6784\u5efa API\u3001\u4f7f\u7528 gunicorn \u548c Docker \u670d\u52a1 API \u4ee5\u53ca docker-compose\u3002\u5173\u4e8e docker \u7684\u77e5\u8bc6\u8fdc\u4e0d\u6b62\u8fd9\u4e9b\uff0c\u4f46\u8fd9\u5e94\u8be5\u662f\u4e00\u4e2a\u5f00\u59cb\u3002\u5176\u4ed6\u5185\u5bb9\u53ef\u4ee5\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9010\u6e10\u638c\u63e1\u3002 \u6211\u4eec\u8fd8\u8df3\u8fc7\u4e86\u8bb8\u591a\u5de5\u5177\uff0c\u5982 kubernetes\u3001bean-stalk\u3001sagemaker\u3001heroku \u548c\u8bb8\u591a\u5176\u4ed6\u5de5\u5177\uff0c\u8fd9\u4e9b\u5de5\u5177\u5982\u4eca\u88ab\u4eba\u4eec\u7528\u6765\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u6a21\u578b\u3002\"\u6211\u8981\u5199\u4ec0\u4e48\uff1f\u70b9\u51fb\u4fee\u6539\u56fe X \u4e2d\u7684 docker \u5bb9\u5668\"\uff1f\u5728\u4e66\u4e2d\u63cf\u8ff0\u8fd9\u4e9b\u662f\u4e0d\u53ef\u884c\u7684\uff0c\u4e5f\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u6240\u4ee5\u6211\u5c06\u4f7f\u7528\u4e0d\u540c\u7684\u5a92\u4ecb\u6765\u8d5e\u7f8e\u672c\u4e66\u7684\u8fd9\u4e00\u90e8\u5206\u3002\u8bf7\u8bb0\u4f4f\uff0c\u4e00\u65e6\u4f60\u5bf9\u5e94\u7528\u7a0b\u5e8f\u8fdb\u884c\u4e86 Docker \u5316\uff0c\u4f7f\u7528\u8fd9\u4e9b\u6280\u672f/\u5e73\u53f0\u8fdb\u884c\u90e8\u7f72\u5c31\u53d8\u5f97\u6613\u5982\u53cd\u638c\u4e86\u3002\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u8981\u8ba9\u4f60\u7684\u4ee3\u7801\u548c\u6a21\u578b\u5bf9\u4ed6\u4eba\u53ef\u7528\uff0c\u5e76\u505a\u597d\u6587\u6863\u8bb0\u5f55\uff0c\u8fd9\u6837\u4efb\u4f55\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u4f60\u5f00\u53d1\u7684\u4e1c\u897f\uff0c\u800c\u65e0\u9700\u591a\u6b21\u8be2\u95ee\u4f60\u3002\u8fd9\u4e0d\u4ec5\u80fd\u8282\u7701\u60a8\u7684\u65f6\u95f4\uff0c\u8fd8\u80fd\u8282\u7701\u4ed6\u4eba\u7684\u65f6\u95f4\u3002\u597d\u7684\u3001\u5f00\u6e90\u7684\u3001\u53ef\u91cd\u590d\u4f7f\u7528\u7684\u4ee3\u7801\u5728\u60a8\u7684\u4f5c\u54c1\u96c6\u4e2d\u4e5f\u975e\u5e38\u91cd\u8981\u3002","title":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5"},{"location":"%E5%8F%AF%E9%87%8D%E5%A4%8D%E4%BB%A3%E7%A0%81%E5%92%8C%E6%A8%A1%E5%9E%8B%E6%96%B9%E6%B3%95/#_1","text":"\u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u5230\u4e86\u53ef\u4ee5\u5c06\u6a21\u578b/\u8bad\u7ec3\u4ee3\u7801\u5206\u53d1\u7ed9\u4ed6\u4eba\u4f7f\u7528\u7684\u9636\u6bb5\u3002\u60a8\u53ef\u4ee5\u7528\u8f6f\u76d8\u5206\u53d1\u6216\u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\uff0c\u4f46\u8fd9\u5e76\u4e0d\u7406\u60f3\u3002\u662f\u8fd9\u6837\u5417\uff1f\u4e5f\u8bb8\u5f88\u591a\u5e74\u524d\uff0c\u8fd9\u662f\u7406\u60f3\u7684\u505a\u6cd5\uff0c\u4f46\u73b0\u5728\u4e0d\u662f\u4e86\u3002 \u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\u548c\u534f\u4f5c\u7684\u9996\u9009\u65b9\u5f0f\u662f\u4f7f\u7528\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u3002Git \u662f\u6700\u6d41\u884c\u7684\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u4e4b\u4e00\u3002\u90a3\u4e48\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u5b66\u4f1a\u4e86 Git\uff0c\u5e76\u6b63\u786e\u5730\u683c\u5f0f\u5316\u4e86\u4ee3\u7801\uff0c\u7f16\u5199\u4e86\u9002\u5f53\u7684\u6587\u6863\uff0c\u8fd8\u5f00\u6e90\u4e86\u4f60\u7684\u9879\u76ee\u3002\u8fd9\u5c31\u591f\u4e86\u5417\uff1f\u4e0d\uff0c\u8fd8\u4e0d\u591f\u3002\u56e0\u4e3a\u4f60\u5728\u81ea\u5df1\u7684\u7535\u8111\u4e0a\u5199\u7684\u4ee3\u7801\uff0c\u5728\u522b\u4eba\u7684\u7535\u8111\u4e0a\u53ef\u80fd\u4f1a\u56e0\u4e3a\u5404\u79cd\u539f\u56e0\u800c\u65e0\u6cd5\u8fd0\u884c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u5728\u53d1\u5e03\u4ee3\u7801\u65f6\u80fd\u590d\u5236\u81ea\u5df1\u7684\u7535\u8111\uff0c\u800c\u5176\u4ed6\u4eba\u5728\u5b89\u88c5\u60a8\u7684\u8f6f\u4ef6\u6216\u8fd0\u884c\u60a8\u7684\u4ee3\u7801\u65f6\u4e5f\u80fd\u590d\u5236\u60a8\u7684\u7535\u8111\uff0c\u90a3\u5c31\u518d\u597d\u4e0d\u8fc7\u4e86\u3002\u4e3a\u6b64\uff0c\u5982\u4eca\u6700\u6d41\u884c\u7684\u65b9\u6cd5\u662f\u4f7f\u7528 Docker \u5bb9\u5668\uff08Docker Containers\uff09\u3002\u8981\u4f7f\u7528 Docker \u5bb9\u5668\uff0c\u4f60\u9700\u8981\u5b89\u88c5 Docker\u3002 \u8ba9\u6211\u4eec\u7528\u4e0b\u9762\u7684\u547d\u4ee4\u6765\u5b89\u88c5 Docker\u3002 sudo apt install docker.io sudo systemctl start docker sudo systemctl enable docker sudo groupadd docker sudo usermod -aG docker $USER \u8fd9\u4e9b\u547d\u4ee4\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e0a\u8fd0\u884c\u3002Docker \u6700\u68d2\u7684\u5730\u65b9\u5728\u4e8e\u5b83\u53ef\u4ee5\u5b89\u88c5\u5728\u4efb\u4f55\u673a\u5668\u4e0a\uff1a Linux\u3001Windows\u3001OSX\u3002\u56e0\u6b64\uff0c\u5982\u679c\u4f60\u4e00\u76f4\u5728 Docker \u5bb9\u5668\u4e2d\u5de5\u4f5c\uff0c\u54ea\u53f0\u673a\u5668\u90fd\u6ca1\u5173\u7cfb\uff01 Docker \u5bb9\u5668\u53ef\u4ee5\u88ab\u89c6\u4e3a\u5c0f\u578b\u865a\u62df\u673a\u3002\u4f60\u53ef\u4ee5\u4e3a\u4f60\u7684\u4ee3\u7801\u521b\u5efa\u4e00\u4e2a\u5bb9\u5668\uff0c\u7136\u540e\u6bcf\u4e2a\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u548c\u8bbf\u95ee\u5b83\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u521b\u5efa\u53ef\u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u5bb9\u5668\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e00\u7ae0\u4e2d\u8bad\u7ec3\u7684 BERT \u6a21\u578b\uff0c\u5e76\u5c1d\u8bd5\u5c06\u8bad\u7ec3\u4ee3\u7801\u5bb9\u5668\u5316\u3002 \u9996\u5148\uff0c\u4f60\u9700\u8981\u4e00\u4e2a\u5305\u542b python \u9879\u76ee\u9700\u6c42\u7684\u6587\u4ef6\u3002\u9700\u6c42\u5305\u542b\u5728\u540d\u4e3a requirements.txt \u7684\u6587\u4ef6\u4e2d\u3002\u6587\u4ef6\u540d\u662f thestandard\u3002\u8be5\u6587\u4ef6\u5305\u542b\u9879\u76ee\u4e2d\u4f7f\u7528\u7684\u6240\u6709 python \u5e93\u3002\u4e5f\u5c31\u662f\u53ef\u4ee5\u901a\u8fc7 PyPI (pip) \u4e0b\u8f7d\u7684 python \u5e93\u3002\u7528\u4e8e \u8bad\u7ec3 BERT \u6a21\u578b\u4ee5\u68c0\u6d4b\u6b63/\u8d1f\u60c5\u611f\uff0c\u6211\u4eec\u4f7f\u7528\u4e86 torch\u3001transformers\u3001tqdm\u3001scikit-learn\u3001pandas \u548c numpy\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u4eec\u5199\u5165 requirements.txt \u4e2d\u3002\u4f60\u53ef\u4ee5\u53ea\u5199\u540d\u79f0\uff0c\u4e5f\u53ef\u4ee5\u5305\u62ec\u7248\u672c\u3002\u5305\u542b\u7248\u672c\u603b\u662f\u6700\u597d\u7684\uff0c\u8fd9\u4e5f\u662f\u4f60\u5e94\u8be5\u505a\u7684\u3002\u5305\u542b\u7248\u672c\u540e\uff0c\u53ef\u4ee5\u786e\u4fdd\u5176\u4ed6\u4eba\u4f7f\u7528\u7684\u7248\u672c\u4e0e\u4f60\u7684\u7248\u672c\u76f8\u540c\uff0c\u800c\u4e0d\u662f\u6700\u65b0\u7248\u672c\uff0c\u56e0\u4e3a\u6700\u65b0\u7248\u672c\u53ef\u80fd\u4f1a\u66f4\u6539\u67d0\u4e9b\u5185\u5bb9\uff0c\u5982\u679c\u662f\u8fd9\u6837\u7684\u8bdd\uff0c\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u5c31\u4e0d\u4f1a\u4e0e\u4f60\u7684\u76f8\u540c\u4e86\u3002 \u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u663e\u793a\u4e86 requirements.txt\u3002 # requirements.txt pandas == 1.0.4 scikit - learn == 0.22.1 torch == 1.5.0 transformers == 2.11.0 \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a Dockerfile \u7684 Docker \u6587\u4ef6\u3002\u6ca1\u6709\u6269\u5c55\u540d\u3002Dockerfile \u6709\u51e0\u4e2a\u5143\u7d20\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 # Dockerfile # First of all, we include where we are getting the image # from. Image can be thought of as an operating system. # You can do \"FROM ubuntu:18.04\" # this will start from a clean ubuntu 18.04 image. # All images are downloaded from dockerhub # Here are we grabbing image from nvidia's repo # they created a docker image using ubuntu 18.04 # and installed cuda 10.1 and cudnn7 in it. Thus, we don't have to # install it. Makes our life easy. FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 # this is the same apt-get command that you are used to # except the fact that, we have -y argument. Its because # when we build this container, we cannot press Y when asked for RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* # We add a new user called \"abhishek\" # this can be anything. Anything you want it # to be. Usually, we don't use our own name, # you can use \"user\" or \"ubuntu\" RUN useradd -m abhishek # make our user own its own home directory RUN chown -R abhishek:abhishek /home/abhishek/ # copy all files from this direrctory to a # directory called app inside the home of abhishek # and abhishek owns it. COPY --chown = abhishek *.* /home/abhishek/app/ # change to user abhishek USER abhishek RUN mkdir /home/abhishek/data/ # Now we install all the requirements # after moving to the app directory # PLEASE NOTE that ubuntu 18.04 image # has python 3.6.9 and not python 3.7.6 # you can also install conda python here and use that # however, to simplify it, I will be using python 3.6.9 # inside the docker container!!!! RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt # install mkl. its needed for transformers RUN pip3 install mkl # when we log into the docker container, # we will go inside this directory automatically WORKDIR /home/abhishek/app \u521b\u5efa\u597d Docker \u6587\u4ef6\u540e\uff0c\u6211\u4eec\u5c31\u9700\u8981\u6784\u5efa\u5b83\u3002\u6784\u5efa Docker \u5bb9\u5668\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u547d\u4ee4\u3002 docker build -f Dockerfile -t bert:train . \u8be5\u547d\u4ee4\u6839\u636e\u63d0\u4f9b\u7684 Dockerfile \u6784\u5efa\u4e00\u4e2a\u5bb9\u5668\u3002Docker \u5bb9\u5668\u7684\u540d\u79f0\u662f bert:train\u3002\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a \u276f docker build -f Dockerfile -t bert:train . Sending build context to Docker daemon 19.97kB Step 1/7 : FROM nvidia/cuda:10.1-cudnn7-ubuntu18.04 ---> 3b55548ae91f Step 2/7 : RUN apt-get update && apt-get install -y git curl ca- certificates python3 python3-pip sudo && rm -rf /var/lib/apt/lists/* . . . . Removing intermediate container 8f6975dd08ba ---> d1802ac9f1b4 Step 7/7 : WORKDIR /home/abhishek/app ---> Running in 257ff09502ed Removing intermediate container 257ff09502ed ---> e5f6eb4cddd7 Successfully built e5f6eb4cddd7 Successfully tagged bert:train \u8bf7\u6ce8\u610f\uff0c\u6211\u5220\u9664\u4e86\u8f93\u51fa\u4e2d\u7684\u8bb8\u591a\u884c\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u767b\u5f55\u5bb9\u5668\u3002 docker run -ti bert:train /bin/bash \u4f60\u9700\u8981\u8bb0\u4f4f\uff0c\u4e00\u65e6\u9000\u51fa shell\uff0c\u4f60\u5728 shell \u4e2d\u6240\u505a\u7684\u4e00\u5207\u90fd\u5c06\u4e22\u5931\u3002\u4f60\u8fd8\u53ef\u4ee5\u5728 Docker \u5bb9\u5668\u4e2d\u4f7f\u7528\u3002 docker run -ti bert:train python3 train.py \u8f93\u51fa\u60c5\u51b5\uff1a Traceback (most recent call last): File \"train.py\", line 2, in import config File \"/home/abhishek/app/config.py\", line 28, in do_lower_case=True File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 393, in from_pretrained return cls._from_pretrained(*inputs, **kwargs) File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 496, in _from_pretrained list(cls.vocab_files_names.values()), OSError: Model name '../input/bert_base_uncased/' was not found in tokenizers model name list (bert-base-uncased, bert-large-uncased, bert- base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base- multilingual-cased, bert-base-chinese, bert-base-german-cased, bert- large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased- whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert- base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base- finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). We assumed '../input/bert_base_uncased/' was a path, a model identifier, or url to a directory containing vocabulary files named ['vocab.txt'] but couldn't find such vocabulary files at this path or url. \u54ce\u5440\uff0c\u51fa\u9519\u4e86\uff01 \u6211\u4e3a\u4ec0\u4e48\u8981\u628a\u9519\u8bef\u5370\u5728\u4e66\u4e0a\u5462\uff1f \u56e0\u4e3a\u7406\u89e3\u8fd9\u4e2a\u9519\u8bef\u975e\u5e38\u91cd\u8981\u3002\u8fd9\u4e2a\u9519\u8bef\u8bf4\u660e\u4ee3\u7801\u65e0\u6cd5\u627e\u5230\u76ee\u5f55\".../input/bert_base_cased\"\u3002\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u5728\u6ca1\u6709 Docker \u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u76ee\u5f55\u548c\u6240\u6709\u6587\u4ef6\u90fd\u5b58\u5728\u3002\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u662f\u56e0\u4e3a Docker \u5c31\u50cf\u4e00\u4e2a\u865a\u62df\u673a\uff01\u5b83\u6709\u81ea\u5df1\u7684\u6587\u4ef6\u7cfb\u7edf\uff0c\u672c\u5730\u673a\u5668\u4e0a\u7684\u6587\u4ef6\u4e0d\u4f1a\u5171\u4eab\u7ed9 Docker \u5bb9\u5668\u3002\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u672c\u5730\u673a\u5668\u4e0a\u7684\u8def\u5f84\u5e76\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\uff0c\u4f60\u9700\u8981\u5728\u8fd0\u884c Docker \u65f6\u5c06\u5176\u6302\u8f7d\u5230 Docker \u5bb9\u5668\u4e0a\u3002\u5f53\u6211\u4eec\u67e5\u770b\u8fd9\u4e2a\u6587\u4ef6\u5939\u7684\u8def\u5f84\u65f6\uff0c\u6211\u4eec\u77e5\u9053\u5b83\u4f4d\u4e8e\u540d\u4e3a input \u7684\u6587\u4ef6\u5939\u7684\u4e0a\u4e00\u7ea7\u3002\u8ba9\u6211\u4eec\u7a0d\u5fae\u4fee\u6539\u4e00\u4e0b config.py \u6587\u4ef6\uff01 # config.py import os import transformers # fetch home directory # in our docker container, it is # /home/abhishek HOME_DIR = os . path . expanduser ( \"~\" ) # this is the maximum number of tokens in the sentence MAX_LEN = 512 # batch sizes is low because model is huge! TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 # let's train for a maximum of 10 epochs EPOCHS = 10 # define path to BERT model files # Now we assume that all the data is stored inside # /home/abhishek/data BERT_PATH = os . path . join ( HOME_DIR , \"data\" , \"bert_base_uncased\" ) # this is where you want to save the model MODEL_PATH = os . path . join ( HOME_DIR , \"data\" , \"model.bin\" ) # training file TRAINING_FILE = os . path . join ( HOME_DIR , \"data\" , \"imdb.csv\" ) TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u73b0\u5728\uff0c\u4ee3\u7801\u5047\u5b9a\u6240\u6709\u5185\u5bb9\u90fd\u5728\u4e3b\u76ee\u5f55\u4e0b\u540d\u4e3a data \u7684\u6587\u4ef6\u5939\u4e2d\u3002 \u8bf7\u6ce8\u610f\uff0c\u5982\u679c Python \u811a\u672c\u6709\u4efb\u4f55\u6539\u52a8\uff0c\u90fd\u610f\u5473\u7740\u9700\u8981\u91cd\u5efa Docker \u5bb9\u5668\uff01\u56e0\u6b64\uff0c\u6211\u4eec\u91cd\u5efa\u5bb9\u5668\uff0c\u7136\u540e\u91cd\u65b0\u8fd0\u884c Docker \u547d\u4ee4\uff0c\u4f46\u8fd9\u6b21\u8981\u6709\u6240\u6539\u53d8\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u6ca1\u6709\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u8fd9\u4e5f\u662f\u884c\u4e0d\u901a\u7684\u3002\u522b\u62c5\u5fc3\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a Docker \u5bb9\u5668\u3002\u4f60\u53ea\u9700\u8981\u505a\u4e00\u6b21\u3002\u8981\u5b89\u88c5\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e2d\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\u3002 distribution = $( . /etc/os-release ; echo $ID$VERSION_ID ) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/ $distribution /nvidia- docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u6784\u5efa\u6211\u4eec\u7684\u5bb9\u5668\uff0c\u5e76\u5f00\u59cb\u8bad\u7ec3\u8fc7\u7a0b\uff1a docker run --gpus 1 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:train python3 train.py \u5176\u4e2d\uff0c-gpus 1 \u8868\u793a\u6211\u4eec\u5728 docker \u5bb9\u5668\u4e2d\u4f7f\u7528 1 \u4e2a GPU\uff0c-v \u8868\u793a\u6302\u8f7d\u5377\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u8981\u5c06\u672c\u5730\u76ee\u5f55 /home/abhishek/workspace/approaching_almost/input/ \u6302\u8f7d\u5230 docker \u5bb9\u5668\u4e2d\u7684 /home/abhishek/data/\u3002\u8fd9\u4e00\u6b65\u8981\u82b1\u70b9\u65f6\u95f4\uff0c\u4f46\u5b8c\u6210\u540e\uff0c\u672c\u5730\u6587\u4ef6\u5939\u4e2d\u5c31\u4f1a\u6709 model.bin\u3002 \u8fd9\u6837\uff0c\u53ea\u9700\u505a\u4e00\u4e9b\u7b80\u5355\u7684\u6539\u52a8\uff0c\u4f60\u7684\u8bad\u7ec3\u4ee3\u7801\u5c31\u5df2\u7ecf \"dockerized \"\u4e86\u3002\u73b0\u5728\uff0c\u4f60\u53ef\u4ee5\u5728\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u7cfb\u7edf\u4e0a\u4f7f\u7528\u8fd9\u4e9b\u4ee3\u7801\u8fdb\u884c\u8bad\u7ec3\u3002 \u4e0b\u4e00\u90e8\u5206\u662f\u5c06\u6211\u4eec\u8bad\u7ec3\u597d\u7684\u6a21\u578b \"\u63d0\u4f9b \"\u7ed9\u6700\u7ec8\u7528\u6237\u3002\u5047\u8bbe\u60a8\u60f3\u4ece\u63a5\u6536\u5230\u7684\u63a8\u6587\u6d41\u4e2d\u63d0\u53d6\u60c5\u611f\u4fe1\u606f\u3002\u8981\u5b8c\u6210\u8fd9\u9879\u4efb\u52a1\uff0c\u60a8\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a API\uff0c\u7528\u4e8e\u8f93\u5165\u53e5\u5b50\uff0c\u7136\u540e\u8fd4\u56de\u5e26\u6709\u60c5\u611f\u6982\u7387\u7684\u8f93\u51fa\u3002\u4f7f\u7528 Python \u6784\u5efa API \u7684\u6700\u5e38\u89c1\u65b9\u6cd5\u662f\u4f7f\u7528 Flask \uff0c\u5b83\u662f\u4e00\u4e2a\u5fae\u578b\u7f51\u7edc\u670d\u52a1\u6846\u67b6\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) MODEL = None DEVICE = \"cuda\" def sentence_prediction ( sentence ): tokenizer = config . TOKENIZER max_len = config . MAX_LEN review = str ( sentence ) review = \" \" . join ( review . split ()) inputs = tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = max_len ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] padding_length = max_len - len ( ids ) ids = ids + ([ 0 ] * padding_length ) mask = mask + ([ 0 ] * padding_length ) token_type_ids = token_type_ids + ([ 0 ] * padding_length ) ids = torch . tensor ( ids , dtype = torch . long ) . unsqueeze ( 0 ) mask = torch . tensor ( mask , dtype = torch . long ) . unsqueeze ( 0 ) token_type_ids = torch . tensor ( token_type_ids , dtype = torch . long ) . unsqueeze ( 0 ) ids = ids . to ( DEVICE , dtype = torch . long ) token_type_ids = token_type_ids . to ( DEVICE , dtype = torch . long ) mask = mask . to ( DEVICE , dtype = torch . long ) outputs = MODEL ( ids = ids , mask = mask , token_type_ids = token_type_ids ) outputs = torch . sigmoid ( outputs ) . cpu () . detach () . numpy () return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): sentence = request . args . get ( \"sentence\" ) start_time = time . time () positive_prediction = sentence_prediction ( sentence ) negative_prediction = 1 - positive_prediction response = {} response [ \"response\" ] = { \"positive\" : str ( positive_prediction ), \"negative\" : str ( negative_prediction ), \"sentence\" : str ( sentence ), \"time_taken\" : str ( time . time () - start_time ), } return flask . jsonify ( response ) if __name__ == \"__main__\" : MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ) )) MODEL . to ( DEVICE ) MODEL . eval () app . run ( host = \"0.0.0.0\" ) \u7136\u540e\u8fd0\u884c \"python api.py \"\u547d\u4ee4\u542f\u52a8 API\u3002API \u5c06\u5728\u7aef\u53e3 5000 \u7684 localhost \u4e0a\u542f\u52a8\u3002cURL \u8bf7\u6c42\u53ca\u5176\u54cd\u5e94\u793a\u4f8b\u5982\u4e0b\u3002 \u276f curl $'http://192.168.86.48:5000/predict?sentence=this%20is%20the%20best%20boo k%20ever' {\"response\":{\"negative\":\"0.0032927393913269043\",\"positive\":\"0.99670726\",\" sentence\":\"this is the best book ever\",\"time_taken\":\"0.029126882553100586\"}} \u53ef\u4ee5\u770b\u5230\uff0c\u6211\u4eec\u5f97\u5230\u7684\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002 \u60a8\u8fd8\u53ef\u4ee5\u8bbf\u95ee http://127.0.0.1:5000/predict?sentence=this%20book%20is%20too%20complicated%20for%20me\u3002\u8fd9\u5c06\u518d\u6b21\u8fd4\u56de\u4e00\u4e2a JSON \u6587\u4ef6\u3002 { response : { negative : \"0.8646619468927383\" , positive : \"0.13533805\" , sentence : \"this book is too complicated for me\" , time_taken : \"0.03852701187133789\" } } \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\uff0c\u53ef\u4ee5\u7528\u6765\u4e3a\u5c11\u91cf\u7528\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u4e3a\u4ec0\u4e48\u662f\u5c11\u91cf\uff1f\u56e0\u4e3a\u8fd9\u4e2a API \u4e00\u6b21\u53ea\u670d\u52a1\u4e00\u4e2a\u8bf7\u6c42\u3002gunicorn \u662f UNIX \u4e0a\u7684 Python WSGI HTTP \u670d\u52a1\u5668\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u5b83\u7684 CPU \u6765\u5904\u7406\u591a\u4e2a\u5e76\u884c\u8bf7\u6c42\u3002Gunicorn \u53ef\u4ee5\u4e3a API \u521b\u5efa\u591a\u4e2a\u8fdb\u7a0b\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u540c\u65f6\u4e3a\u591a\u4e2a\u5ba2\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 \"pip install gunicorn \"\u5b89\u88c5 gunicorn\u3002 \u4e3a\u4e86\u5c06\u4ee3\u7801\u8f6c\u6362\u4e3a\u4e0e gunicorn \u517c\u5bb9\uff0c\u6211\u4eec\u9700\u8981\u79fb\u9664 init main\uff0c\u5e76\u5c06\u5176\u4e2d\u7684\u6240\u6709\u5185\u5bb9\u79fb\u81f3\u5168\u5c40\u8303\u56f4\u3002\u6b64\u5916\uff0c\u6211\u4eec\u73b0\u5728\u4f7f\u7528\u7684\u662f CPU \u800c\u4e0d\u662f GPU\u3002\u4fee\u6539\u540e\u7684\u4ee3\u7801\u5982\u4e0b\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) DEVICE = \"cpu\" MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ))) MODEL . to ( DEVICE ) MODEL . eval () def sentence_prediction ( sentence ): return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): return flask . jsonify ( response ) \u6211\u4eec\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002 gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u5728\u63d0\u4f9b\u7684 IP \u5730\u5740\u548c\u7aef\u53e3\u4e0a\u4f7f\u7528 4 \u4e2a Worker \u8fd0\u884c\u6211\u4eec\u7684 flask \u5e94\u7528\u7a0b\u5e8f\u3002\u7531\u4e8e\u6709 4 \u4e2a Worker\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u540c\u65f6\u5904\u7406 4 \u4e2a\u8bf7\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u73b0\u5728\u6211\u4eec\u7684\u7ec8\u7aef\u4f7f\u7528\u7684\u662f CPU\uff0c\u56e0\u6b64\u4e0d\u9700\u8981 GPU \u673a\u5668\uff0c\u53ef\u4ee5\u5728\u4efb\u4f55\u6807\u51c6\u670d\u52a1\u5668/\u865a\u62df\u673a\u4e0a\u8fd0\u884c\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u8fd8\u6709\u4e00\u4e2a\u95ee\u9898\uff1a\u6211\u4eec\u5df2\u7ecf\u5728\u672c\u5730\u673a\u5668\u4e0a\u5b8c\u6210\u4e86\u6240\u6709\u5de5\u4f5c\uff0c\u56e0\u6b64\u5fc5\u987b\u5c06\u5176\u575e\u5316\u3002\u770b\u770b\u4e0b\u9762\u8fd9\u4e2a\u672a\u6ce8\u91ca\u7684 Dockerfile\uff0c\u5b83\u53ef\u4ee5\u7528\u6765\u90e8\u7f72\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002\u8bf7\u6ce8\u610f\u7528\u4e8e\u57f9\u8bad\u7684\u65e7 Dockerfile \u548c\u8fd9\u4e2a Dockerfile \u4e4b\u95f4\u7684\u533a\u522b\u3002\u533a\u522b\u4e0d\u5927\u3002 # CPU Dockerfile FROM ubuntu:18.04 RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* RUN useradd -m abhishek RUN chown -R abhishek:abhishek /home/abhishek/ COPY --chown = abhishek *.* /home/abhishek/app/ USER abhishek RUN mkdir /home/abhishek/data/ RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt RUN pip3 install mkl WORKDIR /home/abhishek/app \u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u65b0\u7684 Docker \u5bb9\u5668\u3002 docker build -f Dockerfile -t bert:api \u5f53 Docker \u5bb9\u5668\u6784\u5efa\u5b8c\u6210\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u76f4\u63a5\u8fd0\u884c API \u4e86\u3002 docker run -p 5000 :5000 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:api /home/abhishek/.local/bin/gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5c06\u5bb9\u5668\u5185\u7684 5000 \u7aef\u53e3\u66b4\u9732\u7ed9\u5bb9\u5668\u5916\u7684 5000 \u7aef\u53e3\u3002\u5982\u679c\u4f7f\u7528 docker-compose\uff0c\u4e5f\u53ef\u4ee5\u5f88\u597d\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002Dockercompose \u662f\u4e00\u4e2a\u53ef\u4ee5\u8ba9\u4f60\u540c\u65f6\u5728\u4e0d\u540c\u6216\u76f8\u540c\u5bb9\u5668\u4e2d\u8fd0\u884c\u4e0d\u540c\u670d\u52a1\u7684\u5de5\u5177\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 \"pip install docker-compose \"\u5b89\u88c5 docker-compose\uff0c\u7136\u540e\u5728\u6784\u5efa\u5bb9\u5668\u540e\u8fd0\u884c \"docker-compose up\"\u3002\u8981\u4f7f\u7528 docker-compose\uff0c\u4f60\u9700\u8981\u4e00\u4e2a docker-compose.yml \u6587\u4ef6\u3002 # docker-compose.yml # specify a version of the compose version : '3.7' # you can add multiple services services : # specify service name. we call our service: api api : # specify image name image : bert:api # the command that you would like to run inside the container command : /home/abhishek/.local/bin/gunicorn api:app --bind 0.0.0.0:5000 --workers 4 # mount the volume volumes : - /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ # this ensures that our ports from container will be # exposed as it is network_mode : host \u73b0\u5728\uff0c\u60a8\u53ea\u9700\u4f7f\u7528\u4e0a\u8ff0\u547d\u4ee4\u5373\u53ef\u91cd\u65b0\u8fd0\u884c API\uff0c\u5176\u8fd0\u884c\u65b9\u5f0f\u4e0e\u4e4b\u524d\u76f8\u540c\u3002\u606d\u559c\u4f60\uff0c\u73b0\u5728\uff0c\u4f60\u4e5f\u5df2\u7ecf\u6210\u529f\u5730\u5c06\u9884\u6d4b API \u8fdb\u884c\u4e86 Docker \u5316\uff0c\u53ef\u4ee5\u968f\u65f6\u968f\u5730\u90e8\u7f72\u4e86\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u5b66\u4e60\u4e86 Docker\u3001\u4f7f\u7528 flask \u6784\u5efa API\u3001\u4f7f\u7528 gunicorn \u548c Docker \u670d\u52a1 API \u4ee5\u53ca docker-compose\u3002\u5173\u4e8e docker \u7684\u77e5\u8bc6\u8fdc\u4e0d\u6b62\u8fd9\u4e9b\uff0c\u4f46\u8fd9\u5e94\u8be5\u662f\u4e00\u4e2a\u5f00\u59cb\u3002\u5176\u4ed6\u5185\u5bb9\u53ef\u4ee5\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9010\u6e10\u638c\u63e1\u3002 \u6211\u4eec\u8fd8\u8df3\u8fc7\u4e86\u8bb8\u591a\u5de5\u5177\uff0c\u5982 kubernetes\u3001bean-stalk\u3001sagemaker\u3001heroku \u548c\u8bb8\u591a\u5176\u4ed6\u5de5\u5177\uff0c\u8fd9\u4e9b\u5de5\u5177\u5982\u4eca\u88ab\u4eba\u4eec\u7528\u6765\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u6a21\u578b\u3002\"\u6211\u8981\u5199\u4ec0\u4e48\uff1f\u70b9\u51fb\u4fee\u6539\u56fe X \u4e2d\u7684 docker \u5bb9\u5668\"\uff1f\u5728\u4e66\u4e2d\u63cf\u8ff0\u8fd9\u4e9b\u662f\u4e0d\u53ef\u884c\u7684\uff0c\u4e5f\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u6240\u4ee5\u6211\u5c06\u4f7f\u7528\u4e0d\u540c\u7684\u5a92\u4ecb\u6765\u8d5e\u7f8e\u672c\u4e66\u7684\u8fd9\u4e00\u90e8\u5206\u3002\u8bf7\u8bb0\u4f4f\uff0c\u4e00\u65e6\u4f60\u5bf9\u5e94\u7528\u7a0b\u5e8f\u8fdb\u884c\u4e86 Docker \u5316\uff0c\u4f7f\u7528\u8fd9\u4e9b\u6280\u672f/\u5e73\u53f0\u8fdb\u884c\u90e8\u7f72\u5c31\u53d8\u5f97\u6613\u5982\u53cd\u638c\u4e86\u3002\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u8981\u8ba9\u4f60\u7684\u4ee3\u7801\u548c\u6a21\u578b\u5bf9\u4ed6\u4eba\u53ef\u7528\uff0c\u5e76\u505a\u597d\u6587\u6863\u8bb0\u5f55\uff0c\u8fd9\u6837\u4efb\u4f55\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u4f60\u5f00\u53d1\u7684\u4e1c\u897f\uff0c\u800c\u65e0\u9700\u591a\u6b21\u8be2\u95ee\u4f60\u3002\u8fd9\u4e0d\u4ec5\u80fd\u8282\u7701\u60a8\u7684\u65f6\u95f4\uff0c\u8fd8\u80fd\u8282\u7701\u4ed6\u4eba\u7684\u65f6\u95f4\u3002\u597d\u7684\u3001\u5f00\u6e90\u7684\u3001\u53ef\u91cd\u590d\u4f7f\u7528\u7684\u4ee3\u7801\u5728\u60a8\u7684\u4f5c\u54c1\u96c6\u4e2d\u4e5f\u975e\u5e38\u91cd\u8981\u3002","title":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5"},{"location":"%E5%A4%84%E7%90%86%E5%88%86%E7%B1%BB%E5%8F%98%E9%87%8F/","text":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf \u5f88\u591a\u4eba\u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u65f6\u90fd\u4f1a\u9047\u5230\u5f88\u591a\u56f0\u96be\uff0c\u56e0\u6b64\u8fd9\u503c\u5f97\u7528\u6574\u6574\u4e00\u7ae0\u7684\u7bc7\u5e45\u6765\u8ba8\u8bba\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u5c06\u8bb2\u8ff0\u4e0d\u540c\u7c7b\u578b\u7684\u5206\u7c7b\u6570\u636e\uff0c\u4ee5\u53ca\u5982\u4f55\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u3002 \u4ec0\u4e48\u662f\u5206\u7c7b\u53d8\u91cf\uff1f \u5206\u7c7b\u53d8\u91cf/\u7279\u5f81\u662f\u6307\u4efb\u4f55\u7279\u5f81\u7c7b\u578b\uff0c\u53ef\u5206\u4e3a\u4e24\u5927\u7c7b\uff1a - \u65e0\u5e8f - \u6709\u5e8f \u65e0\u5e8f\u53d8\u91cf \u662f\u6307\u6709\u4e24\u4e2a\u6216\u4e24\u4e2a\u4ee5\u4e0a\u7c7b\u522b\u7684\u53d8\u91cf\uff0c\u8fd9\u4e9b\u7c7b\u522b\u6ca1\u6709\u4efb\u4f55\u76f8\u5173\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u5982\u679c\u5c06\u6027\u522b\u5206\u4e3a\u4e24\u7ec4\uff0c\u5373\u7537\u6027\u548c\u5973\u6027\uff0c\u5219\u53ef\u5c06\u5176\u89c6\u4e3a\u540d\u4e49\u53d8\u91cf\u3002 \u6709\u5e8f\u53d8\u91cf \u5219\u6709 \"\u7b49\u7ea7 \"\u6216\u7c7b\u522b\uff0c\u5e76\u6709\u7279\u5b9a\u7684\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u4e00\u4e2a\u987a\u5e8f\u5206\u7c7b\u53d8\u91cf\u53ef\u4ee5\u662f\u4e00\u4e2a\u5177\u6709\u4f4e\u3001\u4e2d\u3001\u9ad8\u4e09\u4e2a\u4e0d\u540c\u7b49\u7ea7\u7684\u7279\u5f81\u3002\u987a\u5e8f\u5f88\u91cd\u8981\u3002 \u5c31\u5b9a\u4e49\u800c\u8a00\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u53d8\u91cf\u5206\u4e3a \u4e8c\u5143\u53d8\u91cf \uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u7c7b\u522b\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u6709\u4e9b\u4eba\u751a\u81f3\u628a\u5206\u7c7b\u53d8\u91cf\u79f0\u4e3a \" \u5faa\u73af \"\u53d8\u91cf\u3002\u5468\u671f\u53d8\u91cf\u4ee5 \"\u5468\u671f \"\u7684\u5f62\u5f0f\u5b58\u5728\uff0c\u4f8b\u5982\u4e00\u5468\u4e2d\u7684\u5929\u6570\uff1a \u5468\u65e5\u3001\u5468\u4e00\u3001\u5468\u4e8c\u3001\u5468\u4e09\u3001\u5468\u56db\u3001\u5468\u4e94\u548c\u5468\u516d\u3002\u5468\u516d\u8fc7\u540e\uff0c\u53c8\u662f\u5468\u65e5\u3002\u8fd9\u5c31\u662f\u4e00\u4e2a\u5faa\u73af\u3002\u53e6\u4e00\u4e2a\u4f8b\u5b50\u662f\u4e00\u5929\u4e2d\u7684\u5c0f\u65f6\u6570\uff0c\u5982\u679c\u6211\u4eec\u5c06\u5b83\u4eec\u89c6\u4e3a\u7c7b\u522b\u7684\u8bdd\u3002 \u5206\u7c7b\u53d8\u91cf\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9a\u4e49\uff0c\u5f88\u591a\u4eba\u4e5f\u8c08\u5230\u8981\u6839\u636e\u5206\u7c7b\u53d8\u91cf\u7684\u7c7b\u578b\u6765\u5904\u7406\u4e0d\u540c\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u4e0d\u8fc7\uff0c\u6211\u8ba4\u4e3a\u6ca1\u6709\u5fc5\u8981\u8fd9\u6837\u505a\u3002\u6240\u6709\u6d89\u53ca\u5206\u7c7b\u53d8\u91cf\u7684\u95ee\u9898\u90fd\u53ef\u4ee5\u7528\u540c\u6837\u7684\u65b9\u6cd5\u5904\u7406\u3002 \u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e00\u4e2a\u6570\u636e\u96c6\uff08\u4e00\u5982\u65e2\u5f80\uff09\u3002\u8981\u4e86\u89e3\u5206\u7c7b\u53d8\u91cf\uff0c\u6700\u597d\u7684\u514d\u8d39\u6570\u636e\u96c6\u4e4b\u4e00\u662f Kaggle \u5206\u7c7b\u7279\u5f81\u7f16\u7801\u6311\u6218\u8d5b\u4e2d\u7684 cat-in-the-dat \u3002\u5171\u6709\u4e24\u4e2a\u6311\u6218\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6311\u6218\u7684\u6570\u636e\uff0c\u56e0\u4e3a\u5b83\u6bd4\u524d\u4e00\u4e2a\u7248\u672c\u6709\u66f4\u591a\u53d8\u91cf\uff0c\u96be\u5ea6\u4e5f\u66f4\u5927\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u3002 \u56fe 1\uff1aCat-in-the-dat-ii challenge\u90e8\u5206\u6570\u636e\u5c55\u793a \u6570\u636e\u96c6\u7531\u5404\u79cd\u5206\u7c7b\u53d8\u91cf\u7ec4\u6210\uff1a \u65e0\u5e8f \u6709\u5e8f \u5faa\u73af \u4e8c\u5143 \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6240\u6709\u5b58\u5728\u7684\u53d8\u91cf\u548c\u76ee\u6807\u53d8\u91cf\u7684\u5b50\u96c6\u3002 \u8fd9\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002 \u76ee\u6807\u53d8\u91cf\u5bf9\u4e8e\u6211\u4eec\u5b66\u4e60\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\u5e76\u4e0d\u5341\u5206\u91cd\u8981\uff0c\u4f46\u6700\u7ec8\u6211\u4eec\u5c06\u5efa\u7acb\u4e00\u4e2a\u7aef\u5230\u7aef\u6a21\u578b\uff0c\u56e0\u6b64\u8ba9\u6211\u4eec\u770b\u770b\u56fe 2 \u4e2d\u7684\u76ee\u6807\u53d8\u91cf\u5206\u5e03\u3002\u6211\u4eec\u770b\u5230\u76ee\u6807\u662f \u504f\u659c \u7684\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8fd9\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u6765\u8bf4\uff0c\u6700\u597d\u7684\u6307\u6807\u662f ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528\u7cbe\u786e\u5ea6\u548c\u53ec\u56de\u7387\uff0c\u4f46 AUC \u7ed3\u5408\u4e86\u8fd9\u4e24\u4e2a\u6307\u6807\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 AUC \u6765\u8bc4\u4f30\u6211\u4eec\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u7684\u6a21\u578b\u3002 \u56fe 2\uff1a\u6807\u7b7e\u8ba1\u6570\u3002X \u8f74\u8868\u793a\u6807\u7b7e\uff0cY \u8f74\u8868\u793a\u6807\u7b7e\u8ba1\u6570 \u603b\u4f53\u800c\u8a00\uff0c\u6709\uff1a 5\u4e2a\u4e8c\u5143\u53d8\u91cf 10\u4e2a\u65e0\u5e8f\u53d8\u91cf 6\u4e2a\u6709\u5e8f\u53d8\u91cf 2\u4e2a\u5faa\u73af\u53d8\u91cf 1\u4e2a\u76ee\u6807\u53d8\u91cf \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u96c6\u4e2d\u7684 ord_2 \u7279\u5f81\u3002\u5b83\u5305\u62ec6\u4e2a\u4e0d\u540c\u7684\u7c7b\u522b\uff1a - \u51b0\u51bb - \u6e29\u6696 - \u5bd2\u51b7 - \u8f83\u70ed - \u70ed - \u975e\u5e38\u70ed \u6211\u4eec\u5fc5\u987b\u77e5\u9053\uff0c\u8ba1\u7b97\u673a\u65e0\u6cd5\u7406\u89e3\u6587\u672c\u6570\u636e\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u662f\u521b\u5efa\u4e00\u4e2a\u5b57\u5178\uff0c\u5c06\u8fd9\u4e9b\u503c\u6620\u5c04\u4e3a\u4ece 0 \u5230 N-1 \u7684\u6570\u5b57\uff0c\u5176\u4e2d N \u662f\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7c7b\u522b\u7684\u603b\u6570\u3002 # \u6620\u5c04\u5b57\u5178 mapping = { \"Freezing\" : 0 , \"Warm\" : 1 , \"Cold\" : 2 , \"Boiling Hot\" : 3 , \"Hot\" : 4 , \"Lava Hot\" : 5 } \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8bfb\u53d6\u6570\u636e\u96c6\uff0c\u5e76\u8f7b\u677e\u5730\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002 import pandas as pd # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u53d6*ord_2*\u5217\uff0c\u5e76\u4f7f\u7528\u6620\u5c04\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57 df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. map ( mapping ) \u6620\u5c04\u524d\u7684\u6570\u503c\u8ba1\u6570\uff1a df .* ord_2 *. value_counts () Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : * ord_2 * , dtype : int64 \u6620\u5c04\u540e\u7684\u6570\u503c\u8ba1\u6570\uff1a 0.0 142726 1.0 124239 2.0 97822 3.0 84790 4.0 67508 5.0 64840 Name : * ord_2 * , dtype : int64 \u8fd9\u79cd\u5206\u7c7b\u53d8\u91cf\u7684\u7f16\u7801\u65b9\u5f0f\u88ab\u79f0\u4e3a\u6807\u7b7e\u7f16\u7801\uff08Label Encoding\uff09\u6211\u4eec\u5c06\u6bcf\u4e2a\u7c7b\u522b\u7f16\u7801\u4e3a\u4e00\u4e2a\u6570\u5b57\u6807\u7b7e\u3002 \u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 LabelEncoder \u8fdb\u884c\u7f16\u7801\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u5c06\u7f3a\u5931\u503c\u586b\u5145\u4e3a\"NONE\" df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. fillna ( \"NONE\" ) # LabelEncoder\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u8f6c\u6362\u6570\u636e df . loc [:, \"*ord_2*\" ] = lbl_enc . fit_transform ( df .* ord_2 *. values ) \u4f60\u4f1a\u770b\u5230\u6211\u4f7f\u7528\u4e86 pandas \u7684 fillna\u3002\u539f\u56e0\u662f scikit-learn \u7684 LabelEncoder \u65e0\u6cd5\u5904\u7406 NaN \u503c\uff0c\u800c ord_2 \u5217\u4e2d\u6709 NaN \u503c\u3002 \u6211\u4eec\u53ef\u4ee5\u5728\u8bb8\u591a\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u4e2d\u76f4\u63a5\u4f7f\u7528\u5b83\uff1a - \u51b3\u7b56\u6811 - \u968f\u673a\u68ee\u6797 - \u63d0\u5347\u6811 - \u6216\u4efb\u4f55\u4e00\u79cd\u63d0\u5347\u6811\u6a21\u578b - XGBoost - GBM - LightGBM \u8fd9\u79cd\u7f16\u7801\u65b9\u5f0f\u4e0d\u80fd\u7528\u4e8e\u7ebf\u6027\u6a21\u578b\u3001\u652f\u6301\u5411\u91cf\u673a\u6216\u795e\u7ecf\u7f51\u7edc\uff0c\u56e0\u4e3a\u5b83\u4eec\u5e0c\u671b\u6570\u636e\u662f\u6807\u51c6\u5316\u7684\u3002 \u5bf9\u4e8e\u8fd9\u4e9b\u7c7b\u578b\u7684\u6a21\u578b\uff0c\u6211\u4eec\u53ef\u4ee5\u5bf9\u6570\u636e\u8fdb\u884c\u4e8c\u503c\u5316\uff08binarize\uff09\u5904\u7406\u3002 Freezing --> 0 --> 0 0 0 Warm --> 1 --> 0 0 1 Cold --> 2 --> 0 1 0 Boiling Hot --> 3 --> 0 1 1 Hot --> 4 --> 1 0 0 Lava Hot --> 5 --> 1 0 1 \u8fd9\u53ea\u662f\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\uff0c\u7136\u540e\u518d\u8f6c\u6362\u4e3a\u4e8c\u503c\u5316\u8868\u793a\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u628a\u4e00\u4e2a\u7279\u5f81\u5206\u6210\u4e86\u4e09\u4e2a\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u7279\u5f81\uff08\u6216\u5217\uff09\u3002\u5982\u679c\u6211\u4eec\u6709\u66f4\u591a\u7684\u7c7b\u522b\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u5206\u6210\u66f4\u591a\u7684\u5217\u3002 \u5982\u679c\u6211\u4eec\u7528\u7a00\u758f\u683c\u5f0f\u5b58\u50a8\u5927\u91cf\u4e8c\u503c\u5316\u53d8\u91cf\uff0c\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u5b58\u50a8\u8fd9\u4e9b\u53d8\u91cf\u3002\u7a00\u758f\u683c\u5f0f\u4e0d\u8fc7\u662f\u4e00\u79cd\u5728\u5185\u5b58\u4e2d\u5b58\u50a8\u6570\u636e\u7684\u8868\u793a\u6216\u65b9\u5f0f\uff0c\u5728\u8fd9\u79cd\u683c\u5f0f\u4e2d\uff0c\u4f60\u5e76\u4e0d\u5b58\u50a8\u6240\u6709\u7684\u503c\uff0c\u800c\u53ea\u5b58\u50a8\u91cd\u8981\u7684\u503c\u3002\u5728\u4e0a\u8ff0\u4e8c\u8fdb\u5236\u53d8\u91cf\u7684\u60c5\u51b5\u4e2d\uff0c\u6700\u91cd\u8981\u7684\u5c31\u662f\u6709 1 \u7684\u5730\u65b9\u3002 \u5f88\u96be\u60f3\u8c61\u8fd9\u6837\u7684\u683c\u5f0f\uff0c\u4f46\u4e3e\u4e2a\u4f8b\u5b50\u5c31\u4f1a\u660e\u767d\u3002 \u5047\u8bbe\u4e0a\u9762\u7684\u6570\u636e\u5e27\u4e2d\u53ea\u6709\u4e00\u4e2a\u7279\u5f81\uff1a ord_2 \u3002 Index Feature 0 Warm 1 Hot 2 Lava hot \u76ee\u524d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6570\u636e\u96c6\u4e2d\u7684\u4e09\u4e2a\u6837\u672c\u3002\u8ba9\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u4e8c\u503c\u8868\u793a\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u6709\u4e09\u4e2a\u9879\u76ee\u3002 \u8fd9\u4e09\u4e2a\u9879\u76ee\u5c31\u662f\u4e09\u4e2a\u7279\u5f81\u3002 Index Feature_0 Feature_1 Feature_2 0 0 0 1 1 1 0 0 2 1 0 1 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u7279\u5f81\u5b58\u50a8\u5728\u4e00\u4e2a\u6709 3 \u884c 3 \u5217\uff083x3\uff09\u7684\u77e9\u9635\u4e2d\u3002\u77e9\u9635\u7684\u6bcf\u4e2a\u5143\u7d20\u5360\u7528 8 \u4e2a\u5b57\u8282\u3002\u56e0\u6b64\uff0c\u8fd9\u4e2a\u6570\u7ec4\u7684\u603b\u5185\u5b58\u9700\u6c42\u4e3a 8x3x3 = 72 \u5b57\u8282\u3002 \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e2a\u7b80\u5355\u7684 python \u4ee3\u7801\u6bb5\u6765\u68c0\u67e5\u8fd9\u4e00\u70b9\u3002 import numpy as np example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) print ( example . nbytes ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u6253\u5370\u51fa 72\uff0c\u5c31\u50cf\u6211\u4eec\u4e4b\u524d\u8ba1\u7b97\u7684\u90a3\u6837\u3002\u4f46\u6211\u4eec\u9700\u8981\u5b58\u50a8\u8fd9\u4e2a\u77e9\u9635\u7684\u6240\u6709\u5143\u7d20\u5417\uff1f\u5982\u524d\u6240\u8ff0\uff0c\u6211\u4eec\u53ea\u5bf9 1 \u611f\u5174\u8da3\u30020 \u5e76\u4e0d\u91cd\u8981\uff0c\u56e0\u4e3a\u4efb\u4f55\u4e0e 0 \u76f8\u4e58\u7684\u5143\u7d20\u90fd\u662f 0\uff0c\u800c 0 \u4e0e\u4efb\u4f55\u5143\u7d20\u76f8\u52a0\u6216\u76f8\u51cf\u4e5f\u6ca1\u6709\u4efb\u4f55\u533a\u522b\u3002\u53ea\u7528 1 \u8868\u793a\u77e9\u9635\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u67d0\u79cd\u5b57\u5178\u65b9\u6cd5\uff0c\u5176\u4e2d\u952e\u662f\u884c\u548c\u5217\u7684\u7d22\u5f15\uff0c\u503c\u662f 1\uff1a ( 0 , 2 ) 1 ( 1 , 0 ) 1 ( 2 , 0 ) 1 ( 2 , 2 ) 1 \u8fd9\u6837\u7684\u7b26\u53f7\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u5b58\u50a8\u56db\u4e2a\u503c\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u3002\u4f7f\u7528\u7684\u603b\u5185\u5b58\u4e3a 8x4 = 32 \u5b57\u8282\u3002\u4efb\u4f55 numpy \u6570\u7ec4\u90fd\u53ef\u4ee5\u901a\u8fc7\u7b80\u5355\u7684 python \u4ee3\u7801\u8f6c\u6362\u4e3a\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) sparse_example = sparse . csr_matrix ( example ) print ( sparse_example . data . nbytes ) \u8fd9\u5c06\u6253\u5370 32\uff0c\u6bd4\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u5c11\u4e86\u8fd9\u4e48\u591a\uff01\u7a00\u758f csr \u77e9\u9635\u7684\u603b\u5927\u5c0f\u662f\u4e09\u4e2a\u503c\u7684\u603b\u548c\u3002 print ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) \u8fd9\u5c06\u6253\u5370\u51fa 64 \u4e2a\u5143\u7d20\uff0c\u4ecd\u7136\u5c11\u4e8e\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u3002\u9057\u61be\u7684\u662f\uff0c\u6211\u4e0d\u4f1a\u8be6\u7ec6\u4ecb\u7ecd\u8fd9\u4e9b\u5143\u7d20\u3002\u4f60\u53ef\u4ee5\u5728 scipy \u6587\u6863\u4e2d\u4e86\u89e3\u66f4\u591a\u3002\u5f53\u6211\u4eec\u62e5\u6709\u66f4\u5927\u7684\u6570\u7ec4\u65f6\uff0c\u6bd4\u5982\u8bf4\u62e5\u6709\u6570\u5343\u4e2a\u6837\u672c\u548c\u6570\u4e07\u4e2a\u7279\u5f81\u7684\u6570\u7ec4\uff0c\u5927\u5c0f\u5dee\u5f02\u5c31\u4f1a\u53d8\u5f97\u975e\u5e38\u5927\u3002\u4f8b\u5982\uff0c\u6211\u4eec\u4f7f\u7528\u57fa\u4e8e\u8ba1\u6570\u7279\u5f81\u7684\u6587\u672c\u6570\u636e\u96c6\u3002 import numpy as np from scipy import sparse n_rows = 10000 n_cols = 100000 # \u751f\u6210\u7b26\u5408\u4f2f\u52aa\u5229\u5206\u5e03\u7684\u968f\u673a\u6570\u7ec4\uff0c\u7ef4\u5ea6\u4e3a[10000, 100000] example = np . random . binomial ( 1 , p = 0.05 , size = ( n_rows , n_cols )) print ( f \"Size of dense array: { example . nbytes } \" ) # \u5c06\u968f\u673a\u77e9\u9635\u8f6c\u6362\u4e3a\u6d17\u6f31\u77e9\u9635 sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u8fd9\u5c06\u6253\u5370\uff1a Size of dense array : 8000000000 Size of sparse array : 399932496 Full size of sparse array : 599938748 \u56e0\u6b64\uff0c\u5bc6\u96c6\u9635\u5217\u9700\u8981 ~8000MB \u6216\u5927\u7ea6 8GB \u5185\u5b58\u3002\u800c\u7a00\u758f\u9635\u5217\u53ea\u5360\u7528 399MB \u5185\u5b58\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5f53\u6211\u4eec\u7684\u7279\u5f81\u4e2d\u6709\u5927\u91cf\u96f6\u65f6\uff0c\u6211\u4eec\u66f4\u559c\u6b22\u7a00\u758f\u9635\u5217\u800c\u4e0d\u662f\u5bc6\u96c6\u9635\u5217\u7684\u539f\u56e0\u3002 \u8bf7\u6ce8\u610f\uff0c\u7a00\u758f\u77e9\u9635\u6709\u591a\u79cd\u4e0d\u540c\u7684\u8868\u793a\u65b9\u6cd5\u3002\u8fd9\u91cc\u6211\u53ea\u5c55\u793a\u4e86\u5176\u4e2d\u4e00\u79cd\uff08\u53ef\u80fd\u4e5f\u662f\u6700\u5e38\u7528\u7684\uff09\u65b9\u6cd5\u3002\u6df1\u5165\u63a2\u8ba8\u8fd9\u4e9b\u65b9\u6cd5\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\uff0c\u56e0\u6b64\u7559\u7ed9\u8bfb\u8005\u4e00\u4e2a\u7ec3\u4e60\u3002 \u5c3d\u7ba1\u4e8c\u503c\u5316\u7279\u5f81\u7684\u7a00\u758f\u8868\u793a\u6bd4\u5176\u5bc6\u96c6\u8868\u793a\u6240\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u4f46\u5bf9\u4e8e\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\uff0c\u8fd8\u6709\u4e00\u79cd\u8f6c\u6362\u6240\u5360\u7528\u7684\u5185\u5b58\u66f4\u5c11\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \" \u72ec\u70ed\u7f16\u7801 \"\u3002 \u72ec\u70ed\u7f16\u7801\u4e5f\u662f\u4e00\u79cd\u4e8c\u503c\u7f16\u7801\uff0c\u56e0\u4e3a\u53ea\u6709 0 \u548c 1 \u4e24\u4e2a\u503c\u3002\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5b83\u5e76\u4e0d\u662f\u4e8c\u503c\u8868\u793a\u6cd5\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0b\u9762\u7684\u4f8b\u5b50\u6765\u7406\u89e3\u5b83\u7684\u8868\u793a\u6cd5\u3002 \u5047\u8bbe\u6211\u4eec\u7528\u4e00\u4e2a\u5411\u91cf\u6765\u8868\u793a ord_2 \u53d8\u91cf\u7684\u6bcf\u4e2a\u7c7b\u522b\u3002\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e0e ord_2 \u53d8\u91cf\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u6bcf\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u90fd\u662f 6\uff0c\u5e76\u4e14\u9664\u4e86\u4e00\u4e2a\u4f4d\u7f6e\u5916\uff0c\u5176\u4ed6\u4f4d\u7f6e\u90fd\u662f 0\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u7279\u6b8a\u7684\u5411\u91cf\u8868\u3002 Freezing 0 0 0 0 0 1 Warm 0 0 0 0 1 0 Cold 0 0 0 1 0 0 Boiling Hot 0 0 1 0 0 0 Hot 0 1 0 0 0 0 Lava Hot 1 0 0 0 0 0 \u6211\u4eec\u770b\u5230\u5411\u91cf\u7684\u5927\u5c0f\u662f 1x6\uff0c\u5373\u5411\u91cf\u4e2d\u67096\u4e2a\u5143\u7d20\u3002\u8fd9\u4e2a\u6570\u5b57\u662f\u600e\u4e48\u6765\u7684\u5462\uff1f\u5982\u679c\u4f60\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u5982\u524d\u6240\u8ff0\uff0c\u67096\u4e2a\u7c7b\u522b\u3002\u5728\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u65f6\uff0c\u5411\u91cf\u7684\u5927\u5c0f\u5fc5\u987b\u4e0e\u6211\u4eec\u8981\u67e5\u770b\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u6bcf\u4e2a\u5411\u91cf\u90fd\u6709\u4e00\u4e2a 1\uff0c\u5176\u4f59\u6240\u6709\u503c\u90fd\u662f 0\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u7528\u8fd9\u4e9b\u7279\u5f81\u6765\u4ee3\u66ff\u4e4b\u524d\u7684\u4e8c\u503c\u5316\u7279\u5f81\uff0c\u770b\u770b\u80fd\u8282\u7701\u591a\u5c11\u5185\u5b58\u3002 \u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\u4ee5\u524d\u7684\u6570\u636e\uff0c\u5b83\u770b\u8d77\u6765\u5982\u4e0b\uff1a Index Feature 0 Warm 1 Hot 2 Lava hot \u6bcf\u4e2a\u6837\u672c\u67093\u4e2a\u7279\u5f81\u3002\u4f46\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u72ec\u70ed\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 6\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u67096\u4e2a\u7279\u5f81\uff0c\u800c\u4e0d\u662f3\u4e2a\u3002 Index F_0 F_1 F_2 F_3 F_4 F_5 0 0 0 0 0 1 0 1 0 1 0 0 0 0 2 1 0 1 0 0 0 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 6 \u4e2a\u7279\u5f81\uff0c\u800c\u5728\u8fd9\u4e2a 3x6 \u6570\u7ec4\u4e2d\uff0c\u53ea\u6709 3 \u4e2a1\u3002\u4f7f\u7528 numpy \u8ba1\u7b97\u5927\u5c0f\u4e0e\u4e8c\u503c\u5316\u5927\u5c0f\u8ba1\u7b97\u811a\u672c\u975e\u5e38\u76f8\u4f3c\u3002\u4f60\u9700\u8981\u6539\u53d8\u7684\u53ea\u662f\u6570\u7ec4\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6bb5\u4ee3\u7801\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 0 , 0 , 1 , 0 ], [ 0 , 1 , 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 , 0 , 0 ] ] ) print ( f \"Size of dense array: { example . nbytes } \" ) sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u6253\u5370\u5185\u5b58\u5927\u5c0f\u4e3a\uff1a Size of dense array : 144 Size of sparse array : 24 Full size of sparse array : 52 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5bc6\u96c6\u77e9\u9635\u7684\u5927\u5c0f\u8fdc\u8fdc\u5927\u4e8e\u4e8c\u503c\u5316\u77e9\u9635\u7684\u5927\u5c0f\u3002\u4e0d\u8fc7\uff0c\u7a00\u758f\u6570\u7ec4\u7684\u5927\u5c0f\u8981\u66f4\u5c0f\u3002\u8ba9\u6211\u4eec\u7528\u66f4\u5927\u7684\u6570\u7ec4\u6765\u8bd5\u8bd5\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 OneHotEncoder \u5c06\u5305\u542b 1001 \u4e2a\u7c7b\u522b\u7684\u7279\u5f81\u6570\u7ec4\u8f6c\u6362\u4e3a\u5bc6\u96c6\u77e9\u9635\u548c\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from sklearn import preprocessing # \u751f\u6210\u7b26\u5408\u5747\u5300\u5206\u5e03\u7684\u968f\u673a\u6574\u6570\uff0c\u7ef4\u5ea6\u4e3a[1000000, 10000000] example = np . random . randint ( 1000 , size = 1000000 ) # \u72ec\u70ed\u7f16\u7801\uff0c\u975e\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = False ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of dense array: { ohe_example . nbytes } \" ) # \u72ec\u70ed\u7f16\u7801\uff0c\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = True ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of sparse array: { ohe_example . data . nbytes } \" ) full_size = ( ohe_example . data . nbytes + ohe_example . indptr . nbytes + ohe_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u4e0a\u9762\u4ee3\u7801\u6253\u5370\u7684\u8f93\u51fa\uff1a Size of dense array : 8000000000 Size of sparse array : 8000000 Full size of sparse array : 16000004 \u8fd9\u91cc\u7684\u5bc6\u96c6\u9635\u5217\u5927\u5c0f\u7ea6\u4e3a 8GB\uff0c\u7a00\u758f\u9635\u5217\u4e3a 8MB\u3002\u5982\u679c\u53ef\u4ee5\u9009\u62e9\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff1f\u5728\u6211\u770b\u6765\uff0c\u9009\u62e9\u5f88\u7b80\u5355\uff0c\u4e0d\u662f\u5417\uff1f \u8fd9\u4e09\u79cd\u65b9\u6cd5\uff08\u6807\u7b7e\u7f16\u7801\u3001\u7a00\u758f\u77e9\u9635\u3001\u72ec\u70ed\u7f16\u7801\uff09\u662f\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u7684\u6700\u91cd\u8981\u65b9\u6cd5\u3002\u4e0d\u8fc7\uff0c\u4f60\u8fd8\u53ef\u4ee5\u7528\u5f88\u591a\u5176\u4ed6\u4e0d\u540c\u7684\u65b9\u6cd5\u6765\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u3002\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u6570\u503c\u53d8\u91cf\u5c31\u662f\u5176\u4e2d\u7684\u4e00\u4e2a\u4f8b\u5b50\u3002 \u5047\u8bbe\u6211\u4eec\u56de\u5230\u4e4b\u524d\u7684\u5206\u7c7b\u7279\u5f81\u6570\u636e\uff08\u539f\u59cb\u6570\u636e\u4e2d\u7684 cat-in-the-dat-ii\uff09\u3002\u5728\u6570\u636e\u4e2d\uff0c ord_2 \u7684\u503c\u4e3a\u201c\u70ed\u201c\u7684 id \u6709\u591a\u5c11\uff1f \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8ba1\u7b97\u6570\u636e\u7684\u5f62\u72b6\uff08shape\uff09\u8f7b\u677e\u8ba1\u7b97\u51fa\u8fd9\u4e2a\u503c\uff0c\u5176\u4e2d ord_2 \u5217\u7684\u503c\u4e3a Boiling Hot \u3002 In [ X ]: df [ df . ord_2 == \"Boiling Hot\" ] . shape Out [ X ]: ( 84790 , 25 ) \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 84790 \u6761\u8bb0\u5f55\u5177\u6709\u6b64\u503c\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 pandas \u4e2d\u7684 groupby \u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u8be5\u503c\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . count () Out [ X ]: ord_2 Boiling Hot 84790 Cold 97822 Freezing 142726 Hot 67508 Lava Hot 64840 Warm 124239 Name : id , dtype : int64 \u5982\u679c\u6211\u4eec\u53ea\u662f\u5c06 ord_2 \u5217\u66ff\u6362\u4e3a\u5176\u8ba1\u6570\u503c\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u5c06\u5176\u8f6c\u6362\u4e3a\u4e00\u79cd\u6570\u503c\u7279\u5f81\u4e86\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 transform \u51fd\u6570\u548c groupby \u6765\u521b\u5efa\u65b0\u5217\u6216\u66ff\u6362\u8fd9\u4e00\u5217\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . transform ( \"count\" ) Out [ X ]: 0 67508.0 1 124239.0 2 142726.0 3 64840.0 4 97822.0 ... 599995 142726.0 599996 84790.0 599997 142726.0 599998 124239.0 599999 84790.0 Name : id , Length : 600000 , dtype : float64 \u4f60\u53ef\u4ee5\u6dfb\u52a0\u6240\u6709\u7279\u5f81\u7684\u8ba1\u6570\uff0c\u4e5f\u53ef\u4ee5\u66ff\u6362\u5b83\u4eec\uff0c\u6216\u8005\u6839\u636e\u591a\u4e2a\u5217\u53ca\u5176\u8ba1\u6570\u8fdb\u884c\u5206\u7ec4\u3002\u4f8b\u5982\uff0c\u4ee5\u4e0b\u4ee3\u7801\u901a\u8fc7\u5bf9 ord_1 \u548c ord_2 \u5217\u5206\u7ec4\u8fdb\u884c\u8ba1\u6570\u3002 In [ X ]: df . groupby ( ... : [ ... : \"ord_1\" , ... : \"ord_2\" ... : ] ... : )[ \"id\" ] . count () . reset_index ( name = \"count\" ) Out [ X ]: ord_1 ord_2 count 0 Contributor Boiling Hot 15634 1 Contributor Cold 17734 2 Contributor Freezing 26082 3 Contributor Hot 12428 4 Contributor Lava Hot 11919 5 Contributor Warm 22774 6 Expert Boiling Hot 19477 7 Expert Cold 22956 8 Expert Freezing 33249 9 Expert Hot 15792 10 Expert Lava Hot 15078 11 Expert Warm 28900 12 Grandmaster Boiling Hot 13623 13 Grandmaster Cold 15464 14 Grandmaster Freezing 22818 15 Grandmaster Hot 10805 16 Grandmaster Lava Hot 10363 17 Grandmaster Warm 19899 18 Master Boiling Hot 10800 ... \u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u884c\uff0c\u4ee5\u4fbf\u5728\u4e00\u9875\u4e2d\u5bb9\u7eb3\u8fd9\u4e9b\u884c\u3002\u8fd9\u662f\u53e6\u4e00\u79cd\u53ef\u4ee5\u4f5c\u4e3a\u529f\u80fd\u6dfb\u52a0\u7684\u8ba1\u6570\u3002\u60a8\u73b0\u5728\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4f7f\u7528 id \u5217\u8fdb\u884c\u8ba1\u6570\u3002\u4e0d\u8fc7\uff0c\u4f60\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5bf9\u5217\u7684\u7ec4\u5408\u8fdb\u884c\u5206\u7ec4\uff0c\u5bf9\u5176\u4ed6\u5217\u8fdb\u884c\u8ba1\u6570\u3002 \u8fd8\u6709\u4e00\u4e2a\u5c0f\u7a8d\u95e8\uff0c\u5c31\u662f\u4ece\u8fd9\u4e9b\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u4ece\u73b0\u6709\u7684\u7279\u5f81\u4e2d\u521b\u5efa\u65b0\u7684\u5206\u7c7b\u7279\u5f81\uff0c\u800c\u4e14\u53ef\u4ee5\u6beb\u4e0d\u8d39\u529b\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot 1 Grandmaster_Warm 2 nan_Freezing 3 Novice_Lava Hot 4 Grandmaster_Cold ... 599999 Contributor_Boiling Hot Name : new_feature , Length : 600000 , dtype : object \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u7528\u4e0b\u5212\u7ebf\u5c06 ord_1 \u548c ord_2 \u5408\u5e76\uff0c\u7136\u540e\u5c06\u8fd9\u4e9b\u5217\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u7c7b\u578b\u3002\u8bf7\u6ce8\u610f\uff0cNaN \u4e5f\u4f1a\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u3002\u4e0d\u8fc7\u6ca1\u5173\u7cfb\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06 NaN \u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u6709\u4e86\u4e00\u4e2a\u7531\u8fd9\u4e24\u4e2a\u7279\u5f81\u7ec4\u5408\u800c\u6210\u7684\u65b0\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u5c06\u4e09\u5217\u4ee5\u4e0a\u6216\u56db\u5217\u751a\u81f3\u66f4\u591a\u5217\u7ec4\u5408\u5728\u4e00\u8d77\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : + \"_\" ... : + df . ord_3 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot_c 1 Grandmaster_Warm_e 2 nan_Freezing_n 3 Novice_Lava Hot_a 4 Grandmaster_Cold_h ... 599999 Contributor_Boiling Hot_b Name : new_feature , Length : 600000 , dtype : object \u90a3\u4e48\uff0c\u6211\u4eec\u5e94\u8be5\u628a\u54ea\u4e9b\u7c7b\u522b\u7ed3\u5408\u8d77\u6765\u5462\uff1f\u8fd9\u5e76\u6ca1\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u7b54\u6848\u3002\u8fd9\u53d6\u51b3\u4e8e\u60a8\u7684\u6570\u636e\u548c\u7279\u5f81\u7c7b\u578b\u3002\u4e00\u4e9b\u9886\u57df\u77e5\u8bc6\u5bf9\u4e8e\u521b\u5efa\u8fd9\u6837\u7684\u7279\u5f81\u53ef\u80fd\u5f88\u6709\u7528\u3002\u4f46\u662f\uff0c\u5982\u679c\u4f60\u4e0d\u62c5\u5fc3\u5185\u5b58\u548c CPU \u7684\u4f7f\u7528\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\uff0c\u5373\u521b\u5efa\u8bb8\u591a\u8fd9\u6837\u7684\u7ec4\u5408\uff0c\u7136\u540e\u4f7f\u7528\u4e00\u4e2a\u6a21\u578b\u6765\u51b3\u5b9a\u54ea\u4e9b\u7279\u5f81\u662f\u6709\u7528\u7684\uff0c\u5e76\u4fdd\u7559\u5b83\u4eec\u3002\u6211\u4eec\u5c06\u5728\u672c\u4e66\u7a0d\u540e\u90e8\u5206\u4ecb\u7ecd\u8fd9\u79cd\u65b9\u6cd5\u3002 \u65e0\u8bba\u4f55\u65f6\u83b7\u5f97\u5206\u7c7b\u53d8\u91cf\uff0c\u90fd\u8981\u9075\u5faa\u4ee5\u4e0b\u7b80\u5355\u6b65\u9aa4\uff1a - \u586b\u5145 NaN \u503c\uff08\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff01\uff09\u3002 - \u4f7f\u7528 scikit-learn \u7684 LabelEncoder \u6216\u6620\u5c04\u5b57\u5178\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\uff0c\u5c06\u5b83\u4eec\u8f6c\u6362\u4e3a\u6574\u6570\u3002\u5982\u679c\u6ca1\u6709\u586b\u5145 NaN \u503c\uff0c\u53ef\u80fd\u9700\u8981\u5728\u8fd9\u4e00\u6b65\u4e2d\u8fdb\u884c\u5904\u7406 - \u521b\u5efa\u72ec\u70ed\u7f16\u7801\u3002\u662f\u7684\uff0c\u4f60\u53ef\u4ee5\u8df3\u8fc7\u4e8c\u503c\u5316\uff01 - \u5efa\u6a21\uff01\u6211\u6307\u7684\u662f\u673a\u5668\u5b66\u4e60\u3002 \u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u5904\u7406 NaN \u6570\u636e\u975e\u5e38\u91cd\u8981\uff0c\u5426\u5219\u60a8\u53ef\u80fd\u4f1a\u4ece scikit-learn \u7684 LabelEncoder \u4e2d\u5f97\u5230\u81ed\u540d\u662d\u8457\u7684\u9519\u8bef\u4fe1\u606f\uff1a ValueError: y \u5305\u542b\u4ee5\u524d\u672a\u89c1\u8fc7\u7684\u6807\u7b7e\uff1a [Nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan) \u8fd9\u4ec5\u4ec5\u610f\u5473\u7740\uff0c\u5728\u8f6c\u6362\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6570\u636e\u4e2d\u51fa\u73b0\u4e86 NaN \u503c\u3002\u8fd9\u662f\u56e0\u4e3a\u4f60\u5728\u8bad\u7ec3\u65f6\u5fd8\u8bb0\u4e86\u5904\u7406\u5b83\u4eec\u3002 \u5904\u7406 NaN \u503c \u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u4e22\u5f03\u5b83\u4eec\u3002\u867d\u7136\u7b80\u5355\uff0c\u4f46\u5e76\u4e0d\u7406\u60f3\u3002NaN \u503c\u4e2d\u53ef\u80fd\u5305\u542b\u5f88\u591a\u4fe1\u606f\uff0c\u5982\u679c\u53ea\u662f\u4e22\u5f03\u8fd9\u4e9b\u503c\uff0c\u5c31\u4f1a\u4e22\u5931\u8fd9\u4e9b\u4fe1\u606f\u3002\u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u5927\u90e8\u5206\u6570\u636e\u90fd\u662f NaN \u503c\uff0c\u56e0\u6b64\u4e0d\u80fd\u4e22\u5f03 NaN \u503c\u7684\u884c/\u6837\u672c\u3002\u5904\u7406 NaN \u503c\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u5c06\u5176\u4f5c\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u662f\u5904\u7406 NaN \u503c\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002\u5982\u679c\u4f7f\u7528 pandas\uff0c\u8fd8\u53ef\u4ee5\u901a\u8fc7\u975e\u5e38\u7b80\u5355\u7684\u65b9\u5f0f\u5b9e\u73b0\u3002 \u8bf7\u770b\u6211\u4eec\u4e4b\u524d\u67e5\u770b\u8fc7\u7684\u6570\u636e\u7684 ord_2 \u5217\u3002 In [ X ]: df . ord_2 . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : ord_2 , dtype : int64 \u586b\u5165 NaN \u503c\u540e\uff0c\u5c31\u53d8\u6210\u4e86 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u54c7\uff01\u8fd9\u4e00\u5217\u4e2d\u6709 18075 \u4e2a NaN \u503c\uff0c\u800c\u6211\u4eec\u4e4b\u524d\u751a\u81f3\u90fd\u6ca1\u6709\u8003\u8651\u4f7f\u7528\u5b83\u4eec\u3002\u589e\u52a0\u4e86\u8fd9\u4e2a\u65b0\u7c7b\u522b\u540e\uff0c\u7c7b\u522b\u603b\u6570\u4ece 6 \u4e2a\u589e\u52a0\u5230\u4e86 7 \u4e2a\u3002\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u73b0\u5728\u6211\u4eec\u5728\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u4e5f\u4f1a\u8003\u8651 NaN\u3002\u76f8\u5173\u4fe1\u606f\u8d8a\u591a\uff0c\u6a21\u578b\u5c31\u8d8a\u597d\u3002 \u5047\u8bbe ord_2 \u6ca1\u6709\u4efb\u4f55 NaN \u503c\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e00\u5217\u4e2d\u7684\u6240\u6709\u7c7b\u522b\u90fd\u6709\u663e\u8457\u7684\u8ba1\u6570\u3002\u5176\u4e2d\u6ca1\u6709 \"\u7f55\u89c1 \"\u7c7b\u522b\uff0c\u5373\u53ea\u5728\u6837\u672c\u603b\u6570\u4e2d\u5360\u5f88\u5c0f\u6bd4\u4f8b\u7684\u7c7b\u522b\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5047\u8bbe\u60a8\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u4e86\u4f7f\u7528\u8fd9\u4e00\u5217\u7684\u6a21\u578b\uff0c\u5f53\u6a21\u578b\u6216\u9879\u76ee\u4e0a\u7ebf\u65f6\uff0c\u60a8\u5728 ord_2 \u5217\u4e2d\u5f97\u5230\u4e86\u4e00\u4e2a\u5728\u8bad\u7ec3\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u7ba1\u9053\u4f1a\u629b\u51fa\u4e00\u4e2a\u9519\u8bef\uff0c\u60a8\u5bf9\u6b64\u65e0\u80fd\u4e3a\u529b\u3002\u5982\u679c\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\uff0c\u90a3\u4e48\u53ef\u80fd\u662f\u751f\u4ea7\u4e2d\u7684\u7ba1\u9053\u51fa\u4e86\u95ee\u9898\u3002\u5982\u679c\u8fd9\u662f\u9884\u6599\u4e4b\u4e2d\u7684\uff0c\u90a3\u4e48\u60a8\u5c31\u5fc5\u987b\u4fee\u6539\u60a8\u7684\u6a21\u578b\u7ba1\u9053\uff0c\u5e76\u5728\u8fd9\u516d\u4e2a\u7c7b\u522b\u4e2d\u52a0\u5165\u4e00\u4e2a\u65b0\u7c7b\u522b\u3002 \u8fd9\u4e2a\u65b0\u7c7b\u522b\u88ab\u79f0\u4e3a \"\u7f55\u89c1 \"\u7c7b\u522b\u3002\u7f55\u89c1\u7c7b\u522b\u662f\u4e00\u79cd\u4e0d\u5e38\u89c1\u7684\u7c7b\u522b\uff0c\u53ef\u4ee5\u5305\u62ec\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8fd1\u90bb\u6a21\u578b\u6765 \"\u9884\u6d4b \"\u672a\u77e5\u7c7b\u522b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u4e86\u8fd9\u4e2a\u7c7b\u522b\uff0c\u5b83\u5c31\u4f1a\u6210\u4e3a\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u4e00\u4e2a\u7c7b\u522b\u3002 \u56fe 3\uff1a\u5177\u6709\u4e0d\u540c\u7279\u5f81\u4e14\u65e0\u6807\u7b7e\u7684\u6570\u636e\u96c6\u793a\u610f\u56fe\uff0c\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u53ef\u80fd\u4f1a\u5728\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u51fa\u73b0\u65b0\u503c \u5f53\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u56fe 3 \u6240\u793a\u7684\u6570\u636e\u96c6\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\uff0c\u5bf9\u9664 \"f3 \"\u4e4b\u5916\u7684\u6240\u6709\u7279\u5f81\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u6837\uff0c\u4f60\u5c06\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u4e0d\u77e5\u9053\u6216\u8bad\u7ec3\u4e2d\u6ca1\u6709 \"f3 \"\u65f6\u9884\u6d4b\u5b83\u3002\u6211\u4e0d\u6562\u8bf4\u8fd9\u6837\u7684\u6a21\u578b\u662f\u5426\u80fd\u5e26\u6765\u51fa\u8272\u7684\u6027\u80fd\uff0c\u4f46\u4e5f\u8bb8\u80fd\u5904\u7406\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u7684\u7f3a\u5931\u503c\uff0c\u5c31\u50cf\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5176\u4ed6\u4e8b\u60c5\u4e00\u6837\uff0c\u4e0d\u5c1d\u8bd5\u4e00\u4e0b\u662f\u8bf4\u4e0d\u51c6\u7684\u3002 \u5982\u679c\u4f60\u6709\u4e00\u4e2a\u56fa\u5b9a\u7684\u6d4b\u8bd5\u96c6\uff0c\u4f60\u53ef\u4ee5\u5c06\u6d4b\u8bd5\u6570\u636e\u6dfb\u52a0\u5230\u8bad\u7ec3\u4e2d\uff0c\u4ee5\u4e86\u89e3\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u7c7b\u522b\u3002\u8fd9\u4e0e\u534a\u76d1\u7763\u5b66\u4e60\u975e\u5e38\u76f8\u4f3c\uff0c\u5373\u4f7f\u7528\u65e0\u6cd5\u7528\u4e8e\u8bad\u7ec3\u7684\u6570\u636e\u6765\u6539\u8fdb\u6a21\u578b\u3002\u8fd9\u4e5f\u4f1a\u7167\u987e\u5230\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u51fa\u73b0\u6b21\u6570\u6781\u5c11\u4f46\u5728\u6d4b\u8bd5\u6570\u636e\u4e2d\u5927\u91cf\u5b58\u5728\u7684\u7a00\u6709\u503c\u3002\u4f60\u7684\u6a21\u578b\u5c06\u66f4\u52a0\u7a33\u5065\u3002 \u5f88\u591a\u4eba\u8ba4\u4e3a\u8fd9\u79cd\u60f3\u6cd5\u4f1a\u8fc7\u5ea6\u62df\u5408\u3002\u53ef\u80fd\u8fc7\u62df\u5408\uff0c\u4e5f\u53ef\u80fd\u4e0d\u8fc7\u62df\u5408\u3002\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u3002\u5982\u679c\u4f60\u5728\u8bbe\u8ba1\u4ea4\u53c9\u9a8c\u8bc1\u65f6\uff0c\u80fd\u591f\u5728\u6d4b\u8bd5\u6570\u636e\u4e0a\u8fd0\u884c\u6a21\u578b\u65f6\u590d\u5236\u9884\u6d4b\u8fc7\u7a0b\uff0c\u90a3\u4e48\u5b83\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u8fd9\u610f\u5473\u7740\u7b2c\u4e00\u6b65\u5e94\u8be5\u662f\u5206\u79bb\u6298\u53e0\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u4f60\u5e94\u8be5\u5e94\u7528\u4e0e\u6d4b\u8bd5\u6570\u636e\u76f8\u540c\u7684\u9884\u5904\u7406\u3002\u5047\u8bbe\u60a8\u60f3\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e\uff0c\u90a3\u4e48\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u60a8\u5fc5\u987b\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\uff0c\u5e76\u786e\u4fdd\u9a8c\u8bc1\u6570\u636e\u96c6\u590d\u5236\u4e86\u6d4b\u8bd5\u96c6\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u60a8\u5fc5\u987b\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u8bbe\u8ba1\u9a8c\u8bc1\u96c6\uff0c\u4f7f\u5176\u5305\u542b\u8bad\u7ec3\u96c6\u4e2d \"\u672a\u89c1 \"\u7684\u7c7b\u522b\u3002 \u56fe 4\uff1a\u5bf9\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u8fdb\u884c\u7b80\u5355\u5408\u5e76\uff0c\u4ee5\u4e86\u89e3\u6d4b\u8bd5\u96c6\u4e2d\u5b58\u5728\u4f46\u8bad\u7ec3\u96c6\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u6216\u8bad\u7ec3\u96c6\u4e2d\u7f55\u89c1\u7684\u7c7b\u522b \u53ea\u8981\u770b\u4e00\u4e0b\u56fe 4 \u548c\u4e0b\u9762\u7684\u4ee3\u7801\uff0c\u5c31\u80fd\u5f88\u5bb9\u6613\u7406\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( \"../input/cat_train.csv\" ) # \u8bfb\u53d6\u6d4b\u8bd5\u96c6 test = pd . read_csv ( \"../input/cat_test.csv\" ) # \u5c06\u6d4b\u8bd5\u96c6\"target\"\u5217\u5168\u90e8\u7f6e\u4e3a-1 test . loc [:, \"target\" ] = - 1 # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u62fc\u63a5 data = pd . concat ([ train , test ]) . reset_index ( drop = True ) # \u5c06\u9664\"id\"\u548c\"target\"\u5217\u7684\u5176\u4ed6\u7279\u5f81\u5217\u540d\u53d6\u51fa features = [ x for x in train . columns if x not in [ \"id\" , \"target\" ]] # \u904d\u5386\u7279\u5f81 for feat in features : # \u6807\u7b7e\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\",\u5e76\u5c06\u8be5\u5217\u683c\u5f0f\u53d8\u4e3astr temp_col = data [ feat ] . fillna ( \"NONE\" ) . astype ( str ) . values # \u8f6c\u6362\u6570\u503c data . loc [:, feat ] = lbl_enc . fit_transform ( temp_col ) # \u6839\u636e\"target\"\u5217\u5c06\u8bad\u7ec3\u96c6\u4e0e\u6d4b\u8bd5\u96c6\u5206\u5f00 train = data [ data . target != - 1 ] . reset_index ( drop = True ) test = data [ data . target == - 1 ] . reset_index ( drop = True ) \u5f53\u60a8\u9047\u5230\u5df2\u7ecf\u6709\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u95ee\u9898\u65f6\uff0c\u8fd9\u4e2a\u6280\u5de7\u5c31\u4f1a\u8d77\u4f5c\u7528\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u4e00\u62db\u5728\u5b9e\u65f6\u73af\u5883\u4e2d\u4e0d\u8d77\u4f5c\u7528\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u60a8\u6240\u5728\u7684\u516c\u53f8\u63d0\u4f9b\u5b9e\u65f6\u7ade\u4ef7\u89e3\u51b3\u65b9\u6848\uff08RTB\uff09\u3002RTB \u7cfb\u7edf\u4f1a\u5bf9\u5728\u7ebf\u770b\u5230\u7684\u6bcf\u4e2a\u7528\u6237\u8fdb\u884c\u7ade\u4ef7\uff0c\u4ee5\u8d2d\u4e70\u5e7f\u544a\u7a7a\u95f4\u3002\u8fd9\u79cd\u6a21\u5f0f\u53ef\u4f7f\u7528\u7684\u529f\u80fd\u53ef\u80fd\u5305\u62ec\u7f51\u7ad9\u4e2d\u6d4f\u89c8\u7684\u9875\u9762\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u7279\u5f81\u662f\u7528\u6237\u8bbf\u95ee\u7684\u6700\u540e\u4e94\u4e2a\u7c7b\u522b/\u9875\u9762\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u7f51\u7ad9\u5f15\u5165\u4e86\u65b0\u7684\u7c7b\u522b\uff0c\u6211\u4eec\u5c06\u65e0\u6cd5\u518d\u51c6\u786e\u9884\u6d4b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u5931\u6548\u3002\u8fd9\u79cd\u60c5\u51b5\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528 \"\u672a\u77e5 \"\u7c7b\u522b\u6765\u907f\u514d \u3002 \u5728\u6211\u4eec\u7684 cat-in-the-dat \u6570\u636e\u96c6\u4e2d\uff0c ord_2 \u5217\u4e2d\u5df2\u7ecf\u6709\u4e86\u672a\u77e5\u7c7b\u522b\u3002 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u6211\u4eec\u53ef\u4ee5\u5c06 \"NONE \"\u89c6\u4e3a\u672a\u77e5\u3002\u56e0\u6b64\uff0c\u5982\u679c\u5728\u5b9e\u65f6\u6d4b\u8bd5\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4ee5\u524d\u4ece\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\uff0c\u6211\u4eec\u5c31\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a \"NONE\"\u3002 \u8fd9\u4e0e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u95ee\u9898\u975e\u5e38\u76f8\u4f3c\u3002\u6211\u4eec\u603b\u662f\u57fa\u4e8e\u56fa\u5b9a\u7684\u8bcd\u6c47\u5efa\u7acb\u6a21\u578b\u3002\u589e\u52a0\u8bcd\u6c47\u91cf\u5c31\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u3002\u50cf BERT \u8fd9\u6837\u7684\u8f6c\u6362\u5668\u6a21\u578b\u662f\u5728 ~30000 \u4e2a\u5355\u8bcd\uff08\u82f1\u8bed\uff09\u7684\u57fa\u7840\u4e0a\u8bad\u7ec3\u7684\u3002\u56e0\u6b64\uff0c\u5f53\u6709\u65b0\u8bcd\u8f93\u5165\u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a UNK\uff08\u672a\u77e5\uff09\u3002 \u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u5047\u8bbe\u6d4b\u8bd5\u6570\u636e\u4e0e\u8bad\u7ec3\u6570\u636e\u5177\u6709\u76f8\u540c\u7684\u7c7b\u522b\uff0c\u4e5f\u53ef\u4ee5\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u5f15\u5165\u7f55\u89c1\u6216\u672a\u77e5\u7c7b\u522b\uff0c\u4ee5\u5904\u7406\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u65b0\u7c7b\u522b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u586b\u5165 NaN \u503c\u540e ord_4 \u5217\u7684\u503c\u8ba1\u6570\uff1a In [ X ]: df . ord_4 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 . . . K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 G 3404 V 3107 J 1950 L 1657 Name : ord_4 , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u6709\u4e9b\u6570\u503c\u53ea\u51fa\u73b0\u4e86\u51e0\u5343\u6b21\uff0c\u6709\u4e9b\u5219\u51fa\u73b0\u4e86\u8fd1 40000 \u6b21\u3002NaN \u4e5f\u7ecf\u5e38\u51fa\u73b0\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u503c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u5c06\u4e00\u4e2a\u503c\u79f0\u4e3a \" \u7f55\u89c1\uff08rare\uff09 \"\u7684\u6807\u51c6\u4e86\u3002\u6bd4\u65b9\u8bf4\uff0c\u5728\u8fd9\u4e00\u5217\u4e2d\uff0c\u7a00\u6709\u503c\u7684\u8981\u6c42\u662f\u8ba1\u6570\u5c0f\u4e8e 2000\u3002\u8fd9\u6837\u770b\u6765\uff0cJ \u548c L \u5c31\u53ef\u4ee5\u88ab\u6807\u8bb0\u4e3a\u7a00\u6709\u503c\u4e86\u3002\u4f7f\u7528 pandas\uff0c\u6839\u636e\u8ba1\u6570\u9608\u503c\u66ff\u6362\u7c7b\u522b\u975e\u5e38\u7b80\u5355\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 In [ X ]: df . ord_4 = df . ord_4 . fillna ( \"NONE\" ) In [ X ]: df . loc [ ... : df [ \"ord_4\" ] . value_counts ()[ df [ \"ord_4\" ]] . values < 2000 , ... : \"ord_4\" ... : ] = \"RARE\" In [ X ]: df . ord_4 . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 M 32504 . . . B 25212 E 21871 K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 RARE 3607 G 3404 V 3107 Name : ord_4 , dtype : int64 \u6211\u4eec\u8ba4\u4e3a\uff0c\u53ea\u8981\u67d0\u4e2a\u7c7b\u522b\u7684\u503c\u5c0f\u4e8e 2000\uff0c\u5c31\u5c06\u5176\u66ff\u6362\u4e3a\u7f55\u89c1\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u5728\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6240\u6709\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"RARE\"\uff0c\u800c\u6240\u6709\u7f3a\u5931\u503c\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"NONE\"\u3002 \u8fd9\u79cd\u65b9\u6cd5\u8fd8\u80fd\u786e\u4fdd\u5373\u4f7f\u6709\u65b0\u7684\u7c7b\u522b\uff0c\u6a21\u578b\u4e5f\u80fd\u5728\u5b9e\u9645\u73af\u5883\u4e2d\u6b63\u5e38\u5de5\u4f5c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u5177\u5907\u4e86\u5904\u7406\u4efb\u4f55\u5e26\u6709\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u6240\u9700\u7684\u4e00\u5207\u6761\u4ef6\u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u7b2c\u4e00\u4e2a\u6a21\u578b\uff0c\u5e76\u9010\u6b65\u63d0\u9ad8\u5176\u6027\u80fd\u3002 \u5728\u6784\u5efa\u4efb\u4f55\u7c7b\u578b\u7684\u6a21\u578b\u4e4b\u524d\uff0c\u4ea4\u53c9\u68c0\u9a8c\u81f3\u5173\u91cd\u8981\u3002\u6211\u4eec\u5df2\u7ecf\u770b\u5230\u4e86\u6807\u7b7e/\u76ee\u6807\u5206\u5e03\uff0c\u77e5\u9053\u8fd9\u662f\u4e00\u4e2a\u76ee\u6807\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u6765\u5206\u5272\u6570\u636e\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u6dfb\u52a0\"kfold\"\u5217\uff0c\u5e76\u7f6e\u4e3a-1 df [ \"kfold\" ] = - 1 # \u6253\u4e71\u6570\u636e\u987a\u5e8f\uff0c\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u5c06\u76ee\u6807\u5217\u53d6\u51fa y = df . target . values # \u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): # \u533a\u5206\u6298\u53e0 df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u6587\u4ef6 df . to_csv ( \"../input/cat_train_folds.csv\" , index = False ) \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u68c0\u67e5\u65b0\u7684\u6298\u53e0 csv\uff0c\u67e5\u770b\u6bcf\u4e2a\u6298\u53e0\u7684\u6837\u672c\u6570\uff1a In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) In [ X ]: df . kfold . value_counts () Out [ X ]: 4 120000 3 120000 2 120000 1 120000 0 120000 Name : kfold , dtype : int64 \u6240\u6709\u6298\u53e0\u90fd\u6709 120000 \u4e2a\u6837\u672c\u3002\u8fd9\u662f\u610f\u6599\u4e4b\u4e2d\u7684\uff0c\u56e0\u4e3a\u8bad\u7ec3\u6570\u636e\u6709 600000 \u4e2a\u6837\u672c\uff0c\u800c\u6211\u4eec\u505a\u4e865\u6b21\u6298\u53e0\u3002\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4e00\u5207\u987a\u5229\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u6bcf\u4e2a\u6298\u53e0\u7684\u76ee\u6807\u5206\u5e03\u3002 In [ X ]: df [ df . kfold == 0 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 1 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 2 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 3 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 4 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u76ee\u6807\u7684\u5206\u5e03\u90fd\u662f\u4e00\u6837\u7684\u3002\u8fd9\u6b63\u662f\u6211\u4eec\u6240\u9700\u8981\u7684\u3002\u5b83\u4e5f\u53ef\u4ee5\u662f\u76f8\u4f3c\u7684\uff0c\u5e76\u4e0d\u4e00\u5b9a\u8981\u4e00\u76f4\u76f8\u540c\u3002\u73b0\u5728\uff0c\u5f53\u6211\u4eec\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u6807\u7b7e\u5206\u5e03\u90fd\u5c06\u76f8\u540c\u3002 \u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u7684\u6700\u7b80\u5355\u7684\u6a21\u578b\u4e4b\u4e00\u662f\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u5e76\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): # \u8bfb\u53d6\u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) # \u53d6\u9664\"id\", \"target\", \"kfold\"\u5916\u7684\u5176\u4ed6\u7279\u5f81\u5217 features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] # \u904d\u5386\u7279\u5f81\u5217\u8868 for col in features : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u6d4b\u8bd5\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u903b\u8f91\u56de\u5f52 model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . target . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( auc ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6298\u53e00 run ( 0 ) \u90a3\u4e48\uff0c\u53d1\u751f\u4e86\u4ec0\u4e48\u5462\uff1f \u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u5c06\u6570\u636e\u5206\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u4e24\u90e8\u5206\uff0c\u7ed9\u5b9a\u6298\u53e0\u6570\uff0c\u5904\u7406 NaN \u503c\uff0c\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u5355\u6b21\u7f16\u7801\uff0c\u5e76\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\u3002 \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u90e8\u5206\u4ee3\u7801\u65f6\uff0c\u4f1a\u4ea7\u751f\u5982\u4e0b\u8f93\u51fa\uff1a \u276f python ohe_logres . py / home / abhishek / miniconda3 / envs / ml / lib / python3 .7 / site - packages / sklearn / linear_model / _logistic . py : 939 : ConvergenceWarning : lbfgs failed to converge ( status = 1 ): STOP : TOTAL NO . of ITERATIONS REACHED LIMIT . Increase the number of iterations ( max_iter ) or scale the data as shown in : https : // scikit - learn . org / stable / modules / preprocessing . html . Please also refer to the documentation for alternative solver options : https : // scikit - learn . org / stable / modules / linear_model . html #logistic- regression extra_warning_msg = _LOGISTIC_SOLVER_CONVERGENCE_MSG ) 0.7847865042255127 \u6709\u4e00\u4e9b\u8b66\u544a\u3002\u903b\u8f91\u56de\u5f52\u4f3c\u4e4e\u6ca1\u6709\u6536\u655b\u5230\u6700\u5927\u8fed\u4ee3\u6b21\u6570\u3002\u6211\u4eec\u6ca1\u6709\u8c03\u6574\u53c2\u6570\uff0c\u6240\u4ee5\u6ca1\u6709\u95ee\u9898\u3002\u6211\u4eec\u770b\u5230 AUC \u4e3a 0.785\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u5bf9\u4ee3\u7801\u8fdb\u884c\u7b80\u5355\u4fee\u6539\uff0c\u8fd0\u884c\u6240\u6709\u6298\u53e0\u3002 .... model = linear_model . LogisticRegression () model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u5faa\u73af\u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5e76\u6ca1\u6709\u505a\u5f88\u5927\u7684\u6539\u52a8\uff0c\u6240\u4ee5\u6211\u53ea\u663e\u793a\u4e86\u90e8\u5206\u4ee3\u7801\u884c\uff0c\u5176\u4e2d\u4e00\u4e9b\u4ee3\u7801\u884c\u6709\u6539\u52a8\u3002 \u8fd9\u5c31\u6253\u5370\u51fa\u4e86\uff1a python - W ignore ohe_logres . py Fold = 0 , AUC = 0.7847865042255127 Fold = 1 , AUC = 0.7853553605899214 Fold = 2 , AUC = 0.7879321942914885 Fold = 3 , AUC = 0.7870315929550808 Fold = 4 , AUC = 0.7864668243125608 \u8bf7\u6ce8\u610f\uff0c\u6211\u4f7f\u7528\"-W ignore \"\u5ffd\u7565\u4e86\u6240\u6709\u8b66\u544a\u3002 \u6211\u4eec\u770b\u5230\uff0cAUC \u5206\u6570\u5728\u6240\u6709\u8936\u76b1\u4e2d\u90fd\u76f8\u5f53\u7a33\u5b9a\u3002\u5e73\u5747 AUC \u4e3a 0.78631449527\u3002\u5bf9\u4e8e\u6211\u4eec\u7684\u7b2c\u4e00\u4e2a\u6a21\u578b\u6765\u8bf4\u76f8\u5f53\u4e0d\u9519\uff01 \u5f88\u591a\u4eba\u5728\u9047\u5230\u8fd9\u79cd\u95ee\u9898\u65f6\u4f1a\u9996\u5148\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6bd4\u5982\u968f\u673a\u68ee\u6797\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u5e94\u7528\u968f\u673a\u68ee\u6797\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\uff08label encoding\uff09\uff0c\u5c06\u6bcf\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81\u90fd\u8f6c\u6362\u4e3a\u6574\u6570\uff0c\u800c\u4e0d\u662f\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u72ec\u70ed\u7f16\u7801\u3002 \u8fd9\u79cd\u7f16\u7801\u4e0e\u72ec\u70ed\u7f16\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # \u968f\u673a\u68ee\u6797\u6a21\u578b model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u4f7f\u7528 scikit-learn \u4e2d\u7684\u968f\u673a\u68ee\u6797\uff0c\u5e76\u53d6\u6d88\u4e86\u72ec\u70ed\u7f16\u7801\u3002\u6211\u4eec\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u4ee3\u66ff\u72ec\u70ed\u7f16\u7801\u3002\u5f97\u5206\u5982\u4e0b \u276f python lbl_rf . py Fold = 0 , AUC = 0.7167390828113697 Fold = 1 , AUC = 0.7165459672958506 Fold = 2 , AUC = 0.7159709909587376 Fold = 3 , AUC = 0.7161589664189556 Fold = 4 , AUC = 0.7156020216155978 \u54c7 \u5de8\u5927\u7684\u5dee\u5f02\uff01 \u968f\u673a\u68ee\u6797\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8d85\u53c2\u6570\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u8868\u73b0\u8981\u6bd4\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u5dee\u5f88\u591a\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u603b\u662f\u5e94\u8be5\u5148\u4ece\u7b80\u5355\u6a21\u578b\u5f00\u59cb\u7684\u539f\u56e0\u3002\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u7c89\u4e1d\u4f1a\u4ece\u8fd9\u91cc\u5f00\u59cb\uff0c\u800c\u5ffd\u7565\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u8ba4\u4e3a\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u6a21\u578b\uff0c\u4e0d\u80fd\u5e26\u6765\u6bd4\u968f\u673a\u68ee\u6797\u66f4\u597d\u7684\u4ef7\u503c\u3002\u8fd9\u79cd\u4eba\u5c06\u4f1a\u72af\u4e0b\u5927\u9519\u3002\u5728\u6211\u4eec\u5b9e\u73b0\u968f\u673a\u68ee\u6797\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4e0e\u903b\u8f91\u56de\u5f52\u76f8\u6bd4\uff0c\u6298\u53e0\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u624d\u80fd\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e0d\u4ec5\u635f\u5931\u4e86 AUC\uff0c\u8fd8\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u6765\u5b8c\u6210\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u8fdb\u884c\u63a8\u7406\u4e5f\u5f88\u8017\u65f6\uff0c\u800c\u4e14\u5360\u7528\u7684\u7a7a\u95f4\u4e5f\u66f4\u5927\u3002 \u5982\u679c\u6211\u4eec\u613f\u610f\uff0c\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u5728\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u6570\u636e\u4e0a\u8fd0\u884c\u968f\u673a\u68ee\u6797\uff0c\u4f46\u8fd9\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u5947\u5f02\u503c\u5206\u89e3\u6765\u51cf\u5c11\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u77e9\u9635\u3002\u8fd9\u662f\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u63d0\u53d6\u4e3b\u9898\u7684\u5e38\u7528\u65b9\u6cd5\u3002 import pandas as pd from scipy import sparse from sklearn import decomposition from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" )] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) x_train = ohe . transform ( df_train [ features ]) x_valid = ohe . transform ( df_valid [ features ]) # \u5947\u5f02\u503c\u5206\u89e3 svd = decomposition . TruncatedSVD ( n_components = 120 ) full_sparse = sparse . vstack (( x_train , x_valid )) svd . fit ( full_sparse ) x_train = svd . transform ( x_train ) x_valid = svd . transform ( x_valid ) model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u5bf9\u5168\u90e8\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\uff0c\u7136\u540e\u7528\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u5728\u7a00\u758f\u77e9\u9635\u4e0a\u62df\u5408 scikit-learn \u7684 TruncatedSVD\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c06\u9ad8\u7ef4\u7a00\u758f\u77e9\u9635\u51cf\u5c11\u5230 120 \u4e2a\u7279\u5f81\uff0c\u7136\u540e\u62df\u5408\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u3002 \u4ee5\u4e0b\u662f\u8be5\u6a21\u578b\u7684\u8f93\u51fa\u7ed3\u679c\uff1a \u276f python ohe_svd_rf . py Fold = 0 , AUC = 0.7064863038754249 Fold = 1 , AUC = 0.706050102937374 Fold = 2 , AUC = 0.7086069243167242 Fold = 3 , AUC = 0.7066819080085971 Fold = 4 , AUC = 0.7058154015055585 \u6211\u4eec\u53d1\u73b0\u60c5\u51b5\u66f4\u7cdf\u3002\u770b\u6765\uff0c\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u7684\u6700\u4f73\u65b9\u6cd5\u662f\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u3002\u968f\u673a\u68ee\u6797\u4f3c\u4e4e\u8017\u65f6\u592a\u591a\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u8bd5\u8bd5 XGBoost\u3002\u5982\u679c\u4f60\u4e0d\u77e5\u9053 XGBoost\uff0c\u5b83\u662f\u6700\u6d41\u884c\u7684\u68af\u5ea6\u63d0\u5347\u7b97\u6cd5\u4e4b\u4e00\u3002\u7531\u4e8e\u5b83\u662f\u4e00\u79cd\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u6570\u636e\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 , n_estimators = 200 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6211\u5bf9 xgboost \u53c2\u6570\u505a\u4e86\u4e00\u4e9b\u4fee\u6539\u3002xgboost \u7684\u9ed8\u8ba4\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u662f 3\uff0c\u6211\u628a\u5b83\u6539\u6210\u4e86 7\uff0c\u8fd8\u628a\u4f30\u8ba1\u5668\u6570\u91cf\uff08n_estimators\uff09\u4ece 100 \u6539\u6210\u4e86 200\u3002 \u8be5\u6a21\u578b\u7684 5 \u6298\u4ea4\u53c9\u68c0\u9a8c\u5f97\u5206\u5982\u4e0b\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.7656768851999011 Fold = 1 , AUC = 0.7633006564148015 Fold = 2 , AUC = 0.7654277821434345 Fold = 3 , AUC = 0.7663609758878182 Fold = 4 , AUC = 0.764914671468069 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u4e0d\u505a\u4efb\u4f55\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u5f97\u5206\u6bd4\u666e\u901a\u968f\u673a\u68ee\u6797\u8981\u9ad8\u5f97\u591a\u3002 \u60a8\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4e00\u4e9b\u7279\u5f81\u5de5\u7a0b\uff0c\u653e\u5f03\u67d0\u4e9b\u5bf9\u6a21\u578b\u6ca1\u6709\u4efb\u4f55\u4ef7\u503c\u7684\u5217\u7b49\u3002\u4f46\u4f3c\u4e4e\u6211\u4eec\u80fd\u505a\u7684\u4e0d\u591a\uff0c\u65e0\u6cd5\u8bc1\u660e\u6a21\u578b\u7684\u6539\u8fdb\u3002\u8ba9\u6211\u4eec\u628a\u6570\u636e\u96c6\u6362\u6210\u53e6\u4e00\u4e2a\u6709\u5927\u91cf\u5206\u7c7b\u53d8\u91cf\u7684\u6570\u636e\u96c6\u3002\u53e6\u4e00\u4e2a\u6709\u540d\u7684\u6570\u636e\u96c6\u662f \u7f8e\u56fd\u6210\u4eba\u4eba\u53e3\u666e\u67e5\u6570\u636e\uff08US adult census data\uff09 \u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u5305\u542b\u4e00\u4e9b\u7279\u5f81\uff0c\u800c\u4f60\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5de5\u8d44\u7b49\u7ea7\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u3002\u56fe 5 \u663e\u793a\u4e86\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e9b\u5217\u3002 \u56fe 5\uff1a\u90e8\u5206\u6570\u636e\u96c6\u5c55\u793a \u8be5\u6570\u636e\u96c6\u6709\u4ee5\u4e0b\u51e0\u5217\uff1a - \u5e74\u9f84\uff08age\uff09 \u5de5\u4f5c\u7c7b\u522b\uff08workclass\uff09 \u5b66\u5386\uff08fnlwgt\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education.num\uff09 \u5a5a\u59fb\u72b6\u51b5\uff08marital.status\uff09 \u804c\u4e1a\uff08occupation\uff09 \u5173\u7cfb\uff08relationship\uff09 \u79cd\u65cf\uff08race\uff09 \u6027\u522b\uff08sex\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u539f\u7c4d\u56fd\uff08native.country\uff09 \u6536\u5165\uff08income\uff09 \u8fd9\u4e9b\u7279\u5f81\u5927\u591a\u4e0d\u8a00\u81ea\u660e\u3002\u90a3\u4e9b\u4e0d\u660e\u767d\u7684\uff0c\u6211\u4eec\u53ef\u4ee5\u4e0d\u8003\u8651\u3002\u8ba9\u6211\u4eec\u5148\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002 \u6211\u4eec\u770b\u5230\u6536\u5165\u5217\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u3002\u8ba9\u6211\u4eec\u5bf9\u8fd9\u4e00\u5217\u8fdb\u884c\u6570\u503c\u7edf\u8ba1\u3002 In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/adult.csv\" ) In [ X ]: df . income . value_counts () Out [ X ]: <= 50 K 24720 > 50 K 7841 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 7841 \u4e2a\u5b9e\u4f8b\u7684\u6536\u5165\u8d85\u8fc7 5 \u4e07\u7f8e\u5143\u3002\u8fd9\u5360\u6837\u672c\u603b\u6570\u7684 24%\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u4e0e\u732b\u6570\u636e\u96c6\u76f8\u540c\u7684\u8bc4\u4f30\u65b9\u6cd5\uff0c\u5373 AUC\u3002 \u5728\u5f00\u59cb\u5efa\u6a21\u4e4b\u524d\uff0c\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5c06\u53bb\u6389\u51e0\u5217\u7279\u5f81\uff0c\u5373 \u5b66\u5386\uff08fnlwgt\uff09 \u5e74\u9f84\uff08age\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u8ba9\u6211\u4eec\u8bd5\u7740\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u5668\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u7b2c\u4e00\u6b65\u603b\u662f\u8981\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002\u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u90e8\u5206\u4ee3\u7801\u3002\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u9700\u8981\u5220\u9664\u7684\u5217 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) # \u6620\u5c04 target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } # \u4f7f\u7528\u6620\u5c04\u66ff\u6362 df . loc [:, \"income\" ] = df . income . map ( target_mapping ) # \u53d6\u9664\"kfold\", \"income\"\u5217\u7684\u5176\u4ed6\u5217\u540d features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u9a8c\u8bc1\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u6784\u5efa\u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . income . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u6211\u4eec\u4f1a\u5f97\u5230 \u276f python - W ignore ohe_logres . py Fold = 0 , AUC = 0.8794809708119079 Fold = 1 , AUC = 0.8875785068274882 Fold = 2 , AUC = 0.8852609687685753 Fold = 3 , AUC = 0.8681236223251438 Fold = 4 , AUC = 0.8728581541840037 \u5bf9\u4e8e\u4e00\u4e2a\u5982\u6b64\u7b80\u5355\u7684\u6a21\u578b\u6765\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u4e0d\u9519\u7684 AUC\uff01 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5728\u4e0d\u8c03\u6574\u4efb\u4f55\u8d85\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c1d\u8bd5\u4e00\u4e0b\u6807\u7b7e\u7f16\u7801\u7684xgboost\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8ba9\u6211\u4eec\u8fd0\u884c\u4e0a\u9762\u4ee3\u7801\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8800810634234078 Fold = 1 , AUC = 0.886811884948154 Fold = 2 , AUC = 0.8854421433318472 Fold = 3 , AUC = 0.8676319549361007 Fold = 4 , AUC = 0.8714450054900602 \u8fd9\u770b\u8d77\u6765\u5df2\u7ecf\u76f8\u5f53\u4e0d\u9519\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b max_depth \u589e\u52a0\u5230 7 \u548c n_estimators \u589e\u52a0\u5230 200 \u65f6\u7684\u5f97\u5206\u3002 \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8764108944332032 Fold = 1 , AUC = 0.8840708537662638 Fold = 2 , AUC = 0.8816601162613102 Fold = 3 , AUC = 0.8662335762581732 Fold = 4 , AUC = 0.8698983461709926 \u770b\u8d77\u6765\u5e76\u6ca1\u6709\u6539\u5584\u3002 \u8fd9\u8868\u660e\uff0c\u4e00\u4e2a\u6570\u636e\u96c6\u7684\u53c2\u6570\u4e0d\u80fd\u79fb\u690d\u5230\u53e6\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6211\u4eec\u5fc5\u987b\u518d\u6b21\u5c1d\u8bd5\u8c03\u6574\u53c2\u6570\uff0c\u4f46\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u8bf4\u660e\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u4e0d\u8c03\u6574\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c06\u6570\u503c\u7279\u5f81\u7eb3\u5165 xgboost \u6a21\u578b\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u52a0\u5165\u6570\u503c\u7279\u5f81 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : if col not in num_cols : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u56e0\u6b64\uff0c\u6211\u4eec\u4fdd\u7559\u6570\u5b57\u5217\uff0c\u53ea\u662f\u4e0d\u5bf9\u5176\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6700\u7ec8\u7279\u5f81\u77e9\u9635\u5c31\u7531\u6570\u5b57\u5217\uff08\u539f\u6837\uff09\u548c\u7f16\u7801\u5206\u7c7b\u5217\u7ec4\u6210\u4e86\u3002\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\u90fd\u80fd\u8f7b\u677e\u5904\u7406\u8fd9\u79cd\u6df7\u5408\u3002 \u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff0c\u5728\u4f7f\u7528\u7ebf\u6027\u6a21\u578b\uff08\u5982\u903b\u8f91\u56de\u5f52\uff09\u65f6\u4e0d\u5bb9\u5ffd\u89c6\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u4e2a\u811a\u672c\uff01 \u276f python lbl_xgb_num . py Fold = 0 , AUC = 0.9209790185449889 Fold = 1 , AUC = 0.9247157449144706 Fold = 2 , AUC = 0.9269329887598243 Fold = 3 , AUC = 0.9119349082169275 Fold = 4 , AUC = 0.9166408030141667 \u54c7\u54e6 \u8fd9\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u5206\u6570\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c1d\u8bd5\u6dfb\u52a0\u4e00\u4e9b\u529f\u80fd\u3002\u6211\u4eec\u5c06\u63d0\u53d6\u6240\u6709\u5206\u7c7b\u5217\uff0c\u5e76\u521b\u5efa\u6240\u6709\u4e8c\u5ea6\u7ec4\u5408\u3002\u8bf7\u770b\u4e0b\u9762\u4ee3\u7801\u6bb5\u4e2d\u7684 feature_engineering \u51fd\u6570\uff0c\u4e86\u89e3\u5982\u4f55\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002 import itertools import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def feature_engineering ( df , cat_cols ): # \u751f\u6210\u4e24\u4e2a\u7279\u5f81\u7684\u7ec4\u5408 combi = list ( itertools . combinations ( cat_cols , 2 )) for c1 , c2 in combi : df . loc [:, c1 + \"_\" + c2 ] = df [ c1 ] . astype ( str ) + \"_\" + df [ c2 ] . astype ( str ) return df def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) cat_cols = [ c for c in df . columns if c not in num_cols and c not in ( \"kfold\" , \"income\" )] # \u7279\u5f81\u5de5\u7a0b df = feature_engineering ( df , cat_cols ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" )] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u8fd9\u662f\u4ece\u5206\u7c7b\u5217\u4e2d\u521b\u5efa\u7279\u5f81\u7684\u4e00\u79cd\u975e\u5e38\u5e7c\u7a1a\u7684\u65b9\u6cd5\u3002\u6211\u4eec\u5e94\u8be5\u4ed4\u7ec6\u7814\u7a76\u6570\u636e\uff0c\u770b\u770b\u54ea\u4e9b\u7ec4\u5408\u6700\u5408\u7406\u3002\u5982\u679c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5c31\u9700\u8981\u4f7f\u7528\u67d0\u79cd\u7279\u5f81\u9009\u62e9\u6765\u9009\u51fa\u6700\u4f73\u7279\u5f81\u3002\u7a0d\u540e\u6211\u4eec\u5c06\u8be6\u7ec6\u4ecb\u7ecd\u7279\u5f81\u9009\u62e9\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5206\u6570\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9211483465031423 Fold = 1 , AUC = 0.9251499446866125 Fold = 2 , AUC = 0.9262344766486692 Fold = 3 , AUC = 0.9114264068794995 Fold = 4 , AUC = 0.9177914453099201 \u770b\u6765\uff0c\u5373\u4f7f\u4e0d\u6539\u53d8\u4efb\u4f55\u8d85\u53c2\u6570\uff0c\u53ea\u589e\u52a0\u4e00\u4e9b\u7279\u5f81\uff0c\u6211\u4eec\u4e5f\u80fd\u63d0\u9ad8\u4e00\u4e9b\u6298\u53e0\u5f97\u5206\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5c06 max_depth \u589e\u52a0\u5230 7 \u662f\u5426\u6709\u5e2e\u52a9\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9286668430204137 Fold = 1 , AUC = 0.9329340656165378 Fold = 2 , AUC = 0.9319817543218744 Fold = 3 , AUC = 0.919046187194538 Fold = 4 , AUC = 0.9245692057162671 \u6211\u4eec\u518d\u6b21\u6539\u8fdb\u4e86\u6211\u4eec\u7684\u6a21\u578b\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u8fd8\u6ca1\u6709\u4f7f\u7528\u7a00\u6709\u503c\u3001\u4e8c\u503c\u5316\u3001\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u7279\u5f81\u7684\u7ec4\u5408\u4ee5\u53ca\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u3002 \u4ece\u5206\u7c7b\u7279\u5f81\u4e2d\u8fdb\u884c\u7279\u5f81\u5de5\u7a0b\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u4f7f\u7528 \u76ee\u6807\u7f16\u7801 \u3002\u4f46\u662f\uff0c\u60a8\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u8fd9\u53ef\u80fd\u4f1a\u4f7f\u60a8\u7684\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u76ee\u6807\u7f16\u7801\u662f\u4e00\u79cd\u5c06\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u6620\u5c04\u5230\u5176\u5e73\u5747\u76ee\u6807\u503c\u7684\u6280\u672f\uff0c\u4f46\u5fc5\u987b\u59cb\u7ec8\u4ee5\u4ea4\u53c9\u9a8c\u8bc1\u7684\u65b9\u5f0f\u8fdb\u884c\u3002\u8fd9\u610f\u5473\u7740\u9996\u5148\u8981\u521b\u5efa\u6298\u53e0\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u6298\u53e0\u4e3a\u6570\u636e\u7684\u4e0d\u540c\u5217\u521b\u5efa\u76ee\u6807\u7f16\u7801\u7279\u5f81\uff0c\u65b9\u6cd5\u4e0e\u5728\u6298\u53e0\u4e0a\u62df\u5408\u548c\u9884\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u521b\u5efa\u4e86 5 \u4e2a\u6298\u53e0\uff0c\u60a8\u5c31\u5fc5\u987b\u521b\u5efa 5 \u6b21\u76ee\u6807\u7f16\u7801\uff0c\u8fd9\u6837\u6700\u7ec8\uff0c\u60a8\u5c31\u53ef\u4ee5\u4e3a\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u53d8\u91cf\u521b\u5efa\u7f16\u7801\uff0c\u800c\u8fd9\u4e9b\u53d8\u91cf\u5e76\u975e\u6765\u81ea\u540c\u4e00\u4e2a\u6298\u53e0\u3002\u7136\u540e\u5728\u62df\u5408\u6a21\u578b\u65f6\uff0c\u5fc5\u987b\u518d\u6b21\u4f7f\u7528\u76f8\u540c\u7684\u6298\u53e0\u3002\u672a\u89c1\u6d4b\u8bd5\u6570\u636e\u7684\u76ee\u6807\u7f16\u7801\u53ef\u4ee5\u6765\u81ea\u5168\u90e8\u8bad\u7ec3\u6570\u636e\uff0c\u4e5f\u53ef\u4ee5\u662f\u6240\u6709 5 \u4e2a\u6298\u53e0\u7684\u5e73\u5747\u503c\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u540c\u4e00\u4e2a\u6210\u4eba\u6570\u636e\u96c6\u4e0a\u4f7f\u7528\u76ee\u6807\u7f16\u7801\uff0c\u4ee5\u4fbf\u8fdb\u884c\u6bd4\u8f83\u3002 import copy import pandas as pd from sklearn import metrics from sklearn import preprocessing import xgboost as xgb def mean_target_encoding ( data ): df = copy . deepcopy ( data ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) and f not in num_cols ] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) encoded_dfs = [] for fold in range ( 5 ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) for column in features : # \u76ee\u6807\u7f16\u7801 mapping_dict = dict ( df_train . groupby ( column )[ \"income\" ] . mean () ) df_valid . loc [:, column + \"_enc\" ] = df_valid [ column ] . map ( mapping_dict ) encoded_dfs . append ( df_valid ) encoded_df = pd . concat ( encoded_dfs , axis = 0 ) return encoded_df def run ( df , fold ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/adult_folds.csv\" ) df = mean_target_encoding ( df ) for fold_ in range ( 5 ): run ( df , fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u4e0a\u8ff0\u7247\u6bb5\u4e2d\uff0c\u6211\u5728\u8fdb\u884c\u76ee\u6807\u7f16\u7801\u65f6\u5e76\u6ca1\u6709\u5220\u9664\u5206\u7c7b\u5217\u3002\u6211\u4fdd\u7559\u4e86\u6240\u6709\u7279\u5f81\uff0c\u5e76\u5728\u6b64\u57fa\u7840\u4e0a\u6dfb\u52a0\u4e86\u76ee\u6807\u7f16\u7801\u7279\u5f81\u3002\u6b64\u5916\uff0c\u6211\u8fd8\u4f7f\u7528\u4e86\u5e73\u5747\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5e73\u5747\u503c\u3001\u4e2d\u4f4d\u6570\u3001\u6807\u51c6\u504f\u5dee\u6216\u76ee\u6807\u7684\u4efb\u4f55\u5176\u4ed6\u51fd\u6570\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u7ed3\u679c\u3002 Fold = 0 , AUC = 0.9332240662017529 Fold = 1 , AUC = 0.9363551625140347 Fold = 2 , AUC = 0.9375013544556173 Fold = 3 , AUC = 0.92237621307625 Fold = 4 , AUC = 0.9292131180445478 \u4e0d\u9519\uff01\u770b\u6765\u6211\u4eec\u53c8\u6709\u8fdb\u6b65\u4e86\u3002\u4e0d\u8fc7\uff0c\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u5b83\u592a\u5bb9\u6613\u51fa\u73b0\u8fc7\u5ea6\u62df\u5408\u3002\u5f53\u6211\u4eec\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\uff0c\u6700\u597d\u4f7f\u7528\u67d0\u79cd\u5e73\u6ed1\u65b9\u6cd5\u6216\u5728\u7f16\u7801\u503c\u4e2d\u6dfb\u52a0\u566a\u58f0\u3002 Scikit-learn \u7684\u8d21\u732e\u5e93\u4e2d\u6709\u5e26\u5e73\u6ed1\u7684\u76ee\u6807\u7f16\u7801\uff0c\u4f60\u4e5f\u53ef\u4ee5\u521b\u5efa\u81ea\u5df1\u7684\u5e73\u6ed1\u3002\u5e73\u6ed1\u4f1a\u5f15\u5165\u67d0\u79cd\u6b63\u5219\u5316\uff0c\u6709\u52a9\u4e8e\u907f\u514d\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u8fd9\u5e76\u4e0d\u96be\u3002 \u5904\u7406\u5206\u7c7b\u7279\u5f81\u662f\u4e00\u9879\u590d\u6742\u7684\u4efb\u52a1\u3002\u8bb8\u591a\u8d44\u6e90\u4e2d\u90fd\u6709\u5927\u91cf\u4fe1\u606f\u3002\u672c\u7ae0\u5e94\u8be5\u80fd\u5e2e\u52a9\u4f60\u5f00\u59cb\u89e3\u51b3\u5206\u7c7b\u53d8\u91cf\u7684\u4efb\u4f55\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u5bf9\u4e8e\u5927\u591a\u6570\u95ee\u9898\u6765\u8bf4\uff0c\u9664\u4e86\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u4e4b\u5916\uff0c\u4f60\u4e0d\u9700\u8981\u66f4\u591a\u7684\u4e1c\u897f\u3002 \u8981\u8fdb\u4e00\u6b65\u6539\u8fdb\u6a21\u578b\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u66f4\u591a\uff01 \u5728\u672c\u7ae0\u7684\u6700\u540e\uff0c\u6211\u4eec\u4e0d\u80fd\u4e0d\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e00\u79cd\u79f0\u4e3a \u5b9e\u4f53\u5d4c\u5165 \u7684\u6280\u672f\u3002\u5728\u5b9e\u4f53\u5d4c\u5165\u4e2d\uff0c\u7c7b\u522b\u7528\u5411\u91cf\u8868\u793a\u3002\u5728\u4e8c\u503c\u5316\u548c\u72ec\u70ed\u7f16\u7801\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u90fd\u662f\u7528\u5411\u91cf\u6765\u8868\u793a\u7c7b\u522b\u7684\u3002 \u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u6709\u6570\u4ee5\u4e07\u8ba1\u7684\u7c7b\u522b\u600e\u4e48\u529e\uff1f\u8fd9\u5c06\u4f1a\u4ea7\u751f\u5de8\u5927\u7684\u77e9\u9635\uff0c\u6211\u4eec\u5c06\u9700\u8981\u5f88\u957f\u65f6\u95f4\u6765\u8bad\u7ec3\u590d\u6742\u7684\u6a21\u578b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5e26\u6709\u6d6e\u70b9\u503c\u7684\u5411\u91cf\u6765\u8868\u793a\u5b83\u4eec\u3002 \u8fd9\u4e2a\u60f3\u6cd5\u975e\u5e38\u7b80\u5355\u3002\u6bcf\u4e2a\u5206\u7c7b\u7279\u5f81\u90fd\u6709\u4e00\u4e2a\u5d4c\u5165\u5c42\u3002\u56e0\u6b64\uff0c\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u73b0\u5728\u90fd\u53ef\u4ee5\u6620\u5c04\u5230\u4e00\u4e2a\u5d4c\u5165\u5c42\uff08\u5c31\u50cf\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u5c06\u5355\u8bcd\u6620\u5c04\u5230\u5d4c\u5165\u5c42\u4e00\u6837\uff09\u3002\u7136\u540e\uff0c\u6839\u636e\u5176\u7ef4\u5ea6\u91cd\u5851\u8fd9\u4e9b\u5d4c\u5165\u5c42\uff0c\u4f7f\u5176\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c06\u6240\u6709\u6241\u5e73\u5316\u7684\u8f93\u5165\u5d4c\u5165\u5c42\u8fde\u63a5\u8d77\u6765\u3002\u7136\u540e\u6dfb\u52a0\u4e00\u5806\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\uff0c\u5c31\u5927\u529f\u544a\u6210\u4e86\u3002 \u56fe 6\uff1a\u7c7b\u522b\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6216\u5d4c\u5165\u5411\u91cf \u51fa\u4e8e\u67d0\u79cd\u539f\u56e0\uff0c\u6211\u53d1\u73b0\u4f7f\u7528 TF/Keras \u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528 TF/Keras \u5b9e\u73b0\u5b83\u3002\u6b64\u5916\uff0c\u8fd9\u662f\u672c\u4e66\u4e2d\u552f\u4e00\u4e00\u4e2a\u4f7f\u7528 TF/Keras \u7684\u793a\u4f8b\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a PyTorch\uff08\u4f7f\u7528 cat-in-the-dat-ii \u6570\u636e\u96c6\uff09\u4e5f\u975e\u5e38\u5bb9\u6613 import os import gc import joblib import pandas as pd import numpy as np from sklearn import metrics , preprocessing from tensorflow.keras import layers from tensorflow.keras import optimizers from tensorflow.keras.models import Model , load_model from tensorflow.keras import callbacks from tensorflow.keras import backend as K from tensorflow.keras import utils def create_model ( data , catcols ): # \u521b\u5efa\u7a7a\u7684\u8f93\u5165\u5217\u8868\u548c\u8f93\u51fa\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6a21\u578b\u7684\u8f93\u5165\u548c\u8f93\u51fa inputs = [] outputs = [] # \u904d\u5386\u5206\u7c7b\u7279\u5f81\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81 for c in catcols : # \u8ba1\u7b97\u7279\u5f81\u4e2d\u552f\u4e00\u503c\u7684\u6570\u91cf num_unique_values = int ( data [ c ] . nunique ()) # \u8ba1\u7b97\u5d4c\u5165\u7ef4\u5ea6\uff0c\u6700\u5927\u4e0d\u8d85\u8fc750 embed_dim = int ( min ( np . ceil (( num_unique_values ) / 2 ), 50 )) # \u521b\u5efa\u6a21\u578b\u7684\u8f93\u5165\u5c42\uff0c\u6bcf\u4e2a\u7279\u5f81\u5bf9\u5e94\u4e00\u4e2a\u8f93\u5165 inp = layers . Input ( shape = ( 1 ,)) # \u521b\u5efa\u5d4c\u5165\u5c42\uff0c\u5c06\u5206\u7c7b\u7279\u5f81\u6620\u5c04\u5230\u4f4e\u7ef4\u5ea6\u7684\u8fde\u7eed\u5411\u91cf out = layers . Embedding ( num_unique_values + 1 , embed_dim , name = c )( inp ) # \u5bf9\u5d4c\u5165\u5c42\u8fdb\u884c\u7a7a\u95f4\u4e22\u5f03\uff08Dropout\uff09 out = layers . SpatialDropout1D ( 0.3 )( out ) # \u5c06\u5d4c\u5165\u5c42\u7684\u5f62\u72b6\u91cd\u65b0\u8c03\u6574\u4e3a\u4e00\u7ef4 out = layers . Reshape ( target_shape = ( embed_dim ,))( out ) # \u5c06\u8f93\u5165\u548c\u8f93\u51fa\u6dfb\u52a0\u5230\u5bf9\u5e94\u7684\u5217\u8868\u4e2d inputs . append ( inp ) outputs . append ( out ) # \u4f7f\u7528Concatenate\u5c42\u5c06\u6240\u6709\u7684\u5d4c\u5165\u5c42\u8f93\u51fa\u8fde\u63a5\u5728\u4e00\u8d77 x = layers . Concatenate ()( outputs ) # \u5bf9\u8fde\u63a5\u540e\u7684\u6570\u636e\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u53e6\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u8f93\u51fa\u5c42\uff0c\u5177\u67092\u4e2a\u795e\u7ecf\u5143\uff08\u7528\u4e8e\u4e8c\u8fdb\u5236\u5206\u7c7b\uff09\uff0c\u5e76\u4f7f\u7528softmax\u6fc0\u6d3b\u51fd\u6570 y = layers . Dense ( 2 , activation = \"softmax\" )( x ) # \u521b\u5efa\u6a21\u578b\uff0c\u5c06\u8f93\u5165\u548c\u8f93\u51fa\u4f20\u9012\u7ed9Model\u6784\u9020\u51fd\u6570 model = Model ( inputs = inputs , outputs = y ) # \u7f16\u8bd1\u6a21\u578b\uff0c\u6307\u5b9a\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668 model . compile ( loss = 'binary_crossentropy' , optimizer = 'adam' ) # \u8fd4\u56de\u521b\u5efa\u7684\u6a21\u578b return model def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for feat in features : lbl_enc = preprocessing . LabelEncoder () df . loc [:, feat ] = lbl_enc . fit_transform ( df [ feat ] . values ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) model = create_model ( df , features ) xtrain = [ df_train [ features ] . values [:, k ] for k in range ( len ( features ))] xvalid = [ df_valid [ features ] . values [:, k ] for k in range ( len ( features )) ] ytrain = df_train . target . values yvalid = df_valid . target . values ytrain_cat = utils . to_categorical ( ytrain ) yvalid_cat = utils . to_categorical ( yvalid ) model . fit ( xtrain , ytrain_cat , validation_data = ( xvalid , yvalid_cat ), verbose = 1 , batch_size = 1024 , epochs = 3 ) valid_preds = model . predict ( xvalid )[:, 1 ] print ( metrics . roc_auc_score ( yvalid , valid_preds )) K . clear_session () if __name__ == \"__main__\" : run ( 0 ) run ( 1 ) run ( 2 ) run ( 3 ) run ( 4 ) \u4f60\u4f1a\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u6548\u679c\u6700\u597d\uff0c\u800c\u4e14\u5982\u679c\u4f60\u6709 GPU\uff0c\u901f\u5ea6\u4e5f\u8d85\u5feb\uff01\u8fd9\u79cd\u65b9\u6cd5\u8fd8\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\uff0c\u800c\u4e14\u4f60\u65e0\u9700\u62c5\u5fc3\u7279\u5f81\u5de5\u7a0b\uff0c\u56e0\u4e3a\u795e\u7ecf\u7f51\u7edc\u4f1a\u81ea\u884c\u5904\u7406\u3002\u5728\u5904\u7406\u5927\u91cf\u5206\u7c7b\u7279\u5f81\u6570\u636e\u96c6\u65f6\uff0c\u8fd9\u7edd\u5bf9\u503c\u5f97\u4e00\u8bd5\u3002\u5f53\u5d4c\u5165\u5927\u5c0f\u4e0e\u552f\u4e00\u7c7b\u522b\u7684\u6570\u91cf\u76f8\u540c\u65f6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u7f16\u7801\uff08one-hot-encoding\uff09\u3002 \u672c\u7ae0\u57fa\u672c\u4e0a\u90fd\u662f\u5173\u4e8e\u7279\u5f81\u5de5\u7a0b\u7684\u3002\u8ba9\u6211\u4eec\u5728\u4e0b\u4e00\u7ae0\u4e2d\u770b\u770b\u5982\u4f55\u5728\u6570\u5b57\u7279\u5f81\u548c\u4e0d\u540c\u7c7b\u578b\u7279\u5f81\u7684\u7ec4\u5408\u65b9\u9762\u8fdb\u884c\u66f4\u591a\u7684\u7279\u5f81\u5de5\u7a0b\u3002","title":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf"},{"location":"%E5%A4%84%E7%90%86%E5%88%86%E7%B1%BB%E5%8F%98%E9%87%8F/#_1","text":"\u5f88\u591a\u4eba\u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u65f6\u90fd\u4f1a\u9047\u5230\u5f88\u591a\u56f0\u96be\uff0c\u56e0\u6b64\u8fd9\u503c\u5f97\u7528\u6574\u6574\u4e00\u7ae0\u7684\u7bc7\u5e45\u6765\u8ba8\u8bba\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u5c06\u8bb2\u8ff0\u4e0d\u540c\u7c7b\u578b\u7684\u5206\u7c7b\u6570\u636e\uff0c\u4ee5\u53ca\u5982\u4f55\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u3002 \u4ec0\u4e48\u662f\u5206\u7c7b\u53d8\u91cf\uff1f \u5206\u7c7b\u53d8\u91cf/\u7279\u5f81\u662f\u6307\u4efb\u4f55\u7279\u5f81\u7c7b\u578b\uff0c\u53ef\u5206\u4e3a\u4e24\u5927\u7c7b\uff1a - \u65e0\u5e8f - \u6709\u5e8f \u65e0\u5e8f\u53d8\u91cf \u662f\u6307\u6709\u4e24\u4e2a\u6216\u4e24\u4e2a\u4ee5\u4e0a\u7c7b\u522b\u7684\u53d8\u91cf\uff0c\u8fd9\u4e9b\u7c7b\u522b\u6ca1\u6709\u4efb\u4f55\u76f8\u5173\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u5982\u679c\u5c06\u6027\u522b\u5206\u4e3a\u4e24\u7ec4\uff0c\u5373\u7537\u6027\u548c\u5973\u6027\uff0c\u5219\u53ef\u5c06\u5176\u89c6\u4e3a\u540d\u4e49\u53d8\u91cf\u3002 \u6709\u5e8f\u53d8\u91cf \u5219\u6709 \"\u7b49\u7ea7 \"\u6216\u7c7b\u522b\uff0c\u5e76\u6709\u7279\u5b9a\u7684\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u4e00\u4e2a\u987a\u5e8f\u5206\u7c7b\u53d8\u91cf\u53ef\u4ee5\u662f\u4e00\u4e2a\u5177\u6709\u4f4e\u3001\u4e2d\u3001\u9ad8\u4e09\u4e2a\u4e0d\u540c\u7b49\u7ea7\u7684\u7279\u5f81\u3002\u987a\u5e8f\u5f88\u91cd\u8981\u3002 \u5c31\u5b9a\u4e49\u800c\u8a00\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u53d8\u91cf\u5206\u4e3a \u4e8c\u5143\u53d8\u91cf \uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u7c7b\u522b\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u6709\u4e9b\u4eba\u751a\u81f3\u628a\u5206\u7c7b\u53d8\u91cf\u79f0\u4e3a \" \u5faa\u73af \"\u53d8\u91cf\u3002\u5468\u671f\u53d8\u91cf\u4ee5 \"\u5468\u671f \"\u7684\u5f62\u5f0f\u5b58\u5728\uff0c\u4f8b\u5982\u4e00\u5468\u4e2d\u7684\u5929\u6570\uff1a \u5468\u65e5\u3001\u5468\u4e00\u3001\u5468\u4e8c\u3001\u5468\u4e09\u3001\u5468\u56db\u3001\u5468\u4e94\u548c\u5468\u516d\u3002\u5468\u516d\u8fc7\u540e\uff0c\u53c8\u662f\u5468\u65e5\u3002\u8fd9\u5c31\u662f\u4e00\u4e2a\u5faa\u73af\u3002\u53e6\u4e00\u4e2a\u4f8b\u5b50\u662f\u4e00\u5929\u4e2d\u7684\u5c0f\u65f6\u6570\uff0c\u5982\u679c\u6211\u4eec\u5c06\u5b83\u4eec\u89c6\u4e3a\u7c7b\u522b\u7684\u8bdd\u3002 \u5206\u7c7b\u53d8\u91cf\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9a\u4e49\uff0c\u5f88\u591a\u4eba\u4e5f\u8c08\u5230\u8981\u6839\u636e\u5206\u7c7b\u53d8\u91cf\u7684\u7c7b\u578b\u6765\u5904\u7406\u4e0d\u540c\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u4e0d\u8fc7\uff0c\u6211\u8ba4\u4e3a\u6ca1\u6709\u5fc5\u8981\u8fd9\u6837\u505a\u3002\u6240\u6709\u6d89\u53ca\u5206\u7c7b\u53d8\u91cf\u7684\u95ee\u9898\u90fd\u53ef\u4ee5\u7528\u540c\u6837\u7684\u65b9\u6cd5\u5904\u7406\u3002 \u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e00\u4e2a\u6570\u636e\u96c6\uff08\u4e00\u5982\u65e2\u5f80\uff09\u3002\u8981\u4e86\u89e3\u5206\u7c7b\u53d8\u91cf\uff0c\u6700\u597d\u7684\u514d\u8d39\u6570\u636e\u96c6\u4e4b\u4e00\u662f Kaggle \u5206\u7c7b\u7279\u5f81\u7f16\u7801\u6311\u6218\u8d5b\u4e2d\u7684 cat-in-the-dat \u3002\u5171\u6709\u4e24\u4e2a\u6311\u6218\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6311\u6218\u7684\u6570\u636e\uff0c\u56e0\u4e3a\u5b83\u6bd4\u524d\u4e00\u4e2a\u7248\u672c\u6709\u66f4\u591a\u53d8\u91cf\uff0c\u96be\u5ea6\u4e5f\u66f4\u5927\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u3002 \u56fe 1\uff1aCat-in-the-dat-ii challenge\u90e8\u5206\u6570\u636e\u5c55\u793a \u6570\u636e\u96c6\u7531\u5404\u79cd\u5206\u7c7b\u53d8\u91cf\u7ec4\u6210\uff1a \u65e0\u5e8f \u6709\u5e8f \u5faa\u73af \u4e8c\u5143 \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6240\u6709\u5b58\u5728\u7684\u53d8\u91cf\u548c\u76ee\u6807\u53d8\u91cf\u7684\u5b50\u96c6\u3002 \u8fd9\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002 \u76ee\u6807\u53d8\u91cf\u5bf9\u4e8e\u6211\u4eec\u5b66\u4e60\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\u5e76\u4e0d\u5341\u5206\u91cd\u8981\uff0c\u4f46\u6700\u7ec8\u6211\u4eec\u5c06\u5efa\u7acb\u4e00\u4e2a\u7aef\u5230\u7aef\u6a21\u578b\uff0c\u56e0\u6b64\u8ba9\u6211\u4eec\u770b\u770b\u56fe 2 \u4e2d\u7684\u76ee\u6807\u53d8\u91cf\u5206\u5e03\u3002\u6211\u4eec\u770b\u5230\u76ee\u6807\u662f \u504f\u659c \u7684\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8fd9\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u6765\u8bf4\uff0c\u6700\u597d\u7684\u6307\u6807\u662f ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528\u7cbe\u786e\u5ea6\u548c\u53ec\u56de\u7387\uff0c\u4f46 AUC \u7ed3\u5408\u4e86\u8fd9\u4e24\u4e2a\u6307\u6807\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 AUC \u6765\u8bc4\u4f30\u6211\u4eec\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u7684\u6a21\u578b\u3002 \u56fe 2\uff1a\u6807\u7b7e\u8ba1\u6570\u3002X \u8f74\u8868\u793a\u6807\u7b7e\uff0cY \u8f74\u8868\u793a\u6807\u7b7e\u8ba1\u6570 \u603b\u4f53\u800c\u8a00\uff0c\u6709\uff1a 5\u4e2a\u4e8c\u5143\u53d8\u91cf 10\u4e2a\u65e0\u5e8f\u53d8\u91cf 6\u4e2a\u6709\u5e8f\u53d8\u91cf 2\u4e2a\u5faa\u73af\u53d8\u91cf 1\u4e2a\u76ee\u6807\u53d8\u91cf \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u96c6\u4e2d\u7684 ord_2 \u7279\u5f81\u3002\u5b83\u5305\u62ec6\u4e2a\u4e0d\u540c\u7684\u7c7b\u522b\uff1a - \u51b0\u51bb - \u6e29\u6696 - \u5bd2\u51b7 - \u8f83\u70ed - \u70ed - \u975e\u5e38\u70ed \u6211\u4eec\u5fc5\u987b\u77e5\u9053\uff0c\u8ba1\u7b97\u673a\u65e0\u6cd5\u7406\u89e3\u6587\u672c\u6570\u636e\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u662f\u521b\u5efa\u4e00\u4e2a\u5b57\u5178\uff0c\u5c06\u8fd9\u4e9b\u503c\u6620\u5c04\u4e3a\u4ece 0 \u5230 N-1 \u7684\u6570\u5b57\uff0c\u5176\u4e2d N \u662f\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7c7b\u522b\u7684\u603b\u6570\u3002 # \u6620\u5c04\u5b57\u5178 mapping = { \"Freezing\" : 0 , \"Warm\" : 1 , \"Cold\" : 2 , \"Boiling Hot\" : 3 , \"Hot\" : 4 , \"Lava Hot\" : 5 } \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8bfb\u53d6\u6570\u636e\u96c6\uff0c\u5e76\u8f7b\u677e\u5730\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002 import pandas as pd # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u53d6*ord_2*\u5217\uff0c\u5e76\u4f7f\u7528\u6620\u5c04\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57 df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. map ( mapping ) \u6620\u5c04\u524d\u7684\u6570\u503c\u8ba1\u6570\uff1a df .* ord_2 *. value_counts () Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : * ord_2 * , dtype : int64 \u6620\u5c04\u540e\u7684\u6570\u503c\u8ba1\u6570\uff1a 0.0 142726 1.0 124239 2.0 97822 3.0 84790 4.0 67508 5.0 64840 Name : * ord_2 * , dtype : int64 \u8fd9\u79cd\u5206\u7c7b\u53d8\u91cf\u7684\u7f16\u7801\u65b9\u5f0f\u88ab\u79f0\u4e3a\u6807\u7b7e\u7f16\u7801\uff08Label Encoding\uff09\u6211\u4eec\u5c06\u6bcf\u4e2a\u7c7b\u522b\u7f16\u7801\u4e3a\u4e00\u4e2a\u6570\u5b57\u6807\u7b7e\u3002 \u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 LabelEncoder \u8fdb\u884c\u7f16\u7801\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u5c06\u7f3a\u5931\u503c\u586b\u5145\u4e3a\"NONE\" df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. fillna ( \"NONE\" ) # LabelEncoder\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u8f6c\u6362\u6570\u636e df . loc [:, \"*ord_2*\" ] = lbl_enc . fit_transform ( df .* ord_2 *. values ) \u4f60\u4f1a\u770b\u5230\u6211\u4f7f\u7528\u4e86 pandas \u7684 fillna\u3002\u539f\u56e0\u662f scikit-learn \u7684 LabelEncoder \u65e0\u6cd5\u5904\u7406 NaN \u503c\uff0c\u800c ord_2 \u5217\u4e2d\u6709 NaN \u503c\u3002 \u6211\u4eec\u53ef\u4ee5\u5728\u8bb8\u591a\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u4e2d\u76f4\u63a5\u4f7f\u7528\u5b83\uff1a - \u51b3\u7b56\u6811 - \u968f\u673a\u68ee\u6797 - \u63d0\u5347\u6811 - \u6216\u4efb\u4f55\u4e00\u79cd\u63d0\u5347\u6811\u6a21\u578b - XGBoost - GBM - LightGBM \u8fd9\u79cd\u7f16\u7801\u65b9\u5f0f\u4e0d\u80fd\u7528\u4e8e\u7ebf\u6027\u6a21\u578b\u3001\u652f\u6301\u5411\u91cf\u673a\u6216\u795e\u7ecf\u7f51\u7edc\uff0c\u56e0\u4e3a\u5b83\u4eec\u5e0c\u671b\u6570\u636e\u662f\u6807\u51c6\u5316\u7684\u3002 \u5bf9\u4e8e\u8fd9\u4e9b\u7c7b\u578b\u7684\u6a21\u578b\uff0c\u6211\u4eec\u53ef\u4ee5\u5bf9\u6570\u636e\u8fdb\u884c\u4e8c\u503c\u5316\uff08binarize\uff09\u5904\u7406\u3002 Freezing --> 0 --> 0 0 0 Warm --> 1 --> 0 0 1 Cold --> 2 --> 0 1 0 Boiling Hot --> 3 --> 0 1 1 Hot --> 4 --> 1 0 0 Lava Hot --> 5 --> 1 0 1 \u8fd9\u53ea\u662f\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\uff0c\u7136\u540e\u518d\u8f6c\u6362\u4e3a\u4e8c\u503c\u5316\u8868\u793a\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u628a\u4e00\u4e2a\u7279\u5f81\u5206\u6210\u4e86\u4e09\u4e2a\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u7279\u5f81\uff08\u6216\u5217\uff09\u3002\u5982\u679c\u6211\u4eec\u6709\u66f4\u591a\u7684\u7c7b\u522b\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u5206\u6210\u66f4\u591a\u7684\u5217\u3002 \u5982\u679c\u6211\u4eec\u7528\u7a00\u758f\u683c\u5f0f\u5b58\u50a8\u5927\u91cf\u4e8c\u503c\u5316\u53d8\u91cf\uff0c\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u5b58\u50a8\u8fd9\u4e9b\u53d8\u91cf\u3002\u7a00\u758f\u683c\u5f0f\u4e0d\u8fc7\u662f\u4e00\u79cd\u5728\u5185\u5b58\u4e2d\u5b58\u50a8\u6570\u636e\u7684\u8868\u793a\u6216\u65b9\u5f0f\uff0c\u5728\u8fd9\u79cd\u683c\u5f0f\u4e2d\uff0c\u4f60\u5e76\u4e0d\u5b58\u50a8\u6240\u6709\u7684\u503c\uff0c\u800c\u53ea\u5b58\u50a8\u91cd\u8981\u7684\u503c\u3002\u5728\u4e0a\u8ff0\u4e8c\u8fdb\u5236\u53d8\u91cf\u7684\u60c5\u51b5\u4e2d\uff0c\u6700\u91cd\u8981\u7684\u5c31\u662f\u6709 1 \u7684\u5730\u65b9\u3002 \u5f88\u96be\u60f3\u8c61\u8fd9\u6837\u7684\u683c\u5f0f\uff0c\u4f46\u4e3e\u4e2a\u4f8b\u5b50\u5c31\u4f1a\u660e\u767d\u3002 \u5047\u8bbe\u4e0a\u9762\u7684\u6570\u636e\u5e27\u4e2d\u53ea\u6709\u4e00\u4e2a\u7279\u5f81\uff1a ord_2 \u3002 Index Feature 0 Warm 1 Hot 2 Lava hot \u76ee\u524d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6570\u636e\u96c6\u4e2d\u7684\u4e09\u4e2a\u6837\u672c\u3002\u8ba9\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u4e8c\u503c\u8868\u793a\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u6709\u4e09\u4e2a\u9879\u76ee\u3002 \u8fd9\u4e09\u4e2a\u9879\u76ee\u5c31\u662f\u4e09\u4e2a\u7279\u5f81\u3002 Index Feature_0 Feature_1 Feature_2 0 0 0 1 1 1 0 0 2 1 0 1 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u7279\u5f81\u5b58\u50a8\u5728\u4e00\u4e2a\u6709 3 \u884c 3 \u5217\uff083x3\uff09\u7684\u77e9\u9635\u4e2d\u3002\u77e9\u9635\u7684\u6bcf\u4e2a\u5143\u7d20\u5360\u7528 8 \u4e2a\u5b57\u8282\u3002\u56e0\u6b64\uff0c\u8fd9\u4e2a\u6570\u7ec4\u7684\u603b\u5185\u5b58\u9700\u6c42\u4e3a 8x3x3 = 72 \u5b57\u8282\u3002 \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e2a\u7b80\u5355\u7684 python \u4ee3\u7801\u6bb5\u6765\u68c0\u67e5\u8fd9\u4e00\u70b9\u3002 import numpy as np example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) print ( example . nbytes ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u6253\u5370\u51fa 72\uff0c\u5c31\u50cf\u6211\u4eec\u4e4b\u524d\u8ba1\u7b97\u7684\u90a3\u6837\u3002\u4f46\u6211\u4eec\u9700\u8981\u5b58\u50a8\u8fd9\u4e2a\u77e9\u9635\u7684\u6240\u6709\u5143\u7d20\u5417\uff1f\u5982\u524d\u6240\u8ff0\uff0c\u6211\u4eec\u53ea\u5bf9 1 \u611f\u5174\u8da3\u30020 \u5e76\u4e0d\u91cd\u8981\uff0c\u56e0\u4e3a\u4efb\u4f55\u4e0e 0 \u76f8\u4e58\u7684\u5143\u7d20\u90fd\u662f 0\uff0c\u800c 0 \u4e0e\u4efb\u4f55\u5143\u7d20\u76f8\u52a0\u6216\u76f8\u51cf\u4e5f\u6ca1\u6709\u4efb\u4f55\u533a\u522b\u3002\u53ea\u7528 1 \u8868\u793a\u77e9\u9635\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u67d0\u79cd\u5b57\u5178\u65b9\u6cd5\uff0c\u5176\u4e2d\u952e\u662f\u884c\u548c\u5217\u7684\u7d22\u5f15\uff0c\u503c\u662f 1\uff1a ( 0 , 2 ) 1 ( 1 , 0 ) 1 ( 2 , 0 ) 1 ( 2 , 2 ) 1 \u8fd9\u6837\u7684\u7b26\u53f7\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u5b58\u50a8\u56db\u4e2a\u503c\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u3002\u4f7f\u7528\u7684\u603b\u5185\u5b58\u4e3a 8x4 = 32 \u5b57\u8282\u3002\u4efb\u4f55 numpy \u6570\u7ec4\u90fd\u53ef\u4ee5\u901a\u8fc7\u7b80\u5355\u7684 python \u4ee3\u7801\u8f6c\u6362\u4e3a\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) sparse_example = sparse . csr_matrix ( example ) print ( sparse_example . data . nbytes ) \u8fd9\u5c06\u6253\u5370 32\uff0c\u6bd4\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u5c11\u4e86\u8fd9\u4e48\u591a\uff01\u7a00\u758f csr \u77e9\u9635\u7684\u603b\u5927\u5c0f\u662f\u4e09\u4e2a\u503c\u7684\u603b\u548c\u3002 print ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) \u8fd9\u5c06\u6253\u5370\u51fa 64 \u4e2a\u5143\u7d20\uff0c\u4ecd\u7136\u5c11\u4e8e\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u3002\u9057\u61be\u7684\u662f\uff0c\u6211\u4e0d\u4f1a\u8be6\u7ec6\u4ecb\u7ecd\u8fd9\u4e9b\u5143\u7d20\u3002\u4f60\u53ef\u4ee5\u5728 scipy \u6587\u6863\u4e2d\u4e86\u89e3\u66f4\u591a\u3002\u5f53\u6211\u4eec\u62e5\u6709\u66f4\u5927\u7684\u6570\u7ec4\u65f6\uff0c\u6bd4\u5982\u8bf4\u62e5\u6709\u6570\u5343\u4e2a\u6837\u672c\u548c\u6570\u4e07\u4e2a\u7279\u5f81\u7684\u6570\u7ec4\uff0c\u5927\u5c0f\u5dee\u5f02\u5c31\u4f1a\u53d8\u5f97\u975e\u5e38\u5927\u3002\u4f8b\u5982\uff0c\u6211\u4eec\u4f7f\u7528\u57fa\u4e8e\u8ba1\u6570\u7279\u5f81\u7684\u6587\u672c\u6570\u636e\u96c6\u3002 import numpy as np from scipy import sparse n_rows = 10000 n_cols = 100000 # \u751f\u6210\u7b26\u5408\u4f2f\u52aa\u5229\u5206\u5e03\u7684\u968f\u673a\u6570\u7ec4\uff0c\u7ef4\u5ea6\u4e3a[10000, 100000] example = np . random . binomial ( 1 , p = 0.05 , size = ( n_rows , n_cols )) print ( f \"Size of dense array: { example . nbytes } \" ) # \u5c06\u968f\u673a\u77e9\u9635\u8f6c\u6362\u4e3a\u6d17\u6f31\u77e9\u9635 sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u8fd9\u5c06\u6253\u5370\uff1a Size of dense array : 8000000000 Size of sparse array : 399932496 Full size of sparse array : 599938748 \u56e0\u6b64\uff0c\u5bc6\u96c6\u9635\u5217\u9700\u8981 ~8000MB \u6216\u5927\u7ea6 8GB \u5185\u5b58\u3002\u800c\u7a00\u758f\u9635\u5217\u53ea\u5360\u7528 399MB \u5185\u5b58\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5f53\u6211\u4eec\u7684\u7279\u5f81\u4e2d\u6709\u5927\u91cf\u96f6\u65f6\uff0c\u6211\u4eec\u66f4\u559c\u6b22\u7a00\u758f\u9635\u5217\u800c\u4e0d\u662f\u5bc6\u96c6\u9635\u5217\u7684\u539f\u56e0\u3002 \u8bf7\u6ce8\u610f\uff0c\u7a00\u758f\u77e9\u9635\u6709\u591a\u79cd\u4e0d\u540c\u7684\u8868\u793a\u65b9\u6cd5\u3002\u8fd9\u91cc\u6211\u53ea\u5c55\u793a\u4e86\u5176\u4e2d\u4e00\u79cd\uff08\u53ef\u80fd\u4e5f\u662f\u6700\u5e38\u7528\u7684\uff09\u65b9\u6cd5\u3002\u6df1\u5165\u63a2\u8ba8\u8fd9\u4e9b\u65b9\u6cd5\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\uff0c\u56e0\u6b64\u7559\u7ed9\u8bfb\u8005\u4e00\u4e2a\u7ec3\u4e60\u3002 \u5c3d\u7ba1\u4e8c\u503c\u5316\u7279\u5f81\u7684\u7a00\u758f\u8868\u793a\u6bd4\u5176\u5bc6\u96c6\u8868\u793a\u6240\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u4f46\u5bf9\u4e8e\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\uff0c\u8fd8\u6709\u4e00\u79cd\u8f6c\u6362\u6240\u5360\u7528\u7684\u5185\u5b58\u66f4\u5c11\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \" \u72ec\u70ed\u7f16\u7801 \"\u3002 \u72ec\u70ed\u7f16\u7801\u4e5f\u662f\u4e00\u79cd\u4e8c\u503c\u7f16\u7801\uff0c\u56e0\u4e3a\u53ea\u6709 0 \u548c 1 \u4e24\u4e2a\u503c\u3002\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5b83\u5e76\u4e0d\u662f\u4e8c\u503c\u8868\u793a\u6cd5\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0b\u9762\u7684\u4f8b\u5b50\u6765\u7406\u89e3\u5b83\u7684\u8868\u793a\u6cd5\u3002 \u5047\u8bbe\u6211\u4eec\u7528\u4e00\u4e2a\u5411\u91cf\u6765\u8868\u793a ord_2 \u53d8\u91cf\u7684\u6bcf\u4e2a\u7c7b\u522b\u3002\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e0e ord_2 \u53d8\u91cf\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u6bcf\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u90fd\u662f 6\uff0c\u5e76\u4e14\u9664\u4e86\u4e00\u4e2a\u4f4d\u7f6e\u5916\uff0c\u5176\u4ed6\u4f4d\u7f6e\u90fd\u662f 0\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u7279\u6b8a\u7684\u5411\u91cf\u8868\u3002 Freezing 0 0 0 0 0 1 Warm 0 0 0 0 1 0 Cold 0 0 0 1 0 0 Boiling Hot 0 0 1 0 0 0 Hot 0 1 0 0 0 0 Lava Hot 1 0 0 0 0 0 \u6211\u4eec\u770b\u5230\u5411\u91cf\u7684\u5927\u5c0f\u662f 1x6\uff0c\u5373\u5411\u91cf\u4e2d\u67096\u4e2a\u5143\u7d20\u3002\u8fd9\u4e2a\u6570\u5b57\u662f\u600e\u4e48\u6765\u7684\u5462\uff1f\u5982\u679c\u4f60\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u5982\u524d\u6240\u8ff0\uff0c\u67096\u4e2a\u7c7b\u522b\u3002\u5728\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u65f6\uff0c\u5411\u91cf\u7684\u5927\u5c0f\u5fc5\u987b\u4e0e\u6211\u4eec\u8981\u67e5\u770b\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u6bcf\u4e2a\u5411\u91cf\u90fd\u6709\u4e00\u4e2a 1\uff0c\u5176\u4f59\u6240\u6709\u503c\u90fd\u662f 0\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u7528\u8fd9\u4e9b\u7279\u5f81\u6765\u4ee3\u66ff\u4e4b\u524d\u7684\u4e8c\u503c\u5316\u7279\u5f81\uff0c\u770b\u770b\u80fd\u8282\u7701\u591a\u5c11\u5185\u5b58\u3002 \u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\u4ee5\u524d\u7684\u6570\u636e\uff0c\u5b83\u770b\u8d77\u6765\u5982\u4e0b\uff1a Index Feature 0 Warm 1 Hot 2 Lava hot \u6bcf\u4e2a\u6837\u672c\u67093\u4e2a\u7279\u5f81\u3002\u4f46\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u72ec\u70ed\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 6\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u67096\u4e2a\u7279\u5f81\uff0c\u800c\u4e0d\u662f3\u4e2a\u3002 Index F_0 F_1 F_2 F_3 F_4 F_5 0 0 0 0 0 1 0 1 0 1 0 0 0 0 2 1 0 1 0 0 0 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 6 \u4e2a\u7279\u5f81\uff0c\u800c\u5728\u8fd9\u4e2a 3x6 \u6570\u7ec4\u4e2d\uff0c\u53ea\u6709 3 \u4e2a1\u3002\u4f7f\u7528 numpy \u8ba1\u7b97\u5927\u5c0f\u4e0e\u4e8c\u503c\u5316\u5927\u5c0f\u8ba1\u7b97\u811a\u672c\u975e\u5e38\u76f8\u4f3c\u3002\u4f60\u9700\u8981\u6539\u53d8\u7684\u53ea\u662f\u6570\u7ec4\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6bb5\u4ee3\u7801\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 0 , 0 , 1 , 0 ], [ 0 , 1 , 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 , 0 , 0 ] ] ) print ( f \"Size of dense array: { example . nbytes } \" ) sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u6253\u5370\u5185\u5b58\u5927\u5c0f\u4e3a\uff1a Size of dense array : 144 Size of sparse array : 24 Full size of sparse array : 52 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5bc6\u96c6\u77e9\u9635\u7684\u5927\u5c0f\u8fdc\u8fdc\u5927\u4e8e\u4e8c\u503c\u5316\u77e9\u9635\u7684\u5927\u5c0f\u3002\u4e0d\u8fc7\uff0c\u7a00\u758f\u6570\u7ec4\u7684\u5927\u5c0f\u8981\u66f4\u5c0f\u3002\u8ba9\u6211\u4eec\u7528\u66f4\u5927\u7684\u6570\u7ec4\u6765\u8bd5\u8bd5\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 OneHotEncoder \u5c06\u5305\u542b 1001 \u4e2a\u7c7b\u522b\u7684\u7279\u5f81\u6570\u7ec4\u8f6c\u6362\u4e3a\u5bc6\u96c6\u77e9\u9635\u548c\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from sklearn import preprocessing # \u751f\u6210\u7b26\u5408\u5747\u5300\u5206\u5e03\u7684\u968f\u673a\u6574\u6570\uff0c\u7ef4\u5ea6\u4e3a[1000000, 10000000] example = np . random . randint ( 1000 , size = 1000000 ) # \u72ec\u70ed\u7f16\u7801\uff0c\u975e\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = False ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of dense array: { ohe_example . nbytes } \" ) # \u72ec\u70ed\u7f16\u7801\uff0c\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = True ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of sparse array: { ohe_example . data . nbytes } \" ) full_size = ( ohe_example . data . nbytes + ohe_example . indptr . nbytes + ohe_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u4e0a\u9762\u4ee3\u7801\u6253\u5370\u7684\u8f93\u51fa\uff1a Size of dense array : 8000000000 Size of sparse array : 8000000 Full size of sparse array : 16000004 \u8fd9\u91cc\u7684\u5bc6\u96c6\u9635\u5217\u5927\u5c0f\u7ea6\u4e3a 8GB\uff0c\u7a00\u758f\u9635\u5217\u4e3a 8MB\u3002\u5982\u679c\u53ef\u4ee5\u9009\u62e9\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff1f\u5728\u6211\u770b\u6765\uff0c\u9009\u62e9\u5f88\u7b80\u5355\uff0c\u4e0d\u662f\u5417\uff1f \u8fd9\u4e09\u79cd\u65b9\u6cd5\uff08\u6807\u7b7e\u7f16\u7801\u3001\u7a00\u758f\u77e9\u9635\u3001\u72ec\u70ed\u7f16\u7801\uff09\u662f\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u7684\u6700\u91cd\u8981\u65b9\u6cd5\u3002\u4e0d\u8fc7\uff0c\u4f60\u8fd8\u53ef\u4ee5\u7528\u5f88\u591a\u5176\u4ed6\u4e0d\u540c\u7684\u65b9\u6cd5\u6765\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u3002\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u6570\u503c\u53d8\u91cf\u5c31\u662f\u5176\u4e2d\u7684\u4e00\u4e2a\u4f8b\u5b50\u3002 \u5047\u8bbe\u6211\u4eec\u56de\u5230\u4e4b\u524d\u7684\u5206\u7c7b\u7279\u5f81\u6570\u636e\uff08\u539f\u59cb\u6570\u636e\u4e2d\u7684 cat-in-the-dat-ii\uff09\u3002\u5728\u6570\u636e\u4e2d\uff0c ord_2 \u7684\u503c\u4e3a\u201c\u70ed\u201c\u7684 id \u6709\u591a\u5c11\uff1f \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8ba1\u7b97\u6570\u636e\u7684\u5f62\u72b6\uff08shape\uff09\u8f7b\u677e\u8ba1\u7b97\u51fa\u8fd9\u4e2a\u503c\uff0c\u5176\u4e2d ord_2 \u5217\u7684\u503c\u4e3a Boiling Hot \u3002 In [ X ]: df [ df . ord_2 == \"Boiling Hot\" ] . shape Out [ X ]: ( 84790 , 25 ) \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 84790 \u6761\u8bb0\u5f55\u5177\u6709\u6b64\u503c\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 pandas \u4e2d\u7684 groupby \u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u8be5\u503c\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . count () Out [ X ]: ord_2 Boiling Hot 84790 Cold 97822 Freezing 142726 Hot 67508 Lava Hot 64840 Warm 124239 Name : id , dtype : int64 \u5982\u679c\u6211\u4eec\u53ea\u662f\u5c06 ord_2 \u5217\u66ff\u6362\u4e3a\u5176\u8ba1\u6570\u503c\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u5c06\u5176\u8f6c\u6362\u4e3a\u4e00\u79cd\u6570\u503c\u7279\u5f81\u4e86\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 transform \u51fd\u6570\u548c groupby \u6765\u521b\u5efa\u65b0\u5217\u6216\u66ff\u6362\u8fd9\u4e00\u5217\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . transform ( \"count\" ) Out [ X ]: 0 67508.0 1 124239.0 2 142726.0 3 64840.0 4 97822.0 ... 599995 142726.0 599996 84790.0 599997 142726.0 599998 124239.0 599999 84790.0 Name : id , Length : 600000 , dtype : float64 \u4f60\u53ef\u4ee5\u6dfb\u52a0\u6240\u6709\u7279\u5f81\u7684\u8ba1\u6570\uff0c\u4e5f\u53ef\u4ee5\u66ff\u6362\u5b83\u4eec\uff0c\u6216\u8005\u6839\u636e\u591a\u4e2a\u5217\u53ca\u5176\u8ba1\u6570\u8fdb\u884c\u5206\u7ec4\u3002\u4f8b\u5982\uff0c\u4ee5\u4e0b\u4ee3\u7801\u901a\u8fc7\u5bf9 ord_1 \u548c ord_2 \u5217\u5206\u7ec4\u8fdb\u884c\u8ba1\u6570\u3002 In [ X ]: df . groupby ( ... : [ ... : \"ord_1\" , ... : \"ord_2\" ... : ] ... : )[ \"id\" ] . count () . reset_index ( name = \"count\" ) Out [ X ]: ord_1 ord_2 count 0 Contributor Boiling Hot 15634 1 Contributor Cold 17734 2 Contributor Freezing 26082 3 Contributor Hot 12428 4 Contributor Lava Hot 11919 5 Contributor Warm 22774 6 Expert Boiling Hot 19477 7 Expert Cold 22956 8 Expert Freezing 33249 9 Expert Hot 15792 10 Expert Lava Hot 15078 11 Expert Warm 28900 12 Grandmaster Boiling Hot 13623 13 Grandmaster Cold 15464 14 Grandmaster Freezing 22818 15 Grandmaster Hot 10805 16 Grandmaster Lava Hot 10363 17 Grandmaster Warm 19899 18 Master Boiling Hot 10800 ... \u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u884c\uff0c\u4ee5\u4fbf\u5728\u4e00\u9875\u4e2d\u5bb9\u7eb3\u8fd9\u4e9b\u884c\u3002\u8fd9\u662f\u53e6\u4e00\u79cd\u53ef\u4ee5\u4f5c\u4e3a\u529f\u80fd\u6dfb\u52a0\u7684\u8ba1\u6570\u3002\u60a8\u73b0\u5728\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4f7f\u7528 id \u5217\u8fdb\u884c\u8ba1\u6570\u3002\u4e0d\u8fc7\uff0c\u4f60\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5bf9\u5217\u7684\u7ec4\u5408\u8fdb\u884c\u5206\u7ec4\uff0c\u5bf9\u5176\u4ed6\u5217\u8fdb\u884c\u8ba1\u6570\u3002 \u8fd8\u6709\u4e00\u4e2a\u5c0f\u7a8d\u95e8\uff0c\u5c31\u662f\u4ece\u8fd9\u4e9b\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u4ece\u73b0\u6709\u7684\u7279\u5f81\u4e2d\u521b\u5efa\u65b0\u7684\u5206\u7c7b\u7279\u5f81\uff0c\u800c\u4e14\u53ef\u4ee5\u6beb\u4e0d\u8d39\u529b\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot 1 Grandmaster_Warm 2 nan_Freezing 3 Novice_Lava Hot 4 Grandmaster_Cold ... 599999 Contributor_Boiling Hot Name : new_feature , Length : 600000 , dtype : object \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u7528\u4e0b\u5212\u7ebf\u5c06 ord_1 \u548c ord_2 \u5408\u5e76\uff0c\u7136\u540e\u5c06\u8fd9\u4e9b\u5217\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u7c7b\u578b\u3002\u8bf7\u6ce8\u610f\uff0cNaN \u4e5f\u4f1a\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u3002\u4e0d\u8fc7\u6ca1\u5173\u7cfb\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06 NaN \u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u6709\u4e86\u4e00\u4e2a\u7531\u8fd9\u4e24\u4e2a\u7279\u5f81\u7ec4\u5408\u800c\u6210\u7684\u65b0\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u5c06\u4e09\u5217\u4ee5\u4e0a\u6216\u56db\u5217\u751a\u81f3\u66f4\u591a\u5217\u7ec4\u5408\u5728\u4e00\u8d77\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : + \"_\" ... : + df . ord_3 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot_c 1 Grandmaster_Warm_e 2 nan_Freezing_n 3 Novice_Lava Hot_a 4 Grandmaster_Cold_h ... 599999 Contributor_Boiling Hot_b Name : new_feature , Length : 600000 , dtype : object \u90a3\u4e48\uff0c\u6211\u4eec\u5e94\u8be5\u628a\u54ea\u4e9b\u7c7b\u522b\u7ed3\u5408\u8d77\u6765\u5462\uff1f\u8fd9\u5e76\u6ca1\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u7b54\u6848\u3002\u8fd9\u53d6\u51b3\u4e8e\u60a8\u7684\u6570\u636e\u548c\u7279\u5f81\u7c7b\u578b\u3002\u4e00\u4e9b\u9886\u57df\u77e5\u8bc6\u5bf9\u4e8e\u521b\u5efa\u8fd9\u6837\u7684\u7279\u5f81\u53ef\u80fd\u5f88\u6709\u7528\u3002\u4f46\u662f\uff0c\u5982\u679c\u4f60\u4e0d\u62c5\u5fc3\u5185\u5b58\u548c CPU \u7684\u4f7f\u7528\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\uff0c\u5373\u521b\u5efa\u8bb8\u591a\u8fd9\u6837\u7684\u7ec4\u5408\uff0c\u7136\u540e\u4f7f\u7528\u4e00\u4e2a\u6a21\u578b\u6765\u51b3\u5b9a\u54ea\u4e9b\u7279\u5f81\u662f\u6709\u7528\u7684\uff0c\u5e76\u4fdd\u7559\u5b83\u4eec\u3002\u6211\u4eec\u5c06\u5728\u672c\u4e66\u7a0d\u540e\u90e8\u5206\u4ecb\u7ecd\u8fd9\u79cd\u65b9\u6cd5\u3002 \u65e0\u8bba\u4f55\u65f6\u83b7\u5f97\u5206\u7c7b\u53d8\u91cf\uff0c\u90fd\u8981\u9075\u5faa\u4ee5\u4e0b\u7b80\u5355\u6b65\u9aa4\uff1a - \u586b\u5145 NaN \u503c\uff08\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff01\uff09\u3002 - \u4f7f\u7528 scikit-learn \u7684 LabelEncoder \u6216\u6620\u5c04\u5b57\u5178\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\uff0c\u5c06\u5b83\u4eec\u8f6c\u6362\u4e3a\u6574\u6570\u3002\u5982\u679c\u6ca1\u6709\u586b\u5145 NaN \u503c\uff0c\u53ef\u80fd\u9700\u8981\u5728\u8fd9\u4e00\u6b65\u4e2d\u8fdb\u884c\u5904\u7406 - \u521b\u5efa\u72ec\u70ed\u7f16\u7801\u3002\u662f\u7684\uff0c\u4f60\u53ef\u4ee5\u8df3\u8fc7\u4e8c\u503c\u5316\uff01 - \u5efa\u6a21\uff01\u6211\u6307\u7684\u662f\u673a\u5668\u5b66\u4e60\u3002 \u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u5904\u7406 NaN \u6570\u636e\u975e\u5e38\u91cd\u8981\uff0c\u5426\u5219\u60a8\u53ef\u80fd\u4f1a\u4ece scikit-learn \u7684 LabelEncoder \u4e2d\u5f97\u5230\u81ed\u540d\u662d\u8457\u7684\u9519\u8bef\u4fe1\u606f\uff1a ValueError: y \u5305\u542b\u4ee5\u524d\u672a\u89c1\u8fc7\u7684\u6807\u7b7e\uff1a [Nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan) \u8fd9\u4ec5\u4ec5\u610f\u5473\u7740\uff0c\u5728\u8f6c\u6362\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6570\u636e\u4e2d\u51fa\u73b0\u4e86 NaN \u503c\u3002\u8fd9\u662f\u56e0\u4e3a\u4f60\u5728\u8bad\u7ec3\u65f6\u5fd8\u8bb0\u4e86\u5904\u7406\u5b83\u4eec\u3002 \u5904\u7406 NaN \u503c \u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u4e22\u5f03\u5b83\u4eec\u3002\u867d\u7136\u7b80\u5355\uff0c\u4f46\u5e76\u4e0d\u7406\u60f3\u3002NaN \u503c\u4e2d\u53ef\u80fd\u5305\u542b\u5f88\u591a\u4fe1\u606f\uff0c\u5982\u679c\u53ea\u662f\u4e22\u5f03\u8fd9\u4e9b\u503c\uff0c\u5c31\u4f1a\u4e22\u5931\u8fd9\u4e9b\u4fe1\u606f\u3002\u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u5927\u90e8\u5206\u6570\u636e\u90fd\u662f NaN \u503c\uff0c\u56e0\u6b64\u4e0d\u80fd\u4e22\u5f03 NaN \u503c\u7684\u884c/\u6837\u672c\u3002\u5904\u7406 NaN \u503c\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u5c06\u5176\u4f5c\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u662f\u5904\u7406 NaN \u503c\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002\u5982\u679c\u4f7f\u7528 pandas\uff0c\u8fd8\u53ef\u4ee5\u901a\u8fc7\u975e\u5e38\u7b80\u5355\u7684\u65b9\u5f0f\u5b9e\u73b0\u3002 \u8bf7\u770b\u6211\u4eec\u4e4b\u524d\u67e5\u770b\u8fc7\u7684\u6570\u636e\u7684 ord_2 \u5217\u3002 In [ X ]: df . ord_2 . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : ord_2 , dtype : int64 \u586b\u5165 NaN \u503c\u540e\uff0c\u5c31\u53d8\u6210\u4e86 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u54c7\uff01\u8fd9\u4e00\u5217\u4e2d\u6709 18075 \u4e2a NaN \u503c\uff0c\u800c\u6211\u4eec\u4e4b\u524d\u751a\u81f3\u90fd\u6ca1\u6709\u8003\u8651\u4f7f\u7528\u5b83\u4eec\u3002\u589e\u52a0\u4e86\u8fd9\u4e2a\u65b0\u7c7b\u522b\u540e\uff0c\u7c7b\u522b\u603b\u6570\u4ece 6 \u4e2a\u589e\u52a0\u5230\u4e86 7 \u4e2a\u3002\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u73b0\u5728\u6211\u4eec\u5728\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u4e5f\u4f1a\u8003\u8651 NaN\u3002\u76f8\u5173\u4fe1\u606f\u8d8a\u591a\uff0c\u6a21\u578b\u5c31\u8d8a\u597d\u3002 \u5047\u8bbe ord_2 \u6ca1\u6709\u4efb\u4f55 NaN \u503c\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e00\u5217\u4e2d\u7684\u6240\u6709\u7c7b\u522b\u90fd\u6709\u663e\u8457\u7684\u8ba1\u6570\u3002\u5176\u4e2d\u6ca1\u6709 \"\u7f55\u89c1 \"\u7c7b\u522b\uff0c\u5373\u53ea\u5728\u6837\u672c\u603b\u6570\u4e2d\u5360\u5f88\u5c0f\u6bd4\u4f8b\u7684\u7c7b\u522b\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5047\u8bbe\u60a8\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u4e86\u4f7f\u7528\u8fd9\u4e00\u5217\u7684\u6a21\u578b\uff0c\u5f53\u6a21\u578b\u6216\u9879\u76ee\u4e0a\u7ebf\u65f6\uff0c\u60a8\u5728 ord_2 \u5217\u4e2d\u5f97\u5230\u4e86\u4e00\u4e2a\u5728\u8bad\u7ec3\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u7ba1\u9053\u4f1a\u629b\u51fa\u4e00\u4e2a\u9519\u8bef\uff0c\u60a8\u5bf9\u6b64\u65e0\u80fd\u4e3a\u529b\u3002\u5982\u679c\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\uff0c\u90a3\u4e48\u53ef\u80fd\u662f\u751f\u4ea7\u4e2d\u7684\u7ba1\u9053\u51fa\u4e86\u95ee\u9898\u3002\u5982\u679c\u8fd9\u662f\u9884\u6599\u4e4b\u4e2d\u7684\uff0c\u90a3\u4e48\u60a8\u5c31\u5fc5\u987b\u4fee\u6539\u60a8\u7684\u6a21\u578b\u7ba1\u9053\uff0c\u5e76\u5728\u8fd9\u516d\u4e2a\u7c7b\u522b\u4e2d\u52a0\u5165\u4e00\u4e2a\u65b0\u7c7b\u522b\u3002 \u8fd9\u4e2a\u65b0\u7c7b\u522b\u88ab\u79f0\u4e3a \"\u7f55\u89c1 \"\u7c7b\u522b\u3002\u7f55\u89c1\u7c7b\u522b\u662f\u4e00\u79cd\u4e0d\u5e38\u89c1\u7684\u7c7b\u522b\uff0c\u53ef\u4ee5\u5305\u62ec\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8fd1\u90bb\u6a21\u578b\u6765 \"\u9884\u6d4b \"\u672a\u77e5\u7c7b\u522b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u4e86\u8fd9\u4e2a\u7c7b\u522b\uff0c\u5b83\u5c31\u4f1a\u6210\u4e3a\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u4e00\u4e2a\u7c7b\u522b\u3002 \u56fe 3\uff1a\u5177\u6709\u4e0d\u540c\u7279\u5f81\u4e14\u65e0\u6807\u7b7e\u7684\u6570\u636e\u96c6\u793a\u610f\u56fe\uff0c\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u53ef\u80fd\u4f1a\u5728\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u51fa\u73b0\u65b0\u503c \u5f53\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u56fe 3 \u6240\u793a\u7684\u6570\u636e\u96c6\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\uff0c\u5bf9\u9664 \"f3 \"\u4e4b\u5916\u7684\u6240\u6709\u7279\u5f81\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u6837\uff0c\u4f60\u5c06\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u4e0d\u77e5\u9053\u6216\u8bad\u7ec3\u4e2d\u6ca1\u6709 \"f3 \"\u65f6\u9884\u6d4b\u5b83\u3002\u6211\u4e0d\u6562\u8bf4\u8fd9\u6837\u7684\u6a21\u578b\u662f\u5426\u80fd\u5e26\u6765\u51fa\u8272\u7684\u6027\u80fd\uff0c\u4f46\u4e5f\u8bb8\u80fd\u5904\u7406\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u7684\u7f3a\u5931\u503c\uff0c\u5c31\u50cf\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5176\u4ed6\u4e8b\u60c5\u4e00\u6837\uff0c\u4e0d\u5c1d\u8bd5\u4e00\u4e0b\u662f\u8bf4\u4e0d\u51c6\u7684\u3002 \u5982\u679c\u4f60\u6709\u4e00\u4e2a\u56fa\u5b9a\u7684\u6d4b\u8bd5\u96c6\uff0c\u4f60\u53ef\u4ee5\u5c06\u6d4b\u8bd5\u6570\u636e\u6dfb\u52a0\u5230\u8bad\u7ec3\u4e2d\uff0c\u4ee5\u4e86\u89e3\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u7c7b\u522b\u3002\u8fd9\u4e0e\u534a\u76d1\u7763\u5b66\u4e60\u975e\u5e38\u76f8\u4f3c\uff0c\u5373\u4f7f\u7528\u65e0\u6cd5\u7528\u4e8e\u8bad\u7ec3\u7684\u6570\u636e\u6765\u6539\u8fdb\u6a21\u578b\u3002\u8fd9\u4e5f\u4f1a\u7167\u987e\u5230\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u51fa\u73b0\u6b21\u6570\u6781\u5c11\u4f46\u5728\u6d4b\u8bd5\u6570\u636e\u4e2d\u5927\u91cf\u5b58\u5728\u7684\u7a00\u6709\u503c\u3002\u4f60\u7684\u6a21\u578b\u5c06\u66f4\u52a0\u7a33\u5065\u3002 \u5f88\u591a\u4eba\u8ba4\u4e3a\u8fd9\u79cd\u60f3\u6cd5\u4f1a\u8fc7\u5ea6\u62df\u5408\u3002\u53ef\u80fd\u8fc7\u62df\u5408\uff0c\u4e5f\u53ef\u80fd\u4e0d\u8fc7\u62df\u5408\u3002\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u3002\u5982\u679c\u4f60\u5728\u8bbe\u8ba1\u4ea4\u53c9\u9a8c\u8bc1\u65f6\uff0c\u80fd\u591f\u5728\u6d4b\u8bd5\u6570\u636e\u4e0a\u8fd0\u884c\u6a21\u578b\u65f6\u590d\u5236\u9884\u6d4b\u8fc7\u7a0b\uff0c\u90a3\u4e48\u5b83\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u8fd9\u610f\u5473\u7740\u7b2c\u4e00\u6b65\u5e94\u8be5\u662f\u5206\u79bb\u6298\u53e0\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u4f60\u5e94\u8be5\u5e94\u7528\u4e0e\u6d4b\u8bd5\u6570\u636e\u76f8\u540c\u7684\u9884\u5904\u7406\u3002\u5047\u8bbe\u60a8\u60f3\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e\uff0c\u90a3\u4e48\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u60a8\u5fc5\u987b\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\uff0c\u5e76\u786e\u4fdd\u9a8c\u8bc1\u6570\u636e\u96c6\u590d\u5236\u4e86\u6d4b\u8bd5\u96c6\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u60a8\u5fc5\u987b\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u8bbe\u8ba1\u9a8c\u8bc1\u96c6\uff0c\u4f7f\u5176\u5305\u542b\u8bad\u7ec3\u96c6\u4e2d \"\u672a\u89c1 \"\u7684\u7c7b\u522b\u3002 \u56fe 4\uff1a\u5bf9\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u8fdb\u884c\u7b80\u5355\u5408\u5e76\uff0c\u4ee5\u4e86\u89e3\u6d4b\u8bd5\u96c6\u4e2d\u5b58\u5728\u4f46\u8bad\u7ec3\u96c6\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u6216\u8bad\u7ec3\u96c6\u4e2d\u7f55\u89c1\u7684\u7c7b\u522b \u53ea\u8981\u770b\u4e00\u4e0b\u56fe 4 \u548c\u4e0b\u9762\u7684\u4ee3\u7801\uff0c\u5c31\u80fd\u5f88\u5bb9\u6613\u7406\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( \"../input/cat_train.csv\" ) # \u8bfb\u53d6\u6d4b\u8bd5\u96c6 test = pd . read_csv ( \"../input/cat_test.csv\" ) # \u5c06\u6d4b\u8bd5\u96c6\"target\"\u5217\u5168\u90e8\u7f6e\u4e3a-1 test . loc [:, \"target\" ] = - 1 # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u62fc\u63a5 data = pd . concat ([ train , test ]) . reset_index ( drop = True ) # \u5c06\u9664\"id\"\u548c\"target\"\u5217\u7684\u5176\u4ed6\u7279\u5f81\u5217\u540d\u53d6\u51fa features = [ x for x in train . columns if x not in [ \"id\" , \"target\" ]] # \u904d\u5386\u7279\u5f81 for feat in features : # \u6807\u7b7e\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\",\u5e76\u5c06\u8be5\u5217\u683c\u5f0f\u53d8\u4e3astr temp_col = data [ feat ] . fillna ( \"NONE\" ) . astype ( str ) . values # \u8f6c\u6362\u6570\u503c data . loc [:, feat ] = lbl_enc . fit_transform ( temp_col ) # \u6839\u636e\"target\"\u5217\u5c06\u8bad\u7ec3\u96c6\u4e0e\u6d4b\u8bd5\u96c6\u5206\u5f00 train = data [ data . target != - 1 ] . reset_index ( drop = True ) test = data [ data . target == - 1 ] . reset_index ( drop = True ) \u5f53\u60a8\u9047\u5230\u5df2\u7ecf\u6709\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u95ee\u9898\u65f6\uff0c\u8fd9\u4e2a\u6280\u5de7\u5c31\u4f1a\u8d77\u4f5c\u7528\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u4e00\u62db\u5728\u5b9e\u65f6\u73af\u5883\u4e2d\u4e0d\u8d77\u4f5c\u7528\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u60a8\u6240\u5728\u7684\u516c\u53f8\u63d0\u4f9b\u5b9e\u65f6\u7ade\u4ef7\u89e3\u51b3\u65b9\u6848\uff08RTB\uff09\u3002RTB \u7cfb\u7edf\u4f1a\u5bf9\u5728\u7ebf\u770b\u5230\u7684\u6bcf\u4e2a\u7528\u6237\u8fdb\u884c\u7ade\u4ef7\uff0c\u4ee5\u8d2d\u4e70\u5e7f\u544a\u7a7a\u95f4\u3002\u8fd9\u79cd\u6a21\u5f0f\u53ef\u4f7f\u7528\u7684\u529f\u80fd\u53ef\u80fd\u5305\u62ec\u7f51\u7ad9\u4e2d\u6d4f\u89c8\u7684\u9875\u9762\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u7279\u5f81\u662f\u7528\u6237\u8bbf\u95ee\u7684\u6700\u540e\u4e94\u4e2a\u7c7b\u522b/\u9875\u9762\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u7f51\u7ad9\u5f15\u5165\u4e86\u65b0\u7684\u7c7b\u522b\uff0c\u6211\u4eec\u5c06\u65e0\u6cd5\u518d\u51c6\u786e\u9884\u6d4b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u5931\u6548\u3002\u8fd9\u79cd\u60c5\u51b5\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528 \"\u672a\u77e5 \"\u7c7b\u522b\u6765\u907f\u514d \u3002 \u5728\u6211\u4eec\u7684 cat-in-the-dat \u6570\u636e\u96c6\u4e2d\uff0c ord_2 \u5217\u4e2d\u5df2\u7ecf\u6709\u4e86\u672a\u77e5\u7c7b\u522b\u3002 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u6211\u4eec\u53ef\u4ee5\u5c06 \"NONE \"\u89c6\u4e3a\u672a\u77e5\u3002\u56e0\u6b64\uff0c\u5982\u679c\u5728\u5b9e\u65f6\u6d4b\u8bd5\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4ee5\u524d\u4ece\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\uff0c\u6211\u4eec\u5c31\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a \"NONE\"\u3002 \u8fd9\u4e0e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u95ee\u9898\u975e\u5e38\u76f8\u4f3c\u3002\u6211\u4eec\u603b\u662f\u57fa\u4e8e\u56fa\u5b9a\u7684\u8bcd\u6c47\u5efa\u7acb\u6a21\u578b\u3002\u589e\u52a0\u8bcd\u6c47\u91cf\u5c31\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u3002\u50cf BERT \u8fd9\u6837\u7684\u8f6c\u6362\u5668\u6a21\u578b\u662f\u5728 ~30000 \u4e2a\u5355\u8bcd\uff08\u82f1\u8bed\uff09\u7684\u57fa\u7840\u4e0a\u8bad\u7ec3\u7684\u3002\u56e0\u6b64\uff0c\u5f53\u6709\u65b0\u8bcd\u8f93\u5165\u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a UNK\uff08\u672a\u77e5\uff09\u3002 \u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u5047\u8bbe\u6d4b\u8bd5\u6570\u636e\u4e0e\u8bad\u7ec3\u6570\u636e\u5177\u6709\u76f8\u540c\u7684\u7c7b\u522b\uff0c\u4e5f\u53ef\u4ee5\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u5f15\u5165\u7f55\u89c1\u6216\u672a\u77e5\u7c7b\u522b\uff0c\u4ee5\u5904\u7406\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u65b0\u7c7b\u522b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u586b\u5165 NaN \u503c\u540e ord_4 \u5217\u7684\u503c\u8ba1\u6570\uff1a In [ X ]: df . ord_4 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 . . . K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 G 3404 V 3107 J 1950 L 1657 Name : ord_4 , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u6709\u4e9b\u6570\u503c\u53ea\u51fa\u73b0\u4e86\u51e0\u5343\u6b21\uff0c\u6709\u4e9b\u5219\u51fa\u73b0\u4e86\u8fd1 40000 \u6b21\u3002NaN \u4e5f\u7ecf\u5e38\u51fa\u73b0\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u503c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u5c06\u4e00\u4e2a\u503c\u79f0\u4e3a \" \u7f55\u89c1\uff08rare\uff09 \"\u7684\u6807\u51c6\u4e86\u3002\u6bd4\u65b9\u8bf4\uff0c\u5728\u8fd9\u4e00\u5217\u4e2d\uff0c\u7a00\u6709\u503c\u7684\u8981\u6c42\u662f\u8ba1\u6570\u5c0f\u4e8e 2000\u3002\u8fd9\u6837\u770b\u6765\uff0cJ \u548c L \u5c31\u53ef\u4ee5\u88ab\u6807\u8bb0\u4e3a\u7a00\u6709\u503c\u4e86\u3002\u4f7f\u7528 pandas\uff0c\u6839\u636e\u8ba1\u6570\u9608\u503c\u66ff\u6362\u7c7b\u522b\u975e\u5e38\u7b80\u5355\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 In [ X ]: df . ord_4 = df . ord_4 . fillna ( \"NONE\" ) In [ X ]: df . loc [ ... : df [ \"ord_4\" ] . value_counts ()[ df [ \"ord_4\" ]] . values < 2000 , ... : \"ord_4\" ... : ] = \"RARE\" In [ X ]: df . ord_4 . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 M 32504 . . . B 25212 E 21871 K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 RARE 3607 G 3404 V 3107 Name : ord_4 , dtype : int64 \u6211\u4eec\u8ba4\u4e3a\uff0c\u53ea\u8981\u67d0\u4e2a\u7c7b\u522b\u7684\u503c\u5c0f\u4e8e 2000\uff0c\u5c31\u5c06\u5176\u66ff\u6362\u4e3a\u7f55\u89c1\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u5728\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6240\u6709\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"RARE\"\uff0c\u800c\u6240\u6709\u7f3a\u5931\u503c\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"NONE\"\u3002 \u8fd9\u79cd\u65b9\u6cd5\u8fd8\u80fd\u786e\u4fdd\u5373\u4f7f\u6709\u65b0\u7684\u7c7b\u522b\uff0c\u6a21\u578b\u4e5f\u80fd\u5728\u5b9e\u9645\u73af\u5883\u4e2d\u6b63\u5e38\u5de5\u4f5c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u5177\u5907\u4e86\u5904\u7406\u4efb\u4f55\u5e26\u6709\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u6240\u9700\u7684\u4e00\u5207\u6761\u4ef6\u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u7b2c\u4e00\u4e2a\u6a21\u578b\uff0c\u5e76\u9010\u6b65\u63d0\u9ad8\u5176\u6027\u80fd\u3002 \u5728\u6784\u5efa\u4efb\u4f55\u7c7b\u578b\u7684\u6a21\u578b\u4e4b\u524d\uff0c\u4ea4\u53c9\u68c0\u9a8c\u81f3\u5173\u91cd\u8981\u3002\u6211\u4eec\u5df2\u7ecf\u770b\u5230\u4e86\u6807\u7b7e/\u76ee\u6807\u5206\u5e03\uff0c\u77e5\u9053\u8fd9\u662f\u4e00\u4e2a\u76ee\u6807\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u6765\u5206\u5272\u6570\u636e\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u6dfb\u52a0\"kfold\"\u5217\uff0c\u5e76\u7f6e\u4e3a-1 df [ \"kfold\" ] = - 1 # \u6253\u4e71\u6570\u636e\u987a\u5e8f\uff0c\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u5c06\u76ee\u6807\u5217\u53d6\u51fa y = df . target . values # \u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): # \u533a\u5206\u6298\u53e0 df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u6587\u4ef6 df . to_csv ( \"../input/cat_train_folds.csv\" , index = False ) \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u68c0\u67e5\u65b0\u7684\u6298\u53e0 csv\uff0c\u67e5\u770b\u6bcf\u4e2a\u6298\u53e0\u7684\u6837\u672c\u6570\uff1a In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) In [ X ]: df . kfold . value_counts () Out [ X ]: 4 120000 3 120000 2 120000 1 120000 0 120000 Name : kfold , dtype : int64 \u6240\u6709\u6298\u53e0\u90fd\u6709 120000 \u4e2a\u6837\u672c\u3002\u8fd9\u662f\u610f\u6599\u4e4b\u4e2d\u7684\uff0c\u56e0\u4e3a\u8bad\u7ec3\u6570\u636e\u6709 600000 \u4e2a\u6837\u672c\uff0c\u800c\u6211\u4eec\u505a\u4e865\u6b21\u6298\u53e0\u3002\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4e00\u5207\u987a\u5229\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u6bcf\u4e2a\u6298\u53e0\u7684\u76ee\u6807\u5206\u5e03\u3002 In [ X ]: df [ df . kfold == 0 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 1 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 2 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 3 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 4 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u76ee\u6807\u7684\u5206\u5e03\u90fd\u662f\u4e00\u6837\u7684\u3002\u8fd9\u6b63\u662f\u6211\u4eec\u6240\u9700\u8981\u7684\u3002\u5b83\u4e5f\u53ef\u4ee5\u662f\u76f8\u4f3c\u7684\uff0c\u5e76\u4e0d\u4e00\u5b9a\u8981\u4e00\u76f4\u76f8\u540c\u3002\u73b0\u5728\uff0c\u5f53\u6211\u4eec\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u6807\u7b7e\u5206\u5e03\u90fd\u5c06\u76f8\u540c\u3002 \u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u7684\u6700\u7b80\u5355\u7684\u6a21\u578b\u4e4b\u4e00\u662f\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u5e76\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): # \u8bfb\u53d6\u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) # \u53d6\u9664\"id\", \"target\", \"kfold\"\u5916\u7684\u5176\u4ed6\u7279\u5f81\u5217 features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] # \u904d\u5386\u7279\u5f81\u5217\u8868 for col in features : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u6d4b\u8bd5\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u903b\u8f91\u56de\u5f52 model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . target . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( auc ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6298\u53e00 run ( 0 ) \u90a3\u4e48\uff0c\u53d1\u751f\u4e86\u4ec0\u4e48\u5462\uff1f \u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u5c06\u6570\u636e\u5206\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u4e24\u90e8\u5206\uff0c\u7ed9\u5b9a\u6298\u53e0\u6570\uff0c\u5904\u7406 NaN \u503c\uff0c\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u5355\u6b21\u7f16\u7801\uff0c\u5e76\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\u3002 \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u90e8\u5206\u4ee3\u7801\u65f6\uff0c\u4f1a\u4ea7\u751f\u5982\u4e0b\u8f93\u51fa\uff1a \u276f python ohe_logres . py / home / abhishek / miniconda3 / envs / ml / lib / python3 .7 / site - packages / sklearn / linear_model / _logistic . py : 939 : ConvergenceWarning : lbfgs failed to converge ( status = 1 ): STOP : TOTAL NO . of ITERATIONS REACHED LIMIT . Increase the number of iterations ( max_iter ) or scale the data as shown in : https : // scikit - learn . org / stable / modules / preprocessing . html . Please also refer to the documentation for alternative solver options : https : // scikit - learn . org / stable / modules / linear_model . html #logistic- regression extra_warning_msg = _LOGISTIC_SOLVER_CONVERGENCE_MSG ) 0.7847865042255127 \u6709\u4e00\u4e9b\u8b66\u544a\u3002\u903b\u8f91\u56de\u5f52\u4f3c\u4e4e\u6ca1\u6709\u6536\u655b\u5230\u6700\u5927\u8fed\u4ee3\u6b21\u6570\u3002\u6211\u4eec\u6ca1\u6709\u8c03\u6574\u53c2\u6570\uff0c\u6240\u4ee5\u6ca1\u6709\u95ee\u9898\u3002\u6211\u4eec\u770b\u5230 AUC \u4e3a 0.785\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u5bf9\u4ee3\u7801\u8fdb\u884c\u7b80\u5355\u4fee\u6539\uff0c\u8fd0\u884c\u6240\u6709\u6298\u53e0\u3002 .... model = linear_model . LogisticRegression () model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u5faa\u73af\u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5e76\u6ca1\u6709\u505a\u5f88\u5927\u7684\u6539\u52a8\uff0c\u6240\u4ee5\u6211\u53ea\u663e\u793a\u4e86\u90e8\u5206\u4ee3\u7801\u884c\uff0c\u5176\u4e2d\u4e00\u4e9b\u4ee3\u7801\u884c\u6709\u6539\u52a8\u3002 \u8fd9\u5c31\u6253\u5370\u51fa\u4e86\uff1a python - W ignore ohe_logres . py Fold = 0 , AUC = 0.7847865042255127 Fold = 1 , AUC = 0.7853553605899214 Fold = 2 , AUC = 0.7879321942914885 Fold = 3 , AUC = 0.7870315929550808 Fold = 4 , AUC = 0.7864668243125608 \u8bf7\u6ce8\u610f\uff0c\u6211\u4f7f\u7528\"-W ignore \"\u5ffd\u7565\u4e86\u6240\u6709\u8b66\u544a\u3002 \u6211\u4eec\u770b\u5230\uff0cAUC \u5206\u6570\u5728\u6240\u6709\u8936\u76b1\u4e2d\u90fd\u76f8\u5f53\u7a33\u5b9a\u3002\u5e73\u5747 AUC \u4e3a 0.78631449527\u3002\u5bf9\u4e8e\u6211\u4eec\u7684\u7b2c\u4e00\u4e2a\u6a21\u578b\u6765\u8bf4\u76f8\u5f53\u4e0d\u9519\uff01 \u5f88\u591a\u4eba\u5728\u9047\u5230\u8fd9\u79cd\u95ee\u9898\u65f6\u4f1a\u9996\u5148\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6bd4\u5982\u968f\u673a\u68ee\u6797\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u5e94\u7528\u968f\u673a\u68ee\u6797\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\uff08label encoding\uff09\uff0c\u5c06\u6bcf\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81\u90fd\u8f6c\u6362\u4e3a\u6574\u6570\uff0c\u800c\u4e0d\u662f\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u72ec\u70ed\u7f16\u7801\u3002 \u8fd9\u79cd\u7f16\u7801\u4e0e\u72ec\u70ed\u7f16\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # \u968f\u673a\u68ee\u6797\u6a21\u578b model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u4f7f\u7528 scikit-learn \u4e2d\u7684\u968f\u673a\u68ee\u6797\uff0c\u5e76\u53d6\u6d88\u4e86\u72ec\u70ed\u7f16\u7801\u3002\u6211\u4eec\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u4ee3\u66ff\u72ec\u70ed\u7f16\u7801\u3002\u5f97\u5206\u5982\u4e0b \u276f python lbl_rf . py Fold = 0 , AUC = 0.7167390828113697 Fold = 1 , AUC = 0.7165459672958506 Fold = 2 , AUC = 0.7159709909587376 Fold = 3 , AUC = 0.7161589664189556 Fold = 4 , AUC = 0.7156020216155978 \u54c7 \u5de8\u5927\u7684\u5dee\u5f02\uff01 \u968f\u673a\u68ee\u6797\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8d85\u53c2\u6570\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u8868\u73b0\u8981\u6bd4\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u5dee\u5f88\u591a\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u603b\u662f\u5e94\u8be5\u5148\u4ece\u7b80\u5355\u6a21\u578b\u5f00\u59cb\u7684\u539f\u56e0\u3002\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u7c89\u4e1d\u4f1a\u4ece\u8fd9\u91cc\u5f00\u59cb\uff0c\u800c\u5ffd\u7565\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u8ba4\u4e3a\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u6a21\u578b\uff0c\u4e0d\u80fd\u5e26\u6765\u6bd4\u968f\u673a\u68ee\u6797\u66f4\u597d\u7684\u4ef7\u503c\u3002\u8fd9\u79cd\u4eba\u5c06\u4f1a\u72af\u4e0b\u5927\u9519\u3002\u5728\u6211\u4eec\u5b9e\u73b0\u968f\u673a\u68ee\u6797\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4e0e\u903b\u8f91\u56de\u5f52\u76f8\u6bd4\uff0c\u6298\u53e0\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u624d\u80fd\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e0d\u4ec5\u635f\u5931\u4e86 AUC\uff0c\u8fd8\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u6765\u5b8c\u6210\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u8fdb\u884c\u63a8\u7406\u4e5f\u5f88\u8017\u65f6\uff0c\u800c\u4e14\u5360\u7528\u7684\u7a7a\u95f4\u4e5f\u66f4\u5927\u3002 \u5982\u679c\u6211\u4eec\u613f\u610f\uff0c\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u5728\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u6570\u636e\u4e0a\u8fd0\u884c\u968f\u673a\u68ee\u6797\uff0c\u4f46\u8fd9\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u5947\u5f02\u503c\u5206\u89e3\u6765\u51cf\u5c11\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u77e9\u9635\u3002\u8fd9\u662f\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u63d0\u53d6\u4e3b\u9898\u7684\u5e38\u7528\u65b9\u6cd5\u3002 import pandas as pd from scipy import sparse from sklearn import decomposition from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" )] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) x_train = ohe . transform ( df_train [ features ]) x_valid = ohe . transform ( df_valid [ features ]) # \u5947\u5f02\u503c\u5206\u89e3 svd = decomposition . TruncatedSVD ( n_components = 120 ) full_sparse = sparse . vstack (( x_train , x_valid )) svd . fit ( full_sparse ) x_train = svd . transform ( x_train ) x_valid = svd . transform ( x_valid ) model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u5bf9\u5168\u90e8\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\uff0c\u7136\u540e\u7528\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u5728\u7a00\u758f\u77e9\u9635\u4e0a\u62df\u5408 scikit-learn \u7684 TruncatedSVD\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c06\u9ad8\u7ef4\u7a00\u758f\u77e9\u9635\u51cf\u5c11\u5230 120 \u4e2a\u7279\u5f81\uff0c\u7136\u540e\u62df\u5408\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u3002 \u4ee5\u4e0b\u662f\u8be5\u6a21\u578b\u7684\u8f93\u51fa\u7ed3\u679c\uff1a \u276f python ohe_svd_rf . py Fold = 0 , AUC = 0.7064863038754249 Fold = 1 , AUC = 0.706050102937374 Fold = 2 , AUC = 0.7086069243167242 Fold = 3 , AUC = 0.7066819080085971 Fold = 4 , AUC = 0.7058154015055585 \u6211\u4eec\u53d1\u73b0\u60c5\u51b5\u66f4\u7cdf\u3002\u770b\u6765\uff0c\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u7684\u6700\u4f73\u65b9\u6cd5\u662f\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u3002\u968f\u673a\u68ee\u6797\u4f3c\u4e4e\u8017\u65f6\u592a\u591a\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u8bd5\u8bd5 XGBoost\u3002\u5982\u679c\u4f60\u4e0d\u77e5\u9053 XGBoost\uff0c\u5b83\u662f\u6700\u6d41\u884c\u7684\u68af\u5ea6\u63d0\u5347\u7b97\u6cd5\u4e4b\u4e00\u3002\u7531\u4e8e\u5b83\u662f\u4e00\u79cd\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u6570\u636e\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 , n_estimators = 200 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6211\u5bf9 xgboost \u53c2\u6570\u505a\u4e86\u4e00\u4e9b\u4fee\u6539\u3002xgboost \u7684\u9ed8\u8ba4\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u662f 3\uff0c\u6211\u628a\u5b83\u6539\u6210\u4e86 7\uff0c\u8fd8\u628a\u4f30\u8ba1\u5668\u6570\u91cf\uff08n_estimators\uff09\u4ece 100 \u6539\u6210\u4e86 200\u3002 \u8be5\u6a21\u578b\u7684 5 \u6298\u4ea4\u53c9\u68c0\u9a8c\u5f97\u5206\u5982\u4e0b\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.7656768851999011 Fold = 1 , AUC = 0.7633006564148015 Fold = 2 , AUC = 0.7654277821434345 Fold = 3 , AUC = 0.7663609758878182 Fold = 4 , AUC = 0.764914671468069 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u4e0d\u505a\u4efb\u4f55\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u5f97\u5206\u6bd4\u666e\u901a\u968f\u673a\u68ee\u6797\u8981\u9ad8\u5f97\u591a\u3002 \u60a8\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4e00\u4e9b\u7279\u5f81\u5de5\u7a0b\uff0c\u653e\u5f03\u67d0\u4e9b\u5bf9\u6a21\u578b\u6ca1\u6709\u4efb\u4f55\u4ef7\u503c\u7684\u5217\u7b49\u3002\u4f46\u4f3c\u4e4e\u6211\u4eec\u80fd\u505a\u7684\u4e0d\u591a\uff0c\u65e0\u6cd5\u8bc1\u660e\u6a21\u578b\u7684\u6539\u8fdb\u3002\u8ba9\u6211\u4eec\u628a\u6570\u636e\u96c6\u6362\u6210\u53e6\u4e00\u4e2a\u6709\u5927\u91cf\u5206\u7c7b\u53d8\u91cf\u7684\u6570\u636e\u96c6\u3002\u53e6\u4e00\u4e2a\u6709\u540d\u7684\u6570\u636e\u96c6\u662f \u7f8e\u56fd\u6210\u4eba\u4eba\u53e3\u666e\u67e5\u6570\u636e\uff08US adult census data\uff09 \u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u5305\u542b\u4e00\u4e9b\u7279\u5f81\uff0c\u800c\u4f60\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5de5\u8d44\u7b49\u7ea7\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u3002\u56fe 5 \u663e\u793a\u4e86\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e9b\u5217\u3002 \u56fe 5\uff1a\u90e8\u5206\u6570\u636e\u96c6\u5c55\u793a \u8be5\u6570\u636e\u96c6\u6709\u4ee5\u4e0b\u51e0\u5217\uff1a - \u5e74\u9f84\uff08age\uff09 \u5de5\u4f5c\u7c7b\u522b\uff08workclass\uff09 \u5b66\u5386\uff08fnlwgt\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education.num\uff09 \u5a5a\u59fb\u72b6\u51b5\uff08marital.status\uff09 \u804c\u4e1a\uff08occupation\uff09 \u5173\u7cfb\uff08relationship\uff09 \u79cd\u65cf\uff08race\uff09 \u6027\u522b\uff08sex\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u539f\u7c4d\u56fd\uff08native.country\uff09 \u6536\u5165\uff08income\uff09 \u8fd9\u4e9b\u7279\u5f81\u5927\u591a\u4e0d\u8a00\u81ea\u660e\u3002\u90a3\u4e9b\u4e0d\u660e\u767d\u7684\uff0c\u6211\u4eec\u53ef\u4ee5\u4e0d\u8003\u8651\u3002\u8ba9\u6211\u4eec\u5148\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002 \u6211\u4eec\u770b\u5230\u6536\u5165\u5217\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u3002\u8ba9\u6211\u4eec\u5bf9\u8fd9\u4e00\u5217\u8fdb\u884c\u6570\u503c\u7edf\u8ba1\u3002 In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/adult.csv\" ) In [ X ]: df . income . value_counts () Out [ X ]: <= 50 K 24720 > 50 K 7841 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 7841 \u4e2a\u5b9e\u4f8b\u7684\u6536\u5165\u8d85\u8fc7 5 \u4e07\u7f8e\u5143\u3002\u8fd9\u5360\u6837\u672c\u603b\u6570\u7684 24%\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u4e0e\u732b\u6570\u636e\u96c6\u76f8\u540c\u7684\u8bc4\u4f30\u65b9\u6cd5\uff0c\u5373 AUC\u3002 \u5728\u5f00\u59cb\u5efa\u6a21\u4e4b\u524d\uff0c\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5c06\u53bb\u6389\u51e0\u5217\u7279\u5f81\uff0c\u5373 \u5b66\u5386\uff08fnlwgt\uff09 \u5e74\u9f84\uff08age\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u8ba9\u6211\u4eec\u8bd5\u7740\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u5668\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u7b2c\u4e00\u6b65\u603b\u662f\u8981\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002\u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u90e8\u5206\u4ee3\u7801\u3002\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u9700\u8981\u5220\u9664\u7684\u5217 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) # \u6620\u5c04 target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } # \u4f7f\u7528\u6620\u5c04\u66ff\u6362 df . loc [:, \"income\" ] = df . income . map ( target_mapping ) # \u53d6\u9664\"kfold\", \"income\"\u5217\u7684\u5176\u4ed6\u5217\u540d features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u9a8c\u8bc1\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u6784\u5efa\u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . income . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u6211\u4eec\u4f1a\u5f97\u5230 \u276f python - W ignore ohe_logres . py Fold = 0 , AUC = 0.8794809708119079 Fold = 1 , AUC = 0.8875785068274882 Fold = 2 , AUC = 0.8852609687685753 Fold = 3 , AUC = 0.8681236223251438 Fold = 4 , AUC = 0.8728581541840037 \u5bf9\u4e8e\u4e00\u4e2a\u5982\u6b64\u7b80\u5355\u7684\u6a21\u578b\u6765\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u4e0d\u9519\u7684 AUC\uff01 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5728\u4e0d\u8c03\u6574\u4efb\u4f55\u8d85\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c1d\u8bd5\u4e00\u4e0b\u6807\u7b7e\u7f16\u7801\u7684xgboost\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8ba9\u6211\u4eec\u8fd0\u884c\u4e0a\u9762\u4ee3\u7801\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8800810634234078 Fold = 1 , AUC = 0.886811884948154 Fold = 2 , AUC = 0.8854421433318472 Fold = 3 , AUC = 0.8676319549361007 Fold = 4 , AUC = 0.8714450054900602 \u8fd9\u770b\u8d77\u6765\u5df2\u7ecf\u76f8\u5f53\u4e0d\u9519\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b max_depth \u589e\u52a0\u5230 7 \u548c n_estimators \u589e\u52a0\u5230 200 \u65f6\u7684\u5f97\u5206\u3002 \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8764108944332032 Fold = 1 , AUC = 0.8840708537662638 Fold = 2 , AUC = 0.8816601162613102 Fold = 3 , AUC = 0.8662335762581732 Fold = 4 , AUC = 0.8698983461709926 \u770b\u8d77\u6765\u5e76\u6ca1\u6709\u6539\u5584\u3002 \u8fd9\u8868\u660e\uff0c\u4e00\u4e2a\u6570\u636e\u96c6\u7684\u53c2\u6570\u4e0d\u80fd\u79fb\u690d\u5230\u53e6\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6211\u4eec\u5fc5\u987b\u518d\u6b21\u5c1d\u8bd5\u8c03\u6574\u53c2\u6570\uff0c\u4f46\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u8bf4\u660e\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u4e0d\u8c03\u6574\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c06\u6570\u503c\u7279\u5f81\u7eb3\u5165 xgboost \u6a21\u578b\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u52a0\u5165\u6570\u503c\u7279\u5f81 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : if col not in num_cols : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u56e0\u6b64\uff0c\u6211\u4eec\u4fdd\u7559\u6570\u5b57\u5217\uff0c\u53ea\u662f\u4e0d\u5bf9\u5176\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6700\u7ec8\u7279\u5f81\u77e9\u9635\u5c31\u7531\u6570\u5b57\u5217\uff08\u539f\u6837\uff09\u548c\u7f16\u7801\u5206\u7c7b\u5217\u7ec4\u6210\u4e86\u3002\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\u90fd\u80fd\u8f7b\u677e\u5904\u7406\u8fd9\u79cd\u6df7\u5408\u3002 \u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff0c\u5728\u4f7f\u7528\u7ebf\u6027\u6a21\u578b\uff08\u5982\u903b\u8f91\u56de\u5f52\uff09\u65f6\u4e0d\u5bb9\u5ffd\u89c6\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u4e2a\u811a\u672c\uff01 \u276f python lbl_xgb_num . py Fold = 0 , AUC = 0.9209790185449889 Fold = 1 , AUC = 0.9247157449144706 Fold = 2 , AUC = 0.9269329887598243 Fold = 3 , AUC = 0.9119349082169275 Fold = 4 , AUC = 0.9166408030141667 \u54c7\u54e6 \u8fd9\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u5206\u6570\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c1d\u8bd5\u6dfb\u52a0\u4e00\u4e9b\u529f\u80fd\u3002\u6211\u4eec\u5c06\u63d0\u53d6\u6240\u6709\u5206\u7c7b\u5217\uff0c\u5e76\u521b\u5efa\u6240\u6709\u4e8c\u5ea6\u7ec4\u5408\u3002\u8bf7\u770b\u4e0b\u9762\u4ee3\u7801\u6bb5\u4e2d\u7684 feature_engineering \u51fd\u6570\uff0c\u4e86\u89e3\u5982\u4f55\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002 import itertools import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def feature_engineering ( df , cat_cols ): # \u751f\u6210\u4e24\u4e2a\u7279\u5f81\u7684\u7ec4\u5408 combi = list ( itertools . combinations ( cat_cols , 2 )) for c1 , c2 in combi : df . loc [:, c1 + \"_\" + c2 ] = df [ c1 ] . astype ( str ) + \"_\" + df [ c2 ] . astype ( str ) return df def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) cat_cols = [ c for c in df . columns if c not in num_cols and c not in ( \"kfold\" , \"income\" )] # \u7279\u5f81\u5de5\u7a0b df = feature_engineering ( df , cat_cols ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" )] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u8fd9\u662f\u4ece\u5206\u7c7b\u5217\u4e2d\u521b\u5efa\u7279\u5f81\u7684\u4e00\u79cd\u975e\u5e38\u5e7c\u7a1a\u7684\u65b9\u6cd5\u3002\u6211\u4eec\u5e94\u8be5\u4ed4\u7ec6\u7814\u7a76\u6570\u636e\uff0c\u770b\u770b\u54ea\u4e9b\u7ec4\u5408\u6700\u5408\u7406\u3002\u5982\u679c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5c31\u9700\u8981\u4f7f\u7528\u67d0\u79cd\u7279\u5f81\u9009\u62e9\u6765\u9009\u51fa\u6700\u4f73\u7279\u5f81\u3002\u7a0d\u540e\u6211\u4eec\u5c06\u8be6\u7ec6\u4ecb\u7ecd\u7279\u5f81\u9009\u62e9\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5206\u6570\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9211483465031423 Fold = 1 , AUC = 0.9251499446866125 Fold = 2 , AUC = 0.9262344766486692 Fold = 3 , AUC = 0.9114264068794995 Fold = 4 , AUC = 0.9177914453099201 \u770b\u6765\uff0c\u5373\u4f7f\u4e0d\u6539\u53d8\u4efb\u4f55\u8d85\u53c2\u6570\uff0c\u53ea\u589e\u52a0\u4e00\u4e9b\u7279\u5f81\uff0c\u6211\u4eec\u4e5f\u80fd\u63d0\u9ad8\u4e00\u4e9b\u6298\u53e0\u5f97\u5206\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5c06 max_depth \u589e\u52a0\u5230 7 \u662f\u5426\u6709\u5e2e\u52a9\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9286668430204137 Fold = 1 , AUC = 0.9329340656165378 Fold = 2 , AUC = 0.9319817543218744 Fold = 3 , AUC = 0.919046187194538 Fold = 4 , AUC = 0.9245692057162671 \u6211\u4eec\u518d\u6b21\u6539\u8fdb\u4e86\u6211\u4eec\u7684\u6a21\u578b\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u8fd8\u6ca1\u6709\u4f7f\u7528\u7a00\u6709\u503c\u3001\u4e8c\u503c\u5316\u3001\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u7279\u5f81\u7684\u7ec4\u5408\u4ee5\u53ca\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u3002 \u4ece\u5206\u7c7b\u7279\u5f81\u4e2d\u8fdb\u884c\u7279\u5f81\u5de5\u7a0b\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u4f7f\u7528 \u76ee\u6807\u7f16\u7801 \u3002\u4f46\u662f\uff0c\u60a8\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u8fd9\u53ef\u80fd\u4f1a\u4f7f\u60a8\u7684\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u76ee\u6807\u7f16\u7801\u662f\u4e00\u79cd\u5c06\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u6620\u5c04\u5230\u5176\u5e73\u5747\u76ee\u6807\u503c\u7684\u6280\u672f\uff0c\u4f46\u5fc5\u987b\u59cb\u7ec8\u4ee5\u4ea4\u53c9\u9a8c\u8bc1\u7684\u65b9\u5f0f\u8fdb\u884c\u3002\u8fd9\u610f\u5473\u7740\u9996\u5148\u8981\u521b\u5efa\u6298\u53e0\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u6298\u53e0\u4e3a\u6570\u636e\u7684\u4e0d\u540c\u5217\u521b\u5efa\u76ee\u6807\u7f16\u7801\u7279\u5f81\uff0c\u65b9\u6cd5\u4e0e\u5728\u6298\u53e0\u4e0a\u62df\u5408\u548c\u9884\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u521b\u5efa\u4e86 5 \u4e2a\u6298\u53e0\uff0c\u60a8\u5c31\u5fc5\u987b\u521b\u5efa 5 \u6b21\u76ee\u6807\u7f16\u7801\uff0c\u8fd9\u6837\u6700\u7ec8\uff0c\u60a8\u5c31\u53ef\u4ee5\u4e3a\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u53d8\u91cf\u521b\u5efa\u7f16\u7801\uff0c\u800c\u8fd9\u4e9b\u53d8\u91cf\u5e76\u975e\u6765\u81ea\u540c\u4e00\u4e2a\u6298\u53e0\u3002\u7136\u540e\u5728\u62df\u5408\u6a21\u578b\u65f6\uff0c\u5fc5\u987b\u518d\u6b21\u4f7f\u7528\u76f8\u540c\u7684\u6298\u53e0\u3002\u672a\u89c1\u6d4b\u8bd5\u6570\u636e\u7684\u76ee\u6807\u7f16\u7801\u53ef\u4ee5\u6765\u81ea\u5168\u90e8\u8bad\u7ec3\u6570\u636e\uff0c\u4e5f\u53ef\u4ee5\u662f\u6240\u6709 5 \u4e2a\u6298\u53e0\u7684\u5e73\u5747\u503c\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u540c\u4e00\u4e2a\u6210\u4eba\u6570\u636e\u96c6\u4e0a\u4f7f\u7528\u76ee\u6807\u7f16\u7801\uff0c\u4ee5\u4fbf\u8fdb\u884c\u6bd4\u8f83\u3002 import copy import pandas as pd from sklearn import metrics from sklearn import preprocessing import xgboost as xgb def mean_target_encoding ( data ): df = copy . deepcopy ( data ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) and f not in num_cols ] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) encoded_dfs = [] for fold in range ( 5 ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) for column in features : # \u76ee\u6807\u7f16\u7801 mapping_dict = dict ( df_train . groupby ( column )[ \"income\" ] . mean () ) df_valid . loc [:, column + \"_enc\" ] = df_valid [ column ] . map ( mapping_dict ) encoded_dfs . append ( df_valid ) encoded_df = pd . concat ( encoded_dfs , axis = 0 ) return encoded_df def run ( df , fold ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/adult_folds.csv\" ) df = mean_target_encoding ( df ) for fold_ in range ( 5 ): run ( df , fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u4e0a\u8ff0\u7247\u6bb5\u4e2d\uff0c\u6211\u5728\u8fdb\u884c\u76ee\u6807\u7f16\u7801\u65f6\u5e76\u6ca1\u6709\u5220\u9664\u5206\u7c7b\u5217\u3002\u6211\u4fdd\u7559\u4e86\u6240\u6709\u7279\u5f81\uff0c\u5e76\u5728\u6b64\u57fa\u7840\u4e0a\u6dfb\u52a0\u4e86\u76ee\u6807\u7f16\u7801\u7279\u5f81\u3002\u6b64\u5916\uff0c\u6211\u8fd8\u4f7f\u7528\u4e86\u5e73\u5747\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5e73\u5747\u503c\u3001\u4e2d\u4f4d\u6570\u3001\u6807\u51c6\u504f\u5dee\u6216\u76ee\u6807\u7684\u4efb\u4f55\u5176\u4ed6\u51fd\u6570\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u7ed3\u679c\u3002 Fold = 0 , AUC = 0.9332240662017529 Fold = 1 , AUC = 0.9363551625140347 Fold = 2 , AUC = 0.9375013544556173 Fold = 3 , AUC = 0.92237621307625 Fold = 4 , AUC = 0.9292131180445478 \u4e0d\u9519\uff01\u770b\u6765\u6211\u4eec\u53c8\u6709\u8fdb\u6b65\u4e86\u3002\u4e0d\u8fc7\uff0c\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u5b83\u592a\u5bb9\u6613\u51fa\u73b0\u8fc7\u5ea6\u62df\u5408\u3002\u5f53\u6211\u4eec\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\uff0c\u6700\u597d\u4f7f\u7528\u67d0\u79cd\u5e73\u6ed1\u65b9\u6cd5\u6216\u5728\u7f16\u7801\u503c\u4e2d\u6dfb\u52a0\u566a\u58f0\u3002 Scikit-learn \u7684\u8d21\u732e\u5e93\u4e2d\u6709\u5e26\u5e73\u6ed1\u7684\u76ee\u6807\u7f16\u7801\uff0c\u4f60\u4e5f\u53ef\u4ee5\u521b\u5efa\u81ea\u5df1\u7684\u5e73\u6ed1\u3002\u5e73\u6ed1\u4f1a\u5f15\u5165\u67d0\u79cd\u6b63\u5219\u5316\uff0c\u6709\u52a9\u4e8e\u907f\u514d\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u8fd9\u5e76\u4e0d\u96be\u3002 \u5904\u7406\u5206\u7c7b\u7279\u5f81\u662f\u4e00\u9879\u590d\u6742\u7684\u4efb\u52a1\u3002\u8bb8\u591a\u8d44\u6e90\u4e2d\u90fd\u6709\u5927\u91cf\u4fe1\u606f\u3002\u672c\u7ae0\u5e94\u8be5\u80fd\u5e2e\u52a9\u4f60\u5f00\u59cb\u89e3\u51b3\u5206\u7c7b\u53d8\u91cf\u7684\u4efb\u4f55\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u5bf9\u4e8e\u5927\u591a\u6570\u95ee\u9898\u6765\u8bf4\uff0c\u9664\u4e86\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u4e4b\u5916\uff0c\u4f60\u4e0d\u9700\u8981\u66f4\u591a\u7684\u4e1c\u897f\u3002 \u8981\u8fdb\u4e00\u6b65\u6539\u8fdb\u6a21\u578b\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u66f4\u591a\uff01 \u5728\u672c\u7ae0\u7684\u6700\u540e\uff0c\u6211\u4eec\u4e0d\u80fd\u4e0d\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e00\u79cd\u79f0\u4e3a \u5b9e\u4f53\u5d4c\u5165 \u7684\u6280\u672f\u3002\u5728\u5b9e\u4f53\u5d4c\u5165\u4e2d\uff0c\u7c7b\u522b\u7528\u5411\u91cf\u8868\u793a\u3002\u5728\u4e8c\u503c\u5316\u548c\u72ec\u70ed\u7f16\u7801\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u90fd\u662f\u7528\u5411\u91cf\u6765\u8868\u793a\u7c7b\u522b\u7684\u3002 \u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u6709\u6570\u4ee5\u4e07\u8ba1\u7684\u7c7b\u522b\u600e\u4e48\u529e\uff1f\u8fd9\u5c06\u4f1a\u4ea7\u751f\u5de8\u5927\u7684\u77e9\u9635\uff0c\u6211\u4eec\u5c06\u9700\u8981\u5f88\u957f\u65f6\u95f4\u6765\u8bad\u7ec3\u590d\u6742\u7684\u6a21\u578b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5e26\u6709\u6d6e\u70b9\u503c\u7684\u5411\u91cf\u6765\u8868\u793a\u5b83\u4eec\u3002 \u8fd9\u4e2a\u60f3\u6cd5\u975e\u5e38\u7b80\u5355\u3002\u6bcf\u4e2a\u5206\u7c7b\u7279\u5f81\u90fd\u6709\u4e00\u4e2a\u5d4c\u5165\u5c42\u3002\u56e0\u6b64\uff0c\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u73b0\u5728\u90fd\u53ef\u4ee5\u6620\u5c04\u5230\u4e00\u4e2a\u5d4c\u5165\u5c42\uff08\u5c31\u50cf\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u5c06\u5355\u8bcd\u6620\u5c04\u5230\u5d4c\u5165\u5c42\u4e00\u6837\uff09\u3002\u7136\u540e\uff0c\u6839\u636e\u5176\u7ef4\u5ea6\u91cd\u5851\u8fd9\u4e9b\u5d4c\u5165\u5c42\uff0c\u4f7f\u5176\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c06\u6240\u6709\u6241\u5e73\u5316\u7684\u8f93\u5165\u5d4c\u5165\u5c42\u8fde\u63a5\u8d77\u6765\u3002\u7136\u540e\u6dfb\u52a0\u4e00\u5806\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\uff0c\u5c31\u5927\u529f\u544a\u6210\u4e86\u3002 \u56fe 6\uff1a\u7c7b\u522b\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6216\u5d4c\u5165\u5411\u91cf \u51fa\u4e8e\u67d0\u79cd\u539f\u56e0\uff0c\u6211\u53d1\u73b0\u4f7f\u7528 TF/Keras \u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528 TF/Keras \u5b9e\u73b0\u5b83\u3002\u6b64\u5916\uff0c\u8fd9\u662f\u672c\u4e66\u4e2d\u552f\u4e00\u4e00\u4e2a\u4f7f\u7528 TF/Keras \u7684\u793a\u4f8b\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a PyTorch\uff08\u4f7f\u7528 cat-in-the-dat-ii \u6570\u636e\u96c6\uff09\u4e5f\u975e\u5e38\u5bb9\u6613 import os import gc import joblib import pandas as pd import numpy as np from sklearn import metrics , preprocessing from tensorflow.keras import layers from tensorflow.keras import optimizers from tensorflow.keras.models import Model , load_model from tensorflow.keras import callbacks from tensorflow.keras import backend as K from tensorflow.keras import utils def create_model ( data , catcols ): # \u521b\u5efa\u7a7a\u7684\u8f93\u5165\u5217\u8868\u548c\u8f93\u51fa\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6a21\u578b\u7684\u8f93\u5165\u548c\u8f93\u51fa inputs = [] outputs = [] # \u904d\u5386\u5206\u7c7b\u7279\u5f81\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81 for c in catcols : # \u8ba1\u7b97\u7279\u5f81\u4e2d\u552f\u4e00\u503c\u7684\u6570\u91cf num_unique_values = int ( data [ c ] . nunique ()) # \u8ba1\u7b97\u5d4c\u5165\u7ef4\u5ea6\uff0c\u6700\u5927\u4e0d\u8d85\u8fc750 embed_dim = int ( min ( np . ceil (( num_unique_values ) / 2 ), 50 )) # \u521b\u5efa\u6a21\u578b\u7684\u8f93\u5165\u5c42\uff0c\u6bcf\u4e2a\u7279\u5f81\u5bf9\u5e94\u4e00\u4e2a\u8f93\u5165 inp = layers . Input ( shape = ( 1 ,)) # \u521b\u5efa\u5d4c\u5165\u5c42\uff0c\u5c06\u5206\u7c7b\u7279\u5f81\u6620\u5c04\u5230\u4f4e\u7ef4\u5ea6\u7684\u8fde\u7eed\u5411\u91cf out = layers . Embedding ( num_unique_values + 1 , embed_dim , name = c )( inp ) # \u5bf9\u5d4c\u5165\u5c42\u8fdb\u884c\u7a7a\u95f4\u4e22\u5f03\uff08Dropout\uff09 out = layers . SpatialDropout1D ( 0.3 )( out ) # \u5c06\u5d4c\u5165\u5c42\u7684\u5f62\u72b6\u91cd\u65b0\u8c03\u6574\u4e3a\u4e00\u7ef4 out = layers . Reshape ( target_shape = ( embed_dim ,))( out ) # \u5c06\u8f93\u5165\u548c\u8f93\u51fa\u6dfb\u52a0\u5230\u5bf9\u5e94\u7684\u5217\u8868\u4e2d inputs . append ( inp ) outputs . append ( out ) # \u4f7f\u7528Concatenate\u5c42\u5c06\u6240\u6709\u7684\u5d4c\u5165\u5c42\u8f93\u51fa\u8fde\u63a5\u5728\u4e00\u8d77 x = layers . Concatenate ()( outputs ) # \u5bf9\u8fde\u63a5\u540e\u7684\u6570\u636e\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u53e6\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u8f93\u51fa\u5c42\uff0c\u5177\u67092\u4e2a\u795e\u7ecf\u5143\uff08\u7528\u4e8e\u4e8c\u8fdb\u5236\u5206\u7c7b\uff09\uff0c\u5e76\u4f7f\u7528softmax\u6fc0\u6d3b\u51fd\u6570 y = layers . Dense ( 2 , activation = \"softmax\" )( x ) # \u521b\u5efa\u6a21\u578b\uff0c\u5c06\u8f93\u5165\u548c\u8f93\u51fa\u4f20\u9012\u7ed9Model\u6784\u9020\u51fd\u6570 model = Model ( inputs = inputs , outputs = y ) # \u7f16\u8bd1\u6a21\u578b\uff0c\u6307\u5b9a\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668 model . compile ( loss = 'binary_crossentropy' , optimizer = 'adam' ) # \u8fd4\u56de\u521b\u5efa\u7684\u6a21\u578b return model def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for feat in features : lbl_enc = preprocessing . LabelEncoder () df . loc [:, feat ] = lbl_enc . fit_transform ( df [ feat ] . values ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) model = create_model ( df , features ) xtrain = [ df_train [ features ] . values [:, k ] for k in range ( len ( features ))] xvalid = [ df_valid [ features ] . values [:, k ] for k in range ( len ( features )) ] ytrain = df_train . target . values yvalid = df_valid . target . values ytrain_cat = utils . to_categorical ( ytrain ) yvalid_cat = utils . to_categorical ( yvalid ) model . fit ( xtrain , ytrain_cat , validation_data = ( xvalid , yvalid_cat ), verbose = 1 , batch_size = 1024 , epochs = 3 ) valid_preds = model . predict ( xvalid )[:, 1 ] print ( metrics . roc_auc_score ( yvalid , valid_preds )) K . clear_session () if __name__ == \"__main__\" : run ( 0 ) run ( 1 ) run ( 2 ) run ( 3 ) run ( 4 ) \u4f60\u4f1a\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u6548\u679c\u6700\u597d\uff0c\u800c\u4e14\u5982\u679c\u4f60\u6709 GPU\uff0c\u901f\u5ea6\u4e5f\u8d85\u5feb\uff01\u8fd9\u79cd\u65b9\u6cd5\u8fd8\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\uff0c\u800c\u4e14\u4f60\u65e0\u9700\u62c5\u5fc3\u7279\u5f81\u5de5\u7a0b\uff0c\u56e0\u4e3a\u795e\u7ecf\u7f51\u7edc\u4f1a\u81ea\u884c\u5904\u7406\u3002\u5728\u5904\u7406\u5927\u91cf\u5206\u7c7b\u7279\u5f81\u6570\u636e\u96c6\u65f6\uff0c\u8fd9\u7edd\u5bf9\u503c\u5f97\u4e00\u8bd5\u3002\u5f53\u5d4c\u5165\u5927\u5c0f\u4e0e\u552f\u4e00\u7c7b\u522b\u7684\u6570\u91cf\u76f8\u540c\u65f6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u7f16\u7801\uff08one-hot-encoding\uff09\u3002 \u672c\u7ae0\u57fa\u672c\u4e0a\u90fd\u662f\u5173\u4e8e\u7279\u5f81\u5de5\u7a0b\u7684\u3002\u8ba9\u6211\u4eec\u5728\u4e0b\u4e00\u7ae0\u4e2d\u770b\u770b\u5982\u4f55\u5728\u6570\u5b57\u7279\u5f81\u548c\u4e0d\u540c\u7c7b\u578b\u7279\u5f81\u7684\u7ec4\u5408\u65b9\u9762\u8fdb\u884c\u66f4\u591a\u7684\u7279\u5f81\u5de5\u7a0b\u3002","title":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf"},{"location":"%E6%97%A0%E7%9B%91%E7%9D%A3%E5%92%8C%E6%9C%89%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0/","text":"\u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60 \u5728\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u65f6\uff0c\u901a\u5e38\u6709\u4e24\u7c7b\u6570\u636e\uff08\u548c\u673a\u5668\u5b66\u4e60\u6a21\u578b\uff09\uff1a \u76d1\u7763\u6570\u636e\uff1a\u603b\u662f\u6709\u4e00\u4e2a\u6216\u591a\u4e2a\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807 \u65e0\u76d1\u7763\u6570\u636e\uff1a\u6ca1\u6709\u4efb\u4f55\u76ee\u6807\u53d8\u91cf\u3002 \u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\u3002\u6211\u4eec\u9700\u8981\u9884\u6d4b\u4e00\u4e2a\u503c\u7684\u95ee\u9898\u88ab\u79f0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u3002\u4f8b\u5982\uff0c\u5982\u679c\u95ee\u9898\u662f\u6839\u636e\u5386\u53f2\u623f\u4ef7\u9884\u6d4b\u623f\u4ef7\uff0c\u90a3\u4e48\u533b\u9662\u3001\u5b66\u6821\u6216\u8d85\u5e02\u7684\u5b58\u5728\uff0c\u4e0e\u6700\u8fd1\u516c\u5171\u4ea4\u901a\u7684\u8ddd\u79bb\u7b49\u7279\u5f81\u5c31\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u540c\u6837\uff0c\u5f53\u6211\u4eec\u5f97\u5230\u732b\u548c\u72d7\u7684\u56fe\u50cf\u65f6\uff0c\u6211\u4eec\u4e8b\u5148\u77e5\u9053\u54ea\u4e9b\u662f\u732b\uff0c\u54ea\u4e9b\u662f\u72d7\uff0c\u5982\u679c\u4efb\u52a1\u662f\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u6240\u63d0\u4f9b\u7684\u56fe\u50cf\u662f\u732b\u8fd8\u662f\u72d7\uff0c\u90a3\u4e48\u8fd9\u4e2a\u95ee\u9898\u5c31\u88ab\u8ba4\u4e3a\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002 \u56fe 1\uff1a\u6709\u76d1\u7763\u5b66\u4e60\u6570\u636e \u5982\u56fe 1 \u6240\u793a\uff0c\u6570\u636e\u7684\u6bcf\u4e00\u884c\u90fd\u4e0e\u4e00\u4e2a\u76ee\u6807\u6216\u6807\u7b7e\u76f8\u5173\u8054\u3002\u5217\u662f\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u884c\u4ee3\u8868\u4e0d\u540c\u7684\u6570\u636e\u70b9\uff0c\u901a\u5e38\u79f0\u4e3a\u6837\u672c\u3002\u793a\u4f8b\u4e2d\u7684\u5341\u4e2a\u6837\u672c\u6709\u5341\u4e2a\u7279\u5f81\u548c\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\uff0c\u76ee\u6807\u53d8\u91cf\u53ef\u4ee5\u662f\u6570\u5b57\u6216\u7c7b\u522b\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5206\u7c7b\u53d8\u91cf\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5b9e\u6570\uff0c\u95ee\u9898\u5c31\u88ab\u5b9a\u4e49\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6709\u76d1\u7763\u95ee\u9898\u53ef\u5206\u4e3a\u4e24\u4e2a\u5b50\u7c7b\uff1a \u5206\u7c7b\uff1a\u9884\u6d4b\u7c7b\u522b\uff0c\u5982\u732b\u6216\u72d7 \u56de\u5f52\uff1a\u9884\u6d4b\u503c\uff0c\u5982\u623f\u4ef7 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6709\u65f6\u6211\u4eec\u53ef\u80fd\u4f1a\u5728\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u4f7f\u7528\u56de\u5f52\uff0c\u8fd9\u53d6\u51b3\u4e8e\u7528\u4e8e\u8bc4\u4f30\u7684\u6307\u6807\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u7a0d\u540e\u4f1a\u8ba8\u8bba\u8fd9\u4e2a\u95ee\u9898\u3002 \u53e6\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u95ee\u9898\u662f\u65e0\u76d1\u7763\u7c7b\u578b\u3002 \u65e0\u76d1\u7763 \u6570\u636e\u96c6\u6ca1\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\uff0c\u4e00\u822c\u6765\u8bf4\uff0c\u4e0e\u6709\u76d1\u7763\u95ee\u9898\u76f8\u6bd4\uff0c\u5904\u7406\u65e0\u76d1\u7763\u6570\u636e\u96c6\u66f4\u5177\u6311\u6218\u6027\u3002 \u5047\u8bbe\u4f60\u5728\u4e00\u5bb6\u5904\u7406\u4fe1\u7528\u5361\u4ea4\u6613\u7684\u91d1\u878d\u516c\u53f8\u5de5\u4f5c\u3002\u6bcf\u79d2\u949f\u90fd\u6709\u5927\u91cf\u6570\u636e\u6d8c\u5165\u3002\u552f\u4e00\u7684\u95ee\u9898\u662f\uff0c\u5f88\u96be\u627e\u5230\u4e00\u4e2a\u4eba\u6765\u5c06\u6bcf\u7b14\u4ea4\u6613\u6807\u8bb0\u4e3a\u6709\u6548\u4ea4\u6613\u3001\u771f\u5b9e\u4ea4\u6613\u6216\u6b3a\u8bc8\u4ea4\u6613\u3002\u5f53\u6211\u4eec\u6ca1\u6709\u4efb\u4f55\u5173\u4e8e\u4ea4\u6613\u662f\u6b3a\u8bc8\u8fd8\u662f\u771f\u5b9e\u7684\u4fe1\u606f\u65f6\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u65e0\u76d1\u7763\u95ee\u9898\u3002\u8981\u89e3\u51b3\u8fd9\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5fc5\u987b\u8003\u8651\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a\u591a\u5c11\u4e2a \u805a\u7c7b \u3002\u805a\u7c7b\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd8\u6709\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5e94\u7528\u4e8e\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5bf9\u4e8e\u6b3a\u8bc8\u68c0\u6d4b\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\u6570\u636e\u53ef\u4ee5\u5206\u4e3a\u4e24\u7c7b\uff08\u6b3a\u8bc8\u6216\u771f\u5b9e\uff09\u3002 \u5f53\u6211\u4eec\u77e5\u9053\u805a\u7c7b\u7684\u6570\u91cf\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u805a\u7c7b\u7b97\u6cd5\u6765\u89e3\u51b3\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5728\u56fe 2 \u4e2d\uff0c\u5047\u8bbe\u6570\u636e\u5206\u4e3a\u4e24\u7c7b\uff0c\u6df1\u8272\u4ee3\u8868\u6b3a\u8bc8\uff0c\u6d45\u8272\u4ee3\u8868\u771f\u5b9e\u4ea4\u6613\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528\u805a\u7c7b\u65b9\u6cd5\u4e4b\u524d\uff0c\u6211\u4eec\u5e76\u4e0d\u77e5\u9053\u8fd9\u4e9b\u7c7b\u522b\u3002\u5e94\u7528\u805a\u7c7b\u7b97\u6cd5\u540e\uff0c\u6211\u4eec\u5e94\u8be5\u80fd\u591f\u533a\u5206\u8fd9\u4e24\u4e2a\u5047\u5b9a\u76ee\u6807\u3002 \u4e3a\u4e86\u7406\u89e3\u65e0\u76d1\u7763\u95ee\u9898\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u8bb8\u591a\u5206\u89e3\u6280\u672f\uff0c\u5982 \u4e3b\u6210\u5206\u5206\u6790\uff08PCA\uff09\u3001t-\u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09 \u7b49\u3002 \u6709\u76d1\u7763\u7684\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\uff0c\u56e0\u4e3a\u5b83\u4eec\u5f88\u5bb9\u6613\u8bc4\u4f30\u3002\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u8bc4\u4f30\u6280\u672f\u3002\u7136\u800c\uff0c\u5bf9\u65e0\u76d1\u7763\u7b97\u6cd5\u7684\u7ed3\u679c\u8fdb\u884c\u8bc4\u4f30\u5177\u6709\u6311\u6218\u6027\uff0c\u9700\u8981\u5927\u91cf\u7684\u4eba\u4e3a\u5e72\u9884\u6216\u542f\u53d1\u5f0f\u65b9\u6cd5\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3b\u8981\u5173\u6ce8\u6709\u76d1\u7763\u6570\u636e\u548c\u6a21\u578b\uff0c\u4f46\u8fd9\u5e76\u4e0d\u610f\u5473\u7740\u6211\u4eec\u4f1a\u5ffd\u7565\u65e0\u76d1\u7763\u6570\u636e\u95ee\u9898\u3002 \u56fe 2\uff1a\u65e0\u76d1\u7763\u5b66\u4e60\u6570\u636e\u96c6 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u5f53\u4eba\u4eec\u5f00\u59cb\u5b66\u4e60\u6570\u636e\u79d1\u5b66\u6216\u673a\u5668\u5b66\u4e60\u65f6\uff0c\u90fd\u4f1a\u4ece\u975e\u5e38\u8457\u540d\u7684\u6570\u636e\u96c6\u5f00\u59cb\uff0c\u4f8b\u5982\u6cf0\u5766\u5c3c\u514b\u6570\u636e\u96c6\u6216\u8679\u819c\u6570\u636e\u96c6\uff0c\u8fd9\u4e9b\u90fd\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u5728\u6cf0\u5766\u5c3c\u514b\u53f7\u6570\u636e\u96c6\u4e2d\uff0c\u4f60\u5fc5\u987b\u6839\u636e\u8239\u7968\u7b49\u7ea7\u3001\u6027\u522b\u3001\u5e74\u9f84\u7b49\u56e0\u7d20\u9884\u6d4b\u6cf0\u5766\u5c3c\u514b\u53f7\u4e0a\u4e58\u5ba2\u7684\u5b58\u6d3b\u7387\u3002\u540c\u6837\uff0c\u5728\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\u4e2d\uff0c\u60a8\u5fc5\u987b\u6839\u636e\u843c\u7247\u5bbd\u5ea6\u3001\u82b1\u74e3\u957f\u5ea6\u3001\u843c\u7247\u957f\u5ea6\u548c\u82b1\u74e3\u5bbd\u5ea6\u7b49\u56e0\u7d20\u9884\u6d4b\u82b1\u7684\u79cd\u7c7b\u3002 \u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u80fd\u5305\u62ec\u7528\u4e8e\u5ba2\u6237\u7ec6\u5206\u7684\u6570\u636e\u96c6\u3002 \u4f8b\u5982\uff0c\u60a8\u62e5\u6709\u8bbf\u95ee\u60a8\u7684\u7535\u5b50\u5546\u52a1\u7f51\u7ad9\u7684\u5ba2\u6237\u6570\u636e\uff0c\u6216\u8005\u8bbf\u95ee\u5546\u5e97\u6216\u5546\u573a\u7684\u5ba2\u6237\u6570\u636e\uff0c\u800c\u60a8\u5e0c\u671b\u5c06\u5b83\u4eec\u7ec6\u5206\u6216\u805a\u7c7b\u4e3a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u65e0\u76d1\u7763\u6570\u636e\u96c6\u7684\u53e6\u4e00\u4e2a\u4f8b\u5b50\u53ef\u80fd\u5305\u62ec\u4fe1\u7528\u5361\u6b3a\u8bc8\u68c0\u6d4b\u6216\u5bf9\u51e0\u5f20\u56fe\u7247\u8fdb\u884c\u805a\u7c7b\u7b49\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd8\u53ef\u4ee5\u5c06\u6709\u76d1\u7763\u6570\u636e\u96c6\u8f6c\u6362\u4e3a\u65e0\u76d1\u7763\u6570\u636e\u96c6\uff0c\u4ee5\u67e5\u770b\u5b83\u4eec\u5728\u7ed8\u5236\u65f6\u7684\u6548\u679c\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u56fe 3 \u4e2d\u7684\u6570\u636e\u96c6\u3002\u56fe 3 \u663e\u793a\u7684\u662f MNIST \u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u6d41\u884c\u7684\u624b\u5199\u6570\u5b57\u6570\u636e\u96c6\uff0c\u5b83\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\uff0c\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u4f60\u4f1a\u5f97\u5230\u6570\u5b57\u56fe\u50cf\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u6b63\u786e\u6807\u7b7e\u3002\u4f60\u5fc5\u987b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u53ea\u63d0\u4f9b\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\u8bc6\u522b\u51fa\u54ea\u4e2a\u6570\u5b57\u662f\u5b83\u3002 \u56fe 3\uff1aMNIST\u6570\u636e\u96c6 \u5982\u679c\u6211\u4eec\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c t \u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09\u5206\u89e3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u53ea\u9700\u5728\u56fe\u50cf\u50cf\u7d20\u4e0a\u964d\u7ef4\u81f3 2 \u4e2a\u7ef4\u5ea6\uff0c\u5c31\u80fd\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4e0a\u5206\u79bb\u56fe\u50cf\u3002\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1aMNIST \u6570\u636e\u96c6\u7684 t-SNE \u53ef\u89c6\u5316\u3002\u4f7f\u7528\u4e86 3000 \u5e45\u56fe\u50cf\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002\u9996\u5148\u662f\u5bfc\u5165\u6240\u6709\u9700\u8981\u7684\u5e93\u3002 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from sklearn import datasets from sklearn import manifold % matplotlib inline \u6211\u4eec\u4f7f\u7528 matplotlib \u548c seaborn \u8fdb\u884c\u7ed8\u56fe\uff0c\u4f7f\u7528 numpy \u5904\u7406\u6570\u503c\u6570\u7ec4\uff0c\u4f7f\u7528 pandas \u4ece\u6570\u503c\u6570\u7ec4\u521b\u5efa\u6570\u636e\u5e27\uff0c\u4f7f\u7528 scikit-learn (sklearn) \u83b7\u53d6\u6570\u636e\u5e76\u6267\u884c t-SNE\u3002 \u5bfc\u5165\u540e\uff0c\u6211\u4eec\u9700\u8981\u4e0b\u8f7d\u6570\u636e\u5e76\u5355\u72ec\u8bfb\u53d6\uff0c\u6216\u8005\u4f7f\u7528 sklearn \u7684\u5185\u7f6e\u51fd\u6570\u6765\u63d0\u4f9b MNIST \u6570\u636e\u96c6\u3002 data = datasets . fetch_openml ( 'mnist_784' , version = 1 , return_X_y = True ) pixel_values , targets = data targets = targets . astype ( int ) \u5728\u8fd9\u90e8\u5206\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 sklearn \u6570\u636e\u96c6\u83b7\u53d6\u4e86\u6570\u636e\uff0c\u5e76\u83b7\u5f97\u4e86\u4e00\u4e2a\u50cf\u7d20\u503c\u6570\u7ec4\u548c\u53e6\u4e00\u4e2a\u76ee\u6807\u6570\u7ec4\u3002\u7531\u4e8e\u76ee\u6807\u662f\u5b57\u7b26\u4e32\u7c7b\u578b\uff0c\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u3002 pixel_values \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a 70000x784 \u7684\u4e8c\u7ef4\u6570\u7ec4\u3002 \u5171\u6709 70000 \u5f20\u4e0d\u540c\u7684\u56fe\u50cf\uff0c\u6bcf\u5f20\u56fe\u50cf\u5927\u5c0f\u4e3a 28x28 \u50cf\u7d20\u3002\u5e73\u94fa 28x28 \u540e\u5f97\u5230 784 \u4e2a\u6570\u636e\u70b9\u3002 \u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u91cd\u5851\u4e3a\u539f\u6765\u7684\u5f62\u72b6\uff0c\u7136\u540e\u4f7f\u7528 matplotlib \u7ed8\u5236\u6210\u56fe\u8868\uff0c\u4ece\u800c\u5c06\u5176\u53ef\u89c6\u5316\u3002 single_image = pixel_values [ 1 , :] . reshape ( 28 , 28 ) plt . imshow ( single_image , cmap = 'gray' ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u7ed8\u5236\u5982\u4e0b\u56fe\u50cf\uff1a \u56fe 5\uff1a\u7ed8\u5236MNIST\u6570\u636e\u96c6\u5355\u5f20\u56fe\u7247 \u6700\u91cd\u8981\u7684\u4e00\u6b65\u662f\u5728\u6211\u4eec\u83b7\u53d6\u6570\u636e\u4e4b\u540e\u3002 tsne = manifold . TSNE ( n_components = 2 , random_state = 42 ) transformed_data = tsne . fit_transform ( pixel_values [: 3000 , :]) \u8fd9\u4e00\u6b65\u521b\u5efa\u4e86\u6570\u636e\u7684 t-SNE \u53d8\u6362\u3002\u6211\u4eec\u53ea\u4f7f\u7528 2 \u4e2a\u7ef4\u5ea6\uff0c\u56e0\u4e3a\u5728\u4e8c\u7ef4\u73af\u5883\u4e2d\u53ef\u4ee5\u5f88\u597d\u5730\u5c06\u5b83\u4eec\u53ef\u89c6\u5316\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u8f6c\u6362\u540e\u7684\u6570\u636e\u662f\u4e00\u4e2a 3000x2 \u5f62\u72b6\u7684\u6570\u7ec4\uff083000 \u884c 2 \u5217\uff09\u3002\u5728\u6570\u7ec4\u4e0a\u8c03\u7528 pd.DataFrame \u53ef\u4ee5\u5c06\u8fd9\u6837\u7684\u6570\u636e\u8f6c\u6362\u4e3a pandas \u6570\u636e\u5e27\u3002 tsne_df = pd . DataFrame ( np . column_stack (( transformed_data , targets [: 3000 ])), columns = [ \"x\" , \"y\" , \"targets\" ]) tsne_df . loc [:, \"targets\" ] = tsne_df . targets . astype ( int ) \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a numpy \u6570\u7ec4\u521b\u5efa\u4e00\u4e2a pandas \u6570\u636e\u5e27\u3002x \u548c y \u662f t-SNE \u5206\u89e3\u7684\u4e24\u4e2a\u7ef4\u5ea6\uff0ctarget \u662f\u5b9e\u9645\u6570\u5b57\u3002\u8fd9\u6837\u6211\u4eec\u5c31\u5f97\u5230\u4e86\u5982\u56fe 6 \u6240\u793a\u7684\u6570\u636e\u5e27\u3002 \u56fe 6\uff1at-SNE\u540e\u6570\u636e\u524d10\u884c \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 seaborn \u548c matplotlib \u7ed8\u5236\u5b83\u3002 grid = sns . FacetGrid ( tsne_df , hue = \"targets\" , size = 8 ) grid . map ( plt . scatter , \"x\" , \"y\" ) . add_legend () \u8fd9\u662f\u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u89c6\u5316\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5728\u540c\u4e00\u6570\u636e\u96c6\u4e0a\u8fdb\u884c k-means \u805a\u7c7b\uff0c\u770b\u770b\u5b83\u5728\u65e0\u76d1\u7763\u73af\u5883\u4e0b\u7684\u8868\u73b0\u5982\u4f55\u3002\u4e00\u4e2a\u7ecf\u5e38\u51fa\u73b0\u7684\u95ee\u9898\u662f\uff0c\u5982\u4f55\u5728 k-means \u805a\u7c7b\u4e2d\u627e\u5230\u6700\u4f73\u7684\u7c07\u6570\u3002\u8fd9\u4e2a\u95ee\u9898\u6ca1\u6709\u6b63\u786e\u7b54\u6848\u3002\u4f60\u5fc5\u987b\u901a\u8fc7\u4ea4\u53c9\u9a8c\u8bc1\u6765\u627e\u5230\u6700\u4f73\u7c07\u6570\u3002\u672c\u4e66\u7a0d\u540e\u5c06\u8ba8\u8bba\u4ea4\u53c9\u9a8c\u8bc1\u3002\u8bf7\u6ce8\u610f\uff0c\u4e0a\u8ff0\u4ee3\u7801\u662f\u5728 jupyter \u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\u7684\u3002 \u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 jupyter \u505a\u4e00\u4e9b\u7b80\u5355\u7684\u4e8b\u60c5\uff0c\u6bd4\u5982\u4e0a\u9762\u7684\u4f8b\u5b50\u548c \u7ed8\u56fe\u3002\u5bf9\u4e8e\u672c\u4e66\u4e2d\u7684\u5927\u90e8\u5206\u5185\u5bb9\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 python \u811a\u672c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6 IDE \u56e0\u4e3a\u7ed3\u679c\u90fd\u662f\u4e00\u6837\u7684\u3002 MNIST \u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u628a\u5b83\u8f6c\u6362\u6210\u4e00\u4e2a\u65e0\u76d1\u7763\u7684\u95ee\u9898\uff0c\u53ea\u662f\u4e3a\u4e86\u68c0\u67e5\u5b83\u662f\u5426\u80fd\u5e26\u6765\u4efb\u4f55\u597d\u7684\u7ed3\u679c\u3002\u5982\u679c\u6211\u4eec\u4f7f\u7528\u5206\u7c7b\u7b97\u6cd5\uff0c\u6548\u679c\u4f1a\u66f4\u597d\u3002\u8ba9\u6211\u4eec\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u4e00\u63a2\u7a76\u7adf\u3002","title":"\u6709\u76d1\u7763\u548c\u65e0\u76d1\u7763\u5b66\u4e60"},{"location":"%E6%97%A0%E7%9B%91%E7%9D%A3%E5%92%8C%E6%9C%89%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0/#_1","text":"\u5728\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u65f6\uff0c\u901a\u5e38\u6709\u4e24\u7c7b\u6570\u636e\uff08\u548c\u673a\u5668\u5b66\u4e60\u6a21\u578b\uff09\uff1a \u76d1\u7763\u6570\u636e\uff1a\u603b\u662f\u6709\u4e00\u4e2a\u6216\u591a\u4e2a\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807 \u65e0\u76d1\u7763\u6570\u636e\uff1a\u6ca1\u6709\u4efb\u4f55\u76ee\u6807\u53d8\u91cf\u3002 \u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\u3002\u6211\u4eec\u9700\u8981\u9884\u6d4b\u4e00\u4e2a\u503c\u7684\u95ee\u9898\u88ab\u79f0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u3002\u4f8b\u5982\uff0c\u5982\u679c\u95ee\u9898\u662f\u6839\u636e\u5386\u53f2\u623f\u4ef7\u9884\u6d4b\u623f\u4ef7\uff0c\u90a3\u4e48\u533b\u9662\u3001\u5b66\u6821\u6216\u8d85\u5e02\u7684\u5b58\u5728\uff0c\u4e0e\u6700\u8fd1\u516c\u5171\u4ea4\u901a\u7684\u8ddd\u79bb\u7b49\u7279\u5f81\u5c31\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u540c\u6837\uff0c\u5f53\u6211\u4eec\u5f97\u5230\u732b\u548c\u72d7\u7684\u56fe\u50cf\u65f6\uff0c\u6211\u4eec\u4e8b\u5148\u77e5\u9053\u54ea\u4e9b\u662f\u732b\uff0c\u54ea\u4e9b\u662f\u72d7\uff0c\u5982\u679c\u4efb\u52a1\u662f\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u6240\u63d0\u4f9b\u7684\u56fe\u50cf\u662f\u732b\u8fd8\u662f\u72d7\uff0c\u90a3\u4e48\u8fd9\u4e2a\u95ee\u9898\u5c31\u88ab\u8ba4\u4e3a\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002 \u56fe 1\uff1a\u6709\u76d1\u7763\u5b66\u4e60\u6570\u636e \u5982\u56fe 1 \u6240\u793a\uff0c\u6570\u636e\u7684\u6bcf\u4e00\u884c\u90fd\u4e0e\u4e00\u4e2a\u76ee\u6807\u6216\u6807\u7b7e\u76f8\u5173\u8054\u3002\u5217\u662f\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u884c\u4ee3\u8868\u4e0d\u540c\u7684\u6570\u636e\u70b9\uff0c\u901a\u5e38\u79f0\u4e3a\u6837\u672c\u3002\u793a\u4f8b\u4e2d\u7684\u5341\u4e2a\u6837\u672c\u6709\u5341\u4e2a\u7279\u5f81\u548c\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\uff0c\u76ee\u6807\u53d8\u91cf\u53ef\u4ee5\u662f\u6570\u5b57\u6216\u7c7b\u522b\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5206\u7c7b\u53d8\u91cf\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5b9e\u6570\uff0c\u95ee\u9898\u5c31\u88ab\u5b9a\u4e49\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6709\u76d1\u7763\u95ee\u9898\u53ef\u5206\u4e3a\u4e24\u4e2a\u5b50\u7c7b\uff1a \u5206\u7c7b\uff1a\u9884\u6d4b\u7c7b\u522b\uff0c\u5982\u732b\u6216\u72d7 \u56de\u5f52\uff1a\u9884\u6d4b\u503c\uff0c\u5982\u623f\u4ef7 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6709\u65f6\u6211\u4eec\u53ef\u80fd\u4f1a\u5728\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u4f7f\u7528\u56de\u5f52\uff0c\u8fd9\u53d6\u51b3\u4e8e\u7528\u4e8e\u8bc4\u4f30\u7684\u6307\u6807\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u7a0d\u540e\u4f1a\u8ba8\u8bba\u8fd9\u4e2a\u95ee\u9898\u3002 \u53e6\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u95ee\u9898\u662f\u65e0\u76d1\u7763\u7c7b\u578b\u3002 \u65e0\u76d1\u7763 \u6570\u636e\u96c6\u6ca1\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\uff0c\u4e00\u822c\u6765\u8bf4\uff0c\u4e0e\u6709\u76d1\u7763\u95ee\u9898\u76f8\u6bd4\uff0c\u5904\u7406\u65e0\u76d1\u7763\u6570\u636e\u96c6\u66f4\u5177\u6311\u6218\u6027\u3002 \u5047\u8bbe\u4f60\u5728\u4e00\u5bb6\u5904\u7406\u4fe1\u7528\u5361\u4ea4\u6613\u7684\u91d1\u878d\u516c\u53f8\u5de5\u4f5c\u3002\u6bcf\u79d2\u949f\u90fd\u6709\u5927\u91cf\u6570\u636e\u6d8c\u5165\u3002\u552f\u4e00\u7684\u95ee\u9898\u662f\uff0c\u5f88\u96be\u627e\u5230\u4e00\u4e2a\u4eba\u6765\u5c06\u6bcf\u7b14\u4ea4\u6613\u6807\u8bb0\u4e3a\u6709\u6548\u4ea4\u6613\u3001\u771f\u5b9e\u4ea4\u6613\u6216\u6b3a\u8bc8\u4ea4\u6613\u3002\u5f53\u6211\u4eec\u6ca1\u6709\u4efb\u4f55\u5173\u4e8e\u4ea4\u6613\u662f\u6b3a\u8bc8\u8fd8\u662f\u771f\u5b9e\u7684\u4fe1\u606f\u65f6\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u65e0\u76d1\u7763\u95ee\u9898\u3002\u8981\u89e3\u51b3\u8fd9\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5fc5\u987b\u8003\u8651\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a\u591a\u5c11\u4e2a \u805a\u7c7b \u3002\u805a\u7c7b\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd8\u6709\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5e94\u7528\u4e8e\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5bf9\u4e8e\u6b3a\u8bc8\u68c0\u6d4b\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\u6570\u636e\u53ef\u4ee5\u5206\u4e3a\u4e24\u7c7b\uff08\u6b3a\u8bc8\u6216\u771f\u5b9e\uff09\u3002 \u5f53\u6211\u4eec\u77e5\u9053\u805a\u7c7b\u7684\u6570\u91cf\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u805a\u7c7b\u7b97\u6cd5\u6765\u89e3\u51b3\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5728\u56fe 2 \u4e2d\uff0c\u5047\u8bbe\u6570\u636e\u5206\u4e3a\u4e24\u7c7b\uff0c\u6df1\u8272\u4ee3\u8868\u6b3a\u8bc8\uff0c\u6d45\u8272\u4ee3\u8868\u771f\u5b9e\u4ea4\u6613\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528\u805a\u7c7b\u65b9\u6cd5\u4e4b\u524d\uff0c\u6211\u4eec\u5e76\u4e0d\u77e5\u9053\u8fd9\u4e9b\u7c7b\u522b\u3002\u5e94\u7528\u805a\u7c7b\u7b97\u6cd5\u540e\uff0c\u6211\u4eec\u5e94\u8be5\u80fd\u591f\u533a\u5206\u8fd9\u4e24\u4e2a\u5047\u5b9a\u76ee\u6807\u3002 \u4e3a\u4e86\u7406\u89e3\u65e0\u76d1\u7763\u95ee\u9898\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u8bb8\u591a\u5206\u89e3\u6280\u672f\uff0c\u5982 \u4e3b\u6210\u5206\u5206\u6790\uff08PCA\uff09\u3001t-\u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09 \u7b49\u3002 \u6709\u76d1\u7763\u7684\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\uff0c\u56e0\u4e3a\u5b83\u4eec\u5f88\u5bb9\u6613\u8bc4\u4f30\u3002\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u8bc4\u4f30\u6280\u672f\u3002\u7136\u800c\uff0c\u5bf9\u65e0\u76d1\u7763\u7b97\u6cd5\u7684\u7ed3\u679c\u8fdb\u884c\u8bc4\u4f30\u5177\u6709\u6311\u6218\u6027\uff0c\u9700\u8981\u5927\u91cf\u7684\u4eba\u4e3a\u5e72\u9884\u6216\u542f\u53d1\u5f0f\u65b9\u6cd5\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3b\u8981\u5173\u6ce8\u6709\u76d1\u7763\u6570\u636e\u548c\u6a21\u578b\uff0c\u4f46\u8fd9\u5e76\u4e0d\u610f\u5473\u7740\u6211\u4eec\u4f1a\u5ffd\u7565\u65e0\u76d1\u7763\u6570\u636e\u95ee\u9898\u3002 \u56fe 2\uff1a\u65e0\u76d1\u7763\u5b66\u4e60\u6570\u636e\u96c6 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u5f53\u4eba\u4eec\u5f00\u59cb\u5b66\u4e60\u6570\u636e\u79d1\u5b66\u6216\u673a\u5668\u5b66\u4e60\u65f6\uff0c\u90fd\u4f1a\u4ece\u975e\u5e38\u8457\u540d\u7684\u6570\u636e\u96c6\u5f00\u59cb\uff0c\u4f8b\u5982\u6cf0\u5766\u5c3c\u514b\u6570\u636e\u96c6\u6216\u8679\u819c\u6570\u636e\u96c6\uff0c\u8fd9\u4e9b\u90fd\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u5728\u6cf0\u5766\u5c3c\u514b\u53f7\u6570\u636e\u96c6\u4e2d\uff0c\u4f60\u5fc5\u987b\u6839\u636e\u8239\u7968\u7b49\u7ea7\u3001\u6027\u522b\u3001\u5e74\u9f84\u7b49\u56e0\u7d20\u9884\u6d4b\u6cf0\u5766\u5c3c\u514b\u53f7\u4e0a\u4e58\u5ba2\u7684\u5b58\u6d3b\u7387\u3002\u540c\u6837\uff0c\u5728\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\u4e2d\uff0c\u60a8\u5fc5\u987b\u6839\u636e\u843c\u7247\u5bbd\u5ea6\u3001\u82b1\u74e3\u957f\u5ea6\u3001\u843c\u7247\u957f\u5ea6\u548c\u82b1\u74e3\u5bbd\u5ea6\u7b49\u56e0\u7d20\u9884\u6d4b\u82b1\u7684\u79cd\u7c7b\u3002 \u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u80fd\u5305\u62ec\u7528\u4e8e\u5ba2\u6237\u7ec6\u5206\u7684\u6570\u636e\u96c6\u3002 \u4f8b\u5982\uff0c\u60a8\u62e5\u6709\u8bbf\u95ee\u60a8\u7684\u7535\u5b50\u5546\u52a1\u7f51\u7ad9\u7684\u5ba2\u6237\u6570\u636e\uff0c\u6216\u8005\u8bbf\u95ee\u5546\u5e97\u6216\u5546\u573a\u7684\u5ba2\u6237\u6570\u636e\uff0c\u800c\u60a8\u5e0c\u671b\u5c06\u5b83\u4eec\u7ec6\u5206\u6216\u805a\u7c7b\u4e3a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u65e0\u76d1\u7763\u6570\u636e\u96c6\u7684\u53e6\u4e00\u4e2a\u4f8b\u5b50\u53ef\u80fd\u5305\u62ec\u4fe1\u7528\u5361\u6b3a\u8bc8\u68c0\u6d4b\u6216\u5bf9\u51e0\u5f20\u56fe\u7247\u8fdb\u884c\u805a\u7c7b\u7b49\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd8\u53ef\u4ee5\u5c06\u6709\u76d1\u7763\u6570\u636e\u96c6\u8f6c\u6362\u4e3a\u65e0\u76d1\u7763\u6570\u636e\u96c6\uff0c\u4ee5\u67e5\u770b\u5b83\u4eec\u5728\u7ed8\u5236\u65f6\u7684\u6548\u679c\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u56fe 3 \u4e2d\u7684\u6570\u636e\u96c6\u3002\u56fe 3 \u663e\u793a\u7684\u662f MNIST \u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u6d41\u884c\u7684\u624b\u5199\u6570\u5b57\u6570\u636e\u96c6\uff0c\u5b83\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\uff0c\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u4f60\u4f1a\u5f97\u5230\u6570\u5b57\u56fe\u50cf\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u6b63\u786e\u6807\u7b7e\u3002\u4f60\u5fc5\u987b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u53ea\u63d0\u4f9b\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\u8bc6\u522b\u51fa\u54ea\u4e2a\u6570\u5b57\u662f\u5b83\u3002 \u56fe 3\uff1aMNIST\u6570\u636e\u96c6 \u5982\u679c\u6211\u4eec\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c t \u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09\u5206\u89e3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u53ea\u9700\u5728\u56fe\u50cf\u50cf\u7d20\u4e0a\u964d\u7ef4\u81f3 2 \u4e2a\u7ef4\u5ea6\uff0c\u5c31\u80fd\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4e0a\u5206\u79bb\u56fe\u50cf\u3002\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1aMNIST \u6570\u636e\u96c6\u7684 t-SNE \u53ef\u89c6\u5316\u3002\u4f7f\u7528\u4e86 3000 \u5e45\u56fe\u50cf\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002\u9996\u5148\u662f\u5bfc\u5165\u6240\u6709\u9700\u8981\u7684\u5e93\u3002 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from sklearn import datasets from sklearn import manifold % matplotlib inline \u6211\u4eec\u4f7f\u7528 matplotlib \u548c seaborn \u8fdb\u884c\u7ed8\u56fe\uff0c\u4f7f\u7528 numpy \u5904\u7406\u6570\u503c\u6570\u7ec4\uff0c\u4f7f\u7528 pandas \u4ece\u6570\u503c\u6570\u7ec4\u521b\u5efa\u6570\u636e\u5e27\uff0c\u4f7f\u7528 scikit-learn (sklearn) \u83b7\u53d6\u6570\u636e\u5e76\u6267\u884c t-SNE\u3002 \u5bfc\u5165\u540e\uff0c\u6211\u4eec\u9700\u8981\u4e0b\u8f7d\u6570\u636e\u5e76\u5355\u72ec\u8bfb\u53d6\uff0c\u6216\u8005\u4f7f\u7528 sklearn \u7684\u5185\u7f6e\u51fd\u6570\u6765\u63d0\u4f9b MNIST \u6570\u636e\u96c6\u3002 data = datasets . fetch_openml ( 'mnist_784' , version = 1 , return_X_y = True ) pixel_values , targets = data targets = targets . astype ( int ) \u5728\u8fd9\u90e8\u5206\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 sklearn \u6570\u636e\u96c6\u83b7\u53d6\u4e86\u6570\u636e\uff0c\u5e76\u83b7\u5f97\u4e86\u4e00\u4e2a\u50cf\u7d20\u503c\u6570\u7ec4\u548c\u53e6\u4e00\u4e2a\u76ee\u6807\u6570\u7ec4\u3002\u7531\u4e8e\u76ee\u6807\u662f\u5b57\u7b26\u4e32\u7c7b\u578b\uff0c\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u3002 pixel_values \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a 70000x784 \u7684\u4e8c\u7ef4\u6570\u7ec4\u3002 \u5171\u6709 70000 \u5f20\u4e0d\u540c\u7684\u56fe\u50cf\uff0c\u6bcf\u5f20\u56fe\u50cf\u5927\u5c0f\u4e3a 28x28 \u50cf\u7d20\u3002\u5e73\u94fa 28x28 \u540e\u5f97\u5230 784 \u4e2a\u6570\u636e\u70b9\u3002 \u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u91cd\u5851\u4e3a\u539f\u6765\u7684\u5f62\u72b6\uff0c\u7136\u540e\u4f7f\u7528 matplotlib \u7ed8\u5236\u6210\u56fe\u8868\uff0c\u4ece\u800c\u5c06\u5176\u53ef\u89c6\u5316\u3002 single_image = pixel_values [ 1 , :] . reshape ( 28 , 28 ) plt . imshow ( single_image , cmap = 'gray' ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u7ed8\u5236\u5982\u4e0b\u56fe\u50cf\uff1a \u56fe 5\uff1a\u7ed8\u5236MNIST\u6570\u636e\u96c6\u5355\u5f20\u56fe\u7247 \u6700\u91cd\u8981\u7684\u4e00\u6b65\u662f\u5728\u6211\u4eec\u83b7\u53d6\u6570\u636e\u4e4b\u540e\u3002 tsne = manifold . TSNE ( n_components = 2 , random_state = 42 ) transformed_data = tsne . fit_transform ( pixel_values [: 3000 , :]) \u8fd9\u4e00\u6b65\u521b\u5efa\u4e86\u6570\u636e\u7684 t-SNE \u53d8\u6362\u3002\u6211\u4eec\u53ea\u4f7f\u7528 2 \u4e2a\u7ef4\u5ea6\uff0c\u56e0\u4e3a\u5728\u4e8c\u7ef4\u73af\u5883\u4e2d\u53ef\u4ee5\u5f88\u597d\u5730\u5c06\u5b83\u4eec\u53ef\u89c6\u5316\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u8f6c\u6362\u540e\u7684\u6570\u636e\u662f\u4e00\u4e2a 3000x2 \u5f62\u72b6\u7684\u6570\u7ec4\uff083000 \u884c 2 \u5217\uff09\u3002\u5728\u6570\u7ec4\u4e0a\u8c03\u7528 pd.DataFrame \u53ef\u4ee5\u5c06\u8fd9\u6837\u7684\u6570\u636e\u8f6c\u6362\u4e3a pandas \u6570\u636e\u5e27\u3002 tsne_df = pd . DataFrame ( np . column_stack (( transformed_data , targets [: 3000 ])), columns = [ \"x\" , \"y\" , \"targets\" ]) tsne_df . loc [:, \"targets\" ] = tsne_df . targets . astype ( int ) \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a numpy \u6570\u7ec4\u521b\u5efa\u4e00\u4e2a pandas \u6570\u636e\u5e27\u3002x \u548c y \u662f t-SNE \u5206\u89e3\u7684\u4e24\u4e2a\u7ef4\u5ea6\uff0ctarget \u662f\u5b9e\u9645\u6570\u5b57\u3002\u8fd9\u6837\u6211\u4eec\u5c31\u5f97\u5230\u4e86\u5982\u56fe 6 \u6240\u793a\u7684\u6570\u636e\u5e27\u3002 \u56fe 6\uff1at-SNE\u540e\u6570\u636e\u524d10\u884c \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 seaborn \u548c matplotlib \u7ed8\u5236\u5b83\u3002 grid = sns . FacetGrid ( tsne_df , hue = \"targets\" , size = 8 ) grid . map ( plt . scatter , \"x\" , \"y\" ) . add_legend () \u8fd9\u662f\u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u89c6\u5316\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5728\u540c\u4e00\u6570\u636e\u96c6\u4e0a\u8fdb\u884c k-means \u805a\u7c7b\uff0c\u770b\u770b\u5b83\u5728\u65e0\u76d1\u7763\u73af\u5883\u4e0b\u7684\u8868\u73b0\u5982\u4f55\u3002\u4e00\u4e2a\u7ecf\u5e38\u51fa\u73b0\u7684\u95ee\u9898\u662f\uff0c\u5982\u4f55\u5728 k-means \u805a\u7c7b\u4e2d\u627e\u5230\u6700\u4f73\u7684\u7c07\u6570\u3002\u8fd9\u4e2a\u95ee\u9898\u6ca1\u6709\u6b63\u786e\u7b54\u6848\u3002\u4f60\u5fc5\u987b\u901a\u8fc7\u4ea4\u53c9\u9a8c\u8bc1\u6765\u627e\u5230\u6700\u4f73\u7c07\u6570\u3002\u672c\u4e66\u7a0d\u540e\u5c06\u8ba8\u8bba\u4ea4\u53c9\u9a8c\u8bc1\u3002\u8bf7\u6ce8\u610f\uff0c\u4e0a\u8ff0\u4ee3\u7801\u662f\u5728 jupyter \u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\u7684\u3002 \u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 jupyter \u505a\u4e00\u4e9b\u7b80\u5355\u7684\u4e8b\u60c5\uff0c\u6bd4\u5982\u4e0a\u9762\u7684\u4f8b\u5b50\u548c \u7ed8\u56fe\u3002\u5bf9\u4e8e\u672c\u4e66\u4e2d\u7684\u5927\u90e8\u5206\u5185\u5bb9\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 python \u811a\u672c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6 IDE \u56e0\u4e3a\u7ed3\u679c\u90fd\u662f\u4e00\u6837\u7684\u3002 MNIST \u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u628a\u5b83\u8f6c\u6362\u6210\u4e00\u4e2a\u65e0\u76d1\u7763\u7684\u95ee\u9898\uff0c\u53ea\u662f\u4e3a\u4e86\u68c0\u67e5\u5b83\u662f\u5426\u80fd\u5e26\u6765\u4efb\u4f55\u597d\u7684\u7ed3\u679c\u3002\u5982\u679c\u6211\u4eec\u4f7f\u7528\u5206\u7c7b\u7b97\u6cd5\uff0c\u6548\u679c\u4f1a\u66f4\u597d\u3002\u8ba9\u6211\u4eec\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u4e00\u63a2\u7a76\u7adf\u3002","title":"\u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60"},{"location":"%E7%89%B9%E5%BE%81%E5%B7%A5%E7%A8%8B/","text":"\u7279\u5f81\u5de5\u7a0b \u7279\u5f81\u5de5\u7a0b\u662f\u6784\u5efa\u826f\u597d\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u90e8\u5206\u4e4b\u4e00\u3002\u5982\u679c\u6211\u4eec\u62e5\u6709\u6709\u7528\u7684\u7279\u5f81\uff0c\u6a21\u578b\u5c31\u4f1a\u8868\u73b0\u5f97\u66f4\u597d\u3002\u5728\u8bb8\u591a\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u907f\u514d\u4f7f\u7528\u5927\u578b\u590d\u6742\u6a21\u578b\uff0c\u800c\u4f7f\u7528\u5177\u6709\u5173\u952e\u5de5\u7a0b\u7279\u5f81\u7684\u7b80\u5355\u6a21\u578b\u3002\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\uff0c\u53ea\u6709\u5f53\u4f60\u5bf9\u95ee\u9898\u7684\u9886\u57df\u6709\u4e00\u5b9a\u7684\u4e86\u89e3\uff0c\u5e76\u4e14\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u76f8\u5173\u6570\u636e\u65f6\uff0c\u624d\u80fd\u4ee5\u6700\u4f73\u65b9\u5f0f\u5b8c\u6210\u7279\u5f81\u5de5\u7a0b\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e9b\u901a\u7528\u6280\u672f\uff0c\u4ece\u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u7279\u5f81\u3002\u7279\u5f81\u5de5\u7a0b\u4e0d\u4ec5\u4ec5\u662f\u4ece\u6570\u636e\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\uff0c\u8fd8\u5305\u62ec\u4e0d\u540c\u7c7b\u578b\u7684\u5f52\u4e00\u5316\u548c\u8f6c\u6362\u3002 \u5728\u6709\u5173\u5206\u7c7b\u7279\u5f81\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u4e86\u89e3\u4e86\u7ed3\u5408\u4e0d\u540c\u5206\u7c7b\u53d8\u91cf\u7684\u65b9\u6cd5\u3001\u5982\u4f55\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u8ba1\u6570\u3001\u6807\u7b7e\u7f16\u7801\u548c\u4f7f\u7528\u5d4c\u5165\u3002\u8fd9\u4e9b\u51e0\u4e4e\u90fd\u662f\u5229\u7528\u5206\u7c7b\u53d8\u91cf\u8bbe\u8ba1\u7279\u5f81\u7684\u65b9\u6cd5\u3002\u56e0\u6b64\uff0c\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u7684\u91cd\u70b9\u5c06\u4ec5\u9650\u4e8e\u6570\u503c\u53d8\u91cf\u4ee5\u53ca\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u7684\u7ec4\u5408\u3002 \u8ba9\u6211\u4eec\u4ece\u6700\u7b80\u5355\u4f46\u5e94\u7528\u6700\u5e7f\u6cdb\u7684\u7279\u5f81\u5de5\u7a0b\u6280\u672f\u5f00\u59cb\u3002\u5047\u8bbe\u4f60\u6b63\u5728\u5904\u7406\u65e5\u671f\u548c\u65f6\u95f4\u6570\u636e\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709\u4e00\u4e2a\u5e26\u6709\u65e5\u671f\u7c7b\u578b\u5217\u7684 pandas \u6570\u636e\u5e27\u3002\u5229\u7528\u8fd9\u4e00\u5217\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4ee5\u4e0b\u7279\u5f81\uff1a \u5e74 \u5e74\u4e2d\u7684\u5468 \u6708 \u661f\u671f \u5468\u672b \u5c0f\u65f6 \u8fd8\u6709\u66f4\u591a \u800c\u4f7f\u7528pandas\u5c31\u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u6dfb\u52a0'year'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5e74\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'year' ] = df [ 'datetime_column' ] . dt . year # \u6dfb\u52a0'weekofyear'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5468\u6570\u63d0\u53d6\u51fa\u6765 df . loc [:, 'weekofyear' ] = df [ 'datetime_column' ] . dt . weekofyear # \u6dfb\u52a0'month'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u6708\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'month' ] = df [ 'datetime_column' ] . dt . month # \u6dfb\u52a0'dayofweek'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u661f\u671f\u51e0\u63d0\u53d6\u51fa\u6765 df . loc [:, 'dayofweek' ] = df [ 'datetime_column' ] . dt . dayofweek # \u6dfb\u52a0'weekend'\u5217\uff0c\u5224\u65ad\u5f53\u5929\u662f\u5426\u4e3a\u5468\u672b df . loc [:, 'weekend' ] = ( df . datetime_column . dt . weekday >= 5 ) . astype ( int ) # \u6dfb\u52a0 'hour' \u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5c0f\u65f6\u63d0\u53d6\u51fa\u6765 df . loc [:, 'hour' ] = df [ 'datetime_column' ] . dt . hour \u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u65e5\u671f\u65f6\u95f4\u5217\u521b\u5efa\u4e00\u7cfb\u5217\u65b0\u5217\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u53ef\u4ee5\u521b\u5efa\u7684\u4e00\u4e9b\u793a\u4f8b\u529f\u80fd\u3002 import pandas as pd # \u521b\u5efa\u65e5\u671f\u65f6\u95f4\u5e8f\u5217\uff0c\u5305\u542b\u4e86\u4ece '2020-01-06' \u5230 '2020-01-10' \u7684\u65e5\u671f\u65f6\u95f4\u70b9\uff0c\u65f6\u95f4\u95f4\u9694\u4e3a10\u5c0f\u65f6 s = pd . date_range ( '2020-01-06' , '2020-01-10' , freq = '10H' ) . to_series () # \u63d0\u53d6\u5bf9\u5e94\u65f6\u95f4\u7279\u5f81 features = { \"dayofweek\" : s . dt . dayofweek . values , \"dayofyear\" : s . dt . dayofyear . values , \"hour\" : s . dt . hour . values , \"is_leap_year\" : s . dt . is_leap_year . values , \"quarter\" : s . dt . quarter . values , \"weekofyear\" : s . dt . weekofyear . values } \u8fd9\u5c06\u4ece\u7ed9\u5b9a\u7cfb\u5217\u4e2d\u751f\u6210\u4e00\u4e2a\u7279\u5f81\u5b57\u5178\u3002\u60a8\u53ef\u4ee5\u5c06\u6b64\u5e94\u7528\u4e8e pandas \u6570\u636e\u4e2d\u7684\u4efb\u4f55\u65e5\u671f\u65f6\u95f4\u5217\u3002\u8fd9\u4e9b\u662f pandas \u63d0\u4f9b\u7684\u4f17\u591a\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u4e2d\u7684\u4e00\u90e8\u5206\u3002\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u975e\u5e38\u91cd\u8981\uff0c\u4f8b\u5982\uff0c\u5728\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97\u7684\u9500\u552e\u989d\u65f6\uff0c\u5982\u679c\u60f3\u5728\u805a\u5408\u7279\u5f81\u4e0a\u4f7f\u7528 xgboost \u7b49\u6a21\u578b\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u5c31\u975e\u5e38\u91cd\u8981\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u4e0b\u6240\u793a\u7684\u6570\u636e\uff1a \u56fe 1\uff1a\u5305\u542b\u5206\u7c7b\u548c\u65e5\u671f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u6709\u4e00\u4e2a\u65e5\u671f\u5217\uff0c\u4ece\u4e2d\u53ef\u4ee5\u8f7b\u677e\u63d0\u53d6\u5e74\u3001\u6708\u3001\u5b63\u5ea6\u7b49\u7279\u5f81\u3002\u7136\u540e\uff0c\u6211\u4eec\u6709\u4e00\u4e2a customer_id \u5217\uff0c\u8be5\u5217\u6709\u591a\u4e2a\u6761\u76ee\uff0c\u56e0\u6b64\u4e00\u4e2a\u5ba2\u6237\u4f1a\u88ab\u770b\u5230\u5f88\u591a\u6b21\uff08\u622a\u56fe\u4e2d\u770b\u4e0d\u5230\uff09\u3002\u6bcf\u4e2a\u65e5\u671f\u548c\u5ba2\u6237 ID \u90fd\u6709\u4e09\u4e2a\u5206\u7c7b\u7279\u5f81\u548c\u4e00\u4e2a\u6570\u5b57\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u4ece\u4e2d\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff1a - \u5ba2\u6237\u6700\u6d3b\u8dc3\u7684\u6708\u4efd\u662f\u51e0\u6708 - \u67d0\u4e2a\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u7684\u8ba1\u6570\u662f\u591a\u5c11 - \u67d0\u5e74\u67d0\u6708\u67d0\u5468\u67d0\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u6570\u91cf\u662f\u591a\u5c11\uff1f - \u67d0\u4e2a\u5ba2\u6237\u7684 num1 \u5e73\u5747\u503c\u662f\u591a\u5c11\uff1f - \u7b49\u7b49\u3002 \u4f7f\u7528 pandas \u4e2d\u7684\u805a\u5408\uff0c\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u521b\u5efa\u7c7b\u4f3c\u7684\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u5b9e\u73b0\u3002 def generate_features ( df ): df . loc [:, 'year' ] = df [ 'date' ] . dt . year df . loc [:, 'weekofyear' ] = df [ 'date' ] . dt . weekofyear df . loc [:, 'month' ] = df [ 'date' ] . dt . month df . loc [:, 'dayofweek' ] = df [ 'date' ] . dt . dayofweek df . loc [:, 'weekend' ] = ( df [ 'date' ] . dt . weekday >= 5 ) . astype ( int ) aggs = {} # \u5bf9 'month' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'month' ] = [ 'nunique' , 'mean' ] # \u5bf9 'weekofyear' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'weekofyear' ] = [ 'nunique' , 'mean' ] # \u5bf9 'num1' \u5217\u8fdb\u884c sum\u3001max\u3001min\u3001mean \u805a\u5408 aggs [ 'num1' ] = [ 'sum' , 'max' , 'min' , 'mean' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c size \u805a\u5408 aggs [ 'customer_id' ] = [ 'size' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c nunique \u805a\u5408 aggs [ 'customer_id' ] = [ 'nunique' ] # \u5bf9\u6570\u636e\u5e94\u7528\u4e0d\u540c\u7684\u805a\u5408\u51fd\u6570 agg_df = df . groupby ( 'customer_id' ) . agg ( aggs ) # \u91cd\u7f6e\u7d22\u5f15 agg_df = agg_df . reset_index () return agg_df \u8bf7\u6ce8\u610f\uff0c\u5728\u4e0a\u8ff0\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u8df3\u8fc7\u4e86\u5206\u7c7b\u53d8\u91cf\uff0c\u4f46\u60a8\u53ef\u4ee5\u50cf\u4f7f\u7528\u5176\u4ed6\u805a\u5408\u53d8\u91cf\u4e00\u6837\u4f7f\u7528\u5b83\u4eec\u3002 \u56fe 2\uff1a\u603b\u4f53\u7279\u5f81\u548c\u5176\u4ed6\u7279\u5f81 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u56fe 2 \u4e2d\u7684\u6570\u636e\u4e0e\u5e26\u6709 customer_id \u5217\u7684\u539f\u59cb\u6570\u636e\u5e27\u8fde\u63a5\u8d77\u6765\uff0c\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5e76\u4e0d\u662f\u8981\u9884\u6d4b\u4ec0\u4e48\uff1b\u6211\u4eec\u53ea\u662f\u5728\u521b\u5efa\u901a\u7528\u7279\u5f81\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u8bd5\u56fe\u5728\u8fd9\u91cc\u9884\u6d4b\u4ec0\u4e48\uff0c\u521b\u5efa\u7279\u5f81\u4f1a\u66f4\u5bb9\u6613\u3002 \u4f8b\u5982\uff0c\u6709\u65f6\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u95ee\u9898\u65f6\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u7684\u7279\u5f81\u4e0d\u662f\u5355\u4e2a\u503c\uff0c\u800c\u662f\u4e00\u7cfb\u5217\u503c\u3002 \u4f8b\u5982\uff0c\u5ba2\u6237\u5728\u7279\u5b9a\u65f6\u95f4\u6bb5\u5185\u7684\u4ea4\u6613\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4f1a\u521b\u5efa\u4e0d\u540c\u7c7b\u578b\u7684\u7279\u5f81\uff0c\u4f8b\u5982\uff1a\u4f7f\u7528\u6570\u503c\u7279\u5f81\u65f6\uff0c\u5728\u5bf9\u5206\u7c7b\u5217\u8fdb\u884c\u5206\u7ec4\u65f6\uff0c\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e8e\u65f6\u95f4\u5206\u5e03\u503c\u5217\u8868\u7684\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u7cfb\u5217\u7edf\u8ba1\u7279\u5f81\uff0c\u4f8b\u5982 \u5e73\u5747\u503c \u6700\u5927\u503c \u6700\u5c0f\u503c \u72ec\u7279\u6027 \u504f\u659c \u5cf0\u5ea6 Kstat \u767e\u5206\u4f4d\u6570 \u5b9a\u91cf \u5cf0\u503c\u5230\u5cf0\u503c \u4ee5\u53ca\u66f4\u591a \u8fd9\u4e9b\u53ef\u4ee5\u4f7f\u7528\u7b80\u5355\u7684 numpy \u51fd\u6570\u521b\u5efa\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 import numpy as np # \u521b\u5efa\u5b57\u5178\uff0c\u7528\u4e8e\u5b58\u50a8\u4e0d\u540c\u7684\u7edf\u8ba1\u7279\u5f81 feature_dict = {} # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5e73\u5747\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'mean' \u952e\u4e0b feature_dict [ 'mean' ] = np . mean ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5927\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'max' \u952e\u4e0b feature_dict [ 'max' ] = np . max ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5c0f\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'min' \u952e\u4e0b feature_dict [ 'min' ] = np . min ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6807\u51c6\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'std' \u952e\u4e0b feature_dict [ 'std' ] = np . std ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u65b9\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'var' \u952e\u4e0b feature_dict [ 'var' ] = np . var ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5dee\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'ptp' \u952e\u4e0b feature_dict [ 'ptp' ] = np . ptp ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c10\u767e\u5206\u4f4d\u6570\uff08\u5373\u767e\u5206\u4e4b10\u5206\u4f4d\u6570\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_10' \u952e\u4e0b feature_dict [ 'percentile_10' ] = np . percentile ( x , 10 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c60\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_60' \u952e\u4e0b feature_dict [ 'percentile_60' ] = np . percentile ( x , 60 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c90\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_90' \u952e\u4e0b feature_dict [ 'percentile_90' ] = np . percentile ( x , 90 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u76845%\u5206\u4f4d\u6570\uff08\u53730.05\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_5' \u952e\u4e0b feature_dict [ 'quantile_5' ] = np . quantile ( x , 0.05 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768495%\u5206\u4f4d\u6570\uff08\u53730.95\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_95' \u952e\u4e0b feature_dict [ 'quantile_95' ] = np . quantile ( x , 0.95 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768499%\u5206\u4f4d\u6570\uff08\u53730.99\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_99' \u952e\u4e0b feature_dict [ 'quantile_99' ] = np . quantile ( x , 0.99 ) \u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff08\u6570\u503c\u5217\u8868\uff09\u53ef\u4ee5\u8f6c\u6362\u6210\u8bb8\u591a\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4e00\u4e2a\u540d\u4e3a tsfresh \u7684 python \u5e93\u975e\u5e38\u6709\u7528\u3002 from tsfresh.feature_extraction import feature_calculators as fc # \u8ba1\u7b97 x \u6570\u5217\u7684\u7edd\u5bf9\u80fd\u91cf\uff08abs_energy\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'abs_energy' \u952e\u4e0b feature_dict [ 'abs_energy' ] = fc . abs_energy ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u9ad8\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_above_mean' \u952e\u4e0b feature_dict [ 'count_above_mean' ] = fc . count_above_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u4f4e\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_below_mean' \u952e\u4e0b feature_dict [ 'count_below_mean' ] = fc . count_below_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u7edd\u5bf9\u53d8\u5316\uff08mean_abs_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_abs_change' \u952e\u4e0b feature_dict [ 'mean_abs_change' ] = fc . mean_abs_change ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u53d8\u5316\u7387\uff08mean_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_change' \u952e\u4e0b feature_dict [ 'mean_change' ] = fc . mean_change ( x ) \u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\uff1btsfresh \u63d0\u4f9b\u4e86\u6570\u767e\u79cd\u7279\u5f81\u548c\u6570\u5341\u79cd\u4e0d\u540c\u7279\u5f81\u7684\u53d8\u4f53\uff0c\u4f60\u53ef\u4ee5\u5c06\u5b83\u4eec\u7528\u4e8e\u57fa\u4e8e\u65f6\u95f4\u5e8f\u5217\uff08\u503c\u5217\u8868\uff09\u7684\u7279\u5f81\u3002\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0cx \u662f\u4e00\u4e2a\u503c\u5217\u8868\u3002\u4f46\u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\u3002\u60a8\u8fd8\u53ef\u4ee5\u4e3a\u5305\u542b\u6216\u4e0d\u5305\u542b\u5206\u7c7b\u6570\u636e\u7684\u6570\u503c\u6570\u636e\u521b\u5efa\u8bb8\u591a\u5176\u4ed6\u7279\u5f81\u3002\u751f\u6210\u8bb8\u591a\u7279\u5f81\u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u521b\u5efa\u4e00\u5806\u591a\u9879\u5f0f\u7279\u5f81\u3002\u4f8b\u5982\uff0c\u4ece\u4e24\u4e2a\u7279\u5f81 \"a \"\u548c \"b \"\u751f\u6210\u7684\u4e8c\u7ea7\u591a\u9879\u5f0f\u7279\u5f81\u5305\u62ec \"a\"\u3001\"b\"\u3001\"ab\"\u3001\"a^2 \"\u548c \"b^2\"\u3002 import numpy as np df = pd . DataFrame ( np . random . rand ( 100 , 2 ), columns = [ f \"f_ { i } \" for i in range ( 1 , 3 )]) \u5982\u56fe 3 \u6240\u793a\uff0c\u5b83\u7ed9\u51fa\u4e86\u4e00\u4e2a\u6570\u636e\u8868\u3002 \u56fe 3\uff1a\u5305\u542b\u4e24\u4e2a\u6570\u5b57\u7279\u5f81\u7684\u968f\u673a\u6570\u636e\u8868 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u7684 PolynomialFeatures \u521b\u5efa\u4e24\u6b21\u591a\u9879\u5f0f\u7279\u5f81\u3002 from sklearn import preprocessing # \u6307\u5b9a\u591a\u9879\u5f0f\u7684\u6b21\u6570\u4e3a 2\uff0c\u4e0d\u4ec5\u8003\u8651\u4ea4\u4e92\u9879\uff0c\u4e0d\u5305\u62ec\u504f\u5dee\uff08include_bias=False\uff09 pf = preprocessing . PolynomialFeatures ( degree = 2 , interaction_only = False , include_bias = False ) # \u62df\u5408\uff0c\u521b\u5efa\u591a\u9879\u5f0f\u7279\u5f81 pf . fit ( df ) # \u8f6c\u6362\u6570\u636e poly_feats = pf . transform ( df ) # \u83b7\u53d6\u751f\u6210\u7684\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf num_feats = poly_feats . shape [ 1 ] # \u4e3a\u65b0\u751f\u6210\u7684\u7279\u5f81\u547d\u540d df_transformed = pd . DataFrame ( poly_feats , columns = [ f \"f_ { i } \" for i in range ( 1 , num_feats + 1 )] ) \u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u6570\u636e\u8868\uff0c\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1a\u5e26\u6709\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e\u8868 \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e9b\u591a\u9879\u5f0f\u7279\u5f81\u3002\u5982\u679c\u521b\u5efa\u7684\u662f\u4e09\u6b21\u591a\u9879\u5f0f\u7279\u5f81\uff0c\u6700\u7ec8\u603b\u5171\u4f1a\u6709\u4e5d\u4e2a\u7279\u5f81\u3002\u7279\u5f81\u7684\u6570\u91cf\u8d8a\u591a\uff0c\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf\u4e5f\u5c31\u8d8a\u591a\uff0c\u800c\u4e14\u4f60\u8fd8\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5982\u679c\u6570\u636e\u96c6\u4e2d\u6709\u5f88\u591a\u6837\u672c\uff0c\u90a3\u4e48\u521b\u5efa\u8fd9\u7c7b\u7279\u5f81\u5c31\u9700\u8981\u82b1\u8d39\u4e00\u4e9b\u65f6\u95f4\u3002 \u56fe 5\uff1a\u6570\u5b57\u7279\u5f81\u5217\u7684\u76f4\u65b9\u56fe \u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u529f\u80fd\u662f\u5c06\u6570\u5b57\u8f6c\u6362\u4e3a\u7c7b\u522b\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \u5206\u7bb1 \u3002\u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u56fe 5\uff0c\u5b83\u663e\u793a\u4e86\u4e00\u4e2a\u968f\u673a\u6570\u5b57\u7279\u5f81\u7684\u6837\u672c\u76f4\u65b9\u56fe\u3002\u6211\u4eec\u5728\u8be5\u56fe\u4e2d\u4f7f\u7528\u4e8610\u4e2a\u5206\u7bb1\uff0c\u53ef\u4ee5\u770b\u5230\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a10\u4e2a\u90e8\u5206\u3002\u8fd9\u53ef\u4ee5\u4f7f\u7528 pandas \u7684cat\u51fd\u6570\u6765\u5b9e\u73b0\u3002 # \u521b\u5efa10\u4e2a\u5206\u7bb1 df [ \"f_bin_10\" ] = pd . cut ( df [ \"f_1\" ], bins = 10 , labels = False ) # \u521b\u5efa100\u4e2a\u5206\u7bb1 df [ \"f_bin_100\" ] = pd . cut ( df [ \"f_1\" ], bins = 100 , labels = False ) \u5982\u56fe 6 \u6240\u793a\uff0c\u8fd9\u5c06\u5728\u6570\u636e\u5e27\u4e2d\u751f\u6210\u4e24\u4e2a\u65b0\u7279\u5f81\u3002 \u56fe 6\uff1a\u6570\u503c\u7279\u5f81\u5206\u7bb1 \u5f53\u4f60\u8fdb\u884c\u5206\u7c7b\u65f6\uff0c\u53ef\u4ee5\u540c\u65f6\u4f7f\u7528\u5206\u7bb1\u548c\u539f\u59cb\u7279\u5f81\u3002\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u540e\u534a\u90e8\u5206\u5b66\u4e60\u66f4\u591a\u5173\u4e8e\u9009\u62e9\u7279\u5f81\u7684\u77e5\u8bc6\u3002\u5206\u7bb1\u8fd8\u53ef\u4ee5\u5c06\u6570\u5b57\u7279\u5f81\u89c6\u4e3a\u5206\u7c7b\u7279\u5f81\u3002 \u53e6\u4e00\u79cd\u53ef\u4ee5\u4ece\u6570\u503c\u7279\u5f81\u4e2d\u521b\u5efa\u7684\u6709\u8da3\u7279\u5f81\u7c7b\u578b\u662f\u5bf9\u6570\u53d8\u6362\u3002\u8bf7\u770b\u56fe 7 \u4e2d\u7684\u7279\u5f81 f_3\u3002 \u4e0e\u5176\u4ed6\u65b9\u5dee\u8f83\u5c0f\u7684\u7279\u5f81\u76f8\u6bd4\uff08\u5047\u8bbe\u5982\u6b64\uff09\uff0cf_3 \u662f\u4e00\u79cd\u65b9\u5dee\u975e\u5e38\u5927\u7684\u7279\u6b8a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5e0c\u671b\u964d\u4f4e\u8fd9\u4e00\u5217\u7684\u65b9\u5dee\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6570\u53d8\u6362\u6765\u5b9e\u73b0\u3002 f_3 \u5217\u7684\u503c\u8303\u56f4\u4e3a 0 \u5230 10000\uff0c\u76f4\u65b9\u56fe\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u7279\u5f81 f_3 \u7684\u76f4\u65b9\u56fe \u6211\u4eec\u53ef\u4ee5\u5bf9\u8fd9\u4e00\u5217\u5e94\u7528 log(1 + x) \u6765\u51cf\u5c11\u5176\u65b9\u5dee\u3002\u56fe 9 \u663e\u793a\u4e86\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u76f4\u65b9\u56fe\u7684\u53d8\u5316\u3002 \u56fe 9\uff1a\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u7684 f_3 \u76f4\u65b9\u56fe \u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e0d\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u548c\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u7684\u65b9\u5dee\u3002 In [ X ]: df . f_3 . var () Out [ X ]: 8077265.875858586 In [ X ]: df . f_3 . apply ( lambda x : np . log ( 1 + x )) . var () Out [ X ]: 0.6058771732119975 \u6709\u65f6\uff0c\u4e5f\u53ef\u4ee5\u7528\u6307\u6570\u6765\u4ee3\u66ff\u5bf9\u6570\u3002\u4e00\u79cd\u975e\u5e38\u6709\u8da3\u7684\u60c5\u51b5\u662f\uff0c\u60a8\u4f7f\u7528\u57fa\u4e8e\u5bf9\u6570\u7684\u8bc4\u4f30\u6307\u6807\uff0c\u4f8b\u5982 RMSLE\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u5728\u5bf9\u6570\u53d8\u6362\u7684\u76ee\u6807\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u7136\u540e\u5728\u9884\u6d4b\u65f6\u4f7f\u7528\u6307\u6570\u503c\u8f6c\u6362\u56de\u539f\u59cb\u503c\u3002\u8fd9\u5c06\u6709\u52a9\u4e8e\u9488\u5bf9\u6307\u6807\u4f18\u5316\u6a21\u578b\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u7c7b\u6570\u5b57\u7279\u5f81\u90fd\u662f\u57fa\u4e8e\u76f4\u89c9\u521b\u5efa\u7684\u3002\u6ca1\u6709\u516c\u5f0f\u53ef\u5faa\u3002\u5982\u679c\u60a8\u4ece\u4e8b\u7684\u662f\u67d0\u4e00\u884c\u4e1a\uff0c\u60a8\u5c06\u521b\u5efa\u7279\u5b9a\u884c\u4e1a\u7684\u7279\u5f81\u3002 \u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u548c\u6570\u503c\u53d8\u91cf\u65f6\uff0c\u53ef\u80fd\u4f1a\u9047\u5230\u7f3a\u5931\u503c\u3002\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e9b\u5904\u7406\u5206\u7c7b\u7279\u5f81\u4e2d\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\uff0c\u4f46\u8fd8\u6709\u66f4\u591a\u65b9\u6cd5\u53ef\u4ee5\u5904\u7406\u7f3a\u5931\u503c/NaN \u503c\u3002\u8fd9\u4e5f\u88ab\u89c6\u4e3a\u7279\u5f81\u5de5\u7a0b\u3002 \u5982\u679c\u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u9047\u5230\u7f3a\u5931\u503c\uff0c\u5c31\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\uff01\u8fd9\u6837\u505a\u867d\u7136\u7b80\u5355\uff0c\u4f46\uff08\u51e0\u4e4e\uff09\u603b\u662f\u6709\u6548\u7684\uff01 \u5728\u6570\u503c\u6570\u636e\u4e2d\u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u9009\u62e9\u4e00\u4e2a\u5728\u7279\u5b9a\u7279\u5f81\u4e2d\u6ca1\u6709\u51fa\u73b0\u7684\u503c\uff0c\u7136\u540e\u7528\u5b83\u6765\u586b\u8865\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u7279\u5f81\u4e2d\u6ca1\u6709 0\u3002\u8fd9\u662f\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\uff0c\u4f46\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\u3002\u5bf9\u4e8e\u6570\u503c\u6570\u636e\u6765\u8bf4\uff0c\u6bd4\u586b\u5145 0 \u66f4\u6709\u6548\u7684\u65b9\u6cd5\u4e4b\u4e00\u662f\u4f7f\u7528\u5e73\u5747\u503c\u8fdb\u884c\u586b\u5145\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8be5\u7279\u5f81\u6240\u6709\u503c\u7684\u4e2d\u4f4d\u6570\u6765\u586b\u5145\uff0c\u6216\u8005\u4f7f\u7528\u6700\u5e38\u89c1\u7684\u503c\u6765\u586b\u5145\u7f3a\u5931\u503c\u3002\u8fd9\u6837\u505a\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002 \u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u9ad8\u7ea7\u65b9\u6cd5\u662f\u4f7f\u7528 K \u8fd1\u90bb\u6cd5 \u3002 \u60a8\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\uff0c\u7136\u540e\u5229\u7528\u67d0\u79cd\u8ddd\u79bb\u5ea6\u91cf\uff08\u4f8b\u5982\u6b27\u6c0f\u8ddd\u79bb\uff09\u627e\u5230\u6700\u8fd1\u7684\u90bb\u5c45\u3002\u7136\u540e\u53d6\u6240\u6709\u8fd1\u90bb\u7684\u5e73\u5747\u503c\u6765\u586b\u8865\u7f3a\u5931\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 KNN \u6765\u586b\u8865\u8fd9\u6837\u7684\u7f3a\u5931\u503c\u3002 \u56fe 10\uff1a\u6709\u7f3a\u5931\u503c\u7684\u4e8c\u7ef4\u6570\u7ec4 \u8ba9\u6211\u4eec\u770b\u770b KNN \u662f\u5982\u4f55\u5904\u7406\u56fe 10 \u6240\u793a\u7684\u7f3a\u5931\u503c\u77e9\u9635\u7684\u3002 import numpy as np from sklearn import impute # \u751f\u6210\u7ef4\u5ea6\u4e3a (10, 6) \u7684\u968f\u673a\u6574\u6570\u77e9\u9635 X\uff0c\u6570\u503c\u8303\u56f4\u5728 1 \u5230 14 \u4e4b\u95f4 X = np . random . randint ( 1 , 15 , ( 10 , 6 )) # \u6570\u636e\u7c7b\u578b\u8f6c\u6362\u4e3a float X = X . astype ( float ) # \u5728\u77e9\u9635 X \u4e2d\u968f\u673a\u9009\u62e9 10 \u4e2a\u4f4d\u7f6e\uff0c\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u5143\u7d20\u8bbe\u7f6e\u4e3a NaN\uff08\u7f3a\u5931\u503c\uff09 X . ravel ()[ np . random . choice ( X . size , 10 , replace = False )] = np . nan # \u521b\u5efa\u4e00\u4e2a KNNImputer \u5bf9\u8c61 knn_imputer\uff0c\u6307\u5b9a\u90bb\u5c45\u6570\u91cf\u4e3a 2 knn_imputer = impute . KNNImputer ( n_neighbors = 2 ) # # \u4f7f\u7528 knn_imputer \u5bf9\u77e9\u9635 X \u8fdb\u884c\u62df\u5408\u548c\u8f6c\u6362\uff0c\u7528 K-\u6700\u8fd1\u90bb\u65b9\u6cd5\u586b\u8865\u7f3a\u5931\u503c knn_imputer . fit_transform ( X ) \u5982\u56fe 11 \u6240\u793a\uff0c\u5b83\u586b\u5145\u4e86\u4e0a\u8ff0\u77e9\u9635\u3002 \u56fe 11\uff1aKNN\u4f30\u7b97\u7684\u6570\u503c \u53e6\u4e00\u79cd\u5f25\u8865\u5217\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\u662f\u8bad\u7ec3\u56de\u5f52\u6a21\u578b\uff0c\u8bd5\u56fe\u6839\u636e\u5176\u4ed6\u5217\u9884\u6d4b\u67d0\u5217\u7684\u7f3a\u5931\u503c\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u4ece\u6709\u7f3a\u5931\u503c\u7684\u4e00\u5217\u5f00\u59cb\uff0c\u5c06\u8fd9\u4e00\u5217\u4f5c\u4e3a\u65e0\u7f3a\u5931\u503c\u56de\u5f52\u6a21\u578b\u7684\u76ee\u6807\u5217\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6240\u6709\u5176\u4ed6\u5217\uff0c\u5bf9\u76f8\u5173\u5217\u4e2d\u6ca1\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\uff0c\u7136\u540e\u5c1d\u8bd5\u9884\u6d4b\u4e4b\u524d\u5220\u9664\u7684\u6837\u672c\u7684\u76ee\u6807\u5217\uff08\u540c\u4e00\u5217\uff09\u3002\u8fd9\u6837\uff0c\u57fa\u4e8e\u6a21\u578b\u7684\u4f30\u7b97\u5c31\u4f1a\u66f4\u52a0\u7a33\u5065\u3002 \u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u5bf9\u4e8e\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c\u6570\u503c\u5f52\u4e00\u5316\uff0c\u56e0\u4e3a\u5b83\u4eec\u53ef\u4ee5\u81ea\u884c\u5904\u7406\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u6240\u5c55\u793a\u7684\u53ea\u662f\u521b\u5efa\u4e00\u822c\u7279\u5f81\u7684\u4e00\u4e9b\u65b9\u6cd5\u3002\u73b0\u5728\uff0c\u5047\u8bbe\u60a8\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u9884\u6d4b\u4e0d\u540c\u5546\u54c1\uff08\u6bcf\u5468\u6216\u6bcf\u6708\uff09\u5546\u5e97\u9500\u552e\u989d\u7684\u95ee\u9898\u3002\u60a8\u6709\u5546\u54c1\uff0c\u4e5f\u6709\u5546\u5e97 ID\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u6bcf\u4e2a\u5546\u5e97\u7684\u5546\u54c1\u7b49\u7279\u5f81\u3002\u73b0\u5728\uff0c\u8fd9\u662f\u4e0a\u6587\u6ca1\u6709\u8ba8\u8bba\u7684\u7279\u5f81\u4e4b\u4e00\u3002\u8fd9\u7c7b\u7279\u5f81\u4e0d\u80fd\u4e00\u6982\u800c\u8bba\uff0c\u5b8c\u5168\u6765\u81ea\u4e8e\u9886\u57df\u3001\u6570\u636e\u548c\u4e1a\u52a1\u77e5\u8bc6\u3002\u67e5\u770b\u6570\u636e\uff0c\u627e\u51fa\u9002\u5408\u7684\u7279\u5f81\uff0c\u7136\u540e\u521b\u5efa\u76f8\u5e94\u7684\u7279\u5f81\u3002\u5982\u679c\u60a8\u4f7f\u7528\u7684\u662f\u903b\u8f91\u56de\u5f52\u7b49\u7ebf\u6027\u6a21\u578b\u6216 SVM \u7b49\u6a21\u578b\uff0c\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\u5bf9\u7279\u5f81\u8fdb\u884c\u7f29\u653e\u6216\u5f52\u4e00\u5316\u5904\u7406\u3002\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65e0\u9700\u5bf9\u7279\u5f81\u8fdb\u884c\u4efb\u4f55\u5f52\u4e00\u5316\u5904\u7406\u5373\u53ef\u6b63\u5e38\u5de5\u4f5c\u3002","title":"\u7279\u5f81\u5de5\u7a0b"},{"location":"%E7%89%B9%E5%BE%81%E5%B7%A5%E7%A8%8B/#_1","text":"\u7279\u5f81\u5de5\u7a0b\u662f\u6784\u5efa\u826f\u597d\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u90e8\u5206\u4e4b\u4e00\u3002\u5982\u679c\u6211\u4eec\u62e5\u6709\u6709\u7528\u7684\u7279\u5f81\uff0c\u6a21\u578b\u5c31\u4f1a\u8868\u73b0\u5f97\u66f4\u597d\u3002\u5728\u8bb8\u591a\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u907f\u514d\u4f7f\u7528\u5927\u578b\u590d\u6742\u6a21\u578b\uff0c\u800c\u4f7f\u7528\u5177\u6709\u5173\u952e\u5de5\u7a0b\u7279\u5f81\u7684\u7b80\u5355\u6a21\u578b\u3002\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\uff0c\u53ea\u6709\u5f53\u4f60\u5bf9\u95ee\u9898\u7684\u9886\u57df\u6709\u4e00\u5b9a\u7684\u4e86\u89e3\uff0c\u5e76\u4e14\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u76f8\u5173\u6570\u636e\u65f6\uff0c\u624d\u80fd\u4ee5\u6700\u4f73\u65b9\u5f0f\u5b8c\u6210\u7279\u5f81\u5de5\u7a0b\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e9b\u901a\u7528\u6280\u672f\uff0c\u4ece\u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u7279\u5f81\u3002\u7279\u5f81\u5de5\u7a0b\u4e0d\u4ec5\u4ec5\u662f\u4ece\u6570\u636e\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\uff0c\u8fd8\u5305\u62ec\u4e0d\u540c\u7c7b\u578b\u7684\u5f52\u4e00\u5316\u548c\u8f6c\u6362\u3002 \u5728\u6709\u5173\u5206\u7c7b\u7279\u5f81\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u4e86\u89e3\u4e86\u7ed3\u5408\u4e0d\u540c\u5206\u7c7b\u53d8\u91cf\u7684\u65b9\u6cd5\u3001\u5982\u4f55\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u8ba1\u6570\u3001\u6807\u7b7e\u7f16\u7801\u548c\u4f7f\u7528\u5d4c\u5165\u3002\u8fd9\u4e9b\u51e0\u4e4e\u90fd\u662f\u5229\u7528\u5206\u7c7b\u53d8\u91cf\u8bbe\u8ba1\u7279\u5f81\u7684\u65b9\u6cd5\u3002\u56e0\u6b64\uff0c\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u7684\u91cd\u70b9\u5c06\u4ec5\u9650\u4e8e\u6570\u503c\u53d8\u91cf\u4ee5\u53ca\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u7684\u7ec4\u5408\u3002 \u8ba9\u6211\u4eec\u4ece\u6700\u7b80\u5355\u4f46\u5e94\u7528\u6700\u5e7f\u6cdb\u7684\u7279\u5f81\u5de5\u7a0b\u6280\u672f\u5f00\u59cb\u3002\u5047\u8bbe\u4f60\u6b63\u5728\u5904\u7406\u65e5\u671f\u548c\u65f6\u95f4\u6570\u636e\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709\u4e00\u4e2a\u5e26\u6709\u65e5\u671f\u7c7b\u578b\u5217\u7684 pandas \u6570\u636e\u5e27\u3002\u5229\u7528\u8fd9\u4e00\u5217\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4ee5\u4e0b\u7279\u5f81\uff1a \u5e74 \u5e74\u4e2d\u7684\u5468 \u6708 \u661f\u671f \u5468\u672b \u5c0f\u65f6 \u8fd8\u6709\u66f4\u591a \u800c\u4f7f\u7528pandas\u5c31\u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u6dfb\u52a0'year'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5e74\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'year' ] = df [ 'datetime_column' ] . dt . year # \u6dfb\u52a0'weekofyear'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5468\u6570\u63d0\u53d6\u51fa\u6765 df . loc [:, 'weekofyear' ] = df [ 'datetime_column' ] . dt . weekofyear # \u6dfb\u52a0'month'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u6708\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'month' ] = df [ 'datetime_column' ] . dt . month # \u6dfb\u52a0'dayofweek'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u661f\u671f\u51e0\u63d0\u53d6\u51fa\u6765 df . loc [:, 'dayofweek' ] = df [ 'datetime_column' ] . dt . dayofweek # \u6dfb\u52a0'weekend'\u5217\uff0c\u5224\u65ad\u5f53\u5929\u662f\u5426\u4e3a\u5468\u672b df . loc [:, 'weekend' ] = ( df . datetime_column . dt . weekday >= 5 ) . astype ( int ) # \u6dfb\u52a0 'hour' \u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5c0f\u65f6\u63d0\u53d6\u51fa\u6765 df . loc [:, 'hour' ] = df [ 'datetime_column' ] . dt . hour \u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u65e5\u671f\u65f6\u95f4\u5217\u521b\u5efa\u4e00\u7cfb\u5217\u65b0\u5217\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u53ef\u4ee5\u521b\u5efa\u7684\u4e00\u4e9b\u793a\u4f8b\u529f\u80fd\u3002 import pandas as pd # \u521b\u5efa\u65e5\u671f\u65f6\u95f4\u5e8f\u5217\uff0c\u5305\u542b\u4e86\u4ece '2020-01-06' \u5230 '2020-01-10' \u7684\u65e5\u671f\u65f6\u95f4\u70b9\uff0c\u65f6\u95f4\u95f4\u9694\u4e3a10\u5c0f\u65f6 s = pd . date_range ( '2020-01-06' , '2020-01-10' , freq = '10H' ) . to_series () # \u63d0\u53d6\u5bf9\u5e94\u65f6\u95f4\u7279\u5f81 features = { \"dayofweek\" : s . dt . dayofweek . values , \"dayofyear\" : s . dt . dayofyear . values , \"hour\" : s . dt . hour . values , \"is_leap_year\" : s . dt . is_leap_year . values , \"quarter\" : s . dt . quarter . values , \"weekofyear\" : s . dt . weekofyear . values } \u8fd9\u5c06\u4ece\u7ed9\u5b9a\u7cfb\u5217\u4e2d\u751f\u6210\u4e00\u4e2a\u7279\u5f81\u5b57\u5178\u3002\u60a8\u53ef\u4ee5\u5c06\u6b64\u5e94\u7528\u4e8e pandas \u6570\u636e\u4e2d\u7684\u4efb\u4f55\u65e5\u671f\u65f6\u95f4\u5217\u3002\u8fd9\u4e9b\u662f pandas \u63d0\u4f9b\u7684\u4f17\u591a\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u4e2d\u7684\u4e00\u90e8\u5206\u3002\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u975e\u5e38\u91cd\u8981\uff0c\u4f8b\u5982\uff0c\u5728\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97\u7684\u9500\u552e\u989d\u65f6\uff0c\u5982\u679c\u60f3\u5728\u805a\u5408\u7279\u5f81\u4e0a\u4f7f\u7528 xgboost \u7b49\u6a21\u578b\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u5c31\u975e\u5e38\u91cd\u8981\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u4e0b\u6240\u793a\u7684\u6570\u636e\uff1a \u56fe 1\uff1a\u5305\u542b\u5206\u7c7b\u548c\u65e5\u671f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u6709\u4e00\u4e2a\u65e5\u671f\u5217\uff0c\u4ece\u4e2d\u53ef\u4ee5\u8f7b\u677e\u63d0\u53d6\u5e74\u3001\u6708\u3001\u5b63\u5ea6\u7b49\u7279\u5f81\u3002\u7136\u540e\uff0c\u6211\u4eec\u6709\u4e00\u4e2a customer_id \u5217\uff0c\u8be5\u5217\u6709\u591a\u4e2a\u6761\u76ee\uff0c\u56e0\u6b64\u4e00\u4e2a\u5ba2\u6237\u4f1a\u88ab\u770b\u5230\u5f88\u591a\u6b21\uff08\u622a\u56fe\u4e2d\u770b\u4e0d\u5230\uff09\u3002\u6bcf\u4e2a\u65e5\u671f\u548c\u5ba2\u6237 ID \u90fd\u6709\u4e09\u4e2a\u5206\u7c7b\u7279\u5f81\u548c\u4e00\u4e2a\u6570\u5b57\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u4ece\u4e2d\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff1a - \u5ba2\u6237\u6700\u6d3b\u8dc3\u7684\u6708\u4efd\u662f\u51e0\u6708 - \u67d0\u4e2a\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u7684\u8ba1\u6570\u662f\u591a\u5c11 - \u67d0\u5e74\u67d0\u6708\u67d0\u5468\u67d0\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u6570\u91cf\u662f\u591a\u5c11\uff1f - \u67d0\u4e2a\u5ba2\u6237\u7684 num1 \u5e73\u5747\u503c\u662f\u591a\u5c11\uff1f - \u7b49\u7b49\u3002 \u4f7f\u7528 pandas \u4e2d\u7684\u805a\u5408\uff0c\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u521b\u5efa\u7c7b\u4f3c\u7684\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u5b9e\u73b0\u3002 def generate_features ( df ): df . loc [:, 'year' ] = df [ 'date' ] . dt . year df . loc [:, 'weekofyear' ] = df [ 'date' ] . dt . weekofyear df . loc [:, 'month' ] = df [ 'date' ] . dt . month df . loc [:, 'dayofweek' ] = df [ 'date' ] . dt . dayofweek df . loc [:, 'weekend' ] = ( df [ 'date' ] . dt . weekday >= 5 ) . astype ( int ) aggs = {} # \u5bf9 'month' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'month' ] = [ 'nunique' , 'mean' ] # \u5bf9 'weekofyear' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'weekofyear' ] = [ 'nunique' , 'mean' ] # \u5bf9 'num1' \u5217\u8fdb\u884c sum\u3001max\u3001min\u3001mean \u805a\u5408 aggs [ 'num1' ] = [ 'sum' , 'max' , 'min' , 'mean' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c size \u805a\u5408 aggs [ 'customer_id' ] = [ 'size' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c nunique \u805a\u5408 aggs [ 'customer_id' ] = [ 'nunique' ] # \u5bf9\u6570\u636e\u5e94\u7528\u4e0d\u540c\u7684\u805a\u5408\u51fd\u6570 agg_df = df . groupby ( 'customer_id' ) . agg ( aggs ) # \u91cd\u7f6e\u7d22\u5f15 agg_df = agg_df . reset_index () return agg_df \u8bf7\u6ce8\u610f\uff0c\u5728\u4e0a\u8ff0\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u8df3\u8fc7\u4e86\u5206\u7c7b\u53d8\u91cf\uff0c\u4f46\u60a8\u53ef\u4ee5\u50cf\u4f7f\u7528\u5176\u4ed6\u805a\u5408\u53d8\u91cf\u4e00\u6837\u4f7f\u7528\u5b83\u4eec\u3002 \u56fe 2\uff1a\u603b\u4f53\u7279\u5f81\u548c\u5176\u4ed6\u7279\u5f81 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u56fe 2 \u4e2d\u7684\u6570\u636e\u4e0e\u5e26\u6709 customer_id \u5217\u7684\u539f\u59cb\u6570\u636e\u5e27\u8fde\u63a5\u8d77\u6765\uff0c\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5e76\u4e0d\u662f\u8981\u9884\u6d4b\u4ec0\u4e48\uff1b\u6211\u4eec\u53ea\u662f\u5728\u521b\u5efa\u901a\u7528\u7279\u5f81\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u8bd5\u56fe\u5728\u8fd9\u91cc\u9884\u6d4b\u4ec0\u4e48\uff0c\u521b\u5efa\u7279\u5f81\u4f1a\u66f4\u5bb9\u6613\u3002 \u4f8b\u5982\uff0c\u6709\u65f6\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u95ee\u9898\u65f6\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u7684\u7279\u5f81\u4e0d\u662f\u5355\u4e2a\u503c\uff0c\u800c\u662f\u4e00\u7cfb\u5217\u503c\u3002 \u4f8b\u5982\uff0c\u5ba2\u6237\u5728\u7279\u5b9a\u65f6\u95f4\u6bb5\u5185\u7684\u4ea4\u6613\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4f1a\u521b\u5efa\u4e0d\u540c\u7c7b\u578b\u7684\u7279\u5f81\uff0c\u4f8b\u5982\uff1a\u4f7f\u7528\u6570\u503c\u7279\u5f81\u65f6\uff0c\u5728\u5bf9\u5206\u7c7b\u5217\u8fdb\u884c\u5206\u7ec4\u65f6\uff0c\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e8e\u65f6\u95f4\u5206\u5e03\u503c\u5217\u8868\u7684\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u7cfb\u5217\u7edf\u8ba1\u7279\u5f81\uff0c\u4f8b\u5982 \u5e73\u5747\u503c \u6700\u5927\u503c \u6700\u5c0f\u503c \u72ec\u7279\u6027 \u504f\u659c \u5cf0\u5ea6 Kstat \u767e\u5206\u4f4d\u6570 \u5b9a\u91cf \u5cf0\u503c\u5230\u5cf0\u503c \u4ee5\u53ca\u66f4\u591a \u8fd9\u4e9b\u53ef\u4ee5\u4f7f\u7528\u7b80\u5355\u7684 numpy \u51fd\u6570\u521b\u5efa\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 import numpy as np # \u521b\u5efa\u5b57\u5178\uff0c\u7528\u4e8e\u5b58\u50a8\u4e0d\u540c\u7684\u7edf\u8ba1\u7279\u5f81 feature_dict = {} # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5e73\u5747\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'mean' \u952e\u4e0b feature_dict [ 'mean' ] = np . mean ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5927\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'max' \u952e\u4e0b feature_dict [ 'max' ] = np . max ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5c0f\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'min' \u952e\u4e0b feature_dict [ 'min' ] = np . min ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6807\u51c6\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'std' \u952e\u4e0b feature_dict [ 'std' ] = np . std ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u65b9\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'var' \u952e\u4e0b feature_dict [ 'var' ] = np . var ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5dee\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'ptp' \u952e\u4e0b feature_dict [ 'ptp' ] = np . ptp ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c10\u767e\u5206\u4f4d\u6570\uff08\u5373\u767e\u5206\u4e4b10\u5206\u4f4d\u6570\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_10' \u952e\u4e0b feature_dict [ 'percentile_10' ] = np . percentile ( x , 10 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c60\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_60' \u952e\u4e0b feature_dict [ 'percentile_60' ] = np . percentile ( x , 60 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c90\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_90' \u952e\u4e0b feature_dict [ 'percentile_90' ] = np . percentile ( x , 90 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u76845%\u5206\u4f4d\u6570\uff08\u53730.05\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_5' \u952e\u4e0b feature_dict [ 'quantile_5' ] = np . quantile ( x , 0.05 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768495%\u5206\u4f4d\u6570\uff08\u53730.95\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_95' \u952e\u4e0b feature_dict [ 'quantile_95' ] = np . quantile ( x , 0.95 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768499%\u5206\u4f4d\u6570\uff08\u53730.99\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_99' \u952e\u4e0b feature_dict [ 'quantile_99' ] = np . quantile ( x , 0.99 ) \u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff08\u6570\u503c\u5217\u8868\uff09\u53ef\u4ee5\u8f6c\u6362\u6210\u8bb8\u591a\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4e00\u4e2a\u540d\u4e3a tsfresh \u7684 python \u5e93\u975e\u5e38\u6709\u7528\u3002 from tsfresh.feature_extraction import feature_calculators as fc # \u8ba1\u7b97 x \u6570\u5217\u7684\u7edd\u5bf9\u80fd\u91cf\uff08abs_energy\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'abs_energy' \u952e\u4e0b feature_dict [ 'abs_energy' ] = fc . abs_energy ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u9ad8\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_above_mean' \u952e\u4e0b feature_dict [ 'count_above_mean' ] = fc . count_above_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u4f4e\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_below_mean' \u952e\u4e0b feature_dict [ 'count_below_mean' ] = fc . count_below_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u7edd\u5bf9\u53d8\u5316\uff08mean_abs_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_abs_change' \u952e\u4e0b feature_dict [ 'mean_abs_change' ] = fc . mean_abs_change ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u53d8\u5316\u7387\uff08mean_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_change' \u952e\u4e0b feature_dict [ 'mean_change' ] = fc . mean_change ( x ) \u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\uff1btsfresh \u63d0\u4f9b\u4e86\u6570\u767e\u79cd\u7279\u5f81\u548c\u6570\u5341\u79cd\u4e0d\u540c\u7279\u5f81\u7684\u53d8\u4f53\uff0c\u4f60\u53ef\u4ee5\u5c06\u5b83\u4eec\u7528\u4e8e\u57fa\u4e8e\u65f6\u95f4\u5e8f\u5217\uff08\u503c\u5217\u8868\uff09\u7684\u7279\u5f81\u3002\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0cx \u662f\u4e00\u4e2a\u503c\u5217\u8868\u3002\u4f46\u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\u3002\u60a8\u8fd8\u53ef\u4ee5\u4e3a\u5305\u542b\u6216\u4e0d\u5305\u542b\u5206\u7c7b\u6570\u636e\u7684\u6570\u503c\u6570\u636e\u521b\u5efa\u8bb8\u591a\u5176\u4ed6\u7279\u5f81\u3002\u751f\u6210\u8bb8\u591a\u7279\u5f81\u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u521b\u5efa\u4e00\u5806\u591a\u9879\u5f0f\u7279\u5f81\u3002\u4f8b\u5982\uff0c\u4ece\u4e24\u4e2a\u7279\u5f81 \"a \"\u548c \"b \"\u751f\u6210\u7684\u4e8c\u7ea7\u591a\u9879\u5f0f\u7279\u5f81\u5305\u62ec \"a\"\u3001\"b\"\u3001\"ab\"\u3001\"a^2 \"\u548c \"b^2\"\u3002 import numpy as np df = pd . DataFrame ( np . random . rand ( 100 , 2 ), columns = [ f \"f_ { i } \" for i in range ( 1 , 3 )]) \u5982\u56fe 3 \u6240\u793a\uff0c\u5b83\u7ed9\u51fa\u4e86\u4e00\u4e2a\u6570\u636e\u8868\u3002 \u56fe 3\uff1a\u5305\u542b\u4e24\u4e2a\u6570\u5b57\u7279\u5f81\u7684\u968f\u673a\u6570\u636e\u8868 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u7684 PolynomialFeatures \u521b\u5efa\u4e24\u6b21\u591a\u9879\u5f0f\u7279\u5f81\u3002 from sklearn import preprocessing # \u6307\u5b9a\u591a\u9879\u5f0f\u7684\u6b21\u6570\u4e3a 2\uff0c\u4e0d\u4ec5\u8003\u8651\u4ea4\u4e92\u9879\uff0c\u4e0d\u5305\u62ec\u504f\u5dee\uff08include_bias=False\uff09 pf = preprocessing . PolynomialFeatures ( degree = 2 , interaction_only = False , include_bias = False ) # \u62df\u5408\uff0c\u521b\u5efa\u591a\u9879\u5f0f\u7279\u5f81 pf . fit ( df ) # \u8f6c\u6362\u6570\u636e poly_feats = pf . transform ( df ) # \u83b7\u53d6\u751f\u6210\u7684\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf num_feats = poly_feats . shape [ 1 ] # \u4e3a\u65b0\u751f\u6210\u7684\u7279\u5f81\u547d\u540d df_transformed = pd . DataFrame ( poly_feats , columns = [ f \"f_ { i } \" for i in range ( 1 , num_feats + 1 )] ) \u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u6570\u636e\u8868\uff0c\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1a\u5e26\u6709\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e\u8868 \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e9b\u591a\u9879\u5f0f\u7279\u5f81\u3002\u5982\u679c\u521b\u5efa\u7684\u662f\u4e09\u6b21\u591a\u9879\u5f0f\u7279\u5f81\uff0c\u6700\u7ec8\u603b\u5171\u4f1a\u6709\u4e5d\u4e2a\u7279\u5f81\u3002\u7279\u5f81\u7684\u6570\u91cf\u8d8a\u591a\uff0c\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf\u4e5f\u5c31\u8d8a\u591a\uff0c\u800c\u4e14\u4f60\u8fd8\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5982\u679c\u6570\u636e\u96c6\u4e2d\u6709\u5f88\u591a\u6837\u672c\uff0c\u90a3\u4e48\u521b\u5efa\u8fd9\u7c7b\u7279\u5f81\u5c31\u9700\u8981\u82b1\u8d39\u4e00\u4e9b\u65f6\u95f4\u3002 \u56fe 5\uff1a\u6570\u5b57\u7279\u5f81\u5217\u7684\u76f4\u65b9\u56fe \u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u529f\u80fd\u662f\u5c06\u6570\u5b57\u8f6c\u6362\u4e3a\u7c7b\u522b\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \u5206\u7bb1 \u3002\u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u56fe 5\uff0c\u5b83\u663e\u793a\u4e86\u4e00\u4e2a\u968f\u673a\u6570\u5b57\u7279\u5f81\u7684\u6837\u672c\u76f4\u65b9\u56fe\u3002\u6211\u4eec\u5728\u8be5\u56fe\u4e2d\u4f7f\u7528\u4e8610\u4e2a\u5206\u7bb1\uff0c\u53ef\u4ee5\u770b\u5230\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a10\u4e2a\u90e8\u5206\u3002\u8fd9\u53ef\u4ee5\u4f7f\u7528 pandas \u7684cat\u51fd\u6570\u6765\u5b9e\u73b0\u3002 # \u521b\u5efa10\u4e2a\u5206\u7bb1 df [ \"f_bin_10\" ] = pd . cut ( df [ \"f_1\" ], bins = 10 , labels = False ) # \u521b\u5efa100\u4e2a\u5206\u7bb1 df [ \"f_bin_100\" ] = pd . cut ( df [ \"f_1\" ], bins = 100 , labels = False ) \u5982\u56fe 6 \u6240\u793a\uff0c\u8fd9\u5c06\u5728\u6570\u636e\u5e27\u4e2d\u751f\u6210\u4e24\u4e2a\u65b0\u7279\u5f81\u3002 \u56fe 6\uff1a\u6570\u503c\u7279\u5f81\u5206\u7bb1 \u5f53\u4f60\u8fdb\u884c\u5206\u7c7b\u65f6\uff0c\u53ef\u4ee5\u540c\u65f6\u4f7f\u7528\u5206\u7bb1\u548c\u539f\u59cb\u7279\u5f81\u3002\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u540e\u534a\u90e8\u5206\u5b66\u4e60\u66f4\u591a\u5173\u4e8e\u9009\u62e9\u7279\u5f81\u7684\u77e5\u8bc6\u3002\u5206\u7bb1\u8fd8\u53ef\u4ee5\u5c06\u6570\u5b57\u7279\u5f81\u89c6\u4e3a\u5206\u7c7b\u7279\u5f81\u3002 \u53e6\u4e00\u79cd\u53ef\u4ee5\u4ece\u6570\u503c\u7279\u5f81\u4e2d\u521b\u5efa\u7684\u6709\u8da3\u7279\u5f81\u7c7b\u578b\u662f\u5bf9\u6570\u53d8\u6362\u3002\u8bf7\u770b\u56fe 7 \u4e2d\u7684\u7279\u5f81 f_3\u3002 \u4e0e\u5176\u4ed6\u65b9\u5dee\u8f83\u5c0f\u7684\u7279\u5f81\u76f8\u6bd4\uff08\u5047\u8bbe\u5982\u6b64\uff09\uff0cf_3 \u662f\u4e00\u79cd\u65b9\u5dee\u975e\u5e38\u5927\u7684\u7279\u6b8a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5e0c\u671b\u964d\u4f4e\u8fd9\u4e00\u5217\u7684\u65b9\u5dee\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6570\u53d8\u6362\u6765\u5b9e\u73b0\u3002 f_3 \u5217\u7684\u503c\u8303\u56f4\u4e3a 0 \u5230 10000\uff0c\u76f4\u65b9\u56fe\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u7279\u5f81 f_3 \u7684\u76f4\u65b9\u56fe \u6211\u4eec\u53ef\u4ee5\u5bf9\u8fd9\u4e00\u5217\u5e94\u7528 log(1 + x) \u6765\u51cf\u5c11\u5176\u65b9\u5dee\u3002\u56fe 9 \u663e\u793a\u4e86\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u76f4\u65b9\u56fe\u7684\u53d8\u5316\u3002 \u56fe 9\uff1a\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u7684 f_3 \u76f4\u65b9\u56fe \u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e0d\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u548c\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u7684\u65b9\u5dee\u3002 In [ X ]: df . f_3 . var () Out [ X ]: 8077265.875858586 In [ X ]: df . f_3 . apply ( lambda x : np . log ( 1 + x )) . var () Out [ X ]: 0.6058771732119975 \u6709\u65f6\uff0c\u4e5f\u53ef\u4ee5\u7528\u6307\u6570\u6765\u4ee3\u66ff\u5bf9\u6570\u3002\u4e00\u79cd\u975e\u5e38\u6709\u8da3\u7684\u60c5\u51b5\u662f\uff0c\u60a8\u4f7f\u7528\u57fa\u4e8e\u5bf9\u6570\u7684\u8bc4\u4f30\u6307\u6807\uff0c\u4f8b\u5982 RMSLE\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u5728\u5bf9\u6570\u53d8\u6362\u7684\u76ee\u6807\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u7136\u540e\u5728\u9884\u6d4b\u65f6\u4f7f\u7528\u6307\u6570\u503c\u8f6c\u6362\u56de\u539f\u59cb\u503c\u3002\u8fd9\u5c06\u6709\u52a9\u4e8e\u9488\u5bf9\u6307\u6807\u4f18\u5316\u6a21\u578b\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u7c7b\u6570\u5b57\u7279\u5f81\u90fd\u662f\u57fa\u4e8e\u76f4\u89c9\u521b\u5efa\u7684\u3002\u6ca1\u6709\u516c\u5f0f\u53ef\u5faa\u3002\u5982\u679c\u60a8\u4ece\u4e8b\u7684\u662f\u67d0\u4e00\u884c\u4e1a\uff0c\u60a8\u5c06\u521b\u5efa\u7279\u5b9a\u884c\u4e1a\u7684\u7279\u5f81\u3002 \u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u548c\u6570\u503c\u53d8\u91cf\u65f6\uff0c\u53ef\u80fd\u4f1a\u9047\u5230\u7f3a\u5931\u503c\u3002\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e9b\u5904\u7406\u5206\u7c7b\u7279\u5f81\u4e2d\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\uff0c\u4f46\u8fd8\u6709\u66f4\u591a\u65b9\u6cd5\u53ef\u4ee5\u5904\u7406\u7f3a\u5931\u503c/NaN \u503c\u3002\u8fd9\u4e5f\u88ab\u89c6\u4e3a\u7279\u5f81\u5de5\u7a0b\u3002 \u5982\u679c\u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u9047\u5230\u7f3a\u5931\u503c\uff0c\u5c31\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\uff01\u8fd9\u6837\u505a\u867d\u7136\u7b80\u5355\uff0c\u4f46\uff08\u51e0\u4e4e\uff09\u603b\u662f\u6709\u6548\u7684\uff01 \u5728\u6570\u503c\u6570\u636e\u4e2d\u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u9009\u62e9\u4e00\u4e2a\u5728\u7279\u5b9a\u7279\u5f81\u4e2d\u6ca1\u6709\u51fa\u73b0\u7684\u503c\uff0c\u7136\u540e\u7528\u5b83\u6765\u586b\u8865\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u7279\u5f81\u4e2d\u6ca1\u6709 0\u3002\u8fd9\u662f\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\uff0c\u4f46\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\u3002\u5bf9\u4e8e\u6570\u503c\u6570\u636e\u6765\u8bf4\uff0c\u6bd4\u586b\u5145 0 \u66f4\u6709\u6548\u7684\u65b9\u6cd5\u4e4b\u4e00\u662f\u4f7f\u7528\u5e73\u5747\u503c\u8fdb\u884c\u586b\u5145\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8be5\u7279\u5f81\u6240\u6709\u503c\u7684\u4e2d\u4f4d\u6570\u6765\u586b\u5145\uff0c\u6216\u8005\u4f7f\u7528\u6700\u5e38\u89c1\u7684\u503c\u6765\u586b\u5145\u7f3a\u5931\u503c\u3002\u8fd9\u6837\u505a\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002 \u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u9ad8\u7ea7\u65b9\u6cd5\u662f\u4f7f\u7528 K \u8fd1\u90bb\u6cd5 \u3002 \u60a8\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\uff0c\u7136\u540e\u5229\u7528\u67d0\u79cd\u8ddd\u79bb\u5ea6\u91cf\uff08\u4f8b\u5982\u6b27\u6c0f\u8ddd\u79bb\uff09\u627e\u5230\u6700\u8fd1\u7684\u90bb\u5c45\u3002\u7136\u540e\u53d6\u6240\u6709\u8fd1\u90bb\u7684\u5e73\u5747\u503c\u6765\u586b\u8865\u7f3a\u5931\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 KNN \u6765\u586b\u8865\u8fd9\u6837\u7684\u7f3a\u5931\u503c\u3002 \u56fe 10\uff1a\u6709\u7f3a\u5931\u503c\u7684\u4e8c\u7ef4\u6570\u7ec4 \u8ba9\u6211\u4eec\u770b\u770b KNN \u662f\u5982\u4f55\u5904\u7406\u56fe 10 \u6240\u793a\u7684\u7f3a\u5931\u503c\u77e9\u9635\u7684\u3002 import numpy as np from sklearn import impute # \u751f\u6210\u7ef4\u5ea6\u4e3a (10, 6) \u7684\u968f\u673a\u6574\u6570\u77e9\u9635 X\uff0c\u6570\u503c\u8303\u56f4\u5728 1 \u5230 14 \u4e4b\u95f4 X = np . random . randint ( 1 , 15 , ( 10 , 6 )) # \u6570\u636e\u7c7b\u578b\u8f6c\u6362\u4e3a float X = X . astype ( float ) # \u5728\u77e9\u9635 X \u4e2d\u968f\u673a\u9009\u62e9 10 \u4e2a\u4f4d\u7f6e\uff0c\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u5143\u7d20\u8bbe\u7f6e\u4e3a NaN\uff08\u7f3a\u5931\u503c\uff09 X . ravel ()[ np . random . choice ( X . size , 10 , replace = False )] = np . nan # \u521b\u5efa\u4e00\u4e2a KNNImputer \u5bf9\u8c61 knn_imputer\uff0c\u6307\u5b9a\u90bb\u5c45\u6570\u91cf\u4e3a 2 knn_imputer = impute . KNNImputer ( n_neighbors = 2 ) # # \u4f7f\u7528 knn_imputer \u5bf9\u77e9\u9635 X \u8fdb\u884c\u62df\u5408\u548c\u8f6c\u6362\uff0c\u7528 K-\u6700\u8fd1\u90bb\u65b9\u6cd5\u586b\u8865\u7f3a\u5931\u503c knn_imputer . fit_transform ( X ) \u5982\u56fe 11 \u6240\u793a\uff0c\u5b83\u586b\u5145\u4e86\u4e0a\u8ff0\u77e9\u9635\u3002 \u56fe 11\uff1aKNN\u4f30\u7b97\u7684\u6570\u503c \u53e6\u4e00\u79cd\u5f25\u8865\u5217\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\u662f\u8bad\u7ec3\u56de\u5f52\u6a21\u578b\uff0c\u8bd5\u56fe\u6839\u636e\u5176\u4ed6\u5217\u9884\u6d4b\u67d0\u5217\u7684\u7f3a\u5931\u503c\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u4ece\u6709\u7f3a\u5931\u503c\u7684\u4e00\u5217\u5f00\u59cb\uff0c\u5c06\u8fd9\u4e00\u5217\u4f5c\u4e3a\u65e0\u7f3a\u5931\u503c\u56de\u5f52\u6a21\u578b\u7684\u76ee\u6807\u5217\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6240\u6709\u5176\u4ed6\u5217\uff0c\u5bf9\u76f8\u5173\u5217\u4e2d\u6ca1\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\uff0c\u7136\u540e\u5c1d\u8bd5\u9884\u6d4b\u4e4b\u524d\u5220\u9664\u7684\u6837\u672c\u7684\u76ee\u6807\u5217\uff08\u540c\u4e00\u5217\uff09\u3002\u8fd9\u6837\uff0c\u57fa\u4e8e\u6a21\u578b\u7684\u4f30\u7b97\u5c31\u4f1a\u66f4\u52a0\u7a33\u5065\u3002 \u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u5bf9\u4e8e\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c\u6570\u503c\u5f52\u4e00\u5316\uff0c\u56e0\u4e3a\u5b83\u4eec\u53ef\u4ee5\u81ea\u884c\u5904\u7406\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u6240\u5c55\u793a\u7684\u53ea\u662f\u521b\u5efa\u4e00\u822c\u7279\u5f81\u7684\u4e00\u4e9b\u65b9\u6cd5\u3002\u73b0\u5728\uff0c\u5047\u8bbe\u60a8\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u9884\u6d4b\u4e0d\u540c\u5546\u54c1\uff08\u6bcf\u5468\u6216\u6bcf\u6708\uff09\u5546\u5e97\u9500\u552e\u989d\u7684\u95ee\u9898\u3002\u60a8\u6709\u5546\u54c1\uff0c\u4e5f\u6709\u5546\u5e97 ID\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u6bcf\u4e2a\u5546\u5e97\u7684\u5546\u54c1\u7b49\u7279\u5f81\u3002\u73b0\u5728\uff0c\u8fd9\u662f\u4e0a\u6587\u6ca1\u6709\u8ba8\u8bba\u7684\u7279\u5f81\u4e4b\u4e00\u3002\u8fd9\u7c7b\u7279\u5f81\u4e0d\u80fd\u4e00\u6982\u800c\u8bba\uff0c\u5b8c\u5168\u6765\u81ea\u4e8e\u9886\u57df\u3001\u6570\u636e\u548c\u4e1a\u52a1\u77e5\u8bc6\u3002\u67e5\u770b\u6570\u636e\uff0c\u627e\u51fa\u9002\u5408\u7684\u7279\u5f81\uff0c\u7136\u540e\u521b\u5efa\u76f8\u5e94\u7684\u7279\u5f81\u3002\u5982\u679c\u60a8\u4f7f\u7528\u7684\u662f\u903b\u8f91\u56de\u5f52\u7b49\u7ebf\u6027\u6a21\u578b\u6216 SVM \u7b49\u6a21\u578b\uff0c\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\u5bf9\u7279\u5f81\u8fdb\u884c\u7f29\u653e\u6216\u5f52\u4e00\u5316\u5904\u7406\u3002\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65e0\u9700\u5bf9\u7279\u5f81\u8fdb\u884c\u4efb\u4f55\u5f52\u4e00\u5316\u5904\u7406\u5373\u53ef\u6b63\u5e38\u5de5\u4f5c\u3002","title":"\u7279\u5f81\u5de5\u7a0b"},{"location":"%E7%89%B9%E5%BE%81%E9%80%89%E6%8B%A9/","text":"\u7279\u5f81\u9009\u62e9 \u5f53\u4f60\u521b\u5efa\u4e86\u6210\u5343\u4e0a\u4e07\u4e2a\u7279\u5f81\u540e\uff0c\u5c31\u8be5\u4ece\u4e2d\u6311\u9009\u51fa\u51e0\u4e2a\u4e86\u3002\u4f46\u662f\uff0c\u6211\u4eec\u7edd\u4e0d\u5e94\u8be5\u521b\u5efa\u6210\u767e\u4e0a\u5343\u4e2a\u65e0\u7528\u7684\u7279\u5f81\u3002\u7279\u5f81\u8fc7\u591a\u4f1a\u5e26\u6765\u4e00\u4e2a\u4f17\u6240\u5468\u77e5\u7684\u95ee\u9898\uff0c\u5373 \"\u7ef4\u5ea6\u8bc5\u5492\"\u3002\u5982\u679c\u4f60\u6709\u5f88\u591a\u7279\u5f81\uff0c\u4f60\u4e5f\u5fc5\u987b\u6709\u5f88\u591a\u8bad\u7ec3\u6837\u672c\u6765\u6355\u6349\u6240\u6709\u7279\u5f81\u3002\u4ec0\u4e48\u662f \"\u5927\u91cf \"\u5e76\u6ca1\u6709\u6b63\u786e\u7684\u5b9a\u4e49\uff0c\u8fd9\u9700\u8981\u60a8\u901a\u8fc7\u6b63\u786e\u9a8c\u8bc1\u60a8\u7684\u6a21\u578b\u548c\u68c0\u67e5\u8bad\u7ec3\u6a21\u578b\u6240\u9700\u7684\u65f6\u95f4\u6765\u786e\u5b9a\u3002 \u9009\u62e9\u7279\u5f81\u7684\u6700\u7b80\u5355\u65b9\u6cd5\u662f \u5220\u9664\u65b9\u5dee\u975e\u5e38\u5c0f\u7684\u7279\u5f81 \u3002\u5982\u679c\u7279\u5f81\u7684\u65b9\u5dee\u975e\u5e38\u5c0f\uff08\u5373\u975e\u5e38\u63a5\u8fd1\u4e8e 0\uff09\uff0c\u5b83\u4eec\u5c31\u63a5\u8fd1\u4e8e\u5e38\u91cf\uff0c\u56e0\u6b64\u6839\u672c\u4e0d\u4f1a\u7ed9\u4efb\u4f55\u6a21\u578b\u589e\u52a0\u4efb\u4f55\u4ef7\u503c\u3002\u6700\u597d\u7684\u529e\u6cd5\u5c31\u662f\u53bb\u6389\u5b83\u4eec\uff0c\u4ece\u800c\u964d\u4f4e\u590d\u6742\u5ea6\u3002\u8bf7\u6ce8\u610f\uff0c\u65b9\u5dee\u4e5f\u53d6\u51b3\u4e8e\u6570\u636e\u7684\u7f29\u653e\u3002 Scikit-learn \u7684 VarianceThreshold \u5b9e\u73b0\u4e86\u8fd9\u4e00\u70b9\u3002 from sklearn.feature_selection import VarianceThreshold data = ... # \u521b\u5efa VarianceThreshold \u5bf9\u8c61 var_thresh\uff0c\u6307\u5b9a\u65b9\u5dee\u9608\u503c\u4e3a 0.1 var_thresh = VarianceThreshold ( threshold = 0.1 ) # \u4f7f\u7528 var_thresh \u5bf9\u6570\u636e data \u8fdb\u884c\u62df\u5408\u548c\u53d8\u6362\uff0c\u5c06\u65b9\u5dee\u4f4e\u4e8e\u9608\u503c\u7684\u7279\u5f81\u79fb\u9664 transformed_data = var_thresh . fit_transform ( data ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u5220\u9664\u76f8\u5173\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u3002\u8981\u8ba1\u7b97\u4e0d\u540c\u6570\u5b57\u7279\u5f81\u4e4b\u95f4\u7684\u76f8\u5173\u6027\uff0c\u53ef\u4ee5\u4f7f\u7528\u76ae\u5c14\u900a\u76f8\u5173\u6027\u3002 import pandas as pd from sklearn.datasets import fetch_california_housing # \u52a0\u8f7d\u6570\u636e data = fetch_california_housing () # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u77e9\u9635 X X = data [ \"data\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u7684\u5217\u540d col_names = data [ \"feature_names\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf y y = data [ \"target\" ] df = pd . DataFrame ( X , columns = col_names ) # \u6dfb\u52a0 MedInc_Sqrt \u5217\uff0c\u662f MedInc \u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u5e73\u65b9\u6839\u8fd0\u7b97\u7684\u7ed3\u679c df . loc [:, \"MedInc_Sqrt\" ] = df . MedInc . apply ( np . sqrt ) # \u8ba1\u7b97\u76ae\u5c14\u900a\u76f8\u5173\u6027\u77e9\u9635 df . corr () \u5f97\u51fa\u76f8\u5173\u77e9\u9635\uff0c\u5982\u56fe 1 \u6240\u793a\u3002 \u56fe 1\uff1a\u76ae\u5c14\u900a\u76f8\u5173\u77e9\u9635\u6837\u672c \u6211\u4eec\u770b\u5230\uff0cMedInc_Sqrt \u4e0e MedInc \u7684\u76f8\u5173\u6027\u975e\u5e38\u9ad8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5220\u9664\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u4e00\u4e9b \u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5 \u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u53ea\u4e0d\u8fc7\u662f\u9488\u5bf9\u7ed9\u5b9a\u76ee\u6807\u5bf9\u6bcf\u4e2a\u7279\u5f81\u8fdb\u884c\u8bc4\u5206\u3002 \u4e92\u4fe1\u606f \u3001 \u65b9\u5dee\u5206\u6790 F \u68c0\u9a8c\u548c chi2 \u662f\u4e00\u4e9b\u6700\u5e38\u7528\u7684\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u3002\u5728 scikit- learn \u4e2d\uff0c\u6709\u4e24\u79cd\u65b9\u6cd5\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e9b\u65b9\u6cd5\u3002 - SelectKBest\uff1a\u4fdd\u7559\u5f97\u5206\u6700\u9ad8\u7684 k \u4e2a\u7279\u5f81 - SelectPercentile\uff1a\u4fdd\u7559\u7528\u6237\u6307\u5b9a\u767e\u5206\u6bd4\u5185\u7684\u9876\u7ea7\u7279\u5f81\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u53ea\u6709\u975e\u8d1f\u6570\u636e\u624d\u80fd\u4f7f\u7528 chi2\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u5f53\u6211\u4eec\u6709\u4e00\u4e9b\u5355\u8bcd\u6216\u57fa\u4e8e tf-idf \u7684\u7279\u5f81\u65f6\uff0c\u8fd9\u662f\u4e00\u79cd\u7279\u522b\u6709\u7528\u7684\u7279\u5f81\u9009\u62e9\u6280\u672f\u3002\u6700\u597d\u4e3a\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u521b\u5efa\u4e00\u4e2a\u5305\u88c5\u5668\uff0c\u51e0\u4e4e\u53ef\u4ee5\u7528\u4e8e\u4efb\u4f55\u65b0\u95ee\u9898\u3002 from sklearn.feature_selection import chi2 from sklearn.feature_selection import f_classif from sklearn.feature_selection import f_regression from sklearn.feature_selection import mutual_info_classif from sklearn.feature_selection import mutual_info_regression from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectPercentile class UnivariateFeatureSelction : def __init__ ( self , n_features , problem_type , scoring ): # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u5206\u7c7b\u95ee\u9898 if problem_type == \"classification\" : # \u521b\u5efa\u5b57\u5178 valid_scoring \uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_classif\" : f_classif , \"chi2\" : chi2 , \"mutual_info_classif\" : mutual_info_classif } # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u56de\u5f52\u95ee\u9898 else : # \u521b\u5efa\u5b57\u5178 valid_scoring\uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_regression\" : f_regression , \"mutual_info_regression\" : mutual_info_regression } # \u68c0\u67e5\u7279\u5f81\u91cd\u8981\u6027\u65b9\u5f0f\u662f\u5426\u5728\u5b57\u5178\u4e2d if scoring not in valid_scoring : raise Exception ( \"Invalid scoring function\" ) # \u68c0\u67e5 n_features \u7684\u7c7b\u578b\uff0c\u5982\u679c\u662f\u6574\u6570\uff0c\u5219\u4f7f\u7528 SelectKBest \u8fdb\u884c\u7279\u5f81\u9009\u62e9 if isinstance ( n_features , int ): self . selection = SelectKBest ( valid_scoring [ scoring ], k = n_features ) # \u5982\u679c n_features \u662f\u6d6e\u70b9\u6570\uff0c\u5219\u4f7f\u7528 SelectPercentile \u8fdb\u884c\u7279\u5f81\u9009\u62e9 elif isinstance ( n_features , float ): self . selection = SelectPercentile ( valid_scoring [ scoring ], percentile = int ( n_features * 100 ) ) # \u5982\u679c n_features \u7c7b\u578b\u65e0\u6548\uff0c\u5f15\u53d1\u5f02\u5e38 else : raise Exception ( \"Invalid type of feature\" ) # \u5b9a\u4e49 fit \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 def fit ( self , X , y ): return self . selection . fit ( X , y ) # \u5b9a\u4e49 transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def transform ( self , X ): return self . selection . transform ( X ) # \u5b9a\u4e49 fit_transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668\u5e76\u540c\u65f6\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def fit_transform ( self , X , y ): return self . selection . fit_transform ( X , y ) \u4f7f\u7528\u8be5\u7c7b\u975e\u5e38\u7b80\u5355\u3002 # \u5b9e\u4f8b\u5316\u7279\u5f81\u9009\u62e9\u5668\uff0c\u4fdd\u7559\u524d10%\u7684\u7279\u5f81\uff0c\u56de\u5f52\u95ee\u9898\uff0c\u4f7f\u7528f_regression\u8861\u91cf\u7279\u5f81\u91cd\u8981\u6027 ufs = UnivariateFeatureSelction ( n_features = 0.1 , problem_type = \"regression\" , scoring = \"f_regression\" ) # \u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 ufs . fit ( X , y ) # \u7279\u5f81\u8f6c\u6362 X_transformed = ufs . transform ( X ) \u8fd9\u6837\u5c31\u80fd\u6ee1\u8db3\u5927\u90e8\u5206\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u7684\u9700\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u521b\u5efa\u8f83\u5c11\u800c\u91cd\u8981\u7684\u7279\u5f81\u901a\u5e38\u6bd4\u521b\u5efa\u6570\u4ee5\u767e\u8ba1\u7684\u7279\u5f81\u8981\u597d\u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u4e0d\u4e00\u5b9a\u603b\u662f\u8868\u73b0\u826f\u597d\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u4eba\u4eec\u66f4\u559c\u6b22\u4f7f\u7528\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 \u4f7f\u7528\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u7684\u6700\u7b80\u5355\u5f62\u5f0f\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u3002\u5728\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u4e2d\uff0c\u7b2c\u4e00\u6b65\u662f\u9009\u62e9\u4e00\u4e2a\u6a21\u578b\u3002\u7b2c\u4e8c\u6b65\u662f\u9009\u62e9\u635f\u5931/\u8bc4\u5206\u51fd\u6570\u3002\u7b2c\u4e09\u6b65\u4e5f\u662f\u6700\u540e\u4e00\u6b65\u662f\u53cd\u590d\u8bc4\u4f30\u6bcf\u4e2a\u7279\u5f81\uff0c\u5982\u679c\u80fd\u63d0\u9ad8\u635f\u5931/\u8bc4\u5206\uff0c\u5c31\u5c06\u5176\u6dfb\u52a0\u5230 \"\u597d \"\u7279\u5f81\u5217\u8868\u4e2d\u3002\u6ca1\u6709\u6bd4\u8fd9\u66f4\u7b80\u5355\u7684\u4e86\u3002\u4f46\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u8fd9\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u662f\u6709\u539f\u56e0\u7684\u3002\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u8fc7\u7a0b\u5728\u6bcf\u6b21\u8bc4\u4f30\u7279\u5f81\u65f6\u90fd\u4f1a\u9002\u5408\u7ed9\u5b9a\u7684\u6a21\u578b\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u8ba1\u7b97\u6210\u672c\u975e\u5e38\u9ad8\u3002\u5b8c\u6210\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u4e5f\u9700\u8981\u5927\u91cf\u65f6\u95f4\u3002\u5982\u679c\u4e0d\u6b63\u786e\u4f7f\u7528\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\uff0c\u751a\u81f3\u4f1a\u5bfc\u81f4\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn.datasets import make_classification class GreedyFeatureSelection : # \u5b9a\u4e49\u8bc4\u4f30\u5206\u6570\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd def evaluate_score ( self , X , y ): # \u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u8bad\u7ec3\u6a21\u578b model . fit ( X , y ) # \u9884\u6d4b\u6982\u7387\u503c predictions = model . predict_proba ( X )[:, 1 ] # \u8ba1\u7b97 AUC \u5206\u6570 auc = metrics . roc_auc_score ( y , predictions ) return auc # \u7279\u5f81\u9009\u62e9\u51fd\u6570 def _feature_selection ( self , X , y ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6700\u4f73\u7279\u5f81\u548c\u6700\u4f73\u5206\u6570 good_features = [] best_scores = [] # \u83b7\u53d6\u7279\u5f81\u6570\u91cf num_features = X . shape [ 1 ] # \u5f00\u59cb\u7279\u5f81\u9009\u62e9\u7684\u5faa\u73af while True : this_feature = None best_score = 0 # \u904d\u5386\u6bcf\u4e2a\u7279\u5f81 for feature in range ( num_features ): if feature in good_features : continue selected_features = good_features + [ feature ] xtrain = X [:, selected_features ] score = self . evaluate_score ( xtrain , y ) # \u5982\u679c\u5f53\u524d\u7279\u5f81\u7684\u5f97\u5206\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u5f97\u5206\uff0c\u5219\u66f4\u65b0 if score > best_score : this_feature = feature best_score = score # \u82e5\u627e\u5230\u4e86\u65b0\u7684\u6700\u4f73\u7279\u5f81 if this_feature != None : # \u7279\u5f81\u6dfb\u52a0\u5230 good_features \u5217\u8868 good_features . append ( this_feature ) # \u5f97\u5206\u6dfb\u52a0\u5230 best_scores \u5217\u8868 best_scores . append ( best_score ) # \u5982\u679c best_scores \u5217\u8868\u957f\u5ea6\u5927\u4e8e2\uff0c\u5e76\u4e14\u6700\u540e\u4e24\u4e2a\u5f97\u5206\u76f8\u6bd4\u8f83\u5dee\uff0c\u5219\u7ed3\u675f\u5faa\u73af if len ( best_scores ) > 2 : if best_scores [ - 1 ] < best_scores [ - 2 ]: break # \u8fd4\u56de\u6700\u4f73\u7279\u5f81\u7684\u5f97\u5206\u5217\u8868\u548c\u6700\u4f73\u7279\u5f81\u5217\u8868 return best_scores [: - 1 ], good_features [: - 1 ] # \u5b9a\u4e49\u7c7b\u7684\u8c03\u7528\u65b9\u6cd5\uff0c\u7528\u4e8e\u6267\u884c\u7279\u5f81\u9009\u62e9 def __call__ ( self , X , y ): scores , features = self . _feature_selection ( X , y ) return X [:, features ], scores if __name__ == \"__main__\" : # \u751f\u6210\u4e00\u4e2a\u793a\u4f8b\u7684\u5206\u7c7b\u6570\u636e\u96c6 X \u548c\u6807\u7b7e y X , y = make_classification ( n_samples = 1000 , n_features = 100 ) # \u5b9e\u4f8b\u5316 GreedyFeatureSelection \u7c7b\uff0c\u5e76\u4f7f\u7528 __call__ \u65b9\u6cd5\u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed , scores = GreedyFeatureSelection ()( X , y ) \u8fd9\u79cd\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u4f1a\u8fd4\u56de\u5206\u6570\u548c\u7279\u5f81\u7d22\u5f15\u5217\u8868\u3002\u56fe 2 \u663e\u793a\u4e86\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u589e\u52a0\u4e00\u4e2a\u65b0\u7279\u5f81\u540e\uff0c\u5206\u6570\u662f\u5982\u4f55\u63d0\u9ad8\u7684\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u67d0\u4e00\u70b9\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u65e0\u6cd5\u63d0\u9ad8\u5206\u6570\u4e86\uff0c\u8fd9\u5c31\u662f\u6211\u4eec\u505c\u6b62\u7684\u5730\u65b9\u3002 \u53e6\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u9012\u5f52\u7279\u5f81\u6d88\u9664\u6cd5\uff08RFE\uff09\u3002\u5728\u524d\u4e00\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a\u7279\u5f81\u5f00\u59cb\uff0c\u7136\u540e\u4e0d\u65ad\u6dfb\u52a0\u65b0\u7684\u7279\u5f81\uff0c\u4f46\u5728 RFE \u4e2d\uff0c\u6211\u4eec\u4ece\u6240\u6709\u7279\u5f81\u5f00\u59cb\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u4e0d\u65ad\u53bb\u9664\u4e00\u4e2a\u5bf9\u7ed9\u5b9a\u6a21\u578b\u63d0\u4f9b\u6700\u5c0f\u503c\u7684\u7279\u5f81\u3002\u4f46\u6211\u4eec\u5982\u4f55\u77e5\u9053\u54ea\u4e2a\u7279\u5f81\u7684\u4ef7\u503c\u6700\u5c0f\u5462\uff1f\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7ebf\u6027\u652f\u6301\u5411\u91cf\u673a\uff08SVM\uff09\u6216\u903b\u8f91\u56de\u5f52\u7b49\u6a21\u578b\uff0c\u6211\u4eec\u4f1a\u4e3a\u6bcf\u4e2a\u7279\u5f81\u5f97\u5230\u4e00\u4e2a\u7cfb\u6570\uff0c\u8be5\u7cfb\u6570\u51b3\u5b9a\u4e86\u7279\u5f81\u7684\u91cd\u8981\u6027\u3002\u800c\u5bf9\u4e8e\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u7279\u5f81\u91cd\u8981\u6027\uff0c\u800c\u4e0d\u662f\u7cfb\u6570\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u53ef\u4ee5\u5254\u9664\u6700\u4e0d\u91cd\u8981\u7684\u7279\u5f81\uff0c\u76f4\u5230\u8fbe\u5230\u6240\u9700\u7684\u7279\u5f81\u6570\u91cf\u4e3a\u6b62\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u51b3\u5b9a\u8981\u4fdd\u7559\u591a\u5c11\u7279\u5f81\u3002 \u56fe 2\uff1a\u589e\u52a0\u65b0\u7279\u5f81\u540e\uff0c\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7684 AUC \u5206\u6570\u5982\u4f55\u53d8\u5316 \u5f53\u6211\u4eec\u8fdb\u884c\u9012\u5f52\u7279\u5f81\u5254\u9664\u65f6\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u4f1a\u5254\u9664\u7279\u5f81\u91cd\u8981\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u6216\u7cfb\u6570\u63a5\u8fd1 0 \u7684\u7279\u5f81\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5f53\u4f60\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u8fd9\u6837\u7684\u6a21\u578b\u8fdb\u884c\u4e8c\u5143\u5206\u7c7b\u65f6\uff0c\u5982\u679c\u7279\u5f81\u5bf9\u6b63\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u6b63\uff0c\u800c\u5982\u679c\u7279\u5f81\u5bf9\u8d1f\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u8d1f\u3002\u4fee\u6539\u6211\u4eec\u7684\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7c7b\uff0c\u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u9012\u5f52\u7279\u5f81\u6d88\u9664\u7c7b\u975e\u5e38\u5bb9\u6613\uff0c\u4f46 scikit-learn \u4e5f\u63d0\u4f9b\u4e86 RFE\u3002\u4e0b\u9762\u7684\u793a\u4f8b\u5c55\u793a\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u7528\u6cd5\u3002 import pandas as pd from sklearn.feature_selection import RFE from sklearn.linear_model import LinearRegression from sklearn.datasets import fetch_california_housing data = fetch_california_housing () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] model = LinearRegression () # \u521b\u5efa RFE\uff08\u9012\u5f52\u7279\u5f81\u6d88\u9664\uff09\uff0c\u6307\u5b9a\u6a21\u578b\u4e3a\u7ebf\u6027\u56de\u5f52\u6a21\u578b\uff0c\u8981\u9009\u62e9\u7684\u7279\u5f81\u6570\u91cf\u4e3a 3 rfe = RFE ( estimator = model , n_features_to_select = 3 ) # \u8bad\u7ec3\u6a21\u578b rfe . fit ( X , y ) # \u4f7f\u7528 RFE \u9009\u62e9\u7684\u7279\u5f81\u8fdb\u884c\u6570\u636e\u8f6c\u6362 X_transformed = rfe . transform ( X ) \u6211\u4eec\u770b\u5230\u4e86\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u7684\u4e24\u79cd\u4e0d\u540c\u7684\u8d2a\u5a6a\u65b9\u6cd5\u3002\u4f46\u4e5f\u53ef\u4ee5\u6839\u636e\u6570\u636e\u62df\u5408\u6a21\u578b\uff0c\u7136\u540e\u901a\u8fc7\u7279\u5f81\u7cfb\u6570\u6216\u7279\u5f81\u7684\u91cd\u8981\u6027\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u5982\u679c\u4f7f\u7528\u7cfb\u6570\uff0c\u5219\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u9608\u503c\uff0c\u5982\u679c\u7cfb\u6570\u9ad8\u4e8e\u8be5\u9608\u503c\uff0c\u5219\u53ef\u4ee5\u4fdd\u7559\u8be5\u7279\u5f81\uff0c\u5426\u5219\u5c06\u5176\u5254\u9664\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4ece\u968f\u673a\u68ee\u6797\u8fd9\u6837\u7684\u6a21\u578b\u4e2d\u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u5b9e\u4f8b\u5316\u968f\u673a\u68ee\u6797\u6a21\u578b model = RandomForestRegressor () # \u62df\u5408\u6a21\u578b model . fit ( X , y ) \u968f\u673a\u68ee\u6797\uff08\u6216\u4efb\u4f55\u6a21\u578b\uff09\u7684\u7279\u5f81\u91cd\u8981\u6027\u53ef\u6309\u5982\u4e0b\u65b9\u5f0f\u7ed8\u5236\u3002 # \u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027 importances = model . feature_importances_ # \u964d\u5e8f\u6392\u5217 idxs = np . argsort ( importances ) # \u8bbe\u5b9a\u6807\u9898 plt . title ( 'Feature Importances' ) # \u521b\u5efa\u76f4\u65b9\u56fe plt . barh ( range ( len ( idxs )), importances [ idxs ], align = 'center' ) # y\u8f74\u6807\u7b7e plt . yticks ( range ( len ( idxs )), [ col_names [ i ] for i in idxs ]) # x\u8f74\u6807\u7b7e plt . xlabel ( 'Random Forest Feature Importance' ) plt . show () \u7ed3\u679c\u5982\u56fe 3 \u6240\u793a\u3002 \u56fe 3\uff1a\u7279\u5f81\u91cd\u8981\u6027\u56fe \u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u6700\u4f73\u7279\u5f81\u5e76\u4e0d\u662f\u4ec0\u4e48\u65b0\u9c9c\u4e8b\u3002\u60a8\u53ef\u4ee5\u4ece\u4e00\u4e2a\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u53e6\u4e00\u4e2a\u6a21\u578b\u8fdb\u884c\u8bad\u7ec3\u3002\u4f8b\u5982\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u7cfb\u6570\u6765\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u968f\u673a\u68ee\u6797\uff08Random Forest\uff09\u5bf9\u6240\u9009\u7279\u5f81\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 SelectFromModel \u7c7b\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u76f4\u63a5\u4ece\u7ed9\u5b9a\u7684\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u6839\u636e\u9700\u8981\u6307\u5b9a\u7cfb\u6570\u6216\u7279\u5f81\u91cd\u8981\u6027\u7684\u9608\u503c\uff0c\u4ee5\u53ca\u8981\u9009\u62e9\u7684\u7279\u5f81\u7684\u6700\u5927\u6570\u91cf\u3002 \u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\uff0c\u6211\u4eec\u4f7f\u7528 SelectFromModel \u4e2d\u7684\u9ed8\u8ba4\u53c2\u6570\u6765\u9009\u62e9\u7279\u5f81\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor from sklearn.feature_selection import SelectFromModel data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u521b\u5efa\u968f\u673a\u68ee\u6797\u6a21\u578b\u56de\u5f52\u6a21\u578b model = RandomForestRegressor () # \u521b\u5efa SelectFromModel \u5bf9\u8c61 sfm\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u6a21\u578b\u4f5c\u4e3a\u4f30\u7b97\u5668 sfm = SelectFromModel ( estimator = model ) # \u4f7f\u7528 sfm \u5bf9\u7279\u5f81\u77e9\u9635 X \u548c\u76ee\u6807\u53d8\u91cf y \u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed = sfm . fit_transform ( X , y ) # \u83b7\u53d6\u7ecf\u8fc7\u7279\u5f81\u9009\u62e9\u540e\u7684\u7279\u5f81\u63a9\u7801\uff08True \u8868\u793a\u7279\u5f81\u88ab\u9009\u62e9\uff0cFalse \u8868\u793a\u7279\u5f81\u672a\u88ab\u9009\u62e9\uff09 support = sfm . get_support () # \u6253\u5370\u88ab\u9009\u62e9\u7684\u7279\u5f81\u5217\u540d print ([ x for x , y in zip ( col_names , support ) if y == True ]) \u4e0a\u9762\u7a0b\u5e8f\u6253\u5370\u7ed3\u679c\uff1a ['bmi'\uff0c's5']\u3002\u6211\u4eec\u518d\u770b\u56fe 3\uff0c\u5c31\u4f1a\u53d1\u73b0\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u4e2a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u76f4\u63a5\u4ece\u968f\u673a\u68ee\u6797\u63d0\u4f9b\u7684\u7279\u5f81\u91cd\u8981\u6027\u4e2d\u8fdb\u884c\u9009\u62e9\u3002\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4ef6\u4e8b\uff0c\u90a3\u5c31\u662f\u4f7f\u7528 L1\uff08Lasso\uff09\u60e9\u7f5a\u6a21\u578b \u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u5f53\u6211\u4eec\u4f7f\u7528 L1 \u60e9\u7f5a\u8fdb\u884c\u6b63\u5219\u5316\u65f6\uff0c\u5927\u90e8\u5206\u7cfb\u6570\u90fd\u5c06\u4e3a 0\uff08\u6216\u63a5\u8fd1 0\uff09\uff0c\u56e0\u6b64\u6211\u4eec\u8981\u9009\u62e9\u7cfb\u6570\u4e0d\u4e3a 0 \u7684\u7279\u5f81\u3002\u53ea\u9700\u5c06\u6a21\u578b\u9009\u62e9\u7247\u6bb5\u4e2d\u7684\u968f\u673a\u68ee\u6797\u66ff\u6362\u4e3a\u652f\u6301 L1 \u60e9\u7f5a\u7684\u6a21\u578b\uff08\u5982 lasso \u56de\u5f52\uff09\u5373\u53ef\u3002\u6240\u6709\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u90fd\u63d0\u4f9b\u7279\u5f81\u91cd\u8981\u6027\uff0c\u56e0\u6b64\u672c\u7ae0\u4e2d\u5c55\u793a\u7684\u6240\u6709\u57fa\u4e8e\u6a21\u578b\u7684\u7247\u6bb5\u90fd\u53ef\u7528\u4e8e XGBoost\u3001LightGBM \u6216 CatBoost\u3002\u7279\u5f81\u91cd\u8981\u6027\u51fd\u6570\u7684\u540d\u79f0\u53ef\u80fd\u4e0d\u540c\uff0c\u4ea7\u751f\u7ed3\u679c\u7684\u683c\u5f0f\u4e5f\u53ef\u80fd\u4e0d\u540c\uff0c\u4f46\u7528\u6cd5\u662f\u4e00\u6837\u7684\u3002\u6700\u540e\uff0c\u5728\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u65f6\u5fc5\u987b\u5c0f\u5fc3\u8c28\u614e\u3002\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u9009\u62e9\u7279\u5f81\uff0c\u5e76\u5728\u9a8c\u8bc1\u6570\u636e\u4e0a\u9a8c\u8bc1\u6a21\u578b\uff0c\u4ee5\u4fbf\u5728\u4e0d\u8fc7\u5ea6\u62df\u5408\u6a21\u578b\u7684\u60c5\u51b5\u4e0b\u6b63\u786e\u9009\u62e9\u7279\u5f81\u3002","title":"\u7279\u5f81\u9009\u62e9"},{"location":"%E7%89%B9%E5%BE%81%E9%80%89%E6%8B%A9/#_1","text":"\u5f53\u4f60\u521b\u5efa\u4e86\u6210\u5343\u4e0a\u4e07\u4e2a\u7279\u5f81\u540e\uff0c\u5c31\u8be5\u4ece\u4e2d\u6311\u9009\u51fa\u51e0\u4e2a\u4e86\u3002\u4f46\u662f\uff0c\u6211\u4eec\u7edd\u4e0d\u5e94\u8be5\u521b\u5efa\u6210\u767e\u4e0a\u5343\u4e2a\u65e0\u7528\u7684\u7279\u5f81\u3002\u7279\u5f81\u8fc7\u591a\u4f1a\u5e26\u6765\u4e00\u4e2a\u4f17\u6240\u5468\u77e5\u7684\u95ee\u9898\uff0c\u5373 \"\u7ef4\u5ea6\u8bc5\u5492\"\u3002\u5982\u679c\u4f60\u6709\u5f88\u591a\u7279\u5f81\uff0c\u4f60\u4e5f\u5fc5\u987b\u6709\u5f88\u591a\u8bad\u7ec3\u6837\u672c\u6765\u6355\u6349\u6240\u6709\u7279\u5f81\u3002\u4ec0\u4e48\u662f \"\u5927\u91cf \"\u5e76\u6ca1\u6709\u6b63\u786e\u7684\u5b9a\u4e49\uff0c\u8fd9\u9700\u8981\u60a8\u901a\u8fc7\u6b63\u786e\u9a8c\u8bc1\u60a8\u7684\u6a21\u578b\u548c\u68c0\u67e5\u8bad\u7ec3\u6a21\u578b\u6240\u9700\u7684\u65f6\u95f4\u6765\u786e\u5b9a\u3002 \u9009\u62e9\u7279\u5f81\u7684\u6700\u7b80\u5355\u65b9\u6cd5\u662f \u5220\u9664\u65b9\u5dee\u975e\u5e38\u5c0f\u7684\u7279\u5f81 \u3002\u5982\u679c\u7279\u5f81\u7684\u65b9\u5dee\u975e\u5e38\u5c0f\uff08\u5373\u975e\u5e38\u63a5\u8fd1\u4e8e 0\uff09\uff0c\u5b83\u4eec\u5c31\u63a5\u8fd1\u4e8e\u5e38\u91cf\uff0c\u56e0\u6b64\u6839\u672c\u4e0d\u4f1a\u7ed9\u4efb\u4f55\u6a21\u578b\u589e\u52a0\u4efb\u4f55\u4ef7\u503c\u3002\u6700\u597d\u7684\u529e\u6cd5\u5c31\u662f\u53bb\u6389\u5b83\u4eec\uff0c\u4ece\u800c\u964d\u4f4e\u590d\u6742\u5ea6\u3002\u8bf7\u6ce8\u610f\uff0c\u65b9\u5dee\u4e5f\u53d6\u51b3\u4e8e\u6570\u636e\u7684\u7f29\u653e\u3002 Scikit-learn \u7684 VarianceThreshold \u5b9e\u73b0\u4e86\u8fd9\u4e00\u70b9\u3002 from sklearn.feature_selection import VarianceThreshold data = ... # \u521b\u5efa VarianceThreshold \u5bf9\u8c61 var_thresh\uff0c\u6307\u5b9a\u65b9\u5dee\u9608\u503c\u4e3a 0.1 var_thresh = VarianceThreshold ( threshold = 0.1 ) # \u4f7f\u7528 var_thresh \u5bf9\u6570\u636e data \u8fdb\u884c\u62df\u5408\u548c\u53d8\u6362\uff0c\u5c06\u65b9\u5dee\u4f4e\u4e8e\u9608\u503c\u7684\u7279\u5f81\u79fb\u9664 transformed_data = var_thresh . fit_transform ( data ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u5220\u9664\u76f8\u5173\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u3002\u8981\u8ba1\u7b97\u4e0d\u540c\u6570\u5b57\u7279\u5f81\u4e4b\u95f4\u7684\u76f8\u5173\u6027\uff0c\u53ef\u4ee5\u4f7f\u7528\u76ae\u5c14\u900a\u76f8\u5173\u6027\u3002 import pandas as pd from sklearn.datasets import fetch_california_housing # \u52a0\u8f7d\u6570\u636e data = fetch_california_housing () # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u77e9\u9635 X X = data [ \"data\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u7684\u5217\u540d col_names = data [ \"feature_names\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf y y = data [ \"target\" ] df = pd . DataFrame ( X , columns = col_names ) # \u6dfb\u52a0 MedInc_Sqrt \u5217\uff0c\u662f MedInc \u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u5e73\u65b9\u6839\u8fd0\u7b97\u7684\u7ed3\u679c df . loc [:, \"MedInc_Sqrt\" ] = df . MedInc . apply ( np . sqrt ) # \u8ba1\u7b97\u76ae\u5c14\u900a\u76f8\u5173\u6027\u77e9\u9635 df . corr () \u5f97\u51fa\u76f8\u5173\u77e9\u9635\uff0c\u5982\u56fe 1 \u6240\u793a\u3002 \u56fe 1\uff1a\u76ae\u5c14\u900a\u76f8\u5173\u77e9\u9635\u6837\u672c \u6211\u4eec\u770b\u5230\uff0cMedInc_Sqrt \u4e0e MedInc \u7684\u76f8\u5173\u6027\u975e\u5e38\u9ad8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5220\u9664\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u4e00\u4e9b \u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5 \u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u53ea\u4e0d\u8fc7\u662f\u9488\u5bf9\u7ed9\u5b9a\u76ee\u6807\u5bf9\u6bcf\u4e2a\u7279\u5f81\u8fdb\u884c\u8bc4\u5206\u3002 \u4e92\u4fe1\u606f \u3001 \u65b9\u5dee\u5206\u6790 F \u68c0\u9a8c\u548c chi2 \u662f\u4e00\u4e9b\u6700\u5e38\u7528\u7684\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u3002\u5728 scikit- learn \u4e2d\uff0c\u6709\u4e24\u79cd\u65b9\u6cd5\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e9b\u65b9\u6cd5\u3002 - SelectKBest\uff1a\u4fdd\u7559\u5f97\u5206\u6700\u9ad8\u7684 k \u4e2a\u7279\u5f81 - SelectPercentile\uff1a\u4fdd\u7559\u7528\u6237\u6307\u5b9a\u767e\u5206\u6bd4\u5185\u7684\u9876\u7ea7\u7279\u5f81\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u53ea\u6709\u975e\u8d1f\u6570\u636e\u624d\u80fd\u4f7f\u7528 chi2\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u5f53\u6211\u4eec\u6709\u4e00\u4e9b\u5355\u8bcd\u6216\u57fa\u4e8e tf-idf \u7684\u7279\u5f81\u65f6\uff0c\u8fd9\u662f\u4e00\u79cd\u7279\u522b\u6709\u7528\u7684\u7279\u5f81\u9009\u62e9\u6280\u672f\u3002\u6700\u597d\u4e3a\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u521b\u5efa\u4e00\u4e2a\u5305\u88c5\u5668\uff0c\u51e0\u4e4e\u53ef\u4ee5\u7528\u4e8e\u4efb\u4f55\u65b0\u95ee\u9898\u3002 from sklearn.feature_selection import chi2 from sklearn.feature_selection import f_classif from sklearn.feature_selection import f_regression from sklearn.feature_selection import mutual_info_classif from sklearn.feature_selection import mutual_info_regression from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectPercentile class UnivariateFeatureSelction : def __init__ ( self , n_features , problem_type , scoring ): # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u5206\u7c7b\u95ee\u9898 if problem_type == \"classification\" : # \u521b\u5efa\u5b57\u5178 valid_scoring \uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_classif\" : f_classif , \"chi2\" : chi2 , \"mutual_info_classif\" : mutual_info_classif } # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u56de\u5f52\u95ee\u9898 else : # \u521b\u5efa\u5b57\u5178 valid_scoring\uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_regression\" : f_regression , \"mutual_info_regression\" : mutual_info_regression } # \u68c0\u67e5\u7279\u5f81\u91cd\u8981\u6027\u65b9\u5f0f\u662f\u5426\u5728\u5b57\u5178\u4e2d if scoring not in valid_scoring : raise Exception ( \"Invalid scoring function\" ) # \u68c0\u67e5 n_features \u7684\u7c7b\u578b\uff0c\u5982\u679c\u662f\u6574\u6570\uff0c\u5219\u4f7f\u7528 SelectKBest \u8fdb\u884c\u7279\u5f81\u9009\u62e9 if isinstance ( n_features , int ): self . selection = SelectKBest ( valid_scoring [ scoring ], k = n_features ) # \u5982\u679c n_features \u662f\u6d6e\u70b9\u6570\uff0c\u5219\u4f7f\u7528 SelectPercentile \u8fdb\u884c\u7279\u5f81\u9009\u62e9 elif isinstance ( n_features , float ): self . selection = SelectPercentile ( valid_scoring [ scoring ], percentile = int ( n_features * 100 ) ) # \u5982\u679c n_features \u7c7b\u578b\u65e0\u6548\uff0c\u5f15\u53d1\u5f02\u5e38 else : raise Exception ( \"Invalid type of feature\" ) # \u5b9a\u4e49 fit \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 def fit ( self , X , y ): return self . selection . fit ( X , y ) # \u5b9a\u4e49 transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def transform ( self , X ): return self . selection . transform ( X ) # \u5b9a\u4e49 fit_transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668\u5e76\u540c\u65f6\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def fit_transform ( self , X , y ): return self . selection . fit_transform ( X , y ) \u4f7f\u7528\u8be5\u7c7b\u975e\u5e38\u7b80\u5355\u3002 # \u5b9e\u4f8b\u5316\u7279\u5f81\u9009\u62e9\u5668\uff0c\u4fdd\u7559\u524d10%\u7684\u7279\u5f81\uff0c\u56de\u5f52\u95ee\u9898\uff0c\u4f7f\u7528f_regression\u8861\u91cf\u7279\u5f81\u91cd\u8981\u6027 ufs = UnivariateFeatureSelction ( n_features = 0.1 , problem_type = \"regression\" , scoring = \"f_regression\" ) # \u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 ufs . fit ( X , y ) # \u7279\u5f81\u8f6c\u6362 X_transformed = ufs . transform ( X ) \u8fd9\u6837\u5c31\u80fd\u6ee1\u8db3\u5927\u90e8\u5206\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u7684\u9700\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u521b\u5efa\u8f83\u5c11\u800c\u91cd\u8981\u7684\u7279\u5f81\u901a\u5e38\u6bd4\u521b\u5efa\u6570\u4ee5\u767e\u8ba1\u7684\u7279\u5f81\u8981\u597d\u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u4e0d\u4e00\u5b9a\u603b\u662f\u8868\u73b0\u826f\u597d\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u4eba\u4eec\u66f4\u559c\u6b22\u4f7f\u7528\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 \u4f7f\u7528\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u7684\u6700\u7b80\u5355\u5f62\u5f0f\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u3002\u5728\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u4e2d\uff0c\u7b2c\u4e00\u6b65\u662f\u9009\u62e9\u4e00\u4e2a\u6a21\u578b\u3002\u7b2c\u4e8c\u6b65\u662f\u9009\u62e9\u635f\u5931/\u8bc4\u5206\u51fd\u6570\u3002\u7b2c\u4e09\u6b65\u4e5f\u662f\u6700\u540e\u4e00\u6b65\u662f\u53cd\u590d\u8bc4\u4f30\u6bcf\u4e2a\u7279\u5f81\uff0c\u5982\u679c\u80fd\u63d0\u9ad8\u635f\u5931/\u8bc4\u5206\uff0c\u5c31\u5c06\u5176\u6dfb\u52a0\u5230 \"\u597d \"\u7279\u5f81\u5217\u8868\u4e2d\u3002\u6ca1\u6709\u6bd4\u8fd9\u66f4\u7b80\u5355\u7684\u4e86\u3002\u4f46\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u8fd9\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u662f\u6709\u539f\u56e0\u7684\u3002\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u8fc7\u7a0b\u5728\u6bcf\u6b21\u8bc4\u4f30\u7279\u5f81\u65f6\u90fd\u4f1a\u9002\u5408\u7ed9\u5b9a\u7684\u6a21\u578b\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u8ba1\u7b97\u6210\u672c\u975e\u5e38\u9ad8\u3002\u5b8c\u6210\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u4e5f\u9700\u8981\u5927\u91cf\u65f6\u95f4\u3002\u5982\u679c\u4e0d\u6b63\u786e\u4f7f\u7528\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\uff0c\u751a\u81f3\u4f1a\u5bfc\u81f4\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn.datasets import make_classification class GreedyFeatureSelection : # \u5b9a\u4e49\u8bc4\u4f30\u5206\u6570\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd def evaluate_score ( self , X , y ): # \u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u8bad\u7ec3\u6a21\u578b model . fit ( X , y ) # \u9884\u6d4b\u6982\u7387\u503c predictions = model . predict_proba ( X )[:, 1 ] # \u8ba1\u7b97 AUC \u5206\u6570 auc = metrics . roc_auc_score ( y , predictions ) return auc # \u7279\u5f81\u9009\u62e9\u51fd\u6570 def _feature_selection ( self , X , y ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6700\u4f73\u7279\u5f81\u548c\u6700\u4f73\u5206\u6570 good_features = [] best_scores = [] # \u83b7\u53d6\u7279\u5f81\u6570\u91cf num_features = X . shape [ 1 ] # \u5f00\u59cb\u7279\u5f81\u9009\u62e9\u7684\u5faa\u73af while True : this_feature = None best_score = 0 # \u904d\u5386\u6bcf\u4e2a\u7279\u5f81 for feature in range ( num_features ): if feature in good_features : continue selected_features = good_features + [ feature ] xtrain = X [:, selected_features ] score = self . evaluate_score ( xtrain , y ) # \u5982\u679c\u5f53\u524d\u7279\u5f81\u7684\u5f97\u5206\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u5f97\u5206\uff0c\u5219\u66f4\u65b0 if score > best_score : this_feature = feature best_score = score # \u82e5\u627e\u5230\u4e86\u65b0\u7684\u6700\u4f73\u7279\u5f81 if this_feature != None : # \u7279\u5f81\u6dfb\u52a0\u5230 good_features \u5217\u8868 good_features . append ( this_feature ) # \u5f97\u5206\u6dfb\u52a0\u5230 best_scores \u5217\u8868 best_scores . append ( best_score ) # \u5982\u679c best_scores \u5217\u8868\u957f\u5ea6\u5927\u4e8e2\uff0c\u5e76\u4e14\u6700\u540e\u4e24\u4e2a\u5f97\u5206\u76f8\u6bd4\u8f83\u5dee\uff0c\u5219\u7ed3\u675f\u5faa\u73af if len ( best_scores ) > 2 : if best_scores [ - 1 ] < best_scores [ - 2 ]: break # \u8fd4\u56de\u6700\u4f73\u7279\u5f81\u7684\u5f97\u5206\u5217\u8868\u548c\u6700\u4f73\u7279\u5f81\u5217\u8868 return best_scores [: - 1 ], good_features [: - 1 ] # \u5b9a\u4e49\u7c7b\u7684\u8c03\u7528\u65b9\u6cd5\uff0c\u7528\u4e8e\u6267\u884c\u7279\u5f81\u9009\u62e9 def __call__ ( self , X , y ): scores , features = self . _feature_selection ( X , y ) return X [:, features ], scores if __name__ == \"__main__\" : # \u751f\u6210\u4e00\u4e2a\u793a\u4f8b\u7684\u5206\u7c7b\u6570\u636e\u96c6 X \u548c\u6807\u7b7e y X , y = make_classification ( n_samples = 1000 , n_features = 100 ) # \u5b9e\u4f8b\u5316 GreedyFeatureSelection \u7c7b\uff0c\u5e76\u4f7f\u7528 __call__ \u65b9\u6cd5\u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed , scores = GreedyFeatureSelection ()( X , y ) \u8fd9\u79cd\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u4f1a\u8fd4\u56de\u5206\u6570\u548c\u7279\u5f81\u7d22\u5f15\u5217\u8868\u3002\u56fe 2 \u663e\u793a\u4e86\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u589e\u52a0\u4e00\u4e2a\u65b0\u7279\u5f81\u540e\uff0c\u5206\u6570\u662f\u5982\u4f55\u63d0\u9ad8\u7684\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u67d0\u4e00\u70b9\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u65e0\u6cd5\u63d0\u9ad8\u5206\u6570\u4e86\uff0c\u8fd9\u5c31\u662f\u6211\u4eec\u505c\u6b62\u7684\u5730\u65b9\u3002 \u53e6\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u9012\u5f52\u7279\u5f81\u6d88\u9664\u6cd5\uff08RFE\uff09\u3002\u5728\u524d\u4e00\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a\u7279\u5f81\u5f00\u59cb\uff0c\u7136\u540e\u4e0d\u65ad\u6dfb\u52a0\u65b0\u7684\u7279\u5f81\uff0c\u4f46\u5728 RFE \u4e2d\uff0c\u6211\u4eec\u4ece\u6240\u6709\u7279\u5f81\u5f00\u59cb\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u4e0d\u65ad\u53bb\u9664\u4e00\u4e2a\u5bf9\u7ed9\u5b9a\u6a21\u578b\u63d0\u4f9b\u6700\u5c0f\u503c\u7684\u7279\u5f81\u3002\u4f46\u6211\u4eec\u5982\u4f55\u77e5\u9053\u54ea\u4e2a\u7279\u5f81\u7684\u4ef7\u503c\u6700\u5c0f\u5462\uff1f\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7ebf\u6027\u652f\u6301\u5411\u91cf\u673a\uff08SVM\uff09\u6216\u903b\u8f91\u56de\u5f52\u7b49\u6a21\u578b\uff0c\u6211\u4eec\u4f1a\u4e3a\u6bcf\u4e2a\u7279\u5f81\u5f97\u5230\u4e00\u4e2a\u7cfb\u6570\uff0c\u8be5\u7cfb\u6570\u51b3\u5b9a\u4e86\u7279\u5f81\u7684\u91cd\u8981\u6027\u3002\u800c\u5bf9\u4e8e\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u7279\u5f81\u91cd\u8981\u6027\uff0c\u800c\u4e0d\u662f\u7cfb\u6570\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u53ef\u4ee5\u5254\u9664\u6700\u4e0d\u91cd\u8981\u7684\u7279\u5f81\uff0c\u76f4\u5230\u8fbe\u5230\u6240\u9700\u7684\u7279\u5f81\u6570\u91cf\u4e3a\u6b62\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u51b3\u5b9a\u8981\u4fdd\u7559\u591a\u5c11\u7279\u5f81\u3002 \u56fe 2\uff1a\u589e\u52a0\u65b0\u7279\u5f81\u540e\uff0c\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7684 AUC \u5206\u6570\u5982\u4f55\u53d8\u5316 \u5f53\u6211\u4eec\u8fdb\u884c\u9012\u5f52\u7279\u5f81\u5254\u9664\u65f6\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u4f1a\u5254\u9664\u7279\u5f81\u91cd\u8981\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u6216\u7cfb\u6570\u63a5\u8fd1 0 \u7684\u7279\u5f81\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5f53\u4f60\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u8fd9\u6837\u7684\u6a21\u578b\u8fdb\u884c\u4e8c\u5143\u5206\u7c7b\u65f6\uff0c\u5982\u679c\u7279\u5f81\u5bf9\u6b63\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u6b63\uff0c\u800c\u5982\u679c\u7279\u5f81\u5bf9\u8d1f\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u8d1f\u3002\u4fee\u6539\u6211\u4eec\u7684\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7c7b\uff0c\u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u9012\u5f52\u7279\u5f81\u6d88\u9664\u7c7b\u975e\u5e38\u5bb9\u6613\uff0c\u4f46 scikit-learn \u4e5f\u63d0\u4f9b\u4e86 RFE\u3002\u4e0b\u9762\u7684\u793a\u4f8b\u5c55\u793a\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u7528\u6cd5\u3002 import pandas as pd from sklearn.feature_selection import RFE from sklearn.linear_model import LinearRegression from sklearn.datasets import fetch_california_housing data = fetch_california_housing () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] model = LinearRegression () # \u521b\u5efa RFE\uff08\u9012\u5f52\u7279\u5f81\u6d88\u9664\uff09\uff0c\u6307\u5b9a\u6a21\u578b\u4e3a\u7ebf\u6027\u56de\u5f52\u6a21\u578b\uff0c\u8981\u9009\u62e9\u7684\u7279\u5f81\u6570\u91cf\u4e3a 3 rfe = RFE ( estimator = model , n_features_to_select = 3 ) # \u8bad\u7ec3\u6a21\u578b rfe . fit ( X , y ) # \u4f7f\u7528 RFE \u9009\u62e9\u7684\u7279\u5f81\u8fdb\u884c\u6570\u636e\u8f6c\u6362 X_transformed = rfe . transform ( X ) \u6211\u4eec\u770b\u5230\u4e86\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u7684\u4e24\u79cd\u4e0d\u540c\u7684\u8d2a\u5a6a\u65b9\u6cd5\u3002\u4f46\u4e5f\u53ef\u4ee5\u6839\u636e\u6570\u636e\u62df\u5408\u6a21\u578b\uff0c\u7136\u540e\u901a\u8fc7\u7279\u5f81\u7cfb\u6570\u6216\u7279\u5f81\u7684\u91cd\u8981\u6027\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u5982\u679c\u4f7f\u7528\u7cfb\u6570\uff0c\u5219\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u9608\u503c\uff0c\u5982\u679c\u7cfb\u6570\u9ad8\u4e8e\u8be5\u9608\u503c\uff0c\u5219\u53ef\u4ee5\u4fdd\u7559\u8be5\u7279\u5f81\uff0c\u5426\u5219\u5c06\u5176\u5254\u9664\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4ece\u968f\u673a\u68ee\u6797\u8fd9\u6837\u7684\u6a21\u578b\u4e2d\u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u5b9e\u4f8b\u5316\u968f\u673a\u68ee\u6797\u6a21\u578b model = RandomForestRegressor () # \u62df\u5408\u6a21\u578b model . fit ( X , y ) \u968f\u673a\u68ee\u6797\uff08\u6216\u4efb\u4f55\u6a21\u578b\uff09\u7684\u7279\u5f81\u91cd\u8981\u6027\u53ef\u6309\u5982\u4e0b\u65b9\u5f0f\u7ed8\u5236\u3002 # \u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027 importances = model . feature_importances_ # \u964d\u5e8f\u6392\u5217 idxs = np . argsort ( importances ) # \u8bbe\u5b9a\u6807\u9898 plt . title ( 'Feature Importances' ) # \u521b\u5efa\u76f4\u65b9\u56fe plt . barh ( range ( len ( idxs )), importances [ idxs ], align = 'center' ) # y\u8f74\u6807\u7b7e plt . yticks ( range ( len ( idxs )), [ col_names [ i ] for i in idxs ]) # x\u8f74\u6807\u7b7e plt . xlabel ( 'Random Forest Feature Importance' ) plt . show () \u7ed3\u679c\u5982\u56fe 3 \u6240\u793a\u3002 \u56fe 3\uff1a\u7279\u5f81\u91cd\u8981\u6027\u56fe \u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u6700\u4f73\u7279\u5f81\u5e76\u4e0d\u662f\u4ec0\u4e48\u65b0\u9c9c\u4e8b\u3002\u60a8\u53ef\u4ee5\u4ece\u4e00\u4e2a\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u53e6\u4e00\u4e2a\u6a21\u578b\u8fdb\u884c\u8bad\u7ec3\u3002\u4f8b\u5982\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u7cfb\u6570\u6765\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u968f\u673a\u68ee\u6797\uff08Random Forest\uff09\u5bf9\u6240\u9009\u7279\u5f81\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 SelectFromModel \u7c7b\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u76f4\u63a5\u4ece\u7ed9\u5b9a\u7684\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u6839\u636e\u9700\u8981\u6307\u5b9a\u7cfb\u6570\u6216\u7279\u5f81\u91cd\u8981\u6027\u7684\u9608\u503c\uff0c\u4ee5\u53ca\u8981\u9009\u62e9\u7684\u7279\u5f81\u7684\u6700\u5927\u6570\u91cf\u3002 \u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\uff0c\u6211\u4eec\u4f7f\u7528 SelectFromModel \u4e2d\u7684\u9ed8\u8ba4\u53c2\u6570\u6765\u9009\u62e9\u7279\u5f81\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor from sklearn.feature_selection import SelectFromModel data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u521b\u5efa\u968f\u673a\u68ee\u6797\u6a21\u578b\u56de\u5f52\u6a21\u578b model = RandomForestRegressor () # \u521b\u5efa SelectFromModel \u5bf9\u8c61 sfm\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u6a21\u578b\u4f5c\u4e3a\u4f30\u7b97\u5668 sfm = SelectFromModel ( estimator = model ) # \u4f7f\u7528 sfm \u5bf9\u7279\u5f81\u77e9\u9635 X \u548c\u76ee\u6807\u53d8\u91cf y \u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed = sfm . fit_transform ( X , y ) # \u83b7\u53d6\u7ecf\u8fc7\u7279\u5f81\u9009\u62e9\u540e\u7684\u7279\u5f81\u63a9\u7801\uff08True \u8868\u793a\u7279\u5f81\u88ab\u9009\u62e9\uff0cFalse \u8868\u793a\u7279\u5f81\u672a\u88ab\u9009\u62e9\uff09 support = sfm . get_support () # \u6253\u5370\u88ab\u9009\u62e9\u7684\u7279\u5f81\u5217\u540d print ([ x for x , y in zip ( col_names , support ) if y == True ]) \u4e0a\u9762\u7a0b\u5e8f\u6253\u5370\u7ed3\u679c\uff1a ['bmi'\uff0c's5']\u3002\u6211\u4eec\u518d\u770b\u56fe 3\uff0c\u5c31\u4f1a\u53d1\u73b0\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u4e2a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u76f4\u63a5\u4ece\u968f\u673a\u68ee\u6797\u63d0\u4f9b\u7684\u7279\u5f81\u91cd\u8981\u6027\u4e2d\u8fdb\u884c\u9009\u62e9\u3002\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4ef6\u4e8b\uff0c\u90a3\u5c31\u662f\u4f7f\u7528 L1\uff08Lasso\uff09\u60e9\u7f5a\u6a21\u578b \u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u5f53\u6211\u4eec\u4f7f\u7528 L1 \u60e9\u7f5a\u8fdb\u884c\u6b63\u5219\u5316\u65f6\uff0c\u5927\u90e8\u5206\u7cfb\u6570\u90fd\u5c06\u4e3a 0\uff08\u6216\u63a5\u8fd1 0\uff09\uff0c\u56e0\u6b64\u6211\u4eec\u8981\u9009\u62e9\u7cfb\u6570\u4e0d\u4e3a 0 \u7684\u7279\u5f81\u3002\u53ea\u9700\u5c06\u6a21\u578b\u9009\u62e9\u7247\u6bb5\u4e2d\u7684\u968f\u673a\u68ee\u6797\u66ff\u6362\u4e3a\u652f\u6301 L1 \u60e9\u7f5a\u7684\u6a21\u578b\uff08\u5982 lasso \u56de\u5f52\uff09\u5373\u53ef\u3002\u6240\u6709\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u90fd\u63d0\u4f9b\u7279\u5f81\u91cd\u8981\u6027\uff0c\u56e0\u6b64\u672c\u7ae0\u4e2d\u5c55\u793a\u7684\u6240\u6709\u57fa\u4e8e\u6a21\u578b\u7684\u7247\u6bb5\u90fd\u53ef\u7528\u4e8e XGBoost\u3001LightGBM \u6216 CatBoost\u3002\u7279\u5f81\u91cd\u8981\u6027\u51fd\u6570\u7684\u540d\u79f0\u53ef\u80fd\u4e0d\u540c\uff0c\u4ea7\u751f\u7ed3\u679c\u7684\u683c\u5f0f\u4e5f\u53ef\u80fd\u4e0d\u540c\uff0c\u4f46\u7528\u6cd5\u662f\u4e00\u6837\u7684\u3002\u6700\u540e\uff0c\u5728\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u65f6\u5fc5\u987b\u5c0f\u5fc3\u8c28\u614e\u3002\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u9009\u62e9\u7279\u5f81\uff0c\u5e76\u5728\u9a8c\u8bc1\u6570\u636e\u4e0a\u9a8c\u8bc1\u6a21\u578b\uff0c\u4ee5\u4fbf\u5728\u4e0d\u8fc7\u5ea6\u62df\u5408\u6a21\u578b\u7684\u60c5\u51b5\u4e0b\u6b63\u786e\u9009\u62e9\u7279\u5f81\u3002","title":"\u7279\u5f81\u9009\u62e9"},{"location":"%E7%BB%84%E5%90%88%E5%92%8C%E5%A0%86%E5%8F%A0%E6%96%B9%E6%B3%95/","text":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u542c\u5230\u4e0a\u9762\u4e24\u4e2a\u8bcd\uff0c\u6211\u4eec\u9996\u5148\u60f3\u5230\u7684\u5c31\u662f\u5728\u7ebf\uff08online\uff09/\u79bb\u7ebf\uff08offline\uff09\u673a\u5668\u5b66\u4e60\u7ade\u8d5b\u3002\u51e0\u5e74\u524d\u662f\u8fd9\u6837\uff0c\u4f46\u73b0\u5728\u968f\u7740\u8ba1\u7b97\u80fd\u529b\u7684\u8fdb\u6b65\u548c\u865a\u62df\u5b9e\u4f8b\u7684\u5ec9\u4ef7\uff0c\u4eba\u4eec\u751a\u81f3\u5f00\u59cb\u5728\u884c\u4e1a\u4e2d\u4f7f\u7528\u7ec4\u5408\u6a21\u578b\uff08ensemble models\uff09\u3002\u4f8b\u5982\uff0c\u90e8\u7f72\u591a\u4e2a\u795e\u7ecf\u7f51\u7edc\u5e76\u5b9e\u65f6\u4e3a\u5b83\u4eec\u63d0\u4f9b\u670d\u52a1\u975e\u5e38\u5bb9\u6613\uff0c\u54cd\u5e94\u65f6\u95f4\u5c0f\u4e8e 500 \u6beb\u79d2\u3002\u6709\u65f6\uff0c\u4e00\u4e2a\u5e9e\u5927\u7684\u795e\u7ecf\u7f51\u7edc\u6216\u5927\u578b\u6a21\u578b\u4e5f\u53ef\u4ee5\u88ab\u5176\u4ed6\u51e0\u4e2a\u6a21\u578b\u53d6\u4ee3\uff0c\u8fd9\u4e9b\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u6027\u80fd\u4e0e\u5927\u578b\u6a21\u578b\u76f8\u4f3c\uff0c\u901f\u5ea6\u5374\u5feb\u4e00\u500d\u3002\u5982\u679c\u662f\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff08\u4e9b\uff09\u6a21\u578b\u5462\uff1f\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u4e8e\u9009\u62e9\u591a\u4e2a\u5c0f\u673a\u578b\uff0c\u5b83\u4eec\u901f\u5ea6\u66f4\u5feb\uff0c\u6027\u80fd\u4e0e\u5927\u673a\u578b\u548c\u6162\u673a\u578b\u76f8\u540c\u3002\u8bf7\u8bb0\u4f4f\uff0c\u8f83\u5c0f\u7684\u578b\u53f7\u4e5f\u66f4\u5bb9\u6613\u548c\u66f4\u5feb\u5730\u8fdb\u884c\u8c03\u6574\u3002 \u7ec4\u5408\uff08ensembling\uff09\u4e0d\u8fc7\u662f\u4e0d\u540c\u6a21\u578b\u7684\u7ec4\u5408\u3002\u6a21\u578b\u53ef\u4ee5\u901a\u8fc7\u9884\u6d4b/\u6982\u7387\u8fdb\u884c\u7ec4\u5408\u3002\u7ec4\u5408\u6a21\u578b\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6c42\u5e73\u5747\u503c\u3002 $$ Ensemble Probabilities = (M1_proba + M2_proba + ... + Mn_Proba)/n $$ \u8fd9\u662f\u6700\u7b80\u5355\u4e5f\u662f\u6700\u6709\u6548\u7684\u7ec4\u5408\u6a21\u578b\u7684\u65b9\u6cd5\u3002\u5728\u7b80\u5355\u5e73\u5747\u6cd5\u4e2d\uff0c\u6240\u6709\u6a21\u578b\u7684\u6743\u91cd\u90fd\u662f\u76f8\u7b49\u7684\u3002\u65e0\u8bba\u91c7\u7528\u54ea\u79cd\u7ec4\u5408\u65b9\u6cd5\uff0c\u60a8\u90fd\u5e94\u8be5\u7262\u8bb0\u4e00\u70b9\uff0c\u90a3\u5c31\u662f\u60a8\u5e94\u8be5\u59cb\u7ec8\u5c06\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b/\u6982\u7387\u7ec4\u5408\u5728\u4e00\u8d77\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u7ec4\u5408\u76f8\u5173\u6027\u4e0d\u9ad8\u7684\u6a21\u578b\u6bd4\u7ec4\u5408\u76f8\u5173\u6027\u5f88\u9ad8\u7684\u6a21\u578b\u6548\u679c\u66f4\u597d\u3002 \u5982\u679c\u6ca1\u6709\u6982\u7387\uff0c\u4e5f\u53ef\u4ee5\u7ec4\u5408\u9884\u6d4b\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6295\u7968\u3002\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u8fdb\u884c\u591a\u7c7b\u5206\u7c7b\uff0c\u6709\u4e09\u4e2a\u7c7b\u522b\uff1a 0\u30011 \u548c 2\u3002 [0, 0, 1] : \u6700\u9ad8\u7968\u6570\uff1a 0 [0, 1, 2] : \u6700\u9ad8\u7968\u7ea7\uff1a \u65e0\uff08\u968f\u673a\u9009\u62e9\u4e00\u4e2a\uff09 [2, 2, 2] : \u6700\u9ad8\u7968\u6570\uff1a 2 \u4ee5\u4e0b\u7b80\u5355\u51fd\u6570\u53ef\u4ee5\u5b8c\u6210\u8fd9\u4e9b\u7b80\u5355\u64cd\u4f5c\u3002 import numpy as np def mean_predictions ( probas ): # \u8ba1\u7b97\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u6bcf\u884c\u5e73\u5747\u503c return np . mean ( probas , axis = 1 ) def max_voting ( preds ): # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u67e5\u627e\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u7684\u7d22\u5f15 idxs = np . argmax ( preds , axis = 1 ) # \u6839\u636e\u7d22\u5f15\u53d6\u51fa\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u5143\u7d20 return np . take_along_axis ( preds , idxs [:, None ], axis = 1 ) \u8bf7\u6ce8\u610f\uff0cprobas \u7684\u6bcf\u4e00\u5217\u90fd\u53ea\u6709\u4e00\u4e2a\u6982\u7387\uff08\u5373\u4e8c\u5143\u5206\u7c7b\uff0c\u901a\u5e38\u4e3a\u7c7b\u522b 1\uff09\u3002\u56e0\u6b64\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u4e00\u4e2a\u65b0\u6a21\u578b\u3002\u540c\u6837\uff0c\u5bf9\u4e8e preds\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u6765\u81ea\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b\u503c\u3002\u8fd9\u4e24\u4e2a\u51fd\u6570\u90fd\u5047\u8bbe\u4e86\u4e00\u4e2a 2 \u7ef4 numpy \u6570\u7ec4\u3002\u60a8\u53ef\u4ee5\u6839\u636e\u81ea\u5df1\u7684\u9700\u6c42\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u6709\u4e00\u4e2a 2 \u7ef4\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u6bcf\u4e2a\u6a21\u578b\u7684\u6982\u7387\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u51fd\u6570\u4f1a\u6709\u4e00\u4e9b\u53d8\u5316\u3002 \u53e6\u4e00\u79cd\u7ec4\u5408\u591a\u4e2a\u6a21\u578b\u7684\u65b9\u6cd5\u662f\u901a\u8fc7\u5b83\u4eec\u7684 \u6982\u7387\u6392\u5e8f \u3002\u5f53\u76f8\u5173\u6307\u6807\u662f\u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u65f6\uff0c\u8fd9\u79cd\u7ec4\u5408\u65b9\u5f0f\u975e\u5e38\u6709\u6548\uff0c\u56e0\u4e3a AUC \u5c31\u662f\u5bf9\u6837\u672c\u8fdb\u884c\u6392\u5e8f\u3002 def rank_mean ( probas ): # \u521b\u5efa\u7a7a\u5217\u8868ranked\u5b58\u50a8\u6bcf\u4e2a\u7c7b\u522b\u6982\u7387\u503c\u6392\u540d ranked = [] # \u904d\u5386\u6982\u7387\u503c\u6bcf\u4e00\u5217\uff08\u6bcf\u4e2a\u7c7b\u522b\u7684\u6982\u7387\u503c\uff09 for i in range ( probas . shape [ 1 ]): # \u5f53\u524d\u5217\u6982\u7387\u503c\u6392\u540d\uff0crank_data\u662f\u6392\u540d\u7ed3\u679c rank_data = stats . rankdata ( probas [:, i ]) # \u5c06\u5f53\u524d\u5217\u6392\u540d\u7ed3\u679c\u6dfb\u52a0\u5230ranked\u5217\u8868\u4e2d ranked . append ( rank_data ) # \u5c06ranked\u5217\u8868\u4e2d\u6392\u540d\u7ed3\u679c\u6309\u5217\u5806\u53e0\uff0c\u5f62\u6210\u4e8c\u7ef4\u6570\u7ec4 ranked = np . column_stack ( ranked ) # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u8ba1\u7b97\u6837\u672c\u6392\u540d\u5e73\u5747\u503c return np . mean ( ranked , axis = 1 ) \u8bf7\u6ce8\u610f\uff0c\u5728 scipy \u7684 rankdata \u4e2d\uff0c\u7b49\u7ea7\u4ece 1 \u5f00\u59cb\u3002 \u4e3a\u4ec0\u4e48\u8fd9\u7c7b\u96c6\u5408\u6709\u6548\uff1f\u8ba9\u6211\u4eec\u770b\u770b\u56fe 1\u3002 \u56fe 1\uff1a\u4e09\u4eba\u731c\u5927\u8c61\u7684\u8eab\u9ad8 \u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u6709\u4e09\u4e2a\u4eba\u5728\u731c\u5927\u8c61\u7684\u9ad8\u5ea6\uff0c\u90a3\u4e48\u539f\u59cb\u9ad8\u5ea6\u5c06\u975e\u5e38\u63a5\u8fd1\u4e09\u4e2a\u4eba\u731c\u6d4b\u7684\u5e73\u5747\u503c\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u4eba\u90fd\u80fd\u731c\u5230\u975e\u5e38\u63a5\u8fd1\u5927\u8c61\u539f\u6765\u7684\u9ad8\u5ea6\u3002\u63a5\u8fd1\u4f30\u8ba1\u503c\u610f\u5473\u7740\u8bef\u5dee\uff0c\u4f46\u5982\u679c\u6211\u4eec\u5c06\u4e09\u4e2a\u9884\u6d4b\u503c\u5e73\u5747\uff0c\u5c31\u80fd\u5c06\u8bef\u5dee\u964d\u5230\u6700\u4f4e\u3002\u8fd9\u5c31\u662f\u591a\u4e2a\u6a21\u578b\u5e73\u5747\u7684\u4e3b\u8981\u601d\u60f3\u3002 $$ Final\\ Probabilities = w_1 \\times M1_proba + w_2 \\times M2_proba + \\cdots + w_n \\times Mn_proba $$ \u5176\u4e2d \\((w_1 + w_2 + w_3 + \\cdots + w_n)=1.0\\) \u4f8b\u5982\uff0c\u5982\u679c\u4f60\u6709\u4e00\u4e2a AUC \u975e\u5e38\u9ad8\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u548c\u4e00\u4e2a AUC \u7a0d\u4f4e\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0c\u968f\u673a\u68ee\u6797\u6a21\u578b\u5360 70%\uff0c\u903b\u8f91\u56de\u5f52\u6a21\u578b\u5360 30%\u3002\u90a3\u4e48\uff0c\u6211\u662f\u5982\u4f55\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u7684\u5462\uff1f\u8ba9\u6211\u4eec\u518d\u6dfb\u52a0\u4e00\u4e2a\u6a21\u578b\uff0c\u5047\u8bbe\u73b0\u5728\u6211\u4eec\u4e5f\u6709\u4e00\u4e2a xgboost \u6a21\u578b\uff0c\u5b83\u7684 AUC \u6bd4\u968f\u673a\u68ee\u6797\u9ad8\u3002\u73b0\u5728\uff0c\u6211\u5c06\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0cxgboost\uff1a\u968f\u673a\u68ee\u6797\uff1a\u903b\u8f91\u56de\u5f52\u7684\u6bd4\u4f8b\u4e3a 3:2:1\u3002\u5f88\u7b80\u5355\u5427\uff1f\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u6613\u5982\u53cd\u638c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u662f\u5982\u4f55\u505a\u5230\u7684\u3002 \u5047\u5b9a\u6211\u4eec\u6709\u4e09\u53ea\u7334\u5b50\uff0c\u4e09\u53ea\u65cb\u94ae\u7684\u6570\u503c\u5728 0 \u548c 1 \u4e4b\u95f4\u3002\u8fd9\u4e9b\u7334\u5b50\u8f6c\u52a8\u65cb\u94ae\uff0c\u6211\u4eec\u8ba1\u7b97\u5b83\u4eec\u6bcf\u8f6c\u5230\u4e00\u4e2a\u6570\u503c\u65f6\u7684 AUC \u5206\u6570\u3002\u6700\u7ec8\uff0c\u7334\u5b50\u4eec\u4f1a\u627e\u5230\u4e00\u4e2a\u80fd\u7ed9\u51fa\u6700\u4f73 AUC \u7684\u7ec4\u5408\u3002\u6ca1\u9519\uff0c\u8fd9\u5c31\u662f\u968f\u673a\u641c\u7d22\uff01\u5728\u8fdb\u884c\u8fd9\u7c7b\u641c\u7d22\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\u4e24\u4e2a\u6700\u91cd\u8981\u7684\u7ec4\u5408\u89c4\u5219\u3002 \u7ec4\u5408\u7684\u7b2c\u4e00\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u7ec4\u5408\u7684\u7b2c\u4e8c\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u662f\u7684\u3002\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u6761\u89c4\u5219\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u6298\u53e0\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u5047\u8bbe\u6211\u4eec\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\uff1a\u6298\u53e0 1 \u548c\u6298\u53e0 2\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u6837\u505a\u53ea\u662f\u4e3a\u4e86\u7b80\u5316\u89e3\u91ca\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\uff0c\u60a8\u5e94\u8be5\u521b\u5efa\u66f4\u591a\u7684\u6298\u53e0\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u968f\u673a\u68ee\u6797\u6a21\u578b\u3001\u903b\u8f91\u56de\u5f52\u6a21\u578b\u548c xgboost \u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 2 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u4e4b\u540e\uff0c\u6211\u4eec\u5728\u6298\u53e0 2 \u4e0a\u4ece\u5934\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 1 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u4e3a\u6240\u6709\u8bad\u7ec3\u6570\u636e\u521b\u5efa\u4e86\u9884\u6d4b\u7ed3\u679c\u3002\u73b0\u5728\uff0c\u4e3a\u4e86\u5408\u5e76\u8fd9\u4e9b\u6a21\u578b\uff0c\u6211\u4eec\u5c06\u6298\u53e0 1 \u548c\u6298\u53e0 1 \u7684\u6240\u6709\u9884\u6d4b\u6570\u636e\u5408\u5e76\u5728\u4e00\u8d77\uff0c\u7136\u540e\u521b\u5efa\u4e00\u4e2a\u4f18\u5316\u51fd\u6570\uff0c\u8bd5\u56fe\u627e\u5230\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4fbf\u9488\u5bf9\u6298\u53e0 2 \u7684\u76ee\u6807\u6700\u5c0f\u5316\u8bef\u5dee\u6216\u6700\u5927\u5316 AUC\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u662f\u7528\u4e09\u4e2a\u6a21\u578b\u7684\u9884\u6d4b\u6982\u7387\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u4e00\u4e2a\u4f18\u5316\u6a21\u578b\uff0c\u7136\u540e\u5728\u6298\u53e0 2 \u4e0a\u5bf9\u5176\u8fdb\u884c\u8bc4\u4f30\u3002\u8ba9\u6211\u4eec\u5148\u6765\u770b\u770b\u6211\u4eec\u53ef\u4ee5\u7528\u6765\u627e\u5230\u591a\u4e2a\u6a21\u578b\u7684\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4f18\u5316 AUC\uff08\u6216\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u6d4b\u6307\u6807\u7ec4\u5408\uff09\u7684\u7c7b\u3002 import numpy as np from functools import partial from scipy.optimize import fmin from sklearn import metrics class OptimizeAUC : def __init__ ( self ): # \u521d\u59cb\u5316\u7cfb\u6570 self . coef_ = 0 def _auc ( self , coef , X , y ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u7cfb\u6570 x_coef = X * coef # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8ba1\u7b97AUC\u5206\u6570 auc_score = metrics . roc_auc_score ( y , predictions ) # \u8fd4\u56de\u8d1fAUC\u4ee5\u4fbf\u6700\u5c0f\u5316 return - 1.0 * auc_score def fit ( self , X , y ): # \u521b\u5efa\u5e26\u6709\u90e8\u5206\u53c2\u6570\u7684\u76ee\u6807\u51fd\u6570 loss_partial = partial ( self . _auc , X = X , y = y ) # \u521d\u59cb\u5316\u7cfb\u6570 initial_coef = np . random . dirichlet ( np . ones ( X . shape [ 1 ]), size = 1 ) # \u4f7f\u7528fmin\u51fd\u6570\u4f18\u5316AUC\u76ee\u6807\u51fd\u6570\uff0c\u627e\u5230\u6700\u4f18\u7cfb\u6570 self . coef_ = fmin ( loss_partial , initial_coef , disp = True ) def predict ( self , X ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u8bad\u7ec3\u597d\u7684\u7cfb\u6570 x_coef = X * self . coef_ # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8fd4\u56de\u9884\u6d4b\u7ed3\u679c return predictions \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528\u5b83\uff0c\u5e76\u5c06\u5176\u4e0e\u7b80\u5355\u5e73\u5747\u6cd5\u8fdb\u884c\u6bd4\u8f83\u3002 import xgboost as xgb from sklearn.datasets import make_classification from sklearn import ensemble from sklearn import linear_model from sklearn import metrics from sklearn import model_selection # \u751f\u6210\u4e00\u4e2a\u5206\u7c7b\u6570\u636e\u96c6 X , y = make_classification ( n_samples = 10000 , n_features = 25 ) # \u5212\u5206\u6570\u636e\u96c6\u4e3a\u4e24\u4e2a\u4ea4\u53c9\u9a8c\u8bc1\u6298\u53e0 xfold1 , xfold2 , yfold1 , yfold2 = model_selection . train_test_split ( X , y , test_size = 0.5 , stratify = y ) # \u521d\u59cb\u5316\u4e09\u4e2a\u4e0d\u540c\u7684\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold1 , yfold1 ) rf . fit ( xfold1 , yfold1 ) xgbc . fit ( xfold1 , yfold1 ) # \u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold2 )[:, 1 ] pred_rf = rf . predict_proba ( xfold2 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold2 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold2_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold2 = [] for i in range ( fold2_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold2 , fold2_preds [:, i ]) aucs_fold2 . append ( auc ) print ( f \"Fold-2: LR AUC = { aucs_fold2 [ 0 ] } \" ) print ( f \"Fold-2: RF AUC = { aucs_fold2 [ 1 ] } \" ) print ( f \"Fold-2: XGB AUC = { aucs_fold2 [ 2 ] } \" ) print ( f \"Fold-2: Average Pred AUC = { aucs_fold2 [ 3 ] } \" ) # \u91cd\u65b0\u521d\u59cb\u5316\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold2 , yfold2 ) rf . fit ( xfold2 , yfold2 ) xgbc . fit ( xfold2 , yfold2 ) # \u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold1 )[:, 1 ] pred_rf = rf . predict_proba ( xfold1 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold1 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold1_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold1 = [] for i in range ( fold1_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold1 , fold1_preds [:, i ]) aucs_fold1 . append ( auc ) print ( f \"Fold-1: LR AUC = { aucs_fold1 [ 0 ] } \" ) print ( f \"Fold-1: RF AUC = { aucs_fold1 [ 1 ] } \" ) print ( f \"Fold-1: XGB AUC = { aucs_fold1 [ 2 ] } \" ) print ( f \"Fold-1: Average prediction AUC = { aucs_fold1 [ 3 ] } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765\u8bad\u7ec3\u4f18\u5316\u5668 opt . fit ( fold1_preds [:, : - 1 ], yfold1 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold2 = opt . predict ( fold2_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold2 , opt_preds_fold2 ) print ( f \"Optimized AUC, Fold 2 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765 opt . fit ( fold2_preds [:, : - 1 ], yfold2 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold1 = opt . predict ( fold1_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold1 , opt_preds_fold1 ) print ( f \"Optimized AUC, Fold 1 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) \u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u8f93\u51fa\uff1a \u276f python auc_opt . py Fold - 2 : LR AUC = 0.9145446769443348 Fold - 2 : RF AUC = 0.9269918948683287 Fold - 2 : XGB AUC = 0.9302436595508696 Fold - 2 : Average Pred AUC = 0.927701495890154 Fold - 1 : LR AUC = 0.9050872233256017 Fold - 1 : RF AUC = 0.9179382818311258 Fold - 1 : XGB AUC = 0.9195837242005629 Fold - 1 : Average prediction AUC = 0.9189669233123695 Optimization terminated successfully . Current function value : - 0.920643 Iterations : 50 Function evaluations : 109 Optimized AUC , Fold 2 = 0.9305386199756128 Coefficients = [ - 0.00188194 0.19328336 0.35891836 ] Optimization terminated successfully . Current function value : - 0.931232 Iterations : 56 Function evaluations : 113 Optimized AUC , Fold 1 = 0.9192523637234037 Coefficients = [ - 0.15655124 0.22393151 0.58711366 ] \u6211\u4eec\u770b\u5230\uff0c\u5e73\u5747\u503c\u66f4\u597d\uff0c\u4f46\u4f7f\u7528\u4f18\u5316\u5668\u627e\u5230\u9608\u503c\u66f4\u597d\uff01\u6709\u65f6\uff0c\u5e73\u5747\u503c\u662f\u6700\u597d\u7684\u9009\u62e9\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u7cfb\u6570\u52a0\u8d77\u6765\u5e76\u6ca1\u6709\u8fbe\u5230 1.0\uff0c\u4f46\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u6211\u4eec\u8981\u5904\u7406\u7684\u662f AUC\uff0c\u800c AUC \u53ea\u5173\u5fc3\u7b49\u7ea7\u3002 \u5373\u4f7f\u968f\u673a\u68ee\u6797\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u968f\u673a\u68ee\u6797\u53ea\u662f\u8bb8\u591a\u7b80\u5355\u51b3\u7b56\u6811\u7684\u7ec4\u5408\u3002\u968f\u673a\u68ee\u6797\u5c5e\u4e8e\u96c6\u5408\u6a21\u578b\u7684\u4e00\u79cd\uff0c\u4e5f\u5c31\u662f\u4fd7\u79f0\u7684 \"bagging\" \u3002\u5728\u888b\u96c6\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u5c0f\u6570\u636e\u5b50\u96c6\u5e76\u8bad\u7ec3\u591a\u4e2a\u7b80\u5355\u6a21\u578b\u3002\u6700\u7ec8\u7ed3\u679c\u7531\u6240\u6709\u8fd9\u4e9b\u5c0f\u6a21\u578b\u7684\u9884\u6d4b\u7ed3\u679c\uff08\u5982\u5e73\u5747\u503c\uff09\u7ec4\u5408\u800c\u6210\u3002 \u6211\u4eec\u4f7f\u7528\u7684 xgboost \u6a21\u578b\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u6240\u6709\u68af\u5ea6\u63d0\u5347\u6a21\u578b\u90fd\u662f\u96c6\u5408\u6a21\u578b\uff0c\u7edf\u79f0\u4e3a \u63d0\u5347\u6a21\u578b\uff08boosting models\uff09 \u3002\u63d0\u5347\u6a21\u578b\u7684\u5de5\u4f5c\u539f\u7406\u4e0e\u88c5\u888b\u6a21\u578b\u7c7b\u4f3c\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u63d0\u5347\u6a21\u578b\u4e2d\u7684\u8fde\u7eed\u6a21\u578b\u662f\u6839\u636e\u8bef\u5dee\u6b8b\u5dee\u8bad\u7ec3\u7684\uff0c\u5e76\u503e\u5411\u4e8e\u6700\u5c0f\u5316\u524d\u9762\u6a21\u578b\u7684\u8bef\u5dee\u3002\u8fd9\u6837\uff0c\u63d0\u5347\u6a21\u578b\u5c31\u80fd\u5b8c\u7f8e\u5730\u5b66\u4e60\u6570\u636e\uff0c\u56e0\u6b64\u5bb9\u6613\u51fa\u73b0\u8fc7\u62df\u5408\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u770b\u5230\u7684\u4ee3\u7801\u7247\u6bb5\u53ea\u8003\u8651\u4e86\u4e00\u5217\u3002\u4f46\u60c5\u51b5\u5e76\u975e\u603b\u662f\u5982\u6b64\uff0c\u5f88\u591a\u65f6\u5019\u60a8\u9700\u8981\u5904\u7406\u591a\u5217\u9884\u6d4b\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u4f1a\u9047\u5230\u4ece\u591a\u4e2a\u7c7b\u522b\u4e2d\u9884\u6d4b\u4e00\u4e2a\u7c7b\u522b\u7684\u95ee\u9898\uff0c\u5373\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5bf9\u4e8e\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u9009\u62e9\u6295\u7968\u65b9\u6cd5\u3002\u4f46\u6295\u7968\u6cd5\u5e76\u4e0d\u603b\u662f\u6700\u4f73\u65b9\u6cd5\u3002\u5982\u679c\u8981\u7ec4\u5408\u6982\u7387\uff0c\u5c31\u4f1a\u6709\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4\uff0c\u800c\u4e0d\u662f\u50cf\u6211\u4eec\u4e4b\u524d\u4f18\u5316 AUC \u65f6\u7684\u5411\u91cf\u3002\u5982\u679c\u6709\u591a\u4e2a\u7c7b\u522b\uff0c\u53ef\u4ee5\u5c1d\u8bd5\u4f18\u5316\u5bf9\u6570\u635f\u5931\uff08\u6216\u5176\u4ed6\u4e0e\u4e1a\u52a1\u76f8\u5173\u7684\u6307\u6807\uff09\u3002 \u8981\u8fdb\u884c\u7ec4\u5408\uff0c\u53ef\u4ee5\u5728\u62df\u5408\u51fd\u6570 (X) \u4e2d\u4f7f\u7528 numpy \u6570\u7ec4\u5217\u8868\u800c\u4e0d\u662f numpy \u6570\u7ec4\uff0c\u968f\u540e\u8fd8\u9700\u8981\u66f4\u6539\u4f18\u5316\u5668\u548c\u9884\u6d4b\u51fd\u6570\u3002\u6211\u5c31\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u5427\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8fdb\u5165\u4e0b\u4e00\u4e2a\u6709\u8da3\u7684\u8bdd\u9898\uff0c\u8fd9\u4e2a\u8bdd\u9898\u76f8\u5f53\u6d41\u884c\uff0c\u88ab\u79f0\u4e3a \u5806\u53e0 \u3002\u56fe 2 \u5c55\u793a\u4e86\u5982\u4f55\u5806\u53e0\u6a21\u578b\u3002 \u56fe2 : Stacking \u5806\u53e0\u4e0d\u50cf\u5236\u9020\u706b\u7bad\u3002\u5b83\u7b80\u5355\u660e\u4e86\u3002\u5982\u679c\u60a8\u8fdb\u884c\u4e86\u6b63\u786e\u7684\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5e76\u5728\u6574\u4e2a\u5efa\u6a21\u8fc7\u7a0b\u4e2d\u4fdd\u6301\u6298\u53e0\u4e0d\u53d8\uff0c\u90a3\u4e48\u5c31\u4e0d\u4f1a\u51fa\u73b0\u4efb\u4f55\u8fc7\u5ea6\u8d34\u5408\u7684\u60c5\u51b5\u3002 \u8ba9\u6211\u7528\u7b80\u5355\u7684\u8981\u70b9\u5411\u4f60\u63cf\u8ff0\u4e00\u4e0b\u8fd9\u4e2a\u60f3\u6cd5\u3002 - \u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u82e5\u5e72\u6298\u53e0\u3002 - \u8bad\u7ec3\u4e00\u5806\u6a21\u578b\uff1a M1\u3001M2.....Mn\u3002 - \u521b\u5efa\u5b8c\u6574\u7684\u8bad\u7ec3\u9884\u6d4b\uff08\u4f7f\u7528\u975e\u6298\u53e0\u8bad\u7ec3\uff09\uff0c\u5e76\u4f7f\u7528\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\u9884\u6d4b\u3002 - \u76f4\u5230\u8fd9\u91cc\u662f\u7b2c 1 \u5c42 (L1)\u3002 - \u5c06\u8fd9\u4e9b\u6a21\u578b\u7684\u6298\u53e0\u9884\u6d4b\u4f5c\u4e3a\u53e6\u4e00\u4e2a\u6a21\u578b\u7684\u7279\u5f81\u3002\u8fd9\u5c31\u662f\u4e8c\u7ea7\u6a21\u578b\uff08L2\uff09\u3002 - \u4f7f\u7528\u4e0e\u4e4b\u524d\u76f8\u540c\u7684\u6298\u53e0\u6765\u8bad\u7ec3\u8fd9\u4e2a L2 \u6a21\u578b\u3002 - \u73b0\u5728\uff0c\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u521b\u5efa OOF\uff08\u6298\u53e0\u5916\uff09\u9884\u6d4b\u3002 - \u73b0\u5728\u60a8\u5c31\u6709\u4e86\u8bad\u7ec3\u6570\u636e\u7684 L2 \u9884\u6d4b\u548c\u6700\u7ec8\u6d4b\u8bd5\u96c6\u9884\u6d4b\u3002 \u60a8\u53ef\u4ee5\u4e0d\u65ad\u91cd\u590d L1 \u90e8\u5206\uff0c\u4e5f\u53ef\u4ee5\u521b\u5efa\u4efb\u610f\u591a\u7684\u5c42\u6b21\u3002 \u6709\u65f6\uff0c\u4f60\u8fd8\u4f1a\u9047\u5230\u4e00\u4e2a\u53eb\u6df7\u5408\u7684\u672f\u8bed blending \u3002\u5982\u679c\u4f60\u9047\u5230\u4e86\uff0c\u4e0d\u7528\u592a\u62c5\u5fc3\u3002\u5b83\u53ea\u4e0d\u8fc7\u662f\u7528\u4e00\u4e2a\u4fdd\u7559\u7ec4\u6765\u5806\u53e0\uff0c\u800c\u4e0d\u662f\u591a\u91cd\u6298\u53e0\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u6211\u5728\u672c\u7ae0\u4e2d\u6240\u63cf\u8ff0\u7684\u5185\u5bb9\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u7c7b\u578b\u7684\u95ee\u9898\uff1a\u5206\u7c7b\u3001\u56de\u5f52\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u7b49\u3002","title":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5"},{"location":"%E7%BB%84%E5%90%88%E5%92%8C%E5%A0%86%E5%8F%A0%E6%96%B9%E6%B3%95/#_1","text":"\u542c\u5230\u4e0a\u9762\u4e24\u4e2a\u8bcd\uff0c\u6211\u4eec\u9996\u5148\u60f3\u5230\u7684\u5c31\u662f\u5728\u7ebf\uff08online\uff09/\u79bb\u7ebf\uff08offline\uff09\u673a\u5668\u5b66\u4e60\u7ade\u8d5b\u3002\u51e0\u5e74\u524d\u662f\u8fd9\u6837\uff0c\u4f46\u73b0\u5728\u968f\u7740\u8ba1\u7b97\u80fd\u529b\u7684\u8fdb\u6b65\u548c\u865a\u62df\u5b9e\u4f8b\u7684\u5ec9\u4ef7\uff0c\u4eba\u4eec\u751a\u81f3\u5f00\u59cb\u5728\u884c\u4e1a\u4e2d\u4f7f\u7528\u7ec4\u5408\u6a21\u578b\uff08ensemble models\uff09\u3002\u4f8b\u5982\uff0c\u90e8\u7f72\u591a\u4e2a\u795e\u7ecf\u7f51\u7edc\u5e76\u5b9e\u65f6\u4e3a\u5b83\u4eec\u63d0\u4f9b\u670d\u52a1\u975e\u5e38\u5bb9\u6613\uff0c\u54cd\u5e94\u65f6\u95f4\u5c0f\u4e8e 500 \u6beb\u79d2\u3002\u6709\u65f6\uff0c\u4e00\u4e2a\u5e9e\u5927\u7684\u795e\u7ecf\u7f51\u7edc\u6216\u5927\u578b\u6a21\u578b\u4e5f\u53ef\u4ee5\u88ab\u5176\u4ed6\u51e0\u4e2a\u6a21\u578b\u53d6\u4ee3\uff0c\u8fd9\u4e9b\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u6027\u80fd\u4e0e\u5927\u578b\u6a21\u578b\u76f8\u4f3c\uff0c\u901f\u5ea6\u5374\u5feb\u4e00\u500d\u3002\u5982\u679c\u662f\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff08\u4e9b\uff09\u6a21\u578b\u5462\uff1f\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u4e8e\u9009\u62e9\u591a\u4e2a\u5c0f\u673a\u578b\uff0c\u5b83\u4eec\u901f\u5ea6\u66f4\u5feb\uff0c\u6027\u80fd\u4e0e\u5927\u673a\u578b\u548c\u6162\u673a\u578b\u76f8\u540c\u3002\u8bf7\u8bb0\u4f4f\uff0c\u8f83\u5c0f\u7684\u578b\u53f7\u4e5f\u66f4\u5bb9\u6613\u548c\u66f4\u5feb\u5730\u8fdb\u884c\u8c03\u6574\u3002 \u7ec4\u5408\uff08ensembling\uff09\u4e0d\u8fc7\u662f\u4e0d\u540c\u6a21\u578b\u7684\u7ec4\u5408\u3002\u6a21\u578b\u53ef\u4ee5\u901a\u8fc7\u9884\u6d4b/\u6982\u7387\u8fdb\u884c\u7ec4\u5408\u3002\u7ec4\u5408\u6a21\u578b\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6c42\u5e73\u5747\u503c\u3002 $$ Ensemble Probabilities = (M1_proba + M2_proba + ... + Mn_Proba)/n $$ \u8fd9\u662f\u6700\u7b80\u5355\u4e5f\u662f\u6700\u6709\u6548\u7684\u7ec4\u5408\u6a21\u578b\u7684\u65b9\u6cd5\u3002\u5728\u7b80\u5355\u5e73\u5747\u6cd5\u4e2d\uff0c\u6240\u6709\u6a21\u578b\u7684\u6743\u91cd\u90fd\u662f\u76f8\u7b49\u7684\u3002\u65e0\u8bba\u91c7\u7528\u54ea\u79cd\u7ec4\u5408\u65b9\u6cd5\uff0c\u60a8\u90fd\u5e94\u8be5\u7262\u8bb0\u4e00\u70b9\uff0c\u90a3\u5c31\u662f\u60a8\u5e94\u8be5\u59cb\u7ec8\u5c06\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b/\u6982\u7387\u7ec4\u5408\u5728\u4e00\u8d77\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u7ec4\u5408\u76f8\u5173\u6027\u4e0d\u9ad8\u7684\u6a21\u578b\u6bd4\u7ec4\u5408\u76f8\u5173\u6027\u5f88\u9ad8\u7684\u6a21\u578b\u6548\u679c\u66f4\u597d\u3002 \u5982\u679c\u6ca1\u6709\u6982\u7387\uff0c\u4e5f\u53ef\u4ee5\u7ec4\u5408\u9884\u6d4b\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6295\u7968\u3002\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u8fdb\u884c\u591a\u7c7b\u5206\u7c7b\uff0c\u6709\u4e09\u4e2a\u7c7b\u522b\uff1a 0\u30011 \u548c 2\u3002 [0, 0, 1] : \u6700\u9ad8\u7968\u6570\uff1a 0 [0, 1, 2] : \u6700\u9ad8\u7968\u7ea7\uff1a \u65e0\uff08\u968f\u673a\u9009\u62e9\u4e00\u4e2a\uff09 [2, 2, 2] : \u6700\u9ad8\u7968\u6570\uff1a 2 \u4ee5\u4e0b\u7b80\u5355\u51fd\u6570\u53ef\u4ee5\u5b8c\u6210\u8fd9\u4e9b\u7b80\u5355\u64cd\u4f5c\u3002 import numpy as np def mean_predictions ( probas ): # \u8ba1\u7b97\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u6bcf\u884c\u5e73\u5747\u503c return np . mean ( probas , axis = 1 ) def max_voting ( preds ): # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u67e5\u627e\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u7684\u7d22\u5f15 idxs = np . argmax ( preds , axis = 1 ) # \u6839\u636e\u7d22\u5f15\u53d6\u51fa\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u5143\u7d20 return np . take_along_axis ( preds , idxs [:, None ], axis = 1 ) \u8bf7\u6ce8\u610f\uff0cprobas \u7684\u6bcf\u4e00\u5217\u90fd\u53ea\u6709\u4e00\u4e2a\u6982\u7387\uff08\u5373\u4e8c\u5143\u5206\u7c7b\uff0c\u901a\u5e38\u4e3a\u7c7b\u522b 1\uff09\u3002\u56e0\u6b64\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u4e00\u4e2a\u65b0\u6a21\u578b\u3002\u540c\u6837\uff0c\u5bf9\u4e8e preds\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u6765\u81ea\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b\u503c\u3002\u8fd9\u4e24\u4e2a\u51fd\u6570\u90fd\u5047\u8bbe\u4e86\u4e00\u4e2a 2 \u7ef4 numpy \u6570\u7ec4\u3002\u60a8\u53ef\u4ee5\u6839\u636e\u81ea\u5df1\u7684\u9700\u6c42\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u6709\u4e00\u4e2a 2 \u7ef4\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u6bcf\u4e2a\u6a21\u578b\u7684\u6982\u7387\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u51fd\u6570\u4f1a\u6709\u4e00\u4e9b\u53d8\u5316\u3002 \u53e6\u4e00\u79cd\u7ec4\u5408\u591a\u4e2a\u6a21\u578b\u7684\u65b9\u6cd5\u662f\u901a\u8fc7\u5b83\u4eec\u7684 \u6982\u7387\u6392\u5e8f \u3002\u5f53\u76f8\u5173\u6307\u6807\u662f\u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u65f6\uff0c\u8fd9\u79cd\u7ec4\u5408\u65b9\u5f0f\u975e\u5e38\u6709\u6548\uff0c\u56e0\u4e3a AUC \u5c31\u662f\u5bf9\u6837\u672c\u8fdb\u884c\u6392\u5e8f\u3002 def rank_mean ( probas ): # \u521b\u5efa\u7a7a\u5217\u8868ranked\u5b58\u50a8\u6bcf\u4e2a\u7c7b\u522b\u6982\u7387\u503c\u6392\u540d ranked = [] # \u904d\u5386\u6982\u7387\u503c\u6bcf\u4e00\u5217\uff08\u6bcf\u4e2a\u7c7b\u522b\u7684\u6982\u7387\u503c\uff09 for i in range ( probas . shape [ 1 ]): # \u5f53\u524d\u5217\u6982\u7387\u503c\u6392\u540d\uff0crank_data\u662f\u6392\u540d\u7ed3\u679c rank_data = stats . rankdata ( probas [:, i ]) # \u5c06\u5f53\u524d\u5217\u6392\u540d\u7ed3\u679c\u6dfb\u52a0\u5230ranked\u5217\u8868\u4e2d ranked . append ( rank_data ) # \u5c06ranked\u5217\u8868\u4e2d\u6392\u540d\u7ed3\u679c\u6309\u5217\u5806\u53e0\uff0c\u5f62\u6210\u4e8c\u7ef4\u6570\u7ec4 ranked = np . column_stack ( ranked ) # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u8ba1\u7b97\u6837\u672c\u6392\u540d\u5e73\u5747\u503c return np . mean ( ranked , axis = 1 ) \u8bf7\u6ce8\u610f\uff0c\u5728 scipy \u7684 rankdata \u4e2d\uff0c\u7b49\u7ea7\u4ece 1 \u5f00\u59cb\u3002 \u4e3a\u4ec0\u4e48\u8fd9\u7c7b\u96c6\u5408\u6709\u6548\uff1f\u8ba9\u6211\u4eec\u770b\u770b\u56fe 1\u3002 \u56fe 1\uff1a\u4e09\u4eba\u731c\u5927\u8c61\u7684\u8eab\u9ad8 \u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u6709\u4e09\u4e2a\u4eba\u5728\u731c\u5927\u8c61\u7684\u9ad8\u5ea6\uff0c\u90a3\u4e48\u539f\u59cb\u9ad8\u5ea6\u5c06\u975e\u5e38\u63a5\u8fd1\u4e09\u4e2a\u4eba\u731c\u6d4b\u7684\u5e73\u5747\u503c\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u4eba\u90fd\u80fd\u731c\u5230\u975e\u5e38\u63a5\u8fd1\u5927\u8c61\u539f\u6765\u7684\u9ad8\u5ea6\u3002\u63a5\u8fd1\u4f30\u8ba1\u503c\u610f\u5473\u7740\u8bef\u5dee\uff0c\u4f46\u5982\u679c\u6211\u4eec\u5c06\u4e09\u4e2a\u9884\u6d4b\u503c\u5e73\u5747\uff0c\u5c31\u80fd\u5c06\u8bef\u5dee\u964d\u5230\u6700\u4f4e\u3002\u8fd9\u5c31\u662f\u591a\u4e2a\u6a21\u578b\u5e73\u5747\u7684\u4e3b\u8981\u601d\u60f3\u3002 $$ Final\\ Probabilities = w_1 \\times M1_proba + w_2 \\times M2_proba + \\cdots + w_n \\times Mn_proba $$ \u5176\u4e2d \\((w_1 + w_2 + w_3 + \\cdots + w_n)=1.0\\) \u4f8b\u5982\uff0c\u5982\u679c\u4f60\u6709\u4e00\u4e2a AUC \u975e\u5e38\u9ad8\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u548c\u4e00\u4e2a AUC \u7a0d\u4f4e\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0c\u968f\u673a\u68ee\u6797\u6a21\u578b\u5360 70%\uff0c\u903b\u8f91\u56de\u5f52\u6a21\u578b\u5360 30%\u3002\u90a3\u4e48\uff0c\u6211\u662f\u5982\u4f55\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u7684\u5462\uff1f\u8ba9\u6211\u4eec\u518d\u6dfb\u52a0\u4e00\u4e2a\u6a21\u578b\uff0c\u5047\u8bbe\u73b0\u5728\u6211\u4eec\u4e5f\u6709\u4e00\u4e2a xgboost \u6a21\u578b\uff0c\u5b83\u7684 AUC \u6bd4\u968f\u673a\u68ee\u6797\u9ad8\u3002\u73b0\u5728\uff0c\u6211\u5c06\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0cxgboost\uff1a\u968f\u673a\u68ee\u6797\uff1a\u903b\u8f91\u56de\u5f52\u7684\u6bd4\u4f8b\u4e3a 3:2:1\u3002\u5f88\u7b80\u5355\u5427\uff1f\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u6613\u5982\u53cd\u638c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u662f\u5982\u4f55\u505a\u5230\u7684\u3002 \u5047\u5b9a\u6211\u4eec\u6709\u4e09\u53ea\u7334\u5b50\uff0c\u4e09\u53ea\u65cb\u94ae\u7684\u6570\u503c\u5728 0 \u548c 1 \u4e4b\u95f4\u3002\u8fd9\u4e9b\u7334\u5b50\u8f6c\u52a8\u65cb\u94ae\uff0c\u6211\u4eec\u8ba1\u7b97\u5b83\u4eec\u6bcf\u8f6c\u5230\u4e00\u4e2a\u6570\u503c\u65f6\u7684 AUC \u5206\u6570\u3002\u6700\u7ec8\uff0c\u7334\u5b50\u4eec\u4f1a\u627e\u5230\u4e00\u4e2a\u80fd\u7ed9\u51fa\u6700\u4f73 AUC \u7684\u7ec4\u5408\u3002\u6ca1\u9519\uff0c\u8fd9\u5c31\u662f\u968f\u673a\u641c\u7d22\uff01\u5728\u8fdb\u884c\u8fd9\u7c7b\u641c\u7d22\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\u4e24\u4e2a\u6700\u91cd\u8981\u7684\u7ec4\u5408\u89c4\u5219\u3002 \u7ec4\u5408\u7684\u7b2c\u4e00\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u7ec4\u5408\u7684\u7b2c\u4e8c\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u662f\u7684\u3002\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u6761\u89c4\u5219\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u6298\u53e0\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u5047\u8bbe\u6211\u4eec\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\uff1a\u6298\u53e0 1 \u548c\u6298\u53e0 2\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u6837\u505a\u53ea\u662f\u4e3a\u4e86\u7b80\u5316\u89e3\u91ca\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\uff0c\u60a8\u5e94\u8be5\u521b\u5efa\u66f4\u591a\u7684\u6298\u53e0\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u968f\u673a\u68ee\u6797\u6a21\u578b\u3001\u903b\u8f91\u56de\u5f52\u6a21\u578b\u548c xgboost \u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 2 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u4e4b\u540e\uff0c\u6211\u4eec\u5728\u6298\u53e0 2 \u4e0a\u4ece\u5934\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 1 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u4e3a\u6240\u6709\u8bad\u7ec3\u6570\u636e\u521b\u5efa\u4e86\u9884\u6d4b\u7ed3\u679c\u3002\u73b0\u5728\uff0c\u4e3a\u4e86\u5408\u5e76\u8fd9\u4e9b\u6a21\u578b\uff0c\u6211\u4eec\u5c06\u6298\u53e0 1 \u548c\u6298\u53e0 1 \u7684\u6240\u6709\u9884\u6d4b\u6570\u636e\u5408\u5e76\u5728\u4e00\u8d77\uff0c\u7136\u540e\u521b\u5efa\u4e00\u4e2a\u4f18\u5316\u51fd\u6570\uff0c\u8bd5\u56fe\u627e\u5230\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4fbf\u9488\u5bf9\u6298\u53e0 2 \u7684\u76ee\u6807\u6700\u5c0f\u5316\u8bef\u5dee\u6216\u6700\u5927\u5316 AUC\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u662f\u7528\u4e09\u4e2a\u6a21\u578b\u7684\u9884\u6d4b\u6982\u7387\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u4e00\u4e2a\u4f18\u5316\u6a21\u578b\uff0c\u7136\u540e\u5728\u6298\u53e0 2 \u4e0a\u5bf9\u5176\u8fdb\u884c\u8bc4\u4f30\u3002\u8ba9\u6211\u4eec\u5148\u6765\u770b\u770b\u6211\u4eec\u53ef\u4ee5\u7528\u6765\u627e\u5230\u591a\u4e2a\u6a21\u578b\u7684\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4f18\u5316 AUC\uff08\u6216\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u6d4b\u6307\u6807\u7ec4\u5408\uff09\u7684\u7c7b\u3002 import numpy as np from functools import partial from scipy.optimize import fmin from sklearn import metrics class OptimizeAUC : def __init__ ( self ): # \u521d\u59cb\u5316\u7cfb\u6570 self . coef_ = 0 def _auc ( self , coef , X , y ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u7cfb\u6570 x_coef = X * coef # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8ba1\u7b97AUC\u5206\u6570 auc_score = metrics . roc_auc_score ( y , predictions ) # \u8fd4\u56de\u8d1fAUC\u4ee5\u4fbf\u6700\u5c0f\u5316 return - 1.0 * auc_score def fit ( self , X , y ): # \u521b\u5efa\u5e26\u6709\u90e8\u5206\u53c2\u6570\u7684\u76ee\u6807\u51fd\u6570 loss_partial = partial ( self . _auc , X = X , y = y ) # \u521d\u59cb\u5316\u7cfb\u6570 initial_coef = np . random . dirichlet ( np . ones ( X . shape [ 1 ]), size = 1 ) # \u4f7f\u7528fmin\u51fd\u6570\u4f18\u5316AUC\u76ee\u6807\u51fd\u6570\uff0c\u627e\u5230\u6700\u4f18\u7cfb\u6570 self . coef_ = fmin ( loss_partial , initial_coef , disp = True ) def predict ( self , X ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u8bad\u7ec3\u597d\u7684\u7cfb\u6570 x_coef = X * self . coef_ # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8fd4\u56de\u9884\u6d4b\u7ed3\u679c return predictions \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528\u5b83\uff0c\u5e76\u5c06\u5176\u4e0e\u7b80\u5355\u5e73\u5747\u6cd5\u8fdb\u884c\u6bd4\u8f83\u3002 import xgboost as xgb from sklearn.datasets import make_classification from sklearn import ensemble from sklearn import linear_model from sklearn import metrics from sklearn import model_selection # \u751f\u6210\u4e00\u4e2a\u5206\u7c7b\u6570\u636e\u96c6 X , y = make_classification ( n_samples = 10000 , n_features = 25 ) # \u5212\u5206\u6570\u636e\u96c6\u4e3a\u4e24\u4e2a\u4ea4\u53c9\u9a8c\u8bc1\u6298\u53e0 xfold1 , xfold2 , yfold1 , yfold2 = model_selection . train_test_split ( X , y , test_size = 0.5 , stratify = y ) # \u521d\u59cb\u5316\u4e09\u4e2a\u4e0d\u540c\u7684\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold1 , yfold1 ) rf . fit ( xfold1 , yfold1 ) xgbc . fit ( xfold1 , yfold1 ) # \u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold2 )[:, 1 ] pred_rf = rf . predict_proba ( xfold2 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold2 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold2_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold2 = [] for i in range ( fold2_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold2 , fold2_preds [:, i ]) aucs_fold2 . append ( auc ) print ( f \"Fold-2: LR AUC = { aucs_fold2 [ 0 ] } \" ) print ( f \"Fold-2: RF AUC = { aucs_fold2 [ 1 ] } \" ) print ( f \"Fold-2: XGB AUC = { aucs_fold2 [ 2 ] } \" ) print ( f \"Fold-2: Average Pred AUC = { aucs_fold2 [ 3 ] } \" ) # \u91cd\u65b0\u521d\u59cb\u5316\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold2 , yfold2 ) rf . fit ( xfold2 , yfold2 ) xgbc . fit ( xfold2 , yfold2 ) # \u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold1 )[:, 1 ] pred_rf = rf . predict_proba ( xfold1 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold1 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold1_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold1 = [] for i in range ( fold1_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold1 , fold1_preds [:, i ]) aucs_fold1 . append ( auc ) print ( f \"Fold-1: LR AUC = { aucs_fold1 [ 0 ] } \" ) print ( f \"Fold-1: RF AUC = { aucs_fold1 [ 1 ] } \" ) print ( f \"Fold-1: XGB AUC = { aucs_fold1 [ 2 ] } \" ) print ( f \"Fold-1: Average prediction AUC = { aucs_fold1 [ 3 ] } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765\u8bad\u7ec3\u4f18\u5316\u5668 opt . fit ( fold1_preds [:, : - 1 ], yfold1 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold2 = opt . predict ( fold2_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold2 , opt_preds_fold2 ) print ( f \"Optimized AUC, Fold 2 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765 opt . fit ( fold2_preds [:, : - 1 ], yfold2 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold1 = opt . predict ( fold1_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold1 , opt_preds_fold1 ) print ( f \"Optimized AUC, Fold 1 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) \u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u8f93\u51fa\uff1a \u276f python auc_opt . py Fold - 2 : LR AUC = 0.9145446769443348 Fold - 2 : RF AUC = 0.9269918948683287 Fold - 2 : XGB AUC = 0.9302436595508696 Fold - 2 : Average Pred AUC = 0.927701495890154 Fold - 1 : LR AUC = 0.9050872233256017 Fold - 1 : RF AUC = 0.9179382818311258 Fold - 1 : XGB AUC = 0.9195837242005629 Fold - 1 : Average prediction AUC = 0.9189669233123695 Optimization terminated successfully . Current function value : - 0.920643 Iterations : 50 Function evaluations : 109 Optimized AUC , Fold 2 = 0.9305386199756128 Coefficients = [ - 0.00188194 0.19328336 0.35891836 ] Optimization terminated successfully . Current function value : - 0.931232 Iterations : 56 Function evaluations : 113 Optimized AUC , Fold 1 = 0.9192523637234037 Coefficients = [ - 0.15655124 0.22393151 0.58711366 ] \u6211\u4eec\u770b\u5230\uff0c\u5e73\u5747\u503c\u66f4\u597d\uff0c\u4f46\u4f7f\u7528\u4f18\u5316\u5668\u627e\u5230\u9608\u503c\u66f4\u597d\uff01\u6709\u65f6\uff0c\u5e73\u5747\u503c\u662f\u6700\u597d\u7684\u9009\u62e9\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u7cfb\u6570\u52a0\u8d77\u6765\u5e76\u6ca1\u6709\u8fbe\u5230 1.0\uff0c\u4f46\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u6211\u4eec\u8981\u5904\u7406\u7684\u662f AUC\uff0c\u800c AUC \u53ea\u5173\u5fc3\u7b49\u7ea7\u3002 \u5373\u4f7f\u968f\u673a\u68ee\u6797\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u968f\u673a\u68ee\u6797\u53ea\u662f\u8bb8\u591a\u7b80\u5355\u51b3\u7b56\u6811\u7684\u7ec4\u5408\u3002\u968f\u673a\u68ee\u6797\u5c5e\u4e8e\u96c6\u5408\u6a21\u578b\u7684\u4e00\u79cd\uff0c\u4e5f\u5c31\u662f\u4fd7\u79f0\u7684 \"bagging\" \u3002\u5728\u888b\u96c6\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u5c0f\u6570\u636e\u5b50\u96c6\u5e76\u8bad\u7ec3\u591a\u4e2a\u7b80\u5355\u6a21\u578b\u3002\u6700\u7ec8\u7ed3\u679c\u7531\u6240\u6709\u8fd9\u4e9b\u5c0f\u6a21\u578b\u7684\u9884\u6d4b\u7ed3\u679c\uff08\u5982\u5e73\u5747\u503c\uff09\u7ec4\u5408\u800c\u6210\u3002 \u6211\u4eec\u4f7f\u7528\u7684 xgboost \u6a21\u578b\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u6240\u6709\u68af\u5ea6\u63d0\u5347\u6a21\u578b\u90fd\u662f\u96c6\u5408\u6a21\u578b\uff0c\u7edf\u79f0\u4e3a \u63d0\u5347\u6a21\u578b\uff08boosting models\uff09 \u3002\u63d0\u5347\u6a21\u578b\u7684\u5de5\u4f5c\u539f\u7406\u4e0e\u88c5\u888b\u6a21\u578b\u7c7b\u4f3c\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u63d0\u5347\u6a21\u578b\u4e2d\u7684\u8fde\u7eed\u6a21\u578b\u662f\u6839\u636e\u8bef\u5dee\u6b8b\u5dee\u8bad\u7ec3\u7684\uff0c\u5e76\u503e\u5411\u4e8e\u6700\u5c0f\u5316\u524d\u9762\u6a21\u578b\u7684\u8bef\u5dee\u3002\u8fd9\u6837\uff0c\u63d0\u5347\u6a21\u578b\u5c31\u80fd\u5b8c\u7f8e\u5730\u5b66\u4e60\u6570\u636e\uff0c\u56e0\u6b64\u5bb9\u6613\u51fa\u73b0\u8fc7\u62df\u5408\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u770b\u5230\u7684\u4ee3\u7801\u7247\u6bb5\u53ea\u8003\u8651\u4e86\u4e00\u5217\u3002\u4f46\u60c5\u51b5\u5e76\u975e\u603b\u662f\u5982\u6b64\uff0c\u5f88\u591a\u65f6\u5019\u60a8\u9700\u8981\u5904\u7406\u591a\u5217\u9884\u6d4b\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u4f1a\u9047\u5230\u4ece\u591a\u4e2a\u7c7b\u522b\u4e2d\u9884\u6d4b\u4e00\u4e2a\u7c7b\u522b\u7684\u95ee\u9898\uff0c\u5373\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5bf9\u4e8e\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u9009\u62e9\u6295\u7968\u65b9\u6cd5\u3002\u4f46\u6295\u7968\u6cd5\u5e76\u4e0d\u603b\u662f\u6700\u4f73\u65b9\u6cd5\u3002\u5982\u679c\u8981\u7ec4\u5408\u6982\u7387\uff0c\u5c31\u4f1a\u6709\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4\uff0c\u800c\u4e0d\u662f\u50cf\u6211\u4eec\u4e4b\u524d\u4f18\u5316 AUC \u65f6\u7684\u5411\u91cf\u3002\u5982\u679c\u6709\u591a\u4e2a\u7c7b\u522b\uff0c\u53ef\u4ee5\u5c1d\u8bd5\u4f18\u5316\u5bf9\u6570\u635f\u5931\uff08\u6216\u5176\u4ed6\u4e0e\u4e1a\u52a1\u76f8\u5173\u7684\u6307\u6807\uff09\u3002 \u8981\u8fdb\u884c\u7ec4\u5408\uff0c\u53ef\u4ee5\u5728\u62df\u5408\u51fd\u6570 (X) \u4e2d\u4f7f\u7528 numpy \u6570\u7ec4\u5217\u8868\u800c\u4e0d\u662f numpy \u6570\u7ec4\uff0c\u968f\u540e\u8fd8\u9700\u8981\u66f4\u6539\u4f18\u5316\u5668\u548c\u9884\u6d4b\u51fd\u6570\u3002\u6211\u5c31\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u5427\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8fdb\u5165\u4e0b\u4e00\u4e2a\u6709\u8da3\u7684\u8bdd\u9898\uff0c\u8fd9\u4e2a\u8bdd\u9898\u76f8\u5f53\u6d41\u884c\uff0c\u88ab\u79f0\u4e3a \u5806\u53e0 \u3002\u56fe 2 \u5c55\u793a\u4e86\u5982\u4f55\u5806\u53e0\u6a21\u578b\u3002 \u56fe2 : Stacking \u5806\u53e0\u4e0d\u50cf\u5236\u9020\u706b\u7bad\u3002\u5b83\u7b80\u5355\u660e\u4e86\u3002\u5982\u679c\u60a8\u8fdb\u884c\u4e86\u6b63\u786e\u7684\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5e76\u5728\u6574\u4e2a\u5efa\u6a21\u8fc7\u7a0b\u4e2d\u4fdd\u6301\u6298\u53e0\u4e0d\u53d8\uff0c\u90a3\u4e48\u5c31\u4e0d\u4f1a\u51fa\u73b0\u4efb\u4f55\u8fc7\u5ea6\u8d34\u5408\u7684\u60c5\u51b5\u3002 \u8ba9\u6211\u7528\u7b80\u5355\u7684\u8981\u70b9\u5411\u4f60\u63cf\u8ff0\u4e00\u4e0b\u8fd9\u4e2a\u60f3\u6cd5\u3002 - \u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u82e5\u5e72\u6298\u53e0\u3002 - \u8bad\u7ec3\u4e00\u5806\u6a21\u578b\uff1a M1\u3001M2.....Mn\u3002 - \u521b\u5efa\u5b8c\u6574\u7684\u8bad\u7ec3\u9884\u6d4b\uff08\u4f7f\u7528\u975e\u6298\u53e0\u8bad\u7ec3\uff09\uff0c\u5e76\u4f7f\u7528\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\u9884\u6d4b\u3002 - \u76f4\u5230\u8fd9\u91cc\u662f\u7b2c 1 \u5c42 (L1)\u3002 - \u5c06\u8fd9\u4e9b\u6a21\u578b\u7684\u6298\u53e0\u9884\u6d4b\u4f5c\u4e3a\u53e6\u4e00\u4e2a\u6a21\u578b\u7684\u7279\u5f81\u3002\u8fd9\u5c31\u662f\u4e8c\u7ea7\u6a21\u578b\uff08L2\uff09\u3002 - \u4f7f\u7528\u4e0e\u4e4b\u524d\u76f8\u540c\u7684\u6298\u53e0\u6765\u8bad\u7ec3\u8fd9\u4e2a L2 \u6a21\u578b\u3002 - \u73b0\u5728\uff0c\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u521b\u5efa OOF\uff08\u6298\u53e0\u5916\uff09\u9884\u6d4b\u3002 - \u73b0\u5728\u60a8\u5c31\u6709\u4e86\u8bad\u7ec3\u6570\u636e\u7684 L2 \u9884\u6d4b\u548c\u6700\u7ec8\u6d4b\u8bd5\u96c6\u9884\u6d4b\u3002 \u60a8\u53ef\u4ee5\u4e0d\u65ad\u91cd\u590d L1 \u90e8\u5206\uff0c\u4e5f\u53ef\u4ee5\u521b\u5efa\u4efb\u610f\u591a\u7684\u5c42\u6b21\u3002 \u6709\u65f6\uff0c\u4f60\u8fd8\u4f1a\u9047\u5230\u4e00\u4e2a\u53eb\u6df7\u5408\u7684\u672f\u8bed blending \u3002\u5982\u679c\u4f60\u9047\u5230\u4e86\uff0c\u4e0d\u7528\u592a\u62c5\u5fc3\u3002\u5b83\u53ea\u4e0d\u8fc7\u662f\u7528\u4e00\u4e2a\u4fdd\u7559\u7ec4\u6765\u5806\u53e0\uff0c\u800c\u4e0d\u662f\u591a\u91cd\u6298\u53e0\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u6211\u5728\u672c\u7ae0\u4e2d\u6240\u63cf\u8ff0\u7684\u5185\u5bb9\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u7c7b\u578b\u7684\u95ee\u9898\uff1a\u5206\u7c7b\u3001\u56de\u5f52\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u7b49\u3002","title":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5"},{"location":"%E7%BB%84%E7%BB%87%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%A1%B9%E7%9B%AE/","text":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee \u7ec8\u4e8e\uff0c\u6211\u4eec\u53ef\u4ee5\u5f00\u59cb\u6784\u5efa\u7b2c\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e86\u3002 \u662f\u8fd9\u6837\u5417\uff1f \u5728\u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u6ce8\u610f\u51e0\u4ef6\u4e8b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u6211\u4eec\u5c06\u5728\u96c6\u6210\u5f00\u53d1\u73af\u5883/\u6587\u672c\u7f16\u8f91\u5668\u4e2d\u5de5\u4f5c\uff0c\u800c\u4e0d\u662f\u5728 jupyter notebook\u4e2d\u3002\u4f60\u4e5f\u53ef\u4ee5\u5728 jupyter notebook\u4e2d\u5de5\u4f5c\uff0c\u8fd9\u5b8c\u5168\u53d6\u51b3\u4e8e\u4f60\u3002\u4e0d\u8fc7\uff0c\u6211\u5c06\u53ea\u4f7f\u7528 jupyter notebook\u6765\u63a2\u7d22\u6570\u636e\u3001\u7ed8\u5236\u56fe\u8868\u548c\u56fe\u5f62\u3002\u6211\u4eec\u5c06\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u6784\u5efa\u5206\u7c7b\u6846\u67b6\uff0c\u5373\u63d2\u5373\u7528\u3002\u60a8\u65e0\u9700\u5bf9\u4ee3\u7801\u505a\u592a\u591a\u6539\u52a8\u5c31\u80fd\u8bad\u7ec3\u6a21\u578b\uff0c\u800c\u4e14\u5f53\u60a8\u6539\u8fdb\u6a21\u578b\u65f6\uff0c\u8fd8\u80fd\u4f7f\u7528 git \u5bf9\u5176\u8fdb\u884c\u8ddf\u8e2a\u3002 \u6211\u4eec\u9996\u5148\u6765\u770b\u770b\u6587\u4ef6\u7684\u7ed3\u6784\u3002\u5bf9\u4e8e\u4f60\u6b63\u5728\u505a\u7684\u4efb\u4f55\u9879\u76ee\uff0c\u90fd\u8981\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\u5939\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u5c06\u9879\u76ee\u547d\u540d\u4e3a \"project\"\u3002 \u9879\u76ee\u6587\u4ef6\u5939\u5185\u90e8\u5e94\u8be5\u5982\u4e0b\u6240\u793a\u3002 input train.csv test.csv src create_folds.py train.py inference.py models.py config.py model_dispatcher.py models model_rf.bin model_et.bin notebooks exploration.ipynb check_data.ipynb README.md LICENSE \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e9b\u6587\u4ef6\u5939\u548c\u6587\u4ef6\u7684\u5185\u5bb9\u3002 input/ \uff1a\u8be5\u6587\u4ef6\u5939\u5305\u542b\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7684\u6240\u6709\u8f93\u5165\u6587\u4ef6\u548c\u6570\u636e\u3002\u5982\u679c\u60a8\u6b63\u5728\u5f00\u53d1 NLP \u9879\u76ee\uff0c\u60a8\u53ef\u4ee5\u5c06embeddings\u653e\u5728\u8fd9\u91cc\u3002\u5982\u679c\u662f\u56fe\u50cf\u9879\u76ee\uff0c\u6240\u6709\u56fe\u50cf\u90fd\u653e\u5728\u8be5\u6587\u4ef6\u5939\u4e0b\u7684\u5b50\u6587\u4ef6\u5939\u4e2d\u3002 src/ \uff1a\u6211\u4eec\u5c06\u5728\u8fd9\u91cc\u4fdd\u5b58\u4e0e\u9879\u76ee\u76f8\u5173\u7684\u6240\u6709 python \u811a\u672c\u3002\u5982\u679c\u6211\u8bf4\u7684\u662f\u4e00\u4e2a python \u811a\u672c\uff0c\u5373\u4efb\u4f55 *.py \u6587\u4ef6\uff0c\u5b83\u90fd\u5b58\u50a8\u5728 src \u6587\u4ef6\u5939\u4e2d\u3002 models/ \uff1a\u8be5\u6587\u4ef6\u5939\u4fdd\u5b58\u6240\u6709\u8bad\u7ec3\u8fc7\u7684\u6a21\u578b\u3002 notebook/ \uff1a\u6240\u6709 jupyter notebook\uff08\u5373\u4efb\u4f55 *.ipynb \u6587\u4ef6\uff09\u90fd\u5b58\u50a8\u5728\u7b14\u8bb0\u672c \u6587\u4ef6\u5939\u4e2d\u3002 README.md \uff1a\u8fd9\u662f\u4e00\u4e2a\u6807\u8bb0\u7b26\u6587\u4ef6\uff0c\u60a8\u53ef\u4ee5\u5728\u5176\u4e2d\u63cf\u8ff0\u60a8\u7684\u9879\u76ee\uff0c\u5e76\u5199\u660e\u5982\u4f55\u8bad\u7ec3\u6a21\u578b\u6216\u5728\u751f\u4ea7\u73af\u5883\u4e2d\u4f7f\u7528\u3002 LICENSE \uff1a\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u6587\u4ef6\uff0c\u5305\u542b\u9879\u76ee\u7684\u8bb8\u53ef\u8bc1\uff0c\u5982 MIT\u3001Apache \u7b49\u3002\u5173\u4e8e\u8bb8\u53ef\u8bc1\u7684\u8be6\u7ec6\u4ecb\u7ecd\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 \u5047\u8bbe\u4f60\u6b63\u5728\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u5bf9 MNIST \u6570\u636e\u96c6\uff08\u51e0\u4e4e\u6bcf\u672c\u673a\u5668\u5b66\u4e60\u4e66\u7c4d\u90fd\u4f1a\u7528\u5230\u7684\u6570\u636e\u96c6\uff09\u8fdb\u884c\u5206\u7c7b\u3002\u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\uff0c\u6211\u4eec\u5728\u4ea4\u53c9\u68c0\u9a8c\u4e00\u7ae0\u4e2d\u4e5f\u63d0\u5230\u8fc7 MNIST \u6570\u636e\u96c6\u3002\u6240\u4ee5\uff0c\u6211\u5c31\u4e0d\u89e3\u91ca\u8fd9\u4e2a\u6570\u636e\u96c6\u662f\u4ec0\u4e48\u6837\u5b50\u4e86\u3002\u7f51\u4e0a\u6709\u8bb8\u591a\u4e0d\u540c\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6\uff0c\u4f46\u6211\u4eec\u5c06\u4f7f\u7528 CSV \u683c\u5f0f\u7684\u6570\u636e\u96c6\u3002 \u5728\u8fd9\u79cd\u683c\u5f0f\u7684\u6570\u636e\u96c6\u4e2d\uff0cCSV \u7684\u6bcf\u4e00\u884c\u90fd\u5305\u542b\u56fe\u50cf\u7684\u6807\u7b7e\u548c 784 \u4e2a\u50cf\u7d20\u503c\uff0c\u50cf\u7d20\u503c\u8303\u56f4\u4ece 0 \u5230 255\u3002\u6570\u636e\u96c6\u5305\u542b 60000 \u5f20\u8fd9\u79cd\u683c\u5f0f\u7684\u56fe\u50cf\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u8f7b\u677e\u8bfb\u53d6\u8fd9\u79cd\u6570\u636e\u683c\u5f0f\u3002 \u8bf7\u6ce8\u610f\uff0c\u5c3d\u7ba1\u56fe 1 \u663e\u793a\u6240\u6709\u50cf\u7d20\u503c\u5747\u4e3a\u96f6\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002 \u56fe 1\uff1aCSV\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u6807\u7b7e\u5217\u7684\u8ba1\u6570\u3002 \u56fe 2\uff1aMNIST \u6570\u636e\u96c6\u4e2d\u7684\u6807\u7b7e\u8ba1\u6570 \u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c\u66f4\u591a\u7684\u63a2\u7d22\u3002\u6211\u4eec\u5df2\u7ecf\u77e5\u9053\u4e86\u6211\u4eec\u6240\u62e5\u6709\u7684\u6570\u636e\uff0c\u6ca1\u6709\u5fc5\u8981\u518d\u5bf9\u4e0d\u540c\u7684\u50cf\u7d20\u503c\u8fdb\u884c\u7ed8\u56fe\u3002\u4ece\u56fe 2 \u4e2d\u53ef\u4ee5\u6e05\u695a\u5730\u770b\u51fa\uff0c\u6807\u7b7e\u7684\u5206\u5e03\u76f8\u5f53\u5747\u5300\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387/F1 \u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u8fd9\u5c31\u662f\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u7684\u7b2c\u4e00\u6b65\uff1a\u786e\u5b9a\u8861\u91cf\u6807\u51c6\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7f16\u5199\u4e00\u4e9b\u4ee3\u7801\u4e86\u3002\u6211\u4eec\u9700\u8981\u521b\u5efa src/ \u6587\u4ef6\u5939\u548c\u4e00\u4e9b python \u811a\u672c\u3002 \u8bf7\u6ce8\u610f\uff0c\u8bad\u7ec3 CSV \u6587\u4ef6\u4f4d\u4e8e input/ \u6587\u4ef6\u5939\u4e2d\uff0c\u540d\u4e3a mnist_train.csv \u3002 \u5bf9\u4e8e\u8fd9\u6837\u4e00\u4e2a\u9879\u76ee\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5e94\u8be5\u662f\u4ec0\u4e48\u6837\u7684\u5462\uff1f \u9996\u5148\u8981\u521b\u5efa\u7684\u811a\u672c\u662f create_folds.py \u3002 \u8fd9\u5c06\u5728 input/ \u6587\u4ef6\u5939\u4e2d\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a mnist_train_folds.csv \u7684\u65b0\u6587\u4ef6\uff0c\u4e0e mnist_train.csv \u76f8\u540c\u3002\u552f\u4e00\u4e0d\u540c\u7684\u662f\uff0c\u8fd9\u4e2a CSV \u6587\u4ef6\u7ecf\u8fc7\u4e86\u968f\u673a\u6392\u5e8f\uff0c\u5e76\u65b0\u589e\u4e86\u4e00\u5217\u540d\u4e3a kfold \u7684\u5185\u5bb9\u3002 \u4e00\u65e6\u6211\u4eec\u51b3\u5b9a\u4e86\u8981\u4f7f\u7528\u54ea\u79cd\u8bc4\u4f30\u6307\u6807\u5e76\u521b\u5efa\u4e86\u6298\u53e0\uff0c\u5c31\u53ef\u4ee5\u5f00\u59cb\u521b\u5efa\u57fa\u672c\u6a21\u578b\u4e86\u3002\u8fd9\u53ef\u4ee5\u5728 train.py \u4e2d\u5b8c\u6210\u3002 import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/mnist_train_folds.csv\" ) # \u9009\u53d6df\u4e2dkfold\u5217\u4e0d\u7b49\u4e8efold df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u9009\u53d6df\u4e2dkfold\u5217\u7b49\u4e8efold df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u8bad\u7ec3\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_train = df_train . drop ( \"label\" , axis = 1 ) . values # \u8bad\u7ec3\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_train = df_train . label . values # \u9a8c\u8bc1\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values # \u9a8c\u8bc1\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_valid = df_valid . label . values # \u5b9e\u4f8b\u5316\u51b3\u7b56\u6811\u6a21\u578b clf = tree . DecisionTreeClassifier () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b clf . fit ( x_train , y_train ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u8f93\u5165\u5f97\u5230\u9884\u6d4b\u7ed3\u679c preds = clf . predict ( x_valid ) # \u8ba1\u7b97\u9a8c\u8bc1\u96c6\u51c6\u786e\u7387 accuracy = metrics . accuracy_score ( y_valid , preds ) # \u6253\u5370fold\u4fe1\u606f\u548c\u51c6\u786e\u7387 print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) # \u4fdd\u5b58\u6a21\u578b joblib . dump ( clf , f \"../models/dt_ { fold } .bin\" ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u60a8\u53ef\u4ee5\u5728\u63a7\u5236\u53f0\u8c03\u7528 python train.py \u8fd0\u884c\u8be5\u811a\u672c\u3002 \u276f python train . py Fold = 0 , Accuracy = 0.8680833333333333 Fold = 1 , Accuracy = 0.8685 Fold = 2 , Accuracy = 0.8674166666666666 Fold = 3 , Accuracy = 0.8703333333333333 Fold = 4 , Accuracy = 0.8699166666666667 \u67e5\u770b\u8bad\u7ec3\u811a\u672c\u65f6\uff0c\u60a8\u4f1a\u53d1\u73b0\u8fd8\u6709\u4e00\u4e9b\u5185\u5bb9\u662f\u786c\u7f16\u7801\u7684\uff0c\u4f8b\u5982\u6298\u53e0\u6570\u3001\u8bad\u7ec3\u6587\u4ef6\u548c\u8f93\u51fa\u6587\u4ef6\u5939\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u6240\u6709\u8fd9\u4e9b\u4fe1\u606f\u7684\u914d\u7f6e\u6587\u4ef6\uff1a config.py \u3002 TRAINING_FILE = \"../input/mnist_train_folds.csv\" MODEL_OUTPUT = \"../models/\" \u6211\u4eec\u8fd8\u5bf9\u8bad\u7ec3\u811a\u672c\u8fdb\u884c\u4e86\u4e00\u4e9b\u4fee\u6539\u3002\u8bad\u7ec3\u6587\u4ef6\u73b0\u5728\u4f7f\u7528\u914d\u7f6e\u6587\u4ef6\u3002\u8fd9\u6837\uff0c\u66f4\u6539\u6570\u636e\u6216\u6a21\u578b\u8f93\u51fa\u5c31\u66f4\u5bb9\u6613\u4e86\u3002 import os import config import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u4f7f\u7528config\u4e2d\u7684\u8def\u5f84\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values clf = tree . DecisionTreeClassifier () clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" ) ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5e76\u6ca1\u6709\u5c55\u793a\u8fd9\u4e2a\u57f9\u8bad\u811a\u672c\u4e0e\u4e4b\u524d\u811a\u672c\u7684\u533a\u522b\u3002\u8bf7\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e24\u4e2a\u811a\u672c\uff0c\u81ea\u5df1\u627e\u51fa\u4e0d\u540c\u4e4b\u5904\u3002\u533a\u522b\u5e76\u4e0d\u591a\u3002 \u4e0e\u8bad\u7ec3\u811a\u672c\u76f8\u5173\u7684\u8fd8\u6709\u4e00\u70b9\u53ef\u4ee5\u6539\u8fdb\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u4e3a\u6bcf\u4e2a\u6298\u53e0\u591a\u6b21\u8c03\u7528\u8fd0\u884c\u51fd\u6570\u3002\u6709\u65f6\uff0c\u5728\u540c\u4e00\u4e2a\u811a\u672c\u4e2d\u8fd0\u884c\u591a\u4e2a\u6298\u53e0\u5e76\u4e0d\u53ef\u53d6\uff0c\u56e0\u4e3a\u5185\u5b58\u6d88\u8017\u53ef\u80fd\u4f1a\u4e0d\u65ad\u589e\u52a0\uff0c\u7a0b\u5e8f\u53ef\u80fd\u4f1a\u5d29\u6e83\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5411\u8bad\u7ec3\u811a\u672c\u4f20\u9012\u53c2\u6570\u3002\u6211\u559c\u6b22\u4f7f\u7528 argparse\u3002 import argparse if __name__ == \"__main__\" : # \u5b9e\u4f8b\u5316\u53c2\u6570\u73af\u5883 parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # \u8bfb\u53d6\u53c2\u6570 args = parser . parse_args () run ( fold = args . fold ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u8fd0\u884c python \u811a\u672c\uff0c\u4f46\u4ec5\u9650\u4e8e\u7ed9\u5b9a\u7684\u6298\u53e0\u3002 \u276f python train . py -- fold 0 Fold = 0 , Accuracy = 0.8656666666666667 \u4ed4\u7ec6\u89c2\u5bdf\uff0c\u6211\u4eec\u7684\u7b2c 0 \u6298\u5f97\u5206\u4e0e\u4e4b\u524d\u6709\u4e9b\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6a21\u578b\u4e2d\u5b58\u5728\u968f\u673a\u6027\u3002\u6211\u4eec\u5c06\u5728\u540e\u9762\u7684\u7ae0\u8282\u4e2d\u8ba8\u8bba\u5982\u4f55\u5904\u7406\u968f\u673a\u6027\u3002 \u73b0\u5728\uff0c\u5982\u679c\u4f60\u613f\u610f\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a shell \u811a\u672c \uff0c\u9488\u5bf9\u4e0d\u540c\u7684\u6298\u53e0\u4f7f\u7528\u4e0d\u540c\u7684\u547d\u4ee4\uff0c\u7136\u540e\u4e00\u8d77\u8fd0\u884c\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 python train . py -- fold 0 python train . py -- fold 1 python train . py -- fold 2 python train . py -- fold 3 python train . py -- fold 4 \u60a8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u5b83\u3002 \u276f sh run . sh Fold = 0 , Accuracy = 0.8675 Fold = 1 , Accuracy = 0.8693333333333333 Fold = 2 , Accuracy = 0.8683333333333333 Fold = 3 , Accuracy = 0.8704166666666666 Fold = 4 , Accuracy = 0.8685 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u53d6\u5f97\u4e86\u4e00\u4e9b\u8fdb\u5c55\uff0c\u4f46\u5982\u679c\u6211\u4eec\u770b\u4e00\u4e0b\u6211\u4eec\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u6211\u4eec\u4ecd\u7136\u53d7\u5230\u4e00\u4e9b\u4e1c\u897f\u7684\u9650\u5236\uff0c\u4f8b\u5982\u6a21\u578b\u3002\u6a21\u578b\u662f\u786c\u7f16\u7801\u5728\u8bad\u7ec3\u811a\u672c\u4e2d\u7684\uff0c\u53ea\u6709\u4fee\u6539\u811a\u672c\u624d\u80fd\u6539\u53d8\u5b83\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u65b0\u7684 python \u811a\u672c\uff0c\u540d\u4e3a model_dispatcher.py \u3002model_dispatcher.py\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5c06\u8c03\u5ea6\u6211\u4eec\u7684\u6a21\u578b\u5230\u8bad\u7ec3\u811a\u672c\u4e2d\u3002 from sklearn import tree models = { # \u4ee5gini\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), # \u4ee5entropy\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), } model_dispatcher.py \u4ece scikit-learn \u4e2d\u5bfc\u5165\u4e86 tree\uff0c\u5e76\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u6a21\u578b\u7684\u540d\u79f0\uff0c\u503c\u662f\u6a21\u578b\u672c\u8eab\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5b9a\u4e49\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u51b3\u7b56\u6811\uff0c\u4e00\u79cd\u4f7f\u7528\u57fa\u5c3c\u6807\u51c6\uff0c\u53e6\u4e00\u79cd\u4f7f\u7528\u71b5\u6807\u51c6\u3002\u8981\u4f7f\u7528 py\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u8bad\u7ec3\u811a\u672c\u505a\u4e00\u4e9b\u4fee\u6539\u3002 import argparse import os import joblib import pandas as pd from sklearn import metrics import config import model_dispatcher def run ( fold , model ): df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values # \u6839\u636emodel\u53c2\u6570\u9009\u62e9\u6a21\u578b clf = model_dispatcher . models [ model ] clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" )) if __name__ == \"__main__\" : parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # model\u53c2\u6570 parser . add_argument ( \"--model\" , type = str ) args = parser . parse_args () run ( fold = args . fold , model = args . model ) train.py \u6709\u51e0\u5904\u91cd\u5927\u6539\u52a8\uff1a - \u5bfc\u5165 model_dispatcher - \u4e3a ArgumentParser \u6dfb\u52a0 --model \u53c2\u6570 - \u4e3a run() \u51fd\u6570\u6dfb\u52a0model\u53c2\u6570 - \u4f7f\u7528\u8c03\u5ea6\u7a0b\u5e8f\u83b7\u53d6\u6307\u5b9a\u540d\u79f0\u7684\u6a21\u578b \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u811a\u672c\uff1a \u276f python train . py -- fold 0 -- model decision_tree_gini Fold = 0 , Accuracy = 0.8665833333333334 \u6216\u6267\u884c\u4ee5\u4e0b\u547d\u4ee4 \u276f python train . py -- fold 0 -- model decision_tree_entropy Fold = 0 , Accuracy = 0.8705833333333334 \u73b0\u5728\uff0c\u5982\u679c\u8981\u6dfb\u52a0\u65b0\u6a21\u578b\uff0c\u53ea\u9700\u4fee\u6539 model_dispatcher.py \u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u6dfb\u52a0\u968f\u673a\u68ee\u6797\uff0c\u770b\u770b\u51c6\u786e\u7387\u4f1a\u6709\u4ec0\u4e48\u53d8\u5316\u3002 from sklearn import ensemble from sklearn import tree models = { \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), # \u968f\u673a\u68ee\u6797\u6a21\u578b \"rf\" : ensemble . RandomForestClassifier (), } \u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u3002 \u276f python train . py -- fold 0 -- model rf Fold = 0 , Accuracy = 0.9670833333333333 \u54c7\uff0c\u4e00\u4e2a\u7b80\u5355\u7684\u6539\u52a8\u5c31\u80fd\u8ba9\u5206\u6570\u6709\u5982\u6b64\u5927\u7684\u63d0\u5347\uff01\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528 run.sh \u811a\u672c\u8fd0\u884c 5 \u4e2a\u6298\u53e0\uff01 python train . py -- fold 0 -- model rf python train . py -- fold 1 -- model rf python train . py -- fold 2 -- model rf python train . py -- fold 3 -- model rf python train . py -- fold 4 -- model rf \u5f97\u5206\u60c5\u51b5\u5982\u4e0b \u276f sh run . sh Fold = 0 , Accuracy = 0.9674166666666667 Fold = 1 , Accuracy = 0.9698333333333333 Fold = 2 , Accuracy = 0.96575 Fold = 3 , Accuracy = 0.9684166666666667 Fold = 4 , Accuracy = 0.9666666666666667 MNIST \u51e0\u4e4e\u662f\u6bcf\u672c\u4e66\u548c\u6bcf\u7bc7\u535a\u5ba2\u90fd\u4f1a\u8ba8\u8bba\u7684\u95ee\u9898\u3002\u4f46\u6211\u8bd5\u56fe\u5c06\u8fd9\u4e2a\u95ee\u9898\u8f6c\u6362\u5f97\u66f4\u6709\u8da3\uff0c\u5e76\u5411\u4f60\u5c55\u793a\u5982\u4f55\u4e3a\u4f60\u6b63\u5728\u505a\u7684\u6216\u8ba1\u5212\u5728\u4e0d\u4e45\u7684\u5c06\u6765\u505a\u7684\u51e0\u4e4e\u6240\u6709\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7f16\u5199\u4e00\u4e2a\u57fa\u672c\u6846\u67b6\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4ee5\u6539\u8fdb\u8fd9\u4e2a MNIST \u6a21\u578b\u548c\u8fd9\u4e2a\u6846\u67b6\uff0c\u6211\u4eec\u5c06\u5728\u4ee5\u540e\u7684\u7ae0\u8282\u4e2d\u770b\u5230\u3002 \u6211\u4f7f\u7528\u4e86\u4e00\u4e9b\u811a\u672c\uff0c\u5982 model_dispatcher.py \u548c config.py \uff0c\u5e76\u5c06\u5b83\u4eec\u5bfc\u5165\u5230\u6211\u7684\u8bad\u7ec3\u811a\u672c\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u6ca1\u6709\u5bfc\u5165 \uff0c\u4f60\u4e5f\u4e0d\u5e94\u8be5\u5bfc\u5165\u3002\u5982\u679c\u6211\u5bfc\u5165\u4e86 \uff0c\u4f60\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u77e5\u9053\u6a21\u578b\u5b57\u5178\u662f\u4ece\u54ea\u91cc\u6765\u7684\u3002\u7f16\u5199\u4f18\u79c0\u3001\u6613\u61c2\u7684\u4ee3\u7801\u662f\u4e00\u4e2a\u4eba\u5fc5\u987b\u5177\u5907\u7684\u57fa\u672c\u7d20\u8d28\uff0c\u4f46\u8bb8\u591a\u6570\u636e\u79d1\u5b66\u5bb6\u5374\u5ffd\u89c6\u4e86\u8fd9\u4e00\u70b9\u3002\u5982\u679c\u4f60\u6240\u505a\u7684\u9879\u76ee\u80fd\u8ba9\u5176\u4ed6\u4eba\u7406\u89e3\u5e76\u4f7f\u7528\uff0c\u800c\u65e0\u9700\u54a8\u8be2\u4f60\u7684\u610f\u89c1\uff0c\u90a3\u4e48\u4f60\u5c31\u8282\u7701\u4e86\u4ed6\u4eec\u7684\u65f6\u95f4\u548c\u81ea\u5df1\u7684\u65f6\u95f4\uff0c\u53ef\u4ee5\u5c06\u8fd9\u4e9b\u65f6\u95f4\u6295\u5165\u5230\u6539\u8fdb\u4f60\u7684\u9879\u76ee\u6216\u5f00\u53d1\u65b0\u9879\u76ee\u4e2d\u53bb\u3002","title":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee"},{"location":"%E7%BB%84%E7%BB%87%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%A1%B9%E7%9B%AE/#_1","text":"\u7ec8\u4e8e\uff0c\u6211\u4eec\u53ef\u4ee5\u5f00\u59cb\u6784\u5efa\u7b2c\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e86\u3002 \u662f\u8fd9\u6837\u5417\uff1f \u5728\u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u6ce8\u610f\u51e0\u4ef6\u4e8b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u6211\u4eec\u5c06\u5728\u96c6\u6210\u5f00\u53d1\u73af\u5883/\u6587\u672c\u7f16\u8f91\u5668\u4e2d\u5de5\u4f5c\uff0c\u800c\u4e0d\u662f\u5728 jupyter notebook\u4e2d\u3002\u4f60\u4e5f\u53ef\u4ee5\u5728 jupyter notebook\u4e2d\u5de5\u4f5c\uff0c\u8fd9\u5b8c\u5168\u53d6\u51b3\u4e8e\u4f60\u3002\u4e0d\u8fc7\uff0c\u6211\u5c06\u53ea\u4f7f\u7528 jupyter notebook\u6765\u63a2\u7d22\u6570\u636e\u3001\u7ed8\u5236\u56fe\u8868\u548c\u56fe\u5f62\u3002\u6211\u4eec\u5c06\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u6784\u5efa\u5206\u7c7b\u6846\u67b6\uff0c\u5373\u63d2\u5373\u7528\u3002\u60a8\u65e0\u9700\u5bf9\u4ee3\u7801\u505a\u592a\u591a\u6539\u52a8\u5c31\u80fd\u8bad\u7ec3\u6a21\u578b\uff0c\u800c\u4e14\u5f53\u60a8\u6539\u8fdb\u6a21\u578b\u65f6\uff0c\u8fd8\u80fd\u4f7f\u7528 git \u5bf9\u5176\u8fdb\u884c\u8ddf\u8e2a\u3002 \u6211\u4eec\u9996\u5148\u6765\u770b\u770b\u6587\u4ef6\u7684\u7ed3\u6784\u3002\u5bf9\u4e8e\u4f60\u6b63\u5728\u505a\u7684\u4efb\u4f55\u9879\u76ee\uff0c\u90fd\u8981\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\u5939\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u5c06\u9879\u76ee\u547d\u540d\u4e3a \"project\"\u3002 \u9879\u76ee\u6587\u4ef6\u5939\u5185\u90e8\u5e94\u8be5\u5982\u4e0b\u6240\u793a\u3002 input train.csv test.csv src create_folds.py train.py inference.py models.py config.py model_dispatcher.py models model_rf.bin model_et.bin notebooks exploration.ipynb check_data.ipynb README.md LICENSE \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e9b\u6587\u4ef6\u5939\u548c\u6587\u4ef6\u7684\u5185\u5bb9\u3002 input/ \uff1a\u8be5\u6587\u4ef6\u5939\u5305\u542b\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7684\u6240\u6709\u8f93\u5165\u6587\u4ef6\u548c\u6570\u636e\u3002\u5982\u679c\u60a8\u6b63\u5728\u5f00\u53d1 NLP \u9879\u76ee\uff0c\u60a8\u53ef\u4ee5\u5c06embeddings\u653e\u5728\u8fd9\u91cc\u3002\u5982\u679c\u662f\u56fe\u50cf\u9879\u76ee\uff0c\u6240\u6709\u56fe\u50cf\u90fd\u653e\u5728\u8be5\u6587\u4ef6\u5939\u4e0b\u7684\u5b50\u6587\u4ef6\u5939\u4e2d\u3002 src/ \uff1a\u6211\u4eec\u5c06\u5728\u8fd9\u91cc\u4fdd\u5b58\u4e0e\u9879\u76ee\u76f8\u5173\u7684\u6240\u6709 python \u811a\u672c\u3002\u5982\u679c\u6211\u8bf4\u7684\u662f\u4e00\u4e2a python \u811a\u672c\uff0c\u5373\u4efb\u4f55 *.py \u6587\u4ef6\uff0c\u5b83\u90fd\u5b58\u50a8\u5728 src \u6587\u4ef6\u5939\u4e2d\u3002 models/ \uff1a\u8be5\u6587\u4ef6\u5939\u4fdd\u5b58\u6240\u6709\u8bad\u7ec3\u8fc7\u7684\u6a21\u578b\u3002 notebook/ \uff1a\u6240\u6709 jupyter notebook\uff08\u5373\u4efb\u4f55 *.ipynb \u6587\u4ef6\uff09\u90fd\u5b58\u50a8\u5728\u7b14\u8bb0\u672c \u6587\u4ef6\u5939\u4e2d\u3002 README.md \uff1a\u8fd9\u662f\u4e00\u4e2a\u6807\u8bb0\u7b26\u6587\u4ef6\uff0c\u60a8\u53ef\u4ee5\u5728\u5176\u4e2d\u63cf\u8ff0\u60a8\u7684\u9879\u76ee\uff0c\u5e76\u5199\u660e\u5982\u4f55\u8bad\u7ec3\u6a21\u578b\u6216\u5728\u751f\u4ea7\u73af\u5883\u4e2d\u4f7f\u7528\u3002 LICENSE \uff1a\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u6587\u4ef6\uff0c\u5305\u542b\u9879\u76ee\u7684\u8bb8\u53ef\u8bc1\uff0c\u5982 MIT\u3001Apache \u7b49\u3002\u5173\u4e8e\u8bb8\u53ef\u8bc1\u7684\u8be6\u7ec6\u4ecb\u7ecd\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 \u5047\u8bbe\u4f60\u6b63\u5728\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u5bf9 MNIST \u6570\u636e\u96c6\uff08\u51e0\u4e4e\u6bcf\u672c\u673a\u5668\u5b66\u4e60\u4e66\u7c4d\u90fd\u4f1a\u7528\u5230\u7684\u6570\u636e\u96c6\uff09\u8fdb\u884c\u5206\u7c7b\u3002\u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\uff0c\u6211\u4eec\u5728\u4ea4\u53c9\u68c0\u9a8c\u4e00\u7ae0\u4e2d\u4e5f\u63d0\u5230\u8fc7 MNIST \u6570\u636e\u96c6\u3002\u6240\u4ee5\uff0c\u6211\u5c31\u4e0d\u89e3\u91ca\u8fd9\u4e2a\u6570\u636e\u96c6\u662f\u4ec0\u4e48\u6837\u5b50\u4e86\u3002\u7f51\u4e0a\u6709\u8bb8\u591a\u4e0d\u540c\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6\uff0c\u4f46\u6211\u4eec\u5c06\u4f7f\u7528 CSV \u683c\u5f0f\u7684\u6570\u636e\u96c6\u3002 \u5728\u8fd9\u79cd\u683c\u5f0f\u7684\u6570\u636e\u96c6\u4e2d\uff0cCSV \u7684\u6bcf\u4e00\u884c\u90fd\u5305\u542b\u56fe\u50cf\u7684\u6807\u7b7e\u548c 784 \u4e2a\u50cf\u7d20\u503c\uff0c\u50cf\u7d20\u503c\u8303\u56f4\u4ece 0 \u5230 255\u3002\u6570\u636e\u96c6\u5305\u542b 60000 \u5f20\u8fd9\u79cd\u683c\u5f0f\u7684\u56fe\u50cf\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u8f7b\u677e\u8bfb\u53d6\u8fd9\u79cd\u6570\u636e\u683c\u5f0f\u3002 \u8bf7\u6ce8\u610f\uff0c\u5c3d\u7ba1\u56fe 1 \u663e\u793a\u6240\u6709\u50cf\u7d20\u503c\u5747\u4e3a\u96f6\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002 \u56fe 1\uff1aCSV\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u6807\u7b7e\u5217\u7684\u8ba1\u6570\u3002 \u56fe 2\uff1aMNIST \u6570\u636e\u96c6\u4e2d\u7684\u6807\u7b7e\u8ba1\u6570 \u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c\u66f4\u591a\u7684\u63a2\u7d22\u3002\u6211\u4eec\u5df2\u7ecf\u77e5\u9053\u4e86\u6211\u4eec\u6240\u62e5\u6709\u7684\u6570\u636e\uff0c\u6ca1\u6709\u5fc5\u8981\u518d\u5bf9\u4e0d\u540c\u7684\u50cf\u7d20\u503c\u8fdb\u884c\u7ed8\u56fe\u3002\u4ece\u56fe 2 \u4e2d\u53ef\u4ee5\u6e05\u695a\u5730\u770b\u51fa\uff0c\u6807\u7b7e\u7684\u5206\u5e03\u76f8\u5f53\u5747\u5300\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387/F1 \u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u8fd9\u5c31\u662f\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u7684\u7b2c\u4e00\u6b65\uff1a\u786e\u5b9a\u8861\u91cf\u6807\u51c6\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7f16\u5199\u4e00\u4e9b\u4ee3\u7801\u4e86\u3002\u6211\u4eec\u9700\u8981\u521b\u5efa src/ \u6587\u4ef6\u5939\u548c\u4e00\u4e9b python \u811a\u672c\u3002 \u8bf7\u6ce8\u610f\uff0c\u8bad\u7ec3 CSV \u6587\u4ef6\u4f4d\u4e8e input/ \u6587\u4ef6\u5939\u4e2d\uff0c\u540d\u4e3a mnist_train.csv \u3002 \u5bf9\u4e8e\u8fd9\u6837\u4e00\u4e2a\u9879\u76ee\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5e94\u8be5\u662f\u4ec0\u4e48\u6837\u7684\u5462\uff1f \u9996\u5148\u8981\u521b\u5efa\u7684\u811a\u672c\u662f create_folds.py \u3002 \u8fd9\u5c06\u5728 input/ \u6587\u4ef6\u5939\u4e2d\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a mnist_train_folds.csv \u7684\u65b0\u6587\u4ef6\uff0c\u4e0e mnist_train.csv \u76f8\u540c\u3002\u552f\u4e00\u4e0d\u540c\u7684\u662f\uff0c\u8fd9\u4e2a CSV \u6587\u4ef6\u7ecf\u8fc7\u4e86\u968f\u673a\u6392\u5e8f\uff0c\u5e76\u65b0\u589e\u4e86\u4e00\u5217\u540d\u4e3a kfold \u7684\u5185\u5bb9\u3002 \u4e00\u65e6\u6211\u4eec\u51b3\u5b9a\u4e86\u8981\u4f7f\u7528\u54ea\u79cd\u8bc4\u4f30\u6307\u6807\u5e76\u521b\u5efa\u4e86\u6298\u53e0\uff0c\u5c31\u53ef\u4ee5\u5f00\u59cb\u521b\u5efa\u57fa\u672c\u6a21\u578b\u4e86\u3002\u8fd9\u53ef\u4ee5\u5728 train.py \u4e2d\u5b8c\u6210\u3002 import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/mnist_train_folds.csv\" ) # \u9009\u53d6df\u4e2dkfold\u5217\u4e0d\u7b49\u4e8efold df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u9009\u53d6df\u4e2dkfold\u5217\u7b49\u4e8efold df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u8bad\u7ec3\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_train = df_train . drop ( \"label\" , axis = 1 ) . values # \u8bad\u7ec3\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_train = df_train . label . values # \u9a8c\u8bc1\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values # \u9a8c\u8bc1\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_valid = df_valid . label . values # \u5b9e\u4f8b\u5316\u51b3\u7b56\u6811\u6a21\u578b clf = tree . DecisionTreeClassifier () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b clf . fit ( x_train , y_train ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u8f93\u5165\u5f97\u5230\u9884\u6d4b\u7ed3\u679c preds = clf . predict ( x_valid ) # \u8ba1\u7b97\u9a8c\u8bc1\u96c6\u51c6\u786e\u7387 accuracy = metrics . accuracy_score ( y_valid , preds ) # \u6253\u5370fold\u4fe1\u606f\u548c\u51c6\u786e\u7387 print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) # \u4fdd\u5b58\u6a21\u578b joblib . dump ( clf , f \"../models/dt_ { fold } .bin\" ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u60a8\u53ef\u4ee5\u5728\u63a7\u5236\u53f0\u8c03\u7528 python train.py \u8fd0\u884c\u8be5\u811a\u672c\u3002 \u276f python train . py Fold = 0 , Accuracy = 0.8680833333333333 Fold = 1 , Accuracy = 0.8685 Fold = 2 , Accuracy = 0.8674166666666666 Fold = 3 , Accuracy = 0.8703333333333333 Fold = 4 , Accuracy = 0.8699166666666667 \u67e5\u770b\u8bad\u7ec3\u811a\u672c\u65f6\uff0c\u60a8\u4f1a\u53d1\u73b0\u8fd8\u6709\u4e00\u4e9b\u5185\u5bb9\u662f\u786c\u7f16\u7801\u7684\uff0c\u4f8b\u5982\u6298\u53e0\u6570\u3001\u8bad\u7ec3\u6587\u4ef6\u548c\u8f93\u51fa\u6587\u4ef6\u5939\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u6240\u6709\u8fd9\u4e9b\u4fe1\u606f\u7684\u914d\u7f6e\u6587\u4ef6\uff1a config.py \u3002 TRAINING_FILE = \"../input/mnist_train_folds.csv\" MODEL_OUTPUT = \"../models/\" \u6211\u4eec\u8fd8\u5bf9\u8bad\u7ec3\u811a\u672c\u8fdb\u884c\u4e86\u4e00\u4e9b\u4fee\u6539\u3002\u8bad\u7ec3\u6587\u4ef6\u73b0\u5728\u4f7f\u7528\u914d\u7f6e\u6587\u4ef6\u3002\u8fd9\u6837\uff0c\u66f4\u6539\u6570\u636e\u6216\u6a21\u578b\u8f93\u51fa\u5c31\u66f4\u5bb9\u6613\u4e86\u3002 import os import config import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u4f7f\u7528config\u4e2d\u7684\u8def\u5f84\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values clf = tree . DecisionTreeClassifier () clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" ) ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5e76\u6ca1\u6709\u5c55\u793a\u8fd9\u4e2a\u57f9\u8bad\u811a\u672c\u4e0e\u4e4b\u524d\u811a\u672c\u7684\u533a\u522b\u3002\u8bf7\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e24\u4e2a\u811a\u672c\uff0c\u81ea\u5df1\u627e\u51fa\u4e0d\u540c\u4e4b\u5904\u3002\u533a\u522b\u5e76\u4e0d\u591a\u3002 \u4e0e\u8bad\u7ec3\u811a\u672c\u76f8\u5173\u7684\u8fd8\u6709\u4e00\u70b9\u53ef\u4ee5\u6539\u8fdb\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u4e3a\u6bcf\u4e2a\u6298\u53e0\u591a\u6b21\u8c03\u7528\u8fd0\u884c\u51fd\u6570\u3002\u6709\u65f6\uff0c\u5728\u540c\u4e00\u4e2a\u811a\u672c\u4e2d\u8fd0\u884c\u591a\u4e2a\u6298\u53e0\u5e76\u4e0d\u53ef\u53d6\uff0c\u56e0\u4e3a\u5185\u5b58\u6d88\u8017\u53ef\u80fd\u4f1a\u4e0d\u65ad\u589e\u52a0\uff0c\u7a0b\u5e8f\u53ef\u80fd\u4f1a\u5d29\u6e83\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5411\u8bad\u7ec3\u811a\u672c\u4f20\u9012\u53c2\u6570\u3002\u6211\u559c\u6b22\u4f7f\u7528 argparse\u3002 import argparse if __name__ == \"__main__\" : # \u5b9e\u4f8b\u5316\u53c2\u6570\u73af\u5883 parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # \u8bfb\u53d6\u53c2\u6570 args = parser . parse_args () run ( fold = args . fold ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u8fd0\u884c python \u811a\u672c\uff0c\u4f46\u4ec5\u9650\u4e8e\u7ed9\u5b9a\u7684\u6298\u53e0\u3002 \u276f python train . py -- fold 0 Fold = 0 , Accuracy = 0.8656666666666667 \u4ed4\u7ec6\u89c2\u5bdf\uff0c\u6211\u4eec\u7684\u7b2c 0 \u6298\u5f97\u5206\u4e0e\u4e4b\u524d\u6709\u4e9b\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6a21\u578b\u4e2d\u5b58\u5728\u968f\u673a\u6027\u3002\u6211\u4eec\u5c06\u5728\u540e\u9762\u7684\u7ae0\u8282\u4e2d\u8ba8\u8bba\u5982\u4f55\u5904\u7406\u968f\u673a\u6027\u3002 \u73b0\u5728\uff0c\u5982\u679c\u4f60\u613f\u610f\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a shell \u811a\u672c \uff0c\u9488\u5bf9\u4e0d\u540c\u7684\u6298\u53e0\u4f7f\u7528\u4e0d\u540c\u7684\u547d\u4ee4\uff0c\u7136\u540e\u4e00\u8d77\u8fd0\u884c\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 python train . py -- fold 0 python train . py -- fold 1 python train . py -- fold 2 python train . py -- fold 3 python train . py -- fold 4 \u60a8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u5b83\u3002 \u276f sh run . sh Fold = 0 , Accuracy = 0.8675 Fold = 1 , Accuracy = 0.8693333333333333 Fold = 2 , Accuracy = 0.8683333333333333 Fold = 3 , Accuracy = 0.8704166666666666 Fold = 4 , Accuracy = 0.8685 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u53d6\u5f97\u4e86\u4e00\u4e9b\u8fdb\u5c55\uff0c\u4f46\u5982\u679c\u6211\u4eec\u770b\u4e00\u4e0b\u6211\u4eec\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u6211\u4eec\u4ecd\u7136\u53d7\u5230\u4e00\u4e9b\u4e1c\u897f\u7684\u9650\u5236\uff0c\u4f8b\u5982\u6a21\u578b\u3002\u6a21\u578b\u662f\u786c\u7f16\u7801\u5728\u8bad\u7ec3\u811a\u672c\u4e2d\u7684\uff0c\u53ea\u6709\u4fee\u6539\u811a\u672c\u624d\u80fd\u6539\u53d8\u5b83\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u65b0\u7684 python \u811a\u672c\uff0c\u540d\u4e3a model_dispatcher.py \u3002model_dispatcher.py\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5c06\u8c03\u5ea6\u6211\u4eec\u7684\u6a21\u578b\u5230\u8bad\u7ec3\u811a\u672c\u4e2d\u3002 from sklearn import tree models = { # \u4ee5gini\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), # \u4ee5entropy\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), } model_dispatcher.py \u4ece scikit-learn \u4e2d\u5bfc\u5165\u4e86 tree\uff0c\u5e76\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u6a21\u578b\u7684\u540d\u79f0\uff0c\u503c\u662f\u6a21\u578b\u672c\u8eab\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5b9a\u4e49\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u51b3\u7b56\u6811\uff0c\u4e00\u79cd\u4f7f\u7528\u57fa\u5c3c\u6807\u51c6\uff0c\u53e6\u4e00\u79cd\u4f7f\u7528\u71b5\u6807\u51c6\u3002\u8981\u4f7f\u7528 py\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u8bad\u7ec3\u811a\u672c\u505a\u4e00\u4e9b\u4fee\u6539\u3002 import argparse import os import joblib import pandas as pd from sklearn import metrics import config import model_dispatcher def run ( fold , model ): df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values # \u6839\u636emodel\u53c2\u6570\u9009\u62e9\u6a21\u578b clf = model_dispatcher . models [ model ] clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" )) if __name__ == \"__main__\" : parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # model\u53c2\u6570 parser . add_argument ( \"--model\" , type = str ) args = parser . parse_args () run ( fold = args . fold , model = args . model ) train.py \u6709\u51e0\u5904\u91cd\u5927\u6539\u52a8\uff1a - \u5bfc\u5165 model_dispatcher - \u4e3a ArgumentParser \u6dfb\u52a0 --model \u53c2\u6570 - \u4e3a run() \u51fd\u6570\u6dfb\u52a0model\u53c2\u6570 - \u4f7f\u7528\u8c03\u5ea6\u7a0b\u5e8f\u83b7\u53d6\u6307\u5b9a\u540d\u79f0\u7684\u6a21\u578b \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u811a\u672c\uff1a \u276f python train . py -- fold 0 -- model decision_tree_gini Fold = 0 , Accuracy = 0.8665833333333334 \u6216\u6267\u884c\u4ee5\u4e0b\u547d\u4ee4 \u276f python train . py -- fold 0 -- model decision_tree_entropy Fold = 0 , Accuracy = 0.8705833333333334 \u73b0\u5728\uff0c\u5982\u679c\u8981\u6dfb\u52a0\u65b0\u6a21\u578b\uff0c\u53ea\u9700\u4fee\u6539 model_dispatcher.py \u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u6dfb\u52a0\u968f\u673a\u68ee\u6797\uff0c\u770b\u770b\u51c6\u786e\u7387\u4f1a\u6709\u4ec0\u4e48\u53d8\u5316\u3002 from sklearn import ensemble from sklearn import tree models = { \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), # \u968f\u673a\u68ee\u6797\u6a21\u578b \"rf\" : ensemble . RandomForestClassifier (), } \u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u3002 \u276f python train . py -- fold 0 -- model rf Fold = 0 , Accuracy = 0.9670833333333333 \u54c7\uff0c\u4e00\u4e2a\u7b80\u5355\u7684\u6539\u52a8\u5c31\u80fd\u8ba9\u5206\u6570\u6709\u5982\u6b64\u5927\u7684\u63d0\u5347\uff01\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528 run.sh \u811a\u672c\u8fd0\u884c 5 \u4e2a\u6298\u53e0\uff01 python train . py -- fold 0 -- model rf python train . py -- fold 1 -- model rf python train . py -- fold 2 -- model rf python train . py -- fold 3 -- model rf python train . py -- fold 4 -- model rf \u5f97\u5206\u60c5\u51b5\u5982\u4e0b \u276f sh run . sh Fold = 0 , Accuracy = 0.9674166666666667 Fold = 1 , Accuracy = 0.9698333333333333 Fold = 2 , Accuracy = 0.96575 Fold = 3 , Accuracy = 0.9684166666666667 Fold = 4 , Accuracy = 0.9666666666666667 MNIST \u51e0\u4e4e\u662f\u6bcf\u672c\u4e66\u548c\u6bcf\u7bc7\u535a\u5ba2\u90fd\u4f1a\u8ba8\u8bba\u7684\u95ee\u9898\u3002\u4f46\u6211\u8bd5\u56fe\u5c06\u8fd9\u4e2a\u95ee\u9898\u8f6c\u6362\u5f97\u66f4\u6709\u8da3\uff0c\u5e76\u5411\u4f60\u5c55\u793a\u5982\u4f55\u4e3a\u4f60\u6b63\u5728\u505a\u7684\u6216\u8ba1\u5212\u5728\u4e0d\u4e45\u7684\u5c06\u6765\u505a\u7684\u51e0\u4e4e\u6240\u6709\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7f16\u5199\u4e00\u4e2a\u57fa\u672c\u6846\u67b6\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4ee5\u6539\u8fdb\u8fd9\u4e2a MNIST \u6a21\u578b\u548c\u8fd9\u4e2a\u6846\u67b6\uff0c\u6211\u4eec\u5c06\u5728\u4ee5\u540e\u7684\u7ae0\u8282\u4e2d\u770b\u5230\u3002 \u6211\u4f7f\u7528\u4e86\u4e00\u4e9b\u811a\u672c\uff0c\u5982 model_dispatcher.py \u548c config.py \uff0c\u5e76\u5c06\u5b83\u4eec\u5bfc\u5165\u5230\u6211\u7684\u8bad\u7ec3\u811a\u672c\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u6ca1\u6709\u5bfc\u5165 \uff0c\u4f60\u4e5f\u4e0d\u5e94\u8be5\u5bfc\u5165\u3002\u5982\u679c\u6211\u5bfc\u5165\u4e86 \uff0c\u4f60\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u77e5\u9053\u6a21\u578b\u5b57\u5178\u662f\u4ece\u54ea\u91cc\u6765\u7684\u3002\u7f16\u5199\u4f18\u79c0\u3001\u6613\u61c2\u7684\u4ee3\u7801\u662f\u4e00\u4e2a\u4eba\u5fc5\u987b\u5177\u5907\u7684\u57fa\u672c\u7d20\u8d28\uff0c\u4f46\u8bb8\u591a\u6570\u636e\u79d1\u5b66\u5bb6\u5374\u5ffd\u89c6\u4e86\u8fd9\u4e00\u70b9\u3002\u5982\u679c\u4f60\u6240\u505a\u7684\u9879\u76ee\u80fd\u8ba9\u5176\u4ed6\u4eba\u7406\u89e3\u5e76\u4f7f\u7528\uff0c\u800c\u65e0\u9700\u54a8\u8be2\u4f60\u7684\u610f\u89c1\uff0c\u90a3\u4e48\u4f60\u5c31\u8282\u7701\u4e86\u4ed6\u4eec\u7684\u65f6\u95f4\u548c\u81ea\u5df1\u7684\u65f6\u95f4\uff0c\u53ef\u4ee5\u5c06\u8fd9\u4e9b\u65f6\u95f4\u6295\u5165\u5230\u6539\u8fdb\u4f60\u7684\u9879\u76ee\u6216\u5f00\u53d1\u65b0\u9879\u76ee\u4e2d\u53bb\u3002","title":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee"},{"location":"%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87/","text":"\u8bc4\u4f30\u6307\u6807 \u8bf4\u5230\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u4f60\u4f1a\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u9047\u5230\u5f88\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u6307\u6807\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u751a\u81f3\u4f1a\u6839\u636e\u4e1a\u52a1\u95ee\u9898\u521b\u5efa\u5ea6\u91cf\u6807\u51c6\u3002\u9010\u4e00\u4ecb\u7ecd\u548c\u89e3\u91ca\u6bcf\u4e00\u79cd\u5ea6\u91cf\u7c7b\u578b\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u5ea6\u91cf\u6807\u51c6\uff0c\u4f9b\u4f60\u5728\u6700\u521d\u7684\u51e0\u4e2a\u9879\u76ee\u4e2d\u4f7f\u7528\u3002 \u5728\u672c\u4e66\u7684\u5f00\u5934\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u76d1\u7763\u5b66\u4e60\u548c\u975e\u76d1\u7763\u5b66\u4e60\u3002\u867d\u7136\u65e0\u76d1\u7763\u5b66\u4e60\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e9b\u6307\u6807\uff0c\u4f46\u6211\u4eec\u5c06\u53ea\u5173\u6ce8\u6709\u76d1\u7763\u5b66\u4e60\u3002\u8fd9\u662f\u56e0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u591a\uff0c\u800c\u4e14\u5bf9\u65e0\u76d1\u7763\u65b9\u6cd5\u7684\u8bc4\u4f30\u76f8\u5f53\u4e3b\u89c2\u3002 \u5982\u679c\u6211\u4eec\u8c08\u8bba\u5206\u7c7b\u95ee\u9898\uff0c\u6700\u5e38\u7528\u7684\u6307\u6807\u662f\uff1a \u51c6\u786e\u7387\uff08Accuracy\uff09 \u7cbe\u786e\u7387\uff08P\uff09 \u53ec\u56de\u7387\uff08R\uff09 F1 \u5206\u6570\uff08F1\uff09 AUC\uff08AUC\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u8bf4\u5230\u56de\u5f52\uff0c\u6700\u5e38\u7528\u7684\u8bc4\u4ef7\u6307\u6807\u662f \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee \uff08MAE\uff09 \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u5747\u65b9\u6839\u8bef\u5dee \uff08RMSE\uff09 \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \uff08RMSLE\uff09 \u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee \uff08MPE\uff09 \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee \uff08MAPE\uff09 R2 \u4e86\u89e3\u4e0a\u8ff0\u6307\u6807\u7684\u5de5\u4f5c\u539f\u7406\u5e76\u4e0d\u662f\u6211\u4eec\u5fc5\u987b\u4e86\u89e3\u7684\u552f\u4e00\u4e8b\u60c5\u3002\u6211\u4eec\u8fd8\u5fc5\u987b\u77e5\u9053\u4f55\u65f6\u4f7f\u7528\u54ea\u4e9b\u6307\u6807\uff0c\u800c\u8fd9\u53d6\u51b3\u4e8e\u4f60\u6709\u4ec0\u4e48\u6837\u7684\u6570\u636e\u548c\u76ee\u6807\u3002\u6211\u8ba4\u4e3a\u8fd9\u4e0e\u76ee\u6807\u6709\u5173\uff0c\u800c\u4e0e\u6570\u636e\u65e0\u5173\u3002 \u8981\u8fdb\u4e00\u6b65\u4e86\u89e3\u8fd9\u4e9b\u6307\u6807\uff0c\u8ba9\u6211\u4eec\u4ece\u4e00\u4e2a\u7b80\u5355\u7684\u95ee\u9898\u5f00\u59cb\u3002\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a \u4e8c\u5143\u5206\u7c7b \u95ee\u9898\uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u76ee\u6807\u7684\u95ee\u9898\uff0c\u5047\u8bbe\u8fd9\u662f\u4e00\u4e2a\u80f8\u90e8 X \u5149\u56fe\u50cf\u5206\u7c7b\u95ee\u9898\u3002\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6ca1\u6709\u95ee\u9898\uff0c\u800c\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6709\u80ba\u584c\u9677\uff0c\u4e5f\u5c31\u662f\u6240\u8c13\u7684\u6c14\u80f8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u5206\u7c7b\u5668\uff0c\u5728\u7ed9\u5b9a\u80f8\u90e8 X \u5149\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\uff0c\u5b83\u80fd\u68c0\u6d4b\u51fa\u56fe\u50cf\u662f\u5426\u6709\u6c14\u80f8\u3002 \u56fe 1\uff1a\u6c14\u80f8\u80ba\u90e8\u56fe\u50cf \u6211\u4eec\u8fd8\u5047\u8bbe\u6709\u76f8\u540c\u6570\u91cf\u7684\u6c14\u80f8\u548c\u975e\u6c14\u80f8\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u6bd4\u5982\u5404 100 \u5f20\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709 100 \u5f20\u9633\u6027\u6837\u672c\u548c 100 \u5f20\u9634\u6027\u6837\u672c\uff0c\u5171\u8ba1 200 \u5f20\u56fe\u50cf\u3002 \u7b2c\u4e00\u6b65\u662f\u5c06\u4e0a\u8ff0\u6570\u636e\u5206\u4e3a\u4e24\u7ec4\uff0c\u6bcf\u7ec4 100 \u5f20\u56fe\u50cf\uff0c\u5373\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u8fd9\u4e24\u4e2a\u96c6\u5408\u4e2d\uff0c\u6211\u4eec\u90fd\u6709 50 \u4e2a\u6b63\u6837\u672c\u548c 50 \u4e2a\u8d1f\u6837\u672c\u3002 \u5728\u4e8c\u5143\u5206\u7c7b\u6307\u6807\u4e2d\uff0c\u5f53\u6b63\u8d1f\u6837\u672c\u6570\u91cf\u76f8\u7b49\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f7f\u7528\u51c6\u786e\u7387\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002 \u51c6\u786e\u7387 \uff1a\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6700\u76f4\u63a5\u7684\u6307\u6807\u4e4b\u4e00\u3002\u5b83\u5b9a\u4e49\u4e86\u6a21\u578b\u7684\u51c6\u786e\u5ea6\u3002\u5bf9\u4e8e\u4e0a\u8ff0\u95ee\u9898\uff0c\u5982\u679c\u4f60\u5efa\u7acb\u7684\u6a21\u578b\u80fd\u51c6\u786e\u5206\u7c7b 90 \u5f20\u56fe\u7247\uff0c\u90a3\u4e48\u4f60\u7684\u51c6\u786e\u7387\u5c31\u662f 90% \u6216 0.90\u3002\u5982\u679c\u53ea\u6709 83 \u5e45\u56fe\u50cf\u88ab\u6b63\u786e\u5206\u7c7b\uff0c\u90a3\u4e48\u6a21\u578b\u7684\u51c6\u786e\u7387\u5c31\u662f 83% \u6216 0.83\u3002 \u8ba1\u7b97\u51c6\u786e\u7387\u7684 Python \u4ee3\u7801\u4e5f\u975e\u5e38\u7b80\u5355\u3002 def accuracy ( y_true , y_pred ): # \u4e3a\u6b63\u786e\u9884\u6d4b\u6570\u521d\u59cb\u5316\u4e00\u4e2a\u7b80\u5355\u8ba1\u6570\u5668 correct_counter = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): if yt == yp : # \u5982\u679c\u9884\u6d4b\u6807\u7b7e\u4e0e\u771f\u5b9e\u6807\u7b7e\u76f8\u540c\uff0c\u5219\u589e\u52a0\u8ba1\u6570\u5668 correct_counter += 1 # \u8fd4\u56de\u6b63\u786e\u7387\uff0c\u6b63\u786e\u6807\u7b7e\u6570/\u603b\u6807\u7b7e\u6570 return correct_counter / len ( y_true ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u8ba1\u7b97\u51c6\u786e\u7387\u3002 In [ X ]: from sklearn import metrics ... : l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] ... : metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u73b0\u5728\uff0c\u5047\u8bbe\u6211\u4eec\u628a\u6570\u636e\u96c6\u7a0d\u5fae\u6539\u52a8\u4e00\u4e0b\uff0c\u6709 180 \u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u53ea\u6709 20 \u5f20\u6709\u6c14\u80f8\u3002\u5373\u4f7f\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4e5f\u8981\u521b\u5efa\u6b63\u8d1f\uff08\u6c14\u80f8\u4e0e\u975e\u6c14\u80f8\uff09\u76ee\u6807\u6bd4\u4f8b\u76f8\u540c\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u6bcf\u4e00\u7ec4\u4e2d\uff0c\u6211\u4eec\u6709 90 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u5982\u679c\u8bf4\u9a8c\u8bc1\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u90fd\u662f\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u90a3\u4e48\u60a8\u7684\u51c6\u786e\u7387\u4f1a\u662f\u591a\u5c11\u5462\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\uff1b\u60a8\u5bf9 90% \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u6b63\u786e\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u7684\u51c6\u786e\u7387\u662f 90%\u3002 \u4f46\u8bf7\u518d\u770b\u4e00\u904d\u3002 \u4f60\u751a\u81f3\u6ca1\u6709\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5c31\u5f97\u5230\u4e86 90% \u7684\u51c6\u786e\u7387\u3002\u8fd9\u4f3c\u4e4e\u6709\u70b9\u6ca1\u7528\u3002\u5982\u679c\u6211\u4eec\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u6570\u636e\u96c6\u662f\u504f\u659c\u7684\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u6bd4\u53e6\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u591a\u5f88\u591a\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u56e0\u4e3a\u5b83\u4e0d\u80fd\u4ee3\u8868\u6570\u636e\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u80fd\u4f1a\u83b7\u5f97\u5f88\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u4f46\u60a8\u7684\u6a21\u578b\u5728\u5b9e\u9645\u6837\u672c\u4e2d\u7684\u8868\u73b0\u53ef\u80fd\u5e76\u4e0d\u7406\u60f3\uff0c\u800c\u4e14\u60a8\u4e5f\u65e0\u6cd5\u5411\u7ecf\u7406\u89e3\u91ca\u539f\u56e0\u3002 \u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b \u7cbe\u786e\u7387 \u7b49\u5176\u4ed6\u6307\u6807\u3002 \u5728\u5b66\u4e60\u7cbe\u786e\u7387\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e86\u89e3\u4e00\u4e9b\u672f\u8bed\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5047\u8bbe\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e3a\u6b63\u7c7b (1)\uff0c\u6ca1\u6709\u6c14\u80f8\u7684\u4e3a\u8d1f\u7c7b (0)\u3002 \u771f\u9633\u6027 \uff08TP\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6709\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9633\u6027\u3002 \u771f\u9634\u6027 \uff08TN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u800c\u5b9e\u9645\u76ee\u6807\u663e\u793a\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u6b63\u786e\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9633\u6027\uff1b\u5982\u679c\u60a8\u7684\u6a21\u578b\u51c6\u786e\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9634\u6027\u3002 \u5047\u9633\u6027 \uff08FP\uff09 \uff1a\u7ed9\u5b9a\u4e00\u5f20\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u975e\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9633\u6027\u3002 \u5047\u9634\u6027 \uff08FN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u975e\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\uff0c\u90a3\u4e48\u5b83\u5c31\u662f\u5047\u9633\u6027\u3002\u5982\u679c\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5219\u662f\u5047\u9634\u6027\u3002 \u8ba9\u6211\u4eec\u9010\u4e00\u770b\u770b\u8fd9\u4e9b\u5b9e\u73b0\u3002 def true_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u8ba1\u6570\u5668 tp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 1 : tp += 1 # \u8fd4\u56de\u771f\u9633\u6027\u6837\u672c\u6570 return tp def true_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9634\u6027\u6837\u672c\u8ba1\u6570\u5668 tn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 0 : tn += 1 # \u8fd4\u56de\u771f\u9634\u6027\u6837\u672c\u6570 return tn def false_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9633\u6027\u8ba1\u6570\u5668 fp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 1 : fp += 1 # \u8fd4\u56de\u5047\u9633\u6027\u6837\u672c\u6570 return fp def false_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9634\u6027\u8ba1\u6570\u5668 fn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 0 : fn += 1 # \u8fd4\u56de\u5047\u9634\u6027\u6570 return fn \u6211\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e9b\u529f\u80fd\u7684\u65b9\u6cd5\u975e\u5e38\u7b80\u5355\uff0c\u800c\u4e14\u53ea\u9002\u7528\u4e8e\u4e8c\u5143\u5206\u7c7b\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u8fd9\u4e9b\u51fd\u6570\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: true_positive ( l1 , l2 ) Out [ X ]: 2 In [ X ]: false_positive ( l1 , l2 ) Out [ X ]: 1 In [ X ]: false_negative ( l1 , l2 ) Out [ X ]: 2 In [ X ]: true_negative ( l1 , l2 ) Out [ X ]: 3 \u5982\u679c\u6211\u4eec\u5fc5\u987b\u7528\u4e0a\u8ff0\u672f\u8bed\u6765\u5b9a\u4e49\u7cbe\u786e\u7387\uff0c\u6211\u4eec\u53ef\u4ee5\u5199\u4e3a\uff1a \\[ Accuracy Score = (TP + TN)/(TP + TN + FP +FN) \\] \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5728 python \u4e2d\u4f7f\u7528 TP\u3001TN\u3001FP \u548c FN \u5feb\u901f\u5b9e\u73b0\u51c6\u786e\u5ea6\u5f97\u5206\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a accuracy_v2\u3002 def accuracy_v2 ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u51c6\u786e\u7387 accuracy_score = ( tp + tn ) / ( tp + tn + fp + fn ) return accuracy_score \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0e\u4e4b\u524d\u7684\u5b9e\u73b0\u548c scikit-learn \u7248\u672c\u8fdb\u884c\u6bd4\u8f83\uff0c\u5feb\u901f\u68c0\u67e5\u8be5\u51fd\u6570\u7684\u6b63\u786e\u6027\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: accuracy ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: accuracy_v2 ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u8bf7\u6ce8\u610f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0cmetrics.accuracy_score \u6765\u81ea scikit-learn\u3002 \u5f88\u597d\u3002\u6240\u6709\u503c\u90fd\u5339\u914d\u3002\u8fd9\u8bf4\u660e\u6211\u4eec\u5728\u5b9e\u73b0\u8fc7\u7a0b\u4e2d\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u5176\u4ed6\u91cd\u8981\u6307\u6807\u3002 \u9996\u5148\u662f\u7cbe\u786e\u7387\u3002\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u662f \\[ Precision = TP/(TP + FP) \\] \u5047\u8bbe\u6211\u4eec\u5728\u65b0\u7684\u504f\u659c\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u4e86\u4e00\u4e2a\u65b0\u6a21\u578b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 90 \u5f20\u56fe\u50cf\u4e2d\u7684 80 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u56fe\u50cf\u4e2d\u7684 8 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6210\u529f\u8bc6\u522b\u4e86 100 \u5f20\u56fe\u50cf\u4e2d\u7684 88 \u5f20\u3002\u56e0\u6b64\uff0c\u51c6\u786e\u7387\u4e3a 0.88 \u6216 88%\u3002 \u4f46\u662f\uff0c\u5728\u8fd9 100 \u5f20\u6837\u672c\u4e2d\uff0c\u6709 10 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u6c14\u80f8\uff0c2 \u5f20\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u975e\u6c14\u80f8\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 TP : 8 TN: 80 FP: 10 FN: 2 \u7cbe\u786e\u7387\u4e3a 8 / (8 + 10) = 0.444\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u5728\u8bc6\u522b\u9633\u6027\u6837\u672c\uff08\u6c14\u80f8\uff09\u65f6\u6709 44.4% \u7684\u6b63\u786e\u7387\u3002 \u73b0\u5728\uff0c\u65e2\u7136\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86 TP\u3001TN\u3001FP \u548c FN\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5728 python \u4e2d\u5b9e\u73b0\u7cbe\u786e\u7387\u4e86\u3002 def precision ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8ba9\u6211\u4eec\u8bd5\u8bd5\u8fd9\u79cd\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u65b9\u5f0f\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: precision ( l1 , l2 ) Out [ X ]: 0.6666666666666666 \u8fd9\u4f3c\u4e4e\u6ca1\u6709\u95ee\u9898\u3002 \u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u770b \u53ec\u56de\u7387 \u3002\u53ec\u56de\u7387\u7684\u5b9a\u4e49\u662f\uff1a \\[ Recall = TP/(TP + FN) \\] \u5728\u4e0a\u8ff0\u60c5\u51b5\u4e0b\uff0c\u53ec\u56de\u7387\u4e3a 8 / (8 + 2) = 0.80\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 80% \u7684\u9633\u6027\u6837\u672c\u3002 def recall ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u53ec\u56de\u7387 recall = tp / ( tp + fn ) return recall \u5c31\u6211\u4eec\u7684\u4e24\u4e2a\u5c0f\u5217\u8868\u800c\u8a00\uff0c\u53ec\u56de\u7387\u5e94\u8be5\u662f 0.5\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: recall ( l1 , l2 ) Out [ X ]: 0.5 \u8fd9\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u503c\u76f8\u7b26\uff01 \u5bf9\u4e8e\u4e00\u4e2a \"\u597d \"\u6a21\u578b\u6765\u8bf4\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u503c\u90fd\u5e94\u8be5\u5f88\u9ad8\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0c\u53ec\u56de\u503c\u76f8\u5f53\u9ad8\u3002\u4f46\u662f\uff0c\u7cbe\u786e\u7387\u5374\u5f88\u4f4e\uff01\u6211\u4eec\u7684\u6a21\u578b\u4ea7\u751f\u4e86\u5927\u91cf\u7684\u8bef\u62a5\uff0c\u4f46\u8bef\u62a5\u8f83\u5c11\u3002\u5728\u8fd9\u7c7b\u95ee\u9898\u4e2d\uff0c\u5047\u9634\u6027\u8f83\u5c11\u662f\u597d\u4e8b\uff0c\u56e0\u4e3a\u4f60\u4e0d\u60f3\u5728\u75c5\u4eba\u6709\u6c14\u80f8\u7684\u60c5\u51b5\u4e0b\u5374\u8bf4\u4ed6\u4eec\u6ca1\u6709\u6c14\u80f8\u3002\u8fd9\u6837\u505a\u4f1a\u9020\u6210\u66f4\u5927\u7684\u4f24\u5bb3\u3002\u4f46\u6211\u4eec\u4e5f\u6709\u5f88\u591a\u5047\u9633\u6027\u7ed3\u679c\uff0c\u8fd9\u4e5f\u4e0d\u662f\u597d\u4e8b\u3002 \u5927\u591a\u6570\u6a21\u578b\u90fd\u4f1a\u9884\u6d4b\u4e00\u4e2a\u6982\u7387\uff0c\u5f53\u6211\u4eec\u9884\u6d4b\u65f6\uff0c\u901a\u5e38\u4f1a\u5c06\u8fd9\u4e2a\u9608\u503c\u9009\u4e3a 0.5\u3002\u8fd9\u4e2a\u9608\u503c\u5e76\u4e0d\u603b\u662f\u7406\u60f3\u7684\uff0c\u6839\u636e\u8fd9\u4e2a\u9608\u503c\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u503c\u53ef\u80fd\u4f1a\u53d1\u751f\u5f88\u5927\u7684\u53d8\u5316\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u7684\u6bcf\u4e2a\u9608\u503c\u90fd\u80fd\u8ba1\u7b97\u51fa\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u53ef\u4ee5\u5728\u8fd9\u4e9b\u503c\u4e4b\u95f4\u7ed8\u5236\u51fa\u66f2\u7ebf\u56fe\u3002\u8fd9\u5e45\u56fe\u6216\u66f2\u7ebf\u88ab\u79f0\u4e3a \"\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\"\u3002 \u5728\u7814\u7a76\u7cbe\u786e\u7387-\u8c03\u7528\u66f2\u7ebf\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5047\u8bbe\u6709\u4e24\u4e2a\u5217\u8868\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0.02638412 , 0.11114267 , 0.31620708 , ... : 0.0490937 , 0.0191491 , 0.17554844 , ... : 0.15952202 , 0.03819563 , 0.11639273 , ... : 0.079377 , 0.08584789 , 0.39095342 , ... : 0.27259048 , 0.03447096 , 0.04644807 , ... : 0.03543574 , 0.18521942 , 0.05934905 , ... : 0.61977213 , 0.33056815 ] \u56e0\u6b64\uff0cy_true \u662f\u6211\u4eec\u7684\u76ee\u6807\u503c\uff0c\u800c y_pred \u662f\u6837\u672c\u88ab\u8d4b\u503c\u4e3a 1 \u7684\u6982\u7387\u503c\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u6211\u4eec\u8981\u770b\u7684\u662f\u9884\u6d4b\u4e2d\u7684\u6982\u7387\uff0c\u800c\u4e0d\u662f\u9884\u6d4b\u503c\uff08\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u9884\u6d4b\u503c\u7684\u8ba1\u7b97\u9608\u503c\u4e3a 0.5\uff09\u3002 precisions = [] recalls = [] thresholds = [ 0.0490937 , 0.05934905 , 0.079377 , 0.08584789 , 0.11114267 , 0.11639273 , 0.15952202 , 0.17554844 , 0.18521942 , 0.27259048 , 0.31620708 , 0.33056815 , 0.39095342 , 0.61977213 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for i in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_prediction = [ 1 if x >= i else 0 for x in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , temp_prediction ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , temp_prediction ) # \u52a0\u5165\u7cbe\u786e\u7387\u5217\u8868 precisions . append ( p ) # \u52a0\u5165\u53ec\u56de\u7387\u5217\u8868 recalls . append ( r ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7ed8\u5236\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 7 , 7 )) # x\u8f74\u4e3a\u53ec\u56de\u7387\uff0cy\u8f74\u4e3a\u7cbe\u786e\u7387 plt . plot ( recalls , precisions ) # \u6dfb\u52a0x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a15 plt . xlabel ( 'Recall' , fontsize = 15 ) # \u6dfb\u52a0y\u8f74\u6807\u7b7e\uff0c\u5b57\u6761\u5927\u5c0f\u4e3a15 plt . ylabel ( 'Precision' , fontsize = 15 ) \u56fe 2 \u663e\u793a\u4e86\u6211\u4eec\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\u5f97\u5230\u7684\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 \u56fe 2\uff1a\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u8fd9\u6761 \u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u4e0e\u60a8\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230\u7684\u66f2\u7ebf\u622a\u7136\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6211\u4eec\u53ea\u6709 20 \u4e2a\u6837\u672c\uff0c\u5176\u4e2d\u53ea\u6709 3 \u4e2a\u662f\u9633\u6027\u6837\u672c\u3002\u4f46\u8fd9\u6ca1\u4ec0\u4e48\u597d\u62c5\u5fc3\u7684\u3002\u8fd9\u8fd8\u662f\u90a3\u6761\u7cbe\u786e\u7387-\u53ec\u56de\u66f2\u7ebf\u3002 \u4f60\u4f1a\u53d1\u73b0\uff0c\u9009\u62e9\u4e00\u4e2a\u65e2\u80fd\u63d0\u4f9b\u826f\u597d\u7cbe\u786e\u7387\u53c8\u80fd\u63d0\u4f9b\u53ec\u56de\u503c\u7684\u9608\u503c\u662f\u5f88\u6709\u6311\u6218\u6027\u7684\u3002\u5982\u679c\u9608\u503c\u8fc7\u9ad8\uff0c\u771f\u9633\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u51cf\u5c11\uff0c\u800c\u5047\u9634\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u589e\u52a0\u3002\u8fd9\u4f1a\u964d\u4f4e\u53ec\u56de\u7387\uff0c\u4f46\u7cbe\u786e\u7387\u5f97\u5206\u4f1a\u5f88\u9ad8\u3002\u5982\u679c\u5c06\u9608\u503c\u964d\u5f97\u592a\u4f4e\uff0c\u5219\u8bef\u62a5\u4f1a\u5927\u91cf\u589e\u52a0\uff0c\u7cbe\u786e\u7387\u4e5f\u4f1a\u964d\u4f4e\u3002 \u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u8d8a\u63a5\u8fd1 1 \u8d8a\u597d\u3002 F1 \u5206\u6570\u662f\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7efc\u5408\u6307\u6807\u3002\u5b83\u88ab\u5b9a\u4e49\u4e3a\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7b80\u5355\u52a0\u6743\u5e73\u5747\u503c\uff08\u8c03\u548c\u5e73\u5747\u503c\uff09\u3002\u5982\u679c\u6211\u4eec\u7528 P \u8868\u793a\u7cbe\u786e\u7387\uff0c\u7528 R \u8868\u793a\u53ec\u56de\u7387\uff0c\u90a3\u4e48 F1 \u5206\u6570\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a \\[ F1 = 2PR/(P + R) \\] \u6839\u636e TP\u3001FP \u548c FN\uff0c\u7a0d\u52a0\u6570\u5b66\u8ba1\u7b97\u5c31\u80fd\u5f97\u51fa\u4ee5\u4e0b F1 \u7b49\u5f0f\uff1a \\[ F1 = 2TP/(2TP + FP + FN) \\] Python \u5b9e\u73b0\u5f88\u7b80\u5355\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86\u8fd9\u4e9b def f1 ( y_true , y_pred ): # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , y_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , y_pred ) # \u8ba1\u7b97f1\u503c score = 2 * p * r / ( p + r ) return score \u8ba9\u6211\u4eec\u770b\u770b\u5176\u7ed3\u679c\uff0c\u5e76\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: f1 ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u901a\u8fc7 scikit learn\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u76f8\u540c\u7684\u5217\u8868\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . f1_score ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u4e0e\u5176\u5355\u72ec\u770b\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u60a8\u8fd8\u53ef\u4ee5\u53ea\u770b F1 \u5206\u6570\u3002\u4e0e\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c\u51c6\u786e\u5ea6\u4e00\u6837\uff0cF1 \u5206\u6570\u7684\u8303\u56f4\u4e5f\u662f\u4ece 0 \u5230 1\uff0c\u5b8c\u7f8e\u9884\u6d4b\u6a21\u578b\u7684 F1 \u5206\u6570\u4e3a 1\u3002 \u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5e94\u8be5\u4e86\u89e3\u5176\u4ed6\u4e00\u4e9b\u5173\u952e\u672f\u8bed\u3002 \u7b2c\u4e00\u4e2a\u672f\u8bed\u662f TPR \u6216\u771f\u9633\u6027\u7387\uff08True Positive Rate\uff09\uff0c\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\u3002 \\[ TPR = TP/(TP + FN) \\] \u5c3d\u7ba1\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\uff0c\u4f46\u6211\u4eec\u5c06\u4e3a\u5b83\u521b\u5efa\u4e00\u4e2a python \u51fd\u6570\uff0c\u4ee5\u4fbf\u4eca\u540e\u4f7f\u7528\u8fd9\u4e2a\u540d\u79f0\u3002 def tpr ( y_true , y_pred ): # \u771f\u9633\u6027\u7387\uff08TPR\uff09\uff0c\u4e0e\u53ec\u56de\u7387\u8ba1\u7b97\u516c\u5f0f\u4e00\u81f4 return recall ( y_true , y_pred ) TPR \u6216\u53ec\u56de\u7387\u4e5f\u88ab\u79f0\u4e3a\u7075\u654f\u5ea6\u3002 \u800c FPR \u6216\u5047\u9633\u6027\u7387\uff08False Positive Rate\uff09\u7684\u5b9a\u4e49\u662f\uff1a \\[ FPR = FP / (TN + FP) \\] def fpr ( y_true , y_pred ): # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u8fd4\u56de\u5047\u9633\u6027\u7387\uff08FPR\uff09 return fp / ( tn + fp ) 1 - FPR \u88ab\u79f0\u4e3a\u7279\u5f02\u6027\u6216\u771f\u9634\u6027\u7387\u6216 TNR\u3002\u8fd9\u4e9b\u672f\u8bed\u5f88\u591a\uff0c\u4f46\u5176\u4e2d\u6700\u91cd\u8981\u7684\u53ea\u6709 TPR \u548c FPR\u3002\u5047\u8bbe\u6211\u4eec\u53ea\u6709 15 \u4e2a\u6837\u672c\uff0c\u5176\u76ee\u6807\u503c\u4e3a\u4e8c\u5143\uff1a Actual targets : [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1] \u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u7c7b\u4f3c\u968f\u673a\u68ee\u6797\u7684\u6a21\u578b\uff0c\u5c31\u80fd\u5f97\u5230\u6837\u672c\u5448\u9633\u6027\u7684\u6982\u7387\u3002 Predicted probabilities for 1: [0.1, 0.3, 0.2, 0.6, 0.8, 0.05, 0.9, 0.5, 0.3, 0.66, 0.3, 0.2, 0.85, 0.15, 0.99] \u5bf9\u4e8e >= 0.5 \u7684\u5178\u578b\u9608\u503c\uff0c\u6211\u4eec\u53ef\u4ee5\u8bc4\u4f30\u4e0a\u8ff0\u6240\u6709\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387/TPR\u3001F1 \u548c FPR \u503c\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u5c06\u9608\u503c\u9009\u4e3a 0.4 \u6216 0.6\uff0c\u4e5f\u53ef\u4ee5\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0 \u5230 1 \u4e4b\u95f4\u7684\u4efb\u4f55\u503c\uff0c\u5e76\u8ba1\u7b97\u4e0a\u8ff0\u6240\u6709\u6307\u6807\u3002 \u4e0d\u8fc7\uff0c\u6211\u4eec\u53ea\u8ba1\u7b97\u4e24\u4e2a\u503c\uff1a TPR \u548c FPR\u3002 # \u521d\u59cb\u5316\u771f\u9633\u6027\u7387\u5217\u8868 tpr_list = [] # \u521d\u59cb\u5316\u5047\u9633\u6027\u7387\u5217\u8868 fpr_list = [] # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u7387 temp_tpr = tpr ( y_true , temp_pred ) # \u5047\u9633\u6027\u7387 temp_fpr = fpr ( y_true , temp_pred ) # \u5c06\u771f\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 tpr_list . append ( temp_tpr ) # \u5c06\u5047\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 fpr_list . append ( temp_fpr ) \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u6bcf\u4e2a\u9608\u503c\u7684 TPR \u503c\u548c FPR \u503c\u3002 \u56fe 3\uff1a\u9608\u503c\u3001TPR \u548c FPR \u503c\u8868 \u5982\u679c\u6211\u4eec\u7ed8\u5236\u5982\u56fe 3 \u6240\u793a\u7684\u8868\u683c\uff0c\u5373\u4ee5 TPR \u4e3a Y \u8f74\uff0cFPR \u4e3a X \u8f74\uff0c\u5c31\u4f1a\u5f97\u5230\u5982\u56fe 4 \u6240\u793a\u7684\u66f2\u7ebf\u3002 \u56fe 4\uff1aROC\u66f2\u7ebf \u8fd9\u6761\u66f2\u7ebf\u4e5f\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u3002\u5982\u679c\u6211\u4eec\u8ba1\u7b97\u8fd9\u6761 ROC \u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff0c\u5c31\u662f\u5728\u8ba1\u7b97\u53e6\u4e00\u4e2a\u6307\u6807\uff0c\u5f53\u6570\u636e\u96c6\u7684\u4e8c\u5143\u76ee\u6807\u504f\u659c\u65f6\uff0c\u8fd9\u4e2a\u6307\u6807\u5c31\u4f1a\u975e\u5e38\u5e38\u7528\u3002 \u8fd9\u4e2a\u6307\u6807\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u6216\u66f2\u7ebf\u4e0b\u9762\u79ef\uff0c\u7b80\u79f0 AUC\u3002\u8ba1\u7b97 ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002\u5728\u6b64\uff0c\u6211\u4eec\u5c06\u91c7\u7528 scikit- learn \u7684\u5947\u5999\u5b9e\u73b0\u65b9\u6cd5\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: metrics . roc_auc_score ( y_true , y_pred ) Out [ X ]: 0.8300000000000001 AUC \u503c\u4ece 0 \u5230 1 \u4e0d\u7b49\u3002 AUC = 1 \u610f\u5473\u7740\u60a8\u62e5\u6709\u4e00\u4e2a\u5b8c\u7f8e\u7684\u6a21\u578b\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u610f\u5473\u7740\u4f60\u5728\u9a8c\u8bc1\u65f6\u72af\u4e86\u4e00\u4e9b\u9519\u8bef\uff0c\u5e94\u8be5\u91cd\u65b0\u5ba1\u89c6\u6570\u636e\u5904\u7406\u548c\u9a8c\u8bc1\u6d41\u7a0b\u3002\u5982\u679c\u4f60\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\uff0c\u90a3\u4e48\u606d\u559c\u4f60\uff0c\u4f60\u5df2\u7ecf\u62e5\u6709\u4e86\u9488\u5bf9\u6570\u636e\u96c6\u5efa\u7acb\u7684\u6700\u4f73\u6a21\u578b\u3002 AUC = 0 \u610f\u5473\u7740\u60a8\u7684\u6a21\u578b\u975e\u5e38\u7cdf\u7cd5\uff08\u6216\u975e\u5e38\u597d\uff01\uff09\u3002\u8bd5\u7740\u53cd\u8f6c\u9884\u6d4b\u7684\u6982\u7387\uff0c\u4f8b\u5982\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u6b63\u7c7b\u7684\u6982\u7387\u662f p\uff0c\u8bd5\u7740\u7528 1-p \u4ee3\u66ff\u5b83\u3002\u8fd9\u79cd AUC \u4e5f\u53ef\u80fd\u610f\u5473\u7740\u60a8\u7684\u9a8c\u8bc1\u6216\u6570\u636e\u5904\u7406\u5b58\u5728\u95ee\u9898\u3002 AUC = 0.5 \u610f\u5473\u7740\u4f60\u7684\u9884\u6d4b\u662f\u968f\u673a\u7684\u3002\u56e0\u6b64\uff0c\u5bf9\u4e8e\u4efb\u4f55\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u5982\u679c\u6211\u5c06\u6240\u6709\u76ee\u6807\u90fd\u9884\u6d4b\u4e3a 0.5\uff0c\u6211\u5c06\u5f97\u5230 0.5 \u7684 AUC\u3002 AUC \u503c\u4ecb\u4e8e 0 \u548c 0.5 \u4e4b\u95f4\uff0c\u610f\u5473\u7740\u4f60\u7684\u6a21\u578b\u6bd4\u968f\u673a\u6a21\u578b\u66f4\u5dee\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u662f\u56e0\u4e3a\u4f60\u98a0\u5012\u4e86\u7c7b\u522b\u3002 \u5982\u679c\u60a8\u5c1d\u8bd5\u53cd\u8f6c\u9884\u6d4b\uff0c\u60a8\u7684 AUC \u503c\u53ef\u80fd\u4f1a\u8d85\u8fc7 0.5\u3002\u63a5\u8fd1 1 \u7684 AUC \u503c\u88ab\u8ba4\u4e3a\u662f\u597d\u503c\u3002 \u4f46 AUC \u5bf9\u6211\u4eec\u7684\u6a21\u578b\u6709\u4ec0\u4e48\u5f71\u54cd\u5462\uff1f \u5047\u8bbe\u60a8\u5efa\u7acb\u4e86\u4e00\u4e2a\u4ece\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e2d\u68c0\u6d4b\u6c14\u80f8\u7684\u6a21\u578b\uff0c\u5176 AUC \u503c\u4e3a 0.85\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5982\u679c\u60a8\u4ece\u6570\u636e\u96c6\u4e2d\u968f\u673a\u9009\u62e9\u4e00\u5f20\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9633\u6027\u6837\u672c\uff09\u548c\u53e6\u4e00\u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9634\u6027\u6837\u672c\uff09\uff0c\u90a3\u4e48\u6c14\u80f8\u56fe\u50cf\u7684\u6392\u540d\u5c06\u9ad8\u4e8e\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u6982\u7387\u4e3a 0.85\u3002 \u8ba1\u7b97\u6982\u7387\u548c AUC \u540e\uff0c\u60a8\u9700\u8981\u5bf9\u6d4b\u8bd5\u96c6\u8fdb\u884c\u9884\u6d4b\u3002\u6839\u636e\u95ee\u9898\u548c\u4f7f\u7528\u60c5\u51b5\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u6982\u7387\u6216\u5b9e\u9645\u7c7b\u522b\u3002\u5982\u679c\u4f60\u60f3\u8981\u6982\u7387\uff0c\u8fd9\u5e76\u4e0d\u96be\u3002\u5982\u679c\u60a8\u60f3\u8981\u7c7b\u522b\uff0c\u5219\u9700\u8981\u9009\u62e9\u4e00\u4e2a\u9608\u503c\u3002\u5728\u4e8c\u5143\u5206\u7c7b\u7684\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u91c7\u7528\u7c7b\u4f3c\u4e0b\u9762\u7684\u65b9\u6cd5\u3002 \\[ Prediction = Probability >= Threshold \\] \u4e5f\u5c31\u662f\u8bf4\uff0c\u9884\u6d4b\u662f\u4e00\u4e2a\u53ea\u5305\u542b\u4e8c\u5143\u53d8\u91cf\u7684\u65b0\u5217\u8868\u3002\u5982\u679c\u6982\u7387\u5927\u4e8e\u6216\u7b49\u4e8e\u7ed9\u5b9a\u7684\u9608\u503c\uff0c\u5219\u9884\u6d4b\u4e2d\u7684\u4e00\u9879\u4e3a 1\uff0c\u5426\u5219\u4e3a 0\u3002 \u4f60\u731c\u600e\u4e48\u7740\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 ROC \u66f2\u7ebf\u6765\u9009\u62e9\u8fd9\u4e2a\u9608\u503c\uff01ROC \u66f2\u7ebf\u4f1a\u544a\u8bc9\u60a8\u9608\u503c\u5bf9\u5047\u9633\u6027\u7387\u548c\u771f\u9633\u6027\u7387\u7684\u5f71\u54cd\uff0c\u8fdb\u800c\u5f71\u54cd\u5047\u9633\u6027\u548c\u771f\u9633\u6027\u3002\u60a8\u5e94\u8be5\u9009\u62e9\u6700\u9002\u5408\u60a8\u7684\u95ee\u9898\u548c\u6570\u636e\u96c6\u7684\u9608\u503c\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u4e0d\u5e0c\u671b\u6709\u592a\u591a\u7684\u8bef\u62a5\uff0c\u90a3\u4e48\u9608\u503c\u5c31\u5e94\u8be5\u9ad8\u4e00\u4e9b\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e5f\u4f1a\u5e26\u6765\u66f4\u591a\u7684\u8bef\u62a5\u3002\u6ce8\u610f\u6743\u8861\u5229\u5f0a\uff0c\u9009\u62e9\u6700\u4f73\u9608\u503c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u9608\u503c\u5982\u4f55\u5f71\u54cd\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u503c\u3002 # \u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list = [] # \u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list = [] # \u771f\u5b9e\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 temp_tp = true_positive ( y_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 temp_fp = false_positive ( y_true , temp_pred ) # \u52a0\u5165\u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list . append ( temp_tp ) # \u52a0\u5165\u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list . append ( temp_fp ) \u5229\u7528\u8fd9\u4e00\u70b9\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8868\u683c\uff0c\u5982\u56fe 5 \u6240\u793a\u3002 \u56fe 5\uff1a\u4e0d\u540c\u9608\u503c\u7684 TP \u503c\u548c FP \u503c \u5982\u56fe 6 \u6240\u793a\uff0c\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0cROC \u66f2\u7ebf\u5de6\u4e0a\u89d2\u7684\u503c\u5e94\u8be5\u662f\u4e00\u4e2a\u76f8\u5f53\u4e0d\u9519\u7684\u9608\u503c\u3002 \u5bf9\u6bd4\u8868\u683c\u548c ROC \u66f2\u7ebf\uff0c\u6211\u4eec\u53ef\u4ee5\u53d1\u73b0\uff0c0.6 \u5de6\u53f3\u7684\u9608\u503c\u76f8\u5f53\u4e0d\u9519\uff0c\u65e2\u4e0d\u4f1a\u4e22\u5931\u5927\u91cf\u7684\u771f\u9633\u6027\u7ed3\u679c\uff0c\u4e5f\u4e0d\u4f1a\u51fa\u73b0\u5927\u91cf\u7684\u5047\u9633\u6027\u7ed3\u679c\u3002 \u56fe 6\uff1a\u4ece ROC \u66f2\u7ebf\u6700\u5de6\u4fa7\u7684\u9876\u70b9\u9009\u62e9\u6700\u4f73\u9608\u503c AUC \u662f\u4e1a\u5185\u5e7f\u6cdb\u5e94\u7528\u4e8e\u504f\u659c\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6307\u6807\uff0c\u4e5f\u662f\u6bcf\u4e2a\u4eba\u90fd\u5e94\u8be5\u4e86\u89e3\u7684\u6307\u6807\u3002\u4e00\u65e6\u7406\u89e3\u4e86 AUC \u80cc\u540e\u7684\u7406\u5ff5\uff08\u5982\u4e0a\u6587\u6240\u8ff0\uff09\uff0c\u4e5f\u5c31\u5f88\u5bb9\u6613\u5411\u4e1a\u754c\u53ef\u80fd\u4f1a\u8bc4\u4f30\u60a8\u7684\u6a21\u578b\u7684\u975e\u6280\u672f\u4eba\u5458\u89e3\u91ca\u5b83\u4e86\u3002 \u5b66\u4e60 AUC \u540e\uff0c\u4f60\u5e94\u8be5\u5b66\u4e60\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u6307\u6807\u662f\u5bf9\u6570\u635f\u5931\u3002\u5bf9\u4e8e\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5c06\u5bf9\u6570\u635f\u5931\u5b9a\u4e49\u4e3a\uff1a \\[ LogLoss = -1.0 \\times (target \\times log(prediction) + (1-target) \\times log(1-prediction)) \\] \u5176\u4e2d\uff0c\u76ee\u6807\u503c\u4e3a 0 \u6216 1\uff0c\u9884\u6d4b\u503c\u4e3a\u6837\u672c\u5c5e\u4e8e\u7c7b\u522b 1 \u7684\u6982\u7387\u3002 \u5bf9\u4e8e\u6570\u636e\u96c6\u4e2d\u7684\u591a\u4e2a\u6837\u672c\uff0c\u6240\u6709\u6837\u672c\u7684\u5bf9\u6570\u635f\u5931\u53ea\u662f\u6240\u6709\u5355\u4e2a\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u9700\u8981\u8bb0\u4f4f\u7684\u4e00\u70b9\u662f\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u4e0d\u6b63\u786e\u6216\u504f\u5dee\u8f83\u5927\u7684\u9884\u6d4b\u8fdb\u884c\u76f8\u5f53\u9ad8\u7684\u60e9\u7f5a\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u975e\u5e38\u786e\u5b9a\u548c\u975e\u5e38\u9519\u8bef\u7684\u9884\u6d4b\u8fdb\u884c\u60e9\u7f5a\u3002 import numpy as np def log_loss ( y_true , y_proba ): # \u6781\u5c0f\u503c\uff0c\u9632\u6b620\u505a\u5206\u6bcd epsilon = 1e-15 # \u5bf9\u6570\u635f\u5931\u5217\u8868 loss = [] # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_proba ): # \u9650\u5236yp\u8303\u56f4\uff0c\u6700\u5c0f\u4e3aepsilon\uff0c\u6700\u5927\u4e3a1-epsilon yp = np . clip ( yp , epsilon , 1 - epsilon ) # \u8ba1\u7b97\u5bf9\u6570\u635f\u5931 temp_loss = - 1.0 * ( yt * np . log ( yp ) + ( 1 - yt ) * np . log ( 1 - yp )) # \u52a0\u5165\u5bf9\u6570\u635f\u5931\u5217\u8868 loss . append ( temp_loss ) return np . mean ( loss ) \u8ba9\u6211\u4eec\u6d4b\u8bd5\u4e00\u4e0b\u51fd\u6570\u6267\u884c\u60c5\u51b5\uff1a In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_proba = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u6211\u4eec\u53ef\u4ee5\u5c06\u5176\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u5b9e\u73b0\u662f\u6b63\u786e\u7684\u3002 \u5bf9\u6570\u635f\u5931\u7684\u5b9e\u73b0\u5f88\u5bb9\u6613\u3002\u89e3\u91ca\u8d77\u6765\u4f3c\u4e4e\u6709\u70b9\u56f0\u96be\u3002\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5bf9\u6570\u635f\u5931\u7684\u60e9\u7f5a\u8981\u6bd4\u5176\u4ed6\u6307\u6807\u5927\u5f97\u591a\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u6709 51% \u7684\u628a\u63e1\u8ba4\u4e3a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\uff0c\u90a3\u4e48\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.51) + (1 - 1) \\times log(1 - 0.51))=0.67 \\] \u5982\u679c\u4f60\u5bf9\u5c5e\u4e8e 0 \u7c7b\u7684\u6837\u672c\u6709 49% \u7684\u628a\u63e1\uff0c\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.49) + (1 - 1) \\times log(1 - 0.49))=0.67 \\] \u56e0\u6b64\uff0c\u5373\u4f7f\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0.5 \u7684\u622a\u65ad\u503c\u5e76\u5f97\u5230\u5b8c\u7f8e\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u6211\u4eec\u4ecd\u7136\u4f1a\u6709\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002\u56e0\u6b64\uff0c\u5728\u5904\u7406\u5bf9\u6570\u635f\u5931\u65f6\uff0c\u4f60\u9700\u8981\u975e\u5e38\u5c0f\u5fc3\uff1b\u4efb\u4f55\u4e0d\u786e\u5b9a\u7684\u9884\u6d4b\u90fd\u4f1a\u4ea7\u751f\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002 \u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u5927\u591a\u6570\u6307\u6807\u90fd\u53ef\u4ee5\u8f6c\u6362\u6210\u591a\u7c7b\u7248\u672c\u3002\u8fd9\u4e2a\u60f3\u6cd5\u5f88\u7b80\u5355\u3002\u4ee5\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u4e3a\u4f8b\u3002\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u6bcf\u4e00\u7c7b\u7684\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u6709\u4e09\u79cd\u4e0d\u540c\u7684\u8ba1\u7b97\u65b9\u6cd5\uff0c\u6709\u65f6\u53ef\u80fd\u4f1a\u4ee4\u4eba\u56f0\u60d1\u3002\u5047\u8bbe\u6211\u4eec\u9996\u5148\u5bf9\u7cbe\u786e\u7387\u611f\u5174\u8da3\u3002\u6211\u4eec\u77e5\u9053\uff0c\u7cbe\u786e\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u3002 \u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Macro averaged precision\uff09\uff1a\u5206\u522b\u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u7cbe\u786e\u7387\u7136\u540e\u6c42\u5e73\u5747\u503c \u5fae\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Micro averaged precision\uff09\uff1a\u6309\u7c7b\u8ba1\u7b97\u771f\u9633\u6027\u548c\u5047\u9633\u6027\uff0c\u7136\u540e\u7528\u5176\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387\u3002\u7136\u540e\u4ee5\u6b64\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387 \u52a0\u6743\u7cbe\u786e\u7387 \uff08Weighted precision\uff09\uff1a\u4e0e\u5b8f\u89c2\u7cbe\u786e\u7387\u76f8\u540c\uff0c\u4f46\u8fd9\u91cc\u662f\u52a0\u6743\u5e73\u5747\u7cbe\u786e\u7387 \u53d6\u51b3\u4e8e\u6bcf\u4e2a\u7c7b\u522b\u4e2d\u7684\u9879\u76ee\u6570 \u8fd9\u770b\u4f3c\u590d\u6742\uff0c\u4f46\u5728 python \u5b9e\u73b0\u4e2d\u5f88\u5bb9\u6613\u7406\u89e3\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import numpy as np def macro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u5982\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u8ba1\u7b97\u7cbe\u786e\u5ea6 temp_precision = tp / ( tp + fp ) # \u5404\u7c7b\u7cbe\u786e\u7387\u76f8\u52a0 precision += temp_precision # \u8ba1\u7b97\u5e73\u5747\u503c precision /= num_classes return precision \u4f60\u4f1a\u53d1\u73b0\u8fd9\u5e76\u4e0d\u96be\u3002\u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5fae\u5e73\u5747\u7cbe\u786e\u7387\u5206\u6570\u3002 import numpy as np def micro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u6570 tp = 0 # \u521d\u59cb\u5316\u5047\u9633\u6027\u6837\u672c\u6570 fp = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 tp += true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 fp += false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8fd9\u4e5f\u4e0d\u96be\u3002\u90a3\u4ec0\u4e48\u96be\uff1f\u4ec0\u4e48\u90fd\u4e0d\u96be\u3002\u673a\u5668\u5b66\u4e60\u5f88\u7b80\u5355\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u52a0\u6743\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u3002 from collections import Counter import numpy as np def weighted_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 temp_precision = tp / ( tp + fp ) # \u6839\u636e\u8be5\u79cd\u7c7b\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_precision = class_counts [ class_ ] * temp_precision # \u52a0\u6743\u7cbe\u786e\u7387\u6c42\u548c precision += weighted_precision # \u8ba1\u7b97\u5e73\u5747\u7cbe\u786e\u7387 overall_precision = precision / len ( y_true ) return overall_precision \u5c06\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff0c\u4ee5\u4e86\u89e3\u5b9e\u73b0\u662f\u5426\u6b63\u786e\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: macro_precision ( y_true , y_pred ) Out [ X ]: 0.3611111111111111 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"macro\" ) Out [ X ]: 0.3611111111111111 In [ X ]: micro_precision ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"micro\" ) Out [ X ]: 0.4444444444444444 In [ X ]: weighted_precision ( y_true , y_pred ) Out [ X ]: 0.39814814814814814 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.39814814814814814 \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6b63\u786e\u5730\u5b9e\u73b0\u4e86\u4e00\u5207\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u91cc\u5c55\u793a\u7684\u5b9e\u73b0\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\uff0c\u4f46\u5374\u662f\u6700\u5bb9\u6613\u7406\u89e3\u7684\u3002 \u540c\u6837\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5b9e\u73b0 \u591a\u7c7b\u522b\u7684\u53ec\u56de\u7387\u6307\u6807 \u3002\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u3001\u5047\u9633\u6027\u548c\u5047\u9634\u6027\uff0c\u800c F1 \u5219\u53d6\u51b3\u4e8e\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u53ec\u56de\u7387\u7684\u5b9e\u73b0\u65b9\u6cd5\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\uff0c\u8fd9\u91cc\u5b9e\u73b0\u7684\u662f\u591a\u7c7b F1 \u7684\u4e00\u4e2a\u7248\u672c\uff0c\u5373\u52a0\u6743\u5e73\u5747\u503c\u3002 from collections import Counter import numpy as np def weighted_f1 ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316F1\u503c f1 = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( temp_true , temp_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( temp_true , temp_pred ) # \u82e5\u7cbe\u786e\u7387+\u53ec\u56de\u7387\u4e0d\u4e3a0\uff0c\u5219\u4f7f\u7528\u516c\u5f0f\u8ba1\u7b97F1\u503c if p + r != 0 : temp_f1 = 2 * p * r / ( p + r ) # \u5426\u5219\u76f4\u63a5\u4e3a0 else : temp_f1 = 0 # \u6839\u636e\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_f1 = class_counts [ class_ ] * temp_f1 # \u52a0\u6743F1\u503c\u76f8\u52a0 f1 += weighted_f1 # \u8ba1\u7b97\u52a0\u6743\u5e73\u5747F1\u503c overall_f1 = f1 / len ( y_true ) return overall_f1 \u8bf7\u6ce8\u610f\uff0c\u4e0a\u9762\u6709\u51e0\u884c\u4ee3\u7801\u662f\u65b0\u5199\u7684\u3002\u56e0\u6b64\uff0c\u4f60\u5e94\u8be5\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e9b\u4ee3\u7801\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: weighted_f1 ( y_true , y_pred ) Out [ X ]: 0.41269841269841273 In [ X ]: metrics . f1_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.41269841269841273 \u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u591a\u7c7b\u95ee\u9898\u5b9e\u73b0\u4e86\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002\u540c\u6837\uff0c\u60a8\u4e5f\u53ef\u4ee5\u5c06 AUC \u548c\u5bf9\u6570\u635f\u5931\u8f6c\u6362\u4e3a\u591a\u7c7b\u683c\u5f0f\u3002\u8fd9\u79cd\u8f6c\u6362\u683c\u5f0f\u88ab\u79f0\u4e3a one-vs-all \u3002\u8fd9\u91cc\u6211\u4e0d\u6253\u7b97\u5b9e\u73b0\u5b83\u4eec\uff0c\u56e0\u4e3a\u5b9e\u73b0\u65b9\u6cd5\u4e0e\u6211\u4eec\u5df2\u7ecf\u8ba8\u8bba\u8fc7\u7684\u5f88\u76f8\u4f3c\u3002 \u5728\u4e8c\u5143\u6216\u591a\u7c7b\u5206\u7c7b\u4e2d\uff0c\u770b\u4e00\u4e0b \u6df7\u6dc6\u77e9\u9635 \u4e5f\u5f88\u6d41\u884c\u3002\u4e0d\u8981\u56f0\u60d1\uff0c\u8fd9\u5f88\u7b80\u5355\u3002\u6df7\u6dc6\u77e9\u9635\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u5305\u542b TP\u3001FP\u3001TN \u548c FN \u7684\u8868\u683c\u3002\u4f7f\u7528\u6df7\u6dc6\u77e9\u9635\uff0c\u60a8\u53ef\u4ee5\u5feb\u901f\u67e5\u770b\u6709\u591a\u5c11\u6837\u672c\u88ab\u9519\u8bef\u5206\u7c7b\uff0c\u6709\u591a\u5c11\u6837\u672c\u88ab\u6b63\u786e\u5206\u7c7b\u3002\u4e5f\u8bb8\u6709\u4eba\u4f1a\u8bf4\uff0c\u6df7\u6dc6\u77e9\u9635\u5e94\u8be5\u5728\u672c\u7ae0\u5f88\u65e9\u5c31\u8bb2\u5230\uff0c\u4f46\u6211\u6ca1\u6709\u8fd9\u4e48\u505a\u3002\u5982\u679c\u4e86\u89e3\u4e86 TP\u3001FP\u3001TN\u3001FN\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c AUC\uff0c\u5c31\u5f88\u5bb9\u6613\u7406\u89e3\u548c\u89e3\u91ca\u6df7\u6dc6\u77e9\u9635\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u56fe 7 \u4e2d\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6df7\u6dc6\u77e9\u9635\u7531 TP\u3001FP\u3001FN \u548c TN \u7ec4\u6210\u3002\u6211\u4eec\u53ea\u9700\u8981\u8fd9\u4e9b\u503c\u6765\u8ba1\u7b97\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u3001F1 \u5206\u6570\u548c AUC\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u4e5f\u559c\u6b22\u628a FP \u79f0\u4e3a \u7b2c\u4e00\u7c7b\u9519\u8bef \uff0c\u628a FN \u79f0\u4e3a \u7b2c\u4e8c\u7c7b\u9519\u8bef \u3002 \u56fe 7\uff1a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6df7\u6dc6\u77e9\u9635 \u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u4e8c\u5143\u6df7\u6dc6\u77e9\u9635\u6269\u5c55\u4e3a\u591a\u7c7b\u6df7\u6dc6\u77e9\u9635\u3002\u5b83\u4f1a\u662f\u4ec0\u4e48\u6837\u5b50\u5462\uff1f\u5982\u679c\u6211\u4eec\u6709 N \u4e2a\u7c7b\u522b\uff0c\u5b83\u5c06\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a NxN \u7684\u77e9\u9635\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b\uff0c\u6211\u4eec\u90fd\u8981\u8ba1\u7b97\u76f8\u5173\u7c7b\u522b\u548c\u5176\u4ed6\u7c7b\u522b\u7684\u6837\u672c\u603b\u6570\u3002\u4e3e\u4e2a\u4f8b\u5b50\u53ef\u4ee5\u8ba9\u6211\u4eec\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4ee5\u4e0b\u771f\u5b9e\u6807\u7b7e\uff1a \\[ [0, 1, 2, 0, 1, 2, 0, 2, 2] \\] \u6211\u4eec\u7684\u9884\u6d4b\u6807\u7b7e\u662f\uff1a \\[ [0, 2, 1, 0, 2, 1, 0, 0, 2] \\] \u90a3\u4e48\uff0c\u6211\u4eec\u7684\u6df7\u6dc6\u77e9\u9635\u5c06\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u591a\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635 \u56fe 8 \u8bf4\u660e\u4e86\u4ec0\u4e48\uff1f \u8ba9\u6211\u4eec\u6765\u770b\u770b 0 \u7c7b\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e 0 \u7c7b\u3002\u7136\u800c\uff0c\u5728\u9884\u6d4b\u4e2d\uff0c\u6211\u4eec\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 0 \u7c7b\uff0c1 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\u3002\u7406\u60f3\u60c5\u51b5\u4e0b\uff0c\u5bf9\u4e8e\u771f\u5b9e\u6807\u7b7e\u4e2d\u7684\u7c7b\u522b 0\uff0c\u9884\u6d4b\u6807\u7b7e 1 \u548c 2 \u5e94\u8be5\u6ca1\u6709\u4efb\u4f55\u6837\u672c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u7c7b\u522b 2\u3002\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 4\uff0c\u800c\u5728\u9884\u6d4b\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 3\u3002 \u4e00\u4e2a\u5b8c\u7f8e\u7684\u6df7\u6dc6\u77e9\u9635\u53ea\u80fd\u4ece\u5de6\u5230\u53f3\u659c\u5411\u586b\u5145\u3002 \u6df7\u6dc6\u77e9\u9635 \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u8ba1\u7b97\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u4e0d\u540c\u6307\u6807\u3002Scikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u76f4\u63a5\u7684\u65b9\u6cd5\u6765\u751f\u6210\u6df7\u6dc6\u77e9\u9635\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5728\u56fe 8 \u4e2d\u663e\u793a\u7684\u6df7\u6dc6\u77e9\u9635\u662f scikit-learn \u6df7\u6dc6\u77e9\u9635\u7684\u8f6c\u7f6e\uff0c\u539f\u59cb\u7248\u672c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u7ed8\u5236\u3002 import matplotlib.pyplot as plt import seaborn as sns from sklearn import metrics # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] # \u9884\u6d4b\u6837\u672c\u6807\u7b7e y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] # \u8ba1\u7b97\u6df7\u6dc6\u77e9\u9635 cm = metrics . confusion_matrix ( y_true , y_pred ) # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 10 , 10 )) # \u521b\u5efa\u65b9\u683c cmap = sns . cubehelix_palette ( 50 , hue = 0.05 , rot = 0 , light = 0.9 , dark = 0 , as_cmap = True ) # \u89c4\u5b9a\u5b57\u4f53\u5927\u5c0f sns . set ( font_scale = 2.5 ) # \u7ed8\u5236\u70ed\u56fe sns . heatmap ( cm , annot = True , cmap = cmap , cbar = False ) # y\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . ylabel ( 'Actual Labels' , fontsize = 20 ) # x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . xlabel ( 'Predicted Labels' , fontsize = 20 ) \u56e0\u6b64\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u89e3\u51b3\u4e86\u4e8c\u5143\u5206\u7c7b\u548c\u591a\u7c7b\u5206\u7c7b\u7684\u5ea6\u91cf\u95ee\u9898\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u8ba8\u8bba\u53e6\u4e00\u79cd\u7c7b\u578b\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u5373\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u5728\u591a\u6807\u7b7e\u5206\u7c7b\u4e2d\uff0c\u6bcf\u4e2a\u6837\u672c\u90fd\u53ef\u80fd\u4e0e\u4e00\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u76f8\u5173\u8054\u3002\u8fd9\u7c7b\u95ee\u9898\u7684\u4e00\u4e2a\u7b80\u5355\u4f8b\u5b50\u5c31\u662f\u8981\u6c42\u4f60\u9884\u6d4b\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53\u3002 \u56fe 9 \u663e\u793a\u4e86\u4e00\u4e2a\u8457\u540d\u6570\u636e\u96c6\u7684\u56fe\u50cf\u793a\u4f8b\u3002\u8bf7\u6ce8\u610f\uff0c\u8be5\u6570\u636e\u96c6\u7684\u76ee\u6807\u6709\u6240\u4e0d\u540c\uff0c\u4f46\u6211\u4eec\u6682\u4e14\u4e0d\u53bb\u8ba8\u8bba\u5b83\u3002\u6211\u4eec\u5047\u8bbe\u5176\u76ee\u7684\u53ea\u662f\u9884\u6d4b\u56fe\u50cf\u4e2d\u662f\u5426\u5b58\u5728\u67d0\u4e2a\u7269\u4f53\u3002\u5728\u56fe 9 \u4e2d\uff0c\u6211\u4eec\u6709\u6905\u5b50\u3001\u82b1\u76c6\u3001\u7a97\u6237\uff0c\u4f46\u6ca1\u6709\u5176\u4ed6\u7269\u4f53\uff0c\u5982\u7535\u8111\u3001\u5e8a\u3001\u7535\u89c6\u7b49\u3002\u56e0\u6b64\uff0c\u4e00\u5e45\u56fe\u50cf\u53ef\u80fd\u6709\u591a\u4e2a\u76f8\u5173\u76ee\u6807\u3002\u8fd9\u7c7b\u95ee\u9898\u5c31\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u95ee\u9898\u3002 \u56fe 9\uff1a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53 \u8fd9\u7c7b\u5206\u7c7b\u95ee\u9898\u7684\u8861\u91cf\u6807\u51c6\u6709\u4e9b\u4e0d\u540c\u3002\u4e00\u4e9b\u5408\u9002\u7684 \u6700\u5e38\u89c1\u7684\u6307\u6807\u6709\uff1a k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u786e\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 \u8ba9\u6211\u4eec\u4ece k \u7cbe\u786e\u7387\u6216\u8005 P@k \u6211\u4eec\u4e0d\u80fd\u5c06\u8fd9\u4e00\u7cbe\u786e\u7387\u4e0e\u524d\u9762\u8ba8\u8bba\u7684\u7cbe\u786e\u7387\u6df7\u6dc6\u3002\u5982\u679c\u60a8\u6709\u4e00\u4e2a\u7ed9\u5b9a\u6837\u672c\u7684\u539f\u59cb\u7c7b\u522b\u5217\u8868\u548c\u540c\u4e00\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u7c7b\u522b\u5217\u8868\uff0c\u90a3\u4e48\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u5c31\u662f\u9884\u6d4b\u5217\u8868\u4e2d\u4ec5\u8003\u8651\u524d k \u4e2a\u9884\u6d4b\u7ed3\u679c\u7684\u547d\u4e2d\u6570\u9664\u4ee5 k\u3002 \u5982\u679c\u60a8\u5bf9\u6b64\u611f\u5230\u56f0\u60d1\uff0c\u4f7f\u7528 python \u4ee3\u7801\u540e\u5c31\u4f1a\u660e\u767d\u3002 def pk ( y_true , y_pred , k ): # \u5982\u679ck\u4e3a0 if k == 0 : # \u8fd4\u56de0 return 0 # \u53d6\u9884\u6d4b\u6807\u7b7e\u524dk\u4e2a y_pred = y_pred [: k ] # \u5c06\u9884\u6d4b\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 pred_set = set ( y_pred ) # \u5c06\u771f\u5b9e\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 true_set = set ( y_true ) # \u9884\u6d4b\u6807\u7b7e\u96c6\u5408\u4e0e\u771f\u5b9e\u6807\u7b7e\u96c6\u5408\u4ea4\u96c6 common_values = pred_set . intersection ( true_set ) # \u8ba1\u7b97\u7cbe\u786e\u7387 return len ( common_values ) / len ( y_pred [: k ]) \u6709\u4e86\u4ee3\u7801\uff0c\u4e00\u5207\u90fd\u53d8\u5f97\u66f4\u5bb9\u6613\u7406\u89e3\u4e86\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e86 k \u5e73\u5747\u7cbe\u786e\u7387\u6216 AP@k \u3002AP@k \u662f\u901a\u8fc7 P@k \u8ba1\u7b97\u5f97\u51fa\u7684\u3002\u4f8b\u5982\uff0c\u5982\u679c\u8981\u8ba1\u7b97 AP@3\uff0c\u6211\u4eec\u8981\u5148\u8ba1\u7b97 P@1\u3001P@2 \u548c P@3\uff0c\u7136\u540e\u5c06\u603b\u548c\u9664\u4ee5 3\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u7684\u5b9e\u73b0\u3002 def apk ( y_true , y_pred , k ): # \u521d\u59cb\u5316P@k\u5217\u8868 pk_values = [] # \u904d\u53861~k for i in range ( 1 , k + 1 ): # \u5c06P@k\u52a0\u5165\u5217\u8868 pk_values . append ( pk ( y_true , y_pred , i )) # \u82e5\u957f\u5ea6\u4e3a0 if len ( pk_values ) == 0 : # \u8fd4\u56de0 return 0 # \u5426\u5219\u8ba1\u7b97AP@K return sum ( pk_values ) / len ( pk_values ) \u8fd9\u4e24\u4e2a\u51fd\u6570\u53ef\u4ee5\u7528\u6765\u8ba1\u7b97\u4e24\u4e2a\u7ed9\u5b9a\u5217\u8868\u7684 k \u5e73\u5747\u7cbe\u786e\u7387 (AP@k)\uff1b\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u8ba1\u7b97\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: for i in range ( len ( y_true )): ... : for j in range ( 1 , 4 ): ... : print ( ... : f \"\"\" ...: y_true= { y_true [ i ] } , ...: y_pred= { y_pred [ i ] } , ...: AP@ { j } = { apk ( y_true [ i ], y_pred [ i ], k = j ) } ...: \"\"\" ... : ) ... : y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 1 = 0.0 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 2 = 0.25 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 3 = 0.38888888888888884 \u8bf7\u6ce8\u610f\uff0c\u6211\u7701\u7565\u4e86\u8f93\u51fa\u7ed3\u679c\u4e2d\u7684\u8bb8\u591a\u6570\u503c\uff0c\u4f46\u4f60\u4f1a\u660e\u767d\u5176\u4e2d\u7684\u610f\u601d\u3002\u8fd9\u5c31\u662f\u6211\u4eec\u5982\u4f55\u8ba1\u7b97 AP@k \u7684\u65b9\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u7684 AP@k\u3002\u5728\u673a\u5668\u5b66\u4e60\u4e2d\uff0c\u6211\u4eec\u5bf9\u6240\u6709\u6837\u672c\u90fd\u611f\u5174\u8da3\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u6709 \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387 k \u6216 MAP@k \u3002MAP@k \u53ea\u662f AP@k \u7684\u5e73\u5747\u503c\uff0c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b python \u4ee3\u7801\u8f7b\u677e\u8ba1\u7b97\u3002 def mapk ( y_true , y_pred , k ): # \u521d\u59cb\u5316AP@k\u5217\u8868 apk_values = [] # \u904d\u53860~\uff08\u771f\u5b9e\u6807\u7b7e\u6570-1\uff09 for i in range ( len ( y_true )): # \u5c06AP@K\u52a0\u5165\u5217\u8868 apk_values . append ( apk ( y_true [ i ], y_pred [ i ], k = k ) ) # \u8ba1\u7b97\u5e73\u5747AP@k return sum ( apk_values ) / len ( apk_values ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u9488\u5bf9\u76f8\u540c\u7684\u5217\u8868\u8ba1\u7b97 k=1\u30012\u30013 \u548c 4 \u65f6\u7684 MAP@k\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: mapk ( y_true , y_pred , k = 1 ) Out [ X ]: 0.3333333333333333 In [ X ]: mapk ( y_true , y_pred , k = 2 ) Out [ X ]: 0.375 In [ X ]: mapk ( y_true , y_pred , k = 3 ) Out [ X ]: 0.3611111111111111 In [ X ]: mapk ( y_true , y_pred , k = 4 ) Out [ X ]: 0.34722222222222215 P@k\u3001AP@k \u548c MAP@k \u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u5176\u4e2d 1 \u4e3a\u6700\u4f73\u3002 \u8bf7\u6ce8\u610f\uff0c\u6709\u65f6\u60a8\u53ef\u80fd\u4f1a\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230 P@k \u548c AP@k \u7684\u4e0d\u540c\u5b9e\u73b0\u65b9\u5f0f\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5176\u4e2d\u4e00\u79cd\u5b9e\u73b0\u65b9\u5f0f\u3002 import numpy as np def apk ( actual , predicted , k = 10 ): # \u82e5\u9884\u6d4b\u6807\u7b7e\u957f\u5ea6\u5927\u4e8ek if len ( predicted ) > k : # \u53d6\u524dk\u4e2a\u6807\u7b7e predicted = predicted [: k ] score = 0.0 num_hits = 0.0 for i , p in enumerate ( predicted ): if p in actual and p not in predicted [: i ]: num_hits += 1.0 score += num_hits / ( i + 1.0 ) if not actual : return 0.0 return score / min ( len ( actual ), k ) \u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u662f AP@k \u7684\u53e6\u4e00\u4e2a\u7248\u672c\uff0c\u5176\u4e2d\u987a\u5e8f\u5f88\u91cd\u8981\uff0c\u6211\u4eec\u8981\u6743\u8861\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u7684\u7ed3\u679c\u4e0e\u6211\u7684\u4ecb\u7ecd\u7565\u6709\u4e0d\u540c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6765\u770b\u770b \u591a\u6807\u7b7e\u5206\u7c7b\u7684\u5bf9\u6570\u635f\u5931 \u3002\u8fd9\u5f88\u5bb9\u6613\u3002\u60a8\u53ef\u4ee5\u5c06\u76ee\u6807\u8f6c\u6362\u4e3a\u4e8c\u5143\u5206\u7c7b\uff0c\u7136\u540e\u5bf9\u6bcf\u4e00\u5217\u4f7f\u7528\u5bf9\u6570\u635f\u5931\u3002\u6700\u540e\uff0c\u4f60\u53ef\u4ee5\u6c42\u51fa\u6bcf\u5217\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u8fd9\u4e5f\u88ab\u79f0\u4e3a\u5e73\u5747\u5217\u5bf9\u6570\u635f\u5931\u3002\u5f53\u7136\uff0c\u8fd8\u6709\u5176\u4ed6\u65b9\u6cd5\u53ef\u4ee5\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff0c\u4f60\u5e94\u8be5\u5728\u9047\u5230\u65f6\u52a0\u4ee5\u63a2\u7d22\u3002 \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u8bf4\u5df2\u7ecf\u638c\u63e1\u4e86\u6240\u6709\u4e8c\u5143\u5206\u7c7b\u3001\u591a\u7c7b\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u6307\u6807\uff0c\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u6307\u6807\u3002 \u56de\u5f52\u4e2d\u6700\u5e38\u89c1\u7684\u6307\u6807\u662f \u8bef\u5dee\uff08Error\uff09 \u3002\u8bef\u5dee\u5f88\u7b80\u5355\uff0c\u4e5f\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \\[ Error = True\\ Value - Predicted\\ Value \\] \u7edd\u5bf9\u8bef\u5dee\uff08Absolute error\uff09 \u53ea\u662f\u4e0a\u8ff0\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\u3002 \\[ Absolute\\ Error = Abs(True\\ Value - Predicted\\ Value) \\] \u63a5\u4e0b\u6765\u6211\u4eec\u8ba8\u8bba \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff08MAE\uff09 \u3002\u5b83\u53ea\u662f\u6240\u6709\u7edd\u5bf9\u8bef\u5dee\u7684\u5e73\u5747\u503c\u3002 import numpy as np def mean_absolute_error ( y_true , y_pred ): #\u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u7edd\u5bf9\u8bef\u5dee error += np . abs ( yt - yp ) # \u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee return error / len ( y_true ) \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5e73\u65b9\u8bef\u5dee\u548c \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u3002 \\[ Squared\\ Error = (True Value - Predicted\\ Value)^2 \\] \u5747\u65b9\u8bef\u5dee\uff08MSE\uff09\u7684\u8ba1\u7b97\u65b9\u5f0f\u5982\u4e0b def mean_squared_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u8bef\u5dee\u5e73\u65b9\u548c error += ( yt - yp ) ** 2 # \u8ba1\u7b97\u5747\u65b9\u8bef\u5dee return error / len ( y_true ) MSE \u548c RMSE\uff08\u5747\u65b9\u6839\u8bef\u5dee\uff09 \u662f\u8bc4\u4f30\u56de\u5f52\u6a21\u578b\u6700\u5e38\u7528\u7684\u6307\u6807\u3002 \\[ RMSE = SQRT(MSE) \\] \u540c\u4e00\u7c7b\u8bef\u5dee\u7684\u53e6\u4e00\u79cd\u7c7b\u578b\u662f \u5e73\u65b9\u5bf9\u6570\u8bef\u5dee \u3002\u6709\u4eba\u79f0\u5176\u4e3a SLE \uff0c\u5f53\u6211\u4eec\u53d6\u6240\u6709\u6837\u672c\u4e2d\u8fd9\u4e00\u8bef\u5dee\u7684\u5e73\u5747\u503c\u65f6\uff0c\u5b83\u88ab\u79f0\u4e3a MSLE\uff08\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee\uff09\uff0c\u5b9e\u73b0\u65b9\u6cd5\u5982\u4e0b\u3002 import numpy as np def mean_squared_log_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee error += ( np . log ( 1 + yt ) - np . log ( 1 + yp )) ** 2 # \u8ba1\u7b97\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee return error / len ( y_true ) \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \u53ea\u662f\u5176\u5e73\u65b9\u6839\u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a RMSLE \u3002 \u7136\u540e\u662f\u767e\u5206\u6bd4\u8bef\u5dee\uff1a \\[ Percentage\\ Error = (( True\\ Value \u2013 Predicted\\ Value ) / True\\ Value ) \\times 100 \\] \u540c\u6837\u53ef\u4ee5\u8f6c\u6362\u4e3a\u6240\u6709\u6837\u672c\u7684\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee\u3002 def mean_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u767e\u5206\u6bd4\u8bef\u5dee error += ( yt - yp ) / yt # \u8fd4\u56de\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u7edd\u5bf9\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\uff08\u4e5f\u662f\u66f4\u5e38\u89c1\u7684\u7248\u672c\uff09\u88ab\u79f0\u4e3a \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee\u6216 MAPE \u3002 import numpy as np def mean_abs_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee error += np . abs ( yt - yp ) / yt #\u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u56de\u5f52\u7684\u6700\u5927\u4f18\u70b9\u662f\uff0c\u53ea\u6709\u51e0\u4e2a\u6700\u5e38\u7528\u7684\u6307\u6807\uff0c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u6240\u6709\u56de\u5f52\u95ee\u9898\u3002\u4e0e\u5206\u7c7b\u6307\u6807\u76f8\u6bd4\uff0c\u56de\u5f52\u6307\u6807\u66f4\u5bb9\u6613\u7406\u89e3\u3002 \u8ba9\u6211\u4eec\u6765\u8c08\u8c08\u53e6\u4e00\u4e2a\u56de\u5f52\u6307\u6807 \\(R^2\\) \uff08R \u65b9\uff09\uff0c\u4e5f\u79f0\u4e3a \u5224\u5b9a\u7cfb\u6570 \u3002 \u7b80\u5355\u5730\u8bf4\uff0cR \u65b9\u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u3002R \u65b9\u63a5\u8fd1 1.0 \u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u76f8\u5f53\u597d\uff0c\u800c\u63a5\u8fd1 0 \u5219\u8868\u793a\u6a21\u578b\u4e0d\u662f\u90a3\u4e48\u597d\u3002\u5f53\u6a21\u578b\u53ea\u662f\u505a\u51fa\u8352\u8c2c\u7684\u9884\u6d4b\u65f6\uff0cR \u65b9\u4e5f\u53ef\u80fd\u662f\u8d1f\u503c\u3002 R \u65b9\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\u6240\u793a\uff0c\u4f46 Python \u7684\u5b9e\u73b0\u603b\u662f\u80fd\u8ba9\u4e00\u5207\u66f4\u52a0\u6e05\u6670\u3002 \\[ R^2 = \\frac{\\sum^{N}_{i=1}(y_{t_i}-y_{p_i})^2}{\\sum^{N}_{i=1}(y_{t_i} - y_{t_{mean}})} \\] import numpy as np def r2 ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u5747\u771f\u5b9e\u503c mean_true_value = np . mean ( y_true ) # \u521d\u59cb\u5316\u5e73\u65b9\u8bef\u5dee numerator = 0 denominator = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): numerator += ( yt - yp ) ** 2 denominator += ( yt - mean_true_value ) ** 2 ratio = numerator / denominator # \u8ba1\u7b97R\u65b9 return 1 \u2013 ratio \u8fd8\u6709\u66f4\u591a\u7684\u8bc4\u4ef7\u6307\u6807\uff0c\u8fd9\u4e2a\u6e05\u5355\u6c38\u8fdc\u4e5f\u5217\u4e0d\u5b8c\u3002\u6211\u53ef\u4ee5\u5199\u4e00\u672c\u4e66\uff0c\u53ea\u4ecb\u7ecd\u4e0d\u540c\u7684\u8bc4\u4ef7\u6307\u6807\u3002\u4e5f\u8bb8\u6211\u4f1a\u7684\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b\u8bc4\u4f30\u6307\u6807\u51e0\u4e4e\u53ef\u4ee5\u6ee1\u8db3\u4f60\u60f3\u5c1d\u8bd5\u89e3\u51b3\u7684\u6240\u6709\u95ee\u9898\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ee5\u6700\u76f4\u63a5\u7684\u65b9\u5f0f\u5b9e\u73b0\u4e86\u8fd9\u4e9b\u6307\u6807\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u4eec\u4e0d\u591f\u9ad8\u6548\u3002\u4f60\u53ef\u4ee5\u901a\u8fc7\u6b63\u786e\u4f7f\u7528 numpy \u4ee5\u975e\u5e38\u9ad8\u6548\u7684\u65b9\u5f0f\u5b9e\u73b0\u5176\u4e2d\u5927\u90e8\u5206\u6307\u6807\u3002\u4f8b\u5982\uff0c\u770b\u770b\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\u7684\u5b9e\u73b0\uff0c\u4e0d\u9700\u8981\u4efb\u4f55\u5faa\u73af\u3002 import numpy as np def mae_np ( y_true , y_pred ): return np . mean ( np . abs ( y_true - y_pred )) \u6211\u672c\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5b9e\u73b0\u6240\u6709\u6307\u6807\uff0c\u4f46\u4e3a\u4e86\u5b66\u4e60\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b\u5e95\u5c42\u5b9e\u73b0\u3002\u4e00\u65e6\u4f60\u5b66\u4f1a\u4e86\u7eaf python \u7684\u5e95\u5c42\u5b9e\u73b0\uff0c\u5e76\u4e14\u4e0d\u4f7f\u7528\u5927\u91cf numpy\uff0c\u4f60\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u5176\u8f6c\u6362\u4e3a numpy\uff0c\u5e76\u4f7f\u5176\u53d8\u5f97\u66f4\u5feb\u3002 \u7136\u540e\u662f\u4e00\u4e9b\u9ad8\u7ea7\u5ea6\u91cf\u3002 \u5176\u4e2d\u4e00\u4e2a\u5e94\u7528\u76f8\u5f53\u5e7f\u6cdb\u7684\u6307\u6807\u662f \u4e8c\u6b21\u52a0\u6743\u5361\u5e15 \uff0c\u4e5f\u79f0\u4e3a QWK \u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a\u79d1\u6069\u5361\u5e15\u3002 QWK \u8861\u91cf\u4e24\u4e2a \"\u8bc4\u5206 \"\u4e4b\u95f4\u7684 \"\u4e00\u81f4\u6027\"\u3002\u8bc4\u5206\u53ef\u4ee5\u662f 0 \u5230 N \u4e4b\u95f4\u7684\u4efb\u4f55\u5b9e\u6570\uff0c\u9884\u6d4b\u4e5f\u5728\u540c\u4e00\u8303\u56f4\u5185\u3002\u4e00\u81f4\u6027\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u8fd9\u4e9b\u8bc4\u7ea7\u4e4b\u95f4\u7684\u63a5\u8fd1\u7a0b\u5ea6\u3002\u56e0\u6b64\uff0c\u5b83\u9002\u7528\u4e8e\u6709 N \u4e2a\u4e0d\u540c\u7c7b\u522b\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u4e00\u81f4\u5ea6\u9ad8\uff0c\u5206\u6570\u5c31\u66f4\u63a5\u8fd1 1.0\u3002Cohen's kappa \u5728 scikit-learn \u4e2d\u6709\u5f88\u597d\u7684\u5b9e\u73b0\uff0c\u5173\u4e8e\u8be5\u6307\u6807\u7684\u8be6\u7ec6\u8ba8\u8bba\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 ] In [ X ]: y_pred = [ 2 , 1 , 3 , 1 , 2 , 3 , 3 , 1 , 2 ] In [ X ]: metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) Out [ X ]: 0.33333333333333337 In [ X ]: metrics . accuracy_score ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 \u60a8\u53ef\u4ee5\u770b\u5230\uff0c\u5c3d\u7ba1\u51c6\u786e\u5ea6\u5f88\u9ad8\uff0c\u4f46 QWK \u5374\u5f88\u4f4e\u3002QWK \u5927\u4e8e 0.85 \u5373\u4e3a\u975e\u5e38\u597d\uff01 \u4e00\u4e2a\u91cd\u8981\u7684\u6307\u6807\u662f \u9a6c\u4fee\u76f8\u5173\u7cfb\u6570\uff08MCC\uff09 \u30021 \u4ee3\u8868\u5b8c\u7f8e\u9884\u6d4b\uff0c-1 \u4ee3\u8868\u4e0d\u5b8c\u7f8e\u9884\u6d4b\uff0c0 \u4ee3\u8868\u968f\u673a\u9884\u6d4b\u3002MCC \u7684\u8ba1\u7b97\u516c\u5f0f\u975e\u5e38\u7b80\u5355\u3002 \\[ MCC = \\frac{TP \\times TN - FP \\times FN}{\\sqrt{(TP + FP) \\times (FN + TN) \\times (FP + TN) \\times (TP + FN)}} \\] \u6211\u4eec\u770b\u5230\uff0cMCC \u8003\u8651\u4e86 TP\u3001FP\u3001TN \u548c FN\uff0c\u56e0\u6b64\u53ef\u7528\u4e8e\u5904\u7406\u7c7b\u504f\u659c\u7684\u95ee\u9898\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u7684\u65b9\u6cd5\u5728 python \u4e2d\u5feb\u901f\u5b9e\u73b0\u5b83\u3002 def mcc ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) numerator = ( tp * tn ) - ( fp * fn ) denominator = ( ( tp + fp ) * ( fn + tn ) * ( fp + tn ) * ( tp + fn ) ) denominator = denominator ** 0.5 return numerator / denominator \u8fd9\u4e9b\u6307\u6807\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002 \u9700\u8981\u6ce8\u610f\u7684\u4e00\u70b9\u662f\uff0c\u5728\u8bc4\u4f30\u975e\u76d1\u7763\u65b9\u6cd5\uff08\u4f8b\u5982\u67d0\u79cd\u805a\u7c7b\uff09\u65f6\uff0c\u6700\u597d\u521b\u5efa\u6216\u624b\u52a8\u6807\u8bb0\u6d4b\u8bd5\u96c6\uff0c\u5e76\u5c06\u5176\u4e0e\u5efa\u6a21\u90e8\u5206\u7684\u6240\u6709\u5185\u5bb9\u5206\u5f00\u3002\u5b8c\u6210\u805a\u7c7b\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u4efb\u4f55\u4e00\u79cd\u76d1\u7763\u5b66\u4e60\u6307\u6807\u6765\u8bc4\u4f30\u6d4b\u8bd5\u96c6\u7684\u6027\u80fd\u4e86\u3002 \u4e00\u65e6\u6211\u4eec\u4e86\u89e3\u4e86\u7279\u5b9a\u95ee\u9898\u5e94\u8be5\u4f7f\u7528\u4ec0\u4e48\u6307\u6807\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f00\u59cb\u66f4\u6df1\u5165\u5730\u7814\u7a76\u6211\u4eec\u7684\u6a21\u578b\uff0c\u4ee5\u6c42\u6539\u8fdb\u3002","title":"\u8bc4\u4f30\u6307\u6807"},{"location":"%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87/#_1","text":"\u8bf4\u5230\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u4f60\u4f1a\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u9047\u5230\u5f88\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u6307\u6807\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u751a\u81f3\u4f1a\u6839\u636e\u4e1a\u52a1\u95ee\u9898\u521b\u5efa\u5ea6\u91cf\u6807\u51c6\u3002\u9010\u4e00\u4ecb\u7ecd\u548c\u89e3\u91ca\u6bcf\u4e00\u79cd\u5ea6\u91cf\u7c7b\u578b\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u5ea6\u91cf\u6807\u51c6\uff0c\u4f9b\u4f60\u5728\u6700\u521d\u7684\u51e0\u4e2a\u9879\u76ee\u4e2d\u4f7f\u7528\u3002 \u5728\u672c\u4e66\u7684\u5f00\u5934\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u76d1\u7763\u5b66\u4e60\u548c\u975e\u76d1\u7763\u5b66\u4e60\u3002\u867d\u7136\u65e0\u76d1\u7763\u5b66\u4e60\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e9b\u6307\u6807\uff0c\u4f46\u6211\u4eec\u5c06\u53ea\u5173\u6ce8\u6709\u76d1\u7763\u5b66\u4e60\u3002\u8fd9\u662f\u56e0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u591a\uff0c\u800c\u4e14\u5bf9\u65e0\u76d1\u7763\u65b9\u6cd5\u7684\u8bc4\u4f30\u76f8\u5f53\u4e3b\u89c2\u3002 \u5982\u679c\u6211\u4eec\u8c08\u8bba\u5206\u7c7b\u95ee\u9898\uff0c\u6700\u5e38\u7528\u7684\u6307\u6807\u662f\uff1a \u51c6\u786e\u7387\uff08Accuracy\uff09 \u7cbe\u786e\u7387\uff08P\uff09 \u53ec\u56de\u7387\uff08R\uff09 F1 \u5206\u6570\uff08F1\uff09 AUC\uff08AUC\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u8bf4\u5230\u56de\u5f52\uff0c\u6700\u5e38\u7528\u7684\u8bc4\u4ef7\u6307\u6807\u662f \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee \uff08MAE\uff09 \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u5747\u65b9\u6839\u8bef\u5dee \uff08RMSE\uff09 \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \uff08RMSLE\uff09 \u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee \uff08MPE\uff09 \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee \uff08MAPE\uff09 R2 \u4e86\u89e3\u4e0a\u8ff0\u6307\u6807\u7684\u5de5\u4f5c\u539f\u7406\u5e76\u4e0d\u662f\u6211\u4eec\u5fc5\u987b\u4e86\u89e3\u7684\u552f\u4e00\u4e8b\u60c5\u3002\u6211\u4eec\u8fd8\u5fc5\u987b\u77e5\u9053\u4f55\u65f6\u4f7f\u7528\u54ea\u4e9b\u6307\u6807\uff0c\u800c\u8fd9\u53d6\u51b3\u4e8e\u4f60\u6709\u4ec0\u4e48\u6837\u7684\u6570\u636e\u548c\u76ee\u6807\u3002\u6211\u8ba4\u4e3a\u8fd9\u4e0e\u76ee\u6807\u6709\u5173\uff0c\u800c\u4e0e\u6570\u636e\u65e0\u5173\u3002 \u8981\u8fdb\u4e00\u6b65\u4e86\u89e3\u8fd9\u4e9b\u6307\u6807\uff0c\u8ba9\u6211\u4eec\u4ece\u4e00\u4e2a\u7b80\u5355\u7684\u95ee\u9898\u5f00\u59cb\u3002\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a \u4e8c\u5143\u5206\u7c7b \u95ee\u9898\uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u76ee\u6807\u7684\u95ee\u9898\uff0c\u5047\u8bbe\u8fd9\u662f\u4e00\u4e2a\u80f8\u90e8 X \u5149\u56fe\u50cf\u5206\u7c7b\u95ee\u9898\u3002\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6ca1\u6709\u95ee\u9898\uff0c\u800c\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6709\u80ba\u584c\u9677\uff0c\u4e5f\u5c31\u662f\u6240\u8c13\u7684\u6c14\u80f8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u5206\u7c7b\u5668\uff0c\u5728\u7ed9\u5b9a\u80f8\u90e8 X \u5149\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\uff0c\u5b83\u80fd\u68c0\u6d4b\u51fa\u56fe\u50cf\u662f\u5426\u6709\u6c14\u80f8\u3002 \u56fe 1\uff1a\u6c14\u80f8\u80ba\u90e8\u56fe\u50cf \u6211\u4eec\u8fd8\u5047\u8bbe\u6709\u76f8\u540c\u6570\u91cf\u7684\u6c14\u80f8\u548c\u975e\u6c14\u80f8\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u6bd4\u5982\u5404 100 \u5f20\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709 100 \u5f20\u9633\u6027\u6837\u672c\u548c 100 \u5f20\u9634\u6027\u6837\u672c\uff0c\u5171\u8ba1 200 \u5f20\u56fe\u50cf\u3002 \u7b2c\u4e00\u6b65\u662f\u5c06\u4e0a\u8ff0\u6570\u636e\u5206\u4e3a\u4e24\u7ec4\uff0c\u6bcf\u7ec4 100 \u5f20\u56fe\u50cf\uff0c\u5373\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u8fd9\u4e24\u4e2a\u96c6\u5408\u4e2d\uff0c\u6211\u4eec\u90fd\u6709 50 \u4e2a\u6b63\u6837\u672c\u548c 50 \u4e2a\u8d1f\u6837\u672c\u3002 \u5728\u4e8c\u5143\u5206\u7c7b\u6307\u6807\u4e2d\uff0c\u5f53\u6b63\u8d1f\u6837\u672c\u6570\u91cf\u76f8\u7b49\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f7f\u7528\u51c6\u786e\u7387\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002 \u51c6\u786e\u7387 \uff1a\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6700\u76f4\u63a5\u7684\u6307\u6807\u4e4b\u4e00\u3002\u5b83\u5b9a\u4e49\u4e86\u6a21\u578b\u7684\u51c6\u786e\u5ea6\u3002\u5bf9\u4e8e\u4e0a\u8ff0\u95ee\u9898\uff0c\u5982\u679c\u4f60\u5efa\u7acb\u7684\u6a21\u578b\u80fd\u51c6\u786e\u5206\u7c7b 90 \u5f20\u56fe\u7247\uff0c\u90a3\u4e48\u4f60\u7684\u51c6\u786e\u7387\u5c31\u662f 90% \u6216 0.90\u3002\u5982\u679c\u53ea\u6709 83 \u5e45\u56fe\u50cf\u88ab\u6b63\u786e\u5206\u7c7b\uff0c\u90a3\u4e48\u6a21\u578b\u7684\u51c6\u786e\u7387\u5c31\u662f 83% \u6216 0.83\u3002 \u8ba1\u7b97\u51c6\u786e\u7387\u7684 Python \u4ee3\u7801\u4e5f\u975e\u5e38\u7b80\u5355\u3002 def accuracy ( y_true , y_pred ): # \u4e3a\u6b63\u786e\u9884\u6d4b\u6570\u521d\u59cb\u5316\u4e00\u4e2a\u7b80\u5355\u8ba1\u6570\u5668 correct_counter = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): if yt == yp : # \u5982\u679c\u9884\u6d4b\u6807\u7b7e\u4e0e\u771f\u5b9e\u6807\u7b7e\u76f8\u540c\uff0c\u5219\u589e\u52a0\u8ba1\u6570\u5668 correct_counter += 1 # \u8fd4\u56de\u6b63\u786e\u7387\uff0c\u6b63\u786e\u6807\u7b7e\u6570/\u603b\u6807\u7b7e\u6570 return correct_counter / len ( y_true ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u8ba1\u7b97\u51c6\u786e\u7387\u3002 In [ X ]: from sklearn import metrics ... : l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] ... : metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u73b0\u5728\uff0c\u5047\u8bbe\u6211\u4eec\u628a\u6570\u636e\u96c6\u7a0d\u5fae\u6539\u52a8\u4e00\u4e0b\uff0c\u6709 180 \u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u53ea\u6709 20 \u5f20\u6709\u6c14\u80f8\u3002\u5373\u4f7f\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4e5f\u8981\u521b\u5efa\u6b63\u8d1f\uff08\u6c14\u80f8\u4e0e\u975e\u6c14\u80f8\uff09\u76ee\u6807\u6bd4\u4f8b\u76f8\u540c\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u6bcf\u4e00\u7ec4\u4e2d\uff0c\u6211\u4eec\u6709 90 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u5982\u679c\u8bf4\u9a8c\u8bc1\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u90fd\u662f\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u90a3\u4e48\u60a8\u7684\u51c6\u786e\u7387\u4f1a\u662f\u591a\u5c11\u5462\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\uff1b\u60a8\u5bf9 90% \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u6b63\u786e\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u7684\u51c6\u786e\u7387\u662f 90%\u3002 \u4f46\u8bf7\u518d\u770b\u4e00\u904d\u3002 \u4f60\u751a\u81f3\u6ca1\u6709\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5c31\u5f97\u5230\u4e86 90% \u7684\u51c6\u786e\u7387\u3002\u8fd9\u4f3c\u4e4e\u6709\u70b9\u6ca1\u7528\u3002\u5982\u679c\u6211\u4eec\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u6570\u636e\u96c6\u662f\u504f\u659c\u7684\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u6bd4\u53e6\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u591a\u5f88\u591a\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u56e0\u4e3a\u5b83\u4e0d\u80fd\u4ee3\u8868\u6570\u636e\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u80fd\u4f1a\u83b7\u5f97\u5f88\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u4f46\u60a8\u7684\u6a21\u578b\u5728\u5b9e\u9645\u6837\u672c\u4e2d\u7684\u8868\u73b0\u53ef\u80fd\u5e76\u4e0d\u7406\u60f3\uff0c\u800c\u4e14\u60a8\u4e5f\u65e0\u6cd5\u5411\u7ecf\u7406\u89e3\u91ca\u539f\u56e0\u3002 \u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b \u7cbe\u786e\u7387 \u7b49\u5176\u4ed6\u6307\u6807\u3002 \u5728\u5b66\u4e60\u7cbe\u786e\u7387\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e86\u89e3\u4e00\u4e9b\u672f\u8bed\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5047\u8bbe\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e3a\u6b63\u7c7b (1)\uff0c\u6ca1\u6709\u6c14\u80f8\u7684\u4e3a\u8d1f\u7c7b (0)\u3002 \u771f\u9633\u6027 \uff08TP\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6709\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9633\u6027\u3002 \u771f\u9634\u6027 \uff08TN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u800c\u5b9e\u9645\u76ee\u6807\u663e\u793a\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u6b63\u786e\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9633\u6027\uff1b\u5982\u679c\u60a8\u7684\u6a21\u578b\u51c6\u786e\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9634\u6027\u3002 \u5047\u9633\u6027 \uff08FP\uff09 \uff1a\u7ed9\u5b9a\u4e00\u5f20\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u975e\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9633\u6027\u3002 \u5047\u9634\u6027 \uff08FN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u975e\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\uff0c\u90a3\u4e48\u5b83\u5c31\u662f\u5047\u9633\u6027\u3002\u5982\u679c\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5219\u662f\u5047\u9634\u6027\u3002 \u8ba9\u6211\u4eec\u9010\u4e00\u770b\u770b\u8fd9\u4e9b\u5b9e\u73b0\u3002 def true_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u8ba1\u6570\u5668 tp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 1 : tp += 1 # \u8fd4\u56de\u771f\u9633\u6027\u6837\u672c\u6570 return tp def true_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9634\u6027\u6837\u672c\u8ba1\u6570\u5668 tn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 0 : tn += 1 # \u8fd4\u56de\u771f\u9634\u6027\u6837\u672c\u6570 return tn def false_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9633\u6027\u8ba1\u6570\u5668 fp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 1 : fp += 1 # \u8fd4\u56de\u5047\u9633\u6027\u6837\u672c\u6570 return fp def false_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9634\u6027\u8ba1\u6570\u5668 fn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 0 : fn += 1 # \u8fd4\u56de\u5047\u9634\u6027\u6570 return fn \u6211\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e9b\u529f\u80fd\u7684\u65b9\u6cd5\u975e\u5e38\u7b80\u5355\uff0c\u800c\u4e14\u53ea\u9002\u7528\u4e8e\u4e8c\u5143\u5206\u7c7b\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u8fd9\u4e9b\u51fd\u6570\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: true_positive ( l1 , l2 ) Out [ X ]: 2 In [ X ]: false_positive ( l1 , l2 ) Out [ X ]: 1 In [ X ]: false_negative ( l1 , l2 ) Out [ X ]: 2 In [ X ]: true_negative ( l1 , l2 ) Out [ X ]: 3 \u5982\u679c\u6211\u4eec\u5fc5\u987b\u7528\u4e0a\u8ff0\u672f\u8bed\u6765\u5b9a\u4e49\u7cbe\u786e\u7387\uff0c\u6211\u4eec\u53ef\u4ee5\u5199\u4e3a\uff1a \\[ Accuracy Score = (TP + TN)/(TP + TN + FP +FN) \\] \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5728 python \u4e2d\u4f7f\u7528 TP\u3001TN\u3001FP \u548c FN \u5feb\u901f\u5b9e\u73b0\u51c6\u786e\u5ea6\u5f97\u5206\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a accuracy_v2\u3002 def accuracy_v2 ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u51c6\u786e\u7387 accuracy_score = ( tp + tn ) / ( tp + tn + fp + fn ) return accuracy_score \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0e\u4e4b\u524d\u7684\u5b9e\u73b0\u548c scikit-learn \u7248\u672c\u8fdb\u884c\u6bd4\u8f83\uff0c\u5feb\u901f\u68c0\u67e5\u8be5\u51fd\u6570\u7684\u6b63\u786e\u6027\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: accuracy ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: accuracy_v2 ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u8bf7\u6ce8\u610f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0cmetrics.accuracy_score \u6765\u81ea scikit-learn\u3002 \u5f88\u597d\u3002\u6240\u6709\u503c\u90fd\u5339\u914d\u3002\u8fd9\u8bf4\u660e\u6211\u4eec\u5728\u5b9e\u73b0\u8fc7\u7a0b\u4e2d\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u5176\u4ed6\u91cd\u8981\u6307\u6807\u3002 \u9996\u5148\u662f\u7cbe\u786e\u7387\u3002\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u662f \\[ Precision = TP/(TP + FP) \\] \u5047\u8bbe\u6211\u4eec\u5728\u65b0\u7684\u504f\u659c\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u4e86\u4e00\u4e2a\u65b0\u6a21\u578b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 90 \u5f20\u56fe\u50cf\u4e2d\u7684 80 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u56fe\u50cf\u4e2d\u7684 8 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6210\u529f\u8bc6\u522b\u4e86 100 \u5f20\u56fe\u50cf\u4e2d\u7684 88 \u5f20\u3002\u56e0\u6b64\uff0c\u51c6\u786e\u7387\u4e3a 0.88 \u6216 88%\u3002 \u4f46\u662f\uff0c\u5728\u8fd9 100 \u5f20\u6837\u672c\u4e2d\uff0c\u6709 10 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u6c14\u80f8\uff0c2 \u5f20\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u975e\u6c14\u80f8\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 TP : 8 TN: 80 FP: 10 FN: 2 \u7cbe\u786e\u7387\u4e3a 8 / (8 + 10) = 0.444\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u5728\u8bc6\u522b\u9633\u6027\u6837\u672c\uff08\u6c14\u80f8\uff09\u65f6\u6709 44.4% \u7684\u6b63\u786e\u7387\u3002 \u73b0\u5728\uff0c\u65e2\u7136\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86 TP\u3001TN\u3001FP \u548c FN\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5728 python \u4e2d\u5b9e\u73b0\u7cbe\u786e\u7387\u4e86\u3002 def precision ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8ba9\u6211\u4eec\u8bd5\u8bd5\u8fd9\u79cd\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u65b9\u5f0f\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: precision ( l1 , l2 ) Out [ X ]: 0.6666666666666666 \u8fd9\u4f3c\u4e4e\u6ca1\u6709\u95ee\u9898\u3002 \u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u770b \u53ec\u56de\u7387 \u3002\u53ec\u56de\u7387\u7684\u5b9a\u4e49\u662f\uff1a \\[ Recall = TP/(TP + FN) \\] \u5728\u4e0a\u8ff0\u60c5\u51b5\u4e0b\uff0c\u53ec\u56de\u7387\u4e3a 8 / (8 + 2) = 0.80\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 80% \u7684\u9633\u6027\u6837\u672c\u3002 def recall ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u53ec\u56de\u7387 recall = tp / ( tp + fn ) return recall \u5c31\u6211\u4eec\u7684\u4e24\u4e2a\u5c0f\u5217\u8868\u800c\u8a00\uff0c\u53ec\u56de\u7387\u5e94\u8be5\u662f 0.5\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: recall ( l1 , l2 ) Out [ X ]: 0.5 \u8fd9\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u503c\u76f8\u7b26\uff01 \u5bf9\u4e8e\u4e00\u4e2a \"\u597d \"\u6a21\u578b\u6765\u8bf4\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u503c\u90fd\u5e94\u8be5\u5f88\u9ad8\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0c\u53ec\u56de\u503c\u76f8\u5f53\u9ad8\u3002\u4f46\u662f\uff0c\u7cbe\u786e\u7387\u5374\u5f88\u4f4e\uff01\u6211\u4eec\u7684\u6a21\u578b\u4ea7\u751f\u4e86\u5927\u91cf\u7684\u8bef\u62a5\uff0c\u4f46\u8bef\u62a5\u8f83\u5c11\u3002\u5728\u8fd9\u7c7b\u95ee\u9898\u4e2d\uff0c\u5047\u9634\u6027\u8f83\u5c11\u662f\u597d\u4e8b\uff0c\u56e0\u4e3a\u4f60\u4e0d\u60f3\u5728\u75c5\u4eba\u6709\u6c14\u80f8\u7684\u60c5\u51b5\u4e0b\u5374\u8bf4\u4ed6\u4eec\u6ca1\u6709\u6c14\u80f8\u3002\u8fd9\u6837\u505a\u4f1a\u9020\u6210\u66f4\u5927\u7684\u4f24\u5bb3\u3002\u4f46\u6211\u4eec\u4e5f\u6709\u5f88\u591a\u5047\u9633\u6027\u7ed3\u679c\uff0c\u8fd9\u4e5f\u4e0d\u662f\u597d\u4e8b\u3002 \u5927\u591a\u6570\u6a21\u578b\u90fd\u4f1a\u9884\u6d4b\u4e00\u4e2a\u6982\u7387\uff0c\u5f53\u6211\u4eec\u9884\u6d4b\u65f6\uff0c\u901a\u5e38\u4f1a\u5c06\u8fd9\u4e2a\u9608\u503c\u9009\u4e3a 0.5\u3002\u8fd9\u4e2a\u9608\u503c\u5e76\u4e0d\u603b\u662f\u7406\u60f3\u7684\uff0c\u6839\u636e\u8fd9\u4e2a\u9608\u503c\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u503c\u53ef\u80fd\u4f1a\u53d1\u751f\u5f88\u5927\u7684\u53d8\u5316\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u7684\u6bcf\u4e2a\u9608\u503c\u90fd\u80fd\u8ba1\u7b97\u51fa\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u53ef\u4ee5\u5728\u8fd9\u4e9b\u503c\u4e4b\u95f4\u7ed8\u5236\u51fa\u66f2\u7ebf\u56fe\u3002\u8fd9\u5e45\u56fe\u6216\u66f2\u7ebf\u88ab\u79f0\u4e3a \"\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\"\u3002 \u5728\u7814\u7a76\u7cbe\u786e\u7387-\u8c03\u7528\u66f2\u7ebf\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5047\u8bbe\u6709\u4e24\u4e2a\u5217\u8868\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0.02638412 , 0.11114267 , 0.31620708 , ... : 0.0490937 , 0.0191491 , 0.17554844 , ... : 0.15952202 , 0.03819563 , 0.11639273 , ... : 0.079377 , 0.08584789 , 0.39095342 , ... : 0.27259048 , 0.03447096 , 0.04644807 , ... : 0.03543574 , 0.18521942 , 0.05934905 , ... : 0.61977213 , 0.33056815 ] \u56e0\u6b64\uff0cy_true \u662f\u6211\u4eec\u7684\u76ee\u6807\u503c\uff0c\u800c y_pred \u662f\u6837\u672c\u88ab\u8d4b\u503c\u4e3a 1 \u7684\u6982\u7387\u503c\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u6211\u4eec\u8981\u770b\u7684\u662f\u9884\u6d4b\u4e2d\u7684\u6982\u7387\uff0c\u800c\u4e0d\u662f\u9884\u6d4b\u503c\uff08\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u9884\u6d4b\u503c\u7684\u8ba1\u7b97\u9608\u503c\u4e3a 0.5\uff09\u3002 precisions = [] recalls = [] thresholds = [ 0.0490937 , 0.05934905 , 0.079377 , 0.08584789 , 0.11114267 , 0.11639273 , 0.15952202 , 0.17554844 , 0.18521942 , 0.27259048 , 0.31620708 , 0.33056815 , 0.39095342 , 0.61977213 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for i in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_prediction = [ 1 if x >= i else 0 for x in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , temp_prediction ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , temp_prediction ) # \u52a0\u5165\u7cbe\u786e\u7387\u5217\u8868 precisions . append ( p ) # \u52a0\u5165\u53ec\u56de\u7387\u5217\u8868 recalls . append ( r ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7ed8\u5236\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 7 , 7 )) # x\u8f74\u4e3a\u53ec\u56de\u7387\uff0cy\u8f74\u4e3a\u7cbe\u786e\u7387 plt . plot ( recalls , precisions ) # \u6dfb\u52a0x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a15 plt . xlabel ( 'Recall' , fontsize = 15 ) # \u6dfb\u52a0y\u8f74\u6807\u7b7e\uff0c\u5b57\u6761\u5927\u5c0f\u4e3a15 plt . ylabel ( 'Precision' , fontsize = 15 ) \u56fe 2 \u663e\u793a\u4e86\u6211\u4eec\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\u5f97\u5230\u7684\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 \u56fe 2\uff1a\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u8fd9\u6761 \u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u4e0e\u60a8\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230\u7684\u66f2\u7ebf\u622a\u7136\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6211\u4eec\u53ea\u6709 20 \u4e2a\u6837\u672c\uff0c\u5176\u4e2d\u53ea\u6709 3 \u4e2a\u662f\u9633\u6027\u6837\u672c\u3002\u4f46\u8fd9\u6ca1\u4ec0\u4e48\u597d\u62c5\u5fc3\u7684\u3002\u8fd9\u8fd8\u662f\u90a3\u6761\u7cbe\u786e\u7387-\u53ec\u56de\u66f2\u7ebf\u3002 \u4f60\u4f1a\u53d1\u73b0\uff0c\u9009\u62e9\u4e00\u4e2a\u65e2\u80fd\u63d0\u4f9b\u826f\u597d\u7cbe\u786e\u7387\u53c8\u80fd\u63d0\u4f9b\u53ec\u56de\u503c\u7684\u9608\u503c\u662f\u5f88\u6709\u6311\u6218\u6027\u7684\u3002\u5982\u679c\u9608\u503c\u8fc7\u9ad8\uff0c\u771f\u9633\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u51cf\u5c11\uff0c\u800c\u5047\u9634\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u589e\u52a0\u3002\u8fd9\u4f1a\u964d\u4f4e\u53ec\u56de\u7387\uff0c\u4f46\u7cbe\u786e\u7387\u5f97\u5206\u4f1a\u5f88\u9ad8\u3002\u5982\u679c\u5c06\u9608\u503c\u964d\u5f97\u592a\u4f4e\uff0c\u5219\u8bef\u62a5\u4f1a\u5927\u91cf\u589e\u52a0\uff0c\u7cbe\u786e\u7387\u4e5f\u4f1a\u964d\u4f4e\u3002 \u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u8d8a\u63a5\u8fd1 1 \u8d8a\u597d\u3002 F1 \u5206\u6570\u662f\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7efc\u5408\u6307\u6807\u3002\u5b83\u88ab\u5b9a\u4e49\u4e3a\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7b80\u5355\u52a0\u6743\u5e73\u5747\u503c\uff08\u8c03\u548c\u5e73\u5747\u503c\uff09\u3002\u5982\u679c\u6211\u4eec\u7528 P \u8868\u793a\u7cbe\u786e\u7387\uff0c\u7528 R \u8868\u793a\u53ec\u56de\u7387\uff0c\u90a3\u4e48 F1 \u5206\u6570\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a \\[ F1 = 2PR/(P + R) \\] \u6839\u636e TP\u3001FP \u548c FN\uff0c\u7a0d\u52a0\u6570\u5b66\u8ba1\u7b97\u5c31\u80fd\u5f97\u51fa\u4ee5\u4e0b F1 \u7b49\u5f0f\uff1a \\[ F1 = 2TP/(2TP + FP + FN) \\] Python \u5b9e\u73b0\u5f88\u7b80\u5355\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86\u8fd9\u4e9b def f1 ( y_true , y_pred ): # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , y_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , y_pred ) # \u8ba1\u7b97f1\u503c score = 2 * p * r / ( p + r ) return score \u8ba9\u6211\u4eec\u770b\u770b\u5176\u7ed3\u679c\uff0c\u5e76\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: f1 ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u901a\u8fc7 scikit learn\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u76f8\u540c\u7684\u5217\u8868\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . f1_score ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u4e0e\u5176\u5355\u72ec\u770b\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u60a8\u8fd8\u53ef\u4ee5\u53ea\u770b F1 \u5206\u6570\u3002\u4e0e\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c\u51c6\u786e\u5ea6\u4e00\u6837\uff0cF1 \u5206\u6570\u7684\u8303\u56f4\u4e5f\u662f\u4ece 0 \u5230 1\uff0c\u5b8c\u7f8e\u9884\u6d4b\u6a21\u578b\u7684 F1 \u5206\u6570\u4e3a 1\u3002 \u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5e94\u8be5\u4e86\u89e3\u5176\u4ed6\u4e00\u4e9b\u5173\u952e\u672f\u8bed\u3002 \u7b2c\u4e00\u4e2a\u672f\u8bed\u662f TPR \u6216\u771f\u9633\u6027\u7387\uff08True Positive Rate\uff09\uff0c\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\u3002 \\[ TPR = TP/(TP + FN) \\] \u5c3d\u7ba1\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\uff0c\u4f46\u6211\u4eec\u5c06\u4e3a\u5b83\u521b\u5efa\u4e00\u4e2a python \u51fd\u6570\uff0c\u4ee5\u4fbf\u4eca\u540e\u4f7f\u7528\u8fd9\u4e2a\u540d\u79f0\u3002 def tpr ( y_true , y_pred ): # \u771f\u9633\u6027\u7387\uff08TPR\uff09\uff0c\u4e0e\u53ec\u56de\u7387\u8ba1\u7b97\u516c\u5f0f\u4e00\u81f4 return recall ( y_true , y_pred ) TPR \u6216\u53ec\u56de\u7387\u4e5f\u88ab\u79f0\u4e3a\u7075\u654f\u5ea6\u3002 \u800c FPR \u6216\u5047\u9633\u6027\u7387\uff08False Positive Rate\uff09\u7684\u5b9a\u4e49\u662f\uff1a \\[ FPR = FP / (TN + FP) \\] def fpr ( y_true , y_pred ): # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u8fd4\u56de\u5047\u9633\u6027\u7387\uff08FPR\uff09 return fp / ( tn + fp ) 1 - FPR \u88ab\u79f0\u4e3a\u7279\u5f02\u6027\u6216\u771f\u9634\u6027\u7387\u6216 TNR\u3002\u8fd9\u4e9b\u672f\u8bed\u5f88\u591a\uff0c\u4f46\u5176\u4e2d\u6700\u91cd\u8981\u7684\u53ea\u6709 TPR \u548c FPR\u3002\u5047\u8bbe\u6211\u4eec\u53ea\u6709 15 \u4e2a\u6837\u672c\uff0c\u5176\u76ee\u6807\u503c\u4e3a\u4e8c\u5143\uff1a Actual targets : [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1] \u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u7c7b\u4f3c\u968f\u673a\u68ee\u6797\u7684\u6a21\u578b\uff0c\u5c31\u80fd\u5f97\u5230\u6837\u672c\u5448\u9633\u6027\u7684\u6982\u7387\u3002 Predicted probabilities for 1: [0.1, 0.3, 0.2, 0.6, 0.8, 0.05, 0.9, 0.5, 0.3, 0.66, 0.3, 0.2, 0.85, 0.15, 0.99] \u5bf9\u4e8e >= 0.5 \u7684\u5178\u578b\u9608\u503c\uff0c\u6211\u4eec\u53ef\u4ee5\u8bc4\u4f30\u4e0a\u8ff0\u6240\u6709\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387/TPR\u3001F1 \u548c FPR \u503c\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u5c06\u9608\u503c\u9009\u4e3a 0.4 \u6216 0.6\uff0c\u4e5f\u53ef\u4ee5\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0 \u5230 1 \u4e4b\u95f4\u7684\u4efb\u4f55\u503c\uff0c\u5e76\u8ba1\u7b97\u4e0a\u8ff0\u6240\u6709\u6307\u6807\u3002 \u4e0d\u8fc7\uff0c\u6211\u4eec\u53ea\u8ba1\u7b97\u4e24\u4e2a\u503c\uff1a TPR \u548c FPR\u3002 # \u521d\u59cb\u5316\u771f\u9633\u6027\u7387\u5217\u8868 tpr_list = [] # \u521d\u59cb\u5316\u5047\u9633\u6027\u7387\u5217\u8868 fpr_list = [] # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u7387 temp_tpr = tpr ( y_true , temp_pred ) # \u5047\u9633\u6027\u7387 temp_fpr = fpr ( y_true , temp_pred ) # \u5c06\u771f\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 tpr_list . append ( temp_tpr ) # \u5c06\u5047\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 fpr_list . append ( temp_fpr ) \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u6bcf\u4e2a\u9608\u503c\u7684 TPR \u503c\u548c FPR \u503c\u3002 \u56fe 3\uff1a\u9608\u503c\u3001TPR \u548c FPR \u503c\u8868 \u5982\u679c\u6211\u4eec\u7ed8\u5236\u5982\u56fe 3 \u6240\u793a\u7684\u8868\u683c\uff0c\u5373\u4ee5 TPR \u4e3a Y \u8f74\uff0cFPR \u4e3a X \u8f74\uff0c\u5c31\u4f1a\u5f97\u5230\u5982\u56fe 4 \u6240\u793a\u7684\u66f2\u7ebf\u3002 \u56fe 4\uff1aROC\u66f2\u7ebf \u8fd9\u6761\u66f2\u7ebf\u4e5f\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u3002\u5982\u679c\u6211\u4eec\u8ba1\u7b97\u8fd9\u6761 ROC \u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff0c\u5c31\u662f\u5728\u8ba1\u7b97\u53e6\u4e00\u4e2a\u6307\u6807\uff0c\u5f53\u6570\u636e\u96c6\u7684\u4e8c\u5143\u76ee\u6807\u504f\u659c\u65f6\uff0c\u8fd9\u4e2a\u6307\u6807\u5c31\u4f1a\u975e\u5e38\u5e38\u7528\u3002 \u8fd9\u4e2a\u6307\u6807\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u6216\u66f2\u7ebf\u4e0b\u9762\u79ef\uff0c\u7b80\u79f0 AUC\u3002\u8ba1\u7b97 ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002\u5728\u6b64\uff0c\u6211\u4eec\u5c06\u91c7\u7528 scikit- learn \u7684\u5947\u5999\u5b9e\u73b0\u65b9\u6cd5\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: metrics . roc_auc_score ( y_true , y_pred ) Out [ X ]: 0.8300000000000001 AUC \u503c\u4ece 0 \u5230 1 \u4e0d\u7b49\u3002 AUC = 1 \u610f\u5473\u7740\u60a8\u62e5\u6709\u4e00\u4e2a\u5b8c\u7f8e\u7684\u6a21\u578b\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u610f\u5473\u7740\u4f60\u5728\u9a8c\u8bc1\u65f6\u72af\u4e86\u4e00\u4e9b\u9519\u8bef\uff0c\u5e94\u8be5\u91cd\u65b0\u5ba1\u89c6\u6570\u636e\u5904\u7406\u548c\u9a8c\u8bc1\u6d41\u7a0b\u3002\u5982\u679c\u4f60\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\uff0c\u90a3\u4e48\u606d\u559c\u4f60\uff0c\u4f60\u5df2\u7ecf\u62e5\u6709\u4e86\u9488\u5bf9\u6570\u636e\u96c6\u5efa\u7acb\u7684\u6700\u4f73\u6a21\u578b\u3002 AUC = 0 \u610f\u5473\u7740\u60a8\u7684\u6a21\u578b\u975e\u5e38\u7cdf\u7cd5\uff08\u6216\u975e\u5e38\u597d\uff01\uff09\u3002\u8bd5\u7740\u53cd\u8f6c\u9884\u6d4b\u7684\u6982\u7387\uff0c\u4f8b\u5982\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u6b63\u7c7b\u7684\u6982\u7387\u662f p\uff0c\u8bd5\u7740\u7528 1-p \u4ee3\u66ff\u5b83\u3002\u8fd9\u79cd AUC \u4e5f\u53ef\u80fd\u610f\u5473\u7740\u60a8\u7684\u9a8c\u8bc1\u6216\u6570\u636e\u5904\u7406\u5b58\u5728\u95ee\u9898\u3002 AUC = 0.5 \u610f\u5473\u7740\u4f60\u7684\u9884\u6d4b\u662f\u968f\u673a\u7684\u3002\u56e0\u6b64\uff0c\u5bf9\u4e8e\u4efb\u4f55\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u5982\u679c\u6211\u5c06\u6240\u6709\u76ee\u6807\u90fd\u9884\u6d4b\u4e3a 0.5\uff0c\u6211\u5c06\u5f97\u5230 0.5 \u7684 AUC\u3002 AUC \u503c\u4ecb\u4e8e 0 \u548c 0.5 \u4e4b\u95f4\uff0c\u610f\u5473\u7740\u4f60\u7684\u6a21\u578b\u6bd4\u968f\u673a\u6a21\u578b\u66f4\u5dee\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u662f\u56e0\u4e3a\u4f60\u98a0\u5012\u4e86\u7c7b\u522b\u3002 \u5982\u679c\u60a8\u5c1d\u8bd5\u53cd\u8f6c\u9884\u6d4b\uff0c\u60a8\u7684 AUC \u503c\u53ef\u80fd\u4f1a\u8d85\u8fc7 0.5\u3002\u63a5\u8fd1 1 \u7684 AUC \u503c\u88ab\u8ba4\u4e3a\u662f\u597d\u503c\u3002 \u4f46 AUC \u5bf9\u6211\u4eec\u7684\u6a21\u578b\u6709\u4ec0\u4e48\u5f71\u54cd\u5462\uff1f \u5047\u8bbe\u60a8\u5efa\u7acb\u4e86\u4e00\u4e2a\u4ece\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e2d\u68c0\u6d4b\u6c14\u80f8\u7684\u6a21\u578b\uff0c\u5176 AUC \u503c\u4e3a 0.85\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5982\u679c\u60a8\u4ece\u6570\u636e\u96c6\u4e2d\u968f\u673a\u9009\u62e9\u4e00\u5f20\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9633\u6027\u6837\u672c\uff09\u548c\u53e6\u4e00\u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9634\u6027\u6837\u672c\uff09\uff0c\u90a3\u4e48\u6c14\u80f8\u56fe\u50cf\u7684\u6392\u540d\u5c06\u9ad8\u4e8e\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u6982\u7387\u4e3a 0.85\u3002 \u8ba1\u7b97\u6982\u7387\u548c AUC \u540e\uff0c\u60a8\u9700\u8981\u5bf9\u6d4b\u8bd5\u96c6\u8fdb\u884c\u9884\u6d4b\u3002\u6839\u636e\u95ee\u9898\u548c\u4f7f\u7528\u60c5\u51b5\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u6982\u7387\u6216\u5b9e\u9645\u7c7b\u522b\u3002\u5982\u679c\u4f60\u60f3\u8981\u6982\u7387\uff0c\u8fd9\u5e76\u4e0d\u96be\u3002\u5982\u679c\u60a8\u60f3\u8981\u7c7b\u522b\uff0c\u5219\u9700\u8981\u9009\u62e9\u4e00\u4e2a\u9608\u503c\u3002\u5728\u4e8c\u5143\u5206\u7c7b\u7684\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u91c7\u7528\u7c7b\u4f3c\u4e0b\u9762\u7684\u65b9\u6cd5\u3002 \\[ Prediction = Probability >= Threshold \\] \u4e5f\u5c31\u662f\u8bf4\uff0c\u9884\u6d4b\u662f\u4e00\u4e2a\u53ea\u5305\u542b\u4e8c\u5143\u53d8\u91cf\u7684\u65b0\u5217\u8868\u3002\u5982\u679c\u6982\u7387\u5927\u4e8e\u6216\u7b49\u4e8e\u7ed9\u5b9a\u7684\u9608\u503c\uff0c\u5219\u9884\u6d4b\u4e2d\u7684\u4e00\u9879\u4e3a 1\uff0c\u5426\u5219\u4e3a 0\u3002 \u4f60\u731c\u600e\u4e48\u7740\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 ROC \u66f2\u7ebf\u6765\u9009\u62e9\u8fd9\u4e2a\u9608\u503c\uff01ROC \u66f2\u7ebf\u4f1a\u544a\u8bc9\u60a8\u9608\u503c\u5bf9\u5047\u9633\u6027\u7387\u548c\u771f\u9633\u6027\u7387\u7684\u5f71\u54cd\uff0c\u8fdb\u800c\u5f71\u54cd\u5047\u9633\u6027\u548c\u771f\u9633\u6027\u3002\u60a8\u5e94\u8be5\u9009\u62e9\u6700\u9002\u5408\u60a8\u7684\u95ee\u9898\u548c\u6570\u636e\u96c6\u7684\u9608\u503c\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u4e0d\u5e0c\u671b\u6709\u592a\u591a\u7684\u8bef\u62a5\uff0c\u90a3\u4e48\u9608\u503c\u5c31\u5e94\u8be5\u9ad8\u4e00\u4e9b\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e5f\u4f1a\u5e26\u6765\u66f4\u591a\u7684\u8bef\u62a5\u3002\u6ce8\u610f\u6743\u8861\u5229\u5f0a\uff0c\u9009\u62e9\u6700\u4f73\u9608\u503c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u9608\u503c\u5982\u4f55\u5f71\u54cd\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u503c\u3002 # \u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list = [] # \u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list = [] # \u771f\u5b9e\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 temp_tp = true_positive ( y_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 temp_fp = false_positive ( y_true , temp_pred ) # \u52a0\u5165\u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list . append ( temp_tp ) # \u52a0\u5165\u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list . append ( temp_fp ) \u5229\u7528\u8fd9\u4e00\u70b9\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8868\u683c\uff0c\u5982\u56fe 5 \u6240\u793a\u3002 \u56fe 5\uff1a\u4e0d\u540c\u9608\u503c\u7684 TP \u503c\u548c FP \u503c \u5982\u56fe 6 \u6240\u793a\uff0c\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0cROC \u66f2\u7ebf\u5de6\u4e0a\u89d2\u7684\u503c\u5e94\u8be5\u662f\u4e00\u4e2a\u76f8\u5f53\u4e0d\u9519\u7684\u9608\u503c\u3002 \u5bf9\u6bd4\u8868\u683c\u548c ROC \u66f2\u7ebf\uff0c\u6211\u4eec\u53ef\u4ee5\u53d1\u73b0\uff0c0.6 \u5de6\u53f3\u7684\u9608\u503c\u76f8\u5f53\u4e0d\u9519\uff0c\u65e2\u4e0d\u4f1a\u4e22\u5931\u5927\u91cf\u7684\u771f\u9633\u6027\u7ed3\u679c\uff0c\u4e5f\u4e0d\u4f1a\u51fa\u73b0\u5927\u91cf\u7684\u5047\u9633\u6027\u7ed3\u679c\u3002 \u56fe 6\uff1a\u4ece ROC \u66f2\u7ebf\u6700\u5de6\u4fa7\u7684\u9876\u70b9\u9009\u62e9\u6700\u4f73\u9608\u503c AUC \u662f\u4e1a\u5185\u5e7f\u6cdb\u5e94\u7528\u4e8e\u504f\u659c\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6307\u6807\uff0c\u4e5f\u662f\u6bcf\u4e2a\u4eba\u90fd\u5e94\u8be5\u4e86\u89e3\u7684\u6307\u6807\u3002\u4e00\u65e6\u7406\u89e3\u4e86 AUC \u80cc\u540e\u7684\u7406\u5ff5\uff08\u5982\u4e0a\u6587\u6240\u8ff0\uff09\uff0c\u4e5f\u5c31\u5f88\u5bb9\u6613\u5411\u4e1a\u754c\u53ef\u80fd\u4f1a\u8bc4\u4f30\u60a8\u7684\u6a21\u578b\u7684\u975e\u6280\u672f\u4eba\u5458\u89e3\u91ca\u5b83\u4e86\u3002 \u5b66\u4e60 AUC \u540e\uff0c\u4f60\u5e94\u8be5\u5b66\u4e60\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u6307\u6807\u662f\u5bf9\u6570\u635f\u5931\u3002\u5bf9\u4e8e\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5c06\u5bf9\u6570\u635f\u5931\u5b9a\u4e49\u4e3a\uff1a \\[ LogLoss = -1.0 \\times (target \\times log(prediction) + (1-target) \\times log(1-prediction)) \\] \u5176\u4e2d\uff0c\u76ee\u6807\u503c\u4e3a 0 \u6216 1\uff0c\u9884\u6d4b\u503c\u4e3a\u6837\u672c\u5c5e\u4e8e\u7c7b\u522b 1 \u7684\u6982\u7387\u3002 \u5bf9\u4e8e\u6570\u636e\u96c6\u4e2d\u7684\u591a\u4e2a\u6837\u672c\uff0c\u6240\u6709\u6837\u672c\u7684\u5bf9\u6570\u635f\u5931\u53ea\u662f\u6240\u6709\u5355\u4e2a\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u9700\u8981\u8bb0\u4f4f\u7684\u4e00\u70b9\u662f\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u4e0d\u6b63\u786e\u6216\u504f\u5dee\u8f83\u5927\u7684\u9884\u6d4b\u8fdb\u884c\u76f8\u5f53\u9ad8\u7684\u60e9\u7f5a\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u975e\u5e38\u786e\u5b9a\u548c\u975e\u5e38\u9519\u8bef\u7684\u9884\u6d4b\u8fdb\u884c\u60e9\u7f5a\u3002 import numpy as np def log_loss ( y_true , y_proba ): # \u6781\u5c0f\u503c\uff0c\u9632\u6b620\u505a\u5206\u6bcd epsilon = 1e-15 # \u5bf9\u6570\u635f\u5931\u5217\u8868 loss = [] # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_proba ): # \u9650\u5236yp\u8303\u56f4\uff0c\u6700\u5c0f\u4e3aepsilon\uff0c\u6700\u5927\u4e3a1-epsilon yp = np . clip ( yp , epsilon , 1 - epsilon ) # \u8ba1\u7b97\u5bf9\u6570\u635f\u5931 temp_loss = - 1.0 * ( yt * np . log ( yp ) + ( 1 - yt ) * np . log ( 1 - yp )) # \u52a0\u5165\u5bf9\u6570\u635f\u5931\u5217\u8868 loss . append ( temp_loss ) return np . mean ( loss ) \u8ba9\u6211\u4eec\u6d4b\u8bd5\u4e00\u4e0b\u51fd\u6570\u6267\u884c\u60c5\u51b5\uff1a In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_proba = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u6211\u4eec\u53ef\u4ee5\u5c06\u5176\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u5b9e\u73b0\u662f\u6b63\u786e\u7684\u3002 \u5bf9\u6570\u635f\u5931\u7684\u5b9e\u73b0\u5f88\u5bb9\u6613\u3002\u89e3\u91ca\u8d77\u6765\u4f3c\u4e4e\u6709\u70b9\u56f0\u96be\u3002\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5bf9\u6570\u635f\u5931\u7684\u60e9\u7f5a\u8981\u6bd4\u5176\u4ed6\u6307\u6807\u5927\u5f97\u591a\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u6709 51% \u7684\u628a\u63e1\u8ba4\u4e3a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\uff0c\u90a3\u4e48\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.51) + (1 - 1) \\times log(1 - 0.51))=0.67 \\] \u5982\u679c\u4f60\u5bf9\u5c5e\u4e8e 0 \u7c7b\u7684\u6837\u672c\u6709 49% \u7684\u628a\u63e1\uff0c\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.49) + (1 - 1) \\times log(1 - 0.49))=0.67 \\] \u56e0\u6b64\uff0c\u5373\u4f7f\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0.5 \u7684\u622a\u65ad\u503c\u5e76\u5f97\u5230\u5b8c\u7f8e\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u6211\u4eec\u4ecd\u7136\u4f1a\u6709\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002\u56e0\u6b64\uff0c\u5728\u5904\u7406\u5bf9\u6570\u635f\u5931\u65f6\uff0c\u4f60\u9700\u8981\u975e\u5e38\u5c0f\u5fc3\uff1b\u4efb\u4f55\u4e0d\u786e\u5b9a\u7684\u9884\u6d4b\u90fd\u4f1a\u4ea7\u751f\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002 \u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u5927\u591a\u6570\u6307\u6807\u90fd\u53ef\u4ee5\u8f6c\u6362\u6210\u591a\u7c7b\u7248\u672c\u3002\u8fd9\u4e2a\u60f3\u6cd5\u5f88\u7b80\u5355\u3002\u4ee5\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u4e3a\u4f8b\u3002\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u6bcf\u4e00\u7c7b\u7684\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u6709\u4e09\u79cd\u4e0d\u540c\u7684\u8ba1\u7b97\u65b9\u6cd5\uff0c\u6709\u65f6\u53ef\u80fd\u4f1a\u4ee4\u4eba\u56f0\u60d1\u3002\u5047\u8bbe\u6211\u4eec\u9996\u5148\u5bf9\u7cbe\u786e\u7387\u611f\u5174\u8da3\u3002\u6211\u4eec\u77e5\u9053\uff0c\u7cbe\u786e\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u3002 \u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Macro averaged precision\uff09\uff1a\u5206\u522b\u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u7cbe\u786e\u7387\u7136\u540e\u6c42\u5e73\u5747\u503c \u5fae\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Micro averaged precision\uff09\uff1a\u6309\u7c7b\u8ba1\u7b97\u771f\u9633\u6027\u548c\u5047\u9633\u6027\uff0c\u7136\u540e\u7528\u5176\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387\u3002\u7136\u540e\u4ee5\u6b64\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387 \u52a0\u6743\u7cbe\u786e\u7387 \uff08Weighted precision\uff09\uff1a\u4e0e\u5b8f\u89c2\u7cbe\u786e\u7387\u76f8\u540c\uff0c\u4f46\u8fd9\u91cc\u662f\u52a0\u6743\u5e73\u5747\u7cbe\u786e\u7387 \u53d6\u51b3\u4e8e\u6bcf\u4e2a\u7c7b\u522b\u4e2d\u7684\u9879\u76ee\u6570 \u8fd9\u770b\u4f3c\u590d\u6742\uff0c\u4f46\u5728 python \u5b9e\u73b0\u4e2d\u5f88\u5bb9\u6613\u7406\u89e3\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import numpy as np def macro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u5982\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u8ba1\u7b97\u7cbe\u786e\u5ea6 temp_precision = tp / ( tp + fp ) # \u5404\u7c7b\u7cbe\u786e\u7387\u76f8\u52a0 precision += temp_precision # \u8ba1\u7b97\u5e73\u5747\u503c precision /= num_classes return precision \u4f60\u4f1a\u53d1\u73b0\u8fd9\u5e76\u4e0d\u96be\u3002\u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5fae\u5e73\u5747\u7cbe\u786e\u7387\u5206\u6570\u3002 import numpy as np def micro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u6570 tp = 0 # \u521d\u59cb\u5316\u5047\u9633\u6027\u6837\u672c\u6570 fp = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 tp += true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 fp += false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8fd9\u4e5f\u4e0d\u96be\u3002\u90a3\u4ec0\u4e48\u96be\uff1f\u4ec0\u4e48\u90fd\u4e0d\u96be\u3002\u673a\u5668\u5b66\u4e60\u5f88\u7b80\u5355\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u52a0\u6743\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u3002 from collections import Counter import numpy as np def weighted_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 temp_precision = tp / ( tp + fp ) # \u6839\u636e\u8be5\u79cd\u7c7b\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_precision = class_counts [ class_ ] * temp_precision # \u52a0\u6743\u7cbe\u786e\u7387\u6c42\u548c precision += weighted_precision # \u8ba1\u7b97\u5e73\u5747\u7cbe\u786e\u7387 overall_precision = precision / len ( y_true ) return overall_precision \u5c06\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff0c\u4ee5\u4e86\u89e3\u5b9e\u73b0\u662f\u5426\u6b63\u786e\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: macro_precision ( y_true , y_pred ) Out [ X ]: 0.3611111111111111 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"macro\" ) Out [ X ]: 0.3611111111111111 In [ X ]: micro_precision ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"micro\" ) Out [ X ]: 0.4444444444444444 In [ X ]: weighted_precision ( y_true , y_pred ) Out [ X ]: 0.39814814814814814 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.39814814814814814 \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6b63\u786e\u5730\u5b9e\u73b0\u4e86\u4e00\u5207\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u91cc\u5c55\u793a\u7684\u5b9e\u73b0\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\uff0c\u4f46\u5374\u662f\u6700\u5bb9\u6613\u7406\u89e3\u7684\u3002 \u540c\u6837\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5b9e\u73b0 \u591a\u7c7b\u522b\u7684\u53ec\u56de\u7387\u6307\u6807 \u3002\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u3001\u5047\u9633\u6027\u548c\u5047\u9634\u6027\uff0c\u800c F1 \u5219\u53d6\u51b3\u4e8e\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u53ec\u56de\u7387\u7684\u5b9e\u73b0\u65b9\u6cd5\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\uff0c\u8fd9\u91cc\u5b9e\u73b0\u7684\u662f\u591a\u7c7b F1 \u7684\u4e00\u4e2a\u7248\u672c\uff0c\u5373\u52a0\u6743\u5e73\u5747\u503c\u3002 from collections import Counter import numpy as np def weighted_f1 ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316F1\u503c f1 = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( temp_true , temp_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( temp_true , temp_pred ) # \u82e5\u7cbe\u786e\u7387+\u53ec\u56de\u7387\u4e0d\u4e3a0\uff0c\u5219\u4f7f\u7528\u516c\u5f0f\u8ba1\u7b97F1\u503c if p + r != 0 : temp_f1 = 2 * p * r / ( p + r ) # \u5426\u5219\u76f4\u63a5\u4e3a0 else : temp_f1 = 0 # \u6839\u636e\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_f1 = class_counts [ class_ ] * temp_f1 # \u52a0\u6743F1\u503c\u76f8\u52a0 f1 += weighted_f1 # \u8ba1\u7b97\u52a0\u6743\u5e73\u5747F1\u503c overall_f1 = f1 / len ( y_true ) return overall_f1 \u8bf7\u6ce8\u610f\uff0c\u4e0a\u9762\u6709\u51e0\u884c\u4ee3\u7801\u662f\u65b0\u5199\u7684\u3002\u56e0\u6b64\uff0c\u4f60\u5e94\u8be5\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e9b\u4ee3\u7801\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: weighted_f1 ( y_true , y_pred ) Out [ X ]: 0.41269841269841273 In [ X ]: metrics . f1_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.41269841269841273 \u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u591a\u7c7b\u95ee\u9898\u5b9e\u73b0\u4e86\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002\u540c\u6837\uff0c\u60a8\u4e5f\u53ef\u4ee5\u5c06 AUC \u548c\u5bf9\u6570\u635f\u5931\u8f6c\u6362\u4e3a\u591a\u7c7b\u683c\u5f0f\u3002\u8fd9\u79cd\u8f6c\u6362\u683c\u5f0f\u88ab\u79f0\u4e3a one-vs-all \u3002\u8fd9\u91cc\u6211\u4e0d\u6253\u7b97\u5b9e\u73b0\u5b83\u4eec\uff0c\u56e0\u4e3a\u5b9e\u73b0\u65b9\u6cd5\u4e0e\u6211\u4eec\u5df2\u7ecf\u8ba8\u8bba\u8fc7\u7684\u5f88\u76f8\u4f3c\u3002 \u5728\u4e8c\u5143\u6216\u591a\u7c7b\u5206\u7c7b\u4e2d\uff0c\u770b\u4e00\u4e0b \u6df7\u6dc6\u77e9\u9635 \u4e5f\u5f88\u6d41\u884c\u3002\u4e0d\u8981\u56f0\u60d1\uff0c\u8fd9\u5f88\u7b80\u5355\u3002\u6df7\u6dc6\u77e9\u9635\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u5305\u542b TP\u3001FP\u3001TN \u548c FN \u7684\u8868\u683c\u3002\u4f7f\u7528\u6df7\u6dc6\u77e9\u9635\uff0c\u60a8\u53ef\u4ee5\u5feb\u901f\u67e5\u770b\u6709\u591a\u5c11\u6837\u672c\u88ab\u9519\u8bef\u5206\u7c7b\uff0c\u6709\u591a\u5c11\u6837\u672c\u88ab\u6b63\u786e\u5206\u7c7b\u3002\u4e5f\u8bb8\u6709\u4eba\u4f1a\u8bf4\uff0c\u6df7\u6dc6\u77e9\u9635\u5e94\u8be5\u5728\u672c\u7ae0\u5f88\u65e9\u5c31\u8bb2\u5230\uff0c\u4f46\u6211\u6ca1\u6709\u8fd9\u4e48\u505a\u3002\u5982\u679c\u4e86\u89e3\u4e86 TP\u3001FP\u3001TN\u3001FN\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c AUC\uff0c\u5c31\u5f88\u5bb9\u6613\u7406\u89e3\u548c\u89e3\u91ca\u6df7\u6dc6\u77e9\u9635\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u56fe 7 \u4e2d\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6df7\u6dc6\u77e9\u9635\u7531 TP\u3001FP\u3001FN \u548c TN \u7ec4\u6210\u3002\u6211\u4eec\u53ea\u9700\u8981\u8fd9\u4e9b\u503c\u6765\u8ba1\u7b97\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u3001F1 \u5206\u6570\u548c AUC\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u4e5f\u559c\u6b22\u628a FP \u79f0\u4e3a \u7b2c\u4e00\u7c7b\u9519\u8bef \uff0c\u628a FN \u79f0\u4e3a \u7b2c\u4e8c\u7c7b\u9519\u8bef \u3002 \u56fe 7\uff1a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6df7\u6dc6\u77e9\u9635 \u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u4e8c\u5143\u6df7\u6dc6\u77e9\u9635\u6269\u5c55\u4e3a\u591a\u7c7b\u6df7\u6dc6\u77e9\u9635\u3002\u5b83\u4f1a\u662f\u4ec0\u4e48\u6837\u5b50\u5462\uff1f\u5982\u679c\u6211\u4eec\u6709 N \u4e2a\u7c7b\u522b\uff0c\u5b83\u5c06\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a NxN \u7684\u77e9\u9635\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b\uff0c\u6211\u4eec\u90fd\u8981\u8ba1\u7b97\u76f8\u5173\u7c7b\u522b\u548c\u5176\u4ed6\u7c7b\u522b\u7684\u6837\u672c\u603b\u6570\u3002\u4e3e\u4e2a\u4f8b\u5b50\u53ef\u4ee5\u8ba9\u6211\u4eec\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4ee5\u4e0b\u771f\u5b9e\u6807\u7b7e\uff1a \\[ [0, 1, 2, 0, 1, 2, 0, 2, 2] \\] \u6211\u4eec\u7684\u9884\u6d4b\u6807\u7b7e\u662f\uff1a \\[ [0, 2, 1, 0, 2, 1, 0, 0, 2] \\] \u90a3\u4e48\uff0c\u6211\u4eec\u7684\u6df7\u6dc6\u77e9\u9635\u5c06\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u591a\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635 \u56fe 8 \u8bf4\u660e\u4e86\u4ec0\u4e48\uff1f \u8ba9\u6211\u4eec\u6765\u770b\u770b 0 \u7c7b\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e 0 \u7c7b\u3002\u7136\u800c\uff0c\u5728\u9884\u6d4b\u4e2d\uff0c\u6211\u4eec\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 0 \u7c7b\uff0c1 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\u3002\u7406\u60f3\u60c5\u51b5\u4e0b\uff0c\u5bf9\u4e8e\u771f\u5b9e\u6807\u7b7e\u4e2d\u7684\u7c7b\u522b 0\uff0c\u9884\u6d4b\u6807\u7b7e 1 \u548c 2 \u5e94\u8be5\u6ca1\u6709\u4efb\u4f55\u6837\u672c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u7c7b\u522b 2\u3002\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 4\uff0c\u800c\u5728\u9884\u6d4b\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 3\u3002 \u4e00\u4e2a\u5b8c\u7f8e\u7684\u6df7\u6dc6\u77e9\u9635\u53ea\u80fd\u4ece\u5de6\u5230\u53f3\u659c\u5411\u586b\u5145\u3002 \u6df7\u6dc6\u77e9\u9635 \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u8ba1\u7b97\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u4e0d\u540c\u6307\u6807\u3002Scikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u76f4\u63a5\u7684\u65b9\u6cd5\u6765\u751f\u6210\u6df7\u6dc6\u77e9\u9635\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5728\u56fe 8 \u4e2d\u663e\u793a\u7684\u6df7\u6dc6\u77e9\u9635\u662f scikit-learn \u6df7\u6dc6\u77e9\u9635\u7684\u8f6c\u7f6e\uff0c\u539f\u59cb\u7248\u672c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u7ed8\u5236\u3002 import matplotlib.pyplot as plt import seaborn as sns from sklearn import metrics # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] # \u9884\u6d4b\u6837\u672c\u6807\u7b7e y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] # \u8ba1\u7b97\u6df7\u6dc6\u77e9\u9635 cm = metrics . confusion_matrix ( y_true , y_pred ) # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 10 , 10 )) # \u521b\u5efa\u65b9\u683c cmap = sns . cubehelix_palette ( 50 , hue = 0.05 , rot = 0 , light = 0.9 , dark = 0 , as_cmap = True ) # \u89c4\u5b9a\u5b57\u4f53\u5927\u5c0f sns . set ( font_scale = 2.5 ) # \u7ed8\u5236\u70ed\u56fe sns . heatmap ( cm , annot = True , cmap = cmap , cbar = False ) # y\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . ylabel ( 'Actual Labels' , fontsize = 20 ) # x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . xlabel ( 'Predicted Labels' , fontsize = 20 ) \u56e0\u6b64\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u89e3\u51b3\u4e86\u4e8c\u5143\u5206\u7c7b\u548c\u591a\u7c7b\u5206\u7c7b\u7684\u5ea6\u91cf\u95ee\u9898\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u8ba8\u8bba\u53e6\u4e00\u79cd\u7c7b\u578b\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u5373\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u5728\u591a\u6807\u7b7e\u5206\u7c7b\u4e2d\uff0c\u6bcf\u4e2a\u6837\u672c\u90fd\u53ef\u80fd\u4e0e\u4e00\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u76f8\u5173\u8054\u3002\u8fd9\u7c7b\u95ee\u9898\u7684\u4e00\u4e2a\u7b80\u5355\u4f8b\u5b50\u5c31\u662f\u8981\u6c42\u4f60\u9884\u6d4b\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53\u3002 \u56fe 9 \u663e\u793a\u4e86\u4e00\u4e2a\u8457\u540d\u6570\u636e\u96c6\u7684\u56fe\u50cf\u793a\u4f8b\u3002\u8bf7\u6ce8\u610f\uff0c\u8be5\u6570\u636e\u96c6\u7684\u76ee\u6807\u6709\u6240\u4e0d\u540c\uff0c\u4f46\u6211\u4eec\u6682\u4e14\u4e0d\u53bb\u8ba8\u8bba\u5b83\u3002\u6211\u4eec\u5047\u8bbe\u5176\u76ee\u7684\u53ea\u662f\u9884\u6d4b\u56fe\u50cf\u4e2d\u662f\u5426\u5b58\u5728\u67d0\u4e2a\u7269\u4f53\u3002\u5728\u56fe 9 \u4e2d\uff0c\u6211\u4eec\u6709\u6905\u5b50\u3001\u82b1\u76c6\u3001\u7a97\u6237\uff0c\u4f46\u6ca1\u6709\u5176\u4ed6\u7269\u4f53\uff0c\u5982\u7535\u8111\u3001\u5e8a\u3001\u7535\u89c6\u7b49\u3002\u56e0\u6b64\uff0c\u4e00\u5e45\u56fe\u50cf\u53ef\u80fd\u6709\u591a\u4e2a\u76f8\u5173\u76ee\u6807\u3002\u8fd9\u7c7b\u95ee\u9898\u5c31\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u95ee\u9898\u3002 \u56fe 9\uff1a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53 \u8fd9\u7c7b\u5206\u7c7b\u95ee\u9898\u7684\u8861\u91cf\u6807\u51c6\u6709\u4e9b\u4e0d\u540c\u3002\u4e00\u4e9b\u5408\u9002\u7684 \u6700\u5e38\u89c1\u7684\u6307\u6807\u6709\uff1a k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u786e\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 \u8ba9\u6211\u4eec\u4ece k \u7cbe\u786e\u7387\u6216\u8005 P@k \u6211\u4eec\u4e0d\u80fd\u5c06\u8fd9\u4e00\u7cbe\u786e\u7387\u4e0e\u524d\u9762\u8ba8\u8bba\u7684\u7cbe\u786e\u7387\u6df7\u6dc6\u3002\u5982\u679c\u60a8\u6709\u4e00\u4e2a\u7ed9\u5b9a\u6837\u672c\u7684\u539f\u59cb\u7c7b\u522b\u5217\u8868\u548c\u540c\u4e00\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u7c7b\u522b\u5217\u8868\uff0c\u90a3\u4e48\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u5c31\u662f\u9884\u6d4b\u5217\u8868\u4e2d\u4ec5\u8003\u8651\u524d k \u4e2a\u9884\u6d4b\u7ed3\u679c\u7684\u547d\u4e2d\u6570\u9664\u4ee5 k\u3002 \u5982\u679c\u60a8\u5bf9\u6b64\u611f\u5230\u56f0\u60d1\uff0c\u4f7f\u7528 python \u4ee3\u7801\u540e\u5c31\u4f1a\u660e\u767d\u3002 def pk ( y_true , y_pred , k ): # \u5982\u679ck\u4e3a0 if k == 0 : # \u8fd4\u56de0 return 0 # \u53d6\u9884\u6d4b\u6807\u7b7e\u524dk\u4e2a y_pred = y_pred [: k ] # \u5c06\u9884\u6d4b\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 pred_set = set ( y_pred ) # \u5c06\u771f\u5b9e\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 true_set = set ( y_true ) # \u9884\u6d4b\u6807\u7b7e\u96c6\u5408\u4e0e\u771f\u5b9e\u6807\u7b7e\u96c6\u5408\u4ea4\u96c6 common_values = pred_set . intersection ( true_set ) # \u8ba1\u7b97\u7cbe\u786e\u7387 return len ( common_values ) / len ( y_pred [: k ]) \u6709\u4e86\u4ee3\u7801\uff0c\u4e00\u5207\u90fd\u53d8\u5f97\u66f4\u5bb9\u6613\u7406\u89e3\u4e86\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e86 k \u5e73\u5747\u7cbe\u786e\u7387\u6216 AP@k \u3002AP@k \u662f\u901a\u8fc7 P@k \u8ba1\u7b97\u5f97\u51fa\u7684\u3002\u4f8b\u5982\uff0c\u5982\u679c\u8981\u8ba1\u7b97 AP@3\uff0c\u6211\u4eec\u8981\u5148\u8ba1\u7b97 P@1\u3001P@2 \u548c P@3\uff0c\u7136\u540e\u5c06\u603b\u548c\u9664\u4ee5 3\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u7684\u5b9e\u73b0\u3002 def apk ( y_true , y_pred , k ): # \u521d\u59cb\u5316P@k\u5217\u8868 pk_values = [] # \u904d\u53861~k for i in range ( 1 , k + 1 ): # \u5c06P@k\u52a0\u5165\u5217\u8868 pk_values . append ( pk ( y_true , y_pred , i )) # \u82e5\u957f\u5ea6\u4e3a0 if len ( pk_values ) == 0 : # \u8fd4\u56de0 return 0 # \u5426\u5219\u8ba1\u7b97AP@K return sum ( pk_values ) / len ( pk_values ) \u8fd9\u4e24\u4e2a\u51fd\u6570\u53ef\u4ee5\u7528\u6765\u8ba1\u7b97\u4e24\u4e2a\u7ed9\u5b9a\u5217\u8868\u7684 k \u5e73\u5747\u7cbe\u786e\u7387 (AP@k)\uff1b\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u8ba1\u7b97\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: for i in range ( len ( y_true )): ... : for j in range ( 1 , 4 ): ... : print ( ... : f \"\"\" ...: y_true= { y_true [ i ] } , ...: y_pred= { y_pred [ i ] } , ...: AP@ { j } = { apk ( y_true [ i ], y_pred [ i ], k = j ) } ...: \"\"\" ... : ) ... : y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 1 = 0.0 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 2 = 0.25 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 3 = 0.38888888888888884 \u8bf7\u6ce8\u610f\uff0c\u6211\u7701\u7565\u4e86\u8f93\u51fa\u7ed3\u679c\u4e2d\u7684\u8bb8\u591a\u6570\u503c\uff0c\u4f46\u4f60\u4f1a\u660e\u767d\u5176\u4e2d\u7684\u610f\u601d\u3002\u8fd9\u5c31\u662f\u6211\u4eec\u5982\u4f55\u8ba1\u7b97 AP@k \u7684\u65b9\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u7684 AP@k\u3002\u5728\u673a\u5668\u5b66\u4e60\u4e2d\uff0c\u6211\u4eec\u5bf9\u6240\u6709\u6837\u672c\u90fd\u611f\u5174\u8da3\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u6709 \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387 k \u6216 MAP@k \u3002MAP@k \u53ea\u662f AP@k \u7684\u5e73\u5747\u503c\uff0c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b python \u4ee3\u7801\u8f7b\u677e\u8ba1\u7b97\u3002 def mapk ( y_true , y_pred , k ): # \u521d\u59cb\u5316AP@k\u5217\u8868 apk_values = [] # \u904d\u53860~\uff08\u771f\u5b9e\u6807\u7b7e\u6570-1\uff09 for i in range ( len ( y_true )): # \u5c06AP@K\u52a0\u5165\u5217\u8868 apk_values . append ( apk ( y_true [ i ], y_pred [ i ], k = k ) ) # \u8ba1\u7b97\u5e73\u5747AP@k return sum ( apk_values ) / len ( apk_values ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u9488\u5bf9\u76f8\u540c\u7684\u5217\u8868\u8ba1\u7b97 k=1\u30012\u30013 \u548c 4 \u65f6\u7684 MAP@k\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: mapk ( y_true , y_pred , k = 1 ) Out [ X ]: 0.3333333333333333 In [ X ]: mapk ( y_true , y_pred , k = 2 ) Out [ X ]: 0.375 In [ X ]: mapk ( y_true , y_pred , k = 3 ) Out [ X ]: 0.3611111111111111 In [ X ]: mapk ( y_true , y_pred , k = 4 ) Out [ X ]: 0.34722222222222215 P@k\u3001AP@k \u548c MAP@k \u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u5176\u4e2d 1 \u4e3a\u6700\u4f73\u3002 \u8bf7\u6ce8\u610f\uff0c\u6709\u65f6\u60a8\u53ef\u80fd\u4f1a\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230 P@k \u548c AP@k \u7684\u4e0d\u540c\u5b9e\u73b0\u65b9\u5f0f\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5176\u4e2d\u4e00\u79cd\u5b9e\u73b0\u65b9\u5f0f\u3002 import numpy as np def apk ( actual , predicted , k = 10 ): # \u82e5\u9884\u6d4b\u6807\u7b7e\u957f\u5ea6\u5927\u4e8ek if len ( predicted ) > k : # \u53d6\u524dk\u4e2a\u6807\u7b7e predicted = predicted [: k ] score = 0.0 num_hits = 0.0 for i , p in enumerate ( predicted ): if p in actual and p not in predicted [: i ]: num_hits += 1.0 score += num_hits / ( i + 1.0 ) if not actual : return 0.0 return score / min ( len ( actual ), k ) \u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u662f AP@k \u7684\u53e6\u4e00\u4e2a\u7248\u672c\uff0c\u5176\u4e2d\u987a\u5e8f\u5f88\u91cd\u8981\uff0c\u6211\u4eec\u8981\u6743\u8861\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u7684\u7ed3\u679c\u4e0e\u6211\u7684\u4ecb\u7ecd\u7565\u6709\u4e0d\u540c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6765\u770b\u770b \u591a\u6807\u7b7e\u5206\u7c7b\u7684\u5bf9\u6570\u635f\u5931 \u3002\u8fd9\u5f88\u5bb9\u6613\u3002\u60a8\u53ef\u4ee5\u5c06\u76ee\u6807\u8f6c\u6362\u4e3a\u4e8c\u5143\u5206\u7c7b\uff0c\u7136\u540e\u5bf9\u6bcf\u4e00\u5217\u4f7f\u7528\u5bf9\u6570\u635f\u5931\u3002\u6700\u540e\uff0c\u4f60\u53ef\u4ee5\u6c42\u51fa\u6bcf\u5217\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u8fd9\u4e5f\u88ab\u79f0\u4e3a\u5e73\u5747\u5217\u5bf9\u6570\u635f\u5931\u3002\u5f53\u7136\uff0c\u8fd8\u6709\u5176\u4ed6\u65b9\u6cd5\u53ef\u4ee5\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff0c\u4f60\u5e94\u8be5\u5728\u9047\u5230\u65f6\u52a0\u4ee5\u63a2\u7d22\u3002 \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u8bf4\u5df2\u7ecf\u638c\u63e1\u4e86\u6240\u6709\u4e8c\u5143\u5206\u7c7b\u3001\u591a\u7c7b\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u6307\u6807\uff0c\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u6307\u6807\u3002 \u56de\u5f52\u4e2d\u6700\u5e38\u89c1\u7684\u6307\u6807\u662f \u8bef\u5dee\uff08Error\uff09 \u3002\u8bef\u5dee\u5f88\u7b80\u5355\uff0c\u4e5f\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \\[ Error = True\\ Value - Predicted\\ Value \\] \u7edd\u5bf9\u8bef\u5dee\uff08Absolute error\uff09 \u53ea\u662f\u4e0a\u8ff0\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\u3002 \\[ Absolute\\ Error = Abs(True\\ Value - Predicted\\ Value) \\] \u63a5\u4e0b\u6765\u6211\u4eec\u8ba8\u8bba \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff08MAE\uff09 \u3002\u5b83\u53ea\u662f\u6240\u6709\u7edd\u5bf9\u8bef\u5dee\u7684\u5e73\u5747\u503c\u3002 import numpy as np def mean_absolute_error ( y_true , y_pred ): #\u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u7edd\u5bf9\u8bef\u5dee error += np . abs ( yt - yp ) # \u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee return error / len ( y_true ) \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5e73\u65b9\u8bef\u5dee\u548c \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u3002 \\[ Squared\\ Error = (True Value - Predicted\\ Value)^2 \\] \u5747\u65b9\u8bef\u5dee\uff08MSE\uff09\u7684\u8ba1\u7b97\u65b9\u5f0f\u5982\u4e0b def mean_squared_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u8bef\u5dee\u5e73\u65b9\u548c error += ( yt - yp ) ** 2 # \u8ba1\u7b97\u5747\u65b9\u8bef\u5dee return error / len ( y_true ) MSE \u548c RMSE\uff08\u5747\u65b9\u6839\u8bef\u5dee\uff09 \u662f\u8bc4\u4f30\u56de\u5f52\u6a21\u578b\u6700\u5e38\u7528\u7684\u6307\u6807\u3002 \\[ RMSE = SQRT(MSE) \\] \u540c\u4e00\u7c7b\u8bef\u5dee\u7684\u53e6\u4e00\u79cd\u7c7b\u578b\u662f \u5e73\u65b9\u5bf9\u6570\u8bef\u5dee \u3002\u6709\u4eba\u79f0\u5176\u4e3a SLE \uff0c\u5f53\u6211\u4eec\u53d6\u6240\u6709\u6837\u672c\u4e2d\u8fd9\u4e00\u8bef\u5dee\u7684\u5e73\u5747\u503c\u65f6\uff0c\u5b83\u88ab\u79f0\u4e3a MSLE\uff08\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee\uff09\uff0c\u5b9e\u73b0\u65b9\u6cd5\u5982\u4e0b\u3002 import numpy as np def mean_squared_log_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee error += ( np . log ( 1 + yt ) - np . log ( 1 + yp )) ** 2 # \u8ba1\u7b97\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee return error / len ( y_true ) \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \u53ea\u662f\u5176\u5e73\u65b9\u6839\u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a RMSLE \u3002 \u7136\u540e\u662f\u767e\u5206\u6bd4\u8bef\u5dee\uff1a \\[ Percentage\\ Error = (( True\\ Value \u2013 Predicted\\ Value ) / True\\ Value ) \\times 100 \\] \u540c\u6837\u53ef\u4ee5\u8f6c\u6362\u4e3a\u6240\u6709\u6837\u672c\u7684\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee\u3002 def mean_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u767e\u5206\u6bd4\u8bef\u5dee error += ( yt - yp ) / yt # \u8fd4\u56de\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u7edd\u5bf9\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\uff08\u4e5f\u662f\u66f4\u5e38\u89c1\u7684\u7248\u672c\uff09\u88ab\u79f0\u4e3a \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee\u6216 MAPE \u3002 import numpy as np def mean_abs_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee error += np . abs ( yt - yp ) / yt #\u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u56de\u5f52\u7684\u6700\u5927\u4f18\u70b9\u662f\uff0c\u53ea\u6709\u51e0\u4e2a\u6700\u5e38\u7528\u7684\u6307\u6807\uff0c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u6240\u6709\u56de\u5f52\u95ee\u9898\u3002\u4e0e\u5206\u7c7b\u6307\u6807\u76f8\u6bd4\uff0c\u56de\u5f52\u6307\u6807\u66f4\u5bb9\u6613\u7406\u89e3\u3002 \u8ba9\u6211\u4eec\u6765\u8c08\u8c08\u53e6\u4e00\u4e2a\u56de\u5f52\u6307\u6807 \\(R^2\\) \uff08R \u65b9\uff09\uff0c\u4e5f\u79f0\u4e3a \u5224\u5b9a\u7cfb\u6570 \u3002 \u7b80\u5355\u5730\u8bf4\uff0cR \u65b9\u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u3002R \u65b9\u63a5\u8fd1 1.0 \u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u76f8\u5f53\u597d\uff0c\u800c\u63a5\u8fd1 0 \u5219\u8868\u793a\u6a21\u578b\u4e0d\u662f\u90a3\u4e48\u597d\u3002\u5f53\u6a21\u578b\u53ea\u662f\u505a\u51fa\u8352\u8c2c\u7684\u9884\u6d4b\u65f6\uff0cR \u65b9\u4e5f\u53ef\u80fd\u662f\u8d1f\u503c\u3002 R \u65b9\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\u6240\u793a\uff0c\u4f46 Python \u7684\u5b9e\u73b0\u603b\u662f\u80fd\u8ba9\u4e00\u5207\u66f4\u52a0\u6e05\u6670\u3002 \\[ R^2 = \\frac{\\sum^{N}_{i=1}(y_{t_i}-y_{p_i})^2}{\\sum^{N}_{i=1}(y_{t_i} - y_{t_{mean}})} \\] import numpy as np def r2 ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u5747\u771f\u5b9e\u503c mean_true_value = np . mean ( y_true ) # \u521d\u59cb\u5316\u5e73\u65b9\u8bef\u5dee numerator = 0 denominator = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): numerator += ( yt - yp ) ** 2 denominator += ( yt - mean_true_value ) ** 2 ratio = numerator / denominator # \u8ba1\u7b97R\u65b9 return 1 \u2013 ratio \u8fd8\u6709\u66f4\u591a\u7684\u8bc4\u4ef7\u6307\u6807\uff0c\u8fd9\u4e2a\u6e05\u5355\u6c38\u8fdc\u4e5f\u5217\u4e0d\u5b8c\u3002\u6211\u53ef\u4ee5\u5199\u4e00\u672c\u4e66\uff0c\u53ea\u4ecb\u7ecd\u4e0d\u540c\u7684\u8bc4\u4ef7\u6307\u6807\u3002\u4e5f\u8bb8\u6211\u4f1a\u7684\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b\u8bc4\u4f30\u6307\u6807\u51e0\u4e4e\u53ef\u4ee5\u6ee1\u8db3\u4f60\u60f3\u5c1d\u8bd5\u89e3\u51b3\u7684\u6240\u6709\u95ee\u9898\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ee5\u6700\u76f4\u63a5\u7684\u65b9\u5f0f\u5b9e\u73b0\u4e86\u8fd9\u4e9b\u6307\u6807\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u4eec\u4e0d\u591f\u9ad8\u6548\u3002\u4f60\u53ef\u4ee5\u901a\u8fc7\u6b63\u786e\u4f7f\u7528 numpy \u4ee5\u975e\u5e38\u9ad8\u6548\u7684\u65b9\u5f0f\u5b9e\u73b0\u5176\u4e2d\u5927\u90e8\u5206\u6307\u6807\u3002\u4f8b\u5982\uff0c\u770b\u770b\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\u7684\u5b9e\u73b0\uff0c\u4e0d\u9700\u8981\u4efb\u4f55\u5faa\u73af\u3002 import numpy as np def mae_np ( y_true , y_pred ): return np . mean ( np . abs ( y_true - y_pred )) \u6211\u672c\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5b9e\u73b0\u6240\u6709\u6307\u6807\uff0c\u4f46\u4e3a\u4e86\u5b66\u4e60\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b\u5e95\u5c42\u5b9e\u73b0\u3002\u4e00\u65e6\u4f60\u5b66\u4f1a\u4e86\u7eaf python \u7684\u5e95\u5c42\u5b9e\u73b0\uff0c\u5e76\u4e14\u4e0d\u4f7f\u7528\u5927\u91cf numpy\uff0c\u4f60\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u5176\u8f6c\u6362\u4e3a numpy\uff0c\u5e76\u4f7f\u5176\u53d8\u5f97\u66f4\u5feb\u3002 \u7136\u540e\u662f\u4e00\u4e9b\u9ad8\u7ea7\u5ea6\u91cf\u3002 \u5176\u4e2d\u4e00\u4e2a\u5e94\u7528\u76f8\u5f53\u5e7f\u6cdb\u7684\u6307\u6807\u662f \u4e8c\u6b21\u52a0\u6743\u5361\u5e15 \uff0c\u4e5f\u79f0\u4e3a QWK \u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a\u79d1\u6069\u5361\u5e15\u3002 QWK \u8861\u91cf\u4e24\u4e2a \"\u8bc4\u5206 \"\u4e4b\u95f4\u7684 \"\u4e00\u81f4\u6027\"\u3002\u8bc4\u5206\u53ef\u4ee5\u662f 0 \u5230 N \u4e4b\u95f4\u7684\u4efb\u4f55\u5b9e\u6570\uff0c\u9884\u6d4b\u4e5f\u5728\u540c\u4e00\u8303\u56f4\u5185\u3002\u4e00\u81f4\u6027\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u8fd9\u4e9b\u8bc4\u7ea7\u4e4b\u95f4\u7684\u63a5\u8fd1\u7a0b\u5ea6\u3002\u56e0\u6b64\uff0c\u5b83\u9002\u7528\u4e8e\u6709 N \u4e2a\u4e0d\u540c\u7c7b\u522b\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u4e00\u81f4\u5ea6\u9ad8\uff0c\u5206\u6570\u5c31\u66f4\u63a5\u8fd1 1.0\u3002Cohen's kappa \u5728 scikit-learn \u4e2d\u6709\u5f88\u597d\u7684\u5b9e\u73b0\uff0c\u5173\u4e8e\u8be5\u6307\u6807\u7684\u8be6\u7ec6\u8ba8\u8bba\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 ] In [ X ]: y_pred = [ 2 , 1 , 3 , 1 , 2 , 3 , 3 , 1 , 2 ] In [ X ]: metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) Out [ X ]: 0.33333333333333337 In [ X ]: metrics . accuracy_score ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 \u60a8\u53ef\u4ee5\u770b\u5230\uff0c\u5c3d\u7ba1\u51c6\u786e\u5ea6\u5f88\u9ad8\uff0c\u4f46 QWK \u5374\u5f88\u4f4e\u3002QWK \u5927\u4e8e 0.85 \u5373\u4e3a\u975e\u5e38\u597d\uff01 \u4e00\u4e2a\u91cd\u8981\u7684\u6307\u6807\u662f \u9a6c\u4fee\u76f8\u5173\u7cfb\u6570\uff08MCC\uff09 \u30021 \u4ee3\u8868\u5b8c\u7f8e\u9884\u6d4b\uff0c-1 \u4ee3\u8868\u4e0d\u5b8c\u7f8e\u9884\u6d4b\uff0c0 \u4ee3\u8868\u968f\u673a\u9884\u6d4b\u3002MCC \u7684\u8ba1\u7b97\u516c\u5f0f\u975e\u5e38\u7b80\u5355\u3002 \\[ MCC = \\frac{TP \\times TN - FP \\times FN}{\\sqrt{(TP + FP) \\times (FN + TN) \\times (FP + TN) \\times (TP + FN)}} \\] \u6211\u4eec\u770b\u5230\uff0cMCC \u8003\u8651\u4e86 TP\u3001FP\u3001TN \u548c FN\uff0c\u56e0\u6b64\u53ef\u7528\u4e8e\u5904\u7406\u7c7b\u504f\u659c\u7684\u95ee\u9898\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u7684\u65b9\u6cd5\u5728 python \u4e2d\u5feb\u901f\u5b9e\u73b0\u5b83\u3002 def mcc ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) numerator = ( tp * tn ) - ( fp * fn ) denominator = ( ( tp + fp ) * ( fn + tn ) * ( fp + tn ) * ( tp + fn ) ) denominator = denominator ** 0.5 return numerator / denominator \u8fd9\u4e9b\u6307\u6807\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002 \u9700\u8981\u6ce8\u610f\u7684\u4e00\u70b9\u662f\uff0c\u5728\u8bc4\u4f30\u975e\u76d1\u7763\u65b9\u6cd5\uff08\u4f8b\u5982\u67d0\u79cd\u805a\u7c7b\uff09\u65f6\uff0c\u6700\u597d\u521b\u5efa\u6216\u624b\u52a8\u6807\u8bb0\u6d4b\u8bd5\u96c6\uff0c\u5e76\u5c06\u5176\u4e0e\u5efa\u6a21\u90e8\u5206\u7684\u6240\u6709\u5185\u5bb9\u5206\u5f00\u3002\u5b8c\u6210\u805a\u7c7b\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u4efb\u4f55\u4e00\u79cd\u76d1\u7763\u5b66\u4e60\u6307\u6807\u6765\u8bc4\u4f30\u6d4b\u8bd5\u96c6\u7684\u6027\u80fd\u4e86\u3002 \u4e00\u65e6\u6211\u4eec\u4e86\u89e3\u4e86\u7279\u5b9a\u95ee\u9898\u5e94\u8be5\u4f7f\u7528\u4ec0\u4e48\u6307\u6807\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f00\u59cb\u66f4\u6df1\u5165\u5730\u7814\u7a76\u6211\u4eec\u7684\u6a21\u578b\uff0c\u4ee5\u6c42\u6539\u8fdb\u3002","title":"\u8bc4\u4f30\u6307\u6807"},{"location":"%E8%B6%85%E5%8F%82%E6%95%B0%E4%BC%98%E5%8C%96/","text":"\u8d85\u53c2\u6570\u4f18\u5316 \u6709\u4e86\u4f18\u79c0\u7684\u6a21\u578b\uff0c\u5c31\u6709\u4e86\u4f18\u5316\u8d85\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u5f97\u5206\u6a21\u578b\u7684\u96be\u9898\u3002\u90a3\u4e48\uff0c\u4ec0\u4e48\u662f\u8d85\u53c2\u6570\u4f18\u5316\u5462\uff1f\u5047\u8bbe\u60a8\u7684\u673a\u5668\u5b66\u4e60\u9879\u76ee\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u6d41\u7a0b\u3002\u6709\u4e00\u4e2a\u6570\u636e\u96c6\uff0c\u4f60\u76f4\u63a5\u5e94\u7528\u4e00\u4e2a\u6a21\u578b\uff0c\u7136\u540e\u5f97\u5230\u7ed3\u679c\u3002\u6a21\u578b\u5728\u8fd9\u91cc\u7684\u53c2\u6570\u88ab\u79f0\u4e3a\u8d85\u53c2\u6570\uff0c\u5373\u63a7\u5236\u6a21\u578b\u8bad\u7ec3/\u62df\u5408\u8fc7\u7a0b\u7684\u53c2\u6570\u3002\u5982\u679c\u6211\u4eec\u7528 SGD \u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u6a21\u578b\u7684\u53c2\u6570\u662f\u659c\u7387\u548c\u504f\u5dee\uff0c\u8d85\u53c2\u6570\u662f\u5b66\u4e60\u7387\u3002\u4f60\u4f1a\u53d1\u73b0\u6211\u5728\u672c\u7ae0\u548c\u672c\u4e66\u4e2d\u4ea4\u66ff\u4f7f\u7528\u8fd9\u4e9b\u672f\u8bed\u3002\u5047\u8bbe\u6a21\u578b\u4e2d\u6709\u4e09\u4e2a\u53c2\u6570 a\u3001b\u3001c\uff0c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u90fd\u53ef\u4ee5\u662f 1 \u5230 10 \u4e4b\u95f4\u7684\u6574\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u7684 \"\u6b63\u786e \"\u7ec4\u5408\u5c06\u4e3a\u60a8\u63d0\u4f9b\u6700\u4f73\u7ed3\u679c\u3002\u56e0\u6b64\uff0c\u8fd9\u5c31\u6709\u70b9\u50cf\u4e00\u4e2a\u88c5\u6709\u4e09\u62e8\u5bc6\u7801\u9501\u7684\u624b\u63d0\u7bb1\u3002\u4e0d\u8fc7\uff0c\u4e09\u62e8\u5bc6\u7801\u9501\u53ea\u6709\u4e00\u4e2a\u6b63\u786e\u7b54\u6848\u3002\u800c\u6a21\u578b\u6709\u5f88\u591a\u6b63\u786e\u7b54\u6848\u3002\u90a3\u4e48\uff0c\u5982\u4f55\u627e\u5230\u6700\u4f73\u53c2\u6570\u5462\uff1f\u4e00\u79cd\u65b9\u6cd5\u662f\u5bf9\u6240\u6709\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\uff0c\u770b\u54ea\u79cd\u7ec4\u5408\u80fd\u63d0\u9ad8\u6307\u6807\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u521d\u59cb\u5316\u6700\u4f73\u51c6\u786e\u5ea6 best_accuracy = 0 # \u521d\u59cb\u5316\u6700\u4f73\u53c2\u6570\u7684\u5b57\u5178 best_parameters = { \"a\" : 0 , \"b\" : 0 , \"c\" : 0 } # \u5faa\u73af\u904d\u5386 a \u7684\u53d6\u503c\u8303\u56f4 1~10 for a in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 b \u7684\u53d6\u503c\u8303\u56f4 1~10 for b in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 c \u7684\u53d6\u503c\u8303\u56f4 1~10 for c in range ( 1 , 11 ): # \u521b\u5efa\u6a21\u578b\uff0c\u4f7f\u7528 a\u3001b\u3001c \u53c2\u6570 model = MODEL ( a , b , c ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u6a21\u578b model . fit ( training_data ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u9a8c\u8bc1\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( validation_data ) # \u8ba1\u7b97\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 accuracy = metrics . accuracy_score ( targets , preds ) # \u5982\u679c\u5f53\u524d\u51c6\u786e\u5ea6\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u5219\u66f4\u65b0\u6700\u4f73\u51c6\u786e\u5ea6\u548c\u6700\u4f73\u53c2\u6570 if accuracy > best_accuracy : best_accuracy = accuracy best_parameters [ \"a\" ] = a best_parameters [ \"b\" ] = b best_parameters [ \"c\" ] = c \u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4ece 1 \u5230 10 \u5bf9\u6240\u6709\u53c2\u6570\u8fdb\u884c\u4e86\u62df\u5408\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u603b\u5171\u8981\u5bf9\u6a21\u578b\u8fdb\u884c 1000 \u6b21\uff0810 x 10 x 10\uff09\u62df\u5408\u3002\u8fd9\u53ef\u80fd\u4f1a\u5f88\u6602\u8d35\uff0c\u56e0\u4e3a\u6a21\u578b\u7684\u8bad\u7ec3\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u8fc7\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\u5e94\u8be5\u6ca1\u95ee\u9898\uff0c\u4f46\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\uff0c\u5e76\u4e0d\u662f\u53ea\u6709\u4e09\u4e2a\u53c2\u6570\uff0c\u6bcf\u4e2a\u53c2\u6570\u4e5f\u4e0d\u662f\u53ea\u6709\u5341\u4e2a\u503c\u3002 \u5927\u591a\u6570\u6a21\u578b\u53c2\u6570\u90fd\u662f\u5b9e\u6570\uff0c\u4e0d\u540c\u53c2\u6570\u7684\u7ec4\u5408\u53ef\u4ee5\u662f\u65e0\u9650\u7684\u3002 \u8ba9\u6211\u4eec\u770b\u770b scikit-learn \u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002 RandomForestClassifier ( n_estimators = 100 , criterion = 'gini' , max_depth = None , min_samples_split = 2 , min_samples_leaf = 1 , min_weight_fraction_leaf = 0.0 , max_features = 'auto' , max_leaf_nodes = None , min_impurity_decrease = 0.0 , min_impurity_split = None , bootstrap = True , oob_score = False , n_jobs = None , random_state = None , verbose = 0 , warm_start = False , class_weight = None , ccp_alpha = 0.0 , max_samples = None , ) \u6709 19 \u4e2a\u53c2\u6570\uff0c\u800c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u7684\u6240\u6709\u7ec4\u5408\uff0c\u4ee5\u53ca\u5b83\u4eec\u53ef\u4ee5\u627f\u62c5\u7684\u6240\u6709\u503c\uff0c\u90fd\u5c06\u662f\u65e0\u7a77\u65e0\u5c3d\u7684\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u6ca1\u6709\u8db3\u591f\u7684\u8d44\u6e90\u548c\u65f6\u95f4\u6765\u505a\u8fd9\u4ef6\u4e8b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6307\u5b9a\u4e86\u4e00\u4e2a\u53c2\u6570\u7f51\u683c\u3002\u5728\u8fd9\u4e2a\u7f51\u683c\u4e0a\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408\u7684\u641c\u7d22\u79f0\u4e3a\u7f51\u683c\u641c\u7d22\u3002\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0cn_estimators \u53ef\u4ee5\u662f 100\u3001200\u3001250\u3001300\u3001400\u3001500\uff1bmax_depth \u53ef\u4ee5\u662f 1\u30012\u30015\u30017\u300111\u300115\uff1bcriterion \u53ef\u4ee5\u662f gini \u6216 entropy\u3002\u8fd9\u4e9b\u53c2\u6570\u770b\u8d77\u6765\u5e76\u4e0d\u591a\uff0c\u4f46\u5982\u679c\u6570\u636e\u96c6\u8fc7\u5927\uff0c\u8ba1\u7b97\u8d77\u6765\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u53ef\u4ee5\u50cf\u4e4b\u524d\u4e00\u6837\u521b\u5efa\u4e09\u4e2a for \u5faa\u73af\uff0c\u5e76\u5728\u9a8c\u8bc1\u96c6\u4e0a\u8ba1\u7b97\u5f97\u5206\uff0c\u8fd9\u6837\u5c31\u80fd\u5b9e\u73b0\u7f51\u683c\u641c\u7d22\u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u8981\u8fdb\u884c k \u6298\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5219\u9700\u8981\u66f4\u591a\u7684\u5faa\u73af\uff0c\u8fd9\u610f\u5473\u7740\u9700\u8981\u66f4\u591a\u7684\u65f6\u95f4\u6765\u627e\u5230\u5b8c\u7f8e\u7684\u53c2\u6570\u3002\u56e0\u6b64\uff0c\u7f51\u683c\u641c\u7d22\u5e76\u4e0d\u6d41\u884c\u3002\u8ba9\u6211\u4eec\u4ee5\u6839\u636e \u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4 \u6570\u636e\u96c6\u4e3a\u4f8b\uff0c\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 \u56fe 1\uff1a\u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4\u6570\u636e\u96c6\u5c55\u793a \u8bad\u7ec3\u96c6\u4e2d\u53ea\u6709 2000 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u53ef\u4ee5\u8f7b\u677e\u5730\u4f7f\u7528\u5206\u5c42 kfold \u548c\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5177\u6709\u4e0a\u8ff0\u53c2\u6570\u8303\u56f4\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\uff0c\u5e76\u5728\u4e0b\u9762\u7684\u793a\u4f8b\u4e2d\u4e86\u89e3\u5982\u4f55\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u3002 # rf_grid_search.py import numpy as np import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u5220\u9664 price_range \u5217 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u53d6\u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\uff0c\u4f7f\u7528\u6240\u6709\u53ef\u7528\u7684 CPU \u6838\u5fc3\u8fdb\u884c\u8bad\u7ec3 classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid = { \"n_estimators\" : [ 100 , 200 , 250 , 300 , 400 , 500 ], \"max_depth\" : [ 1 , 2 , 5 , 7 , 11 , 15 ], \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22 model = model_selection . GridSearchCV ( estimator = classifier , param_grid = param_grid , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( f \"Best score: { model . best_score_ } \" ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u8fd9\u91cc\u6253\u5370\u4e86\u5f88\u591a\u5185\u5bb9\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u6700\u540e\u51e0\u884c\u3002 [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.895 , total = 1.0 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.890 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.910 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.880 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.870 , total = 1.1 s [ Parallel ( n_jobs = 1 )]: Done 360 out of 360 | elapsed : 3.7 min finished Best score : 0.889 Best parameters set : criterion : 'entropy' max_depth : 15 n_estimators : 500 \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c5\u6298\u4ea4\u53c9\u68c0\u9a8c\u6700\u4f73\u5f97\u5206\u662f 0.889\uff0c\u6211\u4eec\u7684\u7f51\u683c\u641c\u7d22\u5f97\u5230\u4e86\u6700\u4f73\u53c2\u6570\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u7684\u4e0b\u4e00\u4e2a\u6700\u4f73\u65b9\u6cd5\u662f \u968f\u673a\u641c\u7d22 \u3002\u5728\u968f\u673a\u641c\u7d22\u4e2d\uff0c\u6211\u4eec\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u53c2\u6570\u7ec4\u5408\uff0c\u7136\u540e\u8ba1\u7b97\u4ea4\u53c9\u9a8c\u8bc1\u5f97\u5206\u3002\u8fd9\u91cc\u6d88\u8017\u7684\u65f6\u95f4\u6bd4\u7f51\u683c\u641c\u7d22\u5c11\uff0c\u56e0\u4e3a\u6211\u4eec\u4e0d\u5bf9\u6240\u6709\u4e0d\u540c\u7684\u53c2\u6570\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\u3002\u6211\u4eec\u9009\u62e9\u8981\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u5c11\u6b21\u8bc4\u4f30\uff0c\u8fd9\u5c31\u51b3\u5b9a\u4e86\u641c\u7d22\u6240\u9700\u7684\u65f6\u95f4\u3002\u4ee3\u7801\u4e0e\u4e0a\u9762\u7684\u5dee\u522b\u4e0d\u5927\u3002\u9664 GridSearchCV \u5916\uff0c\u6211\u4eec\u4f7f\u7528 RandomizedSearchCV\u3002 if __name__ == \"__main__\" : classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u66f4\u6539\u641c\u7d22\u7a7a\u95f4 param_grid = { \"n_estimators\" : np . arange ( 100 , 1500 , 100 ), \"max_depth\" : np . arange ( 1 , 31 ), \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u968f\u673a\u53c2\u6570\u641c\u7d22 model = model_selection . RandomizedSearchCV ( estimator = classifier , param_distributions = param_grid , n_iter = 20 , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) print ( f \"Best score: { model . best_score_ } \" ) print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u6211\u4eec\u66f4\u6539\u4e86\u968f\u673a\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c\uff0c\u7ed3\u679c\u4f3c\u4e4e\u6709\u4e86\u4e9b\u8bb8\u6539\u8fdb\u3002 Best score : 0.8905 Best parameters set : criterion : entropy max_depth : 25 n_estimators : 300 \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8f83\u5c11\uff0c\u968f\u673a\u641c\u7d22\u6bd4\u7f51\u683c\u641c\u7d22\u66f4\u5feb\u3002\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\uff0c\u4f60\u53ef\u4ee5\u4e3a\u5404\u79cd\u6a21\u578b\u627e\u5230\u6700\u4f18\u53c2\u6570\uff0c\u53ea\u8981\u5b83\u4eec\u6709\u62df\u5408\u548c\u9884\u6d4b\u529f\u80fd\uff0c\u8fd9\u4e5f\u662f scikit-learn \u7684\u6807\u51c6\u3002\u6709\u65f6\uff0c\u4f60\u53ef\u80fd\u60f3\u4f7f\u7528\u7ba1\u9053\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u7531\u4e24\u5217\u6587\u672c\u7ec4\u6210\uff0c\u4f60\u9700\u8981\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u7c7b\u522b\u3002\u8ba9\u6211\u4eec\u5047\u8bbe\u4f60\u9009\u62e9\u7684\u7ba1\u9053\u662f\u9996\u5148\u4ee5\u534a\u76d1\u7763\u7684\u65b9\u5f0f\u5e94\u7528 tf-idf\uff0c\u7136\u540e\u4f7f\u7528 SVD \u548c SVM \u5206\u7c7b\u5668\u3002\u73b0\u5728\u7684\u95ee\u9898\u662f\uff0c\u6211\u4eec\u5fc5\u987b\u9009\u62e9 SVD \u7684\u6210\u5206\uff0c\u8fd8\u9700\u8981\u8c03\u6574 SVM \u7684\u53c2\u6570\u3002\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import numpy as np import pandas as pd from sklearn import metrics from sklearn import model_selection from sklearn import pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # \u8ba1\u7b97\u52a0\u6743\u4e8c\u6b21 Kappa \u5206\u6570 def quadratic_weighted_kappa ( y_true , y_pred ): return metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) if __name__ == '__main__' : # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( '../input/train.csv' ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u63d0\u53d6 id \u5217\u7684\u503c\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u7c7b\u578b\uff0c\u5b58\u50a8\u5728\u53d8\u91cf idx \u4e2d idx = test . id . values . astype ( int ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 train = train . drop ( 'id' , axis = 1 ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 test = test . drop ( 'id' , axis = 1 ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf 'relevance' \uff0c\u5b58\u50a8\u5728\u53d8\u91cf y \u4e2d y = train . relevance . values # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 traindata \u4e2d traindata = list ( train . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 testdata \u4e2d testdata = list ( test . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u521b\u5efa\u4e00\u4e2a TfidfVectorizer \u5bf9\u8c61 tfv\uff0c\u7528\u4e8e\u5c06\u6587\u672c\u6570\u636e\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv = TfidfVectorizer ( min_df = 3 , max_features = None , strip_accents = 'unicode' , analyzer = 'word' , token_pattern = r '\\w{1,}' , ngram_range = ( 1 , 3 ), use_idf = 1 , smooth_idf = 1 , sublinear_tf = 1 , stop_words = 'english' ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408 TfidfVectorizer\uff0c\u5c06\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv . fit ( traindata ) # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X X = tfv . transform ( traindata ) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X_test X_test = tfv . transform ( testdata ) # \u521b\u5efa TruncatedSVD \u5bf9\u8c61 svd\uff0c\u7528\u4e8e\u8fdb\u884c\u5947\u5f02\u503c\u5206\u89e3 svd = TruncatedSVD () # \u521b\u5efa StandardScaler \u5bf9\u8c61 scl\uff0c\u7528\u4e8e\u8fdb\u884c\u7279\u5f81\u7f29\u653e scl = StandardScaler () # \u521b\u5efa\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668\u5bf9\u8c61 svm_model svm_model = SVC () # \u521b\u5efa\u673a\u5668\u5b66\u4e60\u7ba1\u9053 clf\uff0c\u5305\u542b\u5947\u5f02\u503c\u5206\u89e3\u3001\u7279\u5f81\u7f29\u653e\u548c\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668 clf = pipeline . Pipeline ( [ ( 'svd' , svd ), ( 'scl' , scl ), ( 'svm' , svm_model ) ] ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid param_grid = { 'svd__n_components' : [ 200 , 300 ], 'svm__C' : [ 10 , 12 ] } # \u521b\u5efa\u81ea\u5b9a\u4e49\u7684\u8bc4\u5206\u51fd\u6570 kappa_scorer\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd kappa_scorer = metrics . make_scorer ( quadratic_weighted_kappa , greater_is_better = True ) # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model = model_selection . GridSearchCV ( estimator = clf , param_grid = param_grid , scoring = kappa_scorer , verbose = 10 , n_jobs =- 1 , refit = True , cv = 5 ) # \u4f7f\u7528 GridSearchCV \u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( \"Best score: %0.3f \" % model . best_score_ ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( \" \\t %s : %r \" % ( param_name , best_parameters [ param_name ])) # \u83b7\u53d6\u6700\u4f73\u6a21\u578b best_model = model . best_estimator_ best_model . fit ( X , y ) # \u4f7f\u7528\u6700\u4f73\u6a21\u578b\u8fdb\u884c\u9884\u6d4b preds = best_model . predict ( ... ) \u8fd9\u91cc\u663e\u793a\u7684\u7ba1\u9053\u5305\u62ec SVD\uff08\u5947\u5f02\u503c\u5206\u89e3\uff09\u3001\u6807\u51c6\u7f29\u653e\u548c SVM\uff08\u652f\u6301\u5411\u91cf\u673a\uff09\u6a21\u578b\u3002\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e\u6ca1\u6709\u8bad\u7ec3\u6570\u636e\uff0c\u60a8\u65e0\u6cd5\u6309\u539f\u6837\u8fd0\u884c\u4e0a\u8ff0\u4ee3\u7801\u3002\u5f53\u6211\u4eec\u8fdb\u5165\u9ad8\u7ea7\u8d85\u53c2\u6570\u4f18\u5316\u6280\u672f\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4e0d\u540c\u7c7b\u578b\u7684 \u6700\u5c0f\u5316\u7b97\u6cd5 \u6765\u7814\u7a76\u51fd\u6570\u7684\u6700\u5c0f\u5316\u3002\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u591a\u79cd\u6700\u5c0f\u5316\u51fd\u6570\u6765\u5b9e\u73b0\uff0c\u5982\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u3001\u5185\u5c14\u5fb7-\u6885\u5fb7\u4f18\u5316\u7b97\u6cd5\u3001\u4f7f\u7528\u8d1d\u53f6\u65af\u6280\u672f\u548c\u9ad8\u65af\u8fc7\u7a0b\u5bfb\u627e\u6700\u4f18\u53c2\u6570\u6216\u4f7f\u7528\u9057\u4f20\u7b97\u6cd5\u3002\u6211\u5c06\u5728 \"\u96c6\u5408\u4e0e\u5806\u53e0\uff08ensembling and stacking\uff09 \"\u4e00\u7ae0\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u548c Nelder-Mead \u7b97\u6cd5\u7684\u5e94\u7528\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u9ad8\u65af\u8fc7\u7a0b\u5982\u4f55\u7528\u4e8e\u8d85\u53c2\u6570\u4f18\u5316\u3002\u8fd9\u7c7b\u7b97\u6cd5\u9700\u8981\u4e00\u4e2a\u53ef\u4ee5\u4f18\u5316\u7684\u51fd\u6570\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u90fd\u662f\u6700\u5c0f\u5316\u8fd9\u4e2a\u51fd\u6570\uff0c\u5c31\u50cf\u6211\u4eec\u6700\u5c0f\u5316\u635f\u5931\u4e00\u6837\u3002 \u56e0\u6b64\uff0c\u6bd4\u65b9\u8bf4\uff0c\u4f60\u60f3\u627e\u5230\u6700\u4f73\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u663e\u7136\uff0c\u51c6\u786e\u5ea6\u8d8a\u9ad8\u8d8a\u597d\u3002\u73b0\u5728\uff0c\u6211\u4eec\u4e0d\u80fd\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\uff0c\u4f46\u6211\u4eec\u53ef\u4ee5\u5c06\u7cbe\u786e\u5ea6\u4e58\u4ee5-1\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u662f\u5728\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\u7684\u8d1f\u503c\uff0c\u4f46\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u6700\u5927\u5316\u7cbe\u786e\u5ea6\u3002 \u5728\u9ad8\u65af\u8fc7\u7a0b\u4e2d\u4f7f\u7528\u8d1d\u53f6\u65af\u4f18\u5316\uff0c\u53ef\u4ee5\u4f7f\u7528 scikit-optimize (skopt) \u5e93\u4e2d\u7684 gp_minimize \u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4f7f\u7528\u8be5\u51fd\u6570\u8c03\u6574\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u53c2\u6570\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from skopt import gp_minimize from skopt import space def optimize ( params , param_names , x , y ): # \u5c06\u53c2\u6570\u540d\u79f0\u548c\u5bf9\u5e94\u7684\u503c\u6253\u5305\u6210\u5b57\u5178 params = dict ( zip ( param_names , params )) # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u6a21\u578b\uff0c\u4f7f\u7528\u4f20\u5165\u7684\u53c2\u6570\u914d\u7f6e model = ensemble . RandomForestClassifier ( ** params ) # \u521b\u5efa StratifiedKFold \u4ea4\u53c9\u9a8c\u8bc1\u5bf9\u8c61\uff0c\u5c06\u6570\u636e\u5206\u4e3a 5 \u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6298\u53e0\u7684\u51c6\u786e\u5ea6\u7684\u5217\u8868 accuracies = [] # \u5faa\u73af\u904d\u5386\u6bcf\u4e2a\u6298\u53e0\u7684\u8bad\u7ec3\u548c\u6d4b\u8bd5\u6570\u636e for idx in kf . split ( X = x , y = y ): train_idx , test_idx = idx [ 0 ], idx [ 1 ] xtrain = x [ train_idx ] ytrain = y [ train_idx ] xtest = x [ test_idx ] ytest = y [ test_idx ] # \u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u62df\u5408\u6a21\u578b model . fit ( xtrain , ytrain ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( xtest ) # \u8ba1\u7b97\u6298\u53e0\u7684\u51c6\u786e\u5ea6 fold_accuracy = metrics . accuracy_score ( ytest , preds ) accuracies . append ( fold_accuracy ) # \u8fd4\u56de\u5e73\u5747\u51c6\u786e\u5ea6\u7684\u8d1f\u6570\uff08\u56e0\u4e3a skopt \u4f7f\u7528\u8d1f\u6570\u6765\u6700\u5c0f\u5316\u76ee\u6807\u51fd\u6570\uff09 return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u53d6\u7279\u5f81\u77e9\u9635 X\uff08\u53bb\u6389\"price_range\"\u5217\uff09 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u5b9a\u4e49\u8d85\u53c2\u6570\u641c\u7d22\u7a7a\u95f4 param_space param_space = [ space . Integer ( 3 , 15 , name = \"max_depth\" ), space . Integer ( 100 , 1500 , name = \"n_estimators\" ), space . Categorical ([ \"gini\" , \"entropy\" ], name = \"criterion\" ), space . Real ( 0.01 , 1 , prior = \"uniform\" , name = \"max_features\" ) ] # \u5b9a\u4e49\u8d85\u53c2\u6570\u7684\u540d\u79f0\u5217\u8868 param_names param_names = [ \"max_depth\" , \"n_estimators\" , \"criterion\" , \"max_features\" ] # \u521b\u5efa\u51fd\u6570 optimization_function\uff0c\u7528\u4e8e\u4f20\u9012\u7ed9 gp_minimize optimization_function = partial ( optimize , param_names = param_names , x = X , y = y ) # \u4f7f\u7528 Bayesian Optimization\uff08\u57fa\u4e8e\u8d1d\u53f6\u65af\u4f18\u5316\uff09\u6765\u641c\u7d22\u6700\u4f73\u8d85\u53c2\u6570 result = gp_minimize ( optimization_function , dimensions = param_space , n_calls = 15 , n_random_starts = 10 , verbose = 10 ) # \u83b7\u53d6\u6700\u4f73\u8d85\u53c2\u6570\u7684\u5b57\u5178 best_params = dict ( zip ( param_names , result . x ) ) # \u6253\u5370\u51fa\u627e\u5230\u7684\u6700\u4f73\u8d85\u53c2\u6570 print ( best_params ) \u8fd9\u540c\u6837\u4f1a\u4ea7\u751f\u5927\u91cf\u8f93\u51fa\uff0c\u6700\u540e\u4e00\u90e8\u5206\u5982\u4e0b\u6240\u793a\u3002 Iteration No : 14 started . Searching for the next optimal point . Iteration No : 14 ended . Search finished for the next optimal point . Time taken : 4.7793 Function value obtained : - 0.9075 Current minimum : - 0.9075 Iteration No : 15 started . Searching for the next optimal point . Iteration No : 15 ended . Search finished for the next optimal point . Time taken : 49.4186 Function value obtained : - 0.9075 Current minimum : - 0.9075 { 'max_depth' : 12 , 'n_estimators' : 100 , 'criterion' : 'entropy' , 'max_features' : 1.0 } \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6210\u529f\u7a81\u7834\u4e86 0.90 \u7684\u51c6\u786e\u7387\u3002\u8fd9\u771f\u662f\u592a\u795e\u5947\u4e86\uff01 \u6211\u4eec\u8fd8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6bb5\u67e5\u770b\uff08\u7ed8\u5236\uff09\u6211\u4eec\u662f\u5982\u4f55\u5b9e\u73b0\u6536\u655b\u7684\u3002 from skopt.plots import plot_convergence plot_convergence ( result ) \u6536\u655b\u56fe\u5982\u56fe 2 \u6240\u793a\u3002 \u56fe 2\uff1a\u968f\u673a\u68ee\u6797\u53c2\u6570\u4f18\u5316\u7684\u6536\u655b\u56fe Scikit- optimize \u5c31\u662f\u8fd9\u6837\u4e00\u4e2a\u5e93\u3002 hyperopt \u4f7f\u7528\u6811\u72b6\u7ed3\u6784\u8d1d\u53f6\u65af\u4f30\u8ba1\u5668\uff08TPE\uff09\u6765\u627e\u5230\u6700\u4f18\u53c2\u6570\u3002\u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u7247\u6bb5\uff0c\u6211\u5728\u4f7f\u7528 hyperopt \u65f6\u5bf9\u4e4b\u524d\u7684\u4ee3\u7801\u505a\u4e86\u6700\u5c0f\u7684\u6539\u52a8\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from hyperopt import hp , fmin , tpe , Trials from hyperopt.pyll.base import scope def optimize ( params , x , y ): model = ensemble . RandomForestClassifier ( ** params ) kf = model_selection . StratifiedKFold ( n_splits = 5 ) ... return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/mobile_train.csv\" ) X = df . drop ( \"price_range\" , axis = 1 ) . values y = df . price_range . values # \u5b9a\u4e49\u641c\u7d22\u7a7a\u95f4\uff08\u6574\u578b\u3001\u6d6e\u70b9\u6570\u578b\u3001\u9009\u62e9\u578b\uff09 param_space = { \"max_depth\" : scope . int ( hp . quniform ( \"max_depth\" , 1 , 15 , 1 )), \"n_estimators\" : scope . int ( hp . quniform ( \"n_estimators\" , 100 , 1500 , 1 ) ), \"criterion\" : hp . choice ( \"criterion\" , [ \"gini\" , \"entropy\" ]), \"max_features\" : hp . uniform ( \"max_features\" , 0 , 1 ) } # \u5305\u88c5\u51fd\u6570 optimization_function = partial ( optimize , x = X , y = y ) # \u5f00\u59cb\u8bad\u7ec3 trials = Trials () # \u6700\u5c0f\u5316\u76ee\u6807\u503c hopt = fmin ( fn = optimization_function , space = param_space , algo = tpe . suggest , max_evals = 15 , trials = trials ) #\u6253\u5370\u6700\u4f73\u53c2\u6570 print ( hopt ) \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u4e0e\u4e4b\u524d\u7684\u4ee3\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u4f60\u5fc5\u987b\u4ee5\u4e0d\u540c\u7684\u683c\u5f0f\u5b9a\u4e49\u53c2\u6570\u7a7a\u95f4\uff0c\u8fd8\u9700\u8981\u6539\u53d8\u5b9e\u9645\u4f18\u5316\u90e8\u5206\uff0c\u7528 hyperopt \u4ee3\u66ff gp_minimize\u3002\u7ed3\u679c\u76f8\u5f53\u4e0d\u9519\uff01 \u276f python rf_hyperopt . py 100 %| \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 15 / 15 [ 04 : 38 < 00 : 00 , 18.57 s / trial , best loss : - 0.9095000000000001 ] { 'criterion' : 1 , 'max_depth' : 11.0 , 'max_features' : 0.821163568049807 , 'n_estimators' : 806.0 } \u6211\u4eec\u5f97\u5230\u4e86\u6bd4\u4ee5\u524d\u66f4\u597d\u7684\u51c6\u786e\u5ea6\u548c\u4e00\u7ec4\u53ef\u4ee5\u4f7f\u7528\u7684\u53c2\u6570\u3002\u8bf7\u6ce8\u610f\uff0c\u6700\u7ec8\u7ed3\u679c\u4e2d\u7684\u6807\u51c6\u662f 1\u3002\u8fd9\u610f\u5473\u7740\u9009\u62e9\u4e86 1\uff0c\u5373\u71b5\u3002 \u4e0a\u8ff0\u8c03\u6574\u8d85\u53c2\u6570\u7684\u65b9\u6cd5\u662f\u6700\u5e38\u89c1\u7684\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u6a21\u578b\uff1a\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001\u57fa\u4e8e\u6811\u7684\u65b9\u6cd5\u3001\u68af\u5ea6\u63d0\u5347\u6a21\u578b\uff08\u5982 xgboost\u3001lightgbm\uff09\uff0c\u751a\u81f3\u795e\u7ecf\u7f51\u7edc\uff01 \u867d\u7136\u8fd9\u4e9b\u65b9\u6cd5\u5df2\u7ecf\u5b58\u5728\uff0c\u4f46\u5b66\u4e60\u65f6\u5fc5\u987b\u4ece\u624b\u52a8\u8c03\u6574\u8d85\u53c2\u6570\u5f00\u59cb\uff0c\u5373\u624b\u5de5\u8c03\u6574\u3002\u624b\u52a8\u8c03\u6574\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5b66\u4e60\u57fa\u7840\u77e5\u8bc6\uff0c\u4f8b\u5982\uff0c\u5728\u68af\u5ea6\u63d0\u5347\u4e2d\uff0c\u5f53\u4f60\u589e\u52a0\u6df1\u5ea6\u65f6\uff0c\u4f60\u5e94\u8be5\u964d\u4f4e\u5b66\u4e60\u7387\u3002\u5982\u679c\u4f7f\u7528\u81ea\u52a8\u5de5\u5177\uff0c\u5c31\u65e0\u6cd5\u5b66\u4e60\u5230\u8fd9\u4e00\u70b9\u3002\u8bf7\u53c2\u8003\u4e0b\u8868\uff0c\u4e86\u89e3\u5e94\u5982\u4f55\u8c03\u6574\u3002RS* \u8868\u793a\u968f\u673a\u641c\u7d22\u5e94\u8be5\u66f4\u597d\u3002 \u4e00\u65e6\u4f60\u80fd\u66f4\u597d\u5730\u624b\u52a8\u8c03\u6574\u53c2\u6570\uff0c\u4f60\u751a\u81f3\u53ef\u80fd\u4e0d\u9700\u8981\u4efb\u4f55\u81ea\u52a8\u8d85\u53c2\u6570\u8c03\u6574\u3002\u521b\u5efa\u5927\u578b\u6a21\u578b\u6216\u5f15\u5165\u5927\u91cf\u7279\u5f81\u65f6\uff0c\u4e5f\u5bb9\u6613\u9020\u6210\u8bad\u7ec3\u6570\u636e\u7684\u8fc7\u5ea6\u62df\u5408\u3002\u4e3a\u907f\u514d\u8fc7\u5ea6\u62df\u5408\uff0c\u9700\u8981\u5728\u8bad\u7ec3\u6570\u636e\u7279\u5f81\u4e2d\u5f15\u5165\u566a\u58f0\u6216\u5bf9\u4ee3\u4ef7\u51fd\u6570\u8fdb\u884c\u60e9\u7f5a\u3002\u8fd9\u79cd\u60e9\u7f5a\u79f0\u4e3a \u6b63\u5219\u5316 \uff0c\u6709\u52a9\u4e8e\u6cdb\u5316\u6a21\u578b\u3002\u5728\u7ebf\u6027\u6a21\u578b\u4e2d\uff0c\u6700\u5e38\u89c1\u7684\u6b63\u5219\u5316\u7c7b\u578b\u662f L1 \u548c L2\u3002L1 \u4e5f\u79f0\u4e3a Lasso \u56de\u5f52\uff0cL2 \u79f0\u4e3a Ridge \u56de\u5f52\u3002\u8bf4\u5230\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528dropout\u3001\u6dfb\u52a0\u589e\u5f3a\u3001\u566a\u58f0\u7b49\u65b9\u6cd5\u5bf9\u6a21\u578b\u8fdb\u884c\u6b63\u5219\u5316\u3002\u5229\u7528\u8d85\u53c2\u6570\u4f18\u5316\uff0c\u8fd8\u53ef\u4ee5\u627e\u5230\u6b63\u786e\u7684\u60e9\u7f5a\u65b9\u6cd5\u3002 Model Optimize Range of values Linear Regression - fit_intercept - normalize - True/False - True/False Ridge - alpha - fit_intercept - normalize - 0.01, 0.1, 1.0, 10, 100 - True/False - True/False k-neighbors - n_neighbors - p - 2, 4, 8, 16, ... - 2, 3, ... SVM - C - gamma - class_weight - 0.001, 0.01, ...,10, 100, 1000 - 'auto', RS* - 'balanced', None Logistic Regression - Penalyt - C - L1 or L2 - 0.001, 0.01, ..., 10, ..., 100 Lasso - Alpha - Normalize - 0.1, 1.0, 10 - True/False Random Forest - n_estimators - max_depth - min_samples_split - min_samples_leaf - max features - 120, 300, 500, 800, 1200 - 5, 8, 15, 25, 30, None - 1, 2, 5, 10, 15, 100 - log2, sqrt, None XGBoost - eta - gamma - max_depth - min_child_weight - subsample - colsample_bytree - lambda - alpha - 0.01, 0.015, 0.025, 0.05, 0.1 - 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0 - 3, 5, 7, 9, 12, 15, 17, 25 - 1, 3, 5, 7 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.01, 0.1, 1.0, RS - 0, 0.1, 0.5, 1.0, RS","title":"\u8d85\u53c2\u6570\u4f18\u5316"},{"location":"%E8%B6%85%E5%8F%82%E6%95%B0%E4%BC%98%E5%8C%96/#_1","text":"\u6709\u4e86\u4f18\u79c0\u7684\u6a21\u578b\uff0c\u5c31\u6709\u4e86\u4f18\u5316\u8d85\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u5f97\u5206\u6a21\u578b\u7684\u96be\u9898\u3002\u90a3\u4e48\uff0c\u4ec0\u4e48\u662f\u8d85\u53c2\u6570\u4f18\u5316\u5462\uff1f\u5047\u8bbe\u60a8\u7684\u673a\u5668\u5b66\u4e60\u9879\u76ee\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u6d41\u7a0b\u3002\u6709\u4e00\u4e2a\u6570\u636e\u96c6\uff0c\u4f60\u76f4\u63a5\u5e94\u7528\u4e00\u4e2a\u6a21\u578b\uff0c\u7136\u540e\u5f97\u5230\u7ed3\u679c\u3002\u6a21\u578b\u5728\u8fd9\u91cc\u7684\u53c2\u6570\u88ab\u79f0\u4e3a\u8d85\u53c2\u6570\uff0c\u5373\u63a7\u5236\u6a21\u578b\u8bad\u7ec3/\u62df\u5408\u8fc7\u7a0b\u7684\u53c2\u6570\u3002\u5982\u679c\u6211\u4eec\u7528 SGD \u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u6a21\u578b\u7684\u53c2\u6570\u662f\u659c\u7387\u548c\u504f\u5dee\uff0c\u8d85\u53c2\u6570\u662f\u5b66\u4e60\u7387\u3002\u4f60\u4f1a\u53d1\u73b0\u6211\u5728\u672c\u7ae0\u548c\u672c\u4e66\u4e2d\u4ea4\u66ff\u4f7f\u7528\u8fd9\u4e9b\u672f\u8bed\u3002\u5047\u8bbe\u6a21\u578b\u4e2d\u6709\u4e09\u4e2a\u53c2\u6570 a\u3001b\u3001c\uff0c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u90fd\u53ef\u4ee5\u662f 1 \u5230 10 \u4e4b\u95f4\u7684\u6574\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u7684 \"\u6b63\u786e \"\u7ec4\u5408\u5c06\u4e3a\u60a8\u63d0\u4f9b\u6700\u4f73\u7ed3\u679c\u3002\u56e0\u6b64\uff0c\u8fd9\u5c31\u6709\u70b9\u50cf\u4e00\u4e2a\u88c5\u6709\u4e09\u62e8\u5bc6\u7801\u9501\u7684\u624b\u63d0\u7bb1\u3002\u4e0d\u8fc7\uff0c\u4e09\u62e8\u5bc6\u7801\u9501\u53ea\u6709\u4e00\u4e2a\u6b63\u786e\u7b54\u6848\u3002\u800c\u6a21\u578b\u6709\u5f88\u591a\u6b63\u786e\u7b54\u6848\u3002\u90a3\u4e48\uff0c\u5982\u4f55\u627e\u5230\u6700\u4f73\u53c2\u6570\u5462\uff1f\u4e00\u79cd\u65b9\u6cd5\u662f\u5bf9\u6240\u6709\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\uff0c\u770b\u54ea\u79cd\u7ec4\u5408\u80fd\u63d0\u9ad8\u6307\u6807\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u521d\u59cb\u5316\u6700\u4f73\u51c6\u786e\u5ea6 best_accuracy = 0 # \u521d\u59cb\u5316\u6700\u4f73\u53c2\u6570\u7684\u5b57\u5178 best_parameters = { \"a\" : 0 , \"b\" : 0 , \"c\" : 0 } # \u5faa\u73af\u904d\u5386 a \u7684\u53d6\u503c\u8303\u56f4 1~10 for a in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 b \u7684\u53d6\u503c\u8303\u56f4 1~10 for b in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 c \u7684\u53d6\u503c\u8303\u56f4 1~10 for c in range ( 1 , 11 ): # \u521b\u5efa\u6a21\u578b\uff0c\u4f7f\u7528 a\u3001b\u3001c \u53c2\u6570 model = MODEL ( a , b , c ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u6a21\u578b model . fit ( training_data ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u9a8c\u8bc1\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( validation_data ) # \u8ba1\u7b97\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 accuracy = metrics . accuracy_score ( targets , preds ) # \u5982\u679c\u5f53\u524d\u51c6\u786e\u5ea6\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u5219\u66f4\u65b0\u6700\u4f73\u51c6\u786e\u5ea6\u548c\u6700\u4f73\u53c2\u6570 if accuracy > best_accuracy : best_accuracy = accuracy best_parameters [ \"a\" ] = a best_parameters [ \"b\" ] = b best_parameters [ \"c\" ] = c \u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4ece 1 \u5230 10 \u5bf9\u6240\u6709\u53c2\u6570\u8fdb\u884c\u4e86\u62df\u5408\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u603b\u5171\u8981\u5bf9\u6a21\u578b\u8fdb\u884c 1000 \u6b21\uff0810 x 10 x 10\uff09\u62df\u5408\u3002\u8fd9\u53ef\u80fd\u4f1a\u5f88\u6602\u8d35\uff0c\u56e0\u4e3a\u6a21\u578b\u7684\u8bad\u7ec3\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u8fc7\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\u5e94\u8be5\u6ca1\u95ee\u9898\uff0c\u4f46\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\uff0c\u5e76\u4e0d\u662f\u53ea\u6709\u4e09\u4e2a\u53c2\u6570\uff0c\u6bcf\u4e2a\u53c2\u6570\u4e5f\u4e0d\u662f\u53ea\u6709\u5341\u4e2a\u503c\u3002 \u5927\u591a\u6570\u6a21\u578b\u53c2\u6570\u90fd\u662f\u5b9e\u6570\uff0c\u4e0d\u540c\u53c2\u6570\u7684\u7ec4\u5408\u53ef\u4ee5\u662f\u65e0\u9650\u7684\u3002 \u8ba9\u6211\u4eec\u770b\u770b scikit-learn \u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002 RandomForestClassifier ( n_estimators = 100 , criterion = 'gini' , max_depth = None , min_samples_split = 2 , min_samples_leaf = 1 , min_weight_fraction_leaf = 0.0 , max_features = 'auto' , max_leaf_nodes = None , min_impurity_decrease = 0.0 , min_impurity_split = None , bootstrap = True , oob_score = False , n_jobs = None , random_state = None , verbose = 0 , warm_start = False , class_weight = None , ccp_alpha = 0.0 , max_samples = None , ) \u6709 19 \u4e2a\u53c2\u6570\uff0c\u800c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u7684\u6240\u6709\u7ec4\u5408\uff0c\u4ee5\u53ca\u5b83\u4eec\u53ef\u4ee5\u627f\u62c5\u7684\u6240\u6709\u503c\uff0c\u90fd\u5c06\u662f\u65e0\u7a77\u65e0\u5c3d\u7684\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u6ca1\u6709\u8db3\u591f\u7684\u8d44\u6e90\u548c\u65f6\u95f4\u6765\u505a\u8fd9\u4ef6\u4e8b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6307\u5b9a\u4e86\u4e00\u4e2a\u53c2\u6570\u7f51\u683c\u3002\u5728\u8fd9\u4e2a\u7f51\u683c\u4e0a\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408\u7684\u641c\u7d22\u79f0\u4e3a\u7f51\u683c\u641c\u7d22\u3002\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0cn_estimators \u53ef\u4ee5\u662f 100\u3001200\u3001250\u3001300\u3001400\u3001500\uff1bmax_depth \u53ef\u4ee5\u662f 1\u30012\u30015\u30017\u300111\u300115\uff1bcriterion \u53ef\u4ee5\u662f gini \u6216 entropy\u3002\u8fd9\u4e9b\u53c2\u6570\u770b\u8d77\u6765\u5e76\u4e0d\u591a\uff0c\u4f46\u5982\u679c\u6570\u636e\u96c6\u8fc7\u5927\uff0c\u8ba1\u7b97\u8d77\u6765\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u53ef\u4ee5\u50cf\u4e4b\u524d\u4e00\u6837\u521b\u5efa\u4e09\u4e2a for \u5faa\u73af\uff0c\u5e76\u5728\u9a8c\u8bc1\u96c6\u4e0a\u8ba1\u7b97\u5f97\u5206\uff0c\u8fd9\u6837\u5c31\u80fd\u5b9e\u73b0\u7f51\u683c\u641c\u7d22\u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u8981\u8fdb\u884c k \u6298\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5219\u9700\u8981\u66f4\u591a\u7684\u5faa\u73af\uff0c\u8fd9\u610f\u5473\u7740\u9700\u8981\u66f4\u591a\u7684\u65f6\u95f4\u6765\u627e\u5230\u5b8c\u7f8e\u7684\u53c2\u6570\u3002\u56e0\u6b64\uff0c\u7f51\u683c\u641c\u7d22\u5e76\u4e0d\u6d41\u884c\u3002\u8ba9\u6211\u4eec\u4ee5\u6839\u636e \u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4 \u6570\u636e\u96c6\u4e3a\u4f8b\uff0c\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 \u56fe 1\uff1a\u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4\u6570\u636e\u96c6\u5c55\u793a \u8bad\u7ec3\u96c6\u4e2d\u53ea\u6709 2000 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u53ef\u4ee5\u8f7b\u677e\u5730\u4f7f\u7528\u5206\u5c42 kfold \u548c\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5177\u6709\u4e0a\u8ff0\u53c2\u6570\u8303\u56f4\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\uff0c\u5e76\u5728\u4e0b\u9762\u7684\u793a\u4f8b\u4e2d\u4e86\u89e3\u5982\u4f55\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u3002 # rf_grid_search.py import numpy as np import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u5220\u9664 price_range \u5217 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u53d6\u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\uff0c\u4f7f\u7528\u6240\u6709\u53ef\u7528\u7684 CPU \u6838\u5fc3\u8fdb\u884c\u8bad\u7ec3 classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid = { \"n_estimators\" : [ 100 , 200 , 250 , 300 , 400 , 500 ], \"max_depth\" : [ 1 , 2 , 5 , 7 , 11 , 15 ], \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22 model = model_selection . GridSearchCV ( estimator = classifier , param_grid = param_grid , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( f \"Best score: { model . best_score_ } \" ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u8fd9\u91cc\u6253\u5370\u4e86\u5f88\u591a\u5185\u5bb9\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u6700\u540e\u51e0\u884c\u3002 [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.895 , total = 1.0 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.890 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.910 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.880 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.870 , total = 1.1 s [ Parallel ( n_jobs = 1 )]: Done 360 out of 360 | elapsed : 3.7 min finished Best score : 0.889 Best parameters set : criterion : 'entropy' max_depth : 15 n_estimators : 500 \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c5\u6298\u4ea4\u53c9\u68c0\u9a8c\u6700\u4f73\u5f97\u5206\u662f 0.889\uff0c\u6211\u4eec\u7684\u7f51\u683c\u641c\u7d22\u5f97\u5230\u4e86\u6700\u4f73\u53c2\u6570\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u7684\u4e0b\u4e00\u4e2a\u6700\u4f73\u65b9\u6cd5\u662f \u968f\u673a\u641c\u7d22 \u3002\u5728\u968f\u673a\u641c\u7d22\u4e2d\uff0c\u6211\u4eec\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u53c2\u6570\u7ec4\u5408\uff0c\u7136\u540e\u8ba1\u7b97\u4ea4\u53c9\u9a8c\u8bc1\u5f97\u5206\u3002\u8fd9\u91cc\u6d88\u8017\u7684\u65f6\u95f4\u6bd4\u7f51\u683c\u641c\u7d22\u5c11\uff0c\u56e0\u4e3a\u6211\u4eec\u4e0d\u5bf9\u6240\u6709\u4e0d\u540c\u7684\u53c2\u6570\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\u3002\u6211\u4eec\u9009\u62e9\u8981\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u5c11\u6b21\u8bc4\u4f30\uff0c\u8fd9\u5c31\u51b3\u5b9a\u4e86\u641c\u7d22\u6240\u9700\u7684\u65f6\u95f4\u3002\u4ee3\u7801\u4e0e\u4e0a\u9762\u7684\u5dee\u522b\u4e0d\u5927\u3002\u9664 GridSearchCV \u5916\uff0c\u6211\u4eec\u4f7f\u7528 RandomizedSearchCV\u3002 if __name__ == \"__main__\" : classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u66f4\u6539\u641c\u7d22\u7a7a\u95f4 param_grid = { \"n_estimators\" : np . arange ( 100 , 1500 , 100 ), \"max_depth\" : np . arange ( 1 , 31 ), \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u968f\u673a\u53c2\u6570\u641c\u7d22 model = model_selection . RandomizedSearchCV ( estimator = classifier , param_distributions = param_grid , n_iter = 20 , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) print ( f \"Best score: { model . best_score_ } \" ) print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u6211\u4eec\u66f4\u6539\u4e86\u968f\u673a\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c\uff0c\u7ed3\u679c\u4f3c\u4e4e\u6709\u4e86\u4e9b\u8bb8\u6539\u8fdb\u3002 Best score : 0.8905 Best parameters set : criterion : entropy max_depth : 25 n_estimators : 300 \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8f83\u5c11\uff0c\u968f\u673a\u641c\u7d22\u6bd4\u7f51\u683c\u641c\u7d22\u66f4\u5feb\u3002\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\uff0c\u4f60\u53ef\u4ee5\u4e3a\u5404\u79cd\u6a21\u578b\u627e\u5230\u6700\u4f18\u53c2\u6570\uff0c\u53ea\u8981\u5b83\u4eec\u6709\u62df\u5408\u548c\u9884\u6d4b\u529f\u80fd\uff0c\u8fd9\u4e5f\u662f scikit-learn \u7684\u6807\u51c6\u3002\u6709\u65f6\uff0c\u4f60\u53ef\u80fd\u60f3\u4f7f\u7528\u7ba1\u9053\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u7531\u4e24\u5217\u6587\u672c\u7ec4\u6210\uff0c\u4f60\u9700\u8981\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u7c7b\u522b\u3002\u8ba9\u6211\u4eec\u5047\u8bbe\u4f60\u9009\u62e9\u7684\u7ba1\u9053\u662f\u9996\u5148\u4ee5\u534a\u76d1\u7763\u7684\u65b9\u5f0f\u5e94\u7528 tf-idf\uff0c\u7136\u540e\u4f7f\u7528 SVD \u548c SVM \u5206\u7c7b\u5668\u3002\u73b0\u5728\u7684\u95ee\u9898\u662f\uff0c\u6211\u4eec\u5fc5\u987b\u9009\u62e9 SVD \u7684\u6210\u5206\uff0c\u8fd8\u9700\u8981\u8c03\u6574 SVM \u7684\u53c2\u6570\u3002\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import numpy as np import pandas as pd from sklearn import metrics from sklearn import model_selection from sklearn import pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # \u8ba1\u7b97\u52a0\u6743\u4e8c\u6b21 Kappa \u5206\u6570 def quadratic_weighted_kappa ( y_true , y_pred ): return metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) if __name__ == '__main__' : # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( '../input/train.csv' ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u63d0\u53d6 id \u5217\u7684\u503c\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u7c7b\u578b\uff0c\u5b58\u50a8\u5728\u53d8\u91cf idx \u4e2d idx = test . id . values . astype ( int ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 train = train . drop ( 'id' , axis = 1 ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 test = test . drop ( 'id' , axis = 1 ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf 'relevance' \uff0c\u5b58\u50a8\u5728\u53d8\u91cf y \u4e2d y = train . relevance . values # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 traindata \u4e2d traindata = list ( train . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 testdata \u4e2d testdata = list ( test . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u521b\u5efa\u4e00\u4e2a TfidfVectorizer \u5bf9\u8c61 tfv\uff0c\u7528\u4e8e\u5c06\u6587\u672c\u6570\u636e\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv = TfidfVectorizer ( min_df = 3 , max_features = None , strip_accents = 'unicode' , analyzer = 'word' , token_pattern = r '\\w{1,}' , ngram_range = ( 1 , 3 ), use_idf = 1 , smooth_idf = 1 , sublinear_tf = 1 , stop_words = 'english' ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408 TfidfVectorizer\uff0c\u5c06\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv . fit ( traindata ) # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X X = tfv . transform ( traindata ) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X_test X_test = tfv . transform ( testdata ) # \u521b\u5efa TruncatedSVD \u5bf9\u8c61 svd\uff0c\u7528\u4e8e\u8fdb\u884c\u5947\u5f02\u503c\u5206\u89e3 svd = TruncatedSVD () # \u521b\u5efa StandardScaler \u5bf9\u8c61 scl\uff0c\u7528\u4e8e\u8fdb\u884c\u7279\u5f81\u7f29\u653e scl = StandardScaler () # \u521b\u5efa\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668\u5bf9\u8c61 svm_model svm_model = SVC () # \u521b\u5efa\u673a\u5668\u5b66\u4e60\u7ba1\u9053 clf\uff0c\u5305\u542b\u5947\u5f02\u503c\u5206\u89e3\u3001\u7279\u5f81\u7f29\u653e\u548c\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668 clf = pipeline . Pipeline ( [ ( 'svd' , svd ), ( 'scl' , scl ), ( 'svm' , svm_model ) ] ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid param_grid = { 'svd__n_components' : [ 200 , 300 ], 'svm__C' : [ 10 , 12 ] } # \u521b\u5efa\u81ea\u5b9a\u4e49\u7684\u8bc4\u5206\u51fd\u6570 kappa_scorer\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd kappa_scorer = metrics . make_scorer ( quadratic_weighted_kappa , greater_is_better = True ) # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model = model_selection . GridSearchCV ( estimator = clf , param_grid = param_grid , scoring = kappa_scorer , verbose = 10 , n_jobs =- 1 , refit = True , cv = 5 ) # \u4f7f\u7528 GridSearchCV \u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( \"Best score: %0.3f \" % model . best_score_ ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( \" \\t %s : %r \" % ( param_name , best_parameters [ param_name ])) # \u83b7\u53d6\u6700\u4f73\u6a21\u578b best_model = model . best_estimator_ best_model . fit ( X , y ) # \u4f7f\u7528\u6700\u4f73\u6a21\u578b\u8fdb\u884c\u9884\u6d4b preds = best_model . predict ( ... ) \u8fd9\u91cc\u663e\u793a\u7684\u7ba1\u9053\u5305\u62ec SVD\uff08\u5947\u5f02\u503c\u5206\u89e3\uff09\u3001\u6807\u51c6\u7f29\u653e\u548c SVM\uff08\u652f\u6301\u5411\u91cf\u673a\uff09\u6a21\u578b\u3002\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e\u6ca1\u6709\u8bad\u7ec3\u6570\u636e\uff0c\u60a8\u65e0\u6cd5\u6309\u539f\u6837\u8fd0\u884c\u4e0a\u8ff0\u4ee3\u7801\u3002\u5f53\u6211\u4eec\u8fdb\u5165\u9ad8\u7ea7\u8d85\u53c2\u6570\u4f18\u5316\u6280\u672f\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4e0d\u540c\u7c7b\u578b\u7684 \u6700\u5c0f\u5316\u7b97\u6cd5 \u6765\u7814\u7a76\u51fd\u6570\u7684\u6700\u5c0f\u5316\u3002\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u591a\u79cd\u6700\u5c0f\u5316\u51fd\u6570\u6765\u5b9e\u73b0\uff0c\u5982\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u3001\u5185\u5c14\u5fb7-\u6885\u5fb7\u4f18\u5316\u7b97\u6cd5\u3001\u4f7f\u7528\u8d1d\u53f6\u65af\u6280\u672f\u548c\u9ad8\u65af\u8fc7\u7a0b\u5bfb\u627e\u6700\u4f18\u53c2\u6570\u6216\u4f7f\u7528\u9057\u4f20\u7b97\u6cd5\u3002\u6211\u5c06\u5728 \"\u96c6\u5408\u4e0e\u5806\u53e0\uff08ensembling and stacking\uff09 \"\u4e00\u7ae0\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u548c Nelder-Mead \u7b97\u6cd5\u7684\u5e94\u7528\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u9ad8\u65af\u8fc7\u7a0b\u5982\u4f55\u7528\u4e8e\u8d85\u53c2\u6570\u4f18\u5316\u3002\u8fd9\u7c7b\u7b97\u6cd5\u9700\u8981\u4e00\u4e2a\u53ef\u4ee5\u4f18\u5316\u7684\u51fd\u6570\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u90fd\u662f\u6700\u5c0f\u5316\u8fd9\u4e2a\u51fd\u6570\uff0c\u5c31\u50cf\u6211\u4eec\u6700\u5c0f\u5316\u635f\u5931\u4e00\u6837\u3002 \u56e0\u6b64\uff0c\u6bd4\u65b9\u8bf4\uff0c\u4f60\u60f3\u627e\u5230\u6700\u4f73\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u663e\u7136\uff0c\u51c6\u786e\u5ea6\u8d8a\u9ad8\u8d8a\u597d\u3002\u73b0\u5728\uff0c\u6211\u4eec\u4e0d\u80fd\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\uff0c\u4f46\u6211\u4eec\u53ef\u4ee5\u5c06\u7cbe\u786e\u5ea6\u4e58\u4ee5-1\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u662f\u5728\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\u7684\u8d1f\u503c\uff0c\u4f46\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u6700\u5927\u5316\u7cbe\u786e\u5ea6\u3002 \u5728\u9ad8\u65af\u8fc7\u7a0b\u4e2d\u4f7f\u7528\u8d1d\u53f6\u65af\u4f18\u5316\uff0c\u53ef\u4ee5\u4f7f\u7528 scikit-optimize (skopt) \u5e93\u4e2d\u7684 gp_minimize \u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4f7f\u7528\u8be5\u51fd\u6570\u8c03\u6574\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u53c2\u6570\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from skopt import gp_minimize from skopt import space def optimize ( params , param_names , x , y ): # \u5c06\u53c2\u6570\u540d\u79f0\u548c\u5bf9\u5e94\u7684\u503c\u6253\u5305\u6210\u5b57\u5178 params = dict ( zip ( param_names , params )) # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u6a21\u578b\uff0c\u4f7f\u7528\u4f20\u5165\u7684\u53c2\u6570\u914d\u7f6e model = ensemble . RandomForestClassifier ( ** params ) # \u521b\u5efa StratifiedKFold \u4ea4\u53c9\u9a8c\u8bc1\u5bf9\u8c61\uff0c\u5c06\u6570\u636e\u5206\u4e3a 5 \u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6298\u53e0\u7684\u51c6\u786e\u5ea6\u7684\u5217\u8868 accuracies = [] # \u5faa\u73af\u904d\u5386\u6bcf\u4e2a\u6298\u53e0\u7684\u8bad\u7ec3\u548c\u6d4b\u8bd5\u6570\u636e for idx in kf . split ( X = x , y = y ): train_idx , test_idx = idx [ 0 ], idx [ 1 ] xtrain = x [ train_idx ] ytrain = y [ train_idx ] xtest = x [ test_idx ] ytest = y [ test_idx ] # \u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u62df\u5408\u6a21\u578b model . fit ( xtrain , ytrain ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( xtest ) # \u8ba1\u7b97\u6298\u53e0\u7684\u51c6\u786e\u5ea6 fold_accuracy = metrics . accuracy_score ( ytest , preds ) accuracies . append ( fold_accuracy ) # \u8fd4\u56de\u5e73\u5747\u51c6\u786e\u5ea6\u7684\u8d1f\u6570\uff08\u56e0\u4e3a skopt \u4f7f\u7528\u8d1f\u6570\u6765\u6700\u5c0f\u5316\u76ee\u6807\u51fd\u6570\uff09 return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u53d6\u7279\u5f81\u77e9\u9635 X\uff08\u53bb\u6389\"price_range\"\u5217\uff09 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u5b9a\u4e49\u8d85\u53c2\u6570\u641c\u7d22\u7a7a\u95f4 param_space param_space = [ space . Integer ( 3 , 15 , name = \"max_depth\" ), space . Integer ( 100 , 1500 , name = \"n_estimators\" ), space . Categorical ([ \"gini\" , \"entropy\" ], name = \"criterion\" ), space . Real ( 0.01 , 1 , prior = \"uniform\" , name = \"max_features\" ) ] # \u5b9a\u4e49\u8d85\u53c2\u6570\u7684\u540d\u79f0\u5217\u8868 param_names param_names = [ \"max_depth\" , \"n_estimators\" , \"criterion\" , \"max_features\" ] # \u521b\u5efa\u51fd\u6570 optimization_function\uff0c\u7528\u4e8e\u4f20\u9012\u7ed9 gp_minimize optimization_function = partial ( optimize , param_names = param_names , x = X , y = y ) # \u4f7f\u7528 Bayesian Optimization\uff08\u57fa\u4e8e\u8d1d\u53f6\u65af\u4f18\u5316\uff09\u6765\u641c\u7d22\u6700\u4f73\u8d85\u53c2\u6570 result = gp_minimize ( optimization_function , dimensions = param_space , n_calls = 15 , n_random_starts = 10 , verbose = 10 ) # \u83b7\u53d6\u6700\u4f73\u8d85\u53c2\u6570\u7684\u5b57\u5178 best_params = dict ( zip ( param_names , result . x ) ) # \u6253\u5370\u51fa\u627e\u5230\u7684\u6700\u4f73\u8d85\u53c2\u6570 print ( best_params ) \u8fd9\u540c\u6837\u4f1a\u4ea7\u751f\u5927\u91cf\u8f93\u51fa\uff0c\u6700\u540e\u4e00\u90e8\u5206\u5982\u4e0b\u6240\u793a\u3002 Iteration No : 14 started . Searching for the next optimal point . Iteration No : 14 ended . Search finished for the next optimal point . Time taken : 4.7793 Function value obtained : - 0.9075 Current minimum : - 0.9075 Iteration No : 15 started . Searching for the next optimal point . Iteration No : 15 ended . Search finished for the next optimal point . Time taken : 49.4186 Function value obtained : - 0.9075 Current minimum : - 0.9075 { 'max_depth' : 12 , 'n_estimators' : 100 , 'criterion' : 'entropy' , 'max_features' : 1.0 } \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6210\u529f\u7a81\u7834\u4e86 0.90 \u7684\u51c6\u786e\u7387\u3002\u8fd9\u771f\u662f\u592a\u795e\u5947\u4e86\uff01 \u6211\u4eec\u8fd8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6bb5\u67e5\u770b\uff08\u7ed8\u5236\uff09\u6211\u4eec\u662f\u5982\u4f55\u5b9e\u73b0\u6536\u655b\u7684\u3002 from skopt.plots import plot_convergence plot_convergence ( result ) \u6536\u655b\u56fe\u5982\u56fe 2 \u6240\u793a\u3002 \u56fe 2\uff1a\u968f\u673a\u68ee\u6797\u53c2\u6570\u4f18\u5316\u7684\u6536\u655b\u56fe Scikit- optimize \u5c31\u662f\u8fd9\u6837\u4e00\u4e2a\u5e93\u3002 hyperopt \u4f7f\u7528\u6811\u72b6\u7ed3\u6784\u8d1d\u53f6\u65af\u4f30\u8ba1\u5668\uff08TPE\uff09\u6765\u627e\u5230\u6700\u4f18\u53c2\u6570\u3002\u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u7247\u6bb5\uff0c\u6211\u5728\u4f7f\u7528 hyperopt \u65f6\u5bf9\u4e4b\u524d\u7684\u4ee3\u7801\u505a\u4e86\u6700\u5c0f\u7684\u6539\u52a8\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from hyperopt import hp , fmin , tpe , Trials from hyperopt.pyll.base import scope def optimize ( params , x , y ): model = ensemble . RandomForestClassifier ( ** params ) kf = model_selection . StratifiedKFold ( n_splits = 5 ) ... return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/mobile_train.csv\" ) X = df . drop ( \"price_range\" , axis = 1 ) . values y = df . price_range . values # \u5b9a\u4e49\u641c\u7d22\u7a7a\u95f4\uff08\u6574\u578b\u3001\u6d6e\u70b9\u6570\u578b\u3001\u9009\u62e9\u578b\uff09 param_space = { \"max_depth\" : scope . int ( hp . quniform ( \"max_depth\" , 1 , 15 , 1 )), \"n_estimators\" : scope . int ( hp . quniform ( \"n_estimators\" , 100 , 1500 , 1 ) ), \"criterion\" : hp . choice ( \"criterion\" , [ \"gini\" , \"entropy\" ]), \"max_features\" : hp . uniform ( \"max_features\" , 0 , 1 ) } # \u5305\u88c5\u51fd\u6570 optimization_function = partial ( optimize , x = X , y = y ) # \u5f00\u59cb\u8bad\u7ec3 trials = Trials () # \u6700\u5c0f\u5316\u76ee\u6807\u503c hopt = fmin ( fn = optimization_function , space = param_space , algo = tpe . suggest , max_evals = 15 , trials = trials ) #\u6253\u5370\u6700\u4f73\u53c2\u6570 print ( hopt ) \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u4e0e\u4e4b\u524d\u7684\u4ee3\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u4f60\u5fc5\u987b\u4ee5\u4e0d\u540c\u7684\u683c\u5f0f\u5b9a\u4e49\u53c2\u6570\u7a7a\u95f4\uff0c\u8fd8\u9700\u8981\u6539\u53d8\u5b9e\u9645\u4f18\u5316\u90e8\u5206\uff0c\u7528 hyperopt \u4ee3\u66ff gp_minimize\u3002\u7ed3\u679c\u76f8\u5f53\u4e0d\u9519\uff01 \u276f python rf_hyperopt . py 100 %| \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 15 / 15 [ 04 : 38 < 00 : 00 , 18.57 s / trial , best loss : - 0.9095000000000001 ] { 'criterion' : 1 , 'max_depth' : 11.0 , 'max_features' : 0.821163568049807 , 'n_estimators' : 806.0 } \u6211\u4eec\u5f97\u5230\u4e86\u6bd4\u4ee5\u524d\u66f4\u597d\u7684\u51c6\u786e\u5ea6\u548c\u4e00\u7ec4\u53ef\u4ee5\u4f7f\u7528\u7684\u53c2\u6570\u3002\u8bf7\u6ce8\u610f\uff0c\u6700\u7ec8\u7ed3\u679c\u4e2d\u7684\u6807\u51c6\u662f 1\u3002\u8fd9\u610f\u5473\u7740\u9009\u62e9\u4e86 1\uff0c\u5373\u71b5\u3002 \u4e0a\u8ff0\u8c03\u6574\u8d85\u53c2\u6570\u7684\u65b9\u6cd5\u662f\u6700\u5e38\u89c1\u7684\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u6a21\u578b\uff1a\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001\u57fa\u4e8e\u6811\u7684\u65b9\u6cd5\u3001\u68af\u5ea6\u63d0\u5347\u6a21\u578b\uff08\u5982 xgboost\u3001lightgbm\uff09\uff0c\u751a\u81f3\u795e\u7ecf\u7f51\u7edc\uff01 \u867d\u7136\u8fd9\u4e9b\u65b9\u6cd5\u5df2\u7ecf\u5b58\u5728\uff0c\u4f46\u5b66\u4e60\u65f6\u5fc5\u987b\u4ece\u624b\u52a8\u8c03\u6574\u8d85\u53c2\u6570\u5f00\u59cb\uff0c\u5373\u624b\u5de5\u8c03\u6574\u3002\u624b\u52a8\u8c03\u6574\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5b66\u4e60\u57fa\u7840\u77e5\u8bc6\uff0c\u4f8b\u5982\uff0c\u5728\u68af\u5ea6\u63d0\u5347\u4e2d\uff0c\u5f53\u4f60\u589e\u52a0\u6df1\u5ea6\u65f6\uff0c\u4f60\u5e94\u8be5\u964d\u4f4e\u5b66\u4e60\u7387\u3002\u5982\u679c\u4f7f\u7528\u81ea\u52a8\u5de5\u5177\uff0c\u5c31\u65e0\u6cd5\u5b66\u4e60\u5230\u8fd9\u4e00\u70b9\u3002\u8bf7\u53c2\u8003\u4e0b\u8868\uff0c\u4e86\u89e3\u5e94\u5982\u4f55\u8c03\u6574\u3002RS* \u8868\u793a\u968f\u673a\u641c\u7d22\u5e94\u8be5\u66f4\u597d\u3002 \u4e00\u65e6\u4f60\u80fd\u66f4\u597d\u5730\u624b\u52a8\u8c03\u6574\u53c2\u6570\uff0c\u4f60\u751a\u81f3\u53ef\u80fd\u4e0d\u9700\u8981\u4efb\u4f55\u81ea\u52a8\u8d85\u53c2\u6570\u8c03\u6574\u3002\u521b\u5efa\u5927\u578b\u6a21\u578b\u6216\u5f15\u5165\u5927\u91cf\u7279\u5f81\u65f6\uff0c\u4e5f\u5bb9\u6613\u9020\u6210\u8bad\u7ec3\u6570\u636e\u7684\u8fc7\u5ea6\u62df\u5408\u3002\u4e3a\u907f\u514d\u8fc7\u5ea6\u62df\u5408\uff0c\u9700\u8981\u5728\u8bad\u7ec3\u6570\u636e\u7279\u5f81\u4e2d\u5f15\u5165\u566a\u58f0\u6216\u5bf9\u4ee3\u4ef7\u51fd\u6570\u8fdb\u884c\u60e9\u7f5a\u3002\u8fd9\u79cd\u60e9\u7f5a\u79f0\u4e3a \u6b63\u5219\u5316 \uff0c\u6709\u52a9\u4e8e\u6cdb\u5316\u6a21\u578b\u3002\u5728\u7ebf\u6027\u6a21\u578b\u4e2d\uff0c\u6700\u5e38\u89c1\u7684\u6b63\u5219\u5316\u7c7b\u578b\u662f L1 \u548c L2\u3002L1 \u4e5f\u79f0\u4e3a Lasso \u56de\u5f52\uff0cL2 \u79f0\u4e3a Ridge \u56de\u5f52\u3002\u8bf4\u5230\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528dropout\u3001\u6dfb\u52a0\u589e\u5f3a\u3001\u566a\u58f0\u7b49\u65b9\u6cd5\u5bf9\u6a21\u578b\u8fdb\u884c\u6b63\u5219\u5316\u3002\u5229\u7528\u8d85\u53c2\u6570\u4f18\u5316\uff0c\u8fd8\u53ef\u4ee5\u627e\u5230\u6b63\u786e\u7684\u60e9\u7f5a\u65b9\u6cd5\u3002 Model Optimize Range of values Linear Regression - fit_intercept - normalize - True/False - True/False Ridge - alpha - fit_intercept - normalize - 0.01, 0.1, 1.0, 10, 100 - True/False - True/False k-neighbors - n_neighbors - p - 2, 4, 8, 16, ... - 2, 3, ... SVM - C - gamma - class_weight - 0.001, 0.01, ...,10, 100, 1000 - 'auto', RS* - 'balanced', None Logistic Regression - Penalyt - C - L1 or L2 - 0.001, 0.01, ..., 10, ..., 100 Lasso - Alpha - Normalize - 0.1, 1.0, 10 - True/False Random Forest - n_estimators - max_depth - min_samples_split - min_samples_leaf - max features - 120, 300, 500, 800, 1200 - 5, 8, 15, 25, 30, None - 1, 2, 5, 10, 15, 100 - log2, sqrt, None XGBoost - eta - gamma - max_depth - min_child_weight - subsample - colsample_bytree - lambda - alpha - 0.01, 0.015, 0.025, 0.05, 0.1 - 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0 - 3, 5, 7, 9, 12, 15, 17, 25 - 1, 3, 5, 7 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.01, 0.1, 1.0, RS - 0, 0.1, 0.5, 1.0, RS","title":"\u8d85\u53c2\u6570\u4f18\u5316"}]} \ No newline at end of file +{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"AAAMLP-CN \u65b0\u7279\u6027 - 2023.09.07 \u26a1 \u4fee\u6b63\u90e8\u5206\u5df2\u77e5\u6587\u5b57\u9519\u8bef\u548c\u4ee3\u7801\u9519\u8bef \ud83e\udd17 \u6dfb\u52a0 \u5728\u7ebf\u6587\u6863 \u7ffb\u8bd1\u8fdb\u7a0b 2023.09.12 \u6dfb\u52a0\u7ae0\u8282\uff1a \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u3001 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5 \u7b80\u4ecb Abhishek Thakur\uff0c\u5f88\u591a kaggler \u5bf9\u4ed6\u90fd\u975e\u5e38\u719f\u6089\uff0c2017 \u5e74\uff0c\u4ed6\u5728 Linkedin \u53d1\u8868\u4e86\u4e00\u7bc7\u540d\u4e3a Approaching (Almost) Any Machine Learning Problem \u7684\u6587\u7ae0\uff0c\u4ecb\u7ecd\u4ed6\u5efa\u7acb\u7684\u4e00\u4e2a\u81ea\u52a8\u7684\u673a\u5668\u5b66\u4e60\u6846\u67b6\uff0c\u51e0\u4e4e\u53ef\u4ee5\u89e3\u51b3\u4efb\u4f55\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u8fd9\u7bc7\u6587\u7ae0\u66fe\u706b\u904d Kaggle\u3002 Abhishek \u5728 Kaggle \u4e0a\u7684\u6210\u5c31\uff1a Competitions Grandmaster\uff0817 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 3\uff09 Kernels Expert \uff08Kagglers \u6392\u540d\u524d 1\uff05\uff09 Discussion Grandmaster\uff0865 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 2\uff09 \u76ee\u524d\uff0cAbhishek \u5728\u632a\u5a01 boost \u516c\u53f8\u62c5\u4efb\u9996\u5e2d\u6570\u636e\u79d1\u5b66\u5bb6\u7684\u804c\u4f4d\uff0c\u8fd9\u662f\u4e00\u5bb6\u4e13\u95e8\u4ece\u4e8b\u4f1a\u8bdd\u4eba\u5de5\u667a\u80fd\u7684\u8f6f\u4ef6\u516c\u53f8\u3002 \u672c\u6587\u5bf9 Approaching (Almost) Any Machine Learning Problem \u8fdb\u884c\u4e86 \u4e2d\u6587\u7ffb\u8bd1 \uff0c\u7531\u4e8e\u672c\u4eba\u6c34\u5e73\u6709\u9650\uff0c\u4e14\u672a\u4f7f\u7528\u673a\u5668\u7ffb\u8bd1\uff0c\u53ef\u80fd\u6709\u90e8\u5206\u8a00\u8bed\u4e0d\u901a\u987a\u6216\u672c\u571f\u5316\u7a0b\u5ea6\u4e0d\u8db3\uff0c\u4e5f\u8bf7\u5927\u5bb6\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u591a\u63d0\u4f9b\u5b9d\u8d35\u610f\u89c1\u3002\u53e6\u9644\u4e0a\u4e66\u7c4d\u539f \u9879\u76ee\u5730\u5740 \uff0c \u8f6c\u8f7d\u8bf7\u4e00\u5b9a\u6807\u660e\u51fa\u5904\uff01 \u672c\u9879\u76ee \u652f\u6301\u5728\u7ebf\u9605\u8bfb \uff0c\u65b9\u4fbf\u60a8\u968f\u65f6\u968f\u5730\u8fdb\u884c\u67e5\u9605\u3002 \u56e0\u4e3a\u6709\u51e0\u7ae0\u5185\u5bb9\u592a\u8fc7\u57fa\u7840\uff0c\u6240\u4ee5\u672a\u8fdb\u884c\u7ffb\u8bd1\uff0c\u8be6\u7ec6\u60c5\u51b5\u8bf7\u53c2\u7167\u4e66\u7c4d\u76ee\u5f55\uff1a \u51c6\u5907\u73af\u5883\uff08\u672a\u7ffb\u8bd1\uff09 \u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60\uff08\u672a\u7ffb\u8bd1\uff09 \u4ea4\u53c9\u68c0\u9a8c\uff08\u5df2\u7ffb\u8bd1\uff09 \u8bc4\u4f30\u6307\u6807\uff08\u5df2\u7ffb\u8bd1\uff09 - \u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u5904\u7406\u5206\u7c7b\u53d8\u91cf\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u5de5\u7a0b\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u9009\u62e9\uff08\u5df2\u7ffb\u8bd1\uff09 \u8d85\u53c2\u6570\u4f18\u5316\uff08\u5df2\u7ffb\u8bd1\uff09 \u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5\uff08\u672a\u7ffb\u8bd1\uff09 \u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5\uff08\u672a\u7ffb\u8bd1\uff09 \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6211\u5c06\u4f1a\u628a\u5b8c\u6574\u7684\u7ffb\u8bd1\u7248 Markdown \u6587\u4ef6\u4e0a\u4f20\u5230 GitHub\uff0c\u4ee5\u4f9b\u5927\u5bb6\u514d\u8d39\u4e0b\u8f7d\u548c\u9605\u8bfb\u3002\u4e3a\u4e86\u6700\u4f73\u7684\u9605\u8bfb\u4f53\u9a8c\uff0c\u63a8\u8350\u4f7f\u7528 PDF \u683c\u5f0f\u6216\u662f\u5728\u7ebf\u9605\u8bfb\u8fdb\u884c\u67e5\u770b \u82e5\u60a8\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u53d1\u73b0\u4efb\u4f55\u9519\u8bef\u6216\u4e0d\u51c6\u786e\u4e4b\u5904\uff0c\u975e\u5e38\u6b22\u8fce\u901a\u8fc7\u63d0\u4ea4 Issue \u6216 Pull Request \u6765\u534f\u52a9\u6211\u8fdb\u884c\u4fee\u6b63\u3002 \u968f\u7740\u65f6\u95f4\u63a8\u79fb\uff0c\u6211\u53ef\u80fd\u4f1a \u7ee7\u7eed\u7ffb\u8bd1\u5c1a\u672a\u5b8c\u6210\u7684\u7ae0\u8282 \u3002\u5982\u679c\u60a8\u89c9\u5f97\u8fd9\u4e2a\u9879\u76ee\u5bf9\u60a8\u6709\u5e2e\u52a9\uff0c\u8bf7\u4e0d\u541d\u7ed9\u4e88 Star \u6216\u8005\u8fdb\u884c\u5173\u6ce8\u3002","title":"\u524d\u8a00"},{"location":"#aaamlp-cn","text":"","title":"AAAMLP-CN"},{"location":"#-20230907","text":"\u26a1 \u4fee\u6b63\u90e8\u5206\u5df2\u77e5\u6587\u5b57\u9519\u8bef\u548c\u4ee3\u7801\u9519\u8bef \ud83e\udd17 \u6dfb\u52a0 \u5728\u7ebf\u6587\u6863","title":"\u65b0\u7279\u6027 - 2023.09.07"},{"location":"#_1","text":"2023.09.12 \u6dfb\u52a0\u7ae0\u8282\uff1a \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u3001 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5","title":"\u7ffb\u8bd1\u8fdb\u7a0b"},{"location":"#_2","text":"Abhishek Thakur\uff0c\u5f88\u591a kaggler \u5bf9\u4ed6\u90fd\u975e\u5e38\u719f\u6089\uff0c2017 \u5e74\uff0c\u4ed6\u5728 Linkedin \u53d1\u8868\u4e86\u4e00\u7bc7\u540d\u4e3a Approaching (Almost) Any Machine Learning Problem \u7684\u6587\u7ae0\uff0c\u4ecb\u7ecd\u4ed6\u5efa\u7acb\u7684\u4e00\u4e2a\u81ea\u52a8\u7684\u673a\u5668\u5b66\u4e60\u6846\u67b6\uff0c\u51e0\u4e4e\u53ef\u4ee5\u89e3\u51b3\u4efb\u4f55\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u8fd9\u7bc7\u6587\u7ae0\u66fe\u706b\u904d Kaggle\u3002 Abhishek \u5728 Kaggle \u4e0a\u7684\u6210\u5c31\uff1a Competitions Grandmaster\uff0817 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 3\uff09 Kernels Expert \uff08Kagglers \u6392\u540d\u524d 1\uff05\uff09 Discussion Grandmaster\uff0865 \u679a\u91d1\u724c\uff0c\u4e16\u754c\u6392\u540d\u7b2c 2\uff09 \u76ee\u524d\uff0cAbhishek \u5728\u632a\u5a01 boost \u516c\u53f8\u62c5\u4efb\u9996\u5e2d\u6570\u636e\u79d1\u5b66\u5bb6\u7684\u804c\u4f4d\uff0c\u8fd9\u662f\u4e00\u5bb6\u4e13\u95e8\u4ece\u4e8b\u4f1a\u8bdd\u4eba\u5de5\u667a\u80fd\u7684\u8f6f\u4ef6\u516c\u53f8\u3002 \u672c\u6587\u5bf9 Approaching (Almost) Any Machine Learning Problem \u8fdb\u884c\u4e86 \u4e2d\u6587\u7ffb\u8bd1 \uff0c\u7531\u4e8e\u672c\u4eba\u6c34\u5e73\u6709\u9650\uff0c\u4e14\u672a\u4f7f\u7528\u673a\u5668\u7ffb\u8bd1\uff0c\u53ef\u80fd\u6709\u90e8\u5206\u8a00\u8bed\u4e0d\u901a\u987a\u6216\u672c\u571f\u5316\u7a0b\u5ea6\u4e0d\u8db3\uff0c\u4e5f\u8bf7\u5927\u5bb6\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u591a\u63d0\u4f9b\u5b9d\u8d35\u610f\u89c1\u3002\u53e6\u9644\u4e0a\u4e66\u7c4d\u539f \u9879\u76ee\u5730\u5740 \uff0c \u8f6c\u8f7d\u8bf7\u4e00\u5b9a\u6807\u660e\u51fa\u5904\uff01 \u672c\u9879\u76ee \u652f\u6301\u5728\u7ebf\u9605\u8bfb \uff0c\u65b9\u4fbf\u60a8\u968f\u65f6\u968f\u5730\u8fdb\u884c\u67e5\u9605\u3002 \u56e0\u4e3a\u6709\u51e0\u7ae0\u5185\u5bb9\u592a\u8fc7\u57fa\u7840\uff0c\u6240\u4ee5\u672a\u8fdb\u884c\u7ffb\u8bd1\uff0c\u8be6\u7ec6\u60c5\u51b5\u8bf7\u53c2\u7167\u4e66\u7c4d\u76ee\u5f55\uff1a \u51c6\u5907\u73af\u5883\uff08\u672a\u7ffb\u8bd1\uff09 \u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60\uff08\u672a\u7ffb\u8bd1\uff09 \u4ea4\u53c9\u68c0\u9a8c\uff08\u5df2\u7ffb\u8bd1\uff09 \u8bc4\u4f30\u6307\u6807\uff08\u5df2\u7ffb\u8bd1\uff09 - \u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\uff08\u5df2\u7ffb\u8bd1\uff09 \u5904\u7406\u5206\u7c7b\u53d8\u91cf\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u5de5\u7a0b\uff08\u5df2\u7ffb\u8bd1\uff09 \u7279\u5f81\u9009\u62e9\uff08\u5df2\u7ffb\u8bd1\uff09 \u8d85\u53c2\u6570\u4f18\u5316\uff08\u5df2\u7ffb\u8bd1\uff09 \u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5\uff08\u672a\u7ffb\u8bd1\uff09 \u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5\uff08\u672a\u7ffb\u8bd1\uff09 \u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5\uff08\u5df2\u7ffb\u8bd1\uff09 \u6211\u5c06\u4f1a\u628a\u5b8c\u6574\u7684\u7ffb\u8bd1\u7248 Markdown \u6587\u4ef6\u4e0a\u4f20\u5230 GitHub\uff0c\u4ee5\u4f9b\u5927\u5bb6\u514d\u8d39\u4e0b\u8f7d\u548c\u9605\u8bfb\u3002\u4e3a\u4e86\u6700\u4f73\u7684\u9605\u8bfb\u4f53\u9a8c\uff0c\u63a8\u8350\u4f7f\u7528 PDF \u683c\u5f0f\u6216\u662f\u5728\u7ebf\u9605\u8bfb\u8fdb\u884c\u67e5\u770b \u82e5\u60a8\u5728\u9605\u8bfb\u8fc7\u7a0b\u4e2d\u53d1\u73b0\u4efb\u4f55\u9519\u8bef\u6216\u4e0d\u51c6\u786e\u4e4b\u5904\uff0c\u975e\u5e38\u6b22\u8fce\u901a\u8fc7\u63d0\u4ea4 Issue \u6216 Pull Request \u6765\u534f\u52a9\u6211\u8fdb\u884c\u4fee\u6b63\u3002 \u968f\u7740\u65f6\u95f4\u63a8\u79fb\uff0c\u6211\u53ef\u80fd\u4f1a \u7ee7\u7eed\u7ffb\u8bd1\u5c1a\u672a\u5b8c\u6210\u7684\u7ae0\u8282 \u3002\u5982\u679c\u60a8\u89c9\u5f97\u8fd9\u4e2a\u9879\u76ee\u5bf9\u60a8\u6709\u5e2e\u52a9\uff0c\u8bf7\u4e0d\u541d\u7ed9\u4e88 Star \u6216\u8005\u8fdb\u884c\u5173\u6ce8\u3002","title":"\u7b80\u4ecb"},{"location":"%E4%BA%A4%E5%8F%89%E6%A3%80%E9%AA%8C/","text":"\u4ea4\u53c9\u68c0\u9a8c \u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u6ca1\u6709\u5efa\u7acb\u4efb\u4f55\u6a21\u578b\u3002\u539f\u56e0\u5f88\u7b80\u5355\uff0c\u5728\u521b\u5efa\u4efb\u4f55\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4ee5\u53ca\u5982\u4f55\u6839\u636e\u6570\u636e\u96c6\u9009\u62e9\u6700\u4f73\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\u3002 \u90a3\u4e48\uff0c\u4ec0\u4e48\u662f \u4ea4\u53c9\u68c0\u9a8c \uff0c\u6211\u4eec\u4e3a\u4ec0\u4e48\u8981\u5173\u6ce8\u5b83\uff1f \u5173\u4e8e\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u53ef\u4ee5\u627e\u5230\u591a\u79cd\u5b9a\u4e49\u3002\u6211\u7684\u5b9a\u4e49\u53ea\u6709\u4e00\u53e5\u8bdd\uff1a\u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fc7\u7a0b\u4e2d\u7684\u4e00\u4e2a\u6b65\u9aa4\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u786e\u4fdd\u6a21\u578b\u51c6\u786e\u62df\u5408\u6570\u636e\uff0c\u540c\u65f6\u786e\u4fdd\u6211\u4eec\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u4f46\u8fd9\u53c8\u5f15\u51fa\u4e86\u53e6\u4e00\u4e2a\u8bcd\uff1a \u8fc7\u62df\u5408 \u3002 \u8981\u89e3\u91ca\u8fc7\u62df\u5408\uff0c\u6211\u8ba4\u4e3a\u6700\u597d\u5148\u770b\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6709\u4e00\u4e2a\u76f8\u5f53\u6709\u540d\u7684\u7ea2\u9152\u8d28\u91cf\u6570\u636e\u96c6\uff08 red wine quality dataset \uff09\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 11 \u4e2a\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u8fd9\u4e9b\u7279\u5f81\u51b3\u5b9a\u4e86\u7ea2\u9152\u7684\u8d28\u91cf\u3002 \u8fd9\u4e9b\u5c5e\u6027\u5305\u62ec\uff1a \u56fa\u5b9a\u9178\u5ea6\uff08fixed acidity\uff09 \u6325\u53d1\u6027\u9178\u5ea6\uff08volatile acidity\uff09 \u67e0\u6aac\u9178\uff08citric acid\uff09 \u6b8b\u7559\u7cd6\uff08residual sugar\uff09 \u6c2f\u5316\u7269\uff08chlorides\uff09 \u6e38\u79bb\u4e8c\u6c27\u5316\u786b\uff08free sulfur dioxide\uff09 \u4e8c\u6c27\u5316\u786b\u603b\u91cf\uff08total sulfur dioxide\uff09 \u5bc6\u5ea6\uff08density\uff09 PH \u503c\uff08pH\uff09 \u786b\u9178\u76d0\uff08sulphates\uff09 \u9152\u7cbe\uff08alcohol\uff09 \u6839\u636e\u8fd9\u4e9b\u4e0d\u540c\u7279\u5f81\uff0c\u6211\u4eec\u9700\u8981\u9884\u6d4b\u7ea2\u8461\u8404\u9152\u7684\u8d28\u91cf\uff0c\u8d28\u91cf\u503c\u4ecb\u4e8e 0 \u5230 10 \u4e4b\u95f4\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u600e\u6837\u7684\u3002 import pandas as pd df = pd . read_csv ( \"winequality-red.csv\" ) \u56fe 1:\u7ea2\u8461\u8404\u9152\u8d28\u91cf\u6570\u636e\u96c6\u7b80\u5355\u5c55\u793a \u6211\u4eec\u53ef\u4ee5\u5c06\u8fd9\u4e2a\u95ee\u9898\u89c6\u4e3a\u5206\u7c7b\u95ee\u9898\uff0c\u4e5f\u53ef\u4ee5\u89c6\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u9009\u62e9\u5206\u7c7b\u3002\u7136\u800c\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u503c\u5305\u542b 6 \u79cd\u8d28\u91cf\u503c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u6240\u6709\u8d28\u91cf\u503c\u6620\u5c04\u5230 0 \u5230 5 \u4e4b\u95f4\u3002 # \u4e00\u4e2a\u6620\u5c04\u5b57\u5178\uff0c\u7528\u4e8e\u5c06\u8d28\u91cf\u503c\u4ece 0 \u5230 5 \u8fdb\u884c\u6620\u5c04 quality_mapping = { 3 : 0 , 4 : 1 , 5 : 2 , 6 : 3 , 7 : 4 , 8 : 5 } # \u4f60\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 map \u51fd\u6570\u4ee5\u53ca\u4efb\u4f55\u5b57\u5178\uff0c # \u6765\u8f6c\u6362\u7ed9\u5b9a\u5217\u4e2d\u7684\u503c\u4e3a\u5b57\u5178\u4e2d\u7684\u503c df . loc [:, \"quality\" ] = df . quality . map ( quality_mapping ) \u5f53\u6211\u4eec\u770b\u5927\u8fd9\u4e9b\u6570\u636e\u5e76\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u6211\u4eec\u8111\u6d77\u4e2d\u4f1a\u6d6e\u73b0\u51fa\u5f88\u591a\u53ef\u4ee5\u5e94\u7528\u7684\u7b97\u6cd5\uff0c\u4e5f\u8bb8\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u4ece\u4e00\u5f00\u59cb\u5c31\u6df1\u5165\u7814\u7a76\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u5c31\u6709\u70b9\u7275\u5f3a\u4e86\u3002\u6240\u4ee5\uff0c\u8ba9\u6211\u4eec\u4ece\u7b80\u5355\u7684\u3001\u6211\u4eec\u4e5f\u80fd\u53ef\u89c6\u5316\u7684\u4e1c\u897f\u5f00\u59cb\uff1a\u51b3\u7b56\u6811\u3002 \u5728\u5f00\u59cb\u4e86\u89e3\u4ec0\u4e48\u662f\u8fc7\u62df\u5408\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 1599 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u4fdd\u7559 1000 \u4e2a\u6837\u672c\u7528\u4e8e\u8bad\u7ec3\uff0c599 \u4e2a\u6837\u672c\u4f5c\u4e3a\u4e00\u4e2a\u5355\u72ec\u7684\u96c6\u5408\u3002 \u4ee5\u4e0b\u4ee3\u7801\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u5212\u5206\uff1a # \u4f7f\u7528 frac=1 \u7684 sample \u65b9\u6cd5\u6765\u6253\u4e71 dataframe # \u7531\u4e8e\u6253\u4e71\u540e\u7d22\u5f15\u4f1a\u6539\u53d8\uff0c\u6240\u4ee5\u6211\u4eec\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u9009\u53d6\u524d 1000 \u884c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e df_train = df . head ( 1000 ) # \u9009\u53d6\u6700\u540e\u7684 599 \u884c\u4f5c\u4e3a\u6d4b\u8bd5/\u9a8c\u8bc1\u6570\u636e df_test = df . tail ( 599 ) \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u5728\u8bad\u7ec3\u96c6\u4e0a\u4f7f\u7528 scikit-learn \u8bad\u7ec3\u4e00\u4e2a\u51b3\u7b56\u6811\u6a21\u578b\u3002 # \u4ece scikit-learn \u5bfc\u5165\u9700\u8981\u7684\u6a21\u5757 from sklearn import tree from sklearn import metrics # \u521d\u59cb\u5316\u4e00\u4e2a\u51b3\u7b56\u6811\u5206\u7c7b\u5668\uff0c\u8bbe\u7f6e\u6700\u5927\u6df1\u5ea6\u4e3a 3 clf = tree . DecisionTreeClassifier ( max_depth = 3 ) # \u9009\u62e9\u4f60\u60f3\u8981\u8bad\u7ec3\u6a21\u578b\u7684\u5217 # \u8fd9\u4e9b\u5217\u4f5c\u4e3a\u6a21\u578b\u7684\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u4f7f\u7528\u4e4b\u524d\u6620\u5c04\u7684\u8d28\u91cf\u4ee5\u53ca\u63d0\u4f9b\u7684\u7279\u5f81\u6765\u8bad\u7ec3\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5c06\u51b3\u7b56\u6811\u5206\u7c7b\u5668\u7684\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u8bbe\u4e3a 3\u3002\u8be5\u6a21\u578b\u7684\u6240\u6709\u5176\u4ed6\u53c2\u6570\u5747\u4fdd\u6301\u9ed8\u8ba4\u503c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u6d4b\u8bd5\u8be5\u6a21\u578b\u7684\u51c6\u786e\u6027\uff1a # \u5728\u8bad\u7ec3\u96c6\u4e0a\u751f\u6210\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) # \u5728\u6d4b\u8bd5\u96c6\u4e0a\u751f\u6210\u9884\u6d4b test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) # \u8ba1\u7b97\u6d4b\u8bd5\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) \u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u5206\u522b\u4e3a 58.9%\u548c 54.25%\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u589e\u52a0\u5230 7\uff0c\u5e76\u91cd\u590d\u4e0a\u8ff0\u8fc7\u7a0b\u3002\u8fd9\u6837\uff0c\u8bad\u7ec3\u51c6\u786e\u7387\u4e3a 76.6%\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4e3a 57.3%\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4f7f\u7528\u51c6\u786e\u7387\uff0c\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u662f\u6700\u76f4\u63a5\u7684\u6307\u6807\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u95ee\u9898\u6765\u8bf4\uff0c\u5b83\u53ef\u80fd\u4e0d\u662f\u6700\u597d\u7684\u6307\u6807\u3002\u6211\u4eec\u53ef\u4ee5\u6839\u636e\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u4e0d\u540c\u503c\u6765\u8ba1\u7b97\u8fd9\u4e9b\u51c6\u786e\u7387\uff0c\u5e76\u7ed8\u5236\u66f2\u7ebf\u56fe\u3002 # \u6ce8\u610f\uff1a\u8fd9\u6bb5\u4ee3\u7801\u5728 Jupyter \u7b14\u8bb0\u672c\u4e2d\u7f16\u5199 # \u5bfc\u5165 scikit-learn \u7684 tree \u548c metrics from sklearn import tree from sklearn import metrics # \u5bfc\u5165 matplotlib \u548c seaborn # \u7528\u4e8e\u7ed8\u56fe import matplotlib import matplotlib.pyplot as plt import seaborn as sns # \u8bbe\u7f6e\u5168\u5c40\u6807\u7b7e\u6587\u672c\u7684\u5927\u5c0f matplotlib . rc ( 'xtick' , labelsize = 20 ) matplotlib . rc ( 'ytick' , labelsize = 20 ) # \u786e\u4fdd\u56fe\u8868\u76f4\u63a5\u5728\u7b14\u8bb0\u672c\u5185\u663e\u793a % matplotlib inline # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6\u7684\u5217\u8868 # \u6211\u4eec\u4ece 50% \u7684\u51c6\u786e\u5ea6\u5f00\u59cb train_accuracies = [ 0.5 ] test_accuracies = [ 0.5 ] # \u904d\u5386\u51e0\u4e2a\u4e0d\u540c\u7684\u6811\u6df1\u5ea6\u503c for depth in range ( 1 , 25 ): # \u521d\u59cb\u5316\u6a21\u578b clf = tree . DecisionTreeClassifier ( max_depth = depth ) # \u9009\u62e9\u7528\u4e8e\u8bad\u7ec3\u7684\u5217/\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u5728\u7ed9\u5b9a\u7279\u5f81\u4e0a\u62df\u5408\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) # \u521b\u5efa\u8bad\u7ec3\u548c\u6d4b\u8bd5\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) # \u6dfb\u52a0\u51c6\u786e\u5ea6\u5230\u5217\u8868 train_accuracies . append ( train_accuracy ) test_accuracies . append ( test_accuracy ) # \u4f7f\u7528 matplotlib \u548c seaborn \u521b\u5efa\u4e24\u4e2a\u56fe plt . figure ( figsize = ( 10 , 5 )) sns . set_style ( \"whitegrid\" ) plt . plot ( train_accuracies , label = \"train accuracy\" ) plt . plot ( test_accuracies , label = \"test accuracy\" ) plt . legend ( loc = \"upper left\" , prop = { 'size' : 15 }) plt . xticks ( range ( 0 , 26 , 5 )) plt . xlabel ( \"max_depth\" , size = 20 ) plt . ylabel ( \"accuracy\" , size = 20 ) plt . show () \u8fd9\u5c06\u751f\u6210\u5982\u56fe 2 \u6240\u793a\u7684\u66f2\u7ebf\u56fe\u3002 \u56fe 2\uff1a\u4e0d\u540c max_depth \u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u7387\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5f53\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u503c\u4e3a 14 \u65f6\uff0c\u6d4b\u8bd5\u6570\u636e\u7684\u5f97\u5206\u6700\u9ad8\u3002\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u589e\u52a0\u8fd9\u4e2a\u53c2\u6570\u7684\u503c\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4f1a\u4fdd\u6301\u4e0d\u53d8\u6216\u53d8\u5dee\uff0c\u4f46\u8bad\u7ec3\u51c6\u786e\u7387\u4f1a\u4e0d\u65ad\u63d0\u9ad8\u3002\u8fd9\u8bf4\u660e\uff0c\u968f\u7740\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u589e\u52a0\uff0c\u51b3\u7b56\u6811\u6a21\u578b\u5bf9\u8bad\u7ec3\u6570\u636e\u7684\u5b66\u4e60\u6548\u679c\u8d8a\u6765\u8d8a\u597d\uff0c\u4f46\u6d4b\u8bd5\u6570\u636e\u7684\u6027\u80fd\u5374\u4e1d\u6beb\u6ca1\u6709\u63d0\u9ad8\u3002 \u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8fc7\u62df\u5408 \u3002 \u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u5b8c\u5168\u62df\u5408\uff0c\u800c\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5374\u8868\u73b0\u4e0d\u4f73\u3002\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5f88\u597d\u5730\u5b66\u4e60\u8bad\u7ec3\u6570\u636e\uff0c\u4f46\u65e0\u6cd5\u6cdb\u5316\u5230\u672a\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u3002\u5728\u4e0a\u9762\u7684\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u975e\u5e38\u9ad8\u7684\u6a21\u578b\uff0c\u5b83\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u4f1a\u6709\u51fa\u8272\u7684\u7ed3\u679c\uff0c\u4f46\u8fd9\u79cd\u6a21\u578b\u5e76\u4e0d\u5b9e\u7528\uff0c\u56e0\u4e3a\u5b83\u5728\u771f\u5b9e\u4e16\u754c\u7684\u6837\u672c\u6216\u5b9e\u65f6\u6570\u636e\u4e0a\u4e0d\u4f1a\u63d0\u4f9b\u7c7b\u4f3c\u7684\u7ed3\u679c\u3002 \u6709\u4eba\u53ef\u80fd\u4f1a\u8bf4\uff0c\u8fd9\u79cd\u65b9\u6cd5\u5e76\u6ca1\u6709\u8fc7\u62df\u5408\uff0c\u56e0\u4e3a\u6d4b\u8bd5\u96c6\u7684\u51c6\u786e\u7387\u57fa\u672c\u4fdd\u6301\u4e0d\u53d8\u3002\u8fc7\u62df\u5408\u7684\u53e6\u4e00\u4e2a\u5b9a\u4e49\u662f\uff0c\u5f53\u6211\u4eec\u4e0d\u65ad\u63d0\u9ad8\u8bad\u7ec3\u635f\u5931\u65f6\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u5728\u589e\u52a0\u3002\u8fd9\u79cd\u60c5\u51b5\u5728\u795e\u7ecf\u7f51\u7edc\u4e2d\u975e\u5e38\u5e38\u89c1\u3002 \u6bcf\u5f53\u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u90fd\u5fc5\u987b\u5728\u8bad\u7ec3\u671f\u95f4\u76d1\u63a7\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u3002\u5982\u679c\u6211\u4eec\u6709\u4e00\u4e2a\u975e\u5e38\u5927\u7684\u7f51\u7edc\u6765\u5904\u7406\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u6570\u636e\u96c6\uff08\u5373\u6837\u672c\u6570\u975e\u5e38\u5c11\uff09\uff0c\u6211\u4eec\u5c31\u4f1a\u89c2\u5bdf\u5230\uff0c\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u8bad\u7ec3\uff0c\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u90fd\u4f1a\u51cf\u5c11\u3002\u4f46\u662f\uff0c\u5728\u67d0\u4e2a\u65f6\u523b\uff0c\u6d4b\u8bd5\u635f\u5931\u4f1a\u8fbe\u5230\u6700\u5c0f\u503c\uff0c\u4e4b\u540e\uff0c\u5373\u4f7f\u8bad\u7ec3\u635f\u5931\u8fdb\u4e00\u6b65\u51cf\u5c11\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u4f1a\u5f00\u59cb\u589e\u52a0\u3002\u6211\u4eec\u5fc5\u987b\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u6700\u5c0f\u503c\u65f6\u505c\u6b62\u8bad\u7ec3\u3002 \u8fd9\u662f\u5bf9\u8fc7\u62df\u5408\u6700\u5e38\u89c1\u7684\u89e3\u91ca \u3002 \u5965\u5361\u59c6\u5243\u5200\u7528\u7b80\u5355\u7684\u8bdd\u8bf4\uff0c\u5c31\u662f\u4e0d\u8981\u8bd5\u56fe\u628a\u53ef\u4ee5\u7528\u7b80\u5355\u5f97\u591a\u7684\u65b9\u6cd5\u89e3\u51b3\u7684\u4e8b\u60c5\u590d\u6742\u5316\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6848\u5c31\u662f\u6700\u5177\u901a\u7528\u6027\u7684\u89e3\u51b3\u65b9\u6848\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u53ea\u8981\u4f60\u7684\u6a21\u578b\u4e0d\u7b26\u5408\u5965\u5361\u59c6\u5243\u5200\u539f\u5219\uff0c\u5c31\u5f88\u53ef\u80fd\u662f\u8fc7\u62df\u5408\u3002 \u56fe 3\uff1a\u8fc7\u62df\u5408\u7684\u6700\u4e00\u822c\u5b9a\u4e49 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u56de\u5230\u4ea4\u53c9\u68c0\u9a8c\u3002 \u5728\u89e3\u91ca\u8fc7\u62df\u5408\u65f6\uff0c\u6211\u51b3\u5b9a\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u6211\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u53e6\u4e00\u90e8\u5206\u4e0a\u68c0\u67e5\u5176\u6027\u80fd\u3002\u8fd9\u4e5f\u662f\u4ea4\u53c9\u68c0\u9a8c\u7684\u4e00\u79cd\uff0c\u901a\u5e38\u88ab\u79f0\u4e3a \"\u6682\u7559\u96c6\"\uff08 hold-out set \uff09\u3002\u5f53\u6211\u4eec\u62e5\u6709\u5927\u91cf\u6570\u636e\uff0c\u800c\u6a21\u578b\u63a8\u7406\u662f\u4e00\u4e2a\u8017\u65f6\u7684\u8fc7\u7a0b\u65f6\uff0c\u6211\u4eec\u5c31\u4f1a\u4f7f\u7528\u8fd9\u79cd\uff08\u4ea4\u53c9\uff09\u9a8c\u8bc1\u3002 \u4ea4\u53c9\u68c0\u9a8c\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\uff0c\u5b83\u662f\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u6b65\u9aa4\u3002 \u9009\u62e9\u6b63\u786e\u7684\u4ea4\u53c9\u68c0\u9a8c \u53d6\u51b3\u4e8e\u6240\u5904\u7406\u7684\u6570\u636e\u96c6\uff0c\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u9002\u7528\u7684\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u53ef\u80fd\u4e0d\u9002\u7528\u4e8e\u5176\u4ed6\u6570\u636e\u96c6\u3002\u4e0d\u8fc7\uff0c\u6709\u51e0\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u6700\u4e3a\u6d41\u884c\u548c\u5e7f\u6cdb\u4f7f\u7528\u3002 \u5176\u4e2d\u5305\u62ec\uff1a k \u6298\u4ea4\u53c9\u68c0\u9a8c \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c \u5206\u7ec4 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u4ea4\u53c9\u68c0\u9a8c\u662f\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u5c42\u51e0\u4e2a\u90e8\u5206\uff0c\u6211\u4eec\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u5176\u4f59\u90e8\u5206\u4e0a\u8fdb\u884c\u6d4b\u8bd5\u3002\u8bf7\u770b\u56fe 4\u3002 \u56fe 4\uff1a\u5c06\u6570\u636e\u96c6\u62c6\u5206\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u56fe 4 \u548c\u56fe 5 \u8bf4\u660e\uff0c\u5f53\u4f60\u5f97\u5230\u4e00\u4e2a\u6570\u636e\u96c6\u6765\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u65f6\uff0c\u4f60\u4f1a\u628a\u5b83\u4eec\u5206\u6210 \u4e24\u4e2a\u4e0d\u540c\u7684\u96c6\uff1a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u3002\u5f88\u591a\u4eba\u8fd8\u4f1a\u5c06\u5176\u5206\u6210\u7b2c\u4e09\u7ec4\uff0c\u79f0\u4e4b\u4e3a\u6d4b\u8bd5\u96c6\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u5c06\u53ea\u4f7f\u7528\u4e24\u4e2a\u96c6\u3002\u5982\u4f60\u6240\u89c1\uff0c\u6211\u4eec\u5c06\u6837\u672c\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\u8fdb\u884c\u4e86\u5212\u5206\u3002\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a k \u4e2a\u4e92\u4e0d\u5173\u8054\u7684\u4e0d\u540c\u96c6\u5408\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002 \u56fe 5\uff1aK \u6298\u4ea4\u53c9\u68c0\u9a8c \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 KFold \u5c06\u4efb\u4f55\u6570\u636e\u5206\u5272\u6210 k \u4e2a\u76f8\u7b49\u7684\u90e8\u5206\u3002\u6bcf\u4e2a\u6837\u672c\u5206\u914d\u4e00\u4e2a\u4ece 0 \u5230 k-1 \u7684\u503c\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u5b58\u50a8\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a kfold \u7684\u65b0\u5217\uff0c\u5e76\u7528 -1 \u586b\u5145 df [ \"kfold\" ] = - 1 # \u63a5\u4e0b\u6765\u7684\u6b65\u9aa4\u662f\u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4ece model_selection \u6a21\u5757\u521d\u59cb\u5316 kfold \u7c7b kf = model_selection . KFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684 kfold \u5217\uff08enumerate\u7684\u4f5c\u7528\u662f\u8fd4\u56de\u4e00\u4e2a\u8fed\u4ee3\u5668\uff09 for fold , ( trn_ , val_ ) in enumerate ( kf . split ( X = df )): df . loc [ val_ , 'kfold' ] = fold # \u4fdd\u5b58\u5e26\u6709 kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u636e\u96c6\u90fd\u53ef\u4ee5\u4f7f\u7528\u6b64\u6d41\u7a0b\u3002\u4f8b\u5982\uff0c\u5f53\u6570\u636e\u56fe\u50cf\u65f6\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u56fe\u50cf ID\u3001\u56fe\u50cf\u4f4d\u7f6e\u548c\u56fe\u50cf\u6807\u7b7e\u7684 CSV\uff0c\u7136\u540e\u4f7f\u7528\u4e0a\u8ff0\u6d41\u7a0b\u3002 \u53e6\u4e00\u79cd\u91cd\u8981\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u662f \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u6570\u636e\u96c6\uff0c\u5176\u4e2d\u6b63\u6837\u672c\u5360 90%\uff0c\u8d1f\u6837\u672c\u53ea\u5360 10%\uff0c\u90a3\u4e48\u4f60\u5c31\u4e0d\u5e94\u8be5\u4f7f\u7528\u968f\u673a k \u6298\u4ea4\u53c9\u3002\u5bf9\u8fd9\u6837\u7684\u6570\u636e\u96c6\u4f7f\u7528\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u6298\u53e0\u6837\u672c\u5168\u90e8\u4e3a\u8d1f\u6837\u672c\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u66f4\u503e\u5411\u4e8e\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u4ee5\u4fdd\u6301\u6bcf\u4e2a\u6298\u4e2d\u6807\u7b7e\u7684\u6bd4\u4f8b\u4e0d\u53d8\u3002\u56e0\u6b64\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u90fd\u4f1a\u6709\u76f8\u540c\u7684 90% \u6b63\u6837\u672c\u548c 10% \u8d1f\u6837\u672c\u3002\u56e0\u6b64\uff0c\u65e0\u8bba\u60a8\u9009\u62e9\u4ec0\u4e48\u6307\u6807\u8fdb\u884c\u8bc4\u4f30\uff0c\u90fd\u4f1a\u5728\u6240\u6709\u6298\u53e0\u4e2d\u5f97\u5230\u76f8\u4f3c\u7684\u7ed3\u679c\u3002 \u4fee\u6539\u521b\u5efa k \u6298\u4ea4\u53c9\u68c0\u9a8c\u7684\u4ee3\u7801\u4ee5\u521b\u5efa\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5f88\u5bb9\u6613\u3002\u6211\u4eec\u53ea\u9700\u5c06 model_selection.KFold \u66f4\u6539\u4e3a model_selection.StratifiedKFold \uff0c\u5e76\u5728 kf.split(...) \u51fd\u6570\u4e2d\u6307\u5b9a\u8981\u5206\u5c42\u7684\u76ee\u6807\u5217\u3002\u6211\u4eec\u5047\u8bbe CSV \u6570\u636e\u96c6\u6709\u4e00\u5217\u540d\u4e3a \"target\" \uff0c\u5e76\u4e14\u662f\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u4fdd\u5b58\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6dfb\u52a0\u4e00\u4e2a\u65b0\u5217 kfold\uff0c\u5e76\u7528 -1 \u521d\u59cb\u5316 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u53d8\u91cf y = df . target . values # \u521d\u59cb\u5316 StratifiedKFold \u7c7b\uff0c\u8bbe\u7f6e\u6298\u6570\uff08folds\uff09\u4e3a 5 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u4f7f\u7528 StratifiedKFold \u5bf9\u8c61\u7684 split \u65b9\u6cd5\u6765\u83b7\u53d6\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7d22\u5f15 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u5305\u542b kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u5bf9\u4e8e\u8461\u8404\u9152\u6570\u636e\u96c6\uff0c\u6211\u4eec\u6765\u770b\u770b\u6807\u7b7e\u7684\u5206\u5e03\u60c5\u51b5\u3002 b = sns . countplot ( x = 'quality' , data = df ) b . set_xlabel ( \"quality\" , fontsize = 20 ) b . set_ylabel ( \"count\" , fontsize = 20 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u7ee7\u7eed\u4e0a\u9762\u7684\u4ee3\u7801\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u8f6c\u6362\u4e86\u76ee\u6807\u503c\u3002\u4ece\u56fe 6 \u4e2d\u6211\u4eec\u53ef\u4ee5\u770b\u51fa\uff0c\u8d28\u91cf\u504f\u5dee\u5f88\u5927\u3002\u6709\u4e9b\u7c7b\u522b\u6709\u5f88\u591a\u6837\u672c\uff0c\u6709\u4e9b\u5219\u6ca1\u6709\u90a3\u4e48\u591a\u3002\u5982\u679c\u6211\u4eec\u8fdb\u884c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u90a3\u4e48\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u76ee\u6807\u503c\u5206\u5e03\u90fd\u4e0d\u4f1a\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u56fe 6\uff1a\u8461\u8404\u9152\u6570\u636e\u96c6\u4e2d \"\u8d28\u91cf\" \u5206\u5e03\u60c5\u51b5 \u89c4\u5219\u5f88\u7b80\u5355\uff0c\u5982\u679c\u662f\u6807\u51c6\u5206\u7c7b\u95ee\u9898\uff0c\u5c31\u76f2\u76ee\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u4f46\u5982\u679c\u6570\u636e\u91cf\u5f88\u5927\uff0c\u8be5\u600e\u4e48\u529e\u5462\uff1f\u5047\u8bbe\u6211\u4eec\u6709 100 \u4e07\u4e2a\u6837\u672c\u30025 \u500d\u4ea4\u53c9\u68c0\u9a8c\u610f\u5473\u7740\u5728 800k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u5728 200k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u9a8c\u8bc1\u3002\u6839\u636e\u6211\u4eec\u9009\u62e9\u7684\u7b97\u6cd5\uff0c\u5bf9\u4e8e\u8fd9\u6837\u89c4\u6a21\u7684\u6570\u636e\u96c6\u6765\u8bf4\uff0c\u8bad\u7ec3\u751a\u81f3\u9a8c\u8bc1\u90fd\u53ef\u80fd\u975e\u5e38\u6602\u8d35\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u3002 \u521b\u5efa\u4fdd\u6301\u7ed3\u679c\u7684\u8fc7\u7a0b\u4e0e\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u76f8\u540c\u3002\u5bf9\u4e8e\u62e5\u6709 100 \u4e07\u4e2a\u6837\u672c\u7684\u6570\u636e\u96c6\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa 10 \u4e2a\u6298\u53e0\u800c\u4e0d\u662f 5 \u4e2a\uff0c\u5e76\u4fdd\u7559\u5176\u4e2d\u4e00\u4e2a\u6298\u53e0\u4f5c\u4e3a\u4fdd\u7559\u6837\u672c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u6211\u4eec\u5c06\u6709 10 \u4e07\u4e2a\u6837\u672c\u88ab\u4fdd\u7559\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u59cb\u7ec8\u5728\u8fd9\u4e2a\u6837\u672c\u96c6\u4e0a\u8ba1\u7b97\u635f\u5931\u3001\u51c6\u786e\u7387\u548c\u5176\u4ed6\u6307\u6807\uff0c\u5e76\u5728 90 \u4e07\u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002 \u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u6682\u7559\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u975e\u5e38\u5e38\u7528\u3002\u5047\u8bbe\u6211\u4eec\u8981\u89e3\u51b3\u7684\u95ee\u9898\u662f\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97 2020 \u5e74\u7684\u9500\u552e\u989d\uff0c\u800c\u6211\u4eec\u5f97\u5230\u7684\u662f 2015-2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f60\u53ef\u4ee5\u9009\u62e9 2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u4f5c\u4e3a\u4fdd\u7559\u6570\u636e\uff0c\u7136\u540e\u5728 2015 \u5e74\u81f3 2018 \u5e74\u7684\u6240\u6709\u6570\u636e\u4e0a\u8bad\u7ec3\u4f60\u7684\u6a21\u578b\u3002 \u56fe 7\uff1a\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u793a\u4f8b \u5728\u56fe 7 \u6240\u793a\u7684\u793a\u4f8b\u4e2d\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u4ece\u65f6\u95f4\u6b65\u9aa4 31 \u5230 40 \u7684\u9500\u552e\u989d\u3002\u6211\u4eec\u53ef\u4ee5\u4fdd\u7559 21 \u81f3 30 \u6b65\u7684\u6570\u636e\uff0c\u7136\u540e\u4ece 0 \u6b65\u5230 20 \u6b65\u8bad\u7ec3\u6a21\u578b\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5728\u9884\u6d4b 31 \u6b65\u81f3 40 \u6b65\u65f6\uff0c\u5e94\u5c06 21 \u6b65\u81f3 30 \u6b65\u7684\u6570\u636e\u7eb3\u5165\u6a21\u578b\uff0c\u5426\u5219\uff0c\u6a21\u578b\u7684\u6027\u80fd\u5c06\u5927\u6253\u6298\u6263\u3002 \u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u5fc5\u987b\u5904\u7406\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u800c\u521b\u5efa\u5927\u578b\u9a8c\u8bc1\u96c6\u610f\u5473\u7740\u6a21\u578b\u5b66\u4e60\u4f1a\u4e22\u5931\u5927\u91cf\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9\u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c\uff0c\u76f8\u5f53\u4e8e\u7279\u6b8a\u7684 k \u5219\u4ea4\u53c9\u68c0\u9a8c\u5176\u4e2d k=N \uff0cN \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8fd9\u610f\u5473\u7740\u5728\u6240\u6709\u7684\u8bad\u7ec3\u6298\u53e0\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u9664 1 \u4e4b\u5916\u7684\u6240\u6709\u6570\u636e\u6837\u672c\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u7684\u6298\u53e0\u6570\u4e0e\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u76f8\u540c\u3002 \u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u6a21\u578b\u7684\u901f\u5ea6\u4e0d\u591f\u5feb\uff0c\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\uff0c\u4f46\u7531\u4e8e\u8fd9\u79cd\u4ea4\u53c9\u68c0\u9a8c\u53ea\u9002\u7528\u4e8e\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u56e0\u6b64\u5e76\u4e0d\u91cd\u8981\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u95ee\u9898\u4e86\u3002\u56de\u5f52\u95ee\u9898\u7684\u597d\u5904\u5728\u4e8e\uff0c\u9664\u4e86\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e4b\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u5728\u56de\u5f52\u95ee\u9898\u4e0a\u4f7f\u7528\u4e0a\u8ff0\u6240\u6709\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6211\u4eec\u4e0d\u80fd\u76f4\u63a5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u6709\u4e00\u4e9b\u65b9\u6cd5\u53ef\u4ee5\u7a0d\u7a0d\u6539\u53d8\u95ee\u9898\uff0c\u4ece\u800c\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u9002\u7528\u4e8e\u4efb\u4f55\u56de\u5f52\u95ee\u9898\u3002\u4f46\u662f\uff0c\u5982\u679c\u53d1\u73b0\u76ee\u6807\u5206\u5e03\u4e0d\u4e00\u81f4\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u8981\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u5fc5\u987b\u5148\u5c06\u76ee\u6807\u5212\u5206\u4e3a\u82e5\u5e72\u4e2a\u5206\u5c42\uff0c\u7136\u540e\u518d\u4ee5\u5904\u7406\u5206\u7c7b\u95ee\u9898\u7684\u76f8\u540c\u65b9\u5f0f\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u9009\u62e9\u5408\u9002\u7684\u5206\u5c42\u6570\u6709\u51e0\u79cd\u9009\u62e9\u3002\u5982\u679c\u6837\u672c\u91cf\u5f88\u5927\uff08> 10k\uff0c> 100k\uff09\uff0c\u90a3\u4e48\u5c31\u4e0d\u9700\u8981\u8003\u8651\u5206\u5c42\u7684\u6570\u91cf\u3002\u53ea\u9700\u5c06\u6570\u636e\u5206\u4e3a 10 \u6216 20 \u5c42\u5373\u53ef\u3002\u5982\u679c\u6837\u672c\u6570\u4e0d\u591a\uff0c\u5219\u53ef\u4ee5\u4f7f\u7528 Sturge's Rule \u8fd9\u6837\u7684\u7b80\u5355\u89c4\u5219\u6765\u8ba1\u7b97\u9002\u5f53\u7684\u5206\u5c42\u6570\u3002 Sturge's Rule\uff1a \\[ Number of Bins = 1 + log_2(N) \\] \u5176\u4e2d \\(N\\) \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8be5\u51fd\u6570\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u5229\u7528\u65af\u7279\u683c\u6cd5\u5219\u7ed8\u5236\u6837\u672c\u4e0e\u7bb1\u6570\u5bf9\u6bd4\u56fe \u8ba9\u6211\u4eec\u5236\u4f5c\u4e00\u4e2a\u56de\u5f52\u6570\u636e\u96c6\u6837\u672c\uff0c\u5e76\u5c1d\u8bd5\u5e94\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 # stratified-kfold for regression # \u4e3a\u56de\u5f52\u95ee\u9898\u8fdb\u884c\u5206\u5c42K-\u6298\u4ea4\u53c9\u9a8c\u8bc1 # \u5bfc\u5165\u9700\u8981\u7684\u5e93 import numpy as np import pandas as pd from sklearn import datasets from sklearn import model_selection # \u521b\u5efa\u5206\u6298\uff08folds\uff09\u7684\u51fd\u6570 def create_folds ( data ): # \u521b\u5efa\u4e00\u4e2a\u65b0\u5217\u53eb\u505akfold\uff0c\u5e76\u7528-1\u6765\u586b\u5145 data [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4f7f\u7528Sturge\u89c4\u5219\u8ba1\u7b97bin\u7684\u6570\u91cf num_bins = int ( np . floor ( 1 + np . log2 ( len ( data )))) # \u4f7f\u7528pandas\u7684cut\u51fd\u6570\u8fdb\u884c\u76ee\u6807\u53d8\u91cf\uff08target\uff09\u7684\u5206\u7bb1 data . loc [:, \"bins\" ] = pd . cut ( data [ \"target\" ], bins = num_bins , labels = False ) # \u521d\u59cb\u5316StratifiedKFold\u7c7b kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684kfold\u5217 # \u6ce8\u610f\uff1a\u6211\u4eec\u4f7f\u7528\u7684\u662fbins\u800c\u4e0d\u662f\u5b9e\u9645\u7684\u76ee\u6807\u53d8\u91cf\uff08target\uff09\uff01 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = data , y = data . bins . values )): data . loc [ v_ , 'kfold' ] = f # \u5220\u9664bins\u5217 data = data . drop ( \"bins\" , axis = 1 ) # \u8fd4\u56de\u5305\u542bfolds\u7684\u6570\u636e return data # \u4e3b\u7a0b\u5e8f\u5f00\u59cb if __name__ == \"__main__\" : # \u521b\u5efa\u4e00\u4e2a\u5e26\u670915000\u4e2a\u6837\u672c\u3001100\u4e2a\u7279\u5f81\u548c1\u4e2a\u76ee\u6807\u53d8\u91cf\u7684\u6837\u672c\u6570\u636e\u96c6 X , y = datasets . make_regression ( n_samples = 15000 , n_features = 100 , n_targets = 1 ) # \u4f7f\u7528numpy\u6570\u7ec4\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u6846 df = pd . DataFrame ( X , columns = [ f \"f_ { i } \" for i in range ( X . shape [ 1 ])] ) df . loc [:, \"target\" ] = y # \u521b\u5efafolds df = create_folds ( df ) \u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u7b2c\u4e00\u6b65\uff0c\u4e5f\u662f\u6700\u57fa\u672c\u7684\u4e00\u6b65\u3002\u5982\u679c\u8981\u505a\u7279\u5f81\u5de5\u7a0b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u8981\u5efa\u7acb\u6a21\u578b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u65b9\u6848\uff0c\u5176\u4e2d\u9a8c\u8bc1\u6570\u636e\u80fd\u591f\u4ee3\u8868\u8bad\u7ec3\u6570\u636e\u548c\u771f\u5b9e\u4e16\u754c\u7684\u6570\u636e\uff0c\u90a3\u4e48\u4f60\u5c31\u80fd\u5efa\u7acb\u4e00\u4e2a\u5177\u6709\u9ad8\u5ea6\u901a\u7528\u6027\u7684\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u3002 \u672c\u7ae0\u4ecb\u7ecd\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u6570\u636e\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u6839\u636e\u4f60\u7684\u95ee\u9898\u548c\u6570\u636e\u91c7\u7528\u65b0\u7684\u4ea4\u53c9\u68c0\u9a8c\u5f62\u5f0f\u3002 \u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u95ee\u9898\uff0c\u5e0c\u671b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u4ece\u60a3\u8005\u7684\u76ae\u80a4\u56fe\u50cf\u4e2d\u68c0\u6d4b\u51fa\u76ae\u80a4\u764c\u3002\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u5668\uff0c\u8be5\u5206\u7c7b\u5668\u63a5\u6536\u8f93\u5165\u56fe\u50cf\u5e76\u9884\u6d4b\u5176\u826f\u6027\u6216\u6076\u6027\u7684\u6982\u7387\u3002 \u5728\u8fd9\u7c7b\u6570\u636e\u96c6\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u53ef\u80fd\u6709\u540c\u4e00\u60a3\u8005\u7684\u591a\u5f20\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u8981\u5728\u8fd9\u91cc\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u7cfb\u7edf\uff0c\u5fc5\u987b\u6709\u5206\u5c42\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u4e5f\u5fc5\u987b\u786e\u4fdd\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u60a3\u8005\u4e0d\u4f1a\u51fa\u73b0\u5728\u9a8c\u8bc1\u6570\u636e\u4e2d\u3002\u5e78\u8fd0\u7684\u662f\uff0cscikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u79f0\u4e3a GroupKFold \u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u3002 \u5728\u8fd9\u91cc\uff0c\u60a3\u8005\u53ef\u4ee5\u88ab\u89c6\u4e3a\u7ec4\u3002 \u4f46\u9057\u61be\u7684\u662f\uff0cscikit-learn \u65e0\u6cd5\u5c06 GroupKFold \u4e0e StratifiedKFold \u7ed3\u5408\u8d77\u6765\u3002\u6240\u4ee5\u4f60\u9700\u8981\u81ea\u5df1\u52a8\u624b\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u8bfb\u8005\u7684\u7ec3\u4e60\u3002","title":"\u4ea4\u53c9\u68c0\u9a8c"},{"location":"%E4%BA%A4%E5%8F%89%E6%A3%80%E9%AA%8C/#_1","text":"\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u6ca1\u6709\u5efa\u7acb\u4efb\u4f55\u6a21\u578b\u3002\u539f\u56e0\u5f88\u7b80\u5355\uff0c\u5728\u521b\u5efa\u4efb\u4f55\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4ee5\u53ca\u5982\u4f55\u6839\u636e\u6570\u636e\u96c6\u9009\u62e9\u6700\u4f73\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\u3002 \u90a3\u4e48\uff0c\u4ec0\u4e48\u662f \u4ea4\u53c9\u68c0\u9a8c \uff0c\u6211\u4eec\u4e3a\u4ec0\u4e48\u8981\u5173\u6ce8\u5b83\uff1f \u5173\u4e8e\u4ec0\u4e48\u662f\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u53ef\u4ee5\u627e\u5230\u591a\u79cd\u5b9a\u4e49\u3002\u6211\u7684\u5b9a\u4e49\u53ea\u6709\u4e00\u53e5\u8bdd\uff1a\u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fc7\u7a0b\u4e2d\u7684\u4e00\u4e2a\u6b65\u9aa4\uff0c\u5b83\u53ef\u4ee5\u5e2e\u52a9\u6211\u4eec\u786e\u4fdd\u6a21\u578b\u51c6\u786e\u62df\u5408\u6570\u636e\uff0c\u540c\u65f6\u786e\u4fdd\u6211\u4eec\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u4f46\u8fd9\u53c8\u5f15\u51fa\u4e86\u53e6\u4e00\u4e2a\u8bcd\uff1a \u8fc7\u62df\u5408 \u3002 \u8981\u89e3\u91ca\u8fc7\u62df\u5408\uff0c\u6211\u8ba4\u4e3a\u6700\u597d\u5148\u770b\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6709\u4e00\u4e2a\u76f8\u5f53\u6709\u540d\u7684\u7ea2\u9152\u8d28\u91cf\u6570\u636e\u96c6\uff08 red wine quality dataset \uff09\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 11 \u4e2a\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u8fd9\u4e9b\u7279\u5f81\u51b3\u5b9a\u4e86\u7ea2\u9152\u7684\u8d28\u91cf\u3002 \u8fd9\u4e9b\u5c5e\u6027\u5305\u62ec\uff1a \u56fa\u5b9a\u9178\u5ea6\uff08fixed acidity\uff09 \u6325\u53d1\u6027\u9178\u5ea6\uff08volatile acidity\uff09 \u67e0\u6aac\u9178\uff08citric acid\uff09 \u6b8b\u7559\u7cd6\uff08residual sugar\uff09 \u6c2f\u5316\u7269\uff08chlorides\uff09 \u6e38\u79bb\u4e8c\u6c27\u5316\u786b\uff08free sulfur dioxide\uff09 \u4e8c\u6c27\u5316\u786b\u603b\u91cf\uff08total sulfur dioxide\uff09 \u5bc6\u5ea6\uff08density\uff09 PH \u503c\uff08pH\uff09 \u786b\u9178\u76d0\uff08sulphates\uff09 \u9152\u7cbe\uff08alcohol\uff09 \u6839\u636e\u8fd9\u4e9b\u4e0d\u540c\u7279\u5f81\uff0c\u6211\u4eec\u9700\u8981\u9884\u6d4b\u7ea2\u8461\u8404\u9152\u7684\u8d28\u91cf\uff0c\u8d28\u91cf\u503c\u4ecb\u4e8e 0 \u5230 10 \u4e4b\u95f4\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u600e\u6837\u7684\u3002 import pandas as pd df = pd . read_csv ( \"winequality-red.csv\" ) \u56fe 1:\u7ea2\u8461\u8404\u9152\u8d28\u91cf\u6570\u636e\u96c6\u7b80\u5355\u5c55\u793a \u6211\u4eec\u53ef\u4ee5\u5c06\u8fd9\u4e2a\u95ee\u9898\u89c6\u4e3a\u5206\u7c7b\u95ee\u9898\uff0c\u4e5f\u53ef\u4ee5\u89c6\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u9009\u62e9\u5206\u7c7b\u3002\u7136\u800c\uff0c\u8fd9\u4e2a\u6570\u636e\u96c6\u503c\u5305\u542b 6 \u79cd\u8d28\u91cf\u503c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u6240\u6709\u8d28\u91cf\u503c\u6620\u5c04\u5230 0 \u5230 5 \u4e4b\u95f4\u3002 # \u4e00\u4e2a\u6620\u5c04\u5b57\u5178\uff0c\u7528\u4e8e\u5c06\u8d28\u91cf\u503c\u4ece 0 \u5230 5 \u8fdb\u884c\u6620\u5c04 quality_mapping = { 3 : 0 , 4 : 1 , 5 : 2 , 6 : 3 , 7 : 4 , 8 : 5 } # \u4f60\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 map \u51fd\u6570\u4ee5\u53ca\u4efb\u4f55\u5b57\u5178\uff0c # \u6765\u8f6c\u6362\u7ed9\u5b9a\u5217\u4e2d\u7684\u503c\u4e3a\u5b57\u5178\u4e2d\u7684\u503c df . loc [:, \"quality\" ] = df . quality . map ( quality_mapping ) \u5f53\u6211\u4eec\u770b\u5927\u8fd9\u4e9b\u6570\u636e\u5e76\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u6211\u4eec\u8111\u6d77\u4e2d\u4f1a\u6d6e\u73b0\u51fa\u5f88\u591a\u53ef\u4ee5\u5e94\u7528\u7684\u7b97\u6cd5\uff0c\u4e5f\u8bb8\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u4ece\u4e00\u5f00\u59cb\u5c31\u6df1\u5165\u7814\u7a76\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u5c31\u6709\u70b9\u7275\u5f3a\u4e86\u3002\u6240\u4ee5\uff0c\u8ba9\u6211\u4eec\u4ece\u7b80\u5355\u7684\u3001\u6211\u4eec\u4e5f\u80fd\u53ef\u89c6\u5316\u7684\u4e1c\u897f\u5f00\u59cb\uff1a\u51b3\u7b56\u6811\u3002 \u5728\u5f00\u59cb\u4e86\u89e3\u4ec0\u4e48\u662f\u8fc7\u62df\u5408\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u6709 1599 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u4fdd\u7559 1000 \u4e2a\u6837\u672c\u7528\u4e8e\u8bad\u7ec3\uff0c599 \u4e2a\u6837\u672c\u4f5c\u4e3a\u4e00\u4e2a\u5355\u72ec\u7684\u96c6\u5408\u3002 \u4ee5\u4e0b\u4ee3\u7801\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u5212\u5206\uff1a # \u4f7f\u7528 frac=1 \u7684 sample \u65b9\u6cd5\u6765\u6253\u4e71 dataframe # \u7531\u4e8e\u6253\u4e71\u540e\u7d22\u5f15\u4f1a\u6539\u53d8\uff0c\u6240\u4ee5\u6211\u4eec\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u9009\u53d6\u524d 1000 \u884c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e df_train = df . head ( 1000 ) # \u9009\u53d6\u6700\u540e\u7684 599 \u884c\u4f5c\u4e3a\u6d4b\u8bd5/\u9a8c\u8bc1\u6570\u636e df_test = df . tail ( 599 ) \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u5728\u8bad\u7ec3\u96c6\u4e0a\u4f7f\u7528 scikit-learn \u8bad\u7ec3\u4e00\u4e2a\u51b3\u7b56\u6811\u6a21\u578b\u3002 # \u4ece scikit-learn \u5bfc\u5165\u9700\u8981\u7684\u6a21\u5757 from sklearn import tree from sklearn import metrics # \u521d\u59cb\u5316\u4e00\u4e2a\u51b3\u7b56\u6811\u5206\u7c7b\u5668\uff0c\u8bbe\u7f6e\u6700\u5927\u6df1\u5ea6\u4e3a 3 clf = tree . DecisionTreeClassifier ( max_depth = 3 ) # \u9009\u62e9\u4f60\u60f3\u8981\u8bad\u7ec3\u6a21\u578b\u7684\u5217 # \u8fd9\u4e9b\u5217\u4f5c\u4e3a\u6a21\u578b\u7684\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u4f7f\u7528\u4e4b\u524d\u6620\u5c04\u7684\u8d28\u91cf\u4ee5\u53ca\u63d0\u4f9b\u7684\u7279\u5f81\u6765\u8bad\u7ec3\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5c06\u51b3\u7b56\u6811\u5206\u7c7b\u5668\u7684\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u8bbe\u4e3a 3\u3002\u8be5\u6a21\u578b\u7684\u6240\u6709\u5176\u4ed6\u53c2\u6570\u5747\u4fdd\u6301\u9ed8\u8ba4\u503c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u6d4b\u8bd5\u8be5\u6a21\u578b\u7684\u51c6\u786e\u6027\uff1a # \u5728\u8bad\u7ec3\u96c6\u4e0a\u751f\u6210\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) # \u5728\u6d4b\u8bd5\u96c6\u4e0a\u751f\u6210\u9884\u6d4b test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) # \u8ba1\u7b97\u6d4b\u8bd5\u6570\u636e\u96c6\u4e0a\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) \u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u5206\u522b\u4e3a 58.9%\u548c 54.25%\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u589e\u52a0\u5230 7\uff0c\u5e76\u91cd\u590d\u4e0a\u8ff0\u8fc7\u7a0b\u3002\u8fd9\u6837\uff0c\u8bad\u7ec3\u51c6\u786e\u7387\u4e3a 76.6%\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4e3a 57.3%\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4f7f\u7528\u51c6\u786e\u7387\uff0c\u4e3b\u8981\u662f\u56e0\u4e3a\u5b83\u662f\u6700\u76f4\u63a5\u7684\u6307\u6807\u3002\u5bf9\u4e8e\u8fd9\u4e2a\u95ee\u9898\u6765\u8bf4\uff0c\u5b83\u53ef\u80fd\u4e0d\u662f\u6700\u597d\u7684\u6307\u6807\u3002\u6211\u4eec\u53ef\u4ee5\u6839\u636e\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u4e0d\u540c\u503c\u6765\u8ba1\u7b97\u8fd9\u4e9b\u51c6\u786e\u7387\uff0c\u5e76\u7ed8\u5236\u66f2\u7ebf\u56fe\u3002 # \u6ce8\u610f\uff1a\u8fd9\u6bb5\u4ee3\u7801\u5728 Jupyter \u7b14\u8bb0\u672c\u4e2d\u7f16\u5199 # \u5bfc\u5165 scikit-learn \u7684 tree \u548c metrics from sklearn import tree from sklearn import metrics # \u5bfc\u5165 matplotlib \u548c seaborn # \u7528\u4e8e\u7ed8\u56fe import matplotlib import matplotlib.pyplot as plt import seaborn as sns # \u8bbe\u7f6e\u5168\u5c40\u6807\u7b7e\u6587\u672c\u7684\u5927\u5c0f matplotlib . rc ( 'xtick' , labelsize = 20 ) matplotlib . rc ( 'ytick' , labelsize = 20 ) # \u786e\u4fdd\u56fe\u8868\u76f4\u63a5\u5728\u7b14\u8bb0\u672c\u5185\u663e\u793a % matplotlib inline # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6\u7684\u5217\u8868 # \u6211\u4eec\u4ece 50% \u7684\u51c6\u786e\u5ea6\u5f00\u59cb train_accuracies = [ 0.5 ] test_accuracies = [ 0.5 ] # \u904d\u5386\u51e0\u4e2a\u4e0d\u540c\u7684\u6811\u6df1\u5ea6\u503c for depth in range ( 1 , 25 ): # \u521d\u59cb\u5316\u6a21\u578b clf = tree . DecisionTreeClassifier ( max_depth = depth ) # \u9009\u62e9\u7528\u4e8e\u8bad\u7ec3\u7684\u5217/\u7279\u5f81 cols = [ 'fixed acidity' , 'volatile acidity' , 'citric acid' , 'residual sugar' , 'chlorides' , 'free sulfur dioxide' , 'total sulfur dioxide' , 'density' , 'pH' , 'sulphates' , 'alcohol' ] # \u5728\u7ed9\u5b9a\u7279\u5f81\u4e0a\u62df\u5408\u6a21\u578b clf . fit ( df_train [ cols ], df_train . quality ) # \u521b\u5efa\u8bad\u7ec3\u548c\u6d4b\u8bd5\u9884\u6d4b train_predictions = clf . predict ( df_train [ cols ]) test_predictions = clf . predict ( df_test [ cols ]) # \u8ba1\u7b97\u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u5ea6 train_accuracy = metrics . accuracy_score ( df_train . quality , train_predictions ) test_accuracy = metrics . accuracy_score ( df_test . quality , test_predictions ) # \u6dfb\u52a0\u51c6\u786e\u5ea6\u5230\u5217\u8868 train_accuracies . append ( train_accuracy ) test_accuracies . append ( test_accuracy ) # \u4f7f\u7528 matplotlib \u548c seaborn \u521b\u5efa\u4e24\u4e2a\u56fe plt . figure ( figsize = ( 10 , 5 )) sns . set_style ( \"whitegrid\" ) plt . plot ( train_accuracies , label = \"train accuracy\" ) plt . plot ( test_accuracies , label = \"test accuracy\" ) plt . legend ( loc = \"upper left\" , prop = { 'size' : 15 }) plt . xticks ( range ( 0 , 26 , 5 )) plt . xlabel ( \"max_depth\" , size = 20 ) plt . ylabel ( \"accuracy\" , size = 20 ) plt . show () \u8fd9\u5c06\u751f\u6210\u5982\u56fe 2 \u6240\u793a\u7684\u66f2\u7ebf\u56fe\u3002 \u56fe 2\uff1a\u4e0d\u540c max_depth \u8bad\u7ec3\u548c\u6d4b\u8bd5\u51c6\u786e\u7387\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5f53\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u503c\u4e3a 14 \u65f6\uff0c\u6d4b\u8bd5\u6570\u636e\u7684\u5f97\u5206\u6700\u9ad8\u3002\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u589e\u52a0\u8fd9\u4e2a\u53c2\u6570\u7684\u503c\uff0c\u6d4b\u8bd5\u51c6\u786e\u7387\u4f1a\u4fdd\u6301\u4e0d\u53d8\u6216\u53d8\u5dee\uff0c\u4f46\u8bad\u7ec3\u51c6\u786e\u7387\u4f1a\u4e0d\u65ad\u63d0\u9ad8\u3002\u8fd9\u8bf4\u660e\uff0c\u968f\u7740\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u7684\u589e\u52a0\uff0c\u51b3\u7b56\u6811\u6a21\u578b\u5bf9\u8bad\u7ec3\u6570\u636e\u7684\u5b66\u4e60\u6548\u679c\u8d8a\u6765\u8d8a\u597d\uff0c\u4f46\u6d4b\u8bd5\u6570\u636e\u7684\u6027\u80fd\u5374\u4e1d\u6beb\u6ca1\u6709\u63d0\u9ad8\u3002 \u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8fc7\u62df\u5408 \u3002 \u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u5b8c\u5168\u62df\u5408\uff0c\u800c\u5728\u6d4b\u8bd5\u96c6\u4e0a\u5374\u8868\u73b0\u4e0d\u4f73\u3002\u8fd9\u610f\u5473\u7740\u6a21\u578b\u53ef\u4ee5\u5f88\u597d\u5730\u5b66\u4e60\u8bad\u7ec3\u6570\u636e\uff0c\u4f46\u65e0\u6cd5\u6cdb\u5316\u5230\u672a\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u3002\u5728\u4e0a\u9762\u7684\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u975e\u5e38\u9ad8\u7684\u6a21\u578b\uff0c\u5b83\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u4f1a\u6709\u51fa\u8272\u7684\u7ed3\u679c\uff0c\u4f46\u8fd9\u79cd\u6a21\u578b\u5e76\u4e0d\u5b9e\u7528\uff0c\u56e0\u4e3a\u5b83\u5728\u771f\u5b9e\u4e16\u754c\u7684\u6837\u672c\u6216\u5b9e\u65f6\u6570\u636e\u4e0a\u4e0d\u4f1a\u63d0\u4f9b\u7c7b\u4f3c\u7684\u7ed3\u679c\u3002 \u6709\u4eba\u53ef\u80fd\u4f1a\u8bf4\uff0c\u8fd9\u79cd\u65b9\u6cd5\u5e76\u6ca1\u6709\u8fc7\u62df\u5408\uff0c\u56e0\u4e3a\u6d4b\u8bd5\u96c6\u7684\u51c6\u786e\u7387\u57fa\u672c\u4fdd\u6301\u4e0d\u53d8\u3002\u8fc7\u62df\u5408\u7684\u53e6\u4e00\u4e2a\u5b9a\u4e49\u662f\uff0c\u5f53\u6211\u4eec\u4e0d\u65ad\u63d0\u9ad8\u8bad\u7ec3\u635f\u5931\u65f6\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u5728\u589e\u52a0\u3002\u8fd9\u79cd\u60c5\u51b5\u5728\u795e\u7ecf\u7f51\u7edc\u4e2d\u975e\u5e38\u5e38\u89c1\u3002 \u6bcf\u5f53\u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u90fd\u5fc5\u987b\u5728\u8bad\u7ec3\u671f\u95f4\u76d1\u63a7\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u3002\u5982\u679c\u6211\u4eec\u6709\u4e00\u4e2a\u975e\u5e38\u5927\u7684\u7f51\u7edc\u6765\u5904\u7406\u4e00\u4e2a\u975e\u5e38\u5c0f\u7684\u6570\u636e\u96c6\uff08\u5373\u6837\u672c\u6570\u975e\u5e38\u5c11\uff09\uff0c\u6211\u4eec\u5c31\u4f1a\u89c2\u5bdf\u5230\uff0c\u968f\u7740\u6211\u4eec\u4e0d\u65ad\u8bad\u7ec3\uff0c\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u7684\u635f\u5931\u90fd\u4f1a\u51cf\u5c11\u3002\u4f46\u662f\uff0c\u5728\u67d0\u4e2a\u65f6\u523b\uff0c\u6d4b\u8bd5\u635f\u5931\u4f1a\u8fbe\u5230\u6700\u5c0f\u503c\uff0c\u4e4b\u540e\uff0c\u5373\u4f7f\u8bad\u7ec3\u635f\u5931\u8fdb\u4e00\u6b65\u51cf\u5c11\uff0c\u6d4b\u8bd5\u635f\u5931\u4e5f\u4f1a\u5f00\u59cb\u589e\u52a0\u3002\u6211\u4eec\u5fc5\u987b\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u6700\u5c0f\u503c\u65f6\u505c\u6b62\u8bad\u7ec3\u3002 \u8fd9\u662f\u5bf9\u8fc7\u62df\u5408\u6700\u5e38\u89c1\u7684\u89e3\u91ca \u3002 \u5965\u5361\u59c6\u5243\u5200\u7528\u7b80\u5355\u7684\u8bdd\u8bf4\uff0c\u5c31\u662f\u4e0d\u8981\u8bd5\u56fe\u628a\u53ef\u4ee5\u7528\u7b80\u5355\u5f97\u591a\u7684\u65b9\u6cd5\u89e3\u51b3\u7684\u4e8b\u60c5\u590d\u6742\u5316\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u6700\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6848\u5c31\u662f\u6700\u5177\u901a\u7528\u6027\u7684\u89e3\u51b3\u65b9\u6848\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u53ea\u8981\u4f60\u7684\u6a21\u578b\u4e0d\u7b26\u5408\u5965\u5361\u59c6\u5243\u5200\u539f\u5219\uff0c\u5c31\u5f88\u53ef\u80fd\u662f\u8fc7\u62df\u5408\u3002 \u56fe 3\uff1a\u8fc7\u62df\u5408\u7684\u6700\u4e00\u822c\u5b9a\u4e49 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u56de\u5230\u4ea4\u53c9\u68c0\u9a8c\u3002 \u5728\u89e3\u91ca\u8fc7\u62df\u5408\u65f6\uff0c\u6211\u51b3\u5b9a\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u6211\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u53e6\u4e00\u90e8\u5206\u4e0a\u68c0\u67e5\u5176\u6027\u80fd\u3002\u8fd9\u4e5f\u662f\u4ea4\u53c9\u68c0\u9a8c\u7684\u4e00\u79cd\uff0c\u901a\u5e38\u88ab\u79f0\u4e3a \"\u6682\u7559\u96c6\"\uff08 hold-out set \uff09\u3002\u5f53\u6211\u4eec\u62e5\u6709\u5927\u91cf\u6570\u636e\uff0c\u800c\u6a21\u578b\u63a8\u7406\u662f\u4e00\u4e2a\u8017\u65f6\u7684\u8fc7\u7a0b\u65f6\uff0c\u6211\u4eec\u5c31\u4f1a\u4f7f\u7528\u8fd9\u79cd\uff08\u4ea4\u53c9\uff09\u9a8c\u8bc1\u3002 \u4ea4\u53c9\u68c0\u9a8c\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\uff0c\u5b83\u662f\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u6b65\u9aa4\u3002 \u9009\u62e9\u6b63\u786e\u7684\u4ea4\u53c9\u68c0\u9a8c \u53d6\u51b3\u4e8e\u6240\u5904\u7406\u7684\u6570\u636e\u96c6\uff0c\u5728\u4e00\u4e2a\u6570\u636e\u96c6\u4e0a\u9002\u7528\u7684\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u53ef\u80fd\u4e0d\u9002\u7528\u4e8e\u5176\u4ed6\u6570\u636e\u96c6\u3002\u4e0d\u8fc7\uff0c\u6709\u51e0\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u6700\u4e3a\u6d41\u884c\u548c\u5e7f\u6cdb\u4f7f\u7528\u3002 \u5176\u4e2d\u5305\u62ec\uff1a k \u6298\u4ea4\u53c9\u68c0\u9a8c \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c \u5206\u7ec4 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u4ea4\u53c9\u68c0\u9a8c\u662f\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u5c42\u51e0\u4e2a\u90e8\u5206\uff0c\u6211\u4eec\u5728\u5176\u4e2d\u4e00\u90e8\u5206\u4e0a\u8bad\u7ec3\u6a21\u578b\uff0c\u7136\u540e\u5728\u5176\u4f59\u90e8\u5206\u4e0a\u8fdb\u884c\u6d4b\u8bd5\u3002\u8bf7\u770b\u56fe 4\u3002 \u56fe 4\uff1a\u5c06\u6570\u636e\u96c6\u62c6\u5206\u4e3a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u56fe 4 \u548c\u56fe 5 \u8bf4\u660e\uff0c\u5f53\u4f60\u5f97\u5230\u4e00\u4e2a\u6570\u636e\u96c6\u6765\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u65f6\uff0c\u4f60\u4f1a\u628a\u5b83\u4eec\u5206\u6210 \u4e24\u4e2a\u4e0d\u540c\u7684\u96c6\uff1a\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 \u3002\u5f88\u591a\u4eba\u8fd8\u4f1a\u5c06\u5176\u5206\u6210\u7b2c\u4e09\u7ec4\uff0c\u79f0\u4e4b\u4e3a\u6d4b\u8bd5\u96c6\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u5c06\u53ea\u4f7f\u7528\u4e24\u4e2a\u96c6\u3002\u5982\u4f60\u6240\u89c1\uff0c\u6211\u4eec\u5c06\u6837\u672c\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\u8fdb\u884c\u4e86\u5212\u5206\u3002\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a k \u4e2a\u4e92\u4e0d\u5173\u8054\u7684\u4e0d\u540c\u96c6\u5408\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002 \u56fe 5\uff1aK \u6298\u4ea4\u53c9\u68c0\u9a8c \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 KFold \u5c06\u4efb\u4f55\u6570\u636e\u5206\u5272\u6210 k \u4e2a\u76f8\u7b49\u7684\u90e8\u5206\u3002\u6bcf\u4e2a\u6837\u672c\u5206\u914d\u4e00\u4e2a\u4ece 0 \u5230 k-1 \u7684\u503c\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u5b58\u50a8\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a kfold \u7684\u65b0\u5217\uff0c\u5e76\u7528 -1 \u586b\u5145 df [ \"kfold\" ] = - 1 # \u63a5\u4e0b\u6765\u7684\u6b65\u9aa4\u662f\u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4ece model_selection \u6a21\u5757\u521d\u59cb\u5316 kfold \u7c7b kf = model_selection . KFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684 kfold \u5217\uff08enumerate\u7684\u4f5c\u7528\u662f\u8fd4\u56de\u4e00\u4e2a\u8fed\u4ee3\u5668\uff09 for fold , ( trn_ , val_ ) in enumerate ( kf . split ( X = df )): df . loc [ val_ , 'kfold' ] = fold # \u4fdd\u5b58\u5e26\u6709 kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u636e\u96c6\u90fd\u53ef\u4ee5\u4f7f\u7528\u6b64\u6d41\u7a0b\u3002\u4f8b\u5982\uff0c\u5f53\u6570\u636e\u56fe\u50cf\u65f6\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u56fe\u50cf ID\u3001\u56fe\u50cf\u4f4d\u7f6e\u548c\u56fe\u50cf\u6807\u7b7e\u7684 CSV\uff0c\u7136\u540e\u4f7f\u7528\u4e0a\u8ff0\u6d41\u7a0b\u3002 \u53e6\u4e00\u79cd\u91cd\u8981\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u662f \u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c \u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u6570\u636e\u96c6\uff0c\u5176\u4e2d\u6b63\u6837\u672c\u5360 90%\uff0c\u8d1f\u6837\u672c\u53ea\u5360 10%\uff0c\u90a3\u4e48\u4f60\u5c31\u4e0d\u5e94\u8be5\u4f7f\u7528\u968f\u673a k \u6298\u4ea4\u53c9\u3002\u5bf9\u8fd9\u6837\u7684\u6570\u636e\u96c6\u4f7f\u7528\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u5bfc\u81f4\u6298\u53e0\u6837\u672c\u5168\u90e8\u4e3a\u8d1f\u6837\u672c\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u66f4\u503e\u5411\u4e8e\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u53ef\u4ee5\u4fdd\u6301\u6bcf\u4e2a\u6298\u4e2d\u6807\u7b7e\u7684\u6bd4\u4f8b\u4e0d\u53d8\u3002\u56e0\u6b64\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u90fd\u4f1a\u6709\u76f8\u540c\u7684 90% \u6b63\u6837\u672c\u548c 10% \u8d1f\u6837\u672c\u3002\u56e0\u6b64\uff0c\u65e0\u8bba\u60a8\u9009\u62e9\u4ec0\u4e48\u6307\u6807\u8fdb\u884c\u8bc4\u4f30\uff0c\u90fd\u4f1a\u5728\u6240\u6709\u6298\u53e0\u4e2d\u5f97\u5230\u76f8\u4f3c\u7684\u7ed3\u679c\u3002 \u4fee\u6539\u521b\u5efa k \u6298\u4ea4\u53c9\u68c0\u9a8c\u7684\u4ee3\u7801\u4ee5\u521b\u5efa\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5f88\u5bb9\u6613\u3002\u6211\u4eec\u53ea\u9700\u5c06 model_selection.KFold \u66f4\u6539\u4e3a model_selection.StratifiedKFold \uff0c\u5e76\u5728 kf.split(...) \u51fd\u6570\u4e2d\u6307\u5b9a\u8981\u5206\u5c42\u7684\u76ee\u6807\u5217\u3002\u6211\u4eec\u5047\u8bbe CSV \u6570\u636e\u96c6\u6709\u4e00\u5217\u540d\u4e3a \"target\" \uff0c\u5e76\u4e14\u662f\u4e00\u4e2a\u5206\u7c7b\u95ee\u9898\u3002 # \u5bfc\u5165 pandas \u548c scikit-learn \u7684 model_selection \u6a21\u5757 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bad\u7ec3\u6570\u636e\u4fdd\u5b58\u5728\u540d\u4e3a train.csv \u7684 CSV \u6587\u4ef6\u4e2d df = pd . read_csv ( \"train.csv\" ) # \u6dfb\u52a0\u4e00\u4e2a\u65b0\u5217 kfold\uff0c\u5e76\u7528 -1 \u521d\u59cb\u5316 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u884c df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u53d8\u91cf y = df . target . values # \u521d\u59cb\u5316 StratifiedKFold \u7c7b\uff0c\u8bbe\u7f6e\u6298\u6570\uff08folds\uff09\u4e3a 5 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u4f7f\u7528 StratifiedKFold \u5bf9\u8c61\u7684 split \u65b9\u6cd5\u6765\u83b7\u53d6\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7d22\u5f15 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u5305\u542b kfold \u5217\u7684\u65b0 CSV \u6587\u4ef6 df . to_csv ( \"train_folds.csv\" , index = False ) \u5bf9\u4e8e\u8461\u8404\u9152\u6570\u636e\u96c6\uff0c\u6211\u4eec\u6765\u770b\u770b\u6807\u7b7e\u7684\u5206\u5e03\u60c5\u51b5\u3002 b = sns . countplot ( x = 'quality' , data = df ) b . set_xlabel ( \"quality\" , fontsize = 20 ) b . set_ylabel ( \"count\" , fontsize = 20 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u7ee7\u7eed\u4e0a\u9762\u7684\u4ee3\u7801\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u8f6c\u6362\u4e86\u76ee\u6807\u503c\u3002\u4ece\u56fe 6 \u4e2d\u6211\u4eec\u53ef\u4ee5\u770b\u51fa\uff0c\u8d28\u91cf\u504f\u5dee\u5f88\u5927\u3002\u6709\u4e9b\u7c7b\u522b\u6709\u5f88\u591a\u6837\u672c\uff0c\u6709\u4e9b\u5219\u6ca1\u6709\u90a3\u4e48\u591a\u3002\u5982\u679c\u6211\u4eec\u8fdb\u884c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u90a3\u4e48\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u76ee\u6807\u503c\u5206\u5e03\u90fd\u4e0d\u4f1a\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u56fe 6\uff1a\u8461\u8404\u9152\u6570\u636e\u96c6\u4e2d \"\u8d28\u91cf\" \u5206\u5e03\u60c5\u51b5 \u89c4\u5219\u5f88\u7b80\u5355\uff0c\u5982\u679c\u662f\u6807\u51c6\u5206\u7c7b\u95ee\u9898\uff0c\u5c31\u76f2\u76ee\u9009\u62e9\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u4f46\u5982\u679c\u6570\u636e\u91cf\u5f88\u5927\uff0c\u8be5\u600e\u4e48\u529e\u5462\uff1f\u5047\u8bbe\u6211\u4eec\u6709 100 \u4e07\u4e2a\u6837\u672c\u30025 \u500d\u4ea4\u53c9\u68c0\u9a8c\u610f\u5473\u7740\u5728 800k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u5728 200k \u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u9a8c\u8bc1\u3002\u6839\u636e\u6211\u4eec\u9009\u62e9\u7684\u7b97\u6cd5\uff0c\u5bf9\u4e8e\u8fd9\u6837\u89c4\u6a21\u7684\u6570\u636e\u96c6\u6765\u8bf4\uff0c\u8bad\u7ec3\u751a\u81f3\u9a8c\u8bc1\u90fd\u53ef\u80fd\u975e\u5e38\u6602\u8d35\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 \u6682\u7559\u4ea4\u53c9\u68c0\u9a8c \u3002 \u521b\u5efa\u4fdd\u6301\u7ed3\u679c\u7684\u8fc7\u7a0b\u4e0e\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u76f8\u540c\u3002\u5bf9\u4e8e\u62e5\u6709 100 \u4e07\u4e2a\u6837\u672c\u7684\u6570\u636e\u96c6\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa 10 \u4e2a\u6298\u53e0\u800c\u4e0d\u662f 5 \u4e2a\uff0c\u5e76\u4fdd\u7559\u5176\u4e2d\u4e00\u4e2a\u6298\u53e0\u4f5c\u4e3a\u4fdd\u7559\u6837\u672c\u3002\u8fd9\u610f\u5473\u7740\uff0c\u6211\u4eec\u5c06\u6709 10 \u4e07\u4e2a\u6837\u672c\u88ab\u4fdd\u7559\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u59cb\u7ec8\u5728\u8fd9\u4e2a\u6837\u672c\u96c6\u4e0a\u8ba1\u7b97\u635f\u5931\u3001\u51c6\u786e\u7387\u548c\u5176\u4ed6\u6307\u6807\uff0c\u5e76\u5728 90 \u4e07\u4e2a\u6837\u672c\u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002 \u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u6682\u7559\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u975e\u5e38\u5e38\u7528\u3002\u5047\u8bbe\u6211\u4eec\u8981\u89e3\u51b3\u7684\u95ee\u9898\u662f\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97 2020 \u5e74\u7684\u9500\u552e\u989d\uff0c\u800c\u6211\u4eec\u5f97\u5230\u7684\u662f 2015-2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f60\u53ef\u4ee5\u9009\u62e9 2019 \u5e74\u7684\u6240\u6709\u6570\u636e\u4f5c\u4e3a\u4fdd\u7559\u6570\u636e\uff0c\u7136\u540e\u5728 2015 \u5e74\u81f3 2018 \u5e74\u7684\u6240\u6709\u6570\u636e\u4e0a\u8bad\u7ec3\u4f60\u7684\u6a21\u578b\u3002 \u56fe 7\uff1a\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u793a\u4f8b \u5728\u56fe 7 \u6240\u793a\u7684\u793a\u4f8b\u4e2d\uff0c\u5047\u8bbe\u6211\u4eec\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u4ece\u65f6\u95f4\u6b65\u9aa4 31 \u5230 40 \u7684\u9500\u552e\u989d\u3002\u6211\u4eec\u53ef\u4ee5\u4fdd\u7559 21 \u81f3 30 \u6b65\u7684\u6570\u636e\uff0c\u7136\u540e\u4ece 0 \u6b65\u5230 20 \u6b65\u8bad\u7ec3\u6a21\u578b\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5728\u9884\u6d4b 31 \u6b65\u81f3 40 \u6b65\u65f6\uff0c\u5e94\u5c06 21 \u6b65\u81f3 30 \u6b65\u7684\u6570\u636e\u7eb3\u5165\u6a21\u578b\uff0c\u5426\u5219\uff0c\u6a21\u578b\u7684\u6027\u80fd\u5c06\u5927\u6253\u6298\u6263\u3002 \u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u5fc5\u987b\u5904\u7406\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u800c\u521b\u5efa\u5927\u578b\u9a8c\u8bc1\u96c6\u610f\u5473\u7740\u6a21\u578b\u5b66\u4e60\u4f1a\u4e22\u5931\u5927\u91cf\u6570\u636e\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9\u7559\u4e00\u4ea4\u53c9\u68c0\u9a8c\uff0c\u76f8\u5f53\u4e8e\u7279\u6b8a\u7684 k \u5219\u4ea4\u53c9\u68c0\u9a8c\u5176\u4e2d k=N \uff0cN \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8fd9\u610f\u5473\u7740\u5728\u6240\u6709\u7684\u8bad\u7ec3\u6298\u53e0\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u9664 1 \u4e4b\u5916\u7684\u6240\u6709\u6570\u636e\u6837\u672c\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u7684\u6298\u53e0\u6570\u4e0e\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u76f8\u540c\u3002 \u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u6a21\u578b\u7684\u901f\u5ea6\u4e0d\u591f\u5feb\uff0c\u8fd9\u79cd\u7c7b\u578b\u7684\u4ea4\u53c9\u68c0\u9a8c\u53ef\u80fd\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\uff0c\u4f46\u7531\u4e8e\u8fd9\u79cd\u4ea4\u53c9\u68c0\u9a8c\u53ea\u9002\u7528\u4e8e\u5c0f\u578b\u6570\u636e\u96c6\uff0c\u56e0\u6b64\u5e76\u4e0d\u91cd\u8981\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u95ee\u9898\u4e86\u3002\u56de\u5f52\u95ee\u9898\u7684\u597d\u5904\u5728\u4e8e\uff0c\u9664\u4e86\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u4e4b\u5916\uff0c\u6211\u4eec\u53ef\u4ee5\u5728\u56de\u5f52\u95ee\u9898\u4e0a\u4f7f\u7528\u4e0a\u8ff0\u6240\u6709\u4ea4\u53c9\u68c0\u9a8c\u6280\u672f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6211\u4eec\u4e0d\u80fd\u76f4\u63a5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u6709\u4e00\u4e9b\u65b9\u6cd5\u53ef\u4ee5\u7a0d\u7a0d\u6539\u53d8\u95ee\u9898\uff0c\u4ece\u800c\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u7b80\u5355\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u9002\u7528\u4e8e\u4efb\u4f55\u56de\u5f52\u95ee\u9898\u3002\u4f46\u662f\uff0c\u5982\u679c\u53d1\u73b0\u76ee\u6807\u5206\u5e03\u4e0d\u4e00\u81f4\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002 \u8981\u5728\u56de\u5f52\u95ee\u9898\u4e2d\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u6211\u4eec\u5fc5\u987b\u5148\u5c06\u76ee\u6807\u5212\u5206\u4e3a\u82e5\u5e72\u4e2a\u5206\u5c42\uff0c\u7136\u540e\u518d\u4ee5\u5904\u7406\u5206\u7c7b\u95ee\u9898\u7684\u76f8\u540c\u65b9\u5f0f\u4f7f\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\u3002\u9009\u62e9\u5408\u9002\u7684\u5206\u5c42\u6570\u6709\u51e0\u79cd\u9009\u62e9\u3002\u5982\u679c\u6837\u672c\u91cf\u5f88\u5927\uff08> 10k\uff0c> 100k\uff09\uff0c\u90a3\u4e48\u5c31\u4e0d\u9700\u8981\u8003\u8651\u5206\u5c42\u7684\u6570\u91cf\u3002\u53ea\u9700\u5c06\u6570\u636e\u5206\u4e3a 10 \u6216 20 \u5c42\u5373\u53ef\u3002\u5982\u679c\u6837\u672c\u6570\u4e0d\u591a\uff0c\u5219\u53ef\u4ee5\u4f7f\u7528 Sturge's Rule \u8fd9\u6837\u7684\u7b80\u5355\u89c4\u5219\u6765\u8ba1\u7b97\u9002\u5f53\u7684\u5206\u5c42\u6570\u3002 Sturge's Rule\uff1a \\[ Number of Bins = 1 + log_2(N) \\] \u5176\u4e2d \\(N\\) \u662f\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u6570\u3002\u8be5\u51fd\u6570\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u5229\u7528\u65af\u7279\u683c\u6cd5\u5219\u7ed8\u5236\u6837\u672c\u4e0e\u7bb1\u6570\u5bf9\u6bd4\u56fe \u8ba9\u6211\u4eec\u5236\u4f5c\u4e00\u4e2a\u56de\u5f52\u6570\u636e\u96c6\u6837\u672c\uff0c\u5e76\u5c1d\u8bd5\u5e94\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 # stratified-kfold for regression # \u4e3a\u56de\u5f52\u95ee\u9898\u8fdb\u884c\u5206\u5c42K-\u6298\u4ea4\u53c9\u9a8c\u8bc1 # \u5bfc\u5165\u9700\u8981\u7684\u5e93 import numpy as np import pandas as pd from sklearn import datasets from sklearn import model_selection # \u521b\u5efa\u5206\u6298\uff08folds\uff09\u7684\u51fd\u6570 def create_folds ( data ): # \u521b\u5efa\u4e00\u4e2a\u65b0\u5217\u53eb\u505akfold\uff0c\u5e76\u7528-1\u6765\u586b\u5145 data [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e\u7684\u884c data = data . sample ( frac = 1 ) . reset_index ( drop = True ) # \u4f7f\u7528Sturge\u89c4\u5219\u8ba1\u7b97bin\u7684\u6570\u91cf num_bins = int ( np . floor ( 1 + np . log2 ( len ( data )))) # \u4f7f\u7528pandas\u7684cut\u51fd\u6570\u8fdb\u884c\u76ee\u6807\u53d8\u91cf\uff08target\uff09\u7684\u5206\u7bb1 data . loc [:, \"bins\" ] = pd . cut ( data [ \"target\" ], bins = num_bins , labels = False ) # \u521d\u59cb\u5316StratifiedKFold\u7c7b kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u586b\u5145\u65b0\u7684kfold\u5217 # \u6ce8\u610f\uff1a\u6211\u4eec\u4f7f\u7528\u7684\u662fbins\u800c\u4e0d\u662f\u5b9e\u9645\u7684\u76ee\u6807\u53d8\u91cf\uff08target\uff09\uff01 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = data , y = data . bins . values )): data . loc [ v_ , 'kfold' ] = f # \u5220\u9664bins\u5217 data = data . drop ( \"bins\" , axis = 1 ) # \u8fd4\u56de\u5305\u542bfolds\u7684\u6570\u636e return data # \u4e3b\u7a0b\u5e8f\u5f00\u59cb if __name__ == \"__main__\" : # \u521b\u5efa\u4e00\u4e2a\u5e26\u670915000\u4e2a\u6837\u672c\u3001100\u4e2a\u7279\u5f81\u548c1\u4e2a\u76ee\u6807\u53d8\u91cf\u7684\u6837\u672c\u6570\u636e\u96c6 X , y = datasets . make_regression ( n_samples = 15000 , n_features = 100 , n_targets = 1 ) # \u4f7f\u7528numpy\u6570\u7ec4\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u6846 df = pd . DataFrame ( X , columns = [ f \"f_ { i } \" for i in range ( X . shape [ 1 ])] ) df . loc [:, \"target\" ] = y # \u521b\u5efafolds df = create_folds ( df ) \u4ea4\u53c9\u68c0\u9a8c\u662f\u6784\u5efa\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u7b2c\u4e00\u6b65\uff0c\u4e5f\u662f\u6700\u57fa\u672c\u7684\u4e00\u6b65\u3002\u5982\u679c\u8981\u505a\u7279\u5f81\u5de5\u7a0b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u8981\u5efa\u7acb\u6a21\u578b\uff0c\u9996\u5148\u8981\u62c6\u5206\u6570\u636e\u3002\u5982\u679c\u4f60\u6709\u4e00\u4e2a\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u65b9\u6848\uff0c\u5176\u4e2d\u9a8c\u8bc1\u6570\u636e\u80fd\u591f\u4ee3\u8868\u8bad\u7ec3\u6570\u636e\u548c\u771f\u5b9e\u4e16\u754c\u7684\u6570\u636e\uff0c\u90a3\u4e48\u4f60\u5c31\u80fd\u5efa\u7acb\u4e00\u4e2a\u5177\u6709\u9ad8\u5ea6\u901a\u7528\u6027\u7684\u597d\u7684\u673a\u5668\u5b66\u4e60\u6a21\u578b\u3002 \u672c\u7ae0\u4ecb\u7ecd\u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u4ea4\u53c9\u68c0\u9a8c\u4e5f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u6570\u636e\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u6839\u636e\u4f60\u7684\u95ee\u9898\u548c\u6570\u636e\u91c7\u7528\u65b0\u7684\u4ea4\u53c9\u68c0\u9a8c\u5f62\u5f0f\u3002 \u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u95ee\u9898\uff0c\u5e0c\u671b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u4ece\u60a3\u8005\u7684\u76ae\u80a4\u56fe\u50cf\u4e2d\u68c0\u6d4b\u51fa\u76ae\u80a4\u764c\u3002\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u5668\uff0c\u8be5\u5206\u7c7b\u5668\u63a5\u6536\u8f93\u5165\u56fe\u50cf\u5e76\u9884\u6d4b\u5176\u826f\u6027\u6216\u6076\u6027\u7684\u6982\u7387\u3002 \u5728\u8fd9\u7c7b\u6570\u636e\u96c6\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u53ef\u80fd\u6709\u540c\u4e00\u60a3\u8005\u7684\u591a\u5f20\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u8981\u5728\u8fd9\u91cc\u5efa\u7acb\u4e00\u4e2a\u826f\u597d\u7684\u4ea4\u53c9\u68c0\u9a8c\u7cfb\u7edf\uff0c\u5fc5\u987b\u6709\u5206\u5c42\u7684 k \u6298\u4ea4\u53c9\u68c0\u9a8c\uff0c\u4f46\u4e5f\u5fc5\u987b\u786e\u4fdd\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u60a3\u8005\u4e0d\u4f1a\u51fa\u73b0\u5728\u9a8c\u8bc1\u6570\u636e\u4e2d\u3002\u5e78\u8fd0\u7684\u662f\uff0cscikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u79f0\u4e3a GroupKFold \u7684\u4ea4\u53c9\u68c0\u9a8c\u7c7b\u578b\u3002 \u5728\u8fd9\u91cc\uff0c\u60a3\u8005\u53ef\u4ee5\u88ab\u89c6\u4e3a\u7ec4\u3002 \u4f46\u9057\u61be\u7684\u662f\uff0cscikit-learn \u65e0\u6cd5\u5c06 GroupKFold \u4e0e StratifiedKFold \u7ed3\u5408\u8d77\u6765\u3002\u6240\u4ee5\u4f60\u9700\u8981\u81ea\u5df1\u52a8\u624b\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u8bfb\u8005\u7684\u7ec3\u4e60\u3002","title":"\u4ea4\u53c9\u68c0\u9a8c"},{"location":"%E5%87%86%E5%A4%87%E7%8E%AF%E5%A2%83/","text":"\u51c6\u5907\u73af\u5883 \u5728\u6211\u4eec\u5f00\u59cb\u7f16\u7a0b\u4e4b\u524d\uff0c\u5728\u4f60\u7684\u673a\u5668\u4e0a\u8bbe\u7f6e\u597d\u4e00\u5207\u662f\u975e\u5e38\u91cd\u8981\u7684\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Ubuntu 18.04 \u548c Python 3.7.6\u3002\u5982\u679c\u4f60\u662f Windows \u7528\u6237\uff0c\u53ef\u4ee5\u901a\u8fc7\u591a\u79cd\u65b9\u5f0f\u5b89\u88c5 Ubuntu\u3002\u4f8b\u5982\uff0c\u5728\u865a\u62df\u673a\u4e0a\u5b89\u88c5\u7531Oracle\u516c\u53f8\u63d0\u4f9b\u7684\u514d\u8d39\u8f6f\u4ef6 Virtual Box\u3002\u4e0eWindows\u4e00\u8d77\u4f5c\u4e3a\u53cc\u542f\u52a8\u7cfb\u7edf\u3002\u6211\u66f4\u559c\u6b22\u53cc\u542f\u52a8\uff0c\u56e0\u4e3a\u5b83\u662f\u539f\u751f\u7684\u3002\u5982\u679c\u4f60\u4e0d\u662fUbuntu\u7528\u6237\uff0c\u5728\u4f7f\u7528\u672c\u4e66\u4e2d\u7684\u67d0\u4e9bbash\u811a\u672c\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u95ee\u9898\u3002\u4e3a\u4e86\u907f\u514d\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u53ef\u4ee5\u5728\u865a\u62df\u673a\u4e2d\u5b89\u88c5Ubuntu\uff0c\u6216\u8005\u5728Windows\u4e0a\u5b89\u88c5Linux shell\u3002 \u7528 Anaconda \u5728\u4efb\u4f55\u673a\u5668\u4e0a\u5b89\u88c5 Python \u90fd\u5f88\u7b80\u5355\u3002\u6211\u7279\u522b\u559c\u6b22 Miniconda \uff0c\u5b83\u662f conda \u7684\u6700\u5c0f\u5b89\u88c5\u7a0b\u5e8f\u3002\u5b83\u9002\u7528\u4e8e Linux\u3001OSX \u548c Windows\u3002\u7531\u4e8e Python 2 \u652f\u6301\u5df2\u4e8e 2019 \u5e74\u5e95\u7ed3\u675f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Python 3 \u53d1\u884c\u7248\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0cminiconda \u5e76\u4e0d\u50cf\u666e\u901a Anaconda \u9644\u5e26\u6240\u6709\u8f6f\u4ef6\u5305\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u968f\u65f6\u5b89\u88c5\u65b0\u8f6f\u4ef6\u5305\u3002\u5b89\u88c5 miniconda \u975e\u5e38\u7b80\u5355\u3002 \u9996\u5148\u8981\u505a\u7684\u662f\u5c06 Miniconda3 \u4e0b\u8f7d\u5230\u7cfb\u7edf\u4e2d\u3002 cd ~/Downloads wget https://repo.anaconda.com/miniconda/... \u5176\u4e2d wget \u547d\u4ee4\u540e\u7684 URL \u662f miniconda3 \u7f51\u9875\u7684 URL\u3002\u5bf9\u4e8e 64 \u4f4d Linux \u7cfb\u7edf\uff0c\u7f16\u5199\u672c\u4e66\u65f6\u7684 URL \u662f https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \u4e0b\u8f7d miniconda3 \u540e\uff0c\u53ef\u4ee5\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\uff1a sh Miniconda3-latest-Linux-x86_64.sh \u63a5\u4e0b\u6765\uff0c\u8bf7\u9605\u8bfb\u5e76\u6309\u7167\u5c4f\u5e55\u4e0a\u7684\u8bf4\u660e\u64cd\u4f5c\u3002\u5982\u679c\u5b89\u88c5\u6b63\u786e\uff0c\u4f60\u5e94\u8be5\u53ef\u4ee5\u901a\u8fc7\u5728\u7ec8\u7aef\u8f93\u5165 conda init \u6765\u542f\u52a8 conda \u73af\u5883\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u5728\u672c\u4e66\u4e2d\u4e00\u76f4\u4f7f\u7528\u7684 conda \u73af\u5883\u3002\u8981\u521b\u5efa conda \u73af\u5883\uff0c\u53ef\u4ee5\u8f93\u5165\uff1a conda create -n environment_name python = 3 .7.6 \u6b64\u547d\u4ee4\u5c06\u521b\u5efa\u540d\u4e3a environment_name \u7684 conda \u73af\u5883\uff0c\u53ef\u4ee5\u4f7f\u7528\uff1a conda activate environment_name \u73b0\u5728\u6211\u4eec\u7684\u73af\u5883\u5df2\u7ecf\u642d\u5efa\u5b8c\u6bd5\u3002\u662f\u65f6\u5019\u5b89\u88c5\u4e00\u4e9b\u6211\u4eec\u4f1a\u7528\u5230\u7684\u8f6f\u4ef6\u5305\u4e86\u3002\u5728 conda \u73af\u5883\u4e2d\uff0c\u5b89\u88c5\u8f6f\u4ef6\u5305\u6709\u4e24\u79cd\u4e0d\u540c\u7684\u65b9\u5f0f\u3002 \u4f60\u53ef\u4ee5\u4ece conda \u4ed3\u5e93\u6216 PyPi \u5b98\u65b9\u4ed3\u5e93\u5b89\u88c5\u8f6f\u4ef6\u5305\u3002 conda/pip install package_name \u6ce8\u610f\uff1a\u67d0\u4e9b\u8f6f\u4ef6\u5305\u53ef\u80fd\u65e0\u6cd5\u5728 conda \u8f6f\u4ef6\u4ed3\u5e93\u4e2d\u627e\u5230\u3002\u56e0\u6b64\uff0c\u5728\u672c\u4e66\u4e2d\uff0c\u4f7f\u7528 pip \u5b89\u88c5\u662f\u6700\u53ef\u53d6\u7684\u65b9\u6cd5\u3002\u6211\u5df2\u7ecf\u521b\u5efa\u4e86\u4e00\u4e2a\u7f16\u5199\u672c\u4e66\u65f6\u4f7f\u7528\u7684\u8f6f\u4ef6\u5305\u5217\u8868\uff0c\u4fdd\u5b58\u5728 environment.yml \u4e2d\u3002 \u4f60\u53ef\u4ee5\u5728\u6211\u7684 GitHub \u4ed3\u5e93\u4e2d\u7684\u989d\u5916\u8d44\u6599\u4e2d\u627e\u5230\u5b83\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u521b\u5efa\u73af\u5883\uff1a conda env create -f environment.yml \u8be5\u547d\u4ee4\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a ml \u7684\u73af\u5883\u3002\u8981\u6fc0\u6d3b\u8be5\u73af\u5883\u5e76\u5f00\u59cb\u4f7f\u7528\uff0c\u5e94\u8fd0\u884c\uff1a conda activate ml \u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u51c6\u5907\u5c31\u7eea\uff0c\u53ef\u4ee5\u8fdb\u884c\u4e00\u4e9b\u5e94\u7528\u673a\u5668\u5b66\u4e60\u7684\u5de5\u4f5c\u4e86\uff01\u5728\u4f7f\u7528\u672c\u4e66\u8fdb\u884c\u7f16\u7801\u65f6\uff0c\u8bf7\u59cb\u7ec8\u8bb0\u4f4f\u8981\u5728 \"ml \"\u73af\u5883\u4e0b\u8fdb\u884c\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5f00\u59cb\u5b66\u4e60\u771f\u6b63\u7684\u7b2c\u4e00\u7ae0\u3002","title":"\u51c6\u5907\u73af\u5883"},{"location":"%E5%87%86%E5%A4%87%E7%8E%AF%E5%A2%83/#_1","text":"\u5728\u6211\u4eec\u5f00\u59cb\u7f16\u7a0b\u4e4b\u524d\uff0c\u5728\u4f60\u7684\u673a\u5668\u4e0a\u8bbe\u7f6e\u597d\u4e00\u5207\u662f\u975e\u5e38\u91cd\u8981\u7684\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Ubuntu 18.04 \u548c Python 3.7.6\u3002\u5982\u679c\u4f60\u662f Windows \u7528\u6237\uff0c\u53ef\u4ee5\u901a\u8fc7\u591a\u79cd\u65b9\u5f0f\u5b89\u88c5 Ubuntu\u3002\u4f8b\u5982\uff0c\u5728\u865a\u62df\u673a\u4e0a\u5b89\u88c5\u7531Oracle\u516c\u53f8\u63d0\u4f9b\u7684\u514d\u8d39\u8f6f\u4ef6 Virtual Box\u3002\u4e0eWindows\u4e00\u8d77\u4f5c\u4e3a\u53cc\u542f\u52a8\u7cfb\u7edf\u3002\u6211\u66f4\u559c\u6b22\u53cc\u542f\u52a8\uff0c\u56e0\u4e3a\u5b83\u662f\u539f\u751f\u7684\u3002\u5982\u679c\u4f60\u4e0d\u662fUbuntu\u7528\u6237\uff0c\u5728\u4f7f\u7528\u672c\u4e66\u4e2d\u7684\u67d0\u4e9bbash\u811a\u672c\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u95ee\u9898\u3002\u4e3a\u4e86\u907f\u514d\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u53ef\u4ee5\u5728\u865a\u62df\u673a\u4e2d\u5b89\u88c5Ubuntu\uff0c\u6216\u8005\u5728Windows\u4e0a\u5b89\u88c5Linux shell\u3002 \u7528 Anaconda \u5728\u4efb\u4f55\u673a\u5668\u4e0a\u5b89\u88c5 Python \u90fd\u5f88\u7b80\u5355\u3002\u6211\u7279\u522b\u559c\u6b22 Miniconda \uff0c\u5b83\u662f conda \u7684\u6700\u5c0f\u5b89\u88c5\u7a0b\u5e8f\u3002\u5b83\u9002\u7528\u4e8e Linux\u3001OSX \u548c Windows\u3002\u7531\u4e8e Python 2 \u652f\u6301\u5df2\u4e8e 2019 \u5e74\u5e95\u7ed3\u675f\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 Python 3 \u53d1\u884c\u7248\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0cminiconda \u5e76\u4e0d\u50cf\u666e\u901a Anaconda \u9644\u5e26\u6240\u6709\u8f6f\u4ef6\u5305\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u968f\u65f6\u5b89\u88c5\u65b0\u8f6f\u4ef6\u5305\u3002\u5b89\u88c5 miniconda \u975e\u5e38\u7b80\u5355\u3002 \u9996\u5148\u8981\u505a\u7684\u662f\u5c06 Miniconda3 \u4e0b\u8f7d\u5230\u7cfb\u7edf\u4e2d\u3002 cd ~/Downloads wget https://repo.anaconda.com/miniconda/... \u5176\u4e2d wget \u547d\u4ee4\u540e\u7684 URL \u662f miniconda3 \u7f51\u9875\u7684 URL\u3002\u5bf9\u4e8e 64 \u4f4d Linux \u7cfb\u7edf\uff0c\u7f16\u5199\u672c\u4e66\u65f6\u7684 URL \u662f https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \u4e0b\u8f7d miniconda3 \u540e\uff0c\u53ef\u4ee5\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\uff1a sh Miniconda3-latest-Linux-x86_64.sh \u63a5\u4e0b\u6765\uff0c\u8bf7\u9605\u8bfb\u5e76\u6309\u7167\u5c4f\u5e55\u4e0a\u7684\u8bf4\u660e\u64cd\u4f5c\u3002\u5982\u679c\u5b89\u88c5\u6b63\u786e\uff0c\u4f60\u5e94\u8be5\u53ef\u4ee5\u901a\u8fc7\u5728\u7ec8\u7aef\u8f93\u5165 conda init \u6765\u542f\u52a8 conda \u73af\u5883\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u5728\u672c\u4e66\u4e2d\u4e00\u76f4\u4f7f\u7528\u7684 conda \u73af\u5883\u3002\u8981\u521b\u5efa conda \u73af\u5883\uff0c\u53ef\u4ee5\u8f93\u5165\uff1a conda create -n environment_name python = 3 .7.6 \u6b64\u547d\u4ee4\u5c06\u521b\u5efa\u540d\u4e3a environment_name \u7684 conda \u73af\u5883\uff0c\u53ef\u4ee5\u4f7f\u7528\uff1a conda activate environment_name \u73b0\u5728\u6211\u4eec\u7684\u73af\u5883\u5df2\u7ecf\u642d\u5efa\u5b8c\u6bd5\u3002\u662f\u65f6\u5019\u5b89\u88c5\u4e00\u4e9b\u6211\u4eec\u4f1a\u7528\u5230\u7684\u8f6f\u4ef6\u5305\u4e86\u3002\u5728 conda \u73af\u5883\u4e2d\uff0c\u5b89\u88c5\u8f6f\u4ef6\u5305\u6709\u4e24\u79cd\u4e0d\u540c\u7684\u65b9\u5f0f\u3002 \u4f60\u53ef\u4ee5\u4ece conda \u4ed3\u5e93\u6216 PyPi \u5b98\u65b9\u4ed3\u5e93\u5b89\u88c5\u8f6f\u4ef6\u5305\u3002 conda/pip install package_name \u6ce8\u610f\uff1a\u67d0\u4e9b\u8f6f\u4ef6\u5305\u53ef\u80fd\u65e0\u6cd5\u5728 conda \u8f6f\u4ef6\u4ed3\u5e93\u4e2d\u627e\u5230\u3002\u56e0\u6b64\uff0c\u5728\u672c\u4e66\u4e2d\uff0c\u4f7f\u7528 pip \u5b89\u88c5\u662f\u6700\u53ef\u53d6\u7684\u65b9\u6cd5\u3002\u6211\u5df2\u7ecf\u521b\u5efa\u4e86\u4e00\u4e2a\u7f16\u5199\u672c\u4e66\u65f6\u4f7f\u7528\u7684\u8f6f\u4ef6\u5305\u5217\u8868\uff0c\u4fdd\u5b58\u5728 environment.yml \u4e2d\u3002 \u4f60\u53ef\u4ee5\u5728\u6211\u7684 GitHub \u4ed3\u5e93\u4e2d\u7684\u989d\u5916\u8d44\u6599\u4e2d\u627e\u5230\u5b83\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u521b\u5efa\u73af\u5883\uff1a conda env create -f environment.yml \u8be5\u547d\u4ee4\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a ml \u7684\u73af\u5883\u3002\u8981\u6fc0\u6d3b\u8be5\u73af\u5883\u5e76\u5f00\u59cb\u4f7f\u7528\uff0c\u5e94\u8fd0\u884c\uff1a conda activate ml \u73b0\u5728\u6211\u4eec\u5df2\u7ecf\u51c6\u5907\u5c31\u7eea\uff0c\u53ef\u4ee5\u8fdb\u884c\u4e00\u4e9b\u5e94\u7528\u673a\u5668\u5b66\u4e60\u7684\u5de5\u4f5c\u4e86\uff01\u5728\u4f7f\u7528\u672c\u4e66\u8fdb\u884c\u7f16\u7801\u65f6\uff0c\u8bf7\u59cb\u7ec8\u8bb0\u4f4f\u8981\u5728 \"ml \"\u73af\u5883\u4e0b\u8fdb\u884c\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5f00\u59cb\u5b66\u4e60\u771f\u6b63\u7684\u7b2c\u4e00\u7ae0\u3002","title":"\u51c6\u5907\u73af\u5883"},{"location":"%E5%8F%AF%E9%87%8D%E5%A4%8D%E4%BB%A3%E7%A0%81%E5%92%8C%E6%A8%A1%E5%9E%8B%E6%96%B9%E6%B3%95/","text":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u5230\u4e86\u53ef\u4ee5\u5c06\u6a21\u578b/\u8bad\u7ec3\u4ee3\u7801\u5206\u53d1\u7ed9\u4ed6\u4eba\u4f7f\u7528\u7684\u9636\u6bb5\u3002\u60a8\u53ef\u4ee5\u7528\u8f6f\u76d8\u5206\u53d1\u6216\u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\uff0c\u4f46\u8fd9\u5e76\u4e0d\u7406\u60f3\u3002\u662f\u8fd9\u6837\u5417\uff1f\u4e5f\u8bb8\u5f88\u591a\u5e74\u524d\uff0c\u8fd9\u662f\u7406\u60f3\u7684\u505a\u6cd5\uff0c\u4f46\u73b0\u5728\u4e0d\u662f\u4e86\u3002 \u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\u548c\u534f\u4f5c\u7684\u9996\u9009\u65b9\u5f0f\u662f\u4f7f\u7528\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u3002Git \u662f\u6700\u6d41\u884c\u7684\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u4e4b\u4e00\u3002\u90a3\u4e48\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u5b66\u4f1a\u4e86 Git\uff0c\u5e76\u6b63\u786e\u5730\u683c\u5f0f\u5316\u4e86\u4ee3\u7801\uff0c\u7f16\u5199\u4e86\u9002\u5f53\u7684\u6587\u6863\uff0c\u8fd8\u5f00\u6e90\u4e86\u4f60\u7684\u9879\u76ee\u3002\u8fd9\u5c31\u591f\u4e86\u5417\uff1f\u4e0d\uff0c\u8fd8\u4e0d\u591f\u3002\u56e0\u4e3a\u4f60\u5728\u81ea\u5df1\u7684\u7535\u8111\u4e0a\u5199\u7684\u4ee3\u7801\uff0c\u5728\u522b\u4eba\u7684\u7535\u8111\u4e0a\u53ef\u80fd\u4f1a\u56e0\u4e3a\u5404\u79cd\u539f\u56e0\u800c\u65e0\u6cd5\u8fd0\u884c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u5728\u53d1\u5e03\u4ee3\u7801\u65f6\u80fd\u590d\u5236\u81ea\u5df1\u7684\u7535\u8111\uff0c\u800c\u5176\u4ed6\u4eba\u5728\u5b89\u88c5\u60a8\u7684\u8f6f\u4ef6\u6216\u8fd0\u884c\u60a8\u7684\u4ee3\u7801\u65f6\u4e5f\u80fd\u590d\u5236\u60a8\u7684\u7535\u8111\uff0c\u90a3\u5c31\u518d\u597d\u4e0d\u8fc7\u4e86\u3002\u4e3a\u6b64\uff0c\u5982\u4eca\u6700\u6d41\u884c\u7684\u65b9\u6cd5\u662f\u4f7f\u7528 Docker \u5bb9\u5668\uff08Docker Containers\uff09\u3002\u8981\u4f7f\u7528 Docker \u5bb9\u5668\uff0c\u4f60\u9700\u8981\u5b89\u88c5 Docker\u3002 \u8ba9\u6211\u4eec\u7528\u4e0b\u9762\u7684\u547d\u4ee4\u6765\u5b89\u88c5 Docker\u3002 sudo apt install docker.io sudo systemctl start docker sudo systemctl enable docker sudo groupadd docker sudo usermod -aG docker $USER \u8fd9\u4e9b\u547d\u4ee4\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e0a\u8fd0\u884c\u3002Docker \u6700\u68d2\u7684\u5730\u65b9\u5728\u4e8e\u5b83\u53ef\u4ee5\u5b89\u88c5\u5728\u4efb\u4f55\u673a\u5668\u4e0a\uff1a Linux\u3001Windows\u3001OSX\u3002\u56e0\u6b64\uff0c\u5982\u679c\u4f60\u4e00\u76f4\u5728 Docker \u5bb9\u5668\u4e2d\u5de5\u4f5c\uff0c\u54ea\u53f0\u673a\u5668\u90fd\u6ca1\u5173\u7cfb\uff01 Docker \u5bb9\u5668\u53ef\u4ee5\u88ab\u89c6\u4e3a\u5c0f\u578b\u865a\u62df\u673a\u3002\u4f60\u53ef\u4ee5\u4e3a\u4f60\u7684\u4ee3\u7801\u521b\u5efa\u4e00\u4e2a\u5bb9\u5668\uff0c\u7136\u540e\u6bcf\u4e2a\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u548c\u8bbf\u95ee\u5b83\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u521b\u5efa\u53ef\u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u5bb9\u5668\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e00\u7ae0\u4e2d\u8bad\u7ec3\u7684 BERT \u6a21\u578b\uff0c\u5e76\u5c1d\u8bd5\u5c06\u8bad\u7ec3\u4ee3\u7801\u5bb9\u5668\u5316\u3002 \u9996\u5148\uff0c\u4f60\u9700\u8981\u4e00\u4e2a\u5305\u542b python \u9879\u76ee\u9700\u6c42\u7684\u6587\u4ef6\u3002\u9700\u6c42\u5305\u542b\u5728\u540d\u4e3a requirements.txt \u7684\u6587\u4ef6\u4e2d\u3002\u6587\u4ef6\u540d\u662f thestandard\u3002\u8be5\u6587\u4ef6\u5305\u542b\u9879\u76ee\u4e2d\u4f7f\u7528\u7684\u6240\u6709 python \u5e93\u3002\u4e5f\u5c31\u662f\u53ef\u4ee5\u901a\u8fc7 PyPI (pip) \u4e0b\u8f7d\u7684 python \u5e93\u3002\u7528\u4e8e \u8bad\u7ec3 BERT \u6a21\u578b\u4ee5\u68c0\u6d4b\u6b63/\u8d1f\u60c5\u611f\uff0c\u6211\u4eec\u4f7f\u7528\u4e86 torch\u3001transformers\u3001tqdm\u3001scikit-learn\u3001pandas \u548c numpy\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u4eec\u5199\u5165 requirements.txt \u4e2d\u3002\u4f60\u53ef\u4ee5\u53ea\u5199\u540d\u79f0\uff0c\u4e5f\u53ef\u4ee5\u5305\u62ec\u7248\u672c\u3002\u5305\u542b\u7248\u672c\u603b\u662f\u6700\u597d\u7684\uff0c\u8fd9\u4e5f\u662f\u4f60\u5e94\u8be5\u505a\u7684\u3002\u5305\u542b\u7248\u672c\u540e\uff0c\u53ef\u4ee5\u786e\u4fdd\u5176\u4ed6\u4eba\u4f7f\u7528\u7684\u7248\u672c\u4e0e\u4f60\u7684\u7248\u672c\u76f8\u540c\uff0c\u800c\u4e0d\u662f\u6700\u65b0\u7248\u672c\uff0c\u56e0\u4e3a\u6700\u65b0\u7248\u672c\u53ef\u80fd\u4f1a\u66f4\u6539\u67d0\u4e9b\u5185\u5bb9\uff0c\u5982\u679c\u662f\u8fd9\u6837\u7684\u8bdd\uff0c\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u5c31\u4e0d\u4f1a\u4e0e\u4f60\u7684\u76f8\u540c\u4e86\u3002 \u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u663e\u793a\u4e86 requirements.txt\u3002 # requirements.txt pandas == 1.0.4 scikit - learn == 0.22.1 torch == 1.5.0 transformers == 2.11.0 \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a Dockerfile \u7684 Docker \u6587\u4ef6\u3002\u6ca1\u6709\u6269\u5c55\u540d\u3002Dockerfile \u6709\u51e0\u4e2a\u5143\u7d20\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 # Dockerfile # First of all, we include where we are getting the image # from. Image can be thought of as an operating system. # You can do \"FROM ubuntu:18.04\" # this will start from a clean ubuntu 18.04 image. # All images are downloaded from dockerhub # Here are we grabbing image from nvidia's repo # they created a docker image using ubuntu 18.04 # and installed cuda 10.1 and cudnn7 in it. Thus, we don't have to # install it. Makes our life easy. FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 # this is the same apt-get command that you are used to # except the fact that, we have -y argument. Its because # when we build this container, we cannot press Y when asked for RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* # We add a new user called \"abhishek\" # this can be anything. Anything you want it # to be. Usually, we don't use our own name, # you can use \"user\" or \"ubuntu\" RUN useradd -m abhishek # make our user own its own home directory RUN chown -R abhishek:abhishek /home/abhishek/ # copy all files from this direrctory to a # directory called app inside the home of abhishek # and abhishek owns it. COPY --chown = abhishek *.* /home/abhishek/app/ # change to user abhishek USER abhishek RUN mkdir /home/abhishek/data/ # Now we install all the requirements # after moving to the app directory # PLEASE NOTE that ubuntu 18.04 image # has python 3.6.9 and not python 3.7.6 # you can also install conda python here and use that # however, to simplify it, I will be using python 3.6.9 # inside the docker container!!!! RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt # install mkl. its needed for transformers RUN pip3 install mkl # when we log into the docker container, # we will go inside this directory automatically WORKDIR /home/abhishek/app \u521b\u5efa\u597d Docker \u6587\u4ef6\u540e\uff0c\u6211\u4eec\u5c31\u9700\u8981\u6784\u5efa\u5b83\u3002\u6784\u5efa Docker \u5bb9\u5668\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u547d\u4ee4\u3002 docker build -f Dockerfile -t bert:train . \u8be5\u547d\u4ee4\u6839\u636e\u63d0\u4f9b\u7684 Dockerfile \u6784\u5efa\u4e00\u4e2a\u5bb9\u5668\u3002Docker \u5bb9\u5668\u7684\u540d\u79f0\u662f bert:train\u3002\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a \u276f docker build -f Dockerfile -t bert:train . Sending build context to Docker daemon 19.97kB Step 1/7 : FROM nvidia/cuda:10.1-cudnn7-ubuntu18.04 ---> 3b55548ae91f Step 2/7 : RUN apt-get update && apt-get install -y git curl ca- certificates python3 python3-pip sudo && rm -rf /var/lib/apt/lists/* . . . . Removing intermediate container 8f6975dd08ba ---> d1802ac9f1b4 Step 7/7 : WORKDIR /home/abhishek/app ---> Running in 257ff09502ed Removing intermediate container 257ff09502ed ---> e5f6eb4cddd7 Successfully built e5f6eb4cddd7 Successfully tagged bert:train \u8bf7\u6ce8\u610f\uff0c\u6211\u5220\u9664\u4e86\u8f93\u51fa\u4e2d\u7684\u8bb8\u591a\u884c\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u767b\u5f55\u5bb9\u5668\u3002 docker run -ti bert:train /bin/bash \u4f60\u9700\u8981\u8bb0\u4f4f\uff0c\u4e00\u65e6\u9000\u51fa shell\uff0c\u4f60\u5728 shell \u4e2d\u6240\u505a\u7684\u4e00\u5207\u90fd\u5c06\u4e22\u5931\u3002\u4f60\u8fd8\u53ef\u4ee5\u5728 Docker \u5bb9\u5668\u4e2d\u4f7f\u7528\u3002 docker run -ti bert:train python3 train.py \u8f93\u51fa\u60c5\u51b5\uff1a Traceback (most recent call last): File \"train.py\", line 2, in import config File \"/home/abhishek/app/config.py\", line 28, in do_lower_case=True File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 393, in from_pretrained return cls._from_pretrained(*inputs, **kwargs) File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 496, in _from_pretrained list(cls.vocab_files_names.values()), OSError: Model name '../input/bert_base_uncased/' was not found in tokenizers model name list (bert-base-uncased, bert-large-uncased, bert- base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base- multilingual-cased, bert-base-chinese, bert-base-german-cased, bert- large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased- whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert- base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base- finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). We assumed '../input/bert_base_uncased/' was a path, a model identifier, or url to a directory containing vocabulary files named ['vocab.txt'] but couldn't find such vocabulary files at this path or url. \u54ce\u5440\uff0c\u51fa\u9519\u4e86\uff01 \u6211\u4e3a\u4ec0\u4e48\u8981\u628a\u9519\u8bef\u5370\u5728\u4e66\u4e0a\u5462\uff1f \u56e0\u4e3a\u7406\u89e3\u8fd9\u4e2a\u9519\u8bef\u975e\u5e38\u91cd\u8981\u3002\u8fd9\u4e2a\u9519\u8bef\u8bf4\u660e\u4ee3\u7801\u65e0\u6cd5\u627e\u5230\u76ee\u5f55\".../input/bert_base_cased\"\u3002\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u5728\u6ca1\u6709 Docker \u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u76ee\u5f55\u548c\u6240\u6709\u6587\u4ef6\u90fd\u5b58\u5728\u3002\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u662f\u56e0\u4e3a Docker \u5c31\u50cf\u4e00\u4e2a\u865a\u62df\u673a\uff01\u5b83\u6709\u81ea\u5df1\u7684\u6587\u4ef6\u7cfb\u7edf\uff0c\u672c\u5730\u673a\u5668\u4e0a\u7684\u6587\u4ef6\u4e0d\u4f1a\u5171\u4eab\u7ed9 Docker \u5bb9\u5668\u3002\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u672c\u5730\u673a\u5668\u4e0a\u7684\u8def\u5f84\u5e76\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\uff0c\u4f60\u9700\u8981\u5728\u8fd0\u884c Docker \u65f6\u5c06\u5176\u6302\u8f7d\u5230 Docker \u5bb9\u5668\u4e0a\u3002\u5f53\u6211\u4eec\u67e5\u770b\u8fd9\u4e2a\u6587\u4ef6\u5939\u7684\u8def\u5f84\u65f6\uff0c\u6211\u4eec\u77e5\u9053\u5b83\u4f4d\u4e8e\u540d\u4e3a input \u7684\u6587\u4ef6\u5939\u7684\u4e0a\u4e00\u7ea7\u3002\u8ba9\u6211\u4eec\u7a0d\u5fae\u4fee\u6539\u4e00\u4e0b config.py \u6587\u4ef6\uff01 # config.py import os import transformers # fetch home directory # in our docker container, it is # /home/abhishek HOME_DIR = os . path . expanduser ( \"~\" ) # this is the maximum number of tokens in the sentence MAX_LEN = 512 # batch sizes is low because model is huge! TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 # let's train for a maximum of 10 epochs EPOCHS = 10 # define path to BERT model files # Now we assume that all the data is stored inside # /home/abhishek/data BERT_PATH = os . path . join ( HOME_DIR , \"data\" , \"bert_base_uncased\" ) # this is where you want to save the model MODEL_PATH = os . path . join ( HOME_DIR , \"data\" , \"model.bin\" ) # training file TRAINING_FILE = os . path . join ( HOME_DIR , \"data\" , \"imdb.csv\" ) TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u73b0\u5728\uff0c\u4ee3\u7801\u5047\u5b9a\u6240\u6709\u5185\u5bb9\u90fd\u5728\u4e3b\u76ee\u5f55\u4e0b\u540d\u4e3a data \u7684\u6587\u4ef6\u5939\u4e2d\u3002 \u8bf7\u6ce8\u610f\uff0c\u5982\u679c Python \u811a\u672c\u6709\u4efb\u4f55\u6539\u52a8\uff0c\u90fd\u610f\u5473\u7740\u9700\u8981\u91cd\u5efa Docker \u5bb9\u5668\uff01\u56e0\u6b64\uff0c\u6211\u4eec\u91cd\u5efa\u5bb9\u5668\uff0c\u7136\u540e\u91cd\u65b0\u8fd0\u884c Docker \u547d\u4ee4\uff0c\u4f46\u8fd9\u6b21\u8981\u6709\u6240\u6539\u53d8\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u6ca1\u6709\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u8fd9\u4e5f\u662f\u884c\u4e0d\u901a\u7684\u3002\u522b\u62c5\u5fc3\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a Docker \u5bb9\u5668\u3002\u4f60\u53ea\u9700\u8981\u505a\u4e00\u6b21\u3002\u8981\u5b89\u88c5\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e2d\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\u3002 distribution = $( . /etc/os-release ; echo $ID$VERSION_ID ) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/ $distribution /nvidia- docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u6784\u5efa\u6211\u4eec\u7684\u5bb9\u5668\uff0c\u5e76\u5f00\u59cb\u8bad\u7ec3\u8fc7\u7a0b\uff1a docker run --gpus 1 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:train python3 train.py \u5176\u4e2d\uff0c-gpus 1 \u8868\u793a\u6211\u4eec\u5728 docker \u5bb9\u5668\u4e2d\u4f7f\u7528 1 \u4e2a GPU\uff0c-v \u8868\u793a\u6302\u8f7d\u5377\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u8981\u5c06\u672c\u5730\u76ee\u5f55 /home/abhishek/workspace/approaching_almost/input/ \u6302\u8f7d\u5230 docker \u5bb9\u5668\u4e2d\u7684 /home/abhishek/data/\u3002\u8fd9\u4e00\u6b65\u8981\u82b1\u70b9\u65f6\u95f4\uff0c\u4f46\u5b8c\u6210\u540e\uff0c\u672c\u5730\u6587\u4ef6\u5939\u4e2d\u5c31\u4f1a\u6709 model.bin\u3002 \u8fd9\u6837\uff0c\u53ea\u9700\u505a\u4e00\u4e9b\u7b80\u5355\u7684\u6539\u52a8\uff0c\u4f60\u7684\u8bad\u7ec3\u4ee3\u7801\u5c31\u5df2\u7ecf \"dockerized \"\u4e86\u3002\u73b0\u5728\uff0c\u4f60\u53ef\u4ee5\u5728\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u7cfb\u7edf\u4e0a\u4f7f\u7528\u8fd9\u4e9b\u4ee3\u7801\u8fdb\u884c\u8bad\u7ec3\u3002 \u4e0b\u4e00\u90e8\u5206\u662f\u5c06\u6211\u4eec\u8bad\u7ec3\u597d\u7684\u6a21\u578b \"\u63d0\u4f9b \"\u7ed9\u6700\u7ec8\u7528\u6237\u3002\u5047\u8bbe\u60a8\u60f3\u4ece\u63a5\u6536\u5230\u7684\u63a8\u6587\u6d41\u4e2d\u63d0\u53d6\u60c5\u611f\u4fe1\u606f\u3002\u8981\u5b8c\u6210\u8fd9\u9879\u4efb\u52a1\uff0c\u60a8\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a API\uff0c\u7528\u4e8e\u8f93\u5165\u53e5\u5b50\uff0c\u7136\u540e\u8fd4\u56de\u5e26\u6709\u60c5\u611f\u6982\u7387\u7684\u8f93\u51fa\u3002\u4f7f\u7528 Python \u6784\u5efa API \u7684\u6700\u5e38\u89c1\u65b9\u6cd5\u662f\u4f7f\u7528 Flask \uff0c\u5b83\u662f\u4e00\u4e2a\u5fae\u578b\u7f51\u7edc\u670d\u52a1\u6846\u67b6\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) MODEL = None DEVICE = \"cuda\" def sentence_prediction ( sentence ): tokenizer = config . TOKENIZER max_len = config . MAX_LEN review = str ( sentence ) review = \" \" . join ( review . split ()) inputs = tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = max_len ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] padding_length = max_len - len ( ids ) ids = ids + ([ 0 ] * padding_length ) mask = mask + ([ 0 ] * padding_length ) token_type_ids = token_type_ids + ([ 0 ] * padding_length ) ids = torch . tensor ( ids , dtype = torch . long ) . unsqueeze ( 0 ) mask = torch . tensor ( mask , dtype = torch . long ) . unsqueeze ( 0 ) token_type_ids = torch . tensor ( token_type_ids , dtype = torch . long ) . unsqueeze ( 0 ) ids = ids . to ( DEVICE , dtype = torch . long ) token_type_ids = token_type_ids . to ( DEVICE , dtype = torch . long ) mask = mask . to ( DEVICE , dtype = torch . long ) outputs = MODEL ( ids = ids , mask = mask , token_type_ids = token_type_ids ) outputs = torch . sigmoid ( outputs ) . cpu () . detach () . numpy () return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): sentence = request . args . get ( \"sentence\" ) start_time = time . time () positive_prediction = sentence_prediction ( sentence ) negative_prediction = 1 - positive_prediction response = {} response [ \"response\" ] = { \"positive\" : str ( positive_prediction ), \"negative\" : str ( negative_prediction ), \"sentence\" : str ( sentence ), \"time_taken\" : str ( time . time () - start_time ), } return flask . jsonify ( response ) if __name__ == \"__main__\" : MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ) )) MODEL . to ( DEVICE ) MODEL . eval () app . run ( host = \"0.0.0.0\" ) \u7136\u540e\u8fd0\u884c \"python api.py \"\u547d\u4ee4\u542f\u52a8 API\u3002API \u5c06\u5728\u7aef\u53e3 5000 \u7684 localhost \u4e0a\u542f\u52a8\u3002cURL \u8bf7\u6c42\u53ca\u5176\u54cd\u5e94\u793a\u4f8b\u5982\u4e0b\u3002 \u276f curl $'http://192.168.86.48:5000/predict?sentence=this%20is%20the%20best%20boo k%20ever' {\"response\":{\"negative\":\"0.0032927393913269043\",\"positive\":\"0.99670726\",\" sentence\":\"this is the best book ever\",\"time_taken\":\"0.029126882553100586\"}} \u53ef\u4ee5\u770b\u5230\uff0c\u6211\u4eec\u5f97\u5230\u7684\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002 \u60a8\u8fd8\u53ef\u4ee5\u8bbf\u95ee http://127.0.0.1:5000/predict?sentence=this%20book%20is%20too%20complicated%20for%20me\u3002\u8fd9\u5c06\u518d\u6b21\u8fd4\u56de\u4e00\u4e2a JSON \u6587\u4ef6\u3002 { response : { negative : \"0.8646619468927383\" , positive : \"0.13533805\" , sentence : \"this book is too complicated for me\" , time_taken : \"0.03852701187133789\" } } \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\uff0c\u53ef\u4ee5\u7528\u6765\u4e3a\u5c11\u91cf\u7528\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u4e3a\u4ec0\u4e48\u662f\u5c11\u91cf\uff1f\u56e0\u4e3a\u8fd9\u4e2a API \u4e00\u6b21\u53ea\u670d\u52a1\u4e00\u4e2a\u8bf7\u6c42\u3002gunicorn \u662f UNIX \u4e0a\u7684 Python WSGI HTTP \u670d\u52a1\u5668\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u5b83\u7684 CPU \u6765\u5904\u7406\u591a\u4e2a\u5e76\u884c\u8bf7\u6c42\u3002Gunicorn \u53ef\u4ee5\u4e3a API \u521b\u5efa\u591a\u4e2a\u8fdb\u7a0b\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u540c\u65f6\u4e3a\u591a\u4e2a\u5ba2\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 \"pip install gunicorn \"\u5b89\u88c5 gunicorn\u3002 \u4e3a\u4e86\u5c06\u4ee3\u7801\u8f6c\u6362\u4e3a\u4e0e gunicorn \u517c\u5bb9\uff0c\u6211\u4eec\u9700\u8981\u79fb\u9664 init main\uff0c\u5e76\u5c06\u5176\u4e2d\u7684\u6240\u6709\u5185\u5bb9\u79fb\u81f3\u5168\u5c40\u8303\u56f4\u3002\u6b64\u5916\uff0c\u6211\u4eec\u73b0\u5728\u4f7f\u7528\u7684\u662f CPU \u800c\u4e0d\u662f GPU\u3002\u4fee\u6539\u540e\u7684\u4ee3\u7801\u5982\u4e0b\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) DEVICE = \"cpu\" MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ))) MODEL . to ( DEVICE ) MODEL . eval () def sentence_prediction ( sentence ): return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): return flask . jsonify ( response ) \u6211\u4eec\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002 gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u5728\u63d0\u4f9b\u7684 IP \u5730\u5740\u548c\u7aef\u53e3\u4e0a\u4f7f\u7528 4 \u4e2a Worker \u8fd0\u884c\u6211\u4eec\u7684 flask \u5e94\u7528\u7a0b\u5e8f\u3002\u7531\u4e8e\u6709 4 \u4e2a Worker\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u540c\u65f6\u5904\u7406 4 \u4e2a\u8bf7\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u73b0\u5728\u6211\u4eec\u7684\u7ec8\u7aef\u4f7f\u7528\u7684\u662f CPU\uff0c\u56e0\u6b64\u4e0d\u9700\u8981 GPU \u673a\u5668\uff0c\u53ef\u4ee5\u5728\u4efb\u4f55\u6807\u51c6\u670d\u52a1\u5668/\u865a\u62df\u673a\u4e0a\u8fd0\u884c\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u8fd8\u6709\u4e00\u4e2a\u95ee\u9898\uff1a\u6211\u4eec\u5df2\u7ecf\u5728\u672c\u5730\u673a\u5668\u4e0a\u5b8c\u6210\u4e86\u6240\u6709\u5de5\u4f5c\uff0c\u56e0\u6b64\u5fc5\u987b\u5c06\u5176\u575e\u5316\u3002\u770b\u770b\u4e0b\u9762\u8fd9\u4e2a\u672a\u6ce8\u91ca\u7684 Dockerfile\uff0c\u5b83\u53ef\u4ee5\u7528\u6765\u90e8\u7f72\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002\u8bf7\u6ce8\u610f\u7528\u4e8e\u57f9\u8bad\u7684\u65e7 Dockerfile \u548c\u8fd9\u4e2a Dockerfile \u4e4b\u95f4\u7684\u533a\u522b\u3002\u533a\u522b\u4e0d\u5927\u3002 # CPU Dockerfile FROM ubuntu:18.04 RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* RUN useradd -m abhishek RUN chown -R abhishek:abhishek /home/abhishek/ COPY --chown = abhishek *.* /home/abhishek/app/ USER abhishek RUN mkdir /home/abhishek/data/ RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt RUN pip3 install mkl WORKDIR /home/abhishek/app \u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u65b0\u7684 Docker \u5bb9\u5668\u3002 docker build -f Dockerfile -t bert:api \u5f53 Docker \u5bb9\u5668\u6784\u5efa\u5b8c\u6210\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u76f4\u63a5\u8fd0\u884c API \u4e86\u3002 docker run -p 5000 :5000 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:api /home/abhishek/.local/bin/gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5c06\u5bb9\u5668\u5185\u7684 5000 \u7aef\u53e3\u66b4\u9732\u7ed9\u5bb9\u5668\u5916\u7684 5000 \u7aef\u53e3\u3002\u5982\u679c\u4f7f\u7528 docker-compose\uff0c\u4e5f\u53ef\u4ee5\u5f88\u597d\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002Dockercompose \u662f\u4e00\u4e2a\u53ef\u4ee5\u8ba9\u4f60\u540c\u65f6\u5728\u4e0d\u540c\u6216\u76f8\u540c\u5bb9\u5668\u4e2d\u8fd0\u884c\u4e0d\u540c\u670d\u52a1\u7684\u5de5\u5177\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 \"pip install docker-compose \"\u5b89\u88c5 docker-compose\uff0c\u7136\u540e\u5728\u6784\u5efa\u5bb9\u5668\u540e\u8fd0\u884c \"docker-compose up\"\u3002\u8981\u4f7f\u7528 docker-compose\uff0c\u4f60\u9700\u8981\u4e00\u4e2a docker-compose.yml \u6587\u4ef6\u3002 # docker-compose.yml # specify a version of the compose version : '3.7' # you can add multiple services services : # specify service name. we call our service: api api : # specify image name image : bert:api # the command that you would like to run inside the container command : /home/abhishek/.local/bin/gunicorn api:app --bind 0.0.0.0:5000 --workers 4 # mount the volume volumes : - /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ # this ensures that our ports from container will be # exposed as it is network_mode : host \u73b0\u5728\uff0c\u60a8\u53ea\u9700\u4f7f\u7528\u4e0a\u8ff0\u547d\u4ee4\u5373\u53ef\u91cd\u65b0\u8fd0\u884c API\uff0c\u5176\u8fd0\u884c\u65b9\u5f0f\u4e0e\u4e4b\u524d\u76f8\u540c\u3002\u606d\u559c\u4f60\uff0c\u73b0\u5728\uff0c\u4f60\u4e5f\u5df2\u7ecf\u6210\u529f\u5730\u5c06\u9884\u6d4b API \u8fdb\u884c\u4e86 Docker \u5316\uff0c\u53ef\u4ee5\u968f\u65f6\u968f\u5730\u90e8\u7f72\u4e86\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u5b66\u4e60\u4e86 Docker\u3001\u4f7f\u7528 flask \u6784\u5efa API\u3001\u4f7f\u7528 gunicorn \u548c Docker \u670d\u52a1 API \u4ee5\u53ca docker-compose\u3002\u5173\u4e8e docker \u7684\u77e5\u8bc6\u8fdc\u4e0d\u6b62\u8fd9\u4e9b\uff0c\u4f46\u8fd9\u5e94\u8be5\u662f\u4e00\u4e2a\u5f00\u59cb\u3002\u5176\u4ed6\u5185\u5bb9\u53ef\u4ee5\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9010\u6e10\u638c\u63e1\u3002 \u6211\u4eec\u8fd8\u8df3\u8fc7\u4e86\u8bb8\u591a\u5de5\u5177\uff0c\u5982 kubernetes\u3001bean-stalk\u3001sagemaker\u3001heroku \u548c\u8bb8\u591a\u5176\u4ed6\u5de5\u5177\uff0c\u8fd9\u4e9b\u5de5\u5177\u5982\u4eca\u88ab\u4eba\u4eec\u7528\u6765\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u6a21\u578b\u3002\"\u6211\u8981\u5199\u4ec0\u4e48\uff1f\u70b9\u51fb\u4fee\u6539\u56fe X \u4e2d\u7684 docker \u5bb9\u5668\"\uff1f\u5728\u4e66\u4e2d\u63cf\u8ff0\u8fd9\u4e9b\u662f\u4e0d\u53ef\u884c\u7684\uff0c\u4e5f\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u6240\u4ee5\u6211\u5c06\u4f7f\u7528\u4e0d\u540c\u7684\u5a92\u4ecb\u6765\u8d5e\u7f8e\u672c\u4e66\u7684\u8fd9\u4e00\u90e8\u5206\u3002\u8bf7\u8bb0\u4f4f\uff0c\u4e00\u65e6\u4f60\u5bf9\u5e94\u7528\u7a0b\u5e8f\u8fdb\u884c\u4e86 Docker \u5316\uff0c\u4f7f\u7528\u8fd9\u4e9b\u6280\u672f/\u5e73\u53f0\u8fdb\u884c\u90e8\u7f72\u5c31\u53d8\u5f97\u6613\u5982\u53cd\u638c\u4e86\u3002\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u8981\u8ba9\u4f60\u7684\u4ee3\u7801\u548c\u6a21\u578b\u5bf9\u4ed6\u4eba\u53ef\u7528\uff0c\u5e76\u505a\u597d\u6587\u6863\u8bb0\u5f55\uff0c\u8fd9\u6837\u4efb\u4f55\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u4f60\u5f00\u53d1\u7684\u4e1c\u897f\uff0c\u800c\u65e0\u9700\u591a\u6b21\u8be2\u95ee\u4f60\u3002\u8fd9\u4e0d\u4ec5\u80fd\u8282\u7701\u60a8\u7684\u65f6\u95f4\uff0c\u8fd8\u80fd\u8282\u7701\u4ed6\u4eba\u7684\u65f6\u95f4\u3002\u597d\u7684\u3001\u5f00\u6e90\u7684\u3001\u53ef\u91cd\u590d\u4f7f\u7528\u7684\u4ee3\u7801\u5728\u60a8\u7684\u4f5c\u54c1\u96c6\u4e2d\u4e5f\u975e\u5e38\u91cd\u8981\u3002","title":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5"},{"location":"%E5%8F%AF%E9%87%8D%E5%A4%8D%E4%BB%A3%E7%A0%81%E5%92%8C%E6%A8%A1%E5%9E%8B%E6%96%B9%E6%B3%95/#_1","text":"\u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u5230\u4e86\u53ef\u4ee5\u5c06\u6a21\u578b/\u8bad\u7ec3\u4ee3\u7801\u5206\u53d1\u7ed9\u4ed6\u4eba\u4f7f\u7528\u7684\u9636\u6bb5\u3002\u60a8\u53ef\u4ee5\u7528\u8f6f\u76d8\u5206\u53d1\u6216\u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\uff0c\u4f46\u8fd9\u5e76\u4e0d\u7406\u60f3\u3002\u662f\u8fd9\u6837\u5417\uff1f\u4e5f\u8bb8\u5f88\u591a\u5e74\u524d\uff0c\u8fd9\u662f\u7406\u60f3\u7684\u505a\u6cd5\uff0c\u4f46\u73b0\u5728\u4e0d\u662f\u4e86\u3002 \u4e0e\u4ed6\u4eba\u5171\u4eab\u4ee3\u7801\u548c\u534f\u4f5c\u7684\u9996\u9009\u65b9\u5f0f\u662f\u4f7f\u7528\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u3002Git \u662f\u6700\u6d41\u884c\u7684\u6e90\u4ee3\u7801\u7ba1\u7406\u7cfb\u7edf\u4e4b\u4e00\u3002\u90a3\u4e48\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u5b66\u4f1a\u4e86 Git\uff0c\u5e76\u6b63\u786e\u5730\u683c\u5f0f\u5316\u4e86\u4ee3\u7801\uff0c\u7f16\u5199\u4e86\u9002\u5f53\u7684\u6587\u6863\uff0c\u8fd8\u5f00\u6e90\u4e86\u4f60\u7684\u9879\u76ee\u3002\u8fd9\u5c31\u591f\u4e86\u5417\uff1f\u4e0d\uff0c\u8fd8\u4e0d\u591f\u3002\u56e0\u4e3a\u4f60\u5728\u81ea\u5df1\u7684\u7535\u8111\u4e0a\u5199\u7684\u4ee3\u7801\uff0c\u5728\u522b\u4eba\u7684\u7535\u8111\u4e0a\u53ef\u80fd\u4f1a\u56e0\u4e3a\u5404\u79cd\u539f\u56e0\u800c\u65e0\u6cd5\u8fd0\u884c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u5728\u53d1\u5e03\u4ee3\u7801\u65f6\u80fd\u590d\u5236\u81ea\u5df1\u7684\u7535\u8111\uff0c\u800c\u5176\u4ed6\u4eba\u5728\u5b89\u88c5\u60a8\u7684\u8f6f\u4ef6\u6216\u8fd0\u884c\u60a8\u7684\u4ee3\u7801\u65f6\u4e5f\u80fd\u590d\u5236\u60a8\u7684\u7535\u8111\uff0c\u90a3\u5c31\u518d\u597d\u4e0d\u8fc7\u4e86\u3002\u4e3a\u6b64\uff0c\u5982\u4eca\u6700\u6d41\u884c\u7684\u65b9\u6cd5\u662f\u4f7f\u7528 Docker \u5bb9\u5668\uff08Docker Containers\uff09\u3002\u8981\u4f7f\u7528 Docker \u5bb9\u5668\uff0c\u4f60\u9700\u8981\u5b89\u88c5 Docker\u3002 \u8ba9\u6211\u4eec\u7528\u4e0b\u9762\u7684\u547d\u4ee4\u6765\u5b89\u88c5 Docker\u3002 sudo apt install docker.io sudo systemctl start docker sudo systemctl enable docker sudo groupadd docker sudo usermod -aG docker $USER \u8fd9\u4e9b\u547d\u4ee4\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e0a\u8fd0\u884c\u3002Docker \u6700\u68d2\u7684\u5730\u65b9\u5728\u4e8e\u5b83\u53ef\u4ee5\u5b89\u88c5\u5728\u4efb\u4f55\u673a\u5668\u4e0a\uff1a Linux\u3001Windows\u3001OSX\u3002\u56e0\u6b64\uff0c\u5982\u679c\u4f60\u4e00\u76f4\u5728 Docker \u5bb9\u5668\u4e2d\u5de5\u4f5c\uff0c\u54ea\u53f0\u673a\u5668\u90fd\u6ca1\u5173\u7cfb\uff01 Docker \u5bb9\u5668\u53ef\u4ee5\u88ab\u89c6\u4e3a\u5c0f\u578b\u865a\u62df\u673a\u3002\u4f60\u53ef\u4ee5\u4e3a\u4f60\u7684\u4ee3\u7801\u521b\u5efa\u4e00\u4e2a\u5bb9\u5668\uff0c\u7136\u540e\u6bcf\u4e2a\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u548c\u8bbf\u95ee\u5b83\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u521b\u5efa\u53ef\u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u5bb9\u5668\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e00\u7ae0\u4e2d\u8bad\u7ec3\u7684 BERT \u6a21\u578b\uff0c\u5e76\u5c1d\u8bd5\u5c06\u8bad\u7ec3\u4ee3\u7801\u5bb9\u5668\u5316\u3002 \u9996\u5148\uff0c\u4f60\u9700\u8981\u4e00\u4e2a\u5305\u542b python \u9879\u76ee\u9700\u6c42\u7684\u6587\u4ef6\u3002\u9700\u6c42\u5305\u542b\u5728\u540d\u4e3a requirements.txt \u7684\u6587\u4ef6\u4e2d\u3002\u6587\u4ef6\u540d\u662f thestandard\u3002\u8be5\u6587\u4ef6\u5305\u542b\u9879\u76ee\u4e2d\u4f7f\u7528\u7684\u6240\u6709 python \u5e93\u3002\u4e5f\u5c31\u662f\u53ef\u4ee5\u901a\u8fc7 PyPI (pip) \u4e0b\u8f7d\u7684 python \u5e93\u3002\u7528\u4e8e \u8bad\u7ec3 BERT \u6a21\u578b\u4ee5\u68c0\u6d4b\u6b63/\u8d1f\u60c5\u611f\uff0c\u6211\u4eec\u4f7f\u7528\u4e86 torch\u3001transformers\u3001tqdm\u3001scikit-learn\u3001pandas \u548c numpy\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u4eec\u5199\u5165 requirements.txt \u4e2d\u3002\u4f60\u53ef\u4ee5\u53ea\u5199\u540d\u79f0\uff0c\u4e5f\u53ef\u4ee5\u5305\u62ec\u7248\u672c\u3002\u5305\u542b\u7248\u672c\u603b\u662f\u6700\u597d\u7684\uff0c\u8fd9\u4e5f\u662f\u4f60\u5e94\u8be5\u505a\u7684\u3002\u5305\u542b\u7248\u672c\u540e\uff0c\u53ef\u4ee5\u786e\u4fdd\u5176\u4ed6\u4eba\u4f7f\u7528\u7684\u7248\u672c\u4e0e\u4f60\u7684\u7248\u672c\u76f8\u540c\uff0c\u800c\u4e0d\u662f\u6700\u65b0\u7248\u672c\uff0c\u56e0\u4e3a\u6700\u65b0\u7248\u672c\u53ef\u80fd\u4f1a\u66f4\u6539\u67d0\u4e9b\u5185\u5bb9\uff0c\u5982\u679c\u662f\u8fd9\u6837\u7684\u8bdd\uff0c\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u5c31\u4e0d\u4f1a\u4e0e\u4f60\u7684\u76f8\u540c\u4e86\u3002 \u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u663e\u793a\u4e86 requirements.txt\u3002 # requirements.txt pandas == 1.0.4 scikit - learn == 0.22.1 torch == 1.5.0 transformers == 2.11.0 \u73b0\u5728\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a Dockerfile \u7684 Docker \u6587\u4ef6\u3002\u6ca1\u6709\u6269\u5c55\u540d\u3002Dockerfile \u6709\u51e0\u4e2a\u5143\u7d20\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 # Dockerfile # First of all, we include where we are getting the image # from. Image can be thought of as an operating system. # You can do \"FROM ubuntu:18.04\" # this will start from a clean ubuntu 18.04 image. # All images are downloaded from dockerhub # Here are we grabbing image from nvidia's repo # they created a docker image using ubuntu 18.04 # and installed cuda 10.1 and cudnn7 in it. Thus, we don't have to # install it. Makes our life easy. FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 # this is the same apt-get command that you are used to # except the fact that, we have -y argument. Its because # when we build this container, we cannot press Y when asked for RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* # We add a new user called \"abhishek\" # this can be anything. Anything you want it # to be. Usually, we don't use our own name, # you can use \"user\" or \"ubuntu\" RUN useradd -m abhishek # make our user own its own home directory RUN chown -R abhishek:abhishek /home/abhishek/ # copy all files from this direrctory to a # directory called app inside the home of abhishek # and abhishek owns it. COPY --chown = abhishek *.* /home/abhishek/app/ # change to user abhishek USER abhishek RUN mkdir /home/abhishek/data/ # Now we install all the requirements # after moving to the app directory # PLEASE NOTE that ubuntu 18.04 image # has python 3.6.9 and not python 3.7.6 # you can also install conda python here and use that # however, to simplify it, I will be using python 3.6.9 # inside the docker container!!!! RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt # install mkl. its needed for transformers RUN pip3 install mkl # when we log into the docker container, # we will go inside this directory automatically WORKDIR /home/abhishek/app \u521b\u5efa\u597d Docker \u6587\u4ef6\u540e\uff0c\u6211\u4eec\u5c31\u9700\u8981\u6784\u5efa\u5b83\u3002\u6784\u5efa Docker \u5bb9\u5668\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u547d\u4ee4\u3002 docker build -f Dockerfile -t bert:train . \u8be5\u547d\u4ee4\u6839\u636e\u63d0\u4f9b\u7684 Dockerfile \u6784\u5efa\u4e00\u4e2a\u5bb9\u5668\u3002Docker \u5bb9\u5668\u7684\u540d\u79f0\u662f bert:train\u3002\u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a \u276f docker build -f Dockerfile -t bert:train . Sending build context to Docker daemon 19.97kB Step 1/7 : FROM nvidia/cuda:10.1-cudnn7-ubuntu18.04 ---> 3b55548ae91f Step 2/7 : RUN apt-get update && apt-get install -y git curl ca- certificates python3 python3-pip sudo && rm -rf /var/lib/apt/lists/* . . . . Removing intermediate container 8f6975dd08ba ---> d1802ac9f1b4 Step 7/7 : WORKDIR /home/abhishek/app ---> Running in 257ff09502ed Removing intermediate container 257ff09502ed ---> e5f6eb4cddd7 Successfully built e5f6eb4cddd7 Successfully tagged bert:train \u8bf7\u6ce8\u610f\uff0c\u6211\u5220\u9664\u4e86\u8f93\u51fa\u4e2d\u7684\u8bb8\u591a\u884c\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u767b\u5f55\u5bb9\u5668\u3002 docker run -ti bert:train /bin/bash \u4f60\u9700\u8981\u8bb0\u4f4f\uff0c\u4e00\u65e6\u9000\u51fa shell\uff0c\u4f60\u5728 shell \u4e2d\u6240\u505a\u7684\u4e00\u5207\u90fd\u5c06\u4e22\u5931\u3002\u4f60\u8fd8\u53ef\u4ee5\u5728 Docker \u5bb9\u5668\u4e2d\u4f7f\u7528\u3002 docker run -ti bert:train python3 train.py \u8f93\u51fa\u60c5\u51b5\uff1a Traceback (most recent call last): File \"train.py\", line 2, in import config File \"/home/abhishek/app/config.py\", line 28, in do_lower_case=True File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 393, in from_pretrained return cls._from_pretrained(*inputs, **kwargs) File \"/usr/local/lib/python3.6/dist- packages/transformers/tokenization_utils.py\", line 496, in _from_pretrained list(cls.vocab_files_names.values()), OSError: Model name '../input/bert_base_uncased/' was not found in tokenizers model name list (bert-base-uncased, bert-large-uncased, bert- base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base- multilingual-cased, bert-base-chinese, bert-base-german-cased, bert- large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased- whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert- base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base- finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). We assumed '../input/bert_base_uncased/' was a path, a model identifier, or url to a directory containing vocabulary files named ['vocab.txt'] but couldn't find such vocabulary files at this path or url. \u54ce\u5440\uff0c\u51fa\u9519\u4e86\uff01 \u6211\u4e3a\u4ec0\u4e48\u8981\u628a\u9519\u8bef\u5370\u5728\u4e66\u4e0a\u5462\uff1f \u56e0\u4e3a\u7406\u89e3\u8fd9\u4e2a\u9519\u8bef\u975e\u5e38\u91cd\u8981\u3002\u8fd9\u4e2a\u9519\u8bef\u8bf4\u660e\u4ee3\u7801\u65e0\u6cd5\u627e\u5230\u76ee\u5f55\".../input/bert_base_cased\"\u3002\u4e3a\u4ec0\u4e48\u4f1a\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u5728\u6ca1\u6709 Docker \u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u76ee\u5f55\u548c\u6240\u6709\u6587\u4ef6\u90fd\u5b58\u5728\u3002\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\u662f\u56e0\u4e3a Docker \u5c31\u50cf\u4e00\u4e2a\u865a\u62df\u673a\uff01\u5b83\u6709\u81ea\u5df1\u7684\u6587\u4ef6\u7cfb\u7edf\uff0c\u672c\u5730\u673a\u5668\u4e0a\u7684\u6587\u4ef6\u4e0d\u4f1a\u5171\u4eab\u7ed9 Docker \u5bb9\u5668\u3002\u5982\u679c\u4f60\u60f3\u4f7f\u7528\u672c\u5730\u673a\u5668\u4e0a\u7684\u8def\u5f84\u5e76\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\uff0c\u4f60\u9700\u8981\u5728\u8fd0\u884c Docker \u65f6\u5c06\u5176\u6302\u8f7d\u5230 Docker \u5bb9\u5668\u4e0a\u3002\u5f53\u6211\u4eec\u67e5\u770b\u8fd9\u4e2a\u6587\u4ef6\u5939\u7684\u8def\u5f84\u65f6\uff0c\u6211\u4eec\u77e5\u9053\u5b83\u4f4d\u4e8e\u540d\u4e3a input \u7684\u6587\u4ef6\u5939\u7684\u4e0a\u4e00\u7ea7\u3002\u8ba9\u6211\u4eec\u7a0d\u5fae\u4fee\u6539\u4e00\u4e0b config.py \u6587\u4ef6\uff01 # config.py import os import transformers # fetch home directory # in our docker container, it is # /home/abhishek HOME_DIR = os . path . expanduser ( \"~\" ) # this is the maximum number of tokens in the sentence MAX_LEN = 512 # batch sizes is low because model is huge! TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 # let's train for a maximum of 10 epochs EPOCHS = 10 # define path to BERT model files # Now we assume that all the data is stored inside # /home/abhishek/data BERT_PATH = os . path . join ( HOME_DIR , \"data\" , \"bert_base_uncased\" ) # this is where you want to save the model MODEL_PATH = os . path . join ( HOME_DIR , \"data\" , \"model.bin\" ) # training file TRAINING_FILE = os . path . join ( HOME_DIR , \"data\" , \"imdb.csv\" ) TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u73b0\u5728\uff0c\u4ee3\u7801\u5047\u5b9a\u6240\u6709\u5185\u5bb9\u90fd\u5728\u4e3b\u76ee\u5f55\u4e0b\u540d\u4e3a data \u7684\u6587\u4ef6\u5939\u4e2d\u3002 \u8bf7\u6ce8\u610f\uff0c\u5982\u679c Python \u811a\u672c\u6709\u4efb\u4f55\u6539\u52a8\uff0c\u90fd\u610f\u5473\u7740\u9700\u8981\u91cd\u5efa Docker \u5bb9\u5668\uff01\u56e0\u6b64\uff0c\u6211\u4eec\u91cd\u5efa\u5bb9\u5668\uff0c\u7136\u540e\u91cd\u65b0\u8fd0\u884c Docker \u547d\u4ee4\uff0c\u4f46\u8fd9\u6b21\u8981\u6709\u6240\u6539\u53d8\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u6ca1\u6709\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u8fd9\u4e5f\u662f\u884c\u4e0d\u901a\u7684\u3002\u522b\u62c5\u5fc3\uff0c\u8fd9\u53ea\u662f\u4e00\u4e2a Docker \u5bb9\u5668\u3002\u4f60\u53ea\u9700\u8981\u505a\u4e00\u6b21\u3002\u8981\u5b89\u88c5\u82f1\u4f1f\u8fbe\u2122\uff08NVIDIA\u00ae\uff09Docker \u8fd0\u884c\u65f6\uff0c\u53ef\u4ee5\u5728 Ubuntu 18.04 \u4e2d\u8fd0\u884c\u4ee5\u4e0b\u547d\u4ee4\u3002 distribution = $( . /etc/os-release ; echo $ID$VERSION_ID ) curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/ $distribution /nvidia- docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u6784\u5efa\u6211\u4eec\u7684\u5bb9\u5668\uff0c\u5e76\u5f00\u59cb\u8bad\u7ec3\u8fc7\u7a0b\uff1a docker run --gpus 1 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:train python3 train.py \u5176\u4e2d\uff0c-gpus 1 \u8868\u793a\u6211\u4eec\u5728 docker \u5bb9\u5668\u4e2d\u4f7f\u7528 1 \u4e2a GPU\uff0c-v \u8868\u793a\u6302\u8f7d\u5377\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u8981\u5c06\u672c\u5730\u76ee\u5f55 /home/abhishek/workspace/approaching_almost/input/ \u6302\u8f7d\u5230 docker \u5bb9\u5668\u4e2d\u7684 /home/abhishek/data/\u3002\u8fd9\u4e00\u6b65\u8981\u82b1\u70b9\u65f6\u95f4\uff0c\u4f46\u5b8c\u6210\u540e\uff0c\u672c\u5730\u6587\u4ef6\u5939\u4e2d\u5c31\u4f1a\u6709 model.bin\u3002 \u8fd9\u6837\uff0c\u53ea\u9700\u505a\u4e00\u4e9b\u7b80\u5355\u7684\u6539\u52a8\uff0c\u4f60\u7684\u8bad\u7ec3\u4ee3\u7801\u5c31\u5df2\u7ecf \"dockerized \"\u4e86\u3002\u73b0\u5728\uff0c\u4f60\u53ef\u4ee5\u5728\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u4f60\u60f3\u8981\u7684\u7cfb\u7edf\u4e0a\u4f7f\u7528\u8fd9\u4e9b\u4ee3\u7801\u8fdb\u884c\u8bad\u7ec3\u3002 \u4e0b\u4e00\u90e8\u5206\u662f\u5c06\u6211\u4eec\u8bad\u7ec3\u597d\u7684\u6a21\u578b \"\u63d0\u4f9b \"\u7ed9\u6700\u7ec8\u7528\u6237\u3002\u5047\u8bbe\u60a8\u60f3\u4ece\u63a5\u6536\u5230\u7684\u63a8\u6587\u6d41\u4e2d\u63d0\u53d6\u60c5\u611f\u4fe1\u606f\u3002\u8981\u5b8c\u6210\u8fd9\u9879\u4efb\u52a1\uff0c\u60a8\u5fc5\u987b\u521b\u5efa\u4e00\u4e2a API\uff0c\u7528\u4e8e\u8f93\u5165\u53e5\u5b50\uff0c\u7136\u540e\u8fd4\u56de\u5e26\u6709\u60c5\u611f\u6982\u7387\u7684\u8f93\u51fa\u3002\u4f7f\u7528 Python \u6784\u5efa API \u7684\u6700\u5e38\u89c1\u65b9\u6cd5\u662f\u4f7f\u7528 Flask \uff0c\u5b83\u662f\u4e00\u4e2a\u5fae\u578b\u7f51\u7edc\u670d\u52a1\u6846\u67b6\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) MODEL = None DEVICE = \"cuda\" def sentence_prediction ( sentence ): tokenizer = config . TOKENIZER max_len = config . MAX_LEN review = str ( sentence ) review = \" \" . join ( review . split ()) inputs = tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = max_len ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] padding_length = max_len - len ( ids ) ids = ids + ([ 0 ] * padding_length ) mask = mask + ([ 0 ] * padding_length ) token_type_ids = token_type_ids + ([ 0 ] * padding_length ) ids = torch . tensor ( ids , dtype = torch . long ) . unsqueeze ( 0 ) mask = torch . tensor ( mask , dtype = torch . long ) . unsqueeze ( 0 ) token_type_ids = torch . tensor ( token_type_ids , dtype = torch . long ) . unsqueeze ( 0 ) ids = ids . to ( DEVICE , dtype = torch . long ) token_type_ids = token_type_ids . to ( DEVICE , dtype = torch . long ) mask = mask . to ( DEVICE , dtype = torch . long ) outputs = MODEL ( ids = ids , mask = mask , token_type_ids = token_type_ids ) outputs = torch . sigmoid ( outputs ) . cpu () . detach () . numpy () return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): sentence = request . args . get ( \"sentence\" ) start_time = time . time () positive_prediction = sentence_prediction ( sentence ) negative_prediction = 1 - positive_prediction response = {} response [ \"response\" ] = { \"positive\" : str ( positive_prediction ), \"negative\" : str ( negative_prediction ), \"sentence\" : str ( sentence ), \"time_taken\" : str ( time . time () - start_time ), } return flask . jsonify ( response ) if __name__ == \"__main__\" : MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ) )) MODEL . to ( DEVICE ) MODEL . eval () app . run ( host = \"0.0.0.0\" ) \u7136\u540e\u8fd0\u884c \"python api.py \"\u547d\u4ee4\u542f\u52a8 API\u3002API \u5c06\u5728\u7aef\u53e3 5000 \u7684 localhost \u4e0a\u542f\u52a8\u3002cURL \u8bf7\u6c42\u53ca\u5176\u54cd\u5e94\u793a\u4f8b\u5982\u4e0b\u3002 \u276f curl $'http://192.168.86.48:5000/predict?sentence=this%20is%20the%20best%20boo k%20ever' {\"response\":{\"negative\":\"0.0032927393913269043\",\"positive\":\"0.99670726\",\" sentence\":\"this is the best book ever\",\"time_taken\":\"0.029126882553100586\"}} \u53ef\u4ee5\u770b\u5230\uff0c\u6211\u4eec\u5f97\u5230\u7684\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002\u8f93\u5165\u53e5\u5b50\u7684\u6b63\u9762\u60c5\u611f\u6982\u7387\u5f88\u9ad8\u3002 \u60a8\u8fd8\u53ef\u4ee5\u8bbf\u95ee http://127.0.0.1:5000/predict?sentence=this%20book%20is%20too%20complicated%20for%20me\u3002\u8fd9\u5c06\u518d\u6b21\u8fd4\u56de\u4e00\u4e2a JSON \u6587\u4ef6\u3002 { response : { negative : \"0.8646619468927383\" , positive : \"0.13533805\" , sentence : \"this book is too complicated for me\" , time_taken : \"0.03852701187133789\" } } \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\uff0c\u53ef\u4ee5\u7528\u6765\u4e3a\u5c11\u91cf\u7528\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u4e3a\u4ec0\u4e48\u662f\u5c11\u91cf\uff1f\u56e0\u4e3a\u8fd9\u4e2a API \u4e00\u6b21\u53ea\u670d\u52a1\u4e00\u4e2a\u8bf7\u6c42\u3002gunicorn \u662f UNIX \u4e0a\u7684 Python WSGI HTTP \u670d\u52a1\u5668\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u5b83\u7684 CPU \u6765\u5904\u7406\u591a\u4e2a\u5e76\u884c\u8bf7\u6c42\u3002Gunicorn \u53ef\u4ee5\u4e3a API \u521b\u5efa\u591a\u4e2a\u8fdb\u7a0b\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u540c\u65f6\u4e3a\u591a\u4e2a\u5ba2\u6237\u63d0\u4f9b\u670d\u52a1\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 \"pip install gunicorn \"\u5b89\u88c5 gunicorn\u3002 \u4e3a\u4e86\u5c06\u4ee3\u7801\u8f6c\u6362\u4e3a\u4e0e gunicorn \u517c\u5bb9\uff0c\u6211\u4eec\u9700\u8981\u79fb\u9664 init main\uff0c\u5e76\u5c06\u5176\u4e2d\u7684\u6240\u6709\u5185\u5bb9\u79fb\u81f3\u5168\u5c40\u8303\u56f4\u3002\u6b64\u5916\uff0c\u6211\u4eec\u73b0\u5728\u4f7f\u7528\u7684\u662f CPU \u800c\u4e0d\u662f GPU\u3002\u4fee\u6539\u540e\u7684\u4ee3\u7801\u5982\u4e0b\u3002 # api.py import config import flask import time import torch import torch.nn as nn from flask import Flask from flask import request from model import BERTBaseUncased app = Flask ( __name__ ) DEVICE = \"cpu\" MODEL = BERTBaseUncased () MODEL . load_state_dict ( torch . load ( config . MODEL_PATH , map_location = torch . device ( DEVICE ))) MODEL . to ( DEVICE ) MODEL . eval () def sentence_prediction ( sentence ): return outputs [ 0 ][ 0 ] @app . route ( \"/predict\" , methods = [ \"GET\" ]) def predict (): return flask . jsonify ( response ) \u6211\u4eec\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002 gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u5728\u63d0\u4f9b\u7684 IP \u5730\u5740\u548c\u7aef\u53e3\u4e0a\u4f7f\u7528 4 \u4e2a Worker \u8fd0\u884c\u6211\u4eec\u7684 flask \u5e94\u7528\u7a0b\u5e8f\u3002\u7531\u4e8e\u6709 4 \u4e2a Worker\uff0c\u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u540c\u65f6\u5904\u7406 4 \u4e2a\u8bf7\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u73b0\u5728\u6211\u4eec\u7684\u7ec8\u7aef\u4f7f\u7528\u7684\u662f CPU\uff0c\u56e0\u6b64\u4e0d\u9700\u8981 GPU \u673a\u5668\uff0c\u53ef\u4ee5\u5728\u4efb\u4f55\u6807\u51c6\u670d\u52a1\u5668/\u865a\u62df\u673a\u4e0a\u8fd0\u884c\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u8fd8\u6709\u4e00\u4e2a\u95ee\u9898\uff1a\u6211\u4eec\u5df2\u7ecf\u5728\u672c\u5730\u673a\u5668\u4e0a\u5b8c\u6210\u4e86\u6240\u6709\u5de5\u4f5c\uff0c\u56e0\u6b64\u5fc5\u987b\u5c06\u5176\u575e\u5316\u3002\u770b\u770b\u4e0b\u9762\u8fd9\u4e2a\u672a\u6ce8\u91ca\u7684 Dockerfile\uff0c\u5b83\u53ef\u4ee5\u7528\u6765\u90e8\u7f72\u8fd9\u4e2a\u5e94\u7528\u7a0b\u5e8f\u63a5\u53e3\u3002\u8bf7\u6ce8\u610f\u7528\u4e8e\u57f9\u8bad\u7684\u65e7 Dockerfile \u548c\u8fd9\u4e2a Dockerfile \u4e4b\u95f4\u7684\u533a\u522b\u3002\u533a\u522b\u4e0d\u5927\u3002 # CPU Dockerfile FROM ubuntu:18.04 RUN apt-get update && apt-get install -y \\ git \\ curl \\ ca-certificates \\ python3 \\ python3-pip \\ sudo \\ && rm -rf /var/lib/apt/lists/* RUN useradd -m abhishek RUN chown -R abhishek:abhishek /home/abhishek/ COPY --chown = abhishek *.* /home/abhishek/app/ USER abhishek RUN mkdir /home/abhishek/data/ RUN cd /home/abhishek/app/ && pip3 install -r requirements.txt RUN pip3 install mkl WORKDIR /home/abhishek/app \u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u65b0\u7684 Docker \u5bb9\u5668\u3002 docker build -f Dockerfile -t bert:api \u5f53 Docker \u5bb9\u5668\u6784\u5efa\u5b8c\u6210\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u76f4\u63a5\u8fd0\u884c API \u4e86\u3002 docker run -p 5000 :5000 -v /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ - ti bert:api /home/abhishek/.local/bin/gunicorn api:app --bind 0 .0.0.0:5000 --workers 4 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5c06\u5bb9\u5668\u5185\u7684 5000 \u7aef\u53e3\u66b4\u9732\u7ed9\u5bb9\u5668\u5916\u7684 5000 \u7aef\u53e3\u3002\u5982\u679c\u4f7f\u7528 docker-compose\uff0c\u4e5f\u53ef\u4ee5\u5f88\u597d\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002Dockercompose \u662f\u4e00\u4e2a\u53ef\u4ee5\u8ba9\u4f60\u540c\u65f6\u5728\u4e0d\u540c\u6216\u76f8\u540c\u5bb9\u5668\u4e2d\u8fd0\u884c\u4e0d\u540c\u670d\u52a1\u7684\u5de5\u5177\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 \"pip install docker-compose \"\u5b89\u88c5 docker-compose\uff0c\u7136\u540e\u5728\u6784\u5efa\u5bb9\u5668\u540e\u8fd0\u884c \"docker-compose up\"\u3002\u8981\u4f7f\u7528 docker-compose\uff0c\u4f60\u9700\u8981\u4e00\u4e2a docker-compose.yml \u6587\u4ef6\u3002 # docker-compose.yml # specify a version of the compose version : '3.7' # you can add multiple services services : # specify service name. we call our service: api api : # specify image name image : bert:api # the command that you would like to run inside the container command : /home/abhishek/.local/bin/gunicorn api:app --bind 0.0.0.0:5000 --workers 4 # mount the volume volumes : - /home/abhishek/workspace/approaching_almost/input/:/home/abhishek/data/ # this ensures that our ports from container will be # exposed as it is network_mode : host \u73b0\u5728\uff0c\u60a8\u53ea\u9700\u4f7f\u7528\u4e0a\u8ff0\u547d\u4ee4\u5373\u53ef\u91cd\u65b0\u8fd0\u884c API\uff0c\u5176\u8fd0\u884c\u65b9\u5f0f\u4e0e\u4e4b\u524d\u76f8\u540c\u3002\u606d\u559c\u4f60\uff0c\u73b0\u5728\uff0c\u4f60\u4e5f\u5df2\u7ecf\u6210\u529f\u5730\u5c06\u9884\u6d4b API \u8fdb\u884c\u4e86 Docker \u5316\uff0c\u53ef\u4ee5\u968f\u65f6\u968f\u5730\u90e8\u7f72\u4e86\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u5b66\u4e60\u4e86 Docker\u3001\u4f7f\u7528 flask \u6784\u5efa API\u3001\u4f7f\u7528 gunicorn \u548c Docker \u670d\u52a1 API \u4ee5\u53ca docker-compose\u3002\u5173\u4e8e docker \u7684\u77e5\u8bc6\u8fdc\u4e0d\u6b62\u8fd9\u4e9b\uff0c\u4f46\u8fd9\u5e94\u8be5\u662f\u4e00\u4e2a\u5f00\u59cb\u3002\u5176\u4ed6\u5185\u5bb9\u53ef\u4ee5\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9010\u6e10\u638c\u63e1\u3002 \u6211\u4eec\u8fd8\u8df3\u8fc7\u4e86\u8bb8\u591a\u5de5\u5177\uff0c\u5982 kubernetes\u3001bean-stalk\u3001sagemaker\u3001heroku \u548c\u8bb8\u591a\u5176\u4ed6\u5de5\u5177\uff0c\u8fd9\u4e9b\u5de5\u5177\u5982\u4eca\u88ab\u4eba\u4eec\u7528\u6765\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u6a21\u578b\u3002\"\u6211\u8981\u5199\u4ec0\u4e48\uff1f\u70b9\u51fb\u4fee\u6539\u56fe X \u4e2d\u7684 docker \u5bb9\u5668\"\uff1f\u5728\u4e66\u4e2d\u63cf\u8ff0\u8fd9\u4e9b\u662f\u4e0d\u53ef\u884c\u7684\uff0c\u4e5f\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u6240\u4ee5\u6211\u5c06\u4f7f\u7528\u4e0d\u540c\u7684\u5a92\u4ecb\u6765\u8d5e\u7f8e\u672c\u4e66\u7684\u8fd9\u4e00\u90e8\u5206\u3002\u8bf7\u8bb0\u4f4f\uff0c\u4e00\u65e6\u4f60\u5bf9\u5e94\u7528\u7a0b\u5e8f\u8fdb\u884c\u4e86 Docker \u5316\uff0c\u4f7f\u7528\u8fd9\u4e9b\u6280\u672f/\u5e73\u53f0\u8fdb\u884c\u90e8\u7f72\u5c31\u53d8\u5f97\u6613\u5982\u53cd\u638c\u4e86\u3002\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u8981\u8ba9\u4f60\u7684\u4ee3\u7801\u548c\u6a21\u578b\u5bf9\u4ed6\u4eba\u53ef\u7528\uff0c\u5e76\u505a\u597d\u6587\u6863\u8bb0\u5f55\uff0c\u8fd9\u6837\u4efb\u4f55\u4eba\u90fd\u53ef\u4ee5\u4f7f\u7528\u4f60\u5f00\u53d1\u7684\u4e1c\u897f\uff0c\u800c\u65e0\u9700\u591a\u6b21\u8be2\u95ee\u4f60\u3002\u8fd9\u4e0d\u4ec5\u80fd\u8282\u7701\u60a8\u7684\u65f6\u95f4\uff0c\u8fd8\u80fd\u8282\u7701\u4ed6\u4eba\u7684\u65f6\u95f4\u3002\u597d\u7684\u3001\u5f00\u6e90\u7684\u3001\u53ef\u91cd\u590d\u4f7f\u7528\u7684\u4ee3\u7801\u5728\u60a8\u7684\u4f5c\u54c1\u96c6\u4e2d\u4e5f\u975e\u5e38\u91cd\u8981\u3002","title":"\u53ef\u91cd\u590d\u4ee3\u7801\u548c\u6a21\u578b\u65b9\u6cd5"},{"location":"%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB%E5%92%8C%E5%88%86%E5%89%B2%E6%96%B9%E6%B3%95/","text":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5 \u8bf4\u5230\u56fe\u50cf\uff0c\u8fc7\u53bb\u51e0\u5e74\u53d6\u5f97\u4e86\u5f88\u591a\u6210\u5c31\u3002\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8fdb\u6b65\u76f8\u5f53\u5feb\uff0c\u611f\u89c9\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8bb8\u591a\u95ee\u9898\u73b0\u5728\u90fd\u66f4\u5bb9\u6613\u89e3\u51b3\u4e86\u3002\u968f\u7740\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u51fa\u73b0\u548c\u8ba1\u7b97\u6210\u672c\u7684\u964d\u4f4e\uff0c\u73b0\u5728\u5728\u5bb6\u91cc\u5c31\u80fd\u8f7b\u677e\u8bad\u7ec3\u51fa\u63a5\u8fd1\u6700\u5148\u8fdb\u6c34\u5e73\u7684\u6a21\u578b\uff0c\u89e3\u51b3\u5927\u591a\u6570\u4e0e\u56fe\u50cf\u76f8\u5173\u7684\u95ee\u9898\u3002\u4f46\u662f\uff0c\u56fe\u50cf\u95ee\u9898\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u578b\u3002\u4ece\u4e24\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u7684\u6807\u51c6\u56fe\u50cf\u5206\u7c7b\uff0c\u5230\u50cf\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\u8fd9\u6837\u5177\u6709\u6311\u6218\u6027\u7684\u95ee\u9898\u3002\u6211\u4eec\u4e0d\u4f1a\u5728\u672c\u4e66\u4e2d\u8ba8\u8bba\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\uff0c\u4f46\u6211\u4eec\u663e\u7136\u4f1a\u5904\u7406\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u56fe\u50cf\u95ee\u9898\u3002 \u6211\u4eec\u53ef\u4ee5\u5bf9\u56fe\u50cf\u91c7\u7528\u54ea\u4e9b\u4e0d\u540c\u7684\u65b9\u6cd5\uff1f\u56fe\u50cf\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u6570\u5b57\u77e9\u9635\u3002\u8ba1\u7b97\u673a\u65e0\u6cd5\u50cf\u4eba\u7c7b\u4e00\u6837\u770b\u5230\u56fe\u50cf\u3002\u5b83\u53ea\u80fd\u770b\u5230\u6570\u5b57\uff0c\u8fd9\u5c31\u662f\u56fe\u50cf\u3002\u7070\u5ea6\u56fe\u50cf\u662f\u4e00\u4e2a\u4e8c\u7ef4\u77e9\u9635\uff0c\u6570\u503c\u8303\u56f4\u4ece 0 \u5230 255\u30020 \u4ee3\u8868\u9ed1\u8272\uff0c255 \u4ee3\u8868\u767d\u8272\uff0c\u4ecb\u4e8e\u4e24\u8005\u4e4b\u95f4\u7684\u662f\u5404\u79cd\u7070\u8272\u3002\u4ee5\u524d\uff0c\u5728\u6ca1\u6709\u6df1\u5ea6\u5b66\u4e60\u7684\u65f6\u5019\uff08\u6216\u8005\u8bf4\u6df1\u5ea6\u5b66\u4e60\u8fd8\u4e0d\u6d41\u884c\u7684\u65f6\u5019\uff09\uff0c\u4eba\u4eec\u4e60\u60ef\u4e8e\u67e5\u770b\u50cf\u7d20\u3002\u6bcf\u4e2a\u50cf\u7d20\u90fd\u662f\u4e00\u4e2a\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u5728 Python \u4e2d\u8f7b\u677e\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u53ea\u9700\u4f7f\u7528 OpenCV \u6216 Python-PIL \u8bfb\u53d6\u7070\u5ea6\u56fe\u50cf\uff0c\u8f6c\u6362\u4e3a numpy \u6570\u7ec4\uff0c\u7136\u540e\u5c06\u77e9\u9635\u5e73\u94fa\uff08\u6241\u5e73\u5316\uff09\u5373\u53ef\u3002\u5982\u679c\u5904\u7406\u7684\u662f RGB \u56fe\u50cf\uff0c\u5219\u9700\u8981\u4e09\u4e2a\u77e9\u9635\uff0c\u800c\u4e0d\u662f\u4e00\u4e2a\u3002\u4f46\u601d\u8def\u662f\u4e00\u6837\u7684\u3002 import numpy as np import matplotlib.pyplot as plt # \u751f\u6210\u4e00\u4e2a 256x256 \u7684\u968f\u673a\u7070\u5ea6\u56fe\u50cf\uff0c\u50cf\u7d20\u503c\u57280\u5230255\u4e4b\u95f4\u968f\u673a\u5206\u5e03 random_image = np . random . randint ( 0 , 256 , ( 256 , 256 )) # \u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u56fe\u50cf\u7a97\u53e3\uff0c\u8bbe\u7f6e\u7a97\u53e3\u5927\u5c0f\u4e3a7x7\u82f1\u5bf8 plt . figure ( figsize = ( 7 , 7 )) # \u663e\u793a\u751f\u6210\u7684\u968f\u673a\u56fe\u50cf # \u4f7f\u7528\u7070\u5ea6\u989c\u8272\u6620\u5c04 (colormap)\uff0c\u8303\u56f4\u4ece0\u5230255 plt . imshow ( random_image , cmap = 'gray' , vmin = 0 , vmax = 255 ) # \u663e\u793a\u56fe\u50cf\u7a97\u53e3 plt . show () \u4e0a\u9762\u7684\u4ee3\u7801\u4f7f\u7528 numpy \u751f\u6210\u4e00\u4e2a\u968f\u673a\u77e9\u9635\u3002\u8be5\u77e9\u9635\u7531 0 \u5230 255\uff08\u5305\u542b\uff09\u7684\u503c\u7ec4\u6210\uff0c\u5927\u5c0f\u4e3a 256x256\uff08\u4e5f\u79f0\u4e3a\u50cf\u7d20\uff09\u3002 \u56fe 1\uff1a\u4e8c\u7ef4\u56fe\u50cf\u9635\u5217\uff08\u5355\u901a\u9053\uff09\u53ca\u5176\u5c55\u5e73\u7248\u672c \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u62fc\u5199\u540e\u7684\u7248\u672c\u53ea\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a M \u7684\u5411\u91cf\uff0c\u5176\u4e2d M = N * N\uff0c\u5728\u672c\u4f8b\u4e2d\uff0c\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 256 * 256 = 65536\u3002 \u73b0\u5728\uff0c\u5982\u679c\u6211\u4eec\u7ee7\u7eed\u5bf9\u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u8fdb\u884c\u5904\u7406\uff0c\u6bcf\u4e2a\u6837\u672c\u5c31\u4f1a\u6709 65536 \u4e2a\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5feb\u901f\u5efa\u7acb \u51b3\u7b56\u6811\u6a21\u578b\u3001\u968f\u673a\u68ee\u6797\u6a21\u578b\u6216\u57fa\u4e8e SVM \u7684\u6a21\u578b \u3002\u8fd9\u4e9b\u6a21\u578b\u5c06\u57fa\u4e8e\u50cf\u7d20\u503c\uff0c\u5c1d\u8bd5\u5c06\u6b63\u6837\u672c\u4e0e\u8d1f\u6837\u672c\u533a\u5206\u5f00\u6765\uff08\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff09\u3002 \u4f60\u4eec\u4e00\u5b9a\u90fd\u542c\u8bf4\u8fc7\u732b\u4e0e\u72d7\u7684\u95ee\u9898\uff0c\u8fd9\u662f\u4e00\u4e2a\u7ecf\u5178\u7684\u95ee\u9898\u3002\u5982\u679c\u4f60\u4eec\u8fd8\u8bb0\u5f97\uff0c\u5728\u8bc4\u4f30\u6307\u6807\u4e00\u7ae0\u7684\u5f00\u5934\uff0c\u6211\u5411\u4f60\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e2a\u6c14\u80f8\u56fe\u50cf\u6570\u636e\u96c6\u3002\u90a3\u4e48\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u68c0\u6d4b\u80ba\u90e8\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\uff08\u5e76\u4e0d\uff09\u7b80\u5355\u7684\u4e8c\u5143\u5206\u7c7b\u3002 \u56fe 2\uff1a\u975e\u6c14\u80f8\u4e0e\u6c14\u80f8 X \u5149\u56fe\u50cf\u5bf9\u6bd4 \u5728\u56fe 2 \u4e2d\uff0c\u60a8\u53ef\u4ee5\u770b\u5230\u975e\u6c14\u80f8\u548c\u6c14\u80f8\u56fe\u50cf\u7684\u5bf9\u6bd4\u3002\u60a8\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\u4e86\uff0c\u5bf9\u4e8e\u4e00\u4e2a\u975e\u4e13\u4e1a\u4eba\u58eb\uff08\u6bd4\u5982\u6211\uff09\u6765\u8bf4\uff0c\u8981\u5728\u8fd9\u4e9b\u56fe\u50cf\u4e2d\u8fa8\u522b\u51fa\u54ea\u4e2a\u662f\u6c14\u80f8\u662f\u76f8\u5f53\u56f0\u96be\u7684\u3002 \u6700\u521d\u7684\u6570\u636e\u96c6\u662f\u5173\u4e8e\u68c0\u6d4b\u6c14\u80f8\u7684\u5177\u4f53\u4f4d\u7f6e\uff0c\u4f46\u6211\u4eec\u5c06\u95ee\u9898\u4fee\u6539\u4e3a\u67e5\u627e\u7ed9\u5b9a\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u522b\u62c5\u5fc3\uff0c\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u4ecb\u7ecd\u8fd9\u4e2a\u90e8\u5206\u3002\u6570\u636e\u96c6\u7531 10675 \u5f20\u72ec\u7279\u7684\u56fe\u50cf\u7ec4\u6210\uff0c\u5176\u4e2d 2379 \u5f20\u6709\u6c14\u80f8\uff08\u6ce8\u610f\uff0c\u8fd9\u4e9b\u6570\u5b57\u662f\u7ecf\u8fc7\u6570\u636e\u6e05\u7406\u540e\u5f97\u51fa\u7684\uff0c\u56e0\u6b64\u4e0e\u539f\u59cb\u6570\u636e\u96c6\u4e0d\u7b26\uff09\u3002\u6b63\u5982\u6570\u636e\u79d1\u5b66\u5bb6\u6240\u8bf4\uff1a\u8fd9\u662f\u4e00\u4e2a\u5178\u578b\u7684 \u504f\u659c\u4e8c\u5143\u5206\u7c7b\u6848\u4f8b \u3002\u56e0\u6b64\uff0c\u6211\u4eec\u9009\u62e9 AUC \u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\uff0c\u5e76\u91c7\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u9a8c\u8bc1\u65b9\u6848\u3002 \u60a8\u53ef\u4ee5\u5c06\u7279\u5f81\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c1d\u8bd5\u4e00\u4e9b\u7ecf\u5178\u65b9\u6cd5\uff08\u5982 SVM\u3001RF\uff09\u6765\u8fdb\u884c\u5206\u7c7b\uff0c\u8fd9\u5b8c\u5168\u6ca1\u95ee\u9898\uff0c\u4f46\u5374\u65e0\u6cd5\u8ba9\u60a8\u8fbe\u5230\u6700\u5148\u8fdb\u7684\u6c34\u5e73\u3002\u6b64\u5916\uff0c\u56fe\u50cf\u5927\u5c0f\u4e3a 1024x1024\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u6a21\u578b\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u7ba1\u600e\u6837\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u7531\u4e8e\u56fe\u50cf\u662f\u7070\u5ea6\u7684\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u8fdb\u884c\u4efb\u4f55\u8f6c\u6362\u3002\u6211\u4eec\u5c06\u628a\u56fe\u50cf\u5927\u5c0f\u8c03\u6574\u4e3a 256x256\uff0c\u4f7f\u5176\u66f4\u5c0f\uff0c\u5e76\u4f7f\u7528\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684 AUC \u4f5c\u4e3a\u8861\u91cf\u6307\u6807\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5b83\u7684\u8868\u73b0\u5982\u4f55\u3002 import os import numpy as np import pandas as pd from PIL import Image from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from tqdm import tqdm # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u6765\u521b\u5efa\u6570\u636e\u96c6 def create_dataset ( training_df , image_dir ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\u6765\u5b58\u50a8\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c images = [] targets = [] # \u8fed\u4ee3\u5904\u7406\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u7684\u6bcf\u4e00\u884c for index , row in tqdm ( training_df . iterrows (), total = len ( training_df ), desc = \"processing images\" ): # \u83b7\u53d6\u56fe\u50cf\u6587\u4ef6\u540d image_id = row [ \"ImageId\" ] # \u6784\u5efa\u5b8c\u6574\u7684\u56fe\u50cf\u6587\u4ef6\u8def\u5f84 image_path = os . path . join ( image_dir , image_id ) # \u6253\u5f00\u56fe\u50cf\u6587\u4ef6\u5e76\u8fdb\u884c\u5927\u5c0f\u8c03\u6574\uff08resize\uff09\u4e3a 256x256 \u50cf\u7d20\uff0c\u4f7f\u7528\u53cc\u7ebf\u6027\u63d2\u503c\uff08BILINEAR\uff09 image = Image . open ( image_path + \".png\" ) image = image . resize (( 256 , 256 ), resample = Image . BILINEAR ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 image = np . array ( image ) # \u5c06\u56fe\u50cf\u6241\u5e73\u5316\u4e3a\u4e00\u7ef4\u6570\u7ec4\uff0c\u5e76\u5c06\u5176\u6dfb\u52a0\u5230\u56fe\u50cf\u5217\u8868 image = image . ravel () images . append ( image ) # \u5c06\u76ee\u6807\u503c\uff08target\uff09\u6dfb\u52a0\u5230\u76ee\u6807\u5217\u8868 targets . append ( int ( row [ \"target\" ])) # \u5c06\u56fe\u50cf\u5217\u8868\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 images = np . array ( images ) # \u6253\u5370\u56fe\u50cf\u6570\u7ec4\u7684\u5f62\u72b6 print ( images . shape ) # \u8fd4\u56de\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c return images , targets if __name__ == \"__main__\" : # \u5b9a\u4e49CSV\u6587\u4ef6\u8def\u5f84\u548c\u56fe\u50cf\u6587\u4ef6\u76ee\u5f55\u8def\u5f84 csv_path = \"/home/abhishek/workspace/siim_png/train.csv\" image_path = \"/home/abhishek/workspace/siim_png/train_png/\" # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( csv_path ) # \u6dfb\u52a0\u4e00\u4e2a\u540d\u4e3a'kfold'\u7684\u5217\uff0c\u5e76\u521d\u59cb\u5316\u4e3a-1 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u503c\uff08target\uff09 y = df . target . values # \u4f7f\u7528\u5206\u5c42KFold\u4ea4\u53c9\u9a8c\u8bc1\u5c06\u6570\u636e\u96c6\u5206\u62105\u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u904d\u5386\u6bcf\u4e2a\u6298\uff08fold\uff09 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u904d\u5386\u6bcf\u4e2a\u6298 for fold_ in range ( 5 ): # \u83b7\u53d6\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtrain , ytrain = create_dataset ( train_df , image_path ) # \u521b\u5efa\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtest , ytest = create_dataset ( test_df , image_path ) # \u521d\u59cb\u5316\u4e00\u4e2a\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668 clf = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u5206\u7c7b\u5668 clf . fit ( xtrain , ytrain ) # \u4f7f\u7528\u5206\u7c7b\u5668\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b\uff0c\u5e76\u83b7\u53d6\u6982\u7387\u503c preds = clf . predict_proba ( xtest )[:, 1 ] # \u6253\u5370\u6298\u6570\uff08fold\uff09\u548cAUC\uff08ROC\u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff09 print ( f \"FOLD: { fold_ } \" ) print ( f \"AUC = { metrics . roc_auc_score ( ytest , preds ) } \" ) print ( \"\" ) \u5e73\u5747 AUC \u503c\u7ea6\u4e3a 0.72\u3002\u8fd9\u8fd8\u4e0d\u9519\uff0c\u4f46\u6211\u4eec\u5e0c\u671b\u80fd\u505a\u5f97\u66f4\u597d\u3002\u4f60\u53ef\u4ee5\u5c06\u8fd9\u79cd\u65b9\u6cd5\u7528\u4e8e\u56fe\u50cf\uff0c\u8fd9\u4e5f\u662f\u5b83\u5728\u4ee5\u524d\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002SVM \u5728\u56fe\u50cf\u6570\u636e\u96c6\u65b9\u9762\u76f8\u5f53\u6709\u540d\u3002\u6df1\u5ea6\u5b66\u4e60\u5df2\u88ab\u8bc1\u660e\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u6700\u5148\u8fdb\u65b9\u6cd5\uff0c\u56e0\u6b64\u6211\u4eec\u4e0b\u4e00\u6b65\u53ef\u4ee5\u8bd5\u8bd5\u5b83\u3002 \u5173\u4e8e\u6df1\u5ea6\u5b66\u4e60\u7684\u5386\u53f2\u4ee5\u53ca\u8c01\u53d1\u660e\u4e86\u4ec0\u4e48\uff0c\u6211\u5c31\u4e0d\u591a\u8bf4\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u6700\u8457\u540d\u7684\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e4b\u4e00 AlexNet\u3002 \u56fe 3\uff1aAlexNet \u67b6\u67849 \u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u4e2d\u7684\u8f93\u5165\u5927\u5c0f\u4e0d\u662f 224x224 \u800c\u662f 227x227 \u5982\u4eca\uff0c\u4f60\u53ef\u80fd\u4f1a\u8bf4\u8fd9\u53ea\u662f\u4e00\u4e2a\u57fa\u672c\u7684 \u6df1\u5ea6\u5377\u79ef\u795e\u7ecf\u7f51\u7edc \uff0c\u4f46\u5b83\u5374\u662f\u8bb8\u591a\u65b0\u578b\u6df1\u5ea6\u7f51\u7edc\uff08\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff09\u7684\u57fa\u7840\u3002\u6211\u4eec\u770b\u5230\uff0c\u56fe 3 \u4e2d\u7684\u7f51\u7edc\u662f\u4e00\u4e2a\u5177\u6709\u4e94\u4e2a\u5377\u79ef\u5c42\u3001\u4e24\u4e2a\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u3002\u6211\u4eec\u770b\u5230\u8fd8\u6709\u6700\u5927\u6c60\u5316\u3002\u8fd9\u662f\u4ec0\u4e48\u610f\u601d\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5728\u8fdb\u884c\u6df1\u5ea6\u5b66\u4e60\u65f6\u4f1a\u9047\u5230\u7684\u4e00\u4e9b\u672f\u8bed\u3002 \u56fe 4\uff1a\u56fe\u50cf\u5927\u5c0f\u4e3a 8x8\uff0c\u6ee4\u6ce2\u5668\u5927\u5c0f\u4e3a 3x3\uff0c\u6b65\u957f\u4e3a 2\u3002 \u56fe 4 \u5f15\u5165\u4e86\u4e24\u4e2a\u65b0\u672f\u8bed\uff1a\u6ee4\u6ce2\u5668\u548c\u6b65\u957f\u3002 \u6ee4\u6ce2\u5668 \u662f\u7531\u7ed9\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u7684\u4e8c\u7ef4\u77e9\u9635\uff0c\u7531\u6307\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u3002 Kaiming\u6b63\u6001\u521d\u59cb\u5316 \uff0c\u662f\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u7684\u6700\u4f73\u9009\u62e9\u3002\u8fd9\u662f\u56e0\u4e3a\u5927\u591a\u6570\u73b0\u4ee3\u7f51\u7edc\u90fd\u4f7f\u7528 ReLU \uff08\u6574\u6d41\u7ebf\u6027\u5355\u5143\uff09\u6fc0\u6d3b\u51fd\u6570\uff0c\u9700\u8981\u9002\u5f53\u7684\u521d\u59cb\u5316\u6765\u907f\u514d\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff08\u68af\u5ea6\u8d8b\u8fd1\u4e8e\u96f6\uff0c\u7f51\u7edc\u6743\u91cd\u4e0d\u53d8\uff09\u3002\u8be5\u6ee4\u6ce2\u5668\u4e0e\u56fe\u50cf\u8fdb\u884c\u5377\u79ef\u3002\u5377\u79ef\u4e0d\u8fc7\u662f\u6ee4\u6ce2\u5668\u4e0e\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u5f53\u524d\u91cd\u53e0\u50cf\u7d20\u4e4b\u95f4\u7684\u5143\u7d20\u76f8\u4e58\u7684\u603b\u548c\u3002\u60a8\u53ef\u4ee5\u5728\u4efb\u4f55\u9ad8\u4e2d\u6570\u5b66\u6559\u79d1\u4e66\u4e2d\u9605\u8bfb\u66f4\u591a\u5173\u4e8e\u5377\u79ef\u7684\u5185\u5bb9\u3002\u6211\u4eec\u4ece\u56fe\u50cf\u7684\u5de6\u4e0a\u89d2\u5f00\u59cb\u5bf9\u6ee4\u955c\u8fdb\u884c\u5377\u79ef\uff0c\u7136\u540e\u6c34\u5e73\u79fb\u52a8\u6ee4\u955c\u3002\u5982\u679c\u79fb\u52a8 1 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 1\uff1b\u5982\u679c\u79fb\u52a8 2 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 2\u3002 \u5373\u4f7f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u4f8b\u5982\u5728\u95ee\u9898\u548c\u56de\u7b54\u7cfb\u7edf\u4e2d\u9700\u8981\u4ece\u5927\u91cf\u6587\u672c\u8bed\u6599\u4e2d\u7b5b\u9009\u7b54\u6848\u65f6\uff0c\u6b65\u957f\u4e5f\u662f\u4e00\u4e2a\u975e\u5e38\u6709\u7528\u7684\u6982\u5ff5\u3002\u5f53\u6211\u4eec\u5728\u6c34\u5e73\u65b9\u5411\u4e0a\u8d70\u5230\u5c3d\u5934\u65f6\uff0c\u5c31\u4f1a\u4ee5\u540c\u6837\u7684\u6b65\u957f\u5782\u76f4\u5411\u4e0b\u79fb\u52a8\u8fc7\u6ee4\u5668\uff0c\u4ece\u5de6\u4fa7\u5f00\u59cb\u3002\u56fe 4 \u8fd8\u663e\u793a\u4e86\u8fc7\u6ee4\u5668\u79fb\u51fa\u56fe\u50cf\u7684\u60c5\u51b5\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u65e0\u6cd5\u8ba1\u7b97\u5377\u79ef\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u8df3\u8fc7\u5b83\u3002\u5982\u679c\u4e0d\u60f3\u8df3\u8fc7\uff0c\u5219\u9700\u8981\u5bf9\u56fe\u50cf\u8fdb\u884c \u586b\u5145\uff08pad\uff09 \u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5377\u79ef\u4f1a\u51cf\u5c0f\u56fe\u50cf\u7684\u5927\u5c0f\u3002\u586b\u5145\u4e5f\u662f\u4fdd\u6301\u56fe\u50cf\u5927\u5c0f\u4e0d\u53d8\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u5728\u56fe 4 \u4e2d\uff0c\u4e00\u4e2a 3x3 \u6ee4\u6ce2\u5668\u6b63\u5728\u6c34\u5e73\u548c\u5782\u76f4\u79fb\u52a8\uff0c\u6bcf\u6b21\u79fb\u52a8\u90fd\u4f1a\u5206\u522b\u8df3\u8fc7\u4e24\u5217\u548c\u4e24\u884c\uff08\u5373\u50cf\u7d20\uff09\u3002\u7531\u4e8e\u5b83\u8df3\u8fc7\u4e86\u4e24\u4e2a\u50cf\u7d20\uff0c\u6240\u4ee5\u6b65\u957f = 2\u3002\u56e0\u6b64\u56fe\u50cf\u5927\u5c0f\u4e3a [(8-3) / 2] + 1 = 3.5\u3002\u6211\u4eec\u53d6 3.5 \u7684\u4e0b\u9650\uff0c\u6240\u4ee5\u662f 3x3\u3002\u60a8\u53ef\u4ee5\u5728\u8349\u7a3f\u7eb8\u4e0a\u8fdb\u884c\u5c1d\u8bd5\u3002 \u56fe 5\uff1a\u901a\u8fc7\u586b\u5145\uff0c\u6211\u4eec\u53ef\u4ee5\u63d0\u4f9b\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\u7684\u56fe\u50cf \u6211\u4eec\u53ef\u4ee5\u4ece\u56fe 5 \u4e2d\u770b\u5230\u586b\u5145\u7684\u6548\u679c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e00\u4e2a 3x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5b83\u4ee5 1 \u7684\u6b65\u957f\u79fb\u52a8\u3002\u539f\u59cb\u56fe\u50cf\u7684\u5927\u5c0f\u4e3a 6x6\uff0c\u6211\u4eec\u6dfb\u52a0\u4e86 1 \u7684 \u586b\u5145 \u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u751f\u6210\u7684\u56fe\u50cf\u5c06\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\uff0c\u5373 6x6\u3002\u5728\u5904\u7406\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u7684\u53e6\u4e00\u4e2a\u76f8\u5173\u672f\u8bed\u662f \u81a8\u80c0\uff08dilation\uff09 \uff0c\u5982\u56fe 6 \u6240\u793a\u3002 \u56fe 6\uff1a\u81a8\u80c0\uff08dilation\uff09\u7684\u4f8b\u5b50 \u5728\u81a8\u80c0\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u5c06\u6ee4\u6ce2\u5668\u6269\u5927 N-1\uff0c\u5176\u4e2d N \u662f\u81a8\u80c0\u7387\u7684\u503c\uff0c\u6216\u7b80\u79f0\u4e3a\u81a8\u80c0\u3002\u5728\u8fd9\u79cd\u5e26\u81a8\u80c0\u7684\u5185\u6838\u4e2d\uff0c\u6bcf\u6b21\u5377\u79ef\u90fd\u4f1a\u8df3\u8fc7\u4e00\u4e9b\u50cf\u7d20\u3002\u8fd9\u5728\u5206\u5272\u4efb\u52a1\u4e2d\u5c24\u4e3a\u6709\u6548\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u4e8c\u7ef4\u5377\u79ef\u3002 \u8fd8\u6709\u4e00\u7ef4\u5377\u79ef\u548c\u66f4\u9ad8\u7ef4\u5ea6\u7684\u5377\u79ef\u3002\u5b83\u4eec\u90fd\u57fa\u4e8e\u76f8\u540c\u7684\u57fa\u672c\u6982\u5ff5\u3002 \u63a5\u4e0b\u6765\u662f \u6700\u5927\u6c60\u5316\uff08Max pooling\uff09 \u3002\u6700\u5927\u503c\u6c60\u53ea\u662f\u4e00\u4e2a\u8fd4\u56de\u6700\u5927\u503c\u7684\u6ee4\u6ce2\u5668\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u63d0\u53d6\u7684\u4e0d\u662f\u5377\u79ef\uff0c\u800c\u662f\u50cf\u7d20\u7684\u6700\u5927\u503c\u3002\u540c\u6837\uff0c \u5e73\u5747\u6c60\u5316\uff08average pooling\uff09 \u6216 \u5747\u503c\u6c60\u5316\uff08mean pooling\uff09 \u4f1a\u8fd4\u56de\u50cf\u7d20\u7684\u5e73\u5747\u503c\u3002\u5b83\u4eec\u7684\u4f7f\u7528\u65b9\u6cd5\u4e0e\u5377\u79ef\u6838\u76f8\u540c\u3002\u6c60\u5316\u6bd4\u5377\u79ef\u66f4\u5feb\uff0c\u662f\u4e00\u79cd\u5bf9\u56fe\u50cf\u8fdb\u884c\u7f29\u51cf\u91c7\u6837\u7684\u65b9\u6cd5\u3002\u6700\u5927\u6c60\u5316\u53ef\u68c0\u6d4b\u8fb9\u7f18\uff0c\u5e73\u5747\u6c60\u5316\u53ef\u5e73\u6ed1\u56fe\u50cf\u3002 \u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u548c\u6df1\u5ea6\u5b66\u4e60\u7684\u6982\u5ff5\u592a\u591a\u4e86\u3002\u6211\u6240\u8ba8\u8bba\u7684\u662f\u4e00\u4e9b\u57fa\u7840\u77e5\u8bc6\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5728 PyTorch \u4e2d\u6784\u5efa\u7b2c\u4e00\u4e2a\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u505a\u597d\u4e86\u5145\u5206\u51c6\u5907\u3002PyTorch \u63d0\u4f9b\u4e86\u4e00\u79cd\u76f4\u89c2\u800c\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u5b9e\u73b0\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e14\u4f60\u4e0d\u9700\u8981\u5173\u5fc3\u53cd\u5411\u4f20\u64ad\u3002\u6211\u4eec\u7528\u4e00\u4e2a python \u7c7b\u548c\u4e00\u4e2a\u524d\u9988\u51fd\u6570\u6765\u5b9a\u4e49\u7f51\u7edc\uff0c\u544a\u8bc9 PyTorch \u5404\u5c42\u4e4b\u95f4\u5982\u4f55\u8fde\u63a5\u3002\u5728 PyTorch \u4e2d\uff0c\u56fe\u50cf\u7b26\u53f7\u662f BS\u3001C\u3001H\u3001W\uff0c\u5176\u4e2d\uff0cBS \u662f\u6279\u5927\u5c0f\uff0cC \u662f\u901a\u9053\uff0cH \u662f\u9ad8\u5ea6\uff0cW \u662f\u5bbd\u5ea6\u3002\u8ba9\u6211\u4eec\u770b\u770b PyTorch \u662f\u5982\u4f55\u5b9e\u73b0 AlexNet \u7684\u3002 import torch import torch.nn as nn import torch.nn.functional as F class AlexNet ( nn . Module ): def __init__ ( self ): super ( AlexNet , self ) . __init__ () self . conv1 = nn . Conv2d ( in_channels = 3 , out_channels = 96 , kernel_size = 11 , stride = 4 , padding = 0 ) self . pool1 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv2 = nn . Conv2d ( in_channels = 96 , out_channels = 256 , kernel_size = 5 , stride = 1 , padding = 2 ) self . pool2 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv3 = nn . Conv2d ( in_channels = 256 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv4 = nn . Conv2d ( in_channels = 384 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv5 = nn . Conv2d ( in_channels = 384 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ) self . pool3 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . fc1 = nn . Linear ( in_features = 9216 , out_features = 4096 ) self . dropout1 = nn . Dropout ( 0.5 ) self . fc2 = nn . Linear ( in_features = 4096 , out_features = 4096 ) self . dropout2 = nn . Dropout ( 0.5 ) self . fc3 = nn . Linear ( in_features = 4096 , out_features = 1000 ) def forward ( self , image ): bs , c , h , w = image . size () x = F . relu ( self . conv1 ( image )) # size: (bs, 96, 55, 55) x = self . pool1 ( x ) # size: (bs, 96, 27, 27) x = F . relu ( self . conv2 ( x )) # size: (bs, 256, 27, 27) x = self . pool2 ( x ) # size: (bs, 256, 13, 13) x = F . relu ( self . conv3 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv4 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv5 ( x )) # size: (bs, 256, 13, 13) x = self . pool3 ( x ) # size: (bs, 256, 6, 6) x = x . view ( bs , - 1 ) # size: (bs, 9216) x = F . relu ( self . fc1 ( x )) # size: (bs, 4096) x = self . dropout1 ( x ) # size: (bs, 4096) # dropout does not change size # dropout is used for regularization # 0.3 dropout means that only 70% of the nodes # of the current layer are used for the next layer x = F . relu ( self . fc2 ( x )) # size: (bs, 4096) x = self . dropout2 ( x ) # size: (bs, 4096) x = F . relu ( self . fc3 ( x )) # size: (bs, 1000) # 1000 is number of classes in ImageNet Dataset # softmax is an activation function that converts # linear output to probabilities that add up to 1 # for each sample in the batch x = torch . softmax ( x , axis = 1 ) # size: (bs, 1000) return x \u5982\u679c\u60a8\u6709\u4e00\u5e45 3x227x227 \u7684\u56fe\u50cf\uff0c\u5e76\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11 \u7684\u5377\u79ef\u6ee4\u6ce2\u5668\uff0c\u8fd9\u610f\u5473\u7740\u60a8\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5e76\u4e0e\u4e00\u4e2a\u5927\u5c0f\u4e3a 227x227x3 \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u5377\u79ef\u3002\u8f93\u51fa\u901a\u9053\u7684\u6570\u91cf\u5c31\u662f\u5206\u522b\u5e94\u7528\u4e8e\u56fe\u50cf\u7684\u76f8\u540c\u5927\u5c0f\u7684\u4e0d\u540c\u5377\u79ef\u6ee4\u6ce2\u5668\u7684\u6570\u91cf\u3002 \u56e0\u6b64\uff0c\u5728\u7b2c\u4e00\u4e2a\u5377\u79ef\u5c42\u4e2d\uff0c\u8f93\u5165\u901a\u9053\u662f 3\uff0c\u4e5f\u5c31\u662f\u539f\u59cb\u8f93\u5165\uff0c\u5373 R\u3001G\u3001B \u4e09\u901a\u9053\u3002PyTorch \u7684 torchvision \u63d0\u4f9b\u4e86\u8bb8\u591a\u4e0e AlexNet \u7c7b\u4f3c\u7684\u4e0d\u540c\u6a21\u578b\uff0c\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0cAlexNet \u7684\u5b9e\u73b0\u4e0e torchvision \u7684\u5b9e\u73b0\u5e76\u4e0d\u76f8\u540c\u3002Torchvision \u7684 AlexNet \u5b9e\u73b0\u662f\u4ece\u53e6\u4e00\u7bc7\u8bba\u6587\u4e2d\u4fee\u6539\u800c\u6765\u7684 AlexNet\uff1a Krizhevsky, A. One weird trick for parallelizing convolutional neural networks. CoRR, abs/1404.5997, 2014. \u4f60\u53ef\u4ee5\u4e3a\u81ea\u5df1\u7684\u4efb\u52a1\u8bbe\u8ba1\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u5f88\u591a\u65f6\u5019\uff0c\u4ece\u96f6\u505a\u8d77\u662f\u4e2a\u4e0d\u9519\u7684\u4e3b\u610f\u3002\u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u7f51\u7edc\uff0c\u7528\u4e8e\u533a\u5206\u56fe\u50cf\u6709\u65e0\u6c14\u80f8\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u51c6\u5907\u4e00\u4e9b\u6587\u4ef6\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u4e00\u4e2a\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\uff0c\u5373 train.csv\uff0c\u4f46\u589e\u52a0\u4e00\u5217 kfold\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e94\u4e2a\u6587\u4ef6\u5939\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u5df2\u7ecf\u6f14\u793a\u4e86\u5982\u4f55\u9488\u5bf9\u4e0d\u540c\u7684\u6570\u636e\u96c6\u521b\u5efa\u6298\u53e0\uff0c\u56e0\u6b64\u6211\u5c06\u8df3\u8fc7\u8fd9\u4e00\u90e8\u5206\uff0c\u7559\u4f5c\u7ec3\u4e60\u3002\u5bf9\u4e8e\u57fa\u4e8e PyTorch \u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u9700\u8981\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u7684\u76ee\u7684\u662f\u8fd4\u56de\u4e00\u4e2a\u6570\u636e\u9879\u6216\u6570\u636e\u6837\u672c\u3002\u8fd9\u4e2a\u6570\u636e\u6837\u672c\u5e94\u8be5\u5305\u542b\u8bad\u7ec3\u6216\u8bc4\u4f30\u6a21\u578b\u6240\u9700\u7684\u6240\u6709\u5185\u5bb9\u3002 import torch import numpy as np from PIL import Image from PIL import ImageFile ImageFile . LOAD_TRUNCATED_IMAGES = True # \u5b9a\u4e49\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u5904\u7406\u56fe\u50cf\u5206\u7c7b\u4efb\u52a1 class ClassificationDataset : def __init__ ( self , image_paths , targets , resize = None , augmentations = None ): # \u56fe\u50cf\u6587\u4ef6\u8def\u5f84\u5217\u8868 self . image_paths = image_paths # \u76ee\u6807\u6807\u7b7e\u5217\u8868 self . targets = targets # \u56fe\u50cf\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . resize = resize # \u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . augmentations = augmentations def __len__ ( self ): # \u8fd4\u56de\u6570\u636e\u96c6\u7684\u5927\u5c0f\uff0c\u5373\u56fe\u50cf\u6570\u91cf return len ( self . image_paths ) def __getitem__ ( self , item ): # \u83b7\u53d6\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e2a\u6837\u672c image = Image . open ( self . image_paths [ item ]) image = image . convert ( \"RGB\" ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aRGB\u683c\u5f0f # \u83b7\u53d6\u8be5\u6837\u672c\u7684\u76ee\u6807\u6807\u7b7e targets = self . targets [ item ] if self . resize is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u5c06\u56fe\u50cf\u8fdb\u884c\u5c3a\u5bf8\u8c03\u6574 image = image . resize (( self . resize [ 1 ], self . resize [ 0 ]), resample = Image . BILINEAR ) image = np . array ( image ) if self . augmentations is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u5e94\u7528\u6570\u636e\u589e\u5f3a augmented = self . augmentations ( image = image ) image = augmented [ \"image\" ] # \u5c06\u56fe\u50cf\u901a\u9053\u987a\u5e8f\u8c03\u6574\u4e3a(C, H, W)\u7684\u5f62\u5f0f\uff0c\u5e76\u8f6c\u6362\u4e3afloat32\u7c7b\u578b image = np . transpose ( image , ( 2 , 0 , 1 )) . astype ( np . float32 ) # \u8fd4\u56de\u6837\u672c\uff0c\u5305\u62ec\u56fe\u50cf\u548c\u5bf9\u5e94\u7684\u76ee\u6807\u6807\u7b7e return { \"image\" : torch . tensor ( image , dtype = torch . float ), \"targets\" : torch . tensor ( targets , dtype = torch . long ), } \u73b0\u5728\u6211\u4eec\u9700\u8981 engine.py\u3002engine.py \u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u770b\u770b engine.py \u662f\u5982\u4f55\u7f16\u5199\u7684\u3002 import torch import torch.nn as nn from tqdm import tqdm # \u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u51fd\u6570 def train ( data_loader , model , optimizer , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bad\u7ec3\u6a21\u5f0f model . train () for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u548c\u76ee\u6807\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) targets = targets . to ( device , dtype = torch . float ) # \u5c06\u4f18\u5316\u5668\u4e2d\u7684\u68af\u5ea6\u5f52\u96f6 optimizer . zero_grad () # \u524d\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u6a21\u578b\u9884\u6d4b outputs = model ( inputs ) # \u4f7f\u7528\u5e26\u903b\u8f91\u65af\u8482\u51fd\u6570\u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\u8ba1\u7b97\u635f\u5931 loss = nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) # \u53cd\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u68af\u5ea6\u5e76\u66f4\u65b0\u6a21\u578b\u6743\u91cd loss . backward () optimizer . step () # \u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u7684\u51fd\u6570 def evaluate ( data_loader , model , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bc4\u4f30\u6a21\u5f0f\uff08\u4e0d\u8fdb\u884c\u68af\u5ea6\u8ba1\u7b97\uff09 model . eval () # \u521d\u59cb\u5316\u5217\u8868\u4ee5\u5b58\u50a8\u771f\u5b9e\u76ee\u6807\u548c\u6a21\u578b\u9884\u6d4b final_targets = [] final_outputs = [] with torch . no_grad (): for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) # \u83b7\u53d6\u6a21\u578b\u9884\u6d4b output = model ( inputs ) # \u5c06\u76ee\u6807\u548c\u8f93\u51fa\u8f6c\u6362\u4e3aCPU\u548cPython\u5217\u8868 targets = targets . detach () . cpu () . numpy () . tolist () output = output . detach () . cpu () . numpy () . tolist () # \u5c06\u5217\u8868\u6269\u5c55\u4ee5\u5305\u542b\u6279\u6b21\u6570\u636e final_targets . extend ( targets ) final_outputs . extend ( output ) # \u8fd4\u56de\u6700\u7ec8\u7684\u6a21\u578b\u9884\u6d4b\u548c\u771f\u5b9e\u76ee\u6807 return final_outputs , final_targets \u6709\u4e86 engine.py\uff0c\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\uff1amodel.py\u3002model.py \u5c06\u5305\u542b\u6211\u4eec\u7684\u6a21\u578b\u3002\u628a\u6a21\u578b\u4e0e\u8bad\u7ec3\u5206\u5f00\u662f\u4e2a\u597d\u4e3b\u610f\uff0c\u56e0\u4e3a\u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u8bd5\u9a8c\u4e0d\u540c\u7684\u6a21\u578b\u548c\u4e0d\u540c\u7684\u67b6\u6784\u3002\u540d\u4e3a pretrainedmodels \u7684 PyTorch \u5e93\u4e2d\u6709\u5f88\u591a\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\uff0c\u5982 AlexNet\u3001ResNet\u3001DenseNet \u7b49\u3002\u8fd9\u4e9b\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\u662f\u5728\u540d\u4e3a ImageNet \u7684\u5927\u578b\u56fe\u50cf\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u51fa\u6765\u7684\u3002\u5728 ImageNet \u4e0a\u8bad\u7ec3\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5b83\u4eec\u7684\u6743\u91cd\uff0c\u4e5f\u53ef\u4ee5\u4e0d\u4f7f\u7528\u8fd9\u4e9b\u6743\u91cd\u3002\u5982\u679c\u6211\u4eec\u4e0d\u4f7f\u7528 ImageNet \u6743\u91cd\u8fdb\u884c\u8bad\u7ec3\uff0c\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u7f51\u7edc\u5c06\u4ece\u5934\u5f00\u59cb\u5b66\u4e60\u4e00\u5207\u3002\u8fd9\u5c31\u662f model.py \u7684\u6837\u5b50\u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 4096 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 4096 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5982\u679c\u4f60\u6253\u5370\u4e86\u7f51\u7edc\uff0c\u4f1a\u5f97\u5230\u5982\u4e0b\u8f93\u51fa\uff1a AlexNet ( ( avgpool ): AdaptiveAvgPool2d ( output_size = ( 6 , 6 )) ( _features ): Sequential ( ( 0 ): Conv2d ( 3 , 64 , kernel_size = ( 11 , 11 ), stride = ( 4 , 4 ), padding = ( 2 , 2 )) ( 1 ): ReLU ( inplace = True ) ( 2 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 3 ): Conv2d ( 64 , 192 , kernel_size = ( 5 , 5 ), stride = ( 1 , 1 ), padding = ( 2 , 2 )) ( 4 ): ReLU ( inplace = True ) ( 5 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 6 ): Conv2d ( 192 , 384 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 7 ): ReLU ( inplace = True ) ( 8 ): Conv2d ( 384 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 9 ): ReLU ( inplace = True ) ( 10 ): Conv2d ( 256 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 11 ): ReLU ( inplace = True ) ( 12 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , eil_mode = False )) ( dropout0 ): Dropout ( p = 0.5 , inplace = False ) ( linear0 ): Linear ( in_features = 9216 , out_features = 4096 , bias = True ) ( relu0 ): ReLU ( inplace = True ) ( dropout1 ): Dropout ( p = 0.5 , inplace = False ) ( linear1 ): Linear ( in_features = 4096 , out_features = 4096 , bias = True ) ( relu1 ): ReLU ( inplace = True ) ( last_linear ): Sequential ( ( 0 ): BatchNorm1d ( 4096 , eps = 1e-05 , momentum = 0.1 , affine = True , rack_running_stats = True ) ( 1 ): Dropout ( p = 0.25 , inplace = False ) ( 2 ): Linear ( in_features = 4096 , out_features = 2048 , bias = True ) ( 3 ): ReLU () ( 4 ): BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 , affine = True , track_running_stats = True ) ( 5 ): Dropout ( p = 0.5 , inplace = False ) ( 6 ): Linear ( in_features = 2048 , out_features = 1 , bias = True ) ) ) \u73b0\u5728\uff0c\u4e07\u4e8b\u4ff1\u5907\uff0c\u53ef\u4ee5\u5f00\u59cb\u8bad\u7ec3\u4e86\u3002\u6211\u4eec\u5c06\u4f7f\u7528 train.py \u8bad\u7ec3\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import torch from sklearn import metrics from sklearn.model_selection import train_test_split import dataset import engine from model import get_model if __name__ == \"__main__\" : # \u5b9a\u4e49\u6570\u636e\u8def\u5f84\u3001\u8bbe\u5907\u3001\u8fed\u4ee3\u6b21\u6570 data_path = \"/home/abhishek/workspace/siim_png/\" device = \"cuda\" # \u4f7f\u7528GPU\u52a0\u901f epochs = 10 # \u4eceCSV\u6587\u4ef6\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( os . path . join ( data_path , \"train.csv\" )) images = df . ImageId . values . tolist () images = [ os . path . join ( data_path , \"train_png\" , i + \".png\" ) for i in images ] targets = df . target . values # \u83b7\u53d6\u9884\u8bad\u7ec3\u7684\u6a21\u578b model = get_model ( pretrained = True ) model . to ( device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\uff0c\u7528\u4e8e\u6570\u636e\u6807\u51c6\u5316 mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) # \u6570\u636e\u589e\u5f3a\uff0c\u5c06\u56fe\u50cf\u6807\u51c6\u5316 aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5212\u5206\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 train_images , valid_images , train_targets , valid_targets = train_test_split ( images , targets , stratify = targets , random_state = 42 ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u548c\u9a8c\u8bc1\u6570\u636e\u96c6 train_dataset = dataset . ClassificationDataset ( image_paths = train_images , targets = train_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = 16 , shuffle = True , num_workers = 4 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = dataset . ClassificationDataset ( image_paths = valid_images , targets = valid_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = 16 , shuffle = False , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u8bad\u7ec3\u5faa\u73af for epoch in range ( epochs ): # \u8bad\u7ec3\u6a21\u578b engine . train ( train_loader , model , optimizer , device = device ) # \u8bc4\u4f30\u6a21\u578b\u6027\u80fd predictions , valid_targets = engine . evaluate ( valid_loader , model , device = device ) # \u8ba1\u7b97ROC AUC\u5206\u6570\u5e76\u6253\u5370 roc_auc = metrics . roc_auc_score ( valid_targets , predictions ) print ( f \"Epoch= { epoch } , Valid ROC AUC= { roc_auc } \" ) \u8ba9\u6211\u4eec\u5728\u6ca1\u6709\u9884\u8bad\u7ec3\u6743\u91cd\u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff1a Epoch = 0 , Valid ROC AUC = 0.5737161981475328 Epoch = 1 , Valid ROC AUC = 0.5362868001588292 Epoch = 2 , Valid ROC AUC = 0.6163448214387008 Epoch = 3 , Valid ROC AUC = 0.6119219143780944 Epoch = 4 , Valid ROC AUC = 0.6229718888519726 Epoch = 5 , Valid ROC AUC = 0.5983014999635341 Epoch = 6 , Valid ROC AUC = 0.5523236874306134 Epoch = 7 , Valid ROC AUC = 0.4717721611306046 Epoch = 8 , Valid ROC AUC = 0.6473408263980617 Epoch = 9 , Valid ROC AUC = 0.6639862888260415 AUC \u7ea6\u4e3a 0.66\uff0c\u751a\u81f3\u4f4e\u4e8e\u6211\u4eec\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u4f1a\u53d1\u751f\u4ec0\u4e48\u60c5\u51b5\uff1f Epoch = 0 , Valid ROC AUC = 0.5730387429803165 Epoch = 1 , Valid ROC AUC = 0.5319813942934937 Epoch = 2 , Valid ROC AUC = 0.627111577514323 Epoch = 3 , Valid ROC AUC = 0.6819736959393209 Epoch = 4 , Valid ROC AUC = 0.5747117168950512 Epoch = 5 , Valid ROC AUC = 0.5994619255609669 Epoch = 6 , Valid ROC AUC = 0.5080889443530546 Epoch = 7 , Valid ROC AUC = 0.6323792776512727 Epoch = 8 , Valid ROC AUC = 0.6685753182661686 Epoch = 9 , Valid ROC AUC = 0.6861802387300147 \u73b0\u5728\u7684 AUC \u597d\u4e86\u5f88\u591a\u3002\u4e0d\u8fc7\uff0c\u5b83\u4ecd\u7136\u8f83\u4f4e\u3002\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u597d\u5904\u662f\u53ef\u4ee5\u8f7b\u677e\u5c1d\u8bd5\u591a\u79cd\u4e0d\u540c\u7684\u6a21\u578b\u3002\u8ba9\u6211\u4eec\u8bd5\u8bd5\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u7684 resnet18 \u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 512 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 512 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5728\u5c1d\u8bd5\u8be5\u6a21\u578b\u65f6\uff0c\u6211\u8fd8\u5c06\u56fe\u50cf\u5927\u5c0f\u6539\u4e3a 512x512\uff0c\u5e76\u6dfb\u52a0\u4e86\u4e00\u4e2a\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\uff0c\u6bcf 3 \u4e2aepochs\u540e\u5c06\u5b66\u4e60\u7387\u4e58\u4ee5 0.5\u3002 Epoch = 0 , Valid ROC AUC = 0.5988225569880796 Epoch = 1 , Valid ROC AUC = 0.730349343208836 Epoch = 2 , Valid ROC AUC = 0.5870943169939142 Epoch = 3 , Valid ROC AUC = 0.5775864444138311 Epoch = 4 , Valid ROC AUC = 0.7330502499939224 Epoch = 5 , Valid ROC AUC = 0.7500336296524395 Epoch = 6 , Valid ROC AUC = 0.7563722113724951 Epoch = 7 , Valid ROC AUC = 0.7987463837994215 Epoch = 8 , Valid ROC AUC = 0.798505708937384 Epoch = 9 , Valid ROC AUC = 0.8025477500546988 \u8fd9\u4e2a\u6a21\u578b\u4f3c\u4e4e\u8868\u73b0\u6700\u597d\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u8c03\u6574 AlexNet \u4e2d\u7684\u4e0d\u540c\u53c2\u6570\u548c\u56fe\u50cf\u5927\u5c0f\uff0c\u4ee5\u83b7\u5f97\u66f4\u597d\u7684\u5206\u6570\u3002 \u4f7f\u7528\u589e\u5f3a\u6280\u672f\u5c06\u8fdb\u4e00\u6b65\u63d0\u9ad8\u5f97\u5206\u3002\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u5f88\u96be\uff0c\u4f46\u5e76\u975e\u4e0d\u53ef\u80fd\u3002\u9009\u62e9 Adam \u4f18\u5316\u5668\u3001\u4f7f\u7528\u4f4e\u5b66\u4e60\u7387\u3001\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u9ad8\u70b9\u65f6\u964d\u4f4e\u5b66\u4e60\u7387\u3001\u5c1d\u8bd5\u4e00\u4e9b\u589e\u5f3a\u6280\u672f\u3001\u5c1d\u8bd5\u5bf9\u56fe\u50cf\u8fdb\u884c\u9884\u5904\u7406\uff08\u5982\u5728\u9700\u8981\u65f6\u8fdb\u884c\u88c1\u526a\uff0c\u8fd9\u4e5f\u53ef\u89c6\u4e3a\u9884\u5904\u7406\uff09\u3001\u6539\u53d8\u6279\u6b21\u5927\u5c0f\u7b49\u3002\u4f60\u53ef\u4ee5\u505a\u5f88\u591a\u4e8b\u60c5\u6765\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u3002 \u4e0e AlexNet \u76f8\u6bd4\uff0c ResNet \u7684\u7ed3\u6784\u8981\u590d\u6742\u5f97\u591a\u3002ResNet \u662f\u6b8b\u5dee\u795e\u7ecf\u7f51\u7edc\uff08Residual Neural Network\uff09\u7684\u7f29\u5199\uff0c\u7531 K. He\u3001X. Zhang\u3001S. Ren \u548c J. Sun \u5728 2015 \u5e74\u53d1\u8868\u7684\u8bba\u6587\u4e2d\u63d0\u51fa\u3002ResNet \u7531 \u6b8b\u5dee\u5757 \uff08residual blocks\uff09\u7ec4\u6210\uff0c\u901a\u8fc7\u8df3\u8fc7\u67d0\u4e9b\u5c42\uff0c\u4f7f\u77e5\u8bc6\u80fd\u591f\u4e0d\u65ad\u5728\u5404\u5c42\u4e2d\u8fdb\u884c\u4f20\u9012\u3002\u8fd9\u4e9b\u5c42\u4e4b\u95f4\u7684 \u8fde\u63a5\u88ab\u79f0\u4e3a \u8df3\u8dc3\u8fde\u63a5 \uff08skip-connections\uff09\uff0c\u56e0\u4e3a\u6211\u4eec\u8df3\u8fc7\u4e86\u4e00\u5c42\u6216\u591a\u5c42\u3002\u8df3\u8dc3\u8fde\u63a5\u901a\u8fc7\u5c06\u68af\u5ea6\u4f20\u64ad\u5230\u66f4\u591a\u5c42\u6765\u5e2e\u52a9\u89e3\u51b3\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u8bad\u7ec3\u975e\u5e38\u5927\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e0d\u4f1a\u635f\u5931\u6027\u80fd\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5927\u578b\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u4e48\u5f53\u8bad\u7ec3\u5230\u67d0\u4e00\u8282\u70b9\u4e0a\u65f6\u8bad\u7ec3\u635f\u5931\u53cd\u800c\u4f1a\u589e\u52a0\uff0c\u4f46\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u8df3\u8dc3\u8fde\u63a5\u6765\u907f\u514d\u3002\u901a\u8fc7\u56fe 7 \u53ef\u4ee5\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u56fe 7\uff1a\u7b80\u5355\u8fde\u63a5\u4e0e\u6b8b\u5dee\u8fde\u63a5\u7684\u6bd4\u8f83\u3002\u53c2\u89c1\u8df3\u8dc3\u8fde\u63a5\u3002\u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u7701\u7565\u4e86\u6700\u540e\u4e00\u5c42\u3002 \u6b8b\u5dee\u5757\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4f60\u4ece\u67d0\u4e00\u5c42\u83b7\u53d6\u8f93\u51fa\uff0c\u8df3\u8fc7\u4e00\u4e9b\u5c42\uff0c\u7136\u540e\u5c06\u8f93\u51fa\u6dfb\u52a0\u5230\u7f51\u7edc\u4e2d\u66f4\u8fdc\u7684\u4e00\u5c42\u3002\u865a\u7ebf\u8868\u793a\u8f93\u5165\u5f62\u72b6\u9700\u8981\u8c03\u6574\uff0c\u56e0\u4e3a\u4f7f\u7528\u4e86\u6700\u5927\u6c60\u5316\uff0c\u800c\u6700\u5927\u6c60\u5316\u7684\u4f7f\u7528\u4f1a\u6539\u53d8\u8f93\u51fa\u7684\u5927\u5c0f\u3002 ResNet \u6709\u591a\u79cd\u4e0d\u540c\u7684\u7248\u672c\uff1a \u6709 18 \u5c42\u300134 \u5c42\u300150 \u5c42\u3001101 \u5c42\u548c 152 \u5c42\uff0c\u6240\u6709\u8fd9\u4e9b\u5c42\u90fd\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8fdb\u884c\u4e86\u6743\u91cd\u9884\u8bad\u7ec3\u3002\u5982\u4eca\uff0c\u9884\u8bad\u7ec3\u6a21\u578b\uff08\u51e0\u4e4e\uff09\u9002\u7528\u4e8e\u6240\u6709\u60c5\u51b5\uff0c\u4f46\u8bf7\u786e\u4fdd\u60a8\u4ece\u8f83\u5c0f\u7684\u6a21\u578b\u5f00\u59cb\uff0c\u4f8b\u5982\uff0c\u4ece resnet-18 \u5f00\u59cb\uff0c\u800c\u4e0d\u662f resnet-50\u3002\u5176\u4ed6\u4e00\u4e9b ImageNet \u9884\u8bad\u7ec3\u6a21\u578b\u5305\u62ec\uff1a Inception DenseNet(different variations) NASNet PNASNet VGG Xception ResNeXt EfficientNet, etc. \u5927\u90e8\u5206\u9884\u8bad\u7ec3\u7684\u6700\u5148\u8fdb\u6a21\u578b\u53ef\u4ee5\u5728 GitHub \u4e0a\u7684 pytorch- pretrainedmodels \u8d44\u6e90\u5e93\u4e2d\u627e\u5230\uff1ahttps://github.com/Cadene/pretrained-models.pytorch\u3002\u8be6\u7ec6\u8ba8\u8bba\u8fd9\u4e9b\u6a21\u578b\u4e0d\u5728\u672c\u7ae0\uff08\u548c\u672c\u4e66\uff09\u8303\u56f4\u4e4b\u5185\u3002\u65e2\u7136\u6211\u4eec\u53ea\u5173\u6ce8\u5e94\u7528\uff0c\u90a3\u5c31\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6837\u7684\u9884\u8bad\u7ec3\u6a21\u578b\u5982\u4f55\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u3002 \u56fe 8\uff1aU-Net\u67b6\u6784 \u5206\u5272\uff08Segmentation\uff09\u662f\u8ba1\u7b97\u673a\u89c6\u89c9\u4e2d\u76f8\u5f53\u6d41\u884c\u7684\u4e00\u9879\u4efb\u52a1\u3002\u5728\u5206\u5272\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u8bd5\u56fe\u4ece\u80cc\u666f\u4e2d\u79fb\u9664/\u63d0\u53d6\u524d\u666f\u3002 \u524d\u666f\u548c\u80cc\u666f\u53ef\u4ee5\u6709\u4e0d\u540c\u7684\u5b9a\u4e49\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u8bf4\uff0c\u8fd9\u662f\u4e00\u9879\u50cf\u7d20\u5206\u7c7b\u4efb\u52a1\uff0c\u4f60\u7684\u5de5\u4f5c\u662f\u7ed9\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u6bcf\u4e2a\u50cf\u7d20\u5206\u914d\u4e00\u4e2a\u7c7b\u522b\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u6b63\u5728\u5904\u7406\u7684\u6c14\u80f8\u6570\u636e\u96c6\u5c31\u662f\u4e00\u9879\u5206\u5272\u4efb\u52a1\u3002\u5728\u8fd9\u9879\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u7ed9\u5b9a\u7684\u80f8\u90e8\u653e\u5c04\u56fe\u50cf\u8fdb\u884c\u6c14\u80f8\u5206\u5272\u3002\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u7684\u6700\u5e38\u7528\u6a21\u578b\u662f U-Net\u3002\u5176\u7ed3\u6784\u5982\u56fe 8 \u6240\u793a\u3002 U-Net \u5305\u62ec\u4e24\u4e2a\u90e8\u5206\uff1a\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u3002\u7f16\u7801\u5668\u4e0e\u60a8\u76ee\u524d\u6240\u89c1\u8fc7\u7684\u4efb\u4f55 U-Net \u90fd\u662f\u4e00\u6837\u7684\u3002\u89e3\u7801\u5668\u5219\u6709\u4e9b\u4e0d\u540c\u3002\u89e3\u7801\u5668\u7531\u4e0a\u5377\u79ef\u5c42\u7ec4\u6210\u3002\u5728\u4e0a\u5377\u79ef\uff08up-convolutions\uff09\uff08 \u8f6c\u7f6e\u5377\u79ef transposed convolutions\uff09\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u6ee4\u6ce2\u5668\uff0c\u5f53\u5e94\u7528\u5230\u4e00\u4e2a\u5c0f\u56fe\u50cf\u65f6\uff0c\u4f1a\u4ea7\u751f\u4e00\u4e2a\u5927\u56fe\u50cf\u3002\u5728 PyTorch \u4e2d\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528 ConvTranspose2d \u6765\u5b8c\u6210\u8fd9\u4e00\u64cd\u4f5c\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u4e0a\u5377\u79ef\u4e0e\u4e0a\u91c7\u6837\u5e76\u4e0d\u76f8\u540c\u3002\u4e0a\u91c7\u6837\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u8fc7\u7a0b\uff0c\u6211\u4eec\u5728\u56fe\u50cf\u4e0a\u5e94\u7528\u4e00\u4e2a\u51fd\u6570\u6765\u8c03\u6574\u5b83\u7684\u5927\u5c0f\u3002\u5728\u4e0a\u5377\u79ef\u4e2d\uff0c\u6211\u4eec\u8981\u5b66\u4e60\u6ee4\u6ce2\u5668\u3002\u6211\u4eec\u5c06\u7f16\u7801\u5668\u7684\u67d0\u4e9b\u90e8\u5206\u4f5c\u4e3a\u67d0\u4e9b\u89e3\u7801\u5668\u7684\u8f93\u5165\u3002\u8fd9\u5bf9 \u4e0a\u5377\u79ef\u5c42\u975e\u5e38\u91cd\u8981\u3002 \u8ba9\u6211\u4eec\u770b\u770b U-Net \u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import torch import torch.nn as nn from torch.nn import functional as F # \u5b9a\u4e49\u4e00\u4e2a\u53cc\u5377\u79ef\u5c42 def double_conv ( in_channels , out_channels ): conv = nn . Sequential ( nn . Conv2d ( in_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ), nn . Conv2d ( out_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ) ) return conv # \u5b9a\u4e49\u51fd\u6570\u7528\u4e8e\u88c1\u526a\u8f93\u5165\u5f20\u91cf def crop_tensor ( tensor , target_tensor ): target_size = target_tensor . size ()[ 2 ] tensor_size = tensor . size ()[ 2 ] delta = tensor_size - target_size delta = delta // 2 return tensor [:, :, delta : tensor_size - delta , delta : tensor_size - delta ] # \u5b9a\u4e49 U-Net \u6a21\u578b class UNet ( nn . Module ): def __init__ ( self ): super ( UNet , self ) . __init () # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . max_pool_2x2 = nn . MaxPool2d ( kernel_size = 2 , stride = 2 ) self . down_conv_1 = double_conv ( 1 , 64 ) self . down_conv_2 = double_conv ( 64 , 128 ) self . down_conv_3 = double_conv ( 128 , 256 ) self . down_conv_4 = double_conv ( 256 , 512 ) self . down_conv_5 = double_conv ( 512 , 1024 ) # \u5b9a\u4e49\u4e0a\u91c7\u6837\u5c42\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . up_trans_1 = nn . ConvTranspose2d ( in_channels = 1024 , out_channels = 512 , kernel_size = 2 , stride = 2 ) self . up_conv_1 = double_conv ( 1024 , 512 ) self . up_trans_2 = nn . ConvTranspose2d ( in_channels = 512 , out_channels = 256 , kernel_size = 2 , stride = 2 ) self . up_conv_2 = double_conv ( 512 , 256 ) self . up_trans_3 = nn . ConvTranspose2d ( in_channels = 256 , out_channels = 128 , kernel_size = 2 , stride = 2 ) self . up_conv_3 = double_conv ( 256 , 128 ) self . up_trans_4 = nn . ConvTranspose2d ( in_channels = 128 , out_channels = 64 , kernel_size = 2 , stride = 2 ) self . up_conv_4 = double_conv ( 128 , 64 ) # \u5b9a\u4e49\u8f93\u51fa\u5c42 self . out = nn . Conv2d ( in_channels = 64 , out_channels = 2 , kernel_size = 1 ) def forward ( self , image ): # \u7f16\u7801\u5668\u90e8\u5206 x1 = self . down_conv_1 ( image ) x2 = self . max_pool_2x2 ( x1 ) x3 = self . down_conv_2 ( x2 ) x4 = self . max_pool_2x2 ( x3 ) x5 = self . down_conv_3 ( x4 ) x6 = self . max_pool_2x2 ( x5 ) x7 = self . down_conv_4 ( x6 ) x8 = self . max_pool_2x2 ( x7 ) x9 = self . down_conv_5 ( x8 ) # \u89e3\u7801\u5668\u90e8\u5206 x = self . up_trans_1 ( x9 ) y = crop_tensor ( x7 , x ) x = self . up_conv_1 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_2 ( x ) y = crop_tensor ( x5 , x ) x = self . up_conv_2 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_3 ( x ) y = crop_tensor ( x3 , x ) x = self . up_conv_3 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_4 ( x ) y = crop_tensor ( x1 , x ) x = self . up_conv_4 ( torch . cat ([ x , y ], axis = 1 )) # \u8f93\u51fa\u5c42 out = self . out ( x ) return out if __name__ == \"__main__\" : image = torch . rand (( 1 , 1 , 572 , 572 )) model = UNet () print ( model ( image )) \u8bf7\u6ce8\u610f\uff0c\u6211\u4e0a\u9762\u5c55\u793a\u7684 U-Net \u5b9e\u73b0\u662f U-Net \u8bba\u6587\u7684\u539f\u59cb\u5b9e\u73b0\u3002\u4e92\u8054\u7f51\u4e0a\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9e\u73b0\u65b9\u6cd5\u3002 \u6709\u4e9b\u4eba\u559c\u6b22\u4f7f\u7528\u53cc\u7ebf\u6027\u91c7\u6837\u4ee3\u66ff\u8f6c\u7f6e\u5377\u79ef\u8fdb\u884c\u4e0a\u91c7\u6837\uff0c\u4f46\u8fd9\u5e76\u4e0d\u662f\u8bba\u6587\u7684\u771f\u6b63\u5b9e\u73b0\u3002\u4e0d\u8fc7\uff0c\u5b83\u7684\u6027\u80fd\u53ef\u80fd\u4f1a\u66f4\u597d\u3002\u5728\u4e0a\u56fe\u6240\u793a\u7684\u539f\u59cb\u5b9e\u73b0\u4e2d\uff0c\u6709\u4e00\u4e2a\u5355\u901a\u9053\u56fe\u50cf\uff0c\u8f93\u51fa\u4e2d\u6709\u4e24\u4e2a\u901a\u9053\uff1a\u4e00\u4e2a\u662f\u524d\u666f\uff0c\u4e00\u4e2a\u662f\u80cc\u666f\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u4e3a\u4efb\u610f\u6570\u91cf\u7684\u7c7b\u548c\u4efb\u610f\u6570\u91cf\u7684\u8f93\u5165\u901a\u9053\u8fdb\u884c\u5b9a\u5236\u3002\u5728\u6b64\u5b9e\u73b0\u4e2d\uff0c\u8f93\u5165\u56fe\u50cf\u7684\u5927\u5c0f\u4e0e\u8f93\u51fa\u56fe\u50cf\u7684\u5927\u5c0f\u4e0d\u540c\uff0c\u56e0\u4e3a\u6211\u4eec\u4f7f\u7528\u7684\u662f\u65e0\u586b\u5145\u5377\u79ef\uff08convolutions without padding\uff09\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0cU-Net \u7684\u7f16\u7801\u5668\u90e8\u5206\u53ea\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u5377\u79ef\u7f51\u7edc\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u4efb\u4f55\u7f51\u7edc\uff08\u5982 ResNet\uff09\u6765\u66ff\u6362\u5b83\u3002 \u8fd9\u79cd\u66ff\u6362\u4e5f\u53ef\u4ee5\u901a\u8fc7\u9884\u8bad\u7ec3\u6743\u91cd\u6765\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u57fa\u4e8e ResNet \u7684\u7f16\u7801\u5668\uff0c\u8be5\u7f16\u7801\u5668\u5df2\u5728 ImageNet \u548c\u901a\u7528\u89e3\u7801\u5668\u4e0a\u8fdb\u884c\u4e86\u9884\u8bad\u7ec3\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u591a\u79cd\u4e0d\u540c\u7684\u7f51\u7edc\u67b6\u6784\u6765\u4ee3\u66ff ResNet\u3002Pavel Yakubovskiy \u6240\u8457\u7684\u300aSegmentation Models Pytorch\u300b\u5c31\u662f\u8bb8\u591a\u6b64\u7c7b\u53d8\u4f53\u7684\u5b9e\u73b0\uff0c\u5176\u4e2d\u7f16\u7801\u5668\u53ef\u4ee5\u88ab\u9884\u8bad\u7ec3\u6a21\u578b\u6240\u53d6\u4ee3\u3002\u8ba9\u6211\u4eec\u5e94\u7528\u57fa\u4e8e ResNet \u7684 U-Net \u6765\u89e3\u51b3\u6c14\u80f8\u68c0\u6d4b\u95ee\u9898\u3002 \u5927\u591a\u6570\u7c7b\u4f3c\u7684\u95ee\u9898\u90fd\u6709\u4e24\u4e2a\u8f93\u5165\uff1a\u539f\u59cb\u56fe\u50cf\u548c\u63a9\u7801\uff08mask\uff09\u3002 \u5982\u679c\u6709\u591a\u4e2a\u5bf9\u8c61\uff0c\u5c31\u4f1a\u6709\u591a\u4e2a\u63a9\u7801\u3002 \u5728\u6211\u4eec\u7684\u6c14\u80f8\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f RLE\u3002RLE \u4ee3\u8868\u8fd0\u884c\u957f\u5ea6\u7f16\u7801\uff0c\u662f\u4e00\u79cd\u8868\u793a\u4e8c\u8fdb\u5236\u63a9\u7801\u4ee5\u8282\u7701\u7a7a\u95f4\u7684\u65b9\u6cd5\u3002\u6df1\u5165\u7814\u7a76 RLE \u8d85\u51fa\u4e86\u672c\u7ae0\u7684\u8303\u56f4\u3002\u56e0\u6b64\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u5f20\u8f93\u5165\u56fe\u50cf\u548c\u76f8\u5e94\u7684\u63a9\u7801\u3002\u8ba9\u6211\u4eec\u5148\u8bbe\u8ba1\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u8f93\u51fa\u56fe\u50cf\u548c\u63a9\u7801\u56fe\u50cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u521b\u5efa\u7684\u811a\u672c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u5206\u5272\u95ee\u9898\u3002\u8bad\u7ec3\u6570\u636e\u96c6\u662f\u4e00\u4e2a CSV \u6587\u4ef6\uff0c\u53ea\u5305\u542b\u56fe\u50cf ID\uff08\u4e5f\u662f\u6587\u4ef6\u540d\uff09\u3002 import os import glob import torch import numpy as np import pandas as pd from PIL import Image , ImageFile from tqdm import tqdm from collections import defaultdict from torchvision import transforms from albumentations import ( Compose , OneOf , RandomBrightnessContrast , RandomGamma , ShiftScaleRotate , ) # \u8bbe\u7f6ePIL\u56fe\u50cf\u52a0\u8f7d\u622a\u65ad\u7684\u5904\u7406 ImageFile . LOAD_TRUNCATED_IMAGES = True # \u521b\u5efaSIIM\u6570\u636e\u96c6\u7c7b class SIIMDataset ( torch . utils . data . Dataset ): def __init__ ( self , image_ids , transform = True , preprocessing_fn = None ): self . data = defaultdict ( dict ) self . transform = transform self . preprocessing_fn = preprocessing_fn # \u5b9a\u4e49\u6570\u636e\u589e\u5f3a self . aug = Compose ( [ ShiftScaleRotate ( shift_limit = 0.0625 , scale_limit = 0.1 , rotate_limit = 10 , p = 0.8 ), OneOf ( [ RandomGamma ( gamma_limit = ( 90 , 110 ) ), RandomBrightnessContrast ( brightness_limit = 0.1 , contrast_limit = 0.1 ), ], p = 0.5 , ), ] ) # \u6784\u5efa\u6570\u636e\u5b57\u5178\uff0c\u5176\u4e2d\u5305\u542b\u56fe\u50cf\u548c\u63a9\u7801\u7684\u8def\u5f84\u4fe1\u606f for imgid in image_ids : files = glob . glob ( os . path . join ( TRAIN_PATH , imgid , \"*.png\" )) self . data [ counter ] = { \"img_path\" : os . path . join ( TRAIN_PATH , imgid + \".png\" ), \"mask_path\" : os . path . join ( TRAIN_PATH , imgid + \"_mask.png\" ), } def __len__ ( self ): return len ( self . data ) def __getitem__ ( self , item ): img_path = self . data [ item ][ \"img_path\" ] mask_path = self . data [ item ][ \"mask_path\" ] # \u6253\u5f00\u56fe\u50cf\u5e76\u5c06\u5176\u8f6c\u6362\u4e3aRGB\u6a21\u5f0f img = Image . open ( img_path ) img = img . convert ( \"RGB\" ) img = np . array ( img ) # \u6253\u5f00\u63a9\u7801\u56fe\u50cf\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6570 mask = Image . open ( mask_path ) mask = ( mask >= 1 ) . astype ( \"float32\" ) # \u5982\u679c\u9700\u8981\u8fdb\u884c\u6570\u636e\u589e\u5f3a if self . transform is True : augmented = self . aug ( image = img , mask = mask ) img = augmented [ \"image\" ] mask = augmented [ \"mask\" ] # \u5e94\u7528\u9884\u5904\u7406\u51fd\u6570\uff08\u5982\u679c\u6709\uff09 img = self . preprocessing_fn ( img ) # \u8fd4\u56de\u56fe\u50cf\u548c\u63a9\u7801 return { \"image\" : transforms . ToTensor ()( img ), \"mask\" : transforms . ToTensor ()( mask ) . float (), } \u6709\u4e86\u6570\u636e\u96c6\u7c7b\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8bad\u7ec3\u51fd\u6570\u3002 import os import sys import torch import numpy as np import pandas as pd import segmentation_models_pytorch as smp import torch.nn as nn import torch.optim as optim from apex import amp from collections import OrderedDict from sklearn import model_selection from tqdm import tqdm from torch.optim import lr_scheduler from dataset import SIIMDataset # \u5b9a\u4e49\u8bad\u7ec3\u6570\u636e\u96c6CSV\u6587\u4ef6\u8def\u5f84 TRAINING_CSV = \"../input/train_pneumothorax.csv\" # \u5b9a\u4e49\u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u6279\u91cf\u5927\u5c0f TRAINING_BATCH_SIZE = 16 TEST_BATCH_SIZE = 4 # \u5b9a\u4e49\u8bad\u7ec3\u7684\u65f6\u671f\u6570 EPOCHS = 10 # \u6307\u5b9a\u4f7f\u7528\u7684\u7f16\u7801\u5668\u548c\u6743\u91cd ENCODER = \"resnet18\" ENCODER_WEIGHTS = \"imagenet\" # \u6307\u5b9a\u8bbe\u5907\uff08GPU\uff09 DEVICE = \"cuda\" # \u5b9a\u4e49\u8bad\u7ec3\u51fd\u6570 def train ( dataset , data_loader , model , criterion , optimizer ): model . train () num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs . to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) optimizer . zero_grad () outputs = model ( inputs ) loss = criterion ( outputs , targets ) with amp . scale_loss ( loss , optimizer ) as scaled_loss : scaled_loss . backward () optimizer . step () tk0 . close () # \u5b9a\u4e49\u8bc4\u4f30\u51fd\u6570 def evaluate ( dataset , data_loader , model ): model . eval () final_loss = 0 num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) with torch . no_grad (): for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) output = model ( inputs ) loss = criterion ( output , targets ) final_loss += loss tk0 . close () return final_loss / num_batches if __name__ == \"__main__\" : df = pd . read_csv ( TRAINING_CSV ) df_train , df_valid = model_selection . train_test_split ( df , random_state = 42 , test_size = 0.1 ) training_images = df_train . image_id . values validation_images = df_valid . image_id . values # \u521b\u5efa U-Net \u6a21\u578b model = smp . Unet ( encoder_name = ENCODER , encoder_weights = ENCODER_WEIGHTS , classes = 1 , activation = None , ) # \u83b7\u53d6\u6570\u636e\u9884\u5904\u7406\u51fd\u6570 prep_fn = smp . encoders . get_preprocessing_fn ( ENCODER , ENCODER_WEIGHTS ) # \u5c06\u6a21\u578b\u653e\u5728\u8bbe\u5907\u4e0a model . to ( DEVICE ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6 train_dataset = SIIMDataset ( training_images , transform = True , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = TRAINING_BATCH_SIZE , shuffle = True , num_workers = 12 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = SIIMDataset ( validation_images , transform = False , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = TEST_BATCH_SIZE , shuffle = True , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) # \u5b9a\u4e49\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = lr_scheduler . ReduceLROnPlateau ( optimizer , mode = \"min\" , patience = 3 , verbose = True ) # \u521d\u59cb\u5316 Apex \u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3 model , optimizer = amp . initialize ( model , optimizer , opt_level = \"O1\" , verbosity = 0 ) # \u5982\u679c\u6709\u591a\u4e2aGPU\uff0c\u5219\u4f7f\u7528 DataParallel \u8fdb\u884c\u5e76\u884c\u8bad\u7ec3 if torch . cuda . device_count () > 1 : print ( f \"Let's use { torch . cuda . device_count () } GPUs!\" ) model = nn . DataParallel ( model ) # \u8f93\u51fa\u8bad\u7ec3\u76f8\u5173\u7684\u4fe1\u606f print ( f \"Training batch size: { TRAINING_BATCH_SIZE } \" ) print ( f \"Test batch size: { TEST_BATCH_SIZE } \" ) print ( f \"Epochs: { EPOCHS } \" ) print ( f \"Image size: { IMAGE_SIZE } \" ) print ( f \"Number of training images: { len ( train_dataset ) } \" ) print ( f \"Number of validation images: { len ( valid_dataset ) } \" ) print ( f \"Encoder: { ENCODER } \" ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( EPOCHS ): print ( f \"Training Epoch: { epoch } \" ) train ( train_dataset , train_loader , model , criterion , optimizer ) print ( f \"Validation Epoch: { epoch } \" ) val_log = evaluate ( valid_dataset , valid_loader , model ) scheduler . step ( val_log [ \"loss\" ]) print ( \" \\n \" ) \u5728\u5206\u5272\u95ee\u9898\u4e2d\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u5404\u79cd\u635f\u5931\u51fd\u6570\uff0c\u4f8b\u5982\u4e8c\u5143\u4ea4\u53c9\u71b5\u3001focal\u635f\u5931\u3001dice\u635f\u5931\u7b49\u3002\u6211\u628a\u8fd9\u4e2a\u95ee\u9898\u7559\u7ed9 \u8bfb\u8005\u6839\u636e\u8bc4\u4f30\u6307\u6807\u6765\u51b3\u5b9a\u5408\u9002\u7684\u635f\u5931\u3002\u5f53\u8bad\u7ec3\u8fd9\u6837\u4e00\u4e2a\u6a21\u578b\u65f6\uff0c\u60a8\u5c06\u5efa\u7acb\u9884\u6d4b\u6c14\u80f8\u4f4d\u7f6e\u7684\u6a21\u578b\uff0c\u5982\u56fe 9 \u6240\u793a\u3002\u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u82f1\u4f1f\u8fbe apex \u8fdb\u884c\u4e86\u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4ece PyTorch 1.6.0+ \u7248\u672c\u5f00\u59cb\uff0cPyTorch \u672c\u8eab\u5c31\u63d0\u4f9b\u4e86\u8fd9\u4e00\u529f\u80fd\u3002 \u56fe 9\uff1a\u4ece\u8bad\u7ec3\u6709\u7d20\u7684\u6a21\u578b\u4e2d\u68c0\u6d4b\u5230\u6c14\u80f8\u7684\u793a\u4f8b\uff08\u53ef\u80fd\u4e0d\u662f\u6b63\u786e\u9884\u6d4b\uff09\u3002 \u6211\u5728\u4e00\u4e2a\u540d\u4e3a \"Well That's Fantastic Machine Learning (WTFML) \"\u7684 python \u8f6f\u4ef6\u5305\u4e2d\u6536\u5f55\u4e86\u4e00\u4e9b\u5e38\u7528\u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u5982\u4f55\u5e2e\u52a9\u6211\u4eec\u4e3a FGVC 202013 \u690d\u7269\u75c5\u7406\u5b66\u6311\u6218\u8d5b\u4e2d\u7684\u690d\u7269\u56fe\u50cf\u5efa\u7acb\u591a\u7c7b\u5206\u7c7b\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import argparse import torch import torchvision import torch.nn as nn import torch.nn.functional as F from sklearn import metrics from sklearn.model_selection import train_test_split from wtfml.engine import Engine from wtfml.data_loaders.image import ClassificationDataLoader # \u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\uff0c\u5b9e\u73b0\u5bc6\u96c6\u4ea4\u53c9\u71b5 class DenseCrossEntropy ( nn . Module ): def __init__ ( self ): super ( DenseCrossEntropy , self ) . __init__ () def forward ( self , logits , labels ): logits = logits . float () labels = labels . float () logprobs = F . log_softmax ( logits , dim =- 1 ) loss = - labels * logprobs loss = loss . sum ( - 1 ) return loss . mean () # \u81ea\u5b9a\u4e49\u795e\u7ecf\u7f51\u7edc\u6a21\u578b class Model ( nn . Module ): def __init__ ( self ): super () . __init () self . base_model = torchvision . models . resnet18 ( pretrained = True ) in_features = self . base_model . fc . in_features self . out = nn . Linear ( in_features , 4 ) def forward ( self , image , targets = None ): batch_size , C , H , W = image . shape x = self . base_model . conv1 ( image ) x = self . base_model . bn1 ( x ) x = self . base_model . relu ( x ) x = self . base_model . maxpool ( x ) x = self . base_model . layer1 ( x ) x = self . base_model . layer2 ( x ) x = self . base_model . layer3 ( x ) x = self . base_model . layer4 ( x ) x = F . adaptive_avg_pool2d ( x , 1 ) . reshape ( batch_size , - 1 ) x = self . out ( x ) loss = None if targets is not None : loss = DenseCrossEntropy ()( x , targets . type_as ( x )) return x , loss if __name__ == \"__main__\" : # \u547d\u4ee4\u884c\u53c2\u6570\u89e3\u6790\u5668 parser = argparse . ArgumentParser () parser . add_argument ( \"--data_path\" , type = str , ) parser . add_argument ( \"--device\" , type = str ,) parser . add_argument ( \"--epochs\" , type = int ,) args = parser . parse_args () # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( os . path . join ( args . data_path , \"train.csv\" )) images = df . image_id . values . tolist () images = [ os . path . join ( args . data_path , \"images\" , i + \".jpg\" ) for i in images ] targets = df [[ \"healthy\" , \"multiple_diseases\" , \"rust\" , \"scab\" ]] . values # \u521b\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u578b model = Model () model . to ( args . device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\u4ee5\u53ca\u6570\u636e\u589e\u5f3a mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5206\u5272\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 ( train_images , valid_images , train_targets , valid_targets ) = train_test_split ( images , targets ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = ClassificationDataLoader ( image_paths = train_images , targets = train_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = True , tpu = False ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = ClassificationDataLoader ( image_paths = valid_images , targets = valid_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = False , tpu = False ) # \u521b\u5efa\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u521b\u5efa\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = torch . optim . lr_scheduler . StepLR ( optimizer , step_size = 15 , gamma = 0.6 ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( args . epochs ): # \u8bad\u7ec3\u6a21\u578b train_loss = Engine . train ( train_loader , model , optimizer , device = args . device ) # \u8bc4\u4f30\u6a21\u578b valid_loss = Engine . evaluate ( valid_loader , model , device = args . device ) # \u6253\u5370\u635f\u5931\u4fe1\u606f print ( f \" { epoch } , Train Loss= { train_loss } Valid Loss= { valid_loss } \" ) \u6709\u4e86\u6570\u636e\u540e\uff0c\u5c31\u53ef\u4ee5\u8fd0\u884c\u811a\u672c\u4e86\uff1a python plant.py --data_path ../../plant_pathology --device cuda -- epochs 2 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .73it/s, loss = 0 .723 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .62it/s, loss = 0 .433 ] 0 , Train Loss = 0 .7228777609592261 Valid Loss = 0 .4327834551704341 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .74it/s, loss = 0 .271 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .63it/s, loss = 0 .568 ] 1 , Train Loss = 0 .2708700496790021 Valid Loss = 0 .56841839541649 \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u8ba9\u6211\u4eec\u6784\u5efa\u6a21\u578b\u53d8\u5f97\u7b80\u5355\uff0c\u4ee3\u7801\u4e5f\u6613\u4e8e\u9605\u8bfb\u548c\u7406\u89e3\u3002\u6ca1\u6709\u4efb\u4f55\u5c01\u88c5\u7684 PyTorch \u6548\u679c\u6700\u597d\u3002\u56fe\u50cf\u4e2d\u4e0d\u4ec5\u4ec5\u6709\u5206\u7c7b\uff0c\u8fd8\u6709\u5f88\u591a\u5176\u4ed6\u7684\u5185\u5bb9\uff0c\u5982\u679c\u6211\u5f00\u59cb\u5199\u6240\u6709\u7684\u5185\u5bb9\uff0c\u5c31\u5f97\u518d\u5199\u4e00\u672c\u4e66\u4e86\uff0c \u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u56fe\u50cf\u95ee\u9898\uff08\u4f5c\u8005\u5728\u5f00\u73a9\u7b11\uff09\u3002","title":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5"},{"location":"%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB%E5%92%8C%E5%88%86%E5%89%B2%E6%96%B9%E6%B3%95/#_1","text":"\u8bf4\u5230\u56fe\u50cf\uff0c\u8fc7\u53bb\u51e0\u5e74\u53d6\u5f97\u4e86\u5f88\u591a\u6210\u5c31\u3002\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8fdb\u6b65\u76f8\u5f53\u5feb\uff0c\u611f\u89c9\u8ba1\u7b97\u673a\u89c6\u89c9\u7684\u8bb8\u591a\u95ee\u9898\u73b0\u5728\u90fd\u66f4\u5bb9\u6613\u89e3\u51b3\u4e86\u3002\u968f\u7740\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u51fa\u73b0\u548c\u8ba1\u7b97\u6210\u672c\u7684\u964d\u4f4e\uff0c\u73b0\u5728\u5728\u5bb6\u91cc\u5c31\u80fd\u8f7b\u677e\u8bad\u7ec3\u51fa\u63a5\u8fd1\u6700\u5148\u8fdb\u6c34\u5e73\u7684\u6a21\u578b\uff0c\u89e3\u51b3\u5927\u591a\u6570\u4e0e\u56fe\u50cf\u76f8\u5173\u7684\u95ee\u9898\u3002\u4f46\u662f\uff0c\u56fe\u50cf\u95ee\u9898\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u578b\u3002\u4ece\u4e24\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u7684\u6807\u51c6\u56fe\u50cf\u5206\u7c7b\uff0c\u5230\u50cf\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\u8fd9\u6837\u5177\u6709\u6311\u6218\u6027\u7684\u95ee\u9898\u3002\u6211\u4eec\u4e0d\u4f1a\u5728\u672c\u4e66\u4e2d\u8ba8\u8bba\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\uff0c\u4f46\u6211\u4eec\u663e\u7136\u4f1a\u5904\u7406\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u56fe\u50cf\u95ee\u9898\u3002 \u6211\u4eec\u53ef\u4ee5\u5bf9\u56fe\u50cf\u91c7\u7528\u54ea\u4e9b\u4e0d\u540c\u7684\u65b9\u6cd5\uff1f\u56fe\u50cf\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u6570\u5b57\u77e9\u9635\u3002\u8ba1\u7b97\u673a\u65e0\u6cd5\u50cf\u4eba\u7c7b\u4e00\u6837\u770b\u5230\u56fe\u50cf\u3002\u5b83\u53ea\u80fd\u770b\u5230\u6570\u5b57\uff0c\u8fd9\u5c31\u662f\u56fe\u50cf\u3002\u7070\u5ea6\u56fe\u50cf\u662f\u4e00\u4e2a\u4e8c\u7ef4\u77e9\u9635\uff0c\u6570\u503c\u8303\u56f4\u4ece 0 \u5230 255\u30020 \u4ee3\u8868\u9ed1\u8272\uff0c255 \u4ee3\u8868\u767d\u8272\uff0c\u4ecb\u4e8e\u4e24\u8005\u4e4b\u95f4\u7684\u662f\u5404\u79cd\u7070\u8272\u3002\u4ee5\u524d\uff0c\u5728\u6ca1\u6709\u6df1\u5ea6\u5b66\u4e60\u7684\u65f6\u5019\uff08\u6216\u8005\u8bf4\u6df1\u5ea6\u5b66\u4e60\u8fd8\u4e0d\u6d41\u884c\u7684\u65f6\u5019\uff09\uff0c\u4eba\u4eec\u4e60\u60ef\u4e8e\u67e5\u770b\u50cf\u7d20\u3002\u6bcf\u4e2a\u50cf\u7d20\u90fd\u662f\u4e00\u4e2a\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u5728 Python \u4e2d\u8f7b\u677e\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u53ea\u9700\u4f7f\u7528 OpenCV \u6216 Python-PIL \u8bfb\u53d6\u7070\u5ea6\u56fe\u50cf\uff0c\u8f6c\u6362\u4e3a numpy \u6570\u7ec4\uff0c\u7136\u540e\u5c06\u77e9\u9635\u5e73\u94fa\uff08\u6241\u5e73\u5316\uff09\u5373\u53ef\u3002\u5982\u679c\u5904\u7406\u7684\u662f RGB \u56fe\u50cf\uff0c\u5219\u9700\u8981\u4e09\u4e2a\u77e9\u9635\uff0c\u800c\u4e0d\u662f\u4e00\u4e2a\u3002\u4f46\u601d\u8def\u662f\u4e00\u6837\u7684\u3002 import numpy as np import matplotlib.pyplot as plt # \u751f\u6210\u4e00\u4e2a 256x256 \u7684\u968f\u673a\u7070\u5ea6\u56fe\u50cf\uff0c\u50cf\u7d20\u503c\u57280\u5230255\u4e4b\u95f4\u968f\u673a\u5206\u5e03 random_image = np . random . randint ( 0 , 256 , ( 256 , 256 )) # \u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u56fe\u50cf\u7a97\u53e3\uff0c\u8bbe\u7f6e\u7a97\u53e3\u5927\u5c0f\u4e3a7x7\u82f1\u5bf8 plt . figure ( figsize = ( 7 , 7 )) # \u663e\u793a\u751f\u6210\u7684\u968f\u673a\u56fe\u50cf # \u4f7f\u7528\u7070\u5ea6\u989c\u8272\u6620\u5c04 (colormap)\uff0c\u8303\u56f4\u4ece0\u5230255 plt . imshow ( random_image , cmap = 'gray' , vmin = 0 , vmax = 255 ) # \u663e\u793a\u56fe\u50cf\u7a97\u53e3 plt . show () \u4e0a\u9762\u7684\u4ee3\u7801\u4f7f\u7528 numpy \u751f\u6210\u4e00\u4e2a\u968f\u673a\u77e9\u9635\u3002\u8be5\u77e9\u9635\u7531 0 \u5230 255\uff08\u5305\u542b\uff09\u7684\u503c\u7ec4\u6210\uff0c\u5927\u5c0f\u4e3a 256x256\uff08\u4e5f\u79f0\u4e3a\u50cf\u7d20\uff09\u3002 \u56fe 1\uff1a\u4e8c\u7ef4\u56fe\u50cf\u9635\u5217\uff08\u5355\u901a\u9053\uff09\u53ca\u5176\u5c55\u5e73\u7248\u672c \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u62fc\u5199\u540e\u7684\u7248\u672c\u53ea\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a M \u7684\u5411\u91cf\uff0c\u5176\u4e2d M = N * N\uff0c\u5728\u672c\u4f8b\u4e2d\uff0c\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 256 * 256 = 65536\u3002 \u73b0\u5728\uff0c\u5982\u679c\u6211\u4eec\u7ee7\u7eed\u5bf9\u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u8fdb\u884c\u5904\u7406\uff0c\u6bcf\u4e2a\u6837\u672c\u5c31\u4f1a\u6709 65536 \u4e2a\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5feb\u901f\u5efa\u7acb \u51b3\u7b56\u6811\u6a21\u578b\u3001\u968f\u673a\u68ee\u6797\u6a21\u578b\u6216\u57fa\u4e8e SVM \u7684\u6a21\u578b \u3002\u8fd9\u4e9b\u6a21\u578b\u5c06\u57fa\u4e8e\u50cf\u7d20\u503c\uff0c\u5c1d\u8bd5\u5c06\u6b63\u6837\u672c\u4e0e\u8d1f\u6837\u672c\u533a\u5206\u5f00\u6765\uff08\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff09\u3002 \u4f60\u4eec\u4e00\u5b9a\u90fd\u542c\u8bf4\u8fc7\u732b\u4e0e\u72d7\u7684\u95ee\u9898\uff0c\u8fd9\u662f\u4e00\u4e2a\u7ecf\u5178\u7684\u95ee\u9898\u3002\u5982\u679c\u4f60\u4eec\u8fd8\u8bb0\u5f97\uff0c\u5728\u8bc4\u4f30\u6307\u6807\u4e00\u7ae0\u7684\u5f00\u5934\uff0c\u6211\u5411\u4f60\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e2a\u6c14\u80f8\u56fe\u50cf\u6570\u636e\u96c6\u3002\u90a3\u4e48\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u68c0\u6d4b\u80ba\u90e8\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\uff08\u5e76\u4e0d\uff09\u7b80\u5355\u7684\u4e8c\u5143\u5206\u7c7b\u3002 \u56fe 2\uff1a\u975e\u6c14\u80f8\u4e0e\u6c14\u80f8 X \u5149\u56fe\u50cf\u5bf9\u6bd4 \u5728\u56fe 2 \u4e2d\uff0c\u60a8\u53ef\u4ee5\u770b\u5230\u975e\u6c14\u80f8\u548c\u6c14\u80f8\u56fe\u50cf\u7684\u5bf9\u6bd4\u3002\u60a8\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\u4e86\uff0c\u5bf9\u4e8e\u4e00\u4e2a\u975e\u4e13\u4e1a\u4eba\u58eb\uff08\u6bd4\u5982\u6211\uff09\u6765\u8bf4\uff0c\u8981\u5728\u8fd9\u4e9b\u56fe\u50cf\u4e2d\u8fa8\u522b\u51fa\u54ea\u4e2a\u662f\u6c14\u80f8\u662f\u76f8\u5f53\u56f0\u96be\u7684\u3002 \u6700\u521d\u7684\u6570\u636e\u96c6\u662f\u5173\u4e8e\u68c0\u6d4b\u6c14\u80f8\u7684\u5177\u4f53\u4f4d\u7f6e\uff0c\u4f46\u6211\u4eec\u5c06\u95ee\u9898\u4fee\u6539\u4e3a\u67e5\u627e\u7ed9\u5b9a\u7684 X \u5149\u56fe\u50cf\u662f\u5426\u5b58\u5728\u6c14\u80f8\u3002\u522b\u62c5\u5fc3\uff0c\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u4ecb\u7ecd\u8fd9\u4e2a\u90e8\u5206\u3002\u6570\u636e\u96c6\u7531 10675 \u5f20\u72ec\u7279\u7684\u56fe\u50cf\u7ec4\u6210\uff0c\u5176\u4e2d 2379 \u5f20\u6709\u6c14\u80f8\uff08\u6ce8\u610f\uff0c\u8fd9\u4e9b\u6570\u5b57\u662f\u7ecf\u8fc7\u6570\u636e\u6e05\u7406\u540e\u5f97\u51fa\u7684\uff0c\u56e0\u6b64\u4e0e\u539f\u59cb\u6570\u636e\u96c6\u4e0d\u7b26\uff09\u3002\u6b63\u5982\u6570\u636e\u79d1\u5b66\u5bb6\u6240\u8bf4\uff1a\u8fd9\u662f\u4e00\u4e2a\u5178\u578b\u7684 \u504f\u659c\u4e8c\u5143\u5206\u7c7b\u6848\u4f8b \u3002\u56e0\u6b64\uff0c\u6211\u4eec\u9009\u62e9 AUC \u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\uff0c\u5e76\u91c7\u7528\u5206\u5c42 k \u6298\u4ea4\u53c9\u9a8c\u8bc1\u65b9\u6848\u3002 \u60a8\u53ef\u4ee5\u5c06\u7279\u5f81\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c1d\u8bd5\u4e00\u4e9b\u7ecf\u5178\u65b9\u6cd5\uff08\u5982 SVM\u3001RF\uff09\u6765\u8fdb\u884c\u5206\u7c7b\uff0c\u8fd9\u5b8c\u5168\u6ca1\u95ee\u9898\uff0c\u4f46\u5374\u65e0\u6cd5\u8ba9\u60a8\u8fbe\u5230\u6700\u5148\u8fdb\u7684\u6c34\u5e73\u3002\u6b64\u5916\uff0c\u56fe\u50cf\u5927\u5c0f\u4e3a 1024x1024\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u6a21\u578b\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u7ba1\u600e\u6837\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u7531\u4e8e\u56fe\u50cf\u662f\u7070\u5ea6\u7684\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u8fdb\u884c\u4efb\u4f55\u8f6c\u6362\u3002\u6211\u4eec\u5c06\u628a\u56fe\u50cf\u5927\u5c0f\u8c03\u6574\u4e3a 256x256\uff0c\u4f7f\u5176\u66f4\u5c0f\uff0c\u5e76\u4f7f\u7528\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684 AUC \u4f5c\u4e3a\u8861\u91cf\u6307\u6807\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5b83\u7684\u8868\u73b0\u5982\u4f55\u3002 import os import numpy as np import pandas as pd from PIL import Image from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from tqdm import tqdm # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u6765\u521b\u5efa\u6570\u636e\u96c6 def create_dataset ( training_df , image_dir ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\u6765\u5b58\u50a8\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c images = [] targets = [] # \u8fed\u4ee3\u5904\u7406\u8bad\u7ec3\u6570\u636e\u96c6\u4e2d\u7684\u6bcf\u4e00\u884c for index , row in tqdm ( training_df . iterrows (), total = len ( training_df ), desc = \"processing images\" ): # \u83b7\u53d6\u56fe\u50cf\u6587\u4ef6\u540d image_id = row [ \"ImageId\" ] # \u6784\u5efa\u5b8c\u6574\u7684\u56fe\u50cf\u6587\u4ef6\u8def\u5f84 image_path = os . path . join ( image_dir , image_id ) # \u6253\u5f00\u56fe\u50cf\u6587\u4ef6\u5e76\u8fdb\u884c\u5927\u5c0f\u8c03\u6574\uff08resize\uff09\u4e3a 256x256 \u50cf\u7d20\uff0c\u4f7f\u7528\u53cc\u7ebf\u6027\u63d2\u503c\uff08BILINEAR\uff09 image = Image . open ( image_path + \".png\" ) image = image . resize (( 256 , 256 ), resample = Image . BILINEAR ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 image = np . array ( image ) # \u5c06\u56fe\u50cf\u6241\u5e73\u5316\u4e3a\u4e00\u7ef4\u6570\u7ec4\uff0c\u5e76\u5c06\u5176\u6dfb\u52a0\u5230\u56fe\u50cf\u5217\u8868 image = image . ravel () images . append ( image ) # \u5c06\u76ee\u6807\u503c\uff08target\uff09\u6dfb\u52a0\u5230\u76ee\u6807\u5217\u8868 targets . append ( int ( row [ \"target\" ])) # \u5c06\u56fe\u50cf\u5217\u8868\u8f6c\u6362\u4e3aNumPy\u6570\u7ec4 images = np . array ( images ) # \u6253\u5370\u56fe\u50cf\u6570\u7ec4\u7684\u5f62\u72b6 print ( images . shape ) # \u8fd4\u56de\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c return images , targets if __name__ == \"__main__\" : # \u5b9a\u4e49CSV\u6587\u4ef6\u8def\u5f84\u548c\u56fe\u50cf\u6587\u4ef6\u76ee\u5f55\u8def\u5f84 csv_path = \"/home/abhishek/workspace/siim_png/train.csv\" image_path = \"/home/abhishek/workspace/siim_png/train_png/\" # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( csv_path ) # \u6dfb\u52a0\u4e00\u4e2a\u540d\u4e3a'kfold'\u7684\u5217\uff0c\u5e76\u521d\u59cb\u5316\u4e3a-1 df [ \"kfold\" ] = - 1 # \u968f\u673a\u6253\u4e71\u6570\u636e df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u83b7\u53d6\u76ee\u6807\u503c\uff08target\uff09 y = df . target . values # \u4f7f\u7528\u5206\u5c42KFold\u4ea4\u53c9\u9a8c\u8bc1\u5c06\u6570\u636e\u96c6\u5206\u62105\u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u904d\u5386\u6bcf\u4e2a\u6298\uff08fold\uff09 for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f # \u904d\u5386\u6bcf\u4e2a\u6298 for fold_ in range ( 5 ): # \u83b7\u53d6\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtrain , ytrain = create_dataset ( train_df , image_path ) # \u521b\u5efa\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u56fe\u50cf\u6570\u636e\u548c\u76ee\u6807\u503c xtest , ytest = create_dataset ( test_df , image_path ) # \u521d\u59cb\u5316\u4e00\u4e2a\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668 clf = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u5206\u7c7b\u5668 clf . fit ( xtrain , ytrain ) # \u4f7f\u7528\u5206\u7c7b\u5668\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b\uff0c\u5e76\u83b7\u53d6\u6982\u7387\u503c preds = clf . predict_proba ( xtest )[:, 1 ] # \u6253\u5370\u6298\u6570\uff08fold\uff09\u548cAUC\uff08ROC\u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff09 print ( f \"FOLD: { fold_ } \" ) print ( f \"AUC = { metrics . roc_auc_score ( ytest , preds ) } \" ) print ( \"\" ) \u5e73\u5747 AUC \u503c\u7ea6\u4e3a 0.72\u3002\u8fd9\u8fd8\u4e0d\u9519\uff0c\u4f46\u6211\u4eec\u5e0c\u671b\u80fd\u505a\u5f97\u66f4\u597d\u3002\u4f60\u53ef\u4ee5\u5c06\u8fd9\u79cd\u65b9\u6cd5\u7528\u4e8e\u56fe\u50cf\uff0c\u8fd9\u4e5f\u662f\u5b83\u5728\u4ee5\u524d\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002SVM \u5728\u56fe\u50cf\u6570\u636e\u96c6\u65b9\u9762\u76f8\u5f53\u6709\u540d\u3002\u6df1\u5ea6\u5b66\u4e60\u5df2\u88ab\u8bc1\u660e\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u6700\u5148\u8fdb\u65b9\u6cd5\uff0c\u56e0\u6b64\u6211\u4eec\u4e0b\u4e00\u6b65\u53ef\u4ee5\u8bd5\u8bd5\u5b83\u3002 \u5173\u4e8e\u6df1\u5ea6\u5b66\u4e60\u7684\u5386\u53f2\u4ee5\u53ca\u8c01\u53d1\u660e\u4e86\u4ec0\u4e48\uff0c\u6211\u5c31\u4e0d\u591a\u8bf4\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u6700\u8457\u540d\u7684\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\u4e4b\u4e00 AlexNet\u3002 \u56fe 3\uff1aAlexNet \u67b6\u67849 \u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u4e2d\u7684\u8f93\u5165\u5927\u5c0f\u4e0d\u662f 224x224 \u800c\u662f 227x227 \u5982\u4eca\uff0c\u4f60\u53ef\u80fd\u4f1a\u8bf4\u8fd9\u53ea\u662f\u4e00\u4e2a\u57fa\u672c\u7684 \u6df1\u5ea6\u5377\u79ef\u795e\u7ecf\u7f51\u7edc \uff0c\u4f46\u5b83\u5374\u662f\u8bb8\u591a\u65b0\u578b\u6df1\u5ea6\u7f51\u7edc\uff08\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff09\u7684\u57fa\u7840\u3002\u6211\u4eec\u770b\u5230\uff0c\u56fe 3 \u4e2d\u7684\u7f51\u7edc\u662f\u4e00\u4e2a\u5177\u6709\u4e94\u4e2a\u5377\u79ef\u5c42\u3001\u4e24\u4e2a\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u3002\u6211\u4eec\u770b\u5230\u8fd8\u6709\u6700\u5927\u6c60\u5316\u3002\u8fd9\u662f\u4ec0\u4e48\u610f\u601d\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5728\u8fdb\u884c\u6df1\u5ea6\u5b66\u4e60\u65f6\u4f1a\u9047\u5230\u7684\u4e00\u4e9b\u672f\u8bed\u3002 \u56fe 4\uff1a\u56fe\u50cf\u5927\u5c0f\u4e3a 8x8\uff0c\u6ee4\u6ce2\u5668\u5927\u5c0f\u4e3a 3x3\uff0c\u6b65\u957f\u4e3a 2\u3002 \u56fe 4 \u5f15\u5165\u4e86\u4e24\u4e2a\u65b0\u672f\u8bed\uff1a\u6ee4\u6ce2\u5668\u548c\u6b65\u957f\u3002 \u6ee4\u6ce2\u5668 \u662f\u7531\u7ed9\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u7684\u4e8c\u7ef4\u77e9\u9635\uff0c\u7531\u6307\u5b9a\u51fd\u6570\u521d\u59cb\u5316\u3002 Kaiming\u6b63\u6001\u521d\u59cb\u5316 \uff0c\u662f\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u7684\u6700\u4f73\u9009\u62e9\u3002\u8fd9\u662f\u56e0\u4e3a\u5927\u591a\u6570\u73b0\u4ee3\u7f51\u7edc\u90fd\u4f7f\u7528 ReLU \uff08\u6574\u6d41\u7ebf\u6027\u5355\u5143\uff09\u6fc0\u6d3b\u51fd\u6570\uff0c\u9700\u8981\u9002\u5f53\u7684\u521d\u59cb\u5316\u6765\u907f\u514d\u68af\u5ea6\u6d88\u5931\u95ee\u9898\uff08\u68af\u5ea6\u8d8b\u8fd1\u4e8e\u96f6\uff0c\u7f51\u7edc\u6743\u91cd\u4e0d\u53d8\uff09\u3002\u8be5\u6ee4\u6ce2\u5668\u4e0e\u56fe\u50cf\u8fdb\u884c\u5377\u79ef\u3002\u5377\u79ef\u4e0d\u8fc7\u662f\u6ee4\u6ce2\u5668\u4e0e\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u5f53\u524d\u91cd\u53e0\u50cf\u7d20\u4e4b\u95f4\u7684\u5143\u7d20\u76f8\u4e58\u7684\u603b\u548c\u3002\u60a8\u53ef\u4ee5\u5728\u4efb\u4f55\u9ad8\u4e2d\u6570\u5b66\u6559\u79d1\u4e66\u4e2d\u9605\u8bfb\u66f4\u591a\u5173\u4e8e\u5377\u79ef\u7684\u5185\u5bb9\u3002\u6211\u4eec\u4ece\u56fe\u50cf\u7684\u5de6\u4e0a\u89d2\u5f00\u59cb\u5bf9\u6ee4\u955c\u8fdb\u884c\u5377\u79ef\uff0c\u7136\u540e\u6c34\u5e73\u79fb\u52a8\u6ee4\u955c\u3002\u5982\u679c\u79fb\u52a8 1 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 1\uff1b\u5982\u679c\u79fb\u52a8 2 \u4e2a\u50cf\u7d20\uff0c\u5219\u6b65\u957f\u4e3a 2\u3002 \u5373\u4f7f\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u4f8b\u5982\u5728\u95ee\u9898\u548c\u56de\u7b54\u7cfb\u7edf\u4e2d\u9700\u8981\u4ece\u5927\u91cf\u6587\u672c\u8bed\u6599\u4e2d\u7b5b\u9009\u7b54\u6848\u65f6\uff0c\u6b65\u957f\u4e5f\u662f\u4e00\u4e2a\u975e\u5e38\u6709\u7528\u7684\u6982\u5ff5\u3002\u5f53\u6211\u4eec\u5728\u6c34\u5e73\u65b9\u5411\u4e0a\u8d70\u5230\u5c3d\u5934\u65f6\uff0c\u5c31\u4f1a\u4ee5\u540c\u6837\u7684\u6b65\u957f\u5782\u76f4\u5411\u4e0b\u79fb\u52a8\u8fc7\u6ee4\u5668\uff0c\u4ece\u5de6\u4fa7\u5f00\u59cb\u3002\u56fe 4 \u8fd8\u663e\u793a\u4e86\u8fc7\u6ee4\u5668\u79fb\u51fa\u56fe\u50cf\u7684\u60c5\u51b5\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u65e0\u6cd5\u8ba1\u7b97\u5377\u79ef\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u8df3\u8fc7\u5b83\u3002\u5982\u679c\u4e0d\u60f3\u8df3\u8fc7\uff0c\u5219\u9700\u8981\u5bf9\u56fe\u50cf\u8fdb\u884c \u586b\u5145\uff08pad\uff09 \u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5377\u79ef\u4f1a\u51cf\u5c0f\u56fe\u50cf\u7684\u5927\u5c0f\u3002\u586b\u5145\u4e5f\u662f\u4fdd\u6301\u56fe\u50cf\u5927\u5c0f\u4e0d\u53d8\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u5728\u56fe 4 \u4e2d\uff0c\u4e00\u4e2a 3x3 \u6ee4\u6ce2\u5668\u6b63\u5728\u6c34\u5e73\u548c\u5782\u76f4\u79fb\u52a8\uff0c\u6bcf\u6b21\u79fb\u52a8\u90fd\u4f1a\u5206\u522b\u8df3\u8fc7\u4e24\u5217\u548c\u4e24\u884c\uff08\u5373\u50cf\u7d20\uff09\u3002\u7531\u4e8e\u5b83\u8df3\u8fc7\u4e86\u4e24\u4e2a\u50cf\u7d20\uff0c\u6240\u4ee5\u6b65\u957f = 2\u3002\u56e0\u6b64\u56fe\u50cf\u5927\u5c0f\u4e3a [(8-3) / 2] + 1 = 3.5\u3002\u6211\u4eec\u53d6 3.5 \u7684\u4e0b\u9650\uff0c\u6240\u4ee5\u662f 3x3\u3002\u60a8\u53ef\u4ee5\u5728\u8349\u7a3f\u7eb8\u4e0a\u8fdb\u884c\u5c1d\u8bd5\u3002 \u56fe 5\uff1a\u901a\u8fc7\u586b\u5145\uff0c\u6211\u4eec\u53ef\u4ee5\u63d0\u4f9b\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\u7684\u56fe\u50cf \u6211\u4eec\u53ef\u4ee5\u4ece\u56fe 5 \u4e2d\u770b\u5230\u586b\u5145\u7684\u6548\u679c\u3002\u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e00\u4e2a 3x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5b83\u4ee5 1 \u7684\u6b65\u957f\u79fb\u52a8\u3002\u539f\u59cb\u56fe\u50cf\u7684\u5927\u5c0f\u4e3a 6x6\uff0c\u6211\u4eec\u6dfb\u52a0\u4e86 1 \u7684 \u586b\u5145 \u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u751f\u6210\u7684\u56fe\u50cf\u5c06\u4e0e\u8f93\u5165\u56fe\u50cf\u5927\u5c0f\u76f8\u540c\uff0c\u5373 6x6\u3002\u5728\u5904\u7406\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u65f6\u53ef\u80fd\u4f1a\u9047\u5230\u7684\u53e6\u4e00\u4e2a\u76f8\u5173\u672f\u8bed\u662f \u81a8\u80c0\uff08dilation\uff09 \uff0c\u5982\u56fe 6 \u6240\u793a\u3002 \u56fe 6\uff1a\u81a8\u80c0\uff08dilation\uff09\u7684\u4f8b\u5b50 \u5728\u81a8\u80c0\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u5c06\u6ee4\u6ce2\u5668\u6269\u5927 N-1\uff0c\u5176\u4e2d N \u662f\u81a8\u80c0\u7387\u7684\u503c\uff0c\u6216\u7b80\u79f0\u4e3a\u81a8\u80c0\u3002\u5728\u8fd9\u79cd\u5e26\u81a8\u80c0\u7684\u5185\u6838\u4e2d\uff0c\u6bcf\u6b21\u5377\u79ef\u90fd\u4f1a\u8df3\u8fc7\u4e00\u4e9b\u50cf\u7d20\u3002\u8fd9\u5728\u5206\u5272\u4efb\u52a1\u4e2d\u5c24\u4e3a\u6709\u6548\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u4e8c\u7ef4\u5377\u79ef\u3002 \u8fd8\u6709\u4e00\u7ef4\u5377\u79ef\u548c\u66f4\u9ad8\u7ef4\u5ea6\u7684\u5377\u79ef\u3002\u5b83\u4eec\u90fd\u57fa\u4e8e\u76f8\u540c\u7684\u57fa\u672c\u6982\u5ff5\u3002 \u63a5\u4e0b\u6765\u662f \u6700\u5927\u6c60\u5316\uff08Max pooling\uff09 \u3002\u6700\u5927\u503c\u6c60\u53ea\u662f\u4e00\u4e2a\u8fd4\u56de\u6700\u5927\u503c\u7684\u6ee4\u6ce2\u5668\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u63d0\u53d6\u7684\u4e0d\u662f\u5377\u79ef\uff0c\u800c\u662f\u50cf\u7d20\u7684\u6700\u5927\u503c\u3002\u540c\u6837\uff0c \u5e73\u5747\u6c60\u5316\uff08average pooling\uff09 \u6216 \u5747\u503c\u6c60\u5316\uff08mean pooling\uff09 \u4f1a\u8fd4\u56de\u50cf\u7d20\u7684\u5e73\u5747\u503c\u3002\u5b83\u4eec\u7684\u4f7f\u7528\u65b9\u6cd5\u4e0e\u5377\u79ef\u6838\u76f8\u540c\u3002\u6c60\u5316\u6bd4\u5377\u79ef\u66f4\u5feb\uff0c\u662f\u4e00\u79cd\u5bf9\u56fe\u50cf\u8fdb\u884c\u7f29\u51cf\u91c7\u6837\u7684\u65b9\u6cd5\u3002\u6700\u5927\u6c60\u5316\u53ef\u68c0\u6d4b\u8fb9\u7f18\uff0c\u5e73\u5747\u6c60\u5316\u53ef\u5e73\u6ed1\u56fe\u50cf\u3002 \u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u548c\u6df1\u5ea6\u5b66\u4e60\u7684\u6982\u5ff5\u592a\u591a\u4e86\u3002\u6211\u6240\u8ba8\u8bba\u7684\u662f\u4e00\u4e9b\u57fa\u7840\u77e5\u8bc6\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5728 PyTorch \u4e2d\u6784\u5efa\u7b2c\u4e00\u4e2a\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u505a\u597d\u4e86\u5145\u5206\u51c6\u5907\u3002PyTorch \u63d0\u4f9b\u4e86\u4e00\u79cd\u76f4\u89c2\u800c\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u5b9e\u73b0\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e14\u4f60\u4e0d\u9700\u8981\u5173\u5fc3\u53cd\u5411\u4f20\u64ad\u3002\u6211\u4eec\u7528\u4e00\u4e2a python \u7c7b\u548c\u4e00\u4e2a\u524d\u9988\u51fd\u6570\u6765\u5b9a\u4e49\u7f51\u7edc\uff0c\u544a\u8bc9 PyTorch \u5404\u5c42\u4e4b\u95f4\u5982\u4f55\u8fde\u63a5\u3002\u5728 PyTorch \u4e2d\uff0c\u56fe\u50cf\u7b26\u53f7\u662f BS\u3001C\u3001H\u3001W\uff0c\u5176\u4e2d\uff0cBS \u662f\u6279\u5927\u5c0f\uff0cC \u662f\u901a\u9053\uff0cH \u662f\u9ad8\u5ea6\uff0cW \u662f\u5bbd\u5ea6\u3002\u8ba9\u6211\u4eec\u770b\u770b PyTorch \u662f\u5982\u4f55\u5b9e\u73b0 AlexNet \u7684\u3002 import torch import torch.nn as nn import torch.nn.functional as F class AlexNet ( nn . Module ): def __init__ ( self ): super ( AlexNet , self ) . __init__ () self . conv1 = nn . Conv2d ( in_channels = 3 , out_channels = 96 , kernel_size = 11 , stride = 4 , padding = 0 ) self . pool1 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv2 = nn . Conv2d ( in_channels = 96 , out_channels = 256 , kernel_size = 5 , stride = 1 , padding = 2 ) self . pool2 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . conv3 = nn . Conv2d ( in_channels = 256 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv4 = nn . Conv2d ( in_channels = 384 , out_channels = 384 , kernel_size = 3 , stride = 1 , padding = 1 ) self . conv5 = nn . Conv2d ( in_channels = 384 , out_channels = 256 , kernel_size = 3 , stride = 1 , padding = 1 ) self . pool3 = nn . MaxPool2d ( kernel_size = 3 , stride = 2 ) self . fc1 = nn . Linear ( in_features = 9216 , out_features = 4096 ) self . dropout1 = nn . Dropout ( 0.5 ) self . fc2 = nn . Linear ( in_features = 4096 , out_features = 4096 ) self . dropout2 = nn . Dropout ( 0.5 ) self . fc3 = nn . Linear ( in_features = 4096 , out_features = 1000 ) def forward ( self , image ): bs , c , h , w = image . size () x = F . relu ( self . conv1 ( image )) # size: (bs, 96, 55, 55) x = self . pool1 ( x ) # size: (bs, 96, 27, 27) x = F . relu ( self . conv2 ( x )) # size: (bs, 256, 27, 27) x = self . pool2 ( x ) # size: (bs, 256, 13, 13) x = F . relu ( self . conv3 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv4 ( x )) # size: (bs, 384, 13, 13) x = F . relu ( self . conv5 ( x )) # size: (bs, 256, 13, 13) x = self . pool3 ( x ) # size: (bs, 256, 6, 6) x = x . view ( bs , - 1 ) # size: (bs, 9216) x = F . relu ( self . fc1 ( x )) # size: (bs, 4096) x = self . dropout1 ( x ) # size: (bs, 4096) # dropout does not change size # dropout is used for regularization # 0.3 dropout means that only 70% of the nodes # of the current layer are used for the next layer x = F . relu ( self . fc2 ( x )) # size: (bs, 4096) x = self . dropout2 ( x ) # size: (bs, 4096) x = F . relu ( self . fc3 ( x )) # size: (bs, 1000) # 1000 is number of classes in ImageNet Dataset # softmax is an activation function that converts # linear output to probabilities that add up to 1 # for each sample in the batch x = torch . softmax ( x , axis = 1 ) # size: (bs, 1000) return x \u5982\u679c\u60a8\u6709\u4e00\u5e45 3x227x227 \u7684\u56fe\u50cf\uff0c\u5e76\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11 \u7684\u5377\u79ef\u6ee4\u6ce2\u5668\uff0c\u8fd9\u610f\u5473\u7740\u60a8\u5e94\u7528\u4e86\u4e00\u4e2a\u5927\u5c0f\u4e3a 11x11x3 \u7684\u6ee4\u6ce2\u5668\uff0c\u5e76\u4e0e\u4e00\u4e2a\u5927\u5c0f\u4e3a 227x227x3 \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u5377\u79ef\u3002\u8f93\u51fa\u901a\u9053\u7684\u6570\u91cf\u5c31\u662f\u5206\u522b\u5e94\u7528\u4e8e\u56fe\u50cf\u7684\u76f8\u540c\u5927\u5c0f\u7684\u4e0d\u540c\u5377\u79ef\u6ee4\u6ce2\u5668\u7684\u6570\u91cf\u3002 \u56e0\u6b64\uff0c\u5728\u7b2c\u4e00\u4e2a\u5377\u79ef\u5c42\u4e2d\uff0c\u8f93\u5165\u901a\u9053\u662f 3\uff0c\u4e5f\u5c31\u662f\u539f\u59cb\u8f93\u5165\uff0c\u5373 R\u3001G\u3001B \u4e09\u901a\u9053\u3002PyTorch \u7684 torchvision \u63d0\u4f9b\u4e86\u8bb8\u591a\u4e0e AlexNet \u7c7b\u4f3c\u7684\u4e0d\u540c\u6a21\u578b\uff0c\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0cAlexNet \u7684\u5b9e\u73b0\u4e0e torchvision \u7684\u5b9e\u73b0\u5e76\u4e0d\u76f8\u540c\u3002Torchvision \u7684 AlexNet \u5b9e\u73b0\u662f\u4ece\u53e6\u4e00\u7bc7\u8bba\u6587\u4e2d\u4fee\u6539\u800c\u6765\u7684 AlexNet\uff1a Krizhevsky, A. One weird trick for parallelizing convolutional neural networks. CoRR, abs/1404.5997, 2014. \u4f60\u53ef\u4ee5\u4e3a\u81ea\u5df1\u7684\u4efb\u52a1\u8bbe\u8ba1\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u5f88\u591a\u65f6\u5019\uff0c\u4ece\u96f6\u505a\u8d77\u662f\u4e2a\u4e0d\u9519\u7684\u4e3b\u610f\u3002\u8ba9\u6211\u4eec\u6784\u5efa\u4e00\u4e2a\u7f51\u7edc\uff0c\u7528\u4e8e\u533a\u5206\u56fe\u50cf\u6709\u65e0\u6c14\u80f8\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u51c6\u5907\u4e00\u4e9b\u6587\u4ef6\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u4e00\u4e2a\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e\u96c6\uff0c\u5373 train.csv\uff0c\u4f46\u589e\u52a0\u4e00\u5217 kfold\u3002\u6211\u4eec\u5c06\u521b\u5efa\u4e94\u4e2a\u6587\u4ef6\u5939\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u5df2\u7ecf\u6f14\u793a\u4e86\u5982\u4f55\u9488\u5bf9\u4e0d\u540c\u7684\u6570\u636e\u96c6\u521b\u5efa\u6298\u53e0\uff0c\u56e0\u6b64\u6211\u5c06\u8df3\u8fc7\u8fd9\u4e00\u90e8\u5206\uff0c\u7559\u4f5c\u7ec3\u4e60\u3002\u5bf9\u4e8e\u57fa\u4e8e PyTorch \u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u9700\u8981\u521b\u5efa\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u7684\u76ee\u7684\u662f\u8fd4\u56de\u4e00\u4e2a\u6570\u636e\u9879\u6216\u6570\u636e\u6837\u672c\u3002\u8fd9\u4e2a\u6570\u636e\u6837\u672c\u5e94\u8be5\u5305\u542b\u8bad\u7ec3\u6216\u8bc4\u4f30\u6a21\u578b\u6240\u9700\u7684\u6240\u6709\u5185\u5bb9\u3002 import torch import numpy as np from PIL import Image from PIL import ImageFile ImageFile . LOAD_TRUNCATED_IMAGES = True # \u5b9a\u4e49\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u5904\u7406\u56fe\u50cf\u5206\u7c7b\u4efb\u52a1 class ClassificationDataset : def __init__ ( self , image_paths , targets , resize = None , augmentations = None ): # \u56fe\u50cf\u6587\u4ef6\u8def\u5f84\u5217\u8868 self . image_paths = image_paths # \u76ee\u6807\u6807\u7b7e\u5217\u8868 self . targets = targets # \u56fe\u50cf\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . resize = resize # \u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u53ef\u4ee5\u4e3aNone self . augmentations = augmentations def __len__ ( self ): # \u8fd4\u56de\u6570\u636e\u96c6\u7684\u5927\u5c0f\uff0c\u5373\u56fe\u50cf\u6570\u91cf return len ( self . image_paths ) def __getitem__ ( self , item ): # \u83b7\u53d6\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e2a\u6837\u672c image = Image . open ( self . image_paths [ item ]) image = image . convert ( \"RGB\" ) # \u5c06\u56fe\u50cf\u8f6c\u6362\u4e3aRGB\u683c\u5f0f # \u83b7\u53d6\u8be5\u6837\u672c\u7684\u76ee\u6807\u6807\u7b7e targets = self . targets [ item ] if self . resize is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u5c3a\u5bf8\u8c03\u6574\u53c2\u6570\uff0c\u5c06\u56fe\u50cf\u8fdb\u884c\u5c3a\u5bf8\u8c03\u6574 image = image . resize (( self . resize [ 1 ], self . resize [ 0 ]), resample = Image . BILINEAR ) image = np . array ( image ) if self . augmentations is not None : # \u5982\u679c\u6307\u5b9a\u4e86\u6570\u636e\u589e\u5f3a\u51fd\u6570\uff0c\u5e94\u7528\u6570\u636e\u589e\u5f3a augmented = self . augmentations ( image = image ) image = augmented [ \"image\" ] # \u5c06\u56fe\u50cf\u901a\u9053\u987a\u5e8f\u8c03\u6574\u4e3a(C, H, W)\u7684\u5f62\u5f0f\uff0c\u5e76\u8f6c\u6362\u4e3afloat32\u7c7b\u578b image = np . transpose ( image , ( 2 , 0 , 1 )) . astype ( np . float32 ) # \u8fd4\u56de\u6837\u672c\uff0c\u5305\u62ec\u56fe\u50cf\u548c\u5bf9\u5e94\u7684\u76ee\u6807\u6807\u7b7e return { \"image\" : torch . tensor ( image , dtype = torch . float ), \"targets\" : torch . tensor ( targets , dtype = torch . long ), } \u73b0\u5728\u6211\u4eec\u9700\u8981 engine.py\u3002engine.py \u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u770b\u770b engine.py \u662f\u5982\u4f55\u7f16\u5199\u7684\u3002 import torch import torch.nn as nn from tqdm import tqdm # \u7528\u4e8e\u8bad\u7ec3\u6a21\u578b\u7684\u51fd\u6570 def train ( data_loader , model , optimizer , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bad\u7ec3\u6a21\u5f0f model . train () for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u548c\u76ee\u6807\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) targets = targets . to ( device , dtype = torch . float ) # \u5c06\u4f18\u5316\u5668\u4e2d\u7684\u68af\u5ea6\u5f52\u96f6 optimizer . zero_grad () # \u524d\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u6a21\u578b\u9884\u6d4b outputs = model ( inputs ) # \u4f7f\u7528\u5e26\u903b\u8f91\u65af\u8482\u51fd\u6570\u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\u8ba1\u7b97\u635f\u5931 loss = nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) # \u53cd\u5411\u4f20\u64ad\uff1a\u8ba1\u7b97\u68af\u5ea6\u5e76\u66f4\u65b0\u6a21\u578b\u6743\u91cd loss . backward () optimizer . step () # \u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u7684\u51fd\u6570 def evaluate ( data_loader , model , device ): # \u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bc4\u4f30\u6a21\u5f0f\uff08\u4e0d\u8fdb\u884c\u68af\u5ea6\u8ba1\u7b97\uff09 model . eval () # \u521d\u59cb\u5316\u5217\u8868\u4ee5\u5b58\u50a8\u771f\u5b9e\u76ee\u6807\u548c\u6a21\u578b\u9884\u6d4b final_targets = [] final_outputs = [] with torch . no_grad (): for data in data_loader : # \u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u63d0\u53d6\u8f93\u5165\u56fe\u50cf\u548c\u76ee\u6807\u6807\u7b7e inputs = data [ \"image\" ] targets = data [ \"targets\" ] # \u5c06\u8f93\u5165\u79fb\u52a8\u5230\u6307\u5b9a\u7684\u8bbe\u5907\uff08\u4f8b\u5982\uff0cGPU\uff09 inputs = inputs . to ( device , dtype = torch . float ) # \u83b7\u53d6\u6a21\u578b\u9884\u6d4b output = model ( inputs ) # \u5c06\u76ee\u6807\u548c\u8f93\u51fa\u8f6c\u6362\u4e3aCPU\u548cPython\u5217\u8868 targets = targets . detach () . cpu () . numpy () . tolist () output = output . detach () . cpu () . numpy () . tolist () # \u5c06\u5217\u8868\u6269\u5c55\u4ee5\u5305\u542b\u6279\u6b21\u6570\u636e final_targets . extend ( targets ) final_outputs . extend ( output ) # \u8fd4\u56de\u6700\u7ec8\u7684\u6a21\u578b\u9884\u6d4b\u548c\u771f\u5b9e\u76ee\u6807 return final_outputs , final_targets \u6709\u4e86 engine.py\uff0c\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\uff1amodel.py\u3002model.py \u5c06\u5305\u542b\u6211\u4eec\u7684\u6a21\u578b\u3002\u628a\u6a21\u578b\u4e0e\u8bad\u7ec3\u5206\u5f00\u662f\u4e2a\u597d\u4e3b\u610f\uff0c\u56e0\u4e3a\u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u8bd5\u9a8c\u4e0d\u540c\u7684\u6a21\u578b\u548c\u4e0d\u540c\u7684\u67b6\u6784\u3002\u540d\u4e3a pretrainedmodels \u7684 PyTorch \u5e93\u4e2d\u6709\u5f88\u591a\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\uff0c\u5982 AlexNet\u3001ResNet\u3001DenseNet \u7b49\u3002\u8fd9\u4e9b\u4e0d\u540c\u7684\u6a21\u578b\u67b6\u6784\u662f\u5728\u540d\u4e3a ImageNet \u7684\u5927\u578b\u56fe\u50cf\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u51fa\u6765\u7684\u3002\u5728 ImageNet \u4e0a\u8bad\u7ec3\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5b83\u4eec\u7684\u6743\u91cd\uff0c\u4e5f\u53ef\u4ee5\u4e0d\u4f7f\u7528\u8fd9\u4e9b\u6743\u91cd\u3002\u5982\u679c\u6211\u4eec\u4e0d\u4f7f\u7528 ImageNet \u6743\u91cd\u8fdb\u884c\u8bad\u7ec3\uff0c\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u7f51\u7edc\u5c06\u4ece\u5934\u5f00\u59cb\u5b66\u4e60\u4e00\u5207\u3002\u8fd9\u5c31\u662f model.py \u7684\u6837\u5b50\u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 AlexNet \u6a21\u578b model = pretrainedmodels . __dict__ [ \"alexnet\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 4096 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 4096 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5982\u679c\u4f60\u6253\u5370\u4e86\u7f51\u7edc\uff0c\u4f1a\u5f97\u5230\u5982\u4e0b\u8f93\u51fa\uff1a AlexNet ( ( avgpool ): AdaptiveAvgPool2d ( output_size = ( 6 , 6 )) ( _features ): Sequential ( ( 0 ): Conv2d ( 3 , 64 , kernel_size = ( 11 , 11 ), stride = ( 4 , 4 ), padding = ( 2 , 2 )) ( 1 ): ReLU ( inplace = True ) ( 2 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 3 ): Conv2d ( 64 , 192 , kernel_size = ( 5 , 5 ), stride = ( 1 , 1 ), padding = ( 2 , 2 )) ( 4 ): ReLU ( inplace = True ) ( 5 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , ceil_mode = False ) ( 6 ): Conv2d ( 192 , 384 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 7 ): ReLU ( inplace = True ) ( 8 ): Conv2d ( 384 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 9 ): ReLU ( inplace = True ) ( 10 ): Conv2d ( 256 , 256 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 )) ( 11 ): ReLU ( inplace = True ) ( 12 ): MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 0 , dilation = 1 , eil_mode = False )) ( dropout0 ): Dropout ( p = 0.5 , inplace = False ) ( linear0 ): Linear ( in_features = 9216 , out_features = 4096 , bias = True ) ( relu0 ): ReLU ( inplace = True ) ( dropout1 ): Dropout ( p = 0.5 , inplace = False ) ( linear1 ): Linear ( in_features = 4096 , out_features = 4096 , bias = True ) ( relu1 ): ReLU ( inplace = True ) ( last_linear ): Sequential ( ( 0 ): BatchNorm1d ( 4096 , eps = 1e-05 , momentum = 0.1 , affine = True , rack_running_stats = True ) ( 1 ): Dropout ( p = 0.25 , inplace = False ) ( 2 ): Linear ( in_features = 4096 , out_features = 2048 , bias = True ) ( 3 ): ReLU () ( 4 ): BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 , affine = True , track_running_stats = True ) ( 5 ): Dropout ( p = 0.5 , inplace = False ) ( 6 ): Linear ( in_features = 2048 , out_features = 1 , bias = True ) ) ) \u73b0\u5728\uff0c\u4e07\u4e8b\u4ff1\u5907\uff0c\u53ef\u4ee5\u5f00\u59cb\u8bad\u7ec3\u4e86\u3002\u6211\u4eec\u5c06\u4f7f\u7528 train.py \u8bad\u7ec3\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import torch from sklearn import metrics from sklearn.model_selection import train_test_split import dataset import engine from model import get_model if __name__ == \"__main__\" : # \u5b9a\u4e49\u6570\u636e\u8def\u5f84\u3001\u8bbe\u5907\u3001\u8fed\u4ee3\u6b21\u6570 data_path = \"/home/abhishek/workspace/siim_png/\" device = \"cuda\" # \u4f7f\u7528GPU\u52a0\u901f epochs = 10 # \u4eceCSV\u6587\u4ef6\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( os . path . join ( data_path , \"train.csv\" )) images = df . ImageId . values . tolist () images = [ os . path . join ( data_path , \"train_png\" , i + \".png\" ) for i in images ] targets = df . target . values # \u83b7\u53d6\u9884\u8bad\u7ec3\u7684\u6a21\u578b model = get_model ( pretrained = True ) model . to ( device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\uff0c\u7528\u4e8e\u6570\u636e\u6807\u51c6\u5316 mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) # \u6570\u636e\u589e\u5f3a\uff0c\u5c06\u56fe\u50cf\u6807\u51c6\u5316 aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5212\u5206\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 train_images , valid_images , train_targets , valid_targets = train_test_split ( images , targets , stratify = targets , random_state = 42 ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6\u548c\u9a8c\u8bc1\u6570\u636e\u96c6 train_dataset = dataset . ClassificationDataset ( image_paths = train_images , targets = train_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = 16 , shuffle = True , num_workers = 4 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = dataset . ClassificationDataset ( image_paths = valid_images , targets = valid_targets , resize = ( 227 , 227 ), augmentations = aug , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = 16 , shuffle = False , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u8bad\u7ec3\u5faa\u73af for epoch in range ( epochs ): # \u8bad\u7ec3\u6a21\u578b engine . train ( train_loader , model , optimizer , device = device ) # \u8bc4\u4f30\u6a21\u578b\u6027\u80fd predictions , valid_targets = engine . evaluate ( valid_loader , model , device = device ) # \u8ba1\u7b97ROC AUC\u5206\u6570\u5e76\u6253\u5370 roc_auc = metrics . roc_auc_score ( valid_targets , predictions ) print ( f \"Epoch= { epoch } , Valid ROC AUC= { roc_auc } \" ) \u8ba9\u6211\u4eec\u5728\u6ca1\u6709\u9884\u8bad\u7ec3\u6743\u91cd\u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u8bad\u7ec3\uff1a Epoch = 0 , Valid ROC AUC = 0.5737161981475328 Epoch = 1 , Valid ROC AUC = 0.5362868001588292 Epoch = 2 , Valid ROC AUC = 0.6163448214387008 Epoch = 3 , Valid ROC AUC = 0.6119219143780944 Epoch = 4 , Valid ROC AUC = 0.6229718888519726 Epoch = 5 , Valid ROC AUC = 0.5983014999635341 Epoch = 6 , Valid ROC AUC = 0.5523236874306134 Epoch = 7 , Valid ROC AUC = 0.4717721611306046 Epoch = 8 , Valid ROC AUC = 0.6473408263980617 Epoch = 9 , Valid ROC AUC = 0.6639862888260415 AUC \u7ea6\u4e3a 0.66\uff0c\u751a\u81f3\u4f4e\u4e8e\u6211\u4eec\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u4f1a\u53d1\u751f\u4ec0\u4e48\u60c5\u51b5\uff1f Epoch = 0 , Valid ROC AUC = 0.5730387429803165 Epoch = 1 , Valid ROC AUC = 0.5319813942934937 Epoch = 2 , Valid ROC AUC = 0.627111577514323 Epoch = 3 , Valid ROC AUC = 0.6819736959393209 Epoch = 4 , Valid ROC AUC = 0.5747117168950512 Epoch = 5 , Valid ROC AUC = 0.5994619255609669 Epoch = 6 , Valid ROC AUC = 0.5080889443530546 Epoch = 7 , Valid ROC AUC = 0.6323792776512727 Epoch = 8 , Valid ROC AUC = 0.6685753182661686 Epoch = 9 , Valid ROC AUC = 0.6861802387300147 \u73b0\u5728\u7684 AUC \u597d\u4e86\u5f88\u591a\u3002\u4e0d\u8fc7\uff0c\u5b83\u4ecd\u7136\u8f83\u4f4e\u3002\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u597d\u5904\u662f\u53ef\u4ee5\u8f7b\u677e\u5c1d\u8bd5\u591a\u79cd\u4e0d\u540c\u7684\u6a21\u578b\u3002\u8ba9\u6211\u4eec\u8bd5\u8bd5\u4f7f\u7528\u9884\u8bad\u7ec3\u6743\u91cd\u7684 resnet18 \u3002 import torch.nn as nn import pretrainedmodels # \u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u4ee5\u83b7\u53d6\u6a21\u578b def get_model ( pretrained ): if pretrained : # \u4f7f\u7528\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b\uff0c\u52a0\u8f7d\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u7684\u6743\u91cd model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = 'imagenet' ) else : # \u4f7f\u7528\u672a\u7ecf\u9884\u8bad\u7ec3\u7684 ResNet-18 \u6a21\u578b model = pretrainedmodels . __dict__ [ \"resnet18\" ]( pretrained = None ) # \u4fee\u6539\u6a21\u578b\u7684\u6700\u540e\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u4ee5\u9002\u5e94\u7279\u5b9a\u4efb\u52a1 model . last_linear = nn . Sequential ( nn . BatchNorm1d ( 512 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.25 ), # \u968f\u673a\u5931\u6d3b\u5c42\uff0c\u9632\u6b62\u8fc7\u62df\u5408 nn . Linear ( in_features = 512 , out_features = 2048 ), # \u8fde\u63a5\u5c42 nn . ReLU (), # ReLU \u6fc0\u6d3b\u51fd\u6570 nn . BatchNorm1d ( 2048 , eps = 1e-05 , momentum = 0.1 ), # \u6279\u5f52\u4e00\u5316\u5c42 nn . Dropout ( p = 0.5 ), # \u968f\u673a\u5931\u6d3b\u5c42 nn . Linear ( in_features = 2048 , out_features = 1 ) # \u6700\u7ec8\u7684\u4e8c\u5143\u5206\u7c7b\u5c42 ) return model \u5728\u5c1d\u8bd5\u8be5\u6a21\u578b\u65f6\uff0c\u6211\u8fd8\u5c06\u56fe\u50cf\u5927\u5c0f\u6539\u4e3a 512x512\uff0c\u5e76\u6dfb\u52a0\u4e86\u4e00\u4e2a\u5b66\u4e60\u7387\u8c03\u5ea6\u5668\uff0c\u6bcf 3 \u4e2aepochs\u540e\u5c06\u5b66\u4e60\u7387\u4e58\u4ee5 0.5\u3002 Epoch = 0 , Valid ROC AUC = 0.5988225569880796 Epoch = 1 , Valid ROC AUC = 0.730349343208836 Epoch = 2 , Valid ROC AUC = 0.5870943169939142 Epoch = 3 , Valid ROC AUC = 0.5775864444138311 Epoch = 4 , Valid ROC AUC = 0.7330502499939224 Epoch = 5 , Valid ROC AUC = 0.7500336296524395 Epoch = 6 , Valid ROC AUC = 0.7563722113724951 Epoch = 7 , Valid ROC AUC = 0.7987463837994215 Epoch = 8 , Valid ROC AUC = 0.798505708937384 Epoch = 9 , Valid ROC AUC = 0.8025477500546988 \u8fd9\u4e2a\u6a21\u578b\u4f3c\u4e4e\u8868\u73b0\u6700\u597d\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u8c03\u6574 AlexNet \u4e2d\u7684\u4e0d\u540c\u53c2\u6570\u548c\u56fe\u50cf\u5927\u5c0f\uff0c\u4ee5\u83b7\u5f97\u66f4\u597d\u7684\u5206\u6570\u3002 \u4f7f\u7528\u589e\u5f3a\u6280\u672f\u5c06\u8fdb\u4e00\u6b65\u63d0\u9ad8\u5f97\u5206\u3002\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u5f88\u96be\uff0c\u4f46\u5e76\u975e\u4e0d\u53ef\u80fd\u3002\u9009\u62e9 Adam \u4f18\u5316\u5668\u3001\u4f7f\u7528\u4f4e\u5b66\u4e60\u7387\u3001\u5728\u9a8c\u8bc1\u635f\u5931\u8fbe\u5230\u9ad8\u70b9\u65f6\u964d\u4f4e\u5b66\u4e60\u7387\u3001\u5c1d\u8bd5\u4e00\u4e9b\u589e\u5f3a\u6280\u672f\u3001\u5c1d\u8bd5\u5bf9\u56fe\u50cf\u8fdb\u884c\u9884\u5904\u7406\uff08\u5982\u5728\u9700\u8981\u65f6\u8fdb\u884c\u88c1\u526a\uff0c\u8fd9\u4e5f\u53ef\u89c6\u4e3a\u9884\u5904\u7406\uff09\u3001\u6539\u53d8\u6279\u6b21\u5927\u5c0f\u7b49\u3002\u4f60\u53ef\u4ee5\u505a\u5f88\u591a\u4e8b\u60c5\u6765\u4f18\u5316\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\u3002 \u4e0e AlexNet \u76f8\u6bd4\uff0c ResNet \u7684\u7ed3\u6784\u8981\u590d\u6742\u5f97\u591a\u3002ResNet \u662f\u6b8b\u5dee\u795e\u7ecf\u7f51\u7edc\uff08Residual Neural Network\uff09\u7684\u7f29\u5199\uff0c\u7531 K. He\u3001X. Zhang\u3001S. Ren \u548c J. Sun \u5728 2015 \u5e74\u53d1\u8868\u7684\u8bba\u6587\u4e2d\u63d0\u51fa\u3002ResNet \u7531 \u6b8b\u5dee\u5757 \uff08residual blocks\uff09\u7ec4\u6210\uff0c\u901a\u8fc7\u8df3\u8fc7\u67d0\u4e9b\u5c42\uff0c\u4f7f\u77e5\u8bc6\u80fd\u591f\u4e0d\u65ad\u5728\u5404\u5c42\u4e2d\u8fdb\u884c\u4f20\u9012\u3002\u8fd9\u4e9b\u5c42\u4e4b\u95f4\u7684 \u8fde\u63a5\u88ab\u79f0\u4e3a \u8df3\u8dc3\u8fde\u63a5 \uff08skip-connections\uff09\uff0c\u56e0\u4e3a\u6211\u4eec\u8df3\u8fc7\u4e86\u4e00\u5c42\u6216\u591a\u5c42\u3002\u8df3\u8dc3\u8fde\u63a5\u901a\u8fc7\u5c06\u68af\u5ea6\u4f20\u64ad\u5230\u66f4\u591a\u5c42\u6765\u5e2e\u52a9\u89e3\u51b3\u68af\u5ea6\u6d88\u5931\u95ee\u9898\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u8bad\u7ec3\u975e\u5e38\u5927\u7684\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff0c\u800c\u4e0d\u4f1a\u635f\u5931\u6027\u80fd\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5927\u578b\u795e\u7ecf\u7f51\u7edc\uff0c\u90a3\u4e48\u5f53\u8bad\u7ec3\u5230\u67d0\u4e00\u8282\u70b9\u4e0a\u65f6\u8bad\u7ec3\u635f\u5931\u53cd\u800c\u4f1a\u589e\u52a0\uff0c\u4f46\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u8df3\u8dc3\u8fde\u63a5\u6765\u907f\u514d\u3002\u901a\u8fc7\u56fe 7 \u53ef\u4ee5\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u56fe 7\uff1a\u7b80\u5355\u8fde\u63a5\u4e0e\u6b8b\u5dee\u8fde\u63a5\u7684\u6bd4\u8f83\u3002\u53c2\u89c1\u8df3\u8dc3\u8fde\u63a5\u3002\u8bf7\u6ce8\u610f\uff0c\u672c\u56fe\u7701\u7565\u4e86\u6700\u540e\u4e00\u5c42\u3002 \u6b8b\u5dee\u5757\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4f60\u4ece\u67d0\u4e00\u5c42\u83b7\u53d6\u8f93\u51fa\uff0c\u8df3\u8fc7\u4e00\u4e9b\u5c42\uff0c\u7136\u540e\u5c06\u8f93\u51fa\u6dfb\u52a0\u5230\u7f51\u7edc\u4e2d\u66f4\u8fdc\u7684\u4e00\u5c42\u3002\u865a\u7ebf\u8868\u793a\u8f93\u5165\u5f62\u72b6\u9700\u8981\u8c03\u6574\uff0c\u56e0\u4e3a\u4f7f\u7528\u4e86\u6700\u5927\u6c60\u5316\uff0c\u800c\u6700\u5927\u6c60\u5316\u7684\u4f7f\u7528\u4f1a\u6539\u53d8\u8f93\u51fa\u7684\u5927\u5c0f\u3002 ResNet \u6709\u591a\u79cd\u4e0d\u540c\u7684\u7248\u672c\uff1a \u6709 18 \u5c42\u300134 \u5c42\u300150 \u5c42\u3001101 \u5c42\u548c 152 \u5c42\uff0c\u6240\u6709\u8fd9\u4e9b\u5c42\u90fd\u5728 ImageNet \u6570\u636e\u96c6\u4e0a\u8fdb\u884c\u4e86\u6743\u91cd\u9884\u8bad\u7ec3\u3002\u5982\u4eca\uff0c\u9884\u8bad\u7ec3\u6a21\u578b\uff08\u51e0\u4e4e\uff09\u9002\u7528\u4e8e\u6240\u6709\u60c5\u51b5\uff0c\u4f46\u8bf7\u786e\u4fdd\u60a8\u4ece\u8f83\u5c0f\u7684\u6a21\u578b\u5f00\u59cb\uff0c\u4f8b\u5982\uff0c\u4ece resnet-18 \u5f00\u59cb\uff0c\u800c\u4e0d\u662f resnet-50\u3002\u5176\u4ed6\u4e00\u4e9b ImageNet \u9884\u8bad\u7ec3\u6a21\u578b\u5305\u62ec\uff1a Inception DenseNet(different variations) NASNet PNASNet VGG Xception ResNeXt EfficientNet, etc. \u5927\u90e8\u5206\u9884\u8bad\u7ec3\u7684\u6700\u5148\u8fdb\u6a21\u578b\u53ef\u4ee5\u5728 GitHub \u4e0a\u7684 pytorch- pretrainedmodels \u8d44\u6e90\u5e93\u4e2d\u627e\u5230\uff1ahttps://github.com/Cadene/pretrained-models.pytorch\u3002\u8be6\u7ec6\u8ba8\u8bba\u8fd9\u4e9b\u6a21\u578b\u4e0d\u5728\u672c\u7ae0\uff08\u548c\u672c\u4e66\uff09\u8303\u56f4\u4e4b\u5185\u3002\u65e2\u7136\u6211\u4eec\u53ea\u5173\u6ce8\u5e94\u7528\uff0c\u90a3\u5c31\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6837\u7684\u9884\u8bad\u7ec3\u6a21\u578b\u5982\u4f55\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u3002 \u56fe 8\uff1aU-Net\u67b6\u6784 \u5206\u5272\uff08Segmentation\uff09\u662f\u8ba1\u7b97\u673a\u89c6\u89c9\u4e2d\u76f8\u5f53\u6d41\u884c\u7684\u4e00\u9879\u4efb\u52a1\u3002\u5728\u5206\u5272\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u8bd5\u56fe\u4ece\u80cc\u666f\u4e2d\u79fb\u9664/\u63d0\u53d6\u524d\u666f\u3002 \u524d\u666f\u548c\u80cc\u666f\u53ef\u4ee5\u6709\u4e0d\u540c\u7684\u5b9a\u4e49\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u8bf4\uff0c\u8fd9\u662f\u4e00\u9879\u50cf\u7d20\u5206\u7c7b\u4efb\u52a1\uff0c\u4f60\u7684\u5de5\u4f5c\u662f\u7ed9\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u6bcf\u4e2a\u50cf\u7d20\u5206\u914d\u4e00\u4e2a\u7c7b\u522b\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u6b63\u5728\u5904\u7406\u7684\u6c14\u80f8\u6570\u636e\u96c6\u5c31\u662f\u4e00\u9879\u5206\u5272\u4efb\u52a1\u3002\u5728\u8fd9\u9879\u4efb\u52a1\u4e2d\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u7ed9\u5b9a\u7684\u80f8\u90e8\u653e\u5c04\u56fe\u50cf\u8fdb\u884c\u6c14\u80f8\u5206\u5272\u3002\u7528\u4e8e\u5206\u5272\u4efb\u52a1\u7684\u6700\u5e38\u7528\u6a21\u578b\u662f U-Net\u3002\u5176\u7ed3\u6784\u5982\u56fe 8 \u6240\u793a\u3002 U-Net \u5305\u62ec\u4e24\u4e2a\u90e8\u5206\uff1a\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u3002\u7f16\u7801\u5668\u4e0e\u60a8\u76ee\u524d\u6240\u89c1\u8fc7\u7684\u4efb\u4f55 U-Net \u90fd\u662f\u4e00\u6837\u7684\u3002\u89e3\u7801\u5668\u5219\u6709\u4e9b\u4e0d\u540c\u3002\u89e3\u7801\u5668\u7531\u4e0a\u5377\u79ef\u5c42\u7ec4\u6210\u3002\u5728\u4e0a\u5377\u79ef\uff08up-convolutions\uff09\uff08 \u8f6c\u7f6e\u5377\u79ef transposed convolutions\uff09\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u6ee4\u6ce2\u5668\uff0c\u5f53\u5e94\u7528\u5230\u4e00\u4e2a\u5c0f\u56fe\u50cf\u65f6\uff0c\u4f1a\u4ea7\u751f\u4e00\u4e2a\u5927\u56fe\u50cf\u3002\u5728 PyTorch \u4e2d\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528 ConvTranspose2d \u6765\u5b8c\u6210\u8fd9\u4e00\u64cd\u4f5c\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u4e0a\u5377\u79ef\u4e0e\u4e0a\u91c7\u6837\u5e76\u4e0d\u76f8\u540c\u3002\u4e0a\u91c7\u6837\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u8fc7\u7a0b\uff0c\u6211\u4eec\u5728\u56fe\u50cf\u4e0a\u5e94\u7528\u4e00\u4e2a\u51fd\u6570\u6765\u8c03\u6574\u5b83\u7684\u5927\u5c0f\u3002\u5728\u4e0a\u5377\u79ef\u4e2d\uff0c\u6211\u4eec\u8981\u5b66\u4e60\u6ee4\u6ce2\u5668\u3002\u6211\u4eec\u5c06\u7f16\u7801\u5668\u7684\u67d0\u4e9b\u90e8\u5206\u4f5c\u4e3a\u67d0\u4e9b\u89e3\u7801\u5668\u7684\u8f93\u5165\u3002\u8fd9\u5bf9 \u4e0a\u5377\u79ef\u5c42\u975e\u5e38\u91cd\u8981\u3002 \u8ba9\u6211\u4eec\u770b\u770b U-Net \u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import torch import torch.nn as nn from torch.nn import functional as F # \u5b9a\u4e49\u4e00\u4e2a\u53cc\u5377\u79ef\u5c42 def double_conv ( in_channels , out_channels ): conv = nn . Sequential ( nn . Conv2d ( in_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ), nn . Conv2d ( out_channels , out_channels , kernel_size = 3 ), nn . ReLU ( inplace = True ) ) return conv # \u5b9a\u4e49\u51fd\u6570\u7528\u4e8e\u88c1\u526a\u8f93\u5165\u5f20\u91cf def crop_tensor ( tensor , target_tensor ): target_size = target_tensor . size ()[ 2 ] tensor_size = tensor . size ()[ 2 ] delta = tensor_size - target_size delta = delta // 2 return tensor [:, :, delta : tensor_size - delta , delta : tensor_size - delta ] # \u5b9a\u4e49 U-Net \u6a21\u578b class UNet ( nn . Module ): def __init__ ( self ): super ( UNet , self ) . __init () # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . max_pool_2x2 = nn . MaxPool2d ( kernel_size = 2 , stride = 2 ) self . down_conv_1 = double_conv ( 1 , 64 ) self . down_conv_2 = double_conv ( 64 , 128 ) self . down_conv_3 = double_conv ( 128 , 256 ) self . down_conv_4 = double_conv ( 256 , 512 ) self . down_conv_5 = double_conv ( 512 , 1024 ) # \u5b9a\u4e49\u4e0a\u91c7\u6837\u5c42\u548c\u89e3\u7801\u5668\u7684\u53cc\u5377\u79ef\u5c42 self . up_trans_1 = nn . ConvTranspose2d ( in_channels = 1024 , out_channels = 512 , kernel_size = 2 , stride = 2 ) self . up_conv_1 = double_conv ( 1024 , 512 ) self . up_trans_2 = nn . ConvTranspose2d ( in_channels = 512 , out_channels = 256 , kernel_size = 2 , stride = 2 ) self . up_conv_2 = double_conv ( 512 , 256 ) self . up_trans_3 = nn . ConvTranspose2d ( in_channels = 256 , out_channels = 128 , kernel_size = 2 , stride = 2 ) self . up_conv_3 = double_conv ( 256 , 128 ) self . up_trans_4 = nn . ConvTranspose2d ( in_channels = 128 , out_channels = 64 , kernel_size = 2 , stride = 2 ) self . up_conv_4 = double_conv ( 128 , 64 ) # \u5b9a\u4e49\u8f93\u51fa\u5c42 self . out = nn . Conv2d ( in_channels = 64 , out_channels = 2 , kernel_size = 1 ) def forward ( self , image ): # \u7f16\u7801\u5668\u90e8\u5206 x1 = self . down_conv_1 ( image ) x2 = self . max_pool_2x2 ( x1 ) x3 = self . down_conv_2 ( x2 ) x4 = self . max_pool_2x2 ( x3 ) x5 = self . down_conv_3 ( x4 ) x6 = self . max_pool_2x2 ( x5 ) x7 = self . down_conv_4 ( x6 ) x8 = self . max_pool_2x2 ( x7 ) x9 = self . down_conv_5 ( x8 ) # \u89e3\u7801\u5668\u90e8\u5206 x = self . up_trans_1 ( x9 ) y = crop_tensor ( x7 , x ) x = self . up_conv_1 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_2 ( x ) y = crop_tensor ( x5 , x ) x = self . up_conv_2 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_3 ( x ) y = crop_tensor ( x3 , x ) x = self . up_conv_3 ( torch . cat ([ x , y ], axis = 1 )) x = self . up_trans_4 ( x ) y = crop_tensor ( x1 , x ) x = self . up_conv_4 ( torch . cat ([ x , y ], axis = 1 )) # \u8f93\u51fa\u5c42 out = self . out ( x ) return out if __name__ == \"__main__\" : image = torch . rand (( 1 , 1 , 572 , 572 )) model = UNet () print ( model ( image )) \u8bf7\u6ce8\u610f\uff0c\u6211\u4e0a\u9762\u5c55\u793a\u7684 U-Net \u5b9e\u73b0\u662f U-Net \u8bba\u6587\u7684\u539f\u59cb\u5b9e\u73b0\u3002\u4e92\u8054\u7f51\u4e0a\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9e\u73b0\u65b9\u6cd5\u3002 \u6709\u4e9b\u4eba\u559c\u6b22\u4f7f\u7528\u53cc\u7ebf\u6027\u91c7\u6837\u4ee3\u66ff\u8f6c\u7f6e\u5377\u79ef\u8fdb\u884c\u4e0a\u91c7\u6837\uff0c\u4f46\u8fd9\u5e76\u4e0d\u662f\u8bba\u6587\u7684\u771f\u6b63\u5b9e\u73b0\u3002\u4e0d\u8fc7\uff0c\u5b83\u7684\u6027\u80fd\u53ef\u80fd\u4f1a\u66f4\u597d\u3002\u5728\u4e0a\u56fe\u6240\u793a\u7684\u539f\u59cb\u5b9e\u73b0\u4e2d\uff0c\u6709\u4e00\u4e2a\u5355\u901a\u9053\u56fe\u50cf\uff0c\u8f93\u51fa\u4e2d\u6709\u4e24\u4e2a\u901a\u9053\uff1a\u4e00\u4e2a\u662f\u524d\u666f\uff0c\u4e00\u4e2a\u662f\u80cc\u666f\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u4e3a\u4efb\u610f\u6570\u91cf\u7684\u7c7b\u548c\u4efb\u610f\u6570\u91cf\u7684\u8f93\u5165\u901a\u9053\u8fdb\u884c\u5b9a\u5236\u3002\u5728\u6b64\u5b9e\u73b0\u4e2d\uff0c\u8f93\u5165\u56fe\u50cf\u7684\u5927\u5c0f\u4e0e\u8f93\u51fa\u56fe\u50cf\u7684\u5927\u5c0f\u4e0d\u540c\uff0c\u56e0\u4e3a\u6211\u4eec\u4f7f\u7528\u7684\u662f\u65e0\u586b\u5145\u5377\u79ef\uff08convolutions without padding\uff09\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0cU-Net \u7684\u7f16\u7801\u5668\u90e8\u5206\u53ea\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u5377\u79ef\u7f51\u7edc\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u4efb\u4f55\u7f51\u7edc\uff08\u5982 ResNet\uff09\u6765\u66ff\u6362\u5b83\u3002 \u8fd9\u79cd\u66ff\u6362\u4e5f\u53ef\u4ee5\u901a\u8fc7\u9884\u8bad\u7ec3\u6743\u91cd\u6765\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u57fa\u4e8e ResNet \u7684\u7f16\u7801\u5668\uff0c\u8be5\u7f16\u7801\u5668\u5df2\u5728 ImageNet \u548c\u901a\u7528\u89e3\u7801\u5668\u4e0a\u8fdb\u884c\u4e86\u9884\u8bad\u7ec3\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u591a\u79cd\u4e0d\u540c\u7684\u7f51\u7edc\u67b6\u6784\u6765\u4ee3\u66ff ResNet\u3002Pavel Yakubovskiy \u6240\u8457\u7684\u300aSegmentation Models Pytorch\u300b\u5c31\u662f\u8bb8\u591a\u6b64\u7c7b\u53d8\u4f53\u7684\u5b9e\u73b0\uff0c\u5176\u4e2d\u7f16\u7801\u5668\u53ef\u4ee5\u88ab\u9884\u8bad\u7ec3\u6a21\u578b\u6240\u53d6\u4ee3\u3002\u8ba9\u6211\u4eec\u5e94\u7528\u57fa\u4e8e ResNet \u7684 U-Net \u6765\u89e3\u51b3\u6c14\u80f8\u68c0\u6d4b\u95ee\u9898\u3002 \u5927\u591a\u6570\u7c7b\u4f3c\u7684\u95ee\u9898\u90fd\u6709\u4e24\u4e2a\u8f93\u5165\uff1a\u539f\u59cb\u56fe\u50cf\u548c\u63a9\u7801\uff08mask\uff09\u3002 \u5982\u679c\u6709\u591a\u4e2a\u5bf9\u8c61\uff0c\u5c31\u4f1a\u6709\u591a\u4e2a\u63a9\u7801\u3002 \u5728\u6211\u4eec\u7684\u6c14\u80f8\u6570\u636e\u96c6\u4e2d\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f RLE\u3002RLE \u4ee3\u8868\u8fd0\u884c\u957f\u5ea6\u7f16\u7801\uff0c\u662f\u4e00\u79cd\u8868\u793a\u4e8c\u8fdb\u5236\u63a9\u7801\u4ee5\u8282\u7701\u7a7a\u95f4\u7684\u65b9\u6cd5\u3002\u6df1\u5165\u7814\u7a76 RLE \u8d85\u51fa\u4e86\u672c\u7ae0\u7684\u8303\u56f4\u3002\u56e0\u6b64\uff0c\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u5f20\u8f93\u5165\u56fe\u50cf\u548c\u76f8\u5e94\u7684\u63a9\u7801\u3002\u8ba9\u6211\u4eec\u5148\u8bbe\u8ba1\u4e00\u4e2a\u6570\u636e\u96c6\u7c7b\uff0c\u7528\u4e8e\u8f93\u51fa\u56fe\u50cf\u548c\u63a9\u7801\u56fe\u50cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u521b\u5efa\u7684\u811a\u672c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u5206\u5272\u95ee\u9898\u3002\u8bad\u7ec3\u6570\u636e\u96c6\u662f\u4e00\u4e2a CSV \u6587\u4ef6\uff0c\u53ea\u5305\u542b\u56fe\u50cf ID\uff08\u4e5f\u662f\u6587\u4ef6\u540d\uff09\u3002 import os import glob import torch import numpy as np import pandas as pd from PIL import Image , ImageFile from tqdm import tqdm from collections import defaultdict from torchvision import transforms from albumentations import ( Compose , OneOf , RandomBrightnessContrast , RandomGamma , ShiftScaleRotate , ) # \u8bbe\u7f6ePIL\u56fe\u50cf\u52a0\u8f7d\u622a\u65ad\u7684\u5904\u7406 ImageFile . LOAD_TRUNCATED_IMAGES = True # \u521b\u5efaSIIM\u6570\u636e\u96c6\u7c7b class SIIMDataset ( torch . utils . data . Dataset ): def __init__ ( self , image_ids , transform = True , preprocessing_fn = None ): self . data = defaultdict ( dict ) self . transform = transform self . preprocessing_fn = preprocessing_fn # \u5b9a\u4e49\u6570\u636e\u589e\u5f3a self . aug = Compose ( [ ShiftScaleRotate ( shift_limit = 0.0625 , scale_limit = 0.1 , rotate_limit = 10 , p = 0.8 ), OneOf ( [ RandomGamma ( gamma_limit = ( 90 , 110 ) ), RandomBrightnessContrast ( brightness_limit = 0.1 , contrast_limit = 0.1 ), ], p = 0.5 , ), ] ) # \u6784\u5efa\u6570\u636e\u5b57\u5178\uff0c\u5176\u4e2d\u5305\u542b\u56fe\u50cf\u548c\u63a9\u7801\u7684\u8def\u5f84\u4fe1\u606f for imgid in image_ids : files = glob . glob ( os . path . join ( TRAIN_PATH , imgid , \"*.png\" )) self . data [ counter ] = { \"img_path\" : os . path . join ( TRAIN_PATH , imgid + \".png\" ), \"mask_path\" : os . path . join ( TRAIN_PATH , imgid + \"_mask.png\" ), } def __len__ ( self ): return len ( self . data ) def __getitem__ ( self , item ): img_path = self . data [ item ][ \"img_path\" ] mask_path = self . data [ item ][ \"mask_path\" ] # \u6253\u5f00\u56fe\u50cf\u5e76\u5c06\u5176\u8f6c\u6362\u4e3aRGB\u6a21\u5f0f img = Image . open ( img_path ) img = img . convert ( \"RGB\" ) img = np . array ( img ) # \u6253\u5f00\u63a9\u7801\u56fe\u50cf\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6570 mask = Image . open ( mask_path ) mask = ( mask >= 1 ) . astype ( \"float32\" ) # \u5982\u679c\u9700\u8981\u8fdb\u884c\u6570\u636e\u589e\u5f3a if self . transform is True : augmented = self . aug ( image = img , mask = mask ) img = augmented [ \"image\" ] mask = augmented [ \"mask\" ] # \u5e94\u7528\u9884\u5904\u7406\u51fd\u6570\uff08\u5982\u679c\u6709\uff09 img = self . preprocessing_fn ( img ) # \u8fd4\u56de\u56fe\u50cf\u548c\u63a9\u7801 return { \"image\" : transforms . ToTensor ()( img ), \"mask\" : transforms . ToTensor ()( mask ) . float (), } \u6709\u4e86\u6570\u636e\u96c6\u7c7b\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8bad\u7ec3\u51fd\u6570\u3002 import os import sys import torch import numpy as np import pandas as pd import segmentation_models_pytorch as smp import torch.nn as nn import torch.optim as optim from apex import amp from collections import OrderedDict from sklearn import model_selection from tqdm import tqdm from torch.optim import lr_scheduler from dataset import SIIMDataset # \u5b9a\u4e49\u8bad\u7ec3\u6570\u636e\u96c6CSV\u6587\u4ef6\u8def\u5f84 TRAINING_CSV = \"../input/train_pneumothorax.csv\" # \u5b9a\u4e49\u8bad\u7ec3\u548c\u6d4b\u8bd5\u7684\u6279\u91cf\u5927\u5c0f TRAINING_BATCH_SIZE = 16 TEST_BATCH_SIZE = 4 # \u5b9a\u4e49\u8bad\u7ec3\u7684\u65f6\u671f\u6570 EPOCHS = 10 # \u6307\u5b9a\u4f7f\u7528\u7684\u7f16\u7801\u5668\u548c\u6743\u91cd ENCODER = \"resnet18\" ENCODER_WEIGHTS = \"imagenet\" # \u6307\u5b9a\u8bbe\u5907\uff08GPU\uff09 DEVICE = \"cuda\" # \u5b9a\u4e49\u8bad\u7ec3\u51fd\u6570 def train ( dataset , data_loader , model , criterion , optimizer ): model . train () num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs . to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) optimizer . zero_grad () outputs = model ( inputs ) loss = criterion ( outputs , targets ) with amp . scale_loss ( loss , optimizer ) as scaled_loss : scaled_loss . backward () optimizer . step () tk0 . close () # \u5b9a\u4e49\u8bc4\u4f30\u51fd\u6570 def evaluate ( dataset , data_loader , model ): model . eval () final_loss = 0 num_batches = int ( len ( dataset ) / data_loader . batch_size ) tk0 = tqdm ( data_loader , total = num_batches ) with torch . no_grad (): for d in tk0 : inputs = d [ \"image\" ] targets = d [ \"mask\" ] inputs = inputs to ( DEVICE , dtype = torch . float ) targets = targets . to ( DEVICE , dtype = torch . float ) output = model ( inputs ) loss = criterion ( output , targets ) final_loss += loss tk0 . close () return final_loss / num_batches if __name__ == \"__main__\" : df = pd . read_csv ( TRAINING_CSV ) df_train , df_valid = model_selection . train_test_split ( df , random_state = 42 , test_size = 0.1 ) training_images = df_train . image_id . values validation_images = df_valid . image_id . values # \u521b\u5efa U-Net \u6a21\u578b model = smp . Unet ( encoder_name = ENCODER , encoder_weights = ENCODER_WEIGHTS , classes = 1 , activation = None , ) # \u83b7\u53d6\u6570\u636e\u9884\u5904\u7406\u51fd\u6570 prep_fn = smp . encoders . get_preprocessing_fn ( ENCODER , ENCODER_WEIGHTS ) # \u5c06\u6a21\u578b\u653e\u5728\u8bbe\u5907\u4e0a model . to ( DEVICE ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u96c6 train_dataset = SIIMDataset ( training_images , transform = True , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = TRAINING_BATCH_SIZE , shuffle = True , num_workers = 12 ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u96c6 valid_dataset = SIIMDataset ( validation_images , transform = False , preprocessing_fn = prep_fn , ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = TEST_BATCH_SIZE , shuffle = True , num_workers = 4 ) # \u5b9a\u4e49\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) # \u5b9a\u4e49\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = lr_scheduler . ReduceLROnPlateau ( optimizer , mode = \"min\" , patience = 3 , verbose = True ) # \u521d\u59cb\u5316 Apex \u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3 model , optimizer = amp . initialize ( model , optimizer , opt_level = \"O1\" , verbosity = 0 ) # \u5982\u679c\u6709\u591a\u4e2aGPU\uff0c\u5219\u4f7f\u7528 DataParallel \u8fdb\u884c\u5e76\u884c\u8bad\u7ec3 if torch . cuda . device_count () > 1 : print ( f \"Let's use { torch . cuda . device_count () } GPUs!\" ) model = nn . DataParallel ( model ) # \u8f93\u51fa\u8bad\u7ec3\u76f8\u5173\u7684\u4fe1\u606f print ( f \"Training batch size: { TRAINING_BATCH_SIZE } \" ) print ( f \"Test batch size: { TEST_BATCH_SIZE } \" ) print ( f \"Epochs: { EPOCHS } \" ) print ( f \"Image size: { IMAGE_SIZE } \" ) print ( f \"Number of training images: { len ( train_dataset ) } \" ) print ( f \"Number of validation images: { len ( valid_dataset ) } \" ) print ( f \"Encoder: { ENCODER } \" ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( EPOCHS ): print ( f \"Training Epoch: { epoch } \" ) train ( train_dataset , train_loader , model , criterion , optimizer ) print ( f \"Validation Epoch: { epoch } \" ) val_log = evaluate ( valid_dataset , valid_loader , model ) scheduler . step ( val_log [ \"loss\" ]) print ( \" \\n \" ) \u5728\u5206\u5272\u95ee\u9898\u4e2d\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u5404\u79cd\u635f\u5931\u51fd\u6570\uff0c\u4f8b\u5982\u4e8c\u5143\u4ea4\u53c9\u71b5\u3001focal\u635f\u5931\u3001dice\u635f\u5931\u7b49\u3002\u6211\u628a\u8fd9\u4e2a\u95ee\u9898\u7559\u7ed9 \u8bfb\u8005\u6839\u636e\u8bc4\u4f30\u6307\u6807\u6765\u51b3\u5b9a\u5408\u9002\u7684\u635f\u5931\u3002\u5f53\u8bad\u7ec3\u8fd9\u6837\u4e00\u4e2a\u6a21\u578b\u65f6\uff0c\u60a8\u5c06\u5efa\u7acb\u9884\u6d4b\u6c14\u80f8\u4f4d\u7f6e\u7684\u6a21\u578b\uff0c\u5982\u56fe 9 \u6240\u793a\u3002\u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u82f1\u4f1f\u8fbe apex \u8fdb\u884c\u4e86\u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4ece PyTorch 1.6.0+ \u7248\u672c\u5f00\u59cb\uff0cPyTorch \u672c\u8eab\u5c31\u63d0\u4f9b\u4e86\u8fd9\u4e00\u529f\u80fd\u3002 \u56fe 9\uff1a\u4ece\u8bad\u7ec3\u6709\u7d20\u7684\u6a21\u578b\u4e2d\u68c0\u6d4b\u5230\u6c14\u80f8\u7684\u793a\u4f8b\uff08\u53ef\u80fd\u4e0d\u662f\u6b63\u786e\u9884\u6d4b\uff09\u3002 \u6211\u5728\u4e00\u4e2a\u540d\u4e3a \"Well That's Fantastic Machine Learning (WTFML) \"\u7684 python \u8f6f\u4ef6\u5305\u4e2d\u6536\u5f55\u4e86\u4e00\u4e9b\u5e38\u7528\u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u5982\u4f55\u5e2e\u52a9\u6211\u4eec\u4e3a FGVC 202013 \u690d\u7269\u75c5\u7406\u5b66\u6311\u6218\u8d5b\u4e2d\u7684\u690d\u7269\u56fe\u50cf\u5efa\u7acb\u591a\u7c7b\u5206\u7c7b\u6a21\u578b\u3002 import os import pandas as pd import numpy as np import albumentations import argparse import torch import torchvision import torch.nn as nn import torch.nn.functional as F from sklearn import metrics from sklearn.model_selection import train_test_split from wtfml.engine import Engine from wtfml.data_loaders.image import ClassificationDataLoader # \u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\uff0c\u5b9e\u73b0\u5bc6\u96c6\u4ea4\u53c9\u71b5 class DenseCrossEntropy ( nn . Module ): def __init__ ( self ): super ( DenseCrossEntropy , self ) . __init__ () def forward ( self , logits , labels ): logits = logits . float () labels = labels . float () logprobs = F . log_softmax ( logits , dim =- 1 ) loss = - labels * logprobs loss = loss . sum ( - 1 ) return loss . mean () # \u81ea\u5b9a\u4e49\u795e\u7ecf\u7f51\u7edc\u6a21\u578b class Model ( nn . Module ): def __init__ ( self ): super () . __init () self . base_model = torchvision . models . resnet18 ( pretrained = True ) in_features = self . base_model . fc . in_features self . out = nn . Linear ( in_features , 4 ) def forward ( self , image , targets = None ): batch_size , C , H , W = image . shape x = self . base_model . conv1 ( image ) x = self . base_model . bn1 ( x ) x = self . base_model . relu ( x ) x = self . base_model . maxpool ( x ) x = self . base_model . layer1 ( x ) x = self . base_model . layer2 ( x ) x = self . base_model . layer3 ( x ) x = self . base_model . layer4 ( x ) x = F . adaptive_avg_pool2d ( x , 1 ) . reshape ( batch_size , - 1 ) x = self . out ( x ) loss = None if targets is not None : loss = DenseCrossEntropy ()( x , targets . type_as ( x )) return x , loss if __name__ == \"__main__\" : # \u547d\u4ee4\u884c\u53c2\u6570\u89e3\u6790\u5668 parser = argparse . ArgumentParser () parser . add_argument ( \"--data_path\" , type = str , ) parser . add_argument ( \"--device\" , type = str ,) parser . add_argument ( \"--epochs\" , type = int ,) args = parser . parse_args () # \u4eceCSV\u6587\u4ef6\u52a0\u8f7d\u6570\u636e df = pd . read_csv ( os . path . join ( args . data_path , \"train.csv\" )) images = df . image_id . values . tolist () images = [ os . path . join ( args . data_path , \"images\" , i + \".jpg\" ) for i in images ] targets = df [[ \"healthy\" , \"multiple_diseases\" , \"rust\" , \"scab\" ]] . values # \u521b\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u578b model = Model () model . to ( args . device ) # \u5b9a\u4e49\u5747\u503c\u548c\u6807\u51c6\u5dee\u4ee5\u53ca\u6570\u636e\u589e\u5f3a mean = ( 0.485 , 0.456 , 0.406 ) std = ( 0.229 , 0.224 , 0.225 ) aug = albumentations . Compose ( [ albumentations . Normalize ( mean , std , max_pixel_value = 255.0 , always_apply = True ) ] ) # \u5206\u5272\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6 ( train_images , valid_images , train_targets , valid_targets ) = train_test_split ( images , targets ) # \u521b\u5efa\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668 train_loader = ClassificationDataLoader ( image_paths = train_images , targets = train_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = True , tpu = False ) # \u521b\u5efa\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668 valid_loader = ClassificationDataLoader ( image_paths = valid_images , targets = valid_targets , resize = ( 128 , 128 ), augmentations = aug , ) . fetch ( batch_size = 16 , num_workers = 4 , drop_last = False , shuffle = False , tpu = False ) # \u521b\u5efa\u4f18\u5316\u5668 optimizer = torch . optim . Adam ( model . parameters (), lr = 5e-4 ) # \u521b\u5efa\u5b66\u4e60\u7387\u8c03\u5ea6\u5668 scheduler = torch . optim . lr_scheduler . StepLR ( optimizer , step_size = 15 , gamma = 0.6 ) # \u5faa\u73af\u8bad\u7ec3\u591a\u4e2a\u65f6\u671f for epoch in range ( args . epochs ): # \u8bad\u7ec3\u6a21\u578b train_loss = Engine . train ( train_loader , model , optimizer , device = args . device ) # \u8bc4\u4f30\u6a21\u578b valid_loss = Engine . evaluate ( valid_loader , model , device = args . device ) # \u6253\u5370\u635f\u5931\u4fe1\u606f print ( f \" { epoch } , Train Loss= { train_loss } Valid Loss= { valid_loss } \" ) \u6709\u4e86\u6570\u636e\u540e\uff0c\u5c31\u53ef\u4ee5\u8fd0\u884c\u811a\u672c\u4e86\uff1a python plant.py --data_path ../../plant_pathology --device cuda -- epochs 2 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .73it/s, loss = 0 .723 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .62it/s, loss = 0 .433 ] 0 , Train Loss = 0 .7228777609592261 Valid Loss = 0 .4327834551704341 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 86 /86 [ 00 :12< 00 :00, 6 .74it/s, loss = 0 .271 ] 100 % | \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 29 /29 [ 00 :04< 00 :00, 6 .63it/s, loss = 0 .568 ] 1 , Train Loss = 0 .2708700496790021 Valid Loss = 0 .56841839541649 \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u8ba9\u6211\u4eec\u6784\u5efa\u6a21\u578b\u53d8\u5f97\u7b80\u5355\uff0c\u4ee3\u7801\u4e5f\u6613\u4e8e\u9605\u8bfb\u548c\u7406\u89e3\u3002\u6ca1\u6709\u4efb\u4f55\u5c01\u88c5\u7684 PyTorch \u6548\u679c\u6700\u597d\u3002\u56fe\u50cf\u4e2d\u4e0d\u4ec5\u4ec5\u6709\u5206\u7c7b\uff0c\u8fd8\u6709\u5f88\u591a\u5176\u4ed6\u7684\u5185\u5bb9\uff0c\u5982\u679c\u6211\u5f00\u59cb\u5199\u6240\u6709\u7684\u5185\u5bb9\uff0c\u5c31\u5f97\u518d\u5199\u4e00\u672c\u4e66\u4e86\uff0c \u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55\u56fe\u50cf\u95ee\u9898\uff08\u4f5c\u8005\u5728\u5f00\u73a9\u7b11\uff09\u3002","title":"\u56fe\u50cf\u5206\u7c7b\u548c\u5206\u5272\u65b9\u6cd5"},{"location":"%E5%A4%84%E7%90%86%E5%88%86%E7%B1%BB%E5%8F%98%E9%87%8F/","text":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf \u5f88\u591a\u4eba\u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u65f6\u90fd\u4f1a\u9047\u5230\u5f88\u591a\u56f0\u96be\uff0c\u56e0\u6b64\u8fd9\u503c\u5f97\u7528\u6574\u6574\u4e00\u7ae0\u7684\u7bc7\u5e45\u6765\u8ba8\u8bba\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u5c06\u8bb2\u8ff0\u4e0d\u540c\u7c7b\u578b\u7684\u5206\u7c7b\u6570\u636e\uff0c\u4ee5\u53ca\u5982\u4f55\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u3002 \u4ec0\u4e48\u662f\u5206\u7c7b\u53d8\u91cf\uff1f \u5206\u7c7b\u53d8\u91cf/\u7279\u5f81\u662f\u6307\u4efb\u4f55\u7279\u5f81\u7c7b\u578b\uff0c\u53ef\u5206\u4e3a\u4e24\u5927\u7c7b\uff1a - \u65e0\u5e8f - \u6709\u5e8f \u65e0\u5e8f\u53d8\u91cf \u662f\u6307\u6709\u4e24\u4e2a\u6216\u4e24\u4e2a\u4ee5\u4e0a\u7c7b\u522b\u7684\u53d8\u91cf\uff0c\u8fd9\u4e9b\u7c7b\u522b\u6ca1\u6709\u4efb\u4f55\u76f8\u5173\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u5982\u679c\u5c06\u6027\u522b\u5206\u4e3a\u4e24\u7ec4\uff0c\u5373\u7537\u6027\u548c\u5973\u6027\uff0c\u5219\u53ef\u5c06\u5176\u89c6\u4e3a\u540d\u4e49\u53d8\u91cf\u3002 \u6709\u5e8f\u53d8\u91cf \u5219\u6709 \"\u7b49\u7ea7 \"\u6216\u7c7b\u522b\uff0c\u5e76\u6709\u7279\u5b9a\u7684\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u4e00\u4e2a\u987a\u5e8f\u5206\u7c7b\u53d8\u91cf\u53ef\u4ee5\u662f\u4e00\u4e2a\u5177\u6709\u4f4e\u3001\u4e2d\u3001\u9ad8\u4e09\u4e2a\u4e0d\u540c\u7b49\u7ea7\u7684\u7279\u5f81\u3002\u987a\u5e8f\u5f88\u91cd\u8981\u3002 \u5c31\u5b9a\u4e49\u800c\u8a00\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u53d8\u91cf\u5206\u4e3a \u4e8c\u5143\u53d8\u91cf \uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u7c7b\u522b\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u6709\u4e9b\u4eba\u751a\u81f3\u628a\u5206\u7c7b\u53d8\u91cf\u79f0\u4e3a \" \u5faa\u73af \"\u53d8\u91cf\u3002\u5468\u671f\u53d8\u91cf\u4ee5 \"\u5468\u671f \"\u7684\u5f62\u5f0f\u5b58\u5728\uff0c\u4f8b\u5982\u4e00\u5468\u4e2d\u7684\u5929\u6570\uff1a \u5468\u65e5\u3001\u5468\u4e00\u3001\u5468\u4e8c\u3001\u5468\u4e09\u3001\u5468\u56db\u3001\u5468\u4e94\u548c\u5468\u516d\u3002\u5468\u516d\u8fc7\u540e\uff0c\u53c8\u662f\u5468\u65e5\u3002\u8fd9\u5c31\u662f\u4e00\u4e2a\u5faa\u73af\u3002\u53e6\u4e00\u4e2a\u4f8b\u5b50\u662f\u4e00\u5929\u4e2d\u7684\u5c0f\u65f6\u6570\uff0c\u5982\u679c\u6211\u4eec\u5c06\u5b83\u4eec\u89c6\u4e3a\u7c7b\u522b\u7684\u8bdd\u3002 \u5206\u7c7b\u53d8\u91cf\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9a\u4e49\uff0c\u5f88\u591a\u4eba\u4e5f\u8c08\u5230\u8981\u6839\u636e\u5206\u7c7b\u53d8\u91cf\u7684\u7c7b\u578b\u6765\u5904\u7406\u4e0d\u540c\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u4e0d\u8fc7\uff0c\u6211\u8ba4\u4e3a\u6ca1\u6709\u5fc5\u8981\u8fd9\u6837\u505a\u3002\u6240\u6709\u6d89\u53ca\u5206\u7c7b\u53d8\u91cf\u7684\u95ee\u9898\u90fd\u53ef\u4ee5\u7528\u540c\u6837\u7684\u65b9\u6cd5\u5904\u7406\u3002 \u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e00\u4e2a\u6570\u636e\u96c6\uff08\u4e00\u5982\u65e2\u5f80\uff09\u3002\u8981\u4e86\u89e3\u5206\u7c7b\u53d8\u91cf\uff0c\u6700\u597d\u7684\u514d\u8d39\u6570\u636e\u96c6\u4e4b\u4e00\u662f Kaggle \u5206\u7c7b\u7279\u5f81\u7f16\u7801\u6311\u6218\u8d5b\u4e2d\u7684 cat-in-the-dat \u3002\u5171\u6709\u4e24\u4e2a\u6311\u6218\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6311\u6218\u7684\u6570\u636e\uff0c\u56e0\u4e3a\u5b83\u6bd4\u524d\u4e00\u4e2a\u7248\u672c\u6709\u66f4\u591a\u53d8\u91cf\uff0c\u96be\u5ea6\u4e5f\u66f4\u5927\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u3002 \u56fe 1\uff1aCat-in-the-dat-ii challenge\u90e8\u5206\u6570\u636e\u5c55\u793a \u6570\u636e\u96c6\u7531\u5404\u79cd\u5206\u7c7b\u53d8\u91cf\u7ec4\u6210\uff1a \u65e0\u5e8f \u6709\u5e8f \u5faa\u73af \u4e8c\u5143 \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6240\u6709\u5b58\u5728\u7684\u53d8\u91cf\u548c\u76ee\u6807\u53d8\u91cf\u7684\u5b50\u96c6\u3002 \u8fd9\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002 \u76ee\u6807\u53d8\u91cf\u5bf9\u4e8e\u6211\u4eec\u5b66\u4e60\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\u5e76\u4e0d\u5341\u5206\u91cd\u8981\uff0c\u4f46\u6700\u7ec8\u6211\u4eec\u5c06\u5efa\u7acb\u4e00\u4e2a\u7aef\u5230\u7aef\u6a21\u578b\uff0c\u56e0\u6b64\u8ba9\u6211\u4eec\u770b\u770b\u56fe 2 \u4e2d\u7684\u76ee\u6807\u53d8\u91cf\u5206\u5e03\u3002\u6211\u4eec\u770b\u5230\u76ee\u6807\u662f \u504f\u659c \u7684\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8fd9\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u6765\u8bf4\uff0c\u6700\u597d\u7684\u6307\u6807\u662f ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528\u7cbe\u786e\u5ea6\u548c\u53ec\u56de\u7387\uff0c\u4f46 AUC \u7ed3\u5408\u4e86\u8fd9\u4e24\u4e2a\u6307\u6807\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 AUC \u6765\u8bc4\u4f30\u6211\u4eec\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u7684\u6a21\u578b\u3002 \u56fe 2\uff1a\u6807\u7b7e\u8ba1\u6570\u3002X \u8f74\u8868\u793a\u6807\u7b7e\uff0cY \u8f74\u8868\u793a\u6807\u7b7e\u8ba1\u6570 \u603b\u4f53\u800c\u8a00\uff0c\u6709\uff1a 5\u4e2a\u4e8c\u5143\u53d8\u91cf 10\u4e2a\u65e0\u5e8f\u53d8\u91cf 6\u4e2a\u6709\u5e8f\u53d8\u91cf 2\u4e2a\u5faa\u73af\u53d8\u91cf 1\u4e2a\u76ee\u6807\u53d8\u91cf \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u96c6\u4e2d\u7684 ord_2 \u7279\u5f81\u3002\u5b83\u5305\u62ec6\u4e2a\u4e0d\u540c\u7684\u7c7b\u522b\uff1a - \u51b0\u51bb - \u6e29\u6696 - \u5bd2\u51b7 - \u8f83\u70ed - \u70ed - \u975e\u5e38\u70ed \u6211\u4eec\u5fc5\u987b\u77e5\u9053\uff0c\u8ba1\u7b97\u673a\u65e0\u6cd5\u7406\u89e3\u6587\u672c\u6570\u636e\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u662f\u521b\u5efa\u4e00\u4e2a\u5b57\u5178\uff0c\u5c06\u8fd9\u4e9b\u503c\u6620\u5c04\u4e3a\u4ece 0 \u5230 N-1 \u7684\u6570\u5b57\uff0c\u5176\u4e2d N \u662f\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7c7b\u522b\u7684\u603b\u6570\u3002 # \u6620\u5c04\u5b57\u5178 mapping = { \"Freezing\" : 0 , \"Warm\" : 1 , \"Cold\" : 2 , \"Boiling Hot\" : 3 , \"Hot\" : 4 , \"Lava Hot\" : 5 } \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8bfb\u53d6\u6570\u636e\u96c6\uff0c\u5e76\u8f7b\u677e\u5730\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002 import pandas as pd # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u53d6*ord_2*\u5217\uff0c\u5e76\u4f7f\u7528\u6620\u5c04\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57 df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. map ( mapping ) \u6620\u5c04\u524d\u7684\u6570\u503c\u8ba1\u6570\uff1a df .* ord_2 *. value_counts () Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : * ord_2 * , dtype : int64 \u6620\u5c04\u540e\u7684\u6570\u503c\u8ba1\u6570\uff1a 0.0 142726 1.0 124239 2.0 97822 3.0 84790 4.0 67508 5.0 64840 Name : * ord_2 * , dtype : int64 \u8fd9\u79cd\u5206\u7c7b\u53d8\u91cf\u7684\u7f16\u7801\u65b9\u5f0f\u88ab\u79f0\u4e3a\u6807\u7b7e\u7f16\u7801\uff08Label Encoding\uff09\u6211\u4eec\u5c06\u6bcf\u4e2a\u7c7b\u522b\u7f16\u7801\u4e3a\u4e00\u4e2a\u6570\u5b57\u6807\u7b7e\u3002 \u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 LabelEncoder \u8fdb\u884c\u7f16\u7801\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u5c06\u7f3a\u5931\u503c\u586b\u5145\u4e3a\"NONE\" df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. fillna ( \"NONE\" ) # LabelEncoder\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u8f6c\u6362\u6570\u636e df . loc [:, \"*ord_2*\" ] = lbl_enc . fit_transform ( df .* ord_2 *. values ) \u4f60\u4f1a\u770b\u5230\u6211\u4f7f\u7528\u4e86 pandas \u7684 fillna\u3002\u539f\u56e0\u662f scikit-learn \u7684 LabelEncoder \u65e0\u6cd5\u5904\u7406 NaN \u503c\uff0c\u800c ord_2 \u5217\u4e2d\u6709 NaN \u503c\u3002 \u6211\u4eec\u53ef\u4ee5\u5728\u8bb8\u591a\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u4e2d\u76f4\u63a5\u4f7f\u7528\u5b83\uff1a - \u51b3\u7b56\u6811 - \u968f\u673a\u68ee\u6797 - \u63d0\u5347\u6811 - \u6216\u4efb\u4f55\u4e00\u79cd\u63d0\u5347\u6811\u6a21\u578b - XGBoost - GBM - LightGBM \u8fd9\u79cd\u7f16\u7801\u65b9\u5f0f\u4e0d\u80fd\u7528\u4e8e\u7ebf\u6027\u6a21\u578b\u3001\u652f\u6301\u5411\u91cf\u673a\u6216\u795e\u7ecf\u7f51\u7edc\uff0c\u56e0\u4e3a\u5b83\u4eec\u5e0c\u671b\u6570\u636e\u662f\u6807\u51c6\u5316\u7684\u3002 \u5bf9\u4e8e\u8fd9\u4e9b\u7c7b\u578b\u7684\u6a21\u578b\uff0c\u6211\u4eec\u53ef\u4ee5\u5bf9\u6570\u636e\u8fdb\u884c\u4e8c\u503c\u5316\uff08binarize\uff09\u5904\u7406\u3002 Freezing --> 0 --> 0 0 0 Warm --> 1 --> 0 0 1 Cold --> 2 --> 0 1 0 Boiling Hot --> 3 --> 0 1 1 Hot --> 4 --> 1 0 0 Lava Hot --> 5 --> 1 0 1 \u8fd9\u53ea\u662f\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\uff0c\u7136\u540e\u518d\u8f6c\u6362\u4e3a\u4e8c\u503c\u5316\u8868\u793a\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u628a\u4e00\u4e2a\u7279\u5f81\u5206\u6210\u4e86\u4e09\u4e2a\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u7279\u5f81\uff08\u6216\u5217\uff09\u3002\u5982\u679c\u6211\u4eec\u6709\u66f4\u591a\u7684\u7c7b\u522b\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u5206\u6210\u66f4\u591a\u7684\u5217\u3002 \u5982\u679c\u6211\u4eec\u7528\u7a00\u758f\u683c\u5f0f\u5b58\u50a8\u5927\u91cf\u4e8c\u503c\u5316\u53d8\u91cf\uff0c\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u5b58\u50a8\u8fd9\u4e9b\u53d8\u91cf\u3002\u7a00\u758f\u683c\u5f0f\u4e0d\u8fc7\u662f\u4e00\u79cd\u5728\u5185\u5b58\u4e2d\u5b58\u50a8\u6570\u636e\u7684\u8868\u793a\u6216\u65b9\u5f0f\uff0c\u5728\u8fd9\u79cd\u683c\u5f0f\u4e2d\uff0c\u4f60\u5e76\u4e0d\u5b58\u50a8\u6240\u6709\u7684\u503c\uff0c\u800c\u53ea\u5b58\u50a8\u91cd\u8981\u7684\u503c\u3002\u5728\u4e0a\u8ff0\u4e8c\u8fdb\u5236\u53d8\u91cf\u7684\u60c5\u51b5\u4e2d\uff0c\u6700\u91cd\u8981\u7684\u5c31\u662f\u6709 1 \u7684\u5730\u65b9\u3002 \u5f88\u96be\u60f3\u8c61\u8fd9\u6837\u7684\u683c\u5f0f\uff0c\u4f46\u4e3e\u4e2a\u4f8b\u5b50\u5c31\u4f1a\u660e\u767d\u3002 \u5047\u8bbe\u4e0a\u9762\u7684\u6570\u636e\u5e27\u4e2d\u53ea\u6709\u4e00\u4e2a\u7279\u5f81\uff1a ord_2 \u3002 Index Feature 0 Warm 1 Hot 2 Lava hot \u76ee\u524d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6570\u636e\u96c6\u4e2d\u7684\u4e09\u4e2a\u6837\u672c\u3002\u8ba9\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u4e8c\u503c\u8868\u793a\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u6709\u4e09\u4e2a\u9879\u76ee\u3002 \u8fd9\u4e09\u4e2a\u9879\u76ee\u5c31\u662f\u4e09\u4e2a\u7279\u5f81\u3002 Index Feature_0 Feature_1 Feature_2 0 0 0 1 1 1 0 0 2 1 0 1 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u7279\u5f81\u5b58\u50a8\u5728\u4e00\u4e2a\u6709 3 \u884c 3 \u5217\uff083x3\uff09\u7684\u77e9\u9635\u4e2d\u3002\u77e9\u9635\u7684\u6bcf\u4e2a\u5143\u7d20\u5360\u7528 8 \u4e2a\u5b57\u8282\u3002\u56e0\u6b64\uff0c\u8fd9\u4e2a\u6570\u7ec4\u7684\u603b\u5185\u5b58\u9700\u6c42\u4e3a 8x3x3 = 72 \u5b57\u8282\u3002 \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e2a\u7b80\u5355\u7684 python \u4ee3\u7801\u6bb5\u6765\u68c0\u67e5\u8fd9\u4e00\u70b9\u3002 import numpy as np example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) print ( example . nbytes ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u6253\u5370\u51fa 72\uff0c\u5c31\u50cf\u6211\u4eec\u4e4b\u524d\u8ba1\u7b97\u7684\u90a3\u6837\u3002\u4f46\u6211\u4eec\u9700\u8981\u5b58\u50a8\u8fd9\u4e2a\u77e9\u9635\u7684\u6240\u6709\u5143\u7d20\u5417\uff1f\u5982\u524d\u6240\u8ff0\uff0c\u6211\u4eec\u53ea\u5bf9 1 \u611f\u5174\u8da3\u30020 \u5e76\u4e0d\u91cd\u8981\uff0c\u56e0\u4e3a\u4efb\u4f55\u4e0e 0 \u76f8\u4e58\u7684\u5143\u7d20\u90fd\u662f 0\uff0c\u800c 0 \u4e0e\u4efb\u4f55\u5143\u7d20\u76f8\u52a0\u6216\u76f8\u51cf\u4e5f\u6ca1\u6709\u4efb\u4f55\u533a\u522b\u3002\u53ea\u7528 1 \u8868\u793a\u77e9\u9635\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u67d0\u79cd\u5b57\u5178\u65b9\u6cd5\uff0c\u5176\u4e2d\u952e\u662f\u884c\u548c\u5217\u7684\u7d22\u5f15\uff0c\u503c\u662f 1\uff1a ( 0 , 2 ) 1 ( 1 , 0 ) 1 ( 2 , 0 ) 1 ( 2 , 2 ) 1 \u8fd9\u6837\u7684\u7b26\u53f7\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u5b58\u50a8\u56db\u4e2a\u503c\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u3002\u4f7f\u7528\u7684\u603b\u5185\u5b58\u4e3a 8x4 = 32 \u5b57\u8282\u3002\u4efb\u4f55 numpy \u6570\u7ec4\u90fd\u53ef\u4ee5\u901a\u8fc7\u7b80\u5355\u7684 python \u4ee3\u7801\u8f6c\u6362\u4e3a\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) sparse_example = sparse . csr_matrix ( example ) print ( sparse_example . data . nbytes ) \u8fd9\u5c06\u6253\u5370 32\uff0c\u6bd4\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u5c11\u4e86\u8fd9\u4e48\u591a\uff01\u7a00\u758f csr \u77e9\u9635\u7684\u603b\u5927\u5c0f\u662f\u4e09\u4e2a\u503c\u7684\u603b\u548c\u3002 print ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) \u8fd9\u5c06\u6253\u5370\u51fa 64 \u4e2a\u5143\u7d20\uff0c\u4ecd\u7136\u5c11\u4e8e\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u3002\u9057\u61be\u7684\u662f\uff0c\u6211\u4e0d\u4f1a\u8be6\u7ec6\u4ecb\u7ecd\u8fd9\u4e9b\u5143\u7d20\u3002\u4f60\u53ef\u4ee5\u5728 scipy \u6587\u6863\u4e2d\u4e86\u89e3\u66f4\u591a\u3002\u5f53\u6211\u4eec\u62e5\u6709\u66f4\u5927\u7684\u6570\u7ec4\u65f6\uff0c\u6bd4\u5982\u8bf4\u62e5\u6709\u6570\u5343\u4e2a\u6837\u672c\u548c\u6570\u4e07\u4e2a\u7279\u5f81\u7684\u6570\u7ec4\uff0c\u5927\u5c0f\u5dee\u5f02\u5c31\u4f1a\u53d8\u5f97\u975e\u5e38\u5927\u3002\u4f8b\u5982\uff0c\u6211\u4eec\u4f7f\u7528\u57fa\u4e8e\u8ba1\u6570\u7279\u5f81\u7684\u6587\u672c\u6570\u636e\u96c6\u3002 import numpy as np from scipy import sparse n_rows = 10000 n_cols = 100000 # \u751f\u6210\u7b26\u5408\u4f2f\u52aa\u5229\u5206\u5e03\u7684\u968f\u673a\u6570\u7ec4\uff0c\u7ef4\u5ea6\u4e3a[10000, 100000] example = np . random . binomial ( 1 , p = 0.05 , size = ( n_rows , n_cols )) print ( f \"Size of dense array: { example . nbytes } \" ) # \u5c06\u968f\u673a\u77e9\u9635\u8f6c\u6362\u4e3a\u6d17\u6f31\u77e9\u9635 sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u8fd9\u5c06\u6253\u5370\uff1a Size of dense array : 8000000000 Size of sparse array : 399932496 Full size of sparse array : 599938748 \u56e0\u6b64\uff0c\u5bc6\u96c6\u9635\u5217\u9700\u8981 ~8000MB \u6216\u5927\u7ea6 8GB \u5185\u5b58\u3002\u800c\u7a00\u758f\u9635\u5217\u53ea\u5360\u7528 399MB \u5185\u5b58\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5f53\u6211\u4eec\u7684\u7279\u5f81\u4e2d\u6709\u5927\u91cf\u96f6\u65f6\uff0c\u6211\u4eec\u66f4\u559c\u6b22\u7a00\u758f\u9635\u5217\u800c\u4e0d\u662f\u5bc6\u96c6\u9635\u5217\u7684\u539f\u56e0\u3002 \u8bf7\u6ce8\u610f\uff0c\u7a00\u758f\u77e9\u9635\u6709\u591a\u79cd\u4e0d\u540c\u7684\u8868\u793a\u65b9\u6cd5\u3002\u8fd9\u91cc\u6211\u53ea\u5c55\u793a\u4e86\u5176\u4e2d\u4e00\u79cd\uff08\u53ef\u80fd\u4e5f\u662f\u6700\u5e38\u7528\u7684\uff09\u65b9\u6cd5\u3002\u6df1\u5165\u63a2\u8ba8\u8fd9\u4e9b\u65b9\u6cd5\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\uff0c\u56e0\u6b64\u7559\u7ed9\u8bfb\u8005\u4e00\u4e2a\u7ec3\u4e60\u3002 \u5c3d\u7ba1\u4e8c\u503c\u5316\u7279\u5f81\u7684\u7a00\u758f\u8868\u793a\u6bd4\u5176\u5bc6\u96c6\u8868\u793a\u6240\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u4f46\u5bf9\u4e8e\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\uff0c\u8fd8\u6709\u4e00\u79cd\u8f6c\u6362\u6240\u5360\u7528\u7684\u5185\u5b58\u66f4\u5c11\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \" \u72ec\u70ed\u7f16\u7801 \"\u3002 \u72ec\u70ed\u7f16\u7801\u4e5f\u662f\u4e00\u79cd\u4e8c\u503c\u7f16\u7801\uff0c\u56e0\u4e3a\u53ea\u6709 0 \u548c 1 \u4e24\u4e2a\u503c\u3002\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5b83\u5e76\u4e0d\u662f\u4e8c\u503c\u8868\u793a\u6cd5\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0b\u9762\u7684\u4f8b\u5b50\u6765\u7406\u89e3\u5b83\u7684\u8868\u793a\u6cd5\u3002 \u5047\u8bbe\u6211\u4eec\u7528\u4e00\u4e2a\u5411\u91cf\u6765\u8868\u793a ord_2 \u53d8\u91cf\u7684\u6bcf\u4e2a\u7c7b\u522b\u3002\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e0e ord_2 \u53d8\u91cf\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u6bcf\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u90fd\u662f 6\uff0c\u5e76\u4e14\u9664\u4e86\u4e00\u4e2a\u4f4d\u7f6e\u5916\uff0c\u5176\u4ed6\u4f4d\u7f6e\u90fd\u662f 0\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u7279\u6b8a\u7684\u5411\u91cf\u8868\u3002 Freezing 0 0 0 0 0 1 Warm 0 0 0 0 1 0 Cold 0 0 0 1 0 0 Boiling Hot 0 0 1 0 0 0 Hot 0 1 0 0 0 0 Lava Hot 1 0 0 0 0 0 \u6211\u4eec\u770b\u5230\u5411\u91cf\u7684\u5927\u5c0f\u662f 1x6\uff0c\u5373\u5411\u91cf\u4e2d\u67096\u4e2a\u5143\u7d20\u3002\u8fd9\u4e2a\u6570\u5b57\u662f\u600e\u4e48\u6765\u7684\u5462\uff1f\u5982\u679c\u4f60\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u5982\u524d\u6240\u8ff0\uff0c\u67096\u4e2a\u7c7b\u522b\u3002\u5728\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u65f6\uff0c\u5411\u91cf\u7684\u5927\u5c0f\u5fc5\u987b\u4e0e\u6211\u4eec\u8981\u67e5\u770b\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u6bcf\u4e2a\u5411\u91cf\u90fd\u6709\u4e00\u4e2a 1\uff0c\u5176\u4f59\u6240\u6709\u503c\u90fd\u662f 0\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u7528\u8fd9\u4e9b\u7279\u5f81\u6765\u4ee3\u66ff\u4e4b\u524d\u7684\u4e8c\u503c\u5316\u7279\u5f81\uff0c\u770b\u770b\u80fd\u8282\u7701\u591a\u5c11\u5185\u5b58\u3002 \u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\u4ee5\u524d\u7684\u6570\u636e\uff0c\u5b83\u770b\u8d77\u6765\u5982\u4e0b\uff1a Index Feature 0 Warm 1 Hot 2 Lava hot \u6bcf\u4e2a\u6837\u672c\u67093\u4e2a\u7279\u5f81\u3002\u4f46\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u72ec\u70ed\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 6\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u67096\u4e2a\u7279\u5f81\uff0c\u800c\u4e0d\u662f3\u4e2a\u3002 Index F_0 F_1 F_2 F_3 F_4 F_5 0 0 0 0 0 1 0 1 0 1 0 0 0 0 2 1 0 1 0 0 0 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 6 \u4e2a\u7279\u5f81\uff0c\u800c\u5728\u8fd9\u4e2a 3x6 \u6570\u7ec4\u4e2d\uff0c\u53ea\u6709 3 \u4e2a1\u3002\u4f7f\u7528 numpy \u8ba1\u7b97\u5927\u5c0f\u4e0e\u4e8c\u503c\u5316\u5927\u5c0f\u8ba1\u7b97\u811a\u672c\u975e\u5e38\u76f8\u4f3c\u3002\u4f60\u9700\u8981\u6539\u53d8\u7684\u53ea\u662f\u6570\u7ec4\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6bb5\u4ee3\u7801\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 0 , 0 , 1 , 0 ], [ 0 , 1 , 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 , 0 , 0 ] ] ) print ( f \"Size of dense array: { example . nbytes } \" ) sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u6253\u5370\u5185\u5b58\u5927\u5c0f\u4e3a\uff1a Size of dense array : 144 Size of sparse array : 24 Full size of sparse array : 52 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5bc6\u96c6\u77e9\u9635\u7684\u5927\u5c0f\u8fdc\u8fdc\u5927\u4e8e\u4e8c\u503c\u5316\u77e9\u9635\u7684\u5927\u5c0f\u3002\u4e0d\u8fc7\uff0c\u7a00\u758f\u6570\u7ec4\u7684\u5927\u5c0f\u8981\u66f4\u5c0f\u3002\u8ba9\u6211\u4eec\u7528\u66f4\u5927\u7684\u6570\u7ec4\u6765\u8bd5\u8bd5\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 OneHotEncoder \u5c06\u5305\u542b 1001 \u4e2a\u7c7b\u522b\u7684\u7279\u5f81\u6570\u7ec4\u8f6c\u6362\u4e3a\u5bc6\u96c6\u77e9\u9635\u548c\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from sklearn import preprocessing # \u751f\u6210\u7b26\u5408\u5747\u5300\u5206\u5e03\u7684\u968f\u673a\u6574\u6570\uff0c\u7ef4\u5ea6\u4e3a[1000000, 10000000] example = np . random . randint ( 1000 , size = 1000000 ) # \u72ec\u70ed\u7f16\u7801\uff0c\u975e\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = False ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of dense array: { ohe_example . nbytes } \" ) # \u72ec\u70ed\u7f16\u7801\uff0c\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = True ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of sparse array: { ohe_example . data . nbytes } \" ) full_size = ( ohe_example . data . nbytes + ohe_example . indptr . nbytes + ohe_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u4e0a\u9762\u4ee3\u7801\u6253\u5370\u7684\u8f93\u51fa\uff1a Size of dense array : 8000000000 Size of sparse array : 8000000 Full size of sparse array : 16000004 \u8fd9\u91cc\u7684\u5bc6\u96c6\u9635\u5217\u5927\u5c0f\u7ea6\u4e3a 8GB\uff0c\u7a00\u758f\u9635\u5217\u4e3a 8MB\u3002\u5982\u679c\u53ef\u4ee5\u9009\u62e9\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff1f\u5728\u6211\u770b\u6765\uff0c\u9009\u62e9\u5f88\u7b80\u5355\uff0c\u4e0d\u662f\u5417\uff1f \u8fd9\u4e09\u79cd\u65b9\u6cd5\uff08\u6807\u7b7e\u7f16\u7801\u3001\u7a00\u758f\u77e9\u9635\u3001\u72ec\u70ed\u7f16\u7801\uff09\u662f\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u7684\u6700\u91cd\u8981\u65b9\u6cd5\u3002\u4e0d\u8fc7\uff0c\u4f60\u8fd8\u53ef\u4ee5\u7528\u5f88\u591a\u5176\u4ed6\u4e0d\u540c\u7684\u65b9\u6cd5\u6765\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u3002\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u6570\u503c\u53d8\u91cf\u5c31\u662f\u5176\u4e2d\u7684\u4e00\u4e2a\u4f8b\u5b50\u3002 \u5047\u8bbe\u6211\u4eec\u56de\u5230\u4e4b\u524d\u7684\u5206\u7c7b\u7279\u5f81\u6570\u636e\uff08\u539f\u59cb\u6570\u636e\u4e2d\u7684 cat-in-the-dat-ii\uff09\u3002\u5728\u6570\u636e\u4e2d\uff0c ord_2 \u7684\u503c\u4e3a\u201c\u70ed\u201c\u7684 id \u6709\u591a\u5c11\uff1f \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8ba1\u7b97\u6570\u636e\u7684\u5f62\u72b6\uff08shape\uff09\u8f7b\u677e\u8ba1\u7b97\u51fa\u8fd9\u4e2a\u503c\uff0c\u5176\u4e2d ord_2 \u5217\u7684\u503c\u4e3a Boiling Hot \u3002 In [ X ]: df [ df . ord_2 == \"Boiling Hot\" ] . shape Out [ X ]: ( 84790 , 25 ) \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 84790 \u6761\u8bb0\u5f55\u5177\u6709\u6b64\u503c\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 pandas \u4e2d\u7684 groupby \u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u8be5\u503c\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . count () Out [ X ]: ord_2 Boiling Hot 84790 Cold 97822 Freezing 142726 Hot 67508 Lava Hot 64840 Warm 124239 Name : id , dtype : int64 \u5982\u679c\u6211\u4eec\u53ea\u662f\u5c06 ord_2 \u5217\u66ff\u6362\u4e3a\u5176\u8ba1\u6570\u503c\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u5c06\u5176\u8f6c\u6362\u4e3a\u4e00\u79cd\u6570\u503c\u7279\u5f81\u4e86\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 transform \u51fd\u6570\u548c groupby \u6765\u521b\u5efa\u65b0\u5217\u6216\u66ff\u6362\u8fd9\u4e00\u5217\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . transform ( \"count\" ) Out [ X ]: 0 67508.0 1 124239.0 2 142726.0 3 64840.0 4 97822.0 ... 599995 142726.0 599996 84790.0 599997 142726.0 599998 124239.0 599999 84790.0 Name : id , Length : 600000 , dtype : float64 \u4f60\u53ef\u4ee5\u6dfb\u52a0\u6240\u6709\u7279\u5f81\u7684\u8ba1\u6570\uff0c\u4e5f\u53ef\u4ee5\u66ff\u6362\u5b83\u4eec\uff0c\u6216\u8005\u6839\u636e\u591a\u4e2a\u5217\u53ca\u5176\u8ba1\u6570\u8fdb\u884c\u5206\u7ec4\u3002\u4f8b\u5982\uff0c\u4ee5\u4e0b\u4ee3\u7801\u901a\u8fc7\u5bf9 ord_1 \u548c ord_2 \u5217\u5206\u7ec4\u8fdb\u884c\u8ba1\u6570\u3002 In [ X ]: df . groupby ( ... : [ ... : \"ord_1\" , ... : \"ord_2\" ... : ] ... : )[ \"id\" ] . count () . reset_index ( name = \"count\" ) Out [ X ]: ord_1 ord_2 count 0 Contributor Boiling Hot 15634 1 Contributor Cold 17734 2 Contributor Freezing 26082 3 Contributor Hot 12428 4 Contributor Lava Hot 11919 5 Contributor Warm 22774 6 Expert Boiling Hot 19477 7 Expert Cold 22956 8 Expert Freezing 33249 9 Expert Hot 15792 10 Expert Lava Hot 15078 11 Expert Warm 28900 12 Grandmaster Boiling Hot 13623 13 Grandmaster Cold 15464 14 Grandmaster Freezing 22818 15 Grandmaster Hot 10805 16 Grandmaster Lava Hot 10363 17 Grandmaster Warm 19899 18 Master Boiling Hot 10800 ... \u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u884c\uff0c\u4ee5\u4fbf\u5728\u4e00\u9875\u4e2d\u5bb9\u7eb3\u8fd9\u4e9b\u884c\u3002\u8fd9\u662f\u53e6\u4e00\u79cd\u53ef\u4ee5\u4f5c\u4e3a\u529f\u80fd\u6dfb\u52a0\u7684\u8ba1\u6570\u3002\u60a8\u73b0\u5728\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4f7f\u7528 id \u5217\u8fdb\u884c\u8ba1\u6570\u3002\u4e0d\u8fc7\uff0c\u4f60\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5bf9\u5217\u7684\u7ec4\u5408\u8fdb\u884c\u5206\u7ec4\uff0c\u5bf9\u5176\u4ed6\u5217\u8fdb\u884c\u8ba1\u6570\u3002 \u8fd8\u6709\u4e00\u4e2a\u5c0f\u7a8d\u95e8\uff0c\u5c31\u662f\u4ece\u8fd9\u4e9b\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u4ece\u73b0\u6709\u7684\u7279\u5f81\u4e2d\u521b\u5efa\u65b0\u7684\u5206\u7c7b\u7279\u5f81\uff0c\u800c\u4e14\u53ef\u4ee5\u6beb\u4e0d\u8d39\u529b\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot 1 Grandmaster_Warm 2 nan_Freezing 3 Novice_Lava Hot 4 Grandmaster_Cold ... 599999 Contributor_Boiling Hot Name : new_feature , Length : 600000 , dtype : object \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u7528\u4e0b\u5212\u7ebf\u5c06 ord_1 \u548c ord_2 \u5408\u5e76\uff0c\u7136\u540e\u5c06\u8fd9\u4e9b\u5217\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u7c7b\u578b\u3002\u8bf7\u6ce8\u610f\uff0cNaN \u4e5f\u4f1a\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u3002\u4e0d\u8fc7\u6ca1\u5173\u7cfb\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06 NaN \u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u6709\u4e86\u4e00\u4e2a\u7531\u8fd9\u4e24\u4e2a\u7279\u5f81\u7ec4\u5408\u800c\u6210\u7684\u65b0\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u5c06\u4e09\u5217\u4ee5\u4e0a\u6216\u56db\u5217\u751a\u81f3\u66f4\u591a\u5217\u7ec4\u5408\u5728\u4e00\u8d77\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : + \"_\" ... : + df . ord_3 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot_c 1 Grandmaster_Warm_e 2 nan_Freezing_n 3 Novice_Lava Hot_a 4 Grandmaster_Cold_h ... 599999 Contributor_Boiling Hot_b Name : new_feature , Length : 600000 , dtype : object \u90a3\u4e48\uff0c\u6211\u4eec\u5e94\u8be5\u628a\u54ea\u4e9b\u7c7b\u522b\u7ed3\u5408\u8d77\u6765\u5462\uff1f\u8fd9\u5e76\u6ca1\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u7b54\u6848\u3002\u8fd9\u53d6\u51b3\u4e8e\u60a8\u7684\u6570\u636e\u548c\u7279\u5f81\u7c7b\u578b\u3002\u4e00\u4e9b\u9886\u57df\u77e5\u8bc6\u5bf9\u4e8e\u521b\u5efa\u8fd9\u6837\u7684\u7279\u5f81\u53ef\u80fd\u5f88\u6709\u7528\u3002\u4f46\u662f\uff0c\u5982\u679c\u4f60\u4e0d\u62c5\u5fc3\u5185\u5b58\u548c CPU \u7684\u4f7f\u7528\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\uff0c\u5373\u521b\u5efa\u8bb8\u591a\u8fd9\u6837\u7684\u7ec4\u5408\uff0c\u7136\u540e\u4f7f\u7528\u4e00\u4e2a\u6a21\u578b\u6765\u51b3\u5b9a\u54ea\u4e9b\u7279\u5f81\u662f\u6709\u7528\u7684\uff0c\u5e76\u4fdd\u7559\u5b83\u4eec\u3002\u6211\u4eec\u5c06\u5728\u672c\u4e66\u7a0d\u540e\u90e8\u5206\u4ecb\u7ecd\u8fd9\u79cd\u65b9\u6cd5\u3002 \u65e0\u8bba\u4f55\u65f6\u83b7\u5f97\u5206\u7c7b\u53d8\u91cf\uff0c\u90fd\u8981\u9075\u5faa\u4ee5\u4e0b\u7b80\u5355\u6b65\u9aa4\uff1a - \u586b\u5145 NaN \u503c\uff08\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff01\uff09\u3002 - \u4f7f\u7528 scikit-learn \u7684 LabelEncoder \u6216\u6620\u5c04\u5b57\u5178\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\uff0c\u5c06\u5b83\u4eec\u8f6c\u6362\u4e3a\u6574\u6570\u3002\u5982\u679c\u6ca1\u6709\u586b\u5145 NaN \u503c\uff0c\u53ef\u80fd\u9700\u8981\u5728\u8fd9\u4e00\u6b65\u4e2d\u8fdb\u884c\u5904\u7406 - \u521b\u5efa\u72ec\u70ed\u7f16\u7801\u3002\u662f\u7684\uff0c\u4f60\u53ef\u4ee5\u8df3\u8fc7\u4e8c\u503c\u5316\uff01 - \u5efa\u6a21\uff01\u6211\u6307\u7684\u662f\u673a\u5668\u5b66\u4e60\u3002 \u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u5904\u7406 NaN \u6570\u636e\u975e\u5e38\u91cd\u8981\uff0c\u5426\u5219\u60a8\u53ef\u80fd\u4f1a\u4ece scikit-learn \u7684 LabelEncoder \u4e2d\u5f97\u5230\u81ed\u540d\u662d\u8457\u7684\u9519\u8bef\u4fe1\u606f\uff1a ValueError: y \u5305\u542b\u4ee5\u524d\u672a\u89c1\u8fc7\u7684\u6807\u7b7e\uff1a [Nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan) \u8fd9\u4ec5\u4ec5\u610f\u5473\u7740\uff0c\u5728\u8f6c\u6362\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6570\u636e\u4e2d\u51fa\u73b0\u4e86 NaN \u503c\u3002\u8fd9\u662f\u56e0\u4e3a\u4f60\u5728\u8bad\u7ec3\u65f6\u5fd8\u8bb0\u4e86\u5904\u7406\u5b83\u4eec\u3002 \u5904\u7406 NaN \u503c \u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u4e22\u5f03\u5b83\u4eec\u3002\u867d\u7136\u7b80\u5355\uff0c\u4f46\u5e76\u4e0d\u7406\u60f3\u3002NaN \u503c\u4e2d\u53ef\u80fd\u5305\u542b\u5f88\u591a\u4fe1\u606f\uff0c\u5982\u679c\u53ea\u662f\u4e22\u5f03\u8fd9\u4e9b\u503c\uff0c\u5c31\u4f1a\u4e22\u5931\u8fd9\u4e9b\u4fe1\u606f\u3002\u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u5927\u90e8\u5206\u6570\u636e\u90fd\u662f NaN \u503c\uff0c\u56e0\u6b64\u4e0d\u80fd\u4e22\u5f03 NaN \u503c\u7684\u884c/\u6837\u672c\u3002\u5904\u7406 NaN \u503c\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u5c06\u5176\u4f5c\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u662f\u5904\u7406 NaN \u503c\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002\u5982\u679c\u4f7f\u7528 pandas\uff0c\u8fd8\u53ef\u4ee5\u901a\u8fc7\u975e\u5e38\u7b80\u5355\u7684\u65b9\u5f0f\u5b9e\u73b0\u3002 \u8bf7\u770b\u6211\u4eec\u4e4b\u524d\u67e5\u770b\u8fc7\u7684\u6570\u636e\u7684 ord_2 \u5217\u3002 In [ X ]: df . ord_2 . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : ord_2 , dtype : int64 \u586b\u5165 NaN \u503c\u540e\uff0c\u5c31\u53d8\u6210\u4e86 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u54c7\uff01\u8fd9\u4e00\u5217\u4e2d\u6709 18075 \u4e2a NaN \u503c\uff0c\u800c\u6211\u4eec\u4e4b\u524d\u751a\u81f3\u90fd\u6ca1\u6709\u8003\u8651\u4f7f\u7528\u5b83\u4eec\u3002\u589e\u52a0\u4e86\u8fd9\u4e2a\u65b0\u7c7b\u522b\u540e\uff0c\u7c7b\u522b\u603b\u6570\u4ece 6 \u4e2a\u589e\u52a0\u5230\u4e86 7 \u4e2a\u3002\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u73b0\u5728\u6211\u4eec\u5728\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u4e5f\u4f1a\u8003\u8651 NaN\u3002\u76f8\u5173\u4fe1\u606f\u8d8a\u591a\uff0c\u6a21\u578b\u5c31\u8d8a\u597d\u3002 \u5047\u8bbe ord_2 \u6ca1\u6709\u4efb\u4f55 NaN \u503c\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e00\u5217\u4e2d\u7684\u6240\u6709\u7c7b\u522b\u90fd\u6709\u663e\u8457\u7684\u8ba1\u6570\u3002\u5176\u4e2d\u6ca1\u6709 \"\u7f55\u89c1 \"\u7c7b\u522b\uff0c\u5373\u53ea\u5728\u6837\u672c\u603b\u6570\u4e2d\u5360\u5f88\u5c0f\u6bd4\u4f8b\u7684\u7c7b\u522b\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5047\u8bbe\u60a8\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u4e86\u4f7f\u7528\u8fd9\u4e00\u5217\u7684\u6a21\u578b\uff0c\u5f53\u6a21\u578b\u6216\u9879\u76ee\u4e0a\u7ebf\u65f6\uff0c\u60a8\u5728 ord_2 \u5217\u4e2d\u5f97\u5230\u4e86\u4e00\u4e2a\u5728\u8bad\u7ec3\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u7ba1\u9053\u4f1a\u629b\u51fa\u4e00\u4e2a\u9519\u8bef\uff0c\u60a8\u5bf9\u6b64\u65e0\u80fd\u4e3a\u529b\u3002\u5982\u679c\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\uff0c\u90a3\u4e48\u53ef\u80fd\u662f\u751f\u4ea7\u4e2d\u7684\u7ba1\u9053\u51fa\u4e86\u95ee\u9898\u3002\u5982\u679c\u8fd9\u662f\u9884\u6599\u4e4b\u4e2d\u7684\uff0c\u90a3\u4e48\u60a8\u5c31\u5fc5\u987b\u4fee\u6539\u60a8\u7684\u6a21\u578b\u7ba1\u9053\uff0c\u5e76\u5728\u8fd9\u516d\u4e2a\u7c7b\u522b\u4e2d\u52a0\u5165\u4e00\u4e2a\u65b0\u7c7b\u522b\u3002 \u8fd9\u4e2a\u65b0\u7c7b\u522b\u88ab\u79f0\u4e3a \"\u7f55\u89c1 \"\u7c7b\u522b\u3002\u7f55\u89c1\u7c7b\u522b\u662f\u4e00\u79cd\u4e0d\u5e38\u89c1\u7684\u7c7b\u522b\uff0c\u53ef\u4ee5\u5305\u62ec\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8fd1\u90bb\u6a21\u578b\u6765 \"\u9884\u6d4b \"\u672a\u77e5\u7c7b\u522b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u4e86\u8fd9\u4e2a\u7c7b\u522b\uff0c\u5b83\u5c31\u4f1a\u6210\u4e3a\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u4e00\u4e2a\u7c7b\u522b\u3002 \u56fe 3\uff1a\u5177\u6709\u4e0d\u540c\u7279\u5f81\u4e14\u65e0\u6807\u7b7e\u7684\u6570\u636e\u96c6\u793a\u610f\u56fe\uff0c\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u53ef\u80fd\u4f1a\u5728\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u51fa\u73b0\u65b0\u503c \u5f53\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u56fe 3 \u6240\u793a\u7684\u6570\u636e\u96c6\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\uff0c\u5bf9\u9664 \"f3 \"\u4e4b\u5916\u7684\u6240\u6709\u7279\u5f81\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u6837\uff0c\u4f60\u5c06\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u4e0d\u77e5\u9053\u6216\u8bad\u7ec3\u4e2d\u6ca1\u6709 \"f3 \"\u65f6\u9884\u6d4b\u5b83\u3002\u6211\u4e0d\u6562\u8bf4\u8fd9\u6837\u7684\u6a21\u578b\u662f\u5426\u80fd\u5e26\u6765\u51fa\u8272\u7684\u6027\u80fd\uff0c\u4f46\u4e5f\u8bb8\u80fd\u5904\u7406\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u7684\u7f3a\u5931\u503c\uff0c\u5c31\u50cf\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5176\u4ed6\u4e8b\u60c5\u4e00\u6837\uff0c\u4e0d\u5c1d\u8bd5\u4e00\u4e0b\u662f\u8bf4\u4e0d\u51c6\u7684\u3002 \u5982\u679c\u4f60\u6709\u4e00\u4e2a\u56fa\u5b9a\u7684\u6d4b\u8bd5\u96c6\uff0c\u4f60\u53ef\u4ee5\u5c06\u6d4b\u8bd5\u6570\u636e\u6dfb\u52a0\u5230\u8bad\u7ec3\u4e2d\uff0c\u4ee5\u4e86\u89e3\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u7c7b\u522b\u3002\u8fd9\u4e0e\u534a\u76d1\u7763\u5b66\u4e60\u975e\u5e38\u76f8\u4f3c\uff0c\u5373\u4f7f\u7528\u65e0\u6cd5\u7528\u4e8e\u8bad\u7ec3\u7684\u6570\u636e\u6765\u6539\u8fdb\u6a21\u578b\u3002\u8fd9\u4e5f\u4f1a\u7167\u987e\u5230\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u51fa\u73b0\u6b21\u6570\u6781\u5c11\u4f46\u5728\u6d4b\u8bd5\u6570\u636e\u4e2d\u5927\u91cf\u5b58\u5728\u7684\u7a00\u6709\u503c\u3002\u4f60\u7684\u6a21\u578b\u5c06\u66f4\u52a0\u7a33\u5065\u3002 \u5f88\u591a\u4eba\u8ba4\u4e3a\u8fd9\u79cd\u60f3\u6cd5\u4f1a\u8fc7\u5ea6\u62df\u5408\u3002\u53ef\u80fd\u8fc7\u62df\u5408\uff0c\u4e5f\u53ef\u80fd\u4e0d\u8fc7\u62df\u5408\u3002\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u3002\u5982\u679c\u4f60\u5728\u8bbe\u8ba1\u4ea4\u53c9\u9a8c\u8bc1\u65f6\uff0c\u80fd\u591f\u5728\u6d4b\u8bd5\u6570\u636e\u4e0a\u8fd0\u884c\u6a21\u578b\u65f6\u590d\u5236\u9884\u6d4b\u8fc7\u7a0b\uff0c\u90a3\u4e48\u5b83\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u8fd9\u610f\u5473\u7740\u7b2c\u4e00\u6b65\u5e94\u8be5\u662f\u5206\u79bb\u6298\u53e0\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u4f60\u5e94\u8be5\u5e94\u7528\u4e0e\u6d4b\u8bd5\u6570\u636e\u76f8\u540c\u7684\u9884\u5904\u7406\u3002\u5047\u8bbe\u60a8\u60f3\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e\uff0c\u90a3\u4e48\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u60a8\u5fc5\u987b\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\uff0c\u5e76\u786e\u4fdd\u9a8c\u8bc1\u6570\u636e\u96c6\u590d\u5236\u4e86\u6d4b\u8bd5\u96c6\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u60a8\u5fc5\u987b\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u8bbe\u8ba1\u9a8c\u8bc1\u96c6\uff0c\u4f7f\u5176\u5305\u542b\u8bad\u7ec3\u96c6\u4e2d \"\u672a\u89c1 \"\u7684\u7c7b\u522b\u3002 \u56fe 4\uff1a\u5bf9\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u8fdb\u884c\u7b80\u5355\u5408\u5e76\uff0c\u4ee5\u4e86\u89e3\u6d4b\u8bd5\u96c6\u4e2d\u5b58\u5728\u4f46\u8bad\u7ec3\u96c6\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u6216\u8bad\u7ec3\u96c6\u4e2d\u7f55\u89c1\u7684\u7c7b\u522b \u53ea\u8981\u770b\u4e00\u4e0b\u56fe 4 \u548c\u4e0b\u9762\u7684\u4ee3\u7801\uff0c\u5c31\u80fd\u5f88\u5bb9\u6613\u7406\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( \"../input/cat_train.csv\" ) # \u8bfb\u53d6\u6d4b\u8bd5\u96c6 test = pd . read_csv ( \"../input/cat_test.csv\" ) # \u5c06\u6d4b\u8bd5\u96c6\"target\"\u5217\u5168\u90e8\u7f6e\u4e3a-1 test . loc [:, \"target\" ] = - 1 # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u62fc\u63a5 data = pd . concat ([ train , test ]) . reset_index ( drop = True ) # \u5c06\u9664\"id\"\u548c\"target\"\u5217\u7684\u5176\u4ed6\u7279\u5f81\u5217\u540d\u53d6\u51fa features = [ x for x in train . columns if x not in [ \"id\" , \"target\" ]] # \u904d\u5386\u7279\u5f81 for feat in features : # \u6807\u7b7e\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\",\u5e76\u5c06\u8be5\u5217\u683c\u5f0f\u53d8\u4e3astr temp_col = data [ feat ] . fillna ( \"NONE\" ) . astype ( str ) . values # \u8f6c\u6362\u6570\u503c data . loc [:, feat ] = lbl_enc . fit_transform ( temp_col ) # \u6839\u636e\"target\"\u5217\u5c06\u8bad\u7ec3\u96c6\u4e0e\u6d4b\u8bd5\u96c6\u5206\u5f00 train = data [ data . target != - 1 ] . reset_index ( drop = True ) test = data [ data . target == - 1 ] . reset_index ( drop = True ) \u5f53\u60a8\u9047\u5230\u5df2\u7ecf\u6709\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u95ee\u9898\u65f6\uff0c\u8fd9\u4e2a\u6280\u5de7\u5c31\u4f1a\u8d77\u4f5c\u7528\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u4e00\u62db\u5728\u5b9e\u65f6\u73af\u5883\u4e2d\u4e0d\u8d77\u4f5c\u7528\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u60a8\u6240\u5728\u7684\u516c\u53f8\u63d0\u4f9b\u5b9e\u65f6\u7ade\u4ef7\u89e3\u51b3\u65b9\u6848\uff08RTB\uff09\u3002RTB \u7cfb\u7edf\u4f1a\u5bf9\u5728\u7ebf\u770b\u5230\u7684\u6bcf\u4e2a\u7528\u6237\u8fdb\u884c\u7ade\u4ef7\uff0c\u4ee5\u8d2d\u4e70\u5e7f\u544a\u7a7a\u95f4\u3002\u8fd9\u79cd\u6a21\u5f0f\u53ef\u4f7f\u7528\u7684\u529f\u80fd\u53ef\u80fd\u5305\u62ec\u7f51\u7ad9\u4e2d\u6d4f\u89c8\u7684\u9875\u9762\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u7279\u5f81\u662f\u7528\u6237\u8bbf\u95ee\u7684\u6700\u540e\u4e94\u4e2a\u7c7b\u522b/\u9875\u9762\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u7f51\u7ad9\u5f15\u5165\u4e86\u65b0\u7684\u7c7b\u522b\uff0c\u6211\u4eec\u5c06\u65e0\u6cd5\u518d\u51c6\u786e\u9884\u6d4b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u5931\u6548\u3002\u8fd9\u79cd\u60c5\u51b5\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528 \"\u672a\u77e5 \"\u7c7b\u522b\u6765\u907f\u514d \u3002 \u5728\u6211\u4eec\u7684 cat-in-the-dat \u6570\u636e\u96c6\u4e2d\uff0c ord_2 \u5217\u4e2d\u5df2\u7ecf\u6709\u4e86\u672a\u77e5\u7c7b\u522b\u3002 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u6211\u4eec\u53ef\u4ee5\u5c06 \"NONE \"\u89c6\u4e3a\u672a\u77e5\u3002\u56e0\u6b64\uff0c\u5982\u679c\u5728\u5b9e\u65f6\u6d4b\u8bd5\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4ee5\u524d\u4ece\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\uff0c\u6211\u4eec\u5c31\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a \"NONE\"\u3002 \u8fd9\u4e0e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u95ee\u9898\u975e\u5e38\u76f8\u4f3c\u3002\u6211\u4eec\u603b\u662f\u57fa\u4e8e\u56fa\u5b9a\u7684\u8bcd\u6c47\u5efa\u7acb\u6a21\u578b\u3002\u589e\u52a0\u8bcd\u6c47\u91cf\u5c31\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u3002\u50cf BERT \u8fd9\u6837\u7684\u8f6c\u6362\u5668\u6a21\u578b\u662f\u5728 ~30000 \u4e2a\u5355\u8bcd\uff08\u82f1\u8bed\uff09\u7684\u57fa\u7840\u4e0a\u8bad\u7ec3\u7684\u3002\u56e0\u6b64\uff0c\u5f53\u6709\u65b0\u8bcd\u8f93\u5165\u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a UNK\uff08\u672a\u77e5\uff09\u3002 \u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u5047\u8bbe\u6d4b\u8bd5\u6570\u636e\u4e0e\u8bad\u7ec3\u6570\u636e\u5177\u6709\u76f8\u540c\u7684\u7c7b\u522b\uff0c\u4e5f\u53ef\u4ee5\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u5f15\u5165\u7f55\u89c1\u6216\u672a\u77e5\u7c7b\u522b\uff0c\u4ee5\u5904\u7406\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u65b0\u7c7b\u522b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u586b\u5165 NaN \u503c\u540e ord_4 \u5217\u7684\u503c\u8ba1\u6570\uff1a In [ X ]: df . ord_4 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 . . . K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 G 3404 V 3107 J 1950 L 1657 Name : ord_4 , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u6709\u4e9b\u6570\u503c\u53ea\u51fa\u73b0\u4e86\u51e0\u5343\u6b21\uff0c\u6709\u4e9b\u5219\u51fa\u73b0\u4e86\u8fd1 40000 \u6b21\u3002NaN \u4e5f\u7ecf\u5e38\u51fa\u73b0\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u503c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u5c06\u4e00\u4e2a\u503c\u79f0\u4e3a \" \u7f55\u89c1\uff08rare\uff09 \"\u7684\u6807\u51c6\u4e86\u3002\u6bd4\u65b9\u8bf4\uff0c\u5728\u8fd9\u4e00\u5217\u4e2d\uff0c\u7a00\u6709\u503c\u7684\u8981\u6c42\u662f\u8ba1\u6570\u5c0f\u4e8e 2000\u3002\u8fd9\u6837\u770b\u6765\uff0cJ \u548c L \u5c31\u53ef\u4ee5\u88ab\u6807\u8bb0\u4e3a\u7a00\u6709\u503c\u4e86\u3002\u4f7f\u7528 pandas\uff0c\u6839\u636e\u8ba1\u6570\u9608\u503c\u66ff\u6362\u7c7b\u522b\u975e\u5e38\u7b80\u5355\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 In [ X ]: df . ord_4 = df . ord_4 . fillna ( \"NONE\" ) In [ X ]: df . loc [ ... : df [ \"ord_4\" ] . value_counts ()[ df [ \"ord_4\" ]] . values < 2000 , ... : \"ord_4\" ... : ] = \"RARE\" In [ X ]: df . ord_4 . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 M 32504 . . . B 25212 E 21871 K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 RARE 3607 G 3404 V 3107 Name : ord_4 , dtype : int64 \u6211\u4eec\u8ba4\u4e3a\uff0c\u53ea\u8981\u67d0\u4e2a\u7c7b\u522b\u7684\u503c\u5c0f\u4e8e 2000\uff0c\u5c31\u5c06\u5176\u66ff\u6362\u4e3a\u7f55\u89c1\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u5728\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6240\u6709\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"RARE\"\uff0c\u800c\u6240\u6709\u7f3a\u5931\u503c\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"NONE\"\u3002 \u8fd9\u79cd\u65b9\u6cd5\u8fd8\u80fd\u786e\u4fdd\u5373\u4f7f\u6709\u65b0\u7684\u7c7b\u522b\uff0c\u6a21\u578b\u4e5f\u80fd\u5728\u5b9e\u9645\u73af\u5883\u4e2d\u6b63\u5e38\u5de5\u4f5c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u5177\u5907\u4e86\u5904\u7406\u4efb\u4f55\u5e26\u6709\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u6240\u9700\u7684\u4e00\u5207\u6761\u4ef6\u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u7b2c\u4e00\u4e2a\u6a21\u578b\uff0c\u5e76\u9010\u6b65\u63d0\u9ad8\u5176\u6027\u80fd\u3002 \u5728\u6784\u5efa\u4efb\u4f55\u7c7b\u578b\u7684\u6a21\u578b\u4e4b\u524d\uff0c\u4ea4\u53c9\u68c0\u9a8c\u81f3\u5173\u91cd\u8981\u3002\u6211\u4eec\u5df2\u7ecf\u770b\u5230\u4e86\u6807\u7b7e/\u76ee\u6807\u5206\u5e03\uff0c\u77e5\u9053\u8fd9\u662f\u4e00\u4e2a\u76ee\u6807\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u6765\u5206\u5272\u6570\u636e\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u6dfb\u52a0\"kfold\"\u5217\uff0c\u5e76\u7f6e\u4e3a-1 df [ \"kfold\" ] = - 1 # \u6253\u4e71\u6570\u636e\u987a\u5e8f\uff0c\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u5c06\u76ee\u6807\u5217\u53d6\u51fa y = df . target . values # \u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): # \u533a\u5206\u6298\u53e0 df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u6587\u4ef6 df . to_csv ( \"../input/cat_train_folds.csv\" , index = False ) \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u68c0\u67e5\u65b0\u7684\u6298\u53e0 csv\uff0c\u67e5\u770b\u6bcf\u4e2a\u6298\u53e0\u7684\u6837\u672c\u6570\uff1a In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) In [ X ]: df . kfold . value_counts () Out [ X ]: 4 120000 3 120000 2 120000 1 120000 0 120000 Name : kfold , dtype : int64 \u6240\u6709\u6298\u53e0\u90fd\u6709 120000 \u4e2a\u6837\u672c\u3002\u8fd9\u662f\u610f\u6599\u4e4b\u4e2d\u7684\uff0c\u56e0\u4e3a\u8bad\u7ec3\u6570\u636e\u6709 600000 \u4e2a\u6837\u672c\uff0c\u800c\u6211\u4eec\u505a\u4e865\u6b21\u6298\u53e0\u3002\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4e00\u5207\u987a\u5229\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u6bcf\u4e2a\u6298\u53e0\u7684\u76ee\u6807\u5206\u5e03\u3002 In [ X ]: df [ df . kfold == 0 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 1 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 2 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 3 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 4 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u76ee\u6807\u7684\u5206\u5e03\u90fd\u662f\u4e00\u6837\u7684\u3002\u8fd9\u6b63\u662f\u6211\u4eec\u6240\u9700\u8981\u7684\u3002\u5b83\u4e5f\u53ef\u4ee5\u662f\u76f8\u4f3c\u7684\uff0c\u5e76\u4e0d\u4e00\u5b9a\u8981\u4e00\u76f4\u76f8\u540c\u3002\u73b0\u5728\uff0c\u5f53\u6211\u4eec\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u6807\u7b7e\u5206\u5e03\u90fd\u5c06\u76f8\u540c\u3002 \u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u7684\u6700\u7b80\u5355\u7684\u6a21\u578b\u4e4b\u4e00\u662f\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u5e76\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): # \u8bfb\u53d6\u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) # \u53d6\u9664\"id\", \"target\", \"kfold\"\u5916\u7684\u5176\u4ed6\u7279\u5f81\u5217 features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] # \u904d\u5386\u7279\u5f81\u5217\u8868 for col in features : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u6d4b\u8bd5\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u903b\u8f91\u56de\u5f52 model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . target . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( auc ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6298\u53e00 run ( 0 ) \u90a3\u4e48\uff0c\u53d1\u751f\u4e86\u4ec0\u4e48\u5462\uff1f \u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u5c06\u6570\u636e\u5206\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u4e24\u90e8\u5206\uff0c\u7ed9\u5b9a\u6298\u53e0\u6570\uff0c\u5904\u7406 NaN \u503c\uff0c\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u5355\u6b21\u7f16\u7801\uff0c\u5e76\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\u3002 \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u90e8\u5206\u4ee3\u7801\u65f6\uff0c\u4f1a\u4ea7\u751f\u5982\u4e0b\u8f93\u51fa\uff1a \u276f python ohe_logres . py / home / abhishek / miniconda3 / envs / ml / lib / python3 .7 / site - packages / sklearn / linear_model / _logistic . py : 939 : ConvergenceWarning : lbfgs failed to converge ( status = 1 ): STOP : TOTAL NO . of ITERATIONS REACHED LIMIT . Increase the number of iterations ( max_iter ) or scale the data as shown in : https : // scikit - learn . org / stable / modules / preprocessing . html . Please also refer to the documentation for alternative solver options : https : // scikit - learn . org / stable / modules / linear_model . html #logistic- regression extra_warning_msg = _LOGISTIC_SOLVER_CONVERGENCE_MSG ) 0.7847865042255127 \u6709\u4e00\u4e9b\u8b66\u544a\u3002\u903b\u8f91\u56de\u5f52\u4f3c\u4e4e\u6ca1\u6709\u6536\u655b\u5230\u6700\u5927\u8fed\u4ee3\u6b21\u6570\u3002\u6211\u4eec\u6ca1\u6709\u8c03\u6574\u53c2\u6570\uff0c\u6240\u4ee5\u6ca1\u6709\u95ee\u9898\u3002\u6211\u4eec\u770b\u5230 AUC \u4e3a 0.785\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u5bf9\u4ee3\u7801\u8fdb\u884c\u7b80\u5355\u4fee\u6539\uff0c\u8fd0\u884c\u6240\u6709\u6298\u53e0\u3002 .... model = linear_model . LogisticRegression () model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u5faa\u73af\u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5e76\u6ca1\u6709\u505a\u5f88\u5927\u7684\u6539\u52a8\uff0c\u6240\u4ee5\u6211\u53ea\u663e\u793a\u4e86\u90e8\u5206\u4ee3\u7801\u884c\uff0c\u5176\u4e2d\u4e00\u4e9b\u4ee3\u7801\u884c\u6709\u6539\u52a8\u3002 \u8fd9\u5c31\u6253\u5370\u51fa\u4e86\uff1a python - W ignore ohe_logres . py Fold = 0 , AUC = 0.7847865042255127 Fold = 1 , AUC = 0.7853553605899214 Fold = 2 , AUC = 0.7879321942914885 Fold = 3 , AUC = 0.7870315929550808 Fold = 4 , AUC = 0.7864668243125608 \u8bf7\u6ce8\u610f\uff0c\u6211\u4f7f\u7528\"-W ignore \"\u5ffd\u7565\u4e86\u6240\u6709\u8b66\u544a\u3002 \u6211\u4eec\u770b\u5230\uff0cAUC \u5206\u6570\u5728\u6240\u6709\u8936\u76b1\u4e2d\u90fd\u76f8\u5f53\u7a33\u5b9a\u3002\u5e73\u5747 AUC \u4e3a 0.78631449527\u3002\u5bf9\u4e8e\u6211\u4eec\u7684\u7b2c\u4e00\u4e2a\u6a21\u578b\u6765\u8bf4\u76f8\u5f53\u4e0d\u9519\uff01 \u5f88\u591a\u4eba\u5728\u9047\u5230\u8fd9\u79cd\u95ee\u9898\u65f6\u4f1a\u9996\u5148\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6bd4\u5982\u968f\u673a\u68ee\u6797\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u5e94\u7528\u968f\u673a\u68ee\u6797\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\uff08label encoding\uff09\uff0c\u5c06\u6bcf\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81\u90fd\u8f6c\u6362\u4e3a\u6574\u6570\uff0c\u800c\u4e0d\u662f\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u72ec\u70ed\u7f16\u7801\u3002 \u8fd9\u79cd\u7f16\u7801\u4e0e\u72ec\u70ed\u7f16\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # \u968f\u673a\u68ee\u6797\u6a21\u578b model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u4f7f\u7528 scikit-learn \u4e2d\u7684\u968f\u673a\u68ee\u6797\uff0c\u5e76\u53d6\u6d88\u4e86\u72ec\u70ed\u7f16\u7801\u3002\u6211\u4eec\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u4ee3\u66ff\u72ec\u70ed\u7f16\u7801\u3002\u5f97\u5206\u5982\u4e0b \u276f python lbl_rf . py Fold = 0 , AUC = 0.7167390828113697 Fold = 1 , AUC = 0.7165459672958506 Fold = 2 , AUC = 0.7159709909587376 Fold = 3 , AUC = 0.7161589664189556 Fold = 4 , AUC = 0.7156020216155978 \u54c7 \u5de8\u5927\u7684\u5dee\u5f02\uff01 \u968f\u673a\u68ee\u6797\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8d85\u53c2\u6570\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u8868\u73b0\u8981\u6bd4\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u5dee\u5f88\u591a\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u603b\u662f\u5e94\u8be5\u5148\u4ece\u7b80\u5355\u6a21\u578b\u5f00\u59cb\u7684\u539f\u56e0\u3002\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u7c89\u4e1d\u4f1a\u4ece\u8fd9\u91cc\u5f00\u59cb\uff0c\u800c\u5ffd\u7565\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u8ba4\u4e3a\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u6a21\u578b\uff0c\u4e0d\u80fd\u5e26\u6765\u6bd4\u968f\u673a\u68ee\u6797\u66f4\u597d\u7684\u4ef7\u503c\u3002\u8fd9\u79cd\u4eba\u5c06\u4f1a\u72af\u4e0b\u5927\u9519\u3002\u5728\u6211\u4eec\u5b9e\u73b0\u968f\u673a\u68ee\u6797\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4e0e\u903b\u8f91\u56de\u5f52\u76f8\u6bd4\uff0c\u6298\u53e0\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u624d\u80fd\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e0d\u4ec5\u635f\u5931\u4e86 AUC\uff0c\u8fd8\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u6765\u5b8c\u6210\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u8fdb\u884c\u63a8\u7406\u4e5f\u5f88\u8017\u65f6\uff0c\u800c\u4e14\u5360\u7528\u7684\u7a7a\u95f4\u4e5f\u66f4\u5927\u3002 \u5982\u679c\u6211\u4eec\u613f\u610f\uff0c\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u5728\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u6570\u636e\u4e0a\u8fd0\u884c\u968f\u673a\u68ee\u6797\uff0c\u4f46\u8fd9\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u5947\u5f02\u503c\u5206\u89e3\u6765\u51cf\u5c11\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u77e9\u9635\u3002\u8fd9\u662f\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u63d0\u53d6\u4e3b\u9898\u7684\u5e38\u7528\u65b9\u6cd5\u3002 import pandas as pd from scipy import sparse from sklearn import decomposition from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" )] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) x_train = ohe . transform ( df_train [ features ]) x_valid = ohe . transform ( df_valid [ features ]) # \u5947\u5f02\u503c\u5206\u89e3 svd = decomposition . TruncatedSVD ( n_components = 120 ) full_sparse = sparse . vstack (( x_train , x_valid )) svd . fit ( full_sparse ) x_train = svd . transform ( x_train ) x_valid = svd . transform ( x_valid ) model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u5bf9\u5168\u90e8\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\uff0c\u7136\u540e\u7528\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u5728\u7a00\u758f\u77e9\u9635\u4e0a\u62df\u5408 scikit-learn \u7684 TruncatedSVD\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c06\u9ad8\u7ef4\u7a00\u758f\u77e9\u9635\u51cf\u5c11\u5230 120 \u4e2a\u7279\u5f81\uff0c\u7136\u540e\u62df\u5408\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u3002 \u4ee5\u4e0b\u662f\u8be5\u6a21\u578b\u7684\u8f93\u51fa\u7ed3\u679c\uff1a \u276f python ohe_svd_rf . py Fold = 0 , AUC = 0.7064863038754249 Fold = 1 , AUC = 0.706050102937374 Fold = 2 , AUC = 0.7086069243167242 Fold = 3 , AUC = 0.7066819080085971 Fold = 4 , AUC = 0.7058154015055585 \u6211\u4eec\u53d1\u73b0\u60c5\u51b5\u66f4\u7cdf\u3002\u770b\u6765\uff0c\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u7684\u6700\u4f73\u65b9\u6cd5\u662f\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u3002\u968f\u673a\u68ee\u6797\u4f3c\u4e4e\u8017\u65f6\u592a\u591a\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u8bd5\u8bd5 XGBoost\u3002\u5982\u679c\u4f60\u4e0d\u77e5\u9053 XGBoost\uff0c\u5b83\u662f\u6700\u6d41\u884c\u7684\u68af\u5ea6\u63d0\u5347\u7b97\u6cd5\u4e4b\u4e00\u3002\u7531\u4e8e\u5b83\u662f\u4e00\u79cd\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u6570\u636e\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 , n_estimators = 200 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6211\u5bf9 xgboost \u53c2\u6570\u505a\u4e86\u4e00\u4e9b\u4fee\u6539\u3002xgboost \u7684\u9ed8\u8ba4\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u662f 3\uff0c\u6211\u628a\u5b83\u6539\u6210\u4e86 7\uff0c\u8fd8\u628a\u4f30\u8ba1\u5668\u6570\u91cf\uff08n_estimators\uff09\u4ece 100 \u6539\u6210\u4e86 200\u3002 \u8be5\u6a21\u578b\u7684 5 \u6298\u4ea4\u53c9\u68c0\u9a8c\u5f97\u5206\u5982\u4e0b\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.7656768851999011 Fold = 1 , AUC = 0.7633006564148015 Fold = 2 , AUC = 0.7654277821434345 Fold = 3 , AUC = 0.7663609758878182 Fold = 4 , AUC = 0.764914671468069 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u4e0d\u505a\u4efb\u4f55\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u5f97\u5206\u6bd4\u666e\u901a\u968f\u673a\u68ee\u6797\u8981\u9ad8\u5f97\u591a\u3002 \u60a8\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4e00\u4e9b\u7279\u5f81\u5de5\u7a0b\uff0c\u653e\u5f03\u67d0\u4e9b\u5bf9\u6a21\u578b\u6ca1\u6709\u4efb\u4f55\u4ef7\u503c\u7684\u5217\u7b49\u3002\u4f46\u4f3c\u4e4e\u6211\u4eec\u80fd\u505a\u7684\u4e0d\u591a\uff0c\u65e0\u6cd5\u8bc1\u660e\u6a21\u578b\u7684\u6539\u8fdb\u3002\u8ba9\u6211\u4eec\u628a\u6570\u636e\u96c6\u6362\u6210\u53e6\u4e00\u4e2a\u6709\u5927\u91cf\u5206\u7c7b\u53d8\u91cf\u7684\u6570\u636e\u96c6\u3002\u53e6\u4e00\u4e2a\u6709\u540d\u7684\u6570\u636e\u96c6\u662f \u7f8e\u56fd\u6210\u4eba\u4eba\u53e3\u666e\u67e5\u6570\u636e\uff08US adult census data\uff09 \u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u5305\u542b\u4e00\u4e9b\u7279\u5f81\uff0c\u800c\u4f60\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5de5\u8d44\u7b49\u7ea7\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u3002\u56fe 5 \u663e\u793a\u4e86\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e9b\u5217\u3002 \u56fe 5\uff1a\u90e8\u5206\u6570\u636e\u96c6\u5c55\u793a \u8be5\u6570\u636e\u96c6\u6709\u4ee5\u4e0b\u51e0\u5217\uff1a - \u5e74\u9f84\uff08age\uff09 \u5de5\u4f5c\u7c7b\u522b\uff08workclass\uff09 \u5b66\u5386\uff08fnlwgt\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education.num\uff09 \u5a5a\u59fb\u72b6\u51b5\uff08marital.status\uff09 \u804c\u4e1a\uff08occupation\uff09 \u5173\u7cfb\uff08relationship\uff09 \u79cd\u65cf\uff08race\uff09 \u6027\u522b\uff08sex\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u539f\u7c4d\u56fd\uff08native.country\uff09 \u6536\u5165\uff08income\uff09 \u8fd9\u4e9b\u7279\u5f81\u5927\u591a\u4e0d\u8a00\u81ea\u660e\u3002\u90a3\u4e9b\u4e0d\u660e\u767d\u7684\uff0c\u6211\u4eec\u53ef\u4ee5\u4e0d\u8003\u8651\u3002\u8ba9\u6211\u4eec\u5148\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002 \u6211\u4eec\u770b\u5230\u6536\u5165\u5217\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u3002\u8ba9\u6211\u4eec\u5bf9\u8fd9\u4e00\u5217\u8fdb\u884c\u6570\u503c\u7edf\u8ba1\u3002 In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/adult.csv\" ) In [ X ]: df . income . value_counts () Out [ X ]: <= 50 K 24720 > 50 K 7841 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 7841 \u4e2a\u5b9e\u4f8b\u7684\u6536\u5165\u8d85\u8fc7 5 \u4e07\u7f8e\u5143\u3002\u8fd9\u5360\u6837\u672c\u603b\u6570\u7684 24%\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u4e0e\u732b\u6570\u636e\u96c6\u76f8\u540c\u7684\u8bc4\u4f30\u65b9\u6cd5\uff0c\u5373 AUC\u3002 \u5728\u5f00\u59cb\u5efa\u6a21\u4e4b\u524d\uff0c\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5c06\u53bb\u6389\u51e0\u5217\u7279\u5f81\uff0c\u5373 \u5b66\u5386\uff08fnlwgt\uff09 \u5e74\u9f84\uff08age\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u8ba9\u6211\u4eec\u8bd5\u7740\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u5668\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u7b2c\u4e00\u6b65\u603b\u662f\u8981\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002\u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u90e8\u5206\u4ee3\u7801\u3002\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u9700\u8981\u5220\u9664\u7684\u5217 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) # \u6620\u5c04 target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } # \u4f7f\u7528\u6620\u5c04\u66ff\u6362 df . loc [:, \"income\" ] = df . income . map ( target_mapping ) # \u53d6\u9664\"kfold\", \"income\"\u5217\u7684\u5176\u4ed6\u5217\u540d features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u9a8c\u8bc1\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u6784\u5efa\u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . income . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u6211\u4eec\u4f1a\u5f97\u5230 \u276f python - W ignore ohe_logres . py Fold = 0 , AUC = 0.8794809708119079 Fold = 1 , AUC = 0.8875785068274882 Fold = 2 , AUC = 0.8852609687685753 Fold = 3 , AUC = 0.8681236223251438 Fold = 4 , AUC = 0.8728581541840037 \u5bf9\u4e8e\u4e00\u4e2a\u5982\u6b64\u7b80\u5355\u7684\u6a21\u578b\u6765\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u4e0d\u9519\u7684 AUC\uff01 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5728\u4e0d\u8c03\u6574\u4efb\u4f55\u8d85\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c1d\u8bd5\u4e00\u4e0b\u6807\u7b7e\u7f16\u7801\u7684xgboost\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8ba9\u6211\u4eec\u8fd0\u884c\u4e0a\u9762\u4ee3\u7801\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8800810634234078 Fold = 1 , AUC = 0.886811884948154 Fold = 2 , AUC = 0.8854421433318472 Fold = 3 , AUC = 0.8676319549361007 Fold = 4 , AUC = 0.8714450054900602 \u8fd9\u770b\u8d77\u6765\u5df2\u7ecf\u76f8\u5f53\u4e0d\u9519\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b max_depth \u589e\u52a0\u5230 7 \u548c n_estimators \u589e\u52a0\u5230 200 \u65f6\u7684\u5f97\u5206\u3002 \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8764108944332032 Fold = 1 , AUC = 0.8840708537662638 Fold = 2 , AUC = 0.8816601162613102 Fold = 3 , AUC = 0.8662335762581732 Fold = 4 , AUC = 0.8698983461709926 \u770b\u8d77\u6765\u5e76\u6ca1\u6709\u6539\u5584\u3002 \u8fd9\u8868\u660e\uff0c\u4e00\u4e2a\u6570\u636e\u96c6\u7684\u53c2\u6570\u4e0d\u80fd\u79fb\u690d\u5230\u53e6\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6211\u4eec\u5fc5\u987b\u518d\u6b21\u5c1d\u8bd5\u8c03\u6574\u53c2\u6570\uff0c\u4f46\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u8bf4\u660e\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u4e0d\u8c03\u6574\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c06\u6570\u503c\u7279\u5f81\u7eb3\u5165 xgboost \u6a21\u578b\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u52a0\u5165\u6570\u503c\u7279\u5f81 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : if col not in num_cols : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u56e0\u6b64\uff0c\u6211\u4eec\u4fdd\u7559\u6570\u5b57\u5217\uff0c\u53ea\u662f\u4e0d\u5bf9\u5176\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6700\u7ec8\u7279\u5f81\u77e9\u9635\u5c31\u7531\u6570\u5b57\u5217\uff08\u539f\u6837\uff09\u548c\u7f16\u7801\u5206\u7c7b\u5217\u7ec4\u6210\u4e86\u3002\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\u90fd\u80fd\u8f7b\u677e\u5904\u7406\u8fd9\u79cd\u6df7\u5408\u3002 \u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff0c\u5728\u4f7f\u7528\u7ebf\u6027\u6a21\u578b\uff08\u5982\u903b\u8f91\u56de\u5f52\uff09\u65f6\u4e0d\u5bb9\u5ffd\u89c6\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u4e2a\u811a\u672c\uff01 \u276f python lbl_xgb_num . py Fold = 0 , AUC = 0.9209790185449889 Fold = 1 , AUC = 0.9247157449144706 Fold = 2 , AUC = 0.9269329887598243 Fold = 3 , AUC = 0.9119349082169275 Fold = 4 , AUC = 0.9166408030141667 \u54c7\u54e6 \u8fd9\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u5206\u6570\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c1d\u8bd5\u6dfb\u52a0\u4e00\u4e9b\u529f\u80fd\u3002\u6211\u4eec\u5c06\u63d0\u53d6\u6240\u6709\u5206\u7c7b\u5217\uff0c\u5e76\u521b\u5efa\u6240\u6709\u4e8c\u5ea6\u7ec4\u5408\u3002\u8bf7\u770b\u4e0b\u9762\u4ee3\u7801\u6bb5\u4e2d\u7684 feature_engineering \u51fd\u6570\uff0c\u4e86\u89e3\u5982\u4f55\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002 import itertools import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def feature_engineering ( df , cat_cols ): # \u751f\u6210\u4e24\u4e2a\u7279\u5f81\u7684\u7ec4\u5408 combi = list ( itertools . combinations ( cat_cols , 2 )) for c1 , c2 in combi : df . loc [:, c1 + \"_\" + c2 ] = df [ c1 ] . astype ( str ) + \"_\" + df [ c2 ] . astype ( str ) return df def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) cat_cols = [ c for c in df . columns if c not in num_cols and c not in ( \"kfold\" , \"income\" )] # \u7279\u5f81\u5de5\u7a0b df = feature_engineering ( df , cat_cols ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" )] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u8fd9\u662f\u4ece\u5206\u7c7b\u5217\u4e2d\u521b\u5efa\u7279\u5f81\u7684\u4e00\u79cd\u975e\u5e38\u5e7c\u7a1a\u7684\u65b9\u6cd5\u3002\u6211\u4eec\u5e94\u8be5\u4ed4\u7ec6\u7814\u7a76\u6570\u636e\uff0c\u770b\u770b\u54ea\u4e9b\u7ec4\u5408\u6700\u5408\u7406\u3002\u5982\u679c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5c31\u9700\u8981\u4f7f\u7528\u67d0\u79cd\u7279\u5f81\u9009\u62e9\u6765\u9009\u51fa\u6700\u4f73\u7279\u5f81\u3002\u7a0d\u540e\u6211\u4eec\u5c06\u8be6\u7ec6\u4ecb\u7ecd\u7279\u5f81\u9009\u62e9\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5206\u6570\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9211483465031423 Fold = 1 , AUC = 0.9251499446866125 Fold = 2 , AUC = 0.9262344766486692 Fold = 3 , AUC = 0.9114264068794995 Fold = 4 , AUC = 0.9177914453099201 \u770b\u6765\uff0c\u5373\u4f7f\u4e0d\u6539\u53d8\u4efb\u4f55\u8d85\u53c2\u6570\uff0c\u53ea\u589e\u52a0\u4e00\u4e9b\u7279\u5f81\uff0c\u6211\u4eec\u4e5f\u80fd\u63d0\u9ad8\u4e00\u4e9b\u6298\u53e0\u5f97\u5206\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5c06 max_depth \u589e\u52a0\u5230 7 \u662f\u5426\u6709\u5e2e\u52a9\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9286668430204137 Fold = 1 , AUC = 0.9329340656165378 Fold = 2 , AUC = 0.9319817543218744 Fold = 3 , AUC = 0.919046187194538 Fold = 4 , AUC = 0.9245692057162671 \u6211\u4eec\u518d\u6b21\u6539\u8fdb\u4e86\u6211\u4eec\u7684\u6a21\u578b\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u8fd8\u6ca1\u6709\u4f7f\u7528\u7a00\u6709\u503c\u3001\u4e8c\u503c\u5316\u3001\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u7279\u5f81\u7684\u7ec4\u5408\u4ee5\u53ca\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u3002 \u4ece\u5206\u7c7b\u7279\u5f81\u4e2d\u8fdb\u884c\u7279\u5f81\u5de5\u7a0b\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u4f7f\u7528 \u76ee\u6807\u7f16\u7801 \u3002\u4f46\u662f\uff0c\u60a8\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u8fd9\u53ef\u80fd\u4f1a\u4f7f\u60a8\u7684\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u76ee\u6807\u7f16\u7801\u662f\u4e00\u79cd\u5c06\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u6620\u5c04\u5230\u5176\u5e73\u5747\u76ee\u6807\u503c\u7684\u6280\u672f\uff0c\u4f46\u5fc5\u987b\u59cb\u7ec8\u4ee5\u4ea4\u53c9\u9a8c\u8bc1\u7684\u65b9\u5f0f\u8fdb\u884c\u3002\u8fd9\u610f\u5473\u7740\u9996\u5148\u8981\u521b\u5efa\u6298\u53e0\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u6298\u53e0\u4e3a\u6570\u636e\u7684\u4e0d\u540c\u5217\u521b\u5efa\u76ee\u6807\u7f16\u7801\u7279\u5f81\uff0c\u65b9\u6cd5\u4e0e\u5728\u6298\u53e0\u4e0a\u62df\u5408\u548c\u9884\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u521b\u5efa\u4e86 5 \u4e2a\u6298\u53e0\uff0c\u60a8\u5c31\u5fc5\u987b\u521b\u5efa 5 \u6b21\u76ee\u6807\u7f16\u7801\uff0c\u8fd9\u6837\u6700\u7ec8\uff0c\u60a8\u5c31\u53ef\u4ee5\u4e3a\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u53d8\u91cf\u521b\u5efa\u7f16\u7801\uff0c\u800c\u8fd9\u4e9b\u53d8\u91cf\u5e76\u975e\u6765\u81ea\u540c\u4e00\u4e2a\u6298\u53e0\u3002\u7136\u540e\u5728\u62df\u5408\u6a21\u578b\u65f6\uff0c\u5fc5\u987b\u518d\u6b21\u4f7f\u7528\u76f8\u540c\u7684\u6298\u53e0\u3002\u672a\u89c1\u6d4b\u8bd5\u6570\u636e\u7684\u76ee\u6807\u7f16\u7801\u53ef\u4ee5\u6765\u81ea\u5168\u90e8\u8bad\u7ec3\u6570\u636e\uff0c\u4e5f\u53ef\u4ee5\u662f\u6240\u6709 5 \u4e2a\u6298\u53e0\u7684\u5e73\u5747\u503c\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u540c\u4e00\u4e2a\u6210\u4eba\u6570\u636e\u96c6\u4e0a\u4f7f\u7528\u76ee\u6807\u7f16\u7801\uff0c\u4ee5\u4fbf\u8fdb\u884c\u6bd4\u8f83\u3002 import copy import pandas as pd from sklearn import metrics from sklearn import preprocessing import xgboost as xgb def mean_target_encoding ( data ): df = copy . deepcopy ( data ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) and f not in num_cols ] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) encoded_dfs = [] for fold in range ( 5 ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) for column in features : # \u76ee\u6807\u7f16\u7801 mapping_dict = dict ( df_train . groupby ( column )[ \"income\" ] . mean () ) df_valid . loc [:, column + \"_enc\" ] = df_valid [ column ] . map ( mapping_dict ) encoded_dfs . append ( df_valid ) encoded_df = pd . concat ( encoded_dfs , axis = 0 ) return encoded_df def run ( df , fold ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/adult_folds.csv\" ) df = mean_target_encoding ( df ) for fold_ in range ( 5 ): run ( df , fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u4e0a\u8ff0\u7247\u6bb5\u4e2d\uff0c\u6211\u5728\u8fdb\u884c\u76ee\u6807\u7f16\u7801\u65f6\u5e76\u6ca1\u6709\u5220\u9664\u5206\u7c7b\u5217\u3002\u6211\u4fdd\u7559\u4e86\u6240\u6709\u7279\u5f81\uff0c\u5e76\u5728\u6b64\u57fa\u7840\u4e0a\u6dfb\u52a0\u4e86\u76ee\u6807\u7f16\u7801\u7279\u5f81\u3002\u6b64\u5916\uff0c\u6211\u8fd8\u4f7f\u7528\u4e86\u5e73\u5747\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5e73\u5747\u503c\u3001\u4e2d\u4f4d\u6570\u3001\u6807\u51c6\u504f\u5dee\u6216\u76ee\u6807\u7684\u4efb\u4f55\u5176\u4ed6\u51fd\u6570\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u7ed3\u679c\u3002 Fold = 0 , AUC = 0.9332240662017529 Fold = 1 , AUC = 0.9363551625140347 Fold = 2 , AUC = 0.9375013544556173 Fold = 3 , AUC = 0.92237621307625 Fold = 4 , AUC = 0.9292131180445478 \u4e0d\u9519\uff01\u770b\u6765\u6211\u4eec\u53c8\u6709\u8fdb\u6b65\u4e86\u3002\u4e0d\u8fc7\uff0c\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u5b83\u592a\u5bb9\u6613\u51fa\u73b0\u8fc7\u5ea6\u62df\u5408\u3002\u5f53\u6211\u4eec\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\uff0c\u6700\u597d\u4f7f\u7528\u67d0\u79cd\u5e73\u6ed1\u65b9\u6cd5\u6216\u5728\u7f16\u7801\u503c\u4e2d\u6dfb\u52a0\u566a\u58f0\u3002 Scikit-learn \u7684\u8d21\u732e\u5e93\u4e2d\u6709\u5e26\u5e73\u6ed1\u7684\u76ee\u6807\u7f16\u7801\uff0c\u4f60\u4e5f\u53ef\u4ee5\u521b\u5efa\u81ea\u5df1\u7684\u5e73\u6ed1\u3002\u5e73\u6ed1\u4f1a\u5f15\u5165\u67d0\u79cd\u6b63\u5219\u5316\uff0c\u6709\u52a9\u4e8e\u907f\u514d\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u8fd9\u5e76\u4e0d\u96be\u3002 \u5904\u7406\u5206\u7c7b\u7279\u5f81\u662f\u4e00\u9879\u590d\u6742\u7684\u4efb\u52a1\u3002\u8bb8\u591a\u8d44\u6e90\u4e2d\u90fd\u6709\u5927\u91cf\u4fe1\u606f\u3002\u672c\u7ae0\u5e94\u8be5\u80fd\u5e2e\u52a9\u4f60\u5f00\u59cb\u89e3\u51b3\u5206\u7c7b\u53d8\u91cf\u7684\u4efb\u4f55\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u5bf9\u4e8e\u5927\u591a\u6570\u95ee\u9898\u6765\u8bf4\uff0c\u9664\u4e86\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u4e4b\u5916\uff0c\u4f60\u4e0d\u9700\u8981\u66f4\u591a\u7684\u4e1c\u897f\u3002 \u8981\u8fdb\u4e00\u6b65\u6539\u8fdb\u6a21\u578b\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u66f4\u591a\uff01 \u5728\u672c\u7ae0\u7684\u6700\u540e\uff0c\u6211\u4eec\u4e0d\u80fd\u4e0d\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e00\u79cd\u79f0\u4e3a \u5b9e\u4f53\u5d4c\u5165 \u7684\u6280\u672f\u3002\u5728\u5b9e\u4f53\u5d4c\u5165\u4e2d\uff0c\u7c7b\u522b\u7528\u5411\u91cf\u8868\u793a\u3002\u5728\u4e8c\u503c\u5316\u548c\u72ec\u70ed\u7f16\u7801\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u90fd\u662f\u7528\u5411\u91cf\u6765\u8868\u793a\u7c7b\u522b\u7684\u3002 \u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u6709\u6570\u4ee5\u4e07\u8ba1\u7684\u7c7b\u522b\u600e\u4e48\u529e\uff1f\u8fd9\u5c06\u4f1a\u4ea7\u751f\u5de8\u5927\u7684\u77e9\u9635\uff0c\u6211\u4eec\u5c06\u9700\u8981\u5f88\u957f\u65f6\u95f4\u6765\u8bad\u7ec3\u590d\u6742\u7684\u6a21\u578b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5e26\u6709\u6d6e\u70b9\u503c\u7684\u5411\u91cf\u6765\u8868\u793a\u5b83\u4eec\u3002 \u8fd9\u4e2a\u60f3\u6cd5\u975e\u5e38\u7b80\u5355\u3002\u6bcf\u4e2a\u5206\u7c7b\u7279\u5f81\u90fd\u6709\u4e00\u4e2a\u5d4c\u5165\u5c42\u3002\u56e0\u6b64\uff0c\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u73b0\u5728\u90fd\u53ef\u4ee5\u6620\u5c04\u5230\u4e00\u4e2a\u5d4c\u5165\u5c42\uff08\u5c31\u50cf\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u5c06\u5355\u8bcd\u6620\u5c04\u5230\u5d4c\u5165\u5c42\u4e00\u6837\uff09\u3002\u7136\u540e\uff0c\u6839\u636e\u5176\u7ef4\u5ea6\u91cd\u5851\u8fd9\u4e9b\u5d4c\u5165\u5c42\uff0c\u4f7f\u5176\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c06\u6240\u6709\u6241\u5e73\u5316\u7684\u8f93\u5165\u5d4c\u5165\u5c42\u8fde\u63a5\u8d77\u6765\u3002\u7136\u540e\u6dfb\u52a0\u4e00\u5806\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\uff0c\u5c31\u5927\u529f\u544a\u6210\u4e86\u3002 \u56fe 6\uff1a\u7c7b\u522b\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6216\u5d4c\u5165\u5411\u91cf \u51fa\u4e8e\u67d0\u79cd\u539f\u56e0\uff0c\u6211\u53d1\u73b0\u4f7f\u7528 TF/Keras \u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528 TF/Keras \u5b9e\u73b0\u5b83\u3002\u6b64\u5916\uff0c\u8fd9\u662f\u672c\u4e66\u4e2d\u552f\u4e00\u4e00\u4e2a\u4f7f\u7528 TF/Keras \u7684\u793a\u4f8b\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a PyTorch\uff08\u4f7f\u7528 cat-in-the-dat-ii \u6570\u636e\u96c6\uff09\u4e5f\u975e\u5e38\u5bb9\u6613 import os import gc import joblib import pandas as pd import numpy as np from sklearn import metrics , preprocessing from tensorflow.keras import layers from tensorflow.keras import optimizers from tensorflow.keras.models import Model , load_model from tensorflow.keras import callbacks from tensorflow.keras import backend as K from tensorflow.keras import utils def create_model ( data , catcols ): # \u521b\u5efa\u7a7a\u7684\u8f93\u5165\u5217\u8868\u548c\u8f93\u51fa\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6a21\u578b\u7684\u8f93\u5165\u548c\u8f93\u51fa inputs = [] outputs = [] # \u904d\u5386\u5206\u7c7b\u7279\u5f81\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81 for c in catcols : # \u8ba1\u7b97\u7279\u5f81\u4e2d\u552f\u4e00\u503c\u7684\u6570\u91cf num_unique_values = int ( data [ c ] . nunique ()) # \u8ba1\u7b97\u5d4c\u5165\u7ef4\u5ea6\uff0c\u6700\u5927\u4e0d\u8d85\u8fc750 embed_dim = int ( min ( np . ceil (( num_unique_values ) / 2 ), 50 )) # \u521b\u5efa\u6a21\u578b\u7684\u8f93\u5165\u5c42\uff0c\u6bcf\u4e2a\u7279\u5f81\u5bf9\u5e94\u4e00\u4e2a\u8f93\u5165 inp = layers . Input ( shape = ( 1 ,)) # \u521b\u5efa\u5d4c\u5165\u5c42\uff0c\u5c06\u5206\u7c7b\u7279\u5f81\u6620\u5c04\u5230\u4f4e\u7ef4\u5ea6\u7684\u8fde\u7eed\u5411\u91cf out = layers . Embedding ( num_unique_values + 1 , embed_dim , name = c )( inp ) # \u5bf9\u5d4c\u5165\u5c42\u8fdb\u884c\u7a7a\u95f4\u4e22\u5f03\uff08Dropout\uff09 out = layers . SpatialDropout1D ( 0.3 )( out ) # \u5c06\u5d4c\u5165\u5c42\u7684\u5f62\u72b6\u91cd\u65b0\u8c03\u6574\u4e3a\u4e00\u7ef4 out = layers . Reshape ( target_shape = ( embed_dim ,))( out ) # \u5c06\u8f93\u5165\u548c\u8f93\u51fa\u6dfb\u52a0\u5230\u5bf9\u5e94\u7684\u5217\u8868\u4e2d inputs . append ( inp ) outputs . append ( out ) # \u4f7f\u7528Concatenate\u5c42\u5c06\u6240\u6709\u7684\u5d4c\u5165\u5c42\u8f93\u51fa\u8fde\u63a5\u5728\u4e00\u8d77 x = layers . Concatenate ()( outputs ) # \u5bf9\u8fde\u63a5\u540e\u7684\u6570\u636e\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u53e6\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u8f93\u51fa\u5c42\uff0c\u5177\u67092\u4e2a\u795e\u7ecf\u5143\uff08\u7528\u4e8e\u4e8c\u8fdb\u5236\u5206\u7c7b\uff09\uff0c\u5e76\u4f7f\u7528softmax\u6fc0\u6d3b\u51fd\u6570 y = layers . Dense ( 2 , activation = \"softmax\" )( x ) # \u521b\u5efa\u6a21\u578b\uff0c\u5c06\u8f93\u5165\u548c\u8f93\u51fa\u4f20\u9012\u7ed9Model\u6784\u9020\u51fd\u6570 model = Model ( inputs = inputs , outputs = y ) # \u7f16\u8bd1\u6a21\u578b\uff0c\u6307\u5b9a\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668 model . compile ( loss = 'binary_crossentropy' , optimizer = 'adam' ) # \u8fd4\u56de\u521b\u5efa\u7684\u6a21\u578b return model def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for feat in features : lbl_enc = preprocessing . LabelEncoder () df . loc [:, feat ] = lbl_enc . fit_transform ( df [ feat ] . values ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) model = create_model ( df , features ) xtrain = [ df_train [ features ] . values [:, k ] for k in range ( len ( features ))] xvalid = [ df_valid [ features ] . values [:, k ] for k in range ( len ( features )) ] ytrain = df_train . target . values yvalid = df_valid . target . values ytrain_cat = utils . to_categorical ( ytrain ) yvalid_cat = utils . to_categorical ( yvalid ) model . fit ( xtrain , ytrain_cat , validation_data = ( xvalid , yvalid_cat ), verbose = 1 , batch_size = 1024 , epochs = 3 ) valid_preds = model . predict ( xvalid )[:, 1 ] print ( metrics . roc_auc_score ( yvalid , valid_preds )) K . clear_session () if __name__ == \"__main__\" : run ( 0 ) run ( 1 ) run ( 2 ) run ( 3 ) run ( 4 ) \u4f60\u4f1a\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u6548\u679c\u6700\u597d\uff0c\u800c\u4e14\u5982\u679c\u4f60\u6709 GPU\uff0c\u901f\u5ea6\u4e5f\u8d85\u5feb\uff01\u8fd9\u79cd\u65b9\u6cd5\u8fd8\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\uff0c\u800c\u4e14\u4f60\u65e0\u9700\u62c5\u5fc3\u7279\u5f81\u5de5\u7a0b\uff0c\u56e0\u4e3a\u795e\u7ecf\u7f51\u7edc\u4f1a\u81ea\u884c\u5904\u7406\u3002\u5728\u5904\u7406\u5927\u91cf\u5206\u7c7b\u7279\u5f81\u6570\u636e\u96c6\u65f6\uff0c\u8fd9\u7edd\u5bf9\u503c\u5f97\u4e00\u8bd5\u3002\u5f53\u5d4c\u5165\u5927\u5c0f\u4e0e\u552f\u4e00\u7c7b\u522b\u7684\u6570\u91cf\u76f8\u540c\u65f6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u7f16\u7801\uff08one-hot-encoding\uff09\u3002 \u672c\u7ae0\u57fa\u672c\u4e0a\u90fd\u662f\u5173\u4e8e\u7279\u5f81\u5de5\u7a0b\u7684\u3002\u8ba9\u6211\u4eec\u5728\u4e0b\u4e00\u7ae0\u4e2d\u770b\u770b\u5982\u4f55\u5728\u6570\u5b57\u7279\u5f81\u548c\u4e0d\u540c\u7c7b\u578b\u7279\u5f81\u7684\u7ec4\u5408\u65b9\u9762\u8fdb\u884c\u66f4\u591a\u7684\u7279\u5f81\u5de5\u7a0b\u3002","title":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf"},{"location":"%E5%A4%84%E7%90%86%E5%88%86%E7%B1%BB%E5%8F%98%E9%87%8F/#_1","text":"\u5f88\u591a\u4eba\u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u65f6\u90fd\u4f1a\u9047\u5230\u5f88\u591a\u56f0\u96be\uff0c\u56e0\u6b64\u8fd9\u503c\u5f97\u7528\u6574\u6574\u4e00\u7ae0\u7684\u7bc7\u5e45\u6765\u8ba8\u8bba\u3002\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u5c06\u8bb2\u8ff0\u4e0d\u540c\u7c7b\u578b\u7684\u5206\u7c7b\u6570\u636e\uff0c\u4ee5\u53ca\u5982\u4f55\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u3002 \u4ec0\u4e48\u662f\u5206\u7c7b\u53d8\u91cf\uff1f \u5206\u7c7b\u53d8\u91cf/\u7279\u5f81\u662f\u6307\u4efb\u4f55\u7279\u5f81\u7c7b\u578b\uff0c\u53ef\u5206\u4e3a\u4e24\u5927\u7c7b\uff1a - \u65e0\u5e8f - \u6709\u5e8f \u65e0\u5e8f\u53d8\u91cf \u662f\u6307\u6709\u4e24\u4e2a\u6216\u4e24\u4e2a\u4ee5\u4e0a\u7c7b\u522b\u7684\u53d8\u91cf\uff0c\u8fd9\u4e9b\u7c7b\u522b\u6ca1\u6709\u4efb\u4f55\u76f8\u5173\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u5982\u679c\u5c06\u6027\u522b\u5206\u4e3a\u4e24\u7ec4\uff0c\u5373\u7537\u6027\u548c\u5973\u6027\uff0c\u5219\u53ef\u5c06\u5176\u89c6\u4e3a\u540d\u4e49\u53d8\u91cf\u3002 \u6709\u5e8f\u53d8\u91cf \u5219\u6709 \"\u7b49\u7ea7 \"\u6216\u7c7b\u522b\uff0c\u5e76\u6709\u7279\u5b9a\u7684\u987a\u5e8f\u3002\u4f8b\u5982\uff0c\u4e00\u4e2a\u987a\u5e8f\u5206\u7c7b\u53d8\u91cf\u53ef\u4ee5\u662f\u4e00\u4e2a\u5177\u6709\u4f4e\u3001\u4e2d\u3001\u9ad8\u4e09\u4e2a\u4e0d\u540c\u7b49\u7ea7\u7684\u7279\u5f81\u3002\u987a\u5e8f\u5f88\u91cd\u8981\u3002 \u5c31\u5b9a\u4e49\u800c\u8a00\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06\u5206\u7c7b\u53d8\u91cf\u5206\u4e3a \u4e8c\u5143\u53d8\u91cf \uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u7c7b\u522b\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u6709\u4e9b\u4eba\u751a\u81f3\u628a\u5206\u7c7b\u53d8\u91cf\u79f0\u4e3a \" \u5faa\u73af \"\u53d8\u91cf\u3002\u5468\u671f\u53d8\u91cf\u4ee5 \"\u5468\u671f \"\u7684\u5f62\u5f0f\u5b58\u5728\uff0c\u4f8b\u5982\u4e00\u5468\u4e2d\u7684\u5929\u6570\uff1a \u5468\u65e5\u3001\u5468\u4e00\u3001\u5468\u4e8c\u3001\u5468\u4e09\u3001\u5468\u56db\u3001\u5468\u4e94\u548c\u5468\u516d\u3002\u5468\u516d\u8fc7\u540e\uff0c\u53c8\u662f\u5468\u65e5\u3002\u8fd9\u5c31\u662f\u4e00\u4e2a\u5faa\u73af\u3002\u53e6\u4e00\u4e2a\u4f8b\u5b50\u662f\u4e00\u5929\u4e2d\u7684\u5c0f\u65f6\u6570\uff0c\u5982\u679c\u6211\u4eec\u5c06\u5b83\u4eec\u89c6\u4e3a\u7c7b\u522b\u7684\u8bdd\u3002 \u5206\u7c7b\u53d8\u91cf\u6709\u5f88\u591a\u4e0d\u540c\u7684\u5b9a\u4e49\uff0c\u5f88\u591a\u4eba\u4e5f\u8c08\u5230\u8981\u6839\u636e\u5206\u7c7b\u53d8\u91cf\u7684\u7c7b\u578b\u6765\u5904\u7406\u4e0d\u540c\u7684\u5206\u7c7b\u53d8\u91cf\u3002\u4e0d\u8fc7\uff0c\u6211\u8ba4\u4e3a\u6ca1\u6709\u5fc5\u8981\u8fd9\u6837\u505a\u3002\u6240\u6709\u6d89\u53ca\u5206\u7c7b\u53d8\u91cf\u7684\u95ee\u9898\u90fd\u53ef\u4ee5\u7528\u540c\u6837\u7684\u65b9\u6cd5\u5904\u7406\u3002 \u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e00\u4e2a\u6570\u636e\u96c6\uff08\u4e00\u5982\u65e2\u5f80\uff09\u3002\u8981\u4e86\u89e3\u5206\u7c7b\u53d8\u91cf\uff0c\u6700\u597d\u7684\u514d\u8d39\u6570\u636e\u96c6\u4e4b\u4e00\u662f Kaggle \u5206\u7c7b\u7279\u5f81\u7f16\u7801\u6311\u6218\u8d5b\u4e2d\u7684 cat-in-the-dat \u3002\u5171\u6709\u4e24\u4e2a\u6311\u6218\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6311\u6218\u7684\u6570\u636e\uff0c\u56e0\u4e3a\u5b83\u6bd4\u524d\u4e00\u4e2a\u7248\u672c\u6709\u66f4\u591a\u53d8\u91cf\uff0c\u96be\u5ea6\u4e5f\u66f4\u5927\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u3002 \u56fe 1\uff1aCat-in-the-dat-ii challenge\u90e8\u5206\u6570\u636e\u5c55\u793a \u6570\u636e\u96c6\u7531\u5404\u79cd\u5206\u7c7b\u53d8\u91cf\u7ec4\u6210\uff1a \u65e0\u5e8f \u6709\u5e8f \u5faa\u73af \u4e8c\u5143 \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6240\u6709\u5b58\u5728\u7684\u53d8\u91cf\u548c\u76ee\u6807\u53d8\u91cf\u7684\u5b50\u96c6\u3002 \u8fd9\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002 \u76ee\u6807\u53d8\u91cf\u5bf9\u4e8e\u6211\u4eec\u5b66\u4e60\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\u5e76\u4e0d\u5341\u5206\u91cd\u8981\uff0c\u4f46\u6700\u7ec8\u6211\u4eec\u5c06\u5efa\u7acb\u4e00\u4e2a\u7aef\u5230\u7aef\u6a21\u578b\uff0c\u56e0\u6b64\u8ba9\u6211\u4eec\u770b\u770b\u56fe 2 \u4e2d\u7684\u76ee\u6807\u53d8\u91cf\u5206\u5e03\u3002\u6211\u4eec\u770b\u5230\u76ee\u6807\u662f \u504f\u659c \u7684\uff0c\u56e0\u6b64\u5bf9\u4e8e\u8fd9\u4e2a\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u6765\u8bf4\uff0c\u6700\u597d\u7684\u6307\u6807\u662f ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528\u7cbe\u786e\u5ea6\u548c\u53ec\u56de\u7387\uff0c\u4f46 AUC \u7ed3\u5408\u4e86\u8fd9\u4e24\u4e2a\u6307\u6807\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 AUC \u6765\u8bc4\u4f30\u6211\u4eec\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u7684\u6a21\u578b\u3002 \u56fe 2\uff1a\u6807\u7b7e\u8ba1\u6570\u3002X \u8f74\u8868\u793a\u6807\u7b7e\uff0cY \u8f74\u8868\u793a\u6807\u7b7e\u8ba1\u6570 \u603b\u4f53\u800c\u8a00\uff0c\u6709\uff1a 5\u4e2a\u4e8c\u5143\u53d8\u91cf 10\u4e2a\u65e0\u5e8f\u53d8\u91cf 6\u4e2a\u6709\u5e8f\u53d8\u91cf 2\u4e2a\u5faa\u73af\u53d8\u91cf 1\u4e2a\u76ee\u6807\u53d8\u91cf \u8ba9\u6211\u4eec\u6765\u770b\u770b\u6570\u636e\u96c6\u4e2d\u7684 ord_2 \u7279\u5f81\u3002\u5b83\u5305\u62ec6\u4e2a\u4e0d\u540c\u7684\u7c7b\u522b\uff1a - \u51b0\u51bb - \u6e29\u6696 - \u5bd2\u51b7 - \u8f83\u70ed - \u70ed - \u975e\u5e38\u70ed \u6211\u4eec\u5fc5\u987b\u77e5\u9053\uff0c\u8ba1\u7b97\u673a\u65e0\u6cd5\u7406\u89e3\u6587\u672c\u6570\u636e\uff0c\u56e0\u6b64\u6211\u4eec\u9700\u8981\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u662f\u521b\u5efa\u4e00\u4e2a\u5b57\u5178\uff0c\u5c06\u8fd9\u4e9b\u503c\u6620\u5c04\u4e3a\u4ece 0 \u5230 N-1 \u7684\u6570\u5b57\uff0c\u5176\u4e2d N \u662f\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7c7b\u522b\u7684\u603b\u6570\u3002 # \u6620\u5c04\u5b57\u5178 mapping = { \"Freezing\" : 0 , \"Warm\" : 1 , \"Cold\" : 2 , \"Boiling Hot\" : 3 , \"Hot\" : 4 , \"Lava Hot\" : 5 } \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8bfb\u53d6\u6570\u636e\u96c6\uff0c\u5e76\u8f7b\u677e\u5730\u5c06\u8fd9\u4e9b\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\u3002 import pandas as pd # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u53d6*ord_2*\u5217\uff0c\u5e76\u4f7f\u7528\u6620\u5c04\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57 df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. map ( mapping ) \u6620\u5c04\u524d\u7684\u6570\u503c\u8ba1\u6570\uff1a df .* ord_2 *. value_counts () Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : * ord_2 * , dtype : int64 \u6620\u5c04\u540e\u7684\u6570\u503c\u8ba1\u6570\uff1a 0.0 142726 1.0 124239 2.0 97822 3.0 84790 4.0 67508 5.0 64840 Name : * ord_2 * , dtype : int64 \u8fd9\u79cd\u5206\u7c7b\u53d8\u91cf\u7684\u7f16\u7801\u65b9\u5f0f\u88ab\u79f0\u4e3a\u6807\u7b7e\u7f16\u7801\uff08Label Encoding\uff09\u6211\u4eec\u5c06\u6bcf\u4e2a\u7c7b\u522b\u7f16\u7801\u4e3a\u4e00\u4e2a\u6570\u5b57\u6807\u7b7e\u3002 \u6211\u4eec\u4e5f\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u4e2d\u7684 LabelEncoder \u8fdb\u884c\u7f16\u7801\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u5c06\u7f3a\u5931\u503c\u586b\u5145\u4e3a\"NONE\" df . loc [:, \"*ord_2*\" ] = df .* ord_2 *. fillna ( \"NONE\" ) # LabelEncoder\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u8f6c\u6362\u6570\u636e df . loc [:, \"*ord_2*\" ] = lbl_enc . fit_transform ( df .* ord_2 *. values ) \u4f60\u4f1a\u770b\u5230\u6211\u4f7f\u7528\u4e86 pandas \u7684 fillna\u3002\u539f\u56e0\u662f scikit-learn \u7684 LabelEncoder \u65e0\u6cd5\u5904\u7406 NaN \u503c\uff0c\u800c ord_2 \u5217\u4e2d\u6709 NaN \u503c\u3002 \u6211\u4eec\u53ef\u4ee5\u5728\u8bb8\u591a\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u4e2d\u76f4\u63a5\u4f7f\u7528\u5b83\uff1a - \u51b3\u7b56\u6811 - \u968f\u673a\u68ee\u6797 - \u63d0\u5347\u6811 - \u6216\u4efb\u4f55\u4e00\u79cd\u63d0\u5347\u6811\u6a21\u578b - XGBoost - GBM - LightGBM \u8fd9\u79cd\u7f16\u7801\u65b9\u5f0f\u4e0d\u80fd\u7528\u4e8e\u7ebf\u6027\u6a21\u578b\u3001\u652f\u6301\u5411\u91cf\u673a\u6216\u795e\u7ecf\u7f51\u7edc\uff0c\u56e0\u4e3a\u5b83\u4eec\u5e0c\u671b\u6570\u636e\u662f\u6807\u51c6\u5316\u7684\u3002 \u5bf9\u4e8e\u8fd9\u4e9b\u7c7b\u578b\u7684\u6a21\u578b\uff0c\u6211\u4eec\u53ef\u4ee5\u5bf9\u6570\u636e\u8fdb\u884c\u4e8c\u503c\u5316\uff08binarize\uff09\u5904\u7406\u3002 Freezing --> 0 --> 0 0 0 Warm --> 1 --> 0 0 1 Cold --> 2 --> 0 1 0 Boiling Hot --> 3 --> 0 1 1 Hot --> 4 --> 1 0 0 Lava Hot --> 5 --> 1 0 1 \u8fd9\u53ea\u662f\u5c06\u7c7b\u522b\u8f6c\u6362\u4e3a\u6570\u5b57\uff0c\u7136\u540e\u518d\u8f6c\u6362\u4e3a\u4e8c\u503c\u5316\u8868\u793a\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u628a\u4e00\u4e2a\u7279\u5f81\u5206\u6210\u4e86\u4e09\u4e2a\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u7279\u5f81\uff08\u6216\u5217\uff09\u3002\u5982\u679c\u6211\u4eec\u6709\u66f4\u591a\u7684\u7c7b\u522b\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u5206\u6210\u66f4\u591a\u7684\u5217\u3002 \u5982\u679c\u6211\u4eec\u7528\u7a00\u758f\u683c\u5f0f\u5b58\u50a8\u5927\u91cf\u4e8c\u503c\u5316\u53d8\u91cf\uff0c\u5c31\u53ef\u4ee5\u8f7b\u677e\u5730\u5b58\u50a8\u8fd9\u4e9b\u53d8\u91cf\u3002\u7a00\u758f\u683c\u5f0f\u4e0d\u8fc7\u662f\u4e00\u79cd\u5728\u5185\u5b58\u4e2d\u5b58\u50a8\u6570\u636e\u7684\u8868\u793a\u6216\u65b9\u5f0f\uff0c\u5728\u8fd9\u79cd\u683c\u5f0f\u4e2d\uff0c\u4f60\u5e76\u4e0d\u5b58\u50a8\u6240\u6709\u7684\u503c\uff0c\u800c\u53ea\u5b58\u50a8\u91cd\u8981\u7684\u503c\u3002\u5728\u4e0a\u8ff0\u4e8c\u8fdb\u5236\u53d8\u91cf\u7684\u60c5\u51b5\u4e2d\uff0c\u6700\u91cd\u8981\u7684\u5c31\u662f\u6709 1 \u7684\u5730\u65b9\u3002 \u5f88\u96be\u60f3\u8c61\u8fd9\u6837\u7684\u683c\u5f0f\uff0c\u4f46\u4e3e\u4e2a\u4f8b\u5b50\u5c31\u4f1a\u660e\u767d\u3002 \u5047\u8bbe\u4e0a\u9762\u7684\u6570\u636e\u5e27\u4e2d\u53ea\u6709\u4e00\u4e2a\u7279\u5f81\uff1a ord_2 \u3002 Index Feature 0 Warm 1 Hot 2 Lava hot \u76ee\u524d\uff0c\u6211\u4eec\u53ea\u770b\u5230\u6570\u636e\u96c6\u4e2d\u7684\u4e09\u4e2a\u6837\u672c\u3002\u8ba9\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u4e8c\u503c\u8868\u793a\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u6709\u4e09\u4e2a\u9879\u76ee\u3002 \u8fd9\u4e09\u4e2a\u9879\u76ee\u5c31\u662f\u4e09\u4e2a\u7279\u5f81\u3002 Index Feature_0 Feature_1 Feature_2 0 0 0 1 1 1 0 0 2 1 0 1 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u7279\u5f81\u5b58\u50a8\u5728\u4e00\u4e2a\u6709 3 \u884c 3 \u5217\uff083x3\uff09\u7684\u77e9\u9635\u4e2d\u3002\u77e9\u9635\u7684\u6bcf\u4e2a\u5143\u7d20\u5360\u7528 8 \u4e2a\u5b57\u8282\u3002\u56e0\u6b64\uff0c\u8fd9\u4e2a\u6570\u7ec4\u7684\u603b\u5185\u5b58\u9700\u6c42\u4e3a 8x3x3 = 72 \u5b57\u8282\u3002 \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e2a\u7b80\u5355\u7684 python \u4ee3\u7801\u6bb5\u6765\u68c0\u67e5\u8fd9\u4e00\u70b9\u3002 import numpy as np example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) print ( example . nbytes ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u6253\u5370\u51fa 72\uff0c\u5c31\u50cf\u6211\u4eec\u4e4b\u524d\u8ba1\u7b97\u7684\u90a3\u6837\u3002\u4f46\u6211\u4eec\u9700\u8981\u5b58\u50a8\u8fd9\u4e2a\u77e9\u9635\u7684\u6240\u6709\u5143\u7d20\u5417\uff1f\u5982\u524d\u6240\u8ff0\uff0c\u6211\u4eec\u53ea\u5bf9 1 \u611f\u5174\u8da3\u30020 \u5e76\u4e0d\u91cd\u8981\uff0c\u56e0\u4e3a\u4efb\u4f55\u4e0e 0 \u76f8\u4e58\u7684\u5143\u7d20\u90fd\u662f 0\uff0c\u800c 0 \u4e0e\u4efb\u4f55\u5143\u7d20\u76f8\u52a0\u6216\u76f8\u51cf\u4e5f\u6ca1\u6709\u4efb\u4f55\u533a\u522b\u3002\u53ea\u7528 1 \u8868\u793a\u77e9\u9635\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u67d0\u79cd\u5b57\u5178\u65b9\u6cd5\uff0c\u5176\u4e2d\u952e\u662f\u884c\u548c\u5217\u7684\u7d22\u5f15\uff0c\u503c\u662f 1\uff1a ( 0 , 2 ) 1 ( 1 , 0 ) 1 ( 2 , 0 ) 1 ( 2 , 2 ) 1 \u8fd9\u6837\u7684\u7b26\u53f7\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u56e0\u4e3a\u5b83\u53ea\u9700\u5b58\u50a8\u56db\u4e2a\u503c\uff08\u5728\u672c\u4f8b\u4e2d\uff09\u3002\u4f7f\u7528\u7684\u603b\u5185\u5b58\u4e3a 8x4 = 32 \u5b57\u8282\u3002\u4efb\u4f55 numpy \u6570\u7ec4\u90fd\u53ef\u4ee5\u901a\u8fc7\u7b80\u5355\u7684 python \u4ee3\u7801\u8f6c\u6362\u4e3a\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 1 ], [ 1 , 0 , 0 ], [ 1 , 0 , 1 ] ] ) sparse_example = sparse . csr_matrix ( example ) print ( sparse_example . data . nbytes ) \u8fd9\u5c06\u6253\u5370 32\uff0c\u6bd4\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u5c11\u4e86\u8fd9\u4e48\u591a\uff01\u7a00\u758f csr \u77e9\u9635\u7684\u603b\u5927\u5c0f\u662f\u4e09\u4e2a\u503c\u7684\u603b\u548c\u3002 print ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) \u8fd9\u5c06\u6253\u5370\u51fa 64 \u4e2a\u5143\u7d20\uff0c\u4ecd\u7136\u5c11\u4e8e\u6211\u4eec\u7684\u5bc6\u96c6\u6570\u7ec4\u3002\u9057\u61be\u7684\u662f\uff0c\u6211\u4e0d\u4f1a\u8be6\u7ec6\u4ecb\u7ecd\u8fd9\u4e9b\u5143\u7d20\u3002\u4f60\u53ef\u4ee5\u5728 scipy \u6587\u6863\u4e2d\u4e86\u89e3\u66f4\u591a\u3002\u5f53\u6211\u4eec\u62e5\u6709\u66f4\u5927\u7684\u6570\u7ec4\u65f6\uff0c\u6bd4\u5982\u8bf4\u62e5\u6709\u6570\u5343\u4e2a\u6837\u672c\u548c\u6570\u4e07\u4e2a\u7279\u5f81\u7684\u6570\u7ec4\uff0c\u5927\u5c0f\u5dee\u5f02\u5c31\u4f1a\u53d8\u5f97\u975e\u5e38\u5927\u3002\u4f8b\u5982\uff0c\u6211\u4eec\u4f7f\u7528\u57fa\u4e8e\u8ba1\u6570\u7279\u5f81\u7684\u6587\u672c\u6570\u636e\u96c6\u3002 import numpy as np from scipy import sparse n_rows = 10000 n_cols = 100000 # \u751f\u6210\u7b26\u5408\u4f2f\u52aa\u5229\u5206\u5e03\u7684\u968f\u673a\u6570\u7ec4\uff0c\u7ef4\u5ea6\u4e3a[10000, 100000] example = np . random . binomial ( 1 , p = 0.05 , size = ( n_rows , n_cols )) print ( f \"Size of dense array: { example . nbytes } \" ) # \u5c06\u968f\u673a\u77e9\u9635\u8f6c\u6362\u4e3a\u6d17\u6f31\u77e9\u9635 sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u8fd9\u5c06\u6253\u5370\uff1a Size of dense array : 8000000000 Size of sparse array : 399932496 Full size of sparse array : 599938748 \u56e0\u6b64\uff0c\u5bc6\u96c6\u9635\u5217\u9700\u8981 ~8000MB \u6216\u5927\u7ea6 8GB \u5185\u5b58\u3002\u800c\u7a00\u758f\u9635\u5217\u53ea\u5360\u7528 399MB \u5185\u5b58\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5f53\u6211\u4eec\u7684\u7279\u5f81\u4e2d\u6709\u5927\u91cf\u96f6\u65f6\uff0c\u6211\u4eec\u66f4\u559c\u6b22\u7a00\u758f\u9635\u5217\u800c\u4e0d\u662f\u5bc6\u96c6\u9635\u5217\u7684\u539f\u56e0\u3002 \u8bf7\u6ce8\u610f\uff0c\u7a00\u758f\u77e9\u9635\u6709\u591a\u79cd\u4e0d\u540c\u7684\u8868\u793a\u65b9\u6cd5\u3002\u8fd9\u91cc\u6211\u53ea\u5c55\u793a\u4e86\u5176\u4e2d\u4e00\u79cd\uff08\u53ef\u80fd\u4e5f\u662f\u6700\u5e38\u7528\u7684\uff09\u65b9\u6cd5\u3002\u6df1\u5165\u63a2\u8ba8\u8fd9\u4e9b\u65b9\u6cd5\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\uff0c\u56e0\u6b64\u7559\u7ed9\u8bfb\u8005\u4e00\u4e2a\u7ec3\u4e60\u3002 \u5c3d\u7ba1\u4e8c\u503c\u5316\u7279\u5f81\u7684\u7a00\u758f\u8868\u793a\u6bd4\u5176\u5bc6\u96c6\u8868\u793a\u6240\u5360\u7528\u7684\u5185\u5b58\u8981\u5c11\u5f97\u591a\uff0c\u4f46\u5bf9\u4e8e\u5206\u7c7b\u53d8\u91cf\u6765\u8bf4\uff0c\u8fd8\u6709\u4e00\u79cd\u8f6c\u6362\u6240\u5360\u7528\u7684\u5185\u5b58\u66f4\u5c11\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \" \u72ec\u70ed\u7f16\u7801 \"\u3002 \u72ec\u70ed\u7f16\u7801\u4e5f\u662f\u4e00\u79cd\u4e8c\u503c\u7f16\u7801\uff0c\u56e0\u4e3a\u53ea\u6709 0 \u548c 1 \u4e24\u4e2a\u503c\u3002\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5b83\u5e76\u4e0d\u662f\u4e8c\u503c\u8868\u793a\u6cd5\u3002\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0b\u9762\u7684\u4f8b\u5b50\u6765\u7406\u89e3\u5b83\u7684\u8868\u793a\u6cd5\u3002 \u5047\u8bbe\u6211\u4eec\u7528\u4e00\u4e2a\u5411\u91cf\u6765\u8868\u793a ord_2 \u53d8\u91cf\u7684\u6bcf\u4e2a\u7c7b\u522b\u3002\u8fd9\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u4e0e ord_2 \u53d8\u91cf\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u6bcf\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\u90fd\u662f 6\uff0c\u5e76\u4e14\u9664\u4e86\u4e00\u4e2a\u4f4d\u7f6e\u5916\uff0c\u5176\u4ed6\u4f4d\u7f6e\u90fd\u662f 0\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u7279\u6b8a\u7684\u5411\u91cf\u8868\u3002 Freezing 0 0 0 0 0 1 Warm 0 0 0 0 1 0 Cold 0 0 0 1 0 0 Boiling Hot 0 0 1 0 0 0 Hot 0 1 0 0 0 0 Lava Hot 1 0 0 0 0 0 \u6211\u4eec\u770b\u5230\u5411\u91cf\u7684\u5927\u5c0f\u662f 1x6\uff0c\u5373\u5411\u91cf\u4e2d\u67096\u4e2a\u5143\u7d20\u3002\u8fd9\u4e2a\u6570\u5b57\u662f\u600e\u4e48\u6765\u7684\u5462\uff1f\u5982\u679c\u4f60\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u5982\u524d\u6240\u8ff0\uff0c\u67096\u4e2a\u7c7b\u522b\u3002\u5728\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u65f6\uff0c\u5411\u91cf\u7684\u5927\u5c0f\u5fc5\u987b\u4e0e\u6211\u4eec\u8981\u67e5\u770b\u7684\u7c7b\u522b\u6570\u76f8\u540c\u3002\u6bcf\u4e2a\u5411\u91cf\u90fd\u6709\u4e00\u4e2a 1\uff0c\u5176\u4f59\u6240\u6709\u503c\u90fd\u662f 0\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u7528\u8fd9\u4e9b\u7279\u5f81\u6765\u4ee3\u66ff\u4e4b\u524d\u7684\u4e8c\u503c\u5316\u7279\u5f81\uff0c\u770b\u770b\u80fd\u8282\u7701\u591a\u5c11\u5185\u5b58\u3002 \u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\u4ee5\u524d\u7684\u6570\u636e\uff0c\u5b83\u770b\u8d77\u6765\u5982\u4e0b\uff1a Index Feature 0 Warm 1 Hot 2 Lava hot \u6bcf\u4e2a\u6837\u672c\u67093\u4e2a\u7279\u5f81\u3002\u4f46\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u72ec\u70ed\u5411\u91cf\u7684\u5927\u5c0f\u4e3a 6\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u67096\u4e2a\u7279\u5f81\uff0c\u800c\u4e0d\u662f3\u4e2a\u3002 Index F_0 F_1 F_2 F_3 F_4 F_5 0 0 0 0 0 1 0 1 0 1 0 0 0 0 2 1 0 1 0 0 0 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 6 \u4e2a\u7279\u5f81\uff0c\u800c\u5728\u8fd9\u4e2a 3x6 \u6570\u7ec4\u4e2d\uff0c\u53ea\u6709 3 \u4e2a1\u3002\u4f7f\u7528 numpy \u8ba1\u7b97\u5927\u5c0f\u4e0e\u4e8c\u503c\u5316\u5927\u5c0f\u8ba1\u7b97\u811a\u672c\u975e\u5e38\u76f8\u4f3c\u3002\u4f60\u9700\u8981\u6539\u53d8\u7684\u53ea\u662f\u6570\u7ec4\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u6bb5\u4ee3\u7801\u3002 import numpy as np from scipy import sparse example = np . array ( [ [ 0 , 0 , 0 , 0 , 1 , 0 ], [ 0 , 1 , 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 , 0 , 0 ] ] ) print ( f \"Size of dense array: { example . nbytes } \" ) sparse_example = sparse . csr_matrix ( example ) print ( f \"Size of sparse array: { sparse_example . data . nbytes } \" ) full_size = ( sparse_example . data . nbytes + sparse_example . indptr . nbytes + sparse_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u6253\u5370\u5185\u5b58\u5927\u5c0f\u4e3a\uff1a Size of dense array : 144 Size of sparse array : 24 Full size of sparse array : 52 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5bc6\u96c6\u77e9\u9635\u7684\u5927\u5c0f\u8fdc\u8fdc\u5927\u4e8e\u4e8c\u503c\u5316\u77e9\u9635\u7684\u5927\u5c0f\u3002\u4e0d\u8fc7\uff0c\u7a00\u758f\u6570\u7ec4\u7684\u5927\u5c0f\u8981\u66f4\u5c0f\u3002\u8ba9\u6211\u4eec\u7528\u66f4\u5927\u7684\u6570\u7ec4\u6765\u8bd5\u8bd5\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 OneHotEncoder \u5c06\u5305\u542b 1001 \u4e2a\u7c7b\u522b\u7684\u7279\u5f81\u6570\u7ec4\u8f6c\u6362\u4e3a\u5bc6\u96c6\u77e9\u9635\u548c\u7a00\u758f\u77e9\u9635\u3002 import numpy as np from sklearn import preprocessing # \u751f\u6210\u7b26\u5408\u5747\u5300\u5206\u5e03\u7684\u968f\u673a\u6574\u6570\uff0c\u7ef4\u5ea6\u4e3a[1000000, 10000000] example = np . random . randint ( 1000 , size = 1000000 ) # \u72ec\u70ed\u7f16\u7801\uff0c\u975e\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = False ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of dense array: { ohe_example . nbytes } \" ) # \u72ec\u70ed\u7f16\u7801\uff0c\u7a00\u758f\u77e9\u9635 ohe = preprocessing . OneHotEncoder ( sparse = True ) # \u5c06\u968f\u673a\u6570\u7ec4\u5c55\u5e73 ohe_example = ohe . fit_transform ( example . reshape ( - 1 , 1 )) print ( f \"Size of sparse array: { ohe_example . data . nbytes } \" ) full_size = ( ohe_example . data . nbytes + ohe_example . indptr . nbytes + ohe_example . indices . nbytes ) print ( f \"Full size of sparse array: { full_size } \" ) \u4e0a\u9762\u4ee3\u7801\u6253\u5370\u7684\u8f93\u51fa\uff1a Size of dense array : 8000000000 Size of sparse array : 8000000 Full size of sparse array : 16000004 \u8fd9\u91cc\u7684\u5bc6\u96c6\u9635\u5217\u5927\u5c0f\u7ea6\u4e3a 8GB\uff0c\u7a00\u758f\u9635\u5217\u4e3a 8MB\u3002\u5982\u679c\u53ef\u4ee5\u9009\u62e9\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff1f\u5728\u6211\u770b\u6765\uff0c\u9009\u62e9\u5f88\u7b80\u5355\uff0c\u4e0d\u662f\u5417\uff1f \u8fd9\u4e09\u79cd\u65b9\u6cd5\uff08\u6807\u7b7e\u7f16\u7801\u3001\u7a00\u758f\u77e9\u9635\u3001\u72ec\u70ed\u7f16\u7801\uff09\u662f\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u7684\u6700\u91cd\u8981\u65b9\u6cd5\u3002\u4e0d\u8fc7\uff0c\u4f60\u8fd8\u53ef\u4ee5\u7528\u5f88\u591a\u5176\u4ed6\u4e0d\u540c\u7684\u65b9\u6cd5\u6765\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u3002\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u6570\u503c\u53d8\u91cf\u5c31\u662f\u5176\u4e2d\u7684\u4e00\u4e2a\u4f8b\u5b50\u3002 \u5047\u8bbe\u6211\u4eec\u56de\u5230\u4e4b\u524d\u7684\u5206\u7c7b\u7279\u5f81\u6570\u636e\uff08\u539f\u59cb\u6570\u636e\u4e2d\u7684 cat-in-the-dat-ii\uff09\u3002\u5728\u6570\u636e\u4e2d\uff0c ord_2 \u7684\u503c\u4e3a\u201c\u70ed\u201c\u7684 id \u6709\u591a\u5c11\uff1f \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u8ba1\u7b97\u6570\u636e\u7684\u5f62\u72b6\uff08shape\uff09\u8f7b\u677e\u8ba1\u7b97\u51fa\u8fd9\u4e2a\u503c\uff0c\u5176\u4e2d ord_2 \u5217\u7684\u503c\u4e3a Boiling Hot \u3002 In [ X ]: df [ df . ord_2 == \"Boiling Hot\" ] . shape Out [ X ]: ( 84790 , 25 ) \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 84790 \u6761\u8bb0\u5f55\u5177\u6709\u6b64\u503c\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 pandas \u4e2d\u7684 groupby \u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u8be5\u503c\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . count () Out [ X ]: ord_2 Boiling Hot 84790 Cold 97822 Freezing 142726 Hot 67508 Lava Hot 64840 Warm 124239 Name : id , dtype : int64 \u5982\u679c\u6211\u4eec\u53ea\u662f\u5c06 ord_2 \u5217\u66ff\u6362\u4e3a\u5176\u8ba1\u6570\u503c\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u5c06\u5176\u8f6c\u6362\u4e3a\u4e00\u79cd\u6570\u503c\u7279\u5f81\u4e86\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u7684 transform \u51fd\u6570\u548c groupby \u6765\u521b\u5efa\u65b0\u5217\u6216\u66ff\u6362\u8fd9\u4e00\u5217\u3002 In [ X ]: df . groupby ([ \"ord_2\" ])[ \"id\" ] . transform ( \"count\" ) Out [ X ]: 0 67508.0 1 124239.0 2 142726.0 3 64840.0 4 97822.0 ... 599995 142726.0 599996 84790.0 599997 142726.0 599998 124239.0 599999 84790.0 Name : id , Length : 600000 , dtype : float64 \u4f60\u53ef\u4ee5\u6dfb\u52a0\u6240\u6709\u7279\u5f81\u7684\u8ba1\u6570\uff0c\u4e5f\u53ef\u4ee5\u66ff\u6362\u5b83\u4eec\uff0c\u6216\u8005\u6839\u636e\u591a\u4e2a\u5217\u53ca\u5176\u8ba1\u6570\u8fdb\u884c\u5206\u7ec4\u3002\u4f8b\u5982\uff0c\u4ee5\u4e0b\u4ee3\u7801\u901a\u8fc7\u5bf9 ord_1 \u548c ord_2 \u5217\u5206\u7ec4\u8fdb\u884c\u8ba1\u6570\u3002 In [ X ]: df . groupby ( ... : [ ... : \"ord_1\" , ... : \"ord_2\" ... : ] ... : )[ \"id\" ] . count () . reset_index ( name = \"count\" ) Out [ X ]: ord_1 ord_2 count 0 Contributor Boiling Hot 15634 1 Contributor Cold 17734 2 Contributor Freezing 26082 3 Contributor Hot 12428 4 Contributor Lava Hot 11919 5 Contributor Warm 22774 6 Expert Boiling Hot 19477 7 Expert Cold 22956 8 Expert Freezing 33249 9 Expert Hot 15792 10 Expert Lava Hot 15078 11 Expert Warm 28900 12 Grandmaster Boiling Hot 13623 13 Grandmaster Cold 15464 14 Grandmaster Freezing 22818 15 Grandmaster Hot 10805 16 Grandmaster Lava Hot 10363 17 Grandmaster Warm 19899 18 Master Boiling Hot 10800 ... \u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u884c\uff0c\u4ee5\u4fbf\u5728\u4e00\u9875\u4e2d\u5bb9\u7eb3\u8fd9\u4e9b\u884c\u3002\u8fd9\u662f\u53e6\u4e00\u79cd\u53ef\u4ee5\u4f5c\u4e3a\u529f\u80fd\u6dfb\u52a0\u7684\u8ba1\u6570\u3002\u60a8\u73b0\u5728\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4f7f\u7528 id \u5217\u8fdb\u884c\u8ba1\u6570\u3002\u4e0d\u8fc7\uff0c\u4f60\u4e5f\u53ef\u4ee5\u901a\u8fc7\u5bf9\u5217\u7684\u7ec4\u5408\u8fdb\u884c\u5206\u7ec4\uff0c\u5bf9\u5176\u4ed6\u5217\u8fdb\u884c\u8ba1\u6570\u3002 \u8fd8\u6709\u4e00\u4e2a\u5c0f\u7a8d\u95e8\uff0c\u5c31\u662f\u4ece\u8fd9\u4e9b\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\u3002\u4f60\u53ef\u4ee5\u4ece\u73b0\u6709\u7684\u7279\u5f81\u4e2d\u521b\u5efa\u65b0\u7684\u5206\u7c7b\u7279\u5f81\uff0c\u800c\u4e14\u53ef\u4ee5\u6beb\u4e0d\u8d39\u529b\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot 1 Grandmaster_Warm 2 nan_Freezing 3 Novice_Lava Hot 4 Grandmaster_Cold ... 599999 Contributor_Boiling Hot Name : new_feature , Length : 600000 , dtype : object \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u7528\u4e0b\u5212\u7ebf\u5c06 ord_1 \u548c ord_2 \u5408\u5e76\uff0c\u7136\u540e\u5c06\u8fd9\u4e9b\u5217\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u7c7b\u578b\u3002\u8bf7\u6ce8\u610f\uff0cNaN \u4e5f\u4f1a\u8f6c\u6362\u4e3a\u5b57\u7b26\u4e32\u3002\u4e0d\u8fc7\u6ca1\u5173\u7cfb\u3002\u6211\u4eec\u4e5f\u53ef\u4ee5\u5c06 NaN \u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u6709\u4e86\u4e00\u4e2a\u7531\u8fd9\u4e24\u4e2a\u7279\u5f81\u7ec4\u5408\u800c\u6210\u7684\u65b0\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u5c06\u4e09\u5217\u4ee5\u4e0a\u6216\u56db\u5217\u751a\u81f3\u66f4\u591a\u5217\u7ec4\u5408\u5728\u4e00\u8d77\u3002 In [ X ]: df [ \"new_feature\" ] = ( ... : df . ord_1 . astype ( str ) ... : + \"_\" ... : + df . ord_2 . astype ( str ) ... : + \"_\" ... : + df . ord_3 . astype ( str ) ... : ) In [ X ]: df . new_feature Out [ X ]: 0 Contributor_Hot_c 1 Grandmaster_Warm_e 2 nan_Freezing_n 3 Novice_Lava Hot_a 4 Grandmaster_Cold_h ... 599999 Contributor_Boiling Hot_b Name : new_feature , Length : 600000 , dtype : object \u90a3\u4e48\uff0c\u6211\u4eec\u5e94\u8be5\u628a\u54ea\u4e9b\u7c7b\u522b\u7ed3\u5408\u8d77\u6765\u5462\uff1f\u8fd9\u5e76\u6ca1\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u7b54\u6848\u3002\u8fd9\u53d6\u51b3\u4e8e\u60a8\u7684\u6570\u636e\u548c\u7279\u5f81\u7c7b\u578b\u3002\u4e00\u4e9b\u9886\u57df\u77e5\u8bc6\u5bf9\u4e8e\u521b\u5efa\u8fd9\u6837\u7684\u7279\u5f81\u53ef\u80fd\u5f88\u6709\u7528\u3002\u4f46\u662f\uff0c\u5982\u679c\u4f60\u4e0d\u62c5\u5fc3\u5185\u5b58\u548c CPU \u7684\u4f7f\u7528\uff0c\u4f60\u53ef\u4ee5\u91c7\u7528\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\uff0c\u5373\u521b\u5efa\u8bb8\u591a\u8fd9\u6837\u7684\u7ec4\u5408\uff0c\u7136\u540e\u4f7f\u7528\u4e00\u4e2a\u6a21\u578b\u6765\u51b3\u5b9a\u54ea\u4e9b\u7279\u5f81\u662f\u6709\u7528\u7684\uff0c\u5e76\u4fdd\u7559\u5b83\u4eec\u3002\u6211\u4eec\u5c06\u5728\u672c\u4e66\u7a0d\u540e\u90e8\u5206\u4ecb\u7ecd\u8fd9\u79cd\u65b9\u6cd5\u3002 \u65e0\u8bba\u4f55\u65f6\u83b7\u5f97\u5206\u7c7b\u53d8\u91cf\uff0c\u90fd\u8981\u9075\u5faa\u4ee5\u4e0b\u7b80\u5355\u6b65\u9aa4\uff1a - \u586b\u5145 NaN \u503c\uff08\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff01\uff09\u3002 - \u4f7f\u7528 scikit-learn \u7684 LabelEncoder \u6216\u6620\u5c04\u5b57\u5178\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\uff0c\u5c06\u5b83\u4eec\u8f6c\u6362\u4e3a\u6574\u6570\u3002\u5982\u679c\u6ca1\u6709\u586b\u5145 NaN \u503c\uff0c\u53ef\u80fd\u9700\u8981\u5728\u8fd9\u4e00\u6b65\u4e2d\u8fdb\u884c\u5904\u7406 - \u521b\u5efa\u72ec\u70ed\u7f16\u7801\u3002\u662f\u7684\uff0c\u4f60\u53ef\u4ee5\u8df3\u8fc7\u4e8c\u503c\u5316\uff01 - \u5efa\u6a21\uff01\u6211\u6307\u7684\u662f\u673a\u5668\u5b66\u4e60\u3002 \u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u5904\u7406 NaN \u6570\u636e\u975e\u5e38\u91cd\u8981\uff0c\u5426\u5219\u60a8\u53ef\u80fd\u4f1a\u4ece scikit-learn \u7684 LabelEncoder \u4e2d\u5f97\u5230\u81ed\u540d\u662d\u8457\u7684\u9519\u8bef\u4fe1\u606f\uff1a ValueError: y \u5305\u542b\u4ee5\u524d\u672a\u89c1\u8fc7\u7684\u6807\u7b7e\uff1a [Nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan) \u8fd9\u4ec5\u4ec5\u610f\u5473\u7740\uff0c\u5728\u8f6c\u6362\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6570\u636e\u4e2d\u51fa\u73b0\u4e86 NaN \u503c\u3002\u8fd9\u662f\u56e0\u4e3a\u4f60\u5728\u8bad\u7ec3\u65f6\u5fd8\u8bb0\u4e86\u5904\u7406\u5b83\u4eec\u3002 \u5904\u7406 NaN \u503c \u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u4e22\u5f03\u5b83\u4eec\u3002\u867d\u7136\u7b80\u5355\uff0c\u4f46\u5e76\u4e0d\u7406\u60f3\u3002NaN \u503c\u4e2d\u53ef\u80fd\u5305\u542b\u5f88\u591a\u4fe1\u606f\uff0c\u5982\u679c\u53ea\u662f\u4e22\u5f03\u8fd9\u4e9b\u503c\uff0c\u5c31\u4f1a\u4e22\u5931\u8fd9\u4e9b\u4fe1\u606f\u3002\u5728\u5f88\u591a\u60c5\u51b5\u4e0b\uff0c\u5927\u90e8\u5206\u6570\u636e\u90fd\u662f NaN \u503c\uff0c\u56e0\u6b64\u4e0d\u80fd\u4e22\u5f03 NaN \u503c\u7684\u884c/\u6837\u672c\u3002\u5904\u7406 NaN \u503c\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u5c06\u5176\u4f5c\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u7c7b\u522b\u3002\u8fd9\u662f\u5904\u7406 NaN \u503c\u6700\u5e38\u7528\u7684\u65b9\u6cd5\u3002\u5982\u679c\u4f7f\u7528 pandas\uff0c\u8fd8\u53ef\u4ee5\u901a\u8fc7\u975e\u5e38\u7b80\u5355\u7684\u65b9\u5f0f\u5b9e\u73b0\u3002 \u8bf7\u770b\u6211\u4eec\u4e4b\u524d\u67e5\u770b\u8fc7\u7684\u6570\u636e\u7684 ord_2 \u5217\u3002 In [ X ]: df . ord_2 . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 Name : ord_2 , dtype : int64 \u586b\u5165 NaN \u503c\u540e\uff0c\u5c31\u53d8\u6210\u4e86 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u54c7\uff01\u8fd9\u4e00\u5217\u4e2d\u6709 18075 \u4e2a NaN \u503c\uff0c\u800c\u6211\u4eec\u4e4b\u524d\u751a\u81f3\u90fd\u6ca1\u6709\u8003\u8651\u4f7f\u7528\u5b83\u4eec\u3002\u589e\u52a0\u4e86\u8fd9\u4e2a\u65b0\u7c7b\u522b\u540e\uff0c\u7c7b\u522b\u603b\u6570\u4ece 6 \u4e2a\u589e\u52a0\u5230\u4e86 7 \u4e2a\u3002\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u73b0\u5728\u6211\u4eec\u5728\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u4e5f\u4f1a\u8003\u8651 NaN\u3002\u76f8\u5173\u4fe1\u606f\u8d8a\u591a\uff0c\u6a21\u578b\u5c31\u8d8a\u597d\u3002 \u5047\u8bbe ord_2 \u6ca1\u6709\u4efb\u4f55 NaN \u503c\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e00\u5217\u4e2d\u7684\u6240\u6709\u7c7b\u522b\u90fd\u6709\u663e\u8457\u7684\u8ba1\u6570\u3002\u5176\u4e2d\u6ca1\u6709 \"\u7f55\u89c1 \"\u7c7b\u522b\uff0c\u5373\u53ea\u5728\u6837\u672c\u603b\u6570\u4e2d\u5360\u5f88\u5c0f\u6bd4\u4f8b\u7684\u7c7b\u522b\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5047\u8bbe\u60a8\u5728\u751f\u4ea7\u4e2d\u90e8\u7f72\u4e86\u4f7f\u7528\u8fd9\u4e00\u5217\u7684\u6a21\u578b\uff0c\u5f53\u6a21\u578b\u6216\u9879\u76ee\u4e0a\u7ebf\u65f6\uff0c\u60a8\u5728 ord_2 \u5217\u4e2d\u5f97\u5230\u4e86\u4e00\u4e2a\u5728\u8bad\u7ec3\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u7ba1\u9053\u4f1a\u629b\u51fa\u4e00\u4e2a\u9519\u8bef\uff0c\u60a8\u5bf9\u6b64\u65e0\u80fd\u4e3a\u529b\u3002\u5982\u679c\u51fa\u73b0\u8fd9\u79cd\u60c5\u51b5\uff0c\u90a3\u4e48\u53ef\u80fd\u662f\u751f\u4ea7\u4e2d\u7684\u7ba1\u9053\u51fa\u4e86\u95ee\u9898\u3002\u5982\u679c\u8fd9\u662f\u9884\u6599\u4e4b\u4e2d\u7684\uff0c\u90a3\u4e48\u60a8\u5c31\u5fc5\u987b\u4fee\u6539\u60a8\u7684\u6a21\u578b\u7ba1\u9053\uff0c\u5e76\u5728\u8fd9\u516d\u4e2a\u7c7b\u522b\u4e2d\u52a0\u5165\u4e00\u4e2a\u65b0\u7c7b\u522b\u3002 \u8fd9\u4e2a\u65b0\u7c7b\u522b\u88ab\u79f0\u4e3a \"\u7f55\u89c1 \"\u7c7b\u522b\u3002\u7f55\u89c1\u7c7b\u522b\u662f\u4e00\u79cd\u4e0d\u5e38\u89c1\u7684\u7c7b\u522b\uff0c\u53ef\u4ee5\u5305\u62ec\u8bb8\u591a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8fd1\u90bb\u6a21\u578b\u6765 \"\u9884\u6d4b \"\u672a\u77e5\u7c7b\u522b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u4e86\u8fd9\u4e2a\u7c7b\u522b\uff0c\u5b83\u5c31\u4f1a\u6210\u4e3a\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u4e00\u4e2a\u7c7b\u522b\u3002 \u56fe 3\uff1a\u5177\u6709\u4e0d\u540c\u7279\u5f81\u4e14\u65e0\u6807\u7b7e\u7684\u6570\u636e\u96c6\u793a\u610f\u56fe\uff0c\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u53ef\u80fd\u4f1a\u5728\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u51fa\u73b0\u65b0\u503c \u5f53\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u56fe 3 \u6240\u793a\u7684\u6570\u636e\u96c6\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\uff0c\u5bf9\u9664 \"f3 \"\u4e4b\u5916\u7684\u6240\u6709\u7279\u5f81\u8fdb\u884c\u8bad\u7ec3\u3002\u8fd9\u6837\uff0c\u4f60\u5c06\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u4e0d\u77e5\u9053\u6216\u8bad\u7ec3\u4e2d\u6ca1\u6709 \"f3 \"\u65f6\u9884\u6d4b\u5b83\u3002\u6211\u4e0d\u6562\u8bf4\u8fd9\u6837\u7684\u6a21\u578b\u662f\u5426\u80fd\u5e26\u6765\u51fa\u8272\u7684\u6027\u80fd\uff0c\u4f46\u4e5f\u8bb8\u80fd\u5904\u7406\u6d4b\u8bd5\u96c6\u6216\u5b9e\u65f6\u6570\u636e\u4e2d\u7684\u7f3a\u5931\u503c\uff0c\u5c31\u50cf\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5176\u4ed6\u4e8b\u60c5\u4e00\u6837\uff0c\u4e0d\u5c1d\u8bd5\u4e00\u4e0b\u662f\u8bf4\u4e0d\u51c6\u7684\u3002 \u5982\u679c\u4f60\u6709\u4e00\u4e2a\u56fa\u5b9a\u7684\u6d4b\u8bd5\u96c6\uff0c\u4f60\u53ef\u4ee5\u5c06\u6d4b\u8bd5\u6570\u636e\u6dfb\u52a0\u5230\u8bad\u7ec3\u4e2d\uff0c\u4ee5\u4e86\u89e3\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u7c7b\u522b\u3002\u8fd9\u4e0e\u534a\u76d1\u7763\u5b66\u4e60\u975e\u5e38\u76f8\u4f3c\uff0c\u5373\u4f7f\u7528\u65e0\u6cd5\u7528\u4e8e\u8bad\u7ec3\u7684\u6570\u636e\u6765\u6539\u8fdb\u6a21\u578b\u3002\u8fd9\u4e5f\u4f1a\u7167\u987e\u5230\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u51fa\u73b0\u6b21\u6570\u6781\u5c11\u4f46\u5728\u6d4b\u8bd5\u6570\u636e\u4e2d\u5927\u91cf\u5b58\u5728\u7684\u7a00\u6709\u503c\u3002\u4f60\u7684\u6a21\u578b\u5c06\u66f4\u52a0\u7a33\u5065\u3002 \u5f88\u591a\u4eba\u8ba4\u4e3a\u8fd9\u79cd\u60f3\u6cd5\u4f1a\u8fc7\u5ea6\u62df\u5408\u3002\u53ef\u80fd\u8fc7\u62df\u5408\uff0c\u4e5f\u53ef\u80fd\u4e0d\u8fc7\u62df\u5408\u3002\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u89e3\u51b3\u65b9\u6cd5\u3002\u5982\u679c\u4f60\u5728\u8bbe\u8ba1\u4ea4\u53c9\u9a8c\u8bc1\u65f6\uff0c\u80fd\u591f\u5728\u6d4b\u8bd5\u6570\u636e\u4e0a\u8fd0\u884c\u6a21\u578b\u65f6\u590d\u5236\u9884\u6d4b\u8fc7\u7a0b\uff0c\u90a3\u4e48\u5b83\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u8fc7\u62df\u5408\u3002\u8fd9\u610f\u5473\u7740\u7b2c\u4e00\u6b65\u5e94\u8be5\u662f\u5206\u79bb\u6298\u53e0\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u4f60\u5e94\u8be5\u5e94\u7528\u4e0e\u6d4b\u8bd5\u6570\u636e\u76f8\u540c\u7684\u9884\u5904\u7406\u3002\u5047\u8bbe\u60a8\u60f3\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u6d4b\u8bd5\u6570\u636e\uff0c\u90a3\u4e48\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u60a8\u5fc5\u987b\u5408\u5e76\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\uff0c\u5e76\u786e\u4fdd\u9a8c\u8bc1\u6570\u636e\u96c6\u590d\u5236\u4e86\u6d4b\u8bd5\u96c6\u3002\u5728\u8fd9\u79cd\u7279\u5b9a\u60c5\u51b5\u4e0b\uff0c\u60a8\u5fc5\u987b\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u8bbe\u8ba1\u9a8c\u8bc1\u96c6\uff0c\u4f7f\u5176\u5305\u542b\u8bad\u7ec3\u96c6\u4e2d \"\u672a\u89c1 \"\u7684\u7c7b\u522b\u3002 \u56fe 4\uff1a\u5bf9\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u8fdb\u884c\u7b80\u5355\u5408\u5e76\uff0c\u4ee5\u4e86\u89e3\u6d4b\u8bd5\u96c6\u4e2d\u5b58\u5728\u4f46\u8bad\u7ec3\u96c6\u4e2d\u4e0d\u5b58\u5728\u7684\u7c7b\u522b\u6216\u8bad\u7ec3\u96c6\u4e2d\u7f55\u89c1\u7684\u7c7b\u522b \u53ea\u8981\u770b\u4e00\u4e0b\u56fe 4 \u548c\u4e0b\u9762\u7684\u4ee3\u7801\uff0c\u5c31\u80fd\u5f88\u5bb9\u6613\u7406\u89e3\u5176\u5de5\u4f5c\u539f\u7406\u3002 import pandas as pd from sklearn import preprocessing # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( \"../input/cat_train.csv\" ) # \u8bfb\u53d6\u6d4b\u8bd5\u96c6 test = pd . read_csv ( \"../input/cat_test.csv\" ) # \u5c06\u6d4b\u8bd5\u96c6\"target\"\u5217\u5168\u90e8\u7f6e\u4e3a-1 test . loc [:, \"target\" ] = - 1 # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u62fc\u63a5 data = pd . concat ([ train , test ]) . reset_index ( drop = True ) # \u5c06\u9664\"id\"\u548c\"target\"\u5217\u7684\u5176\u4ed6\u7279\u5f81\u5217\u540d\u53d6\u51fa features = [ x for x in train . columns if x not in [ \"id\" , \"target\" ]] # \u904d\u5386\u7279\u5f81 for feat in features : # \u6807\u7b7e\u7f16\u7801 lbl_enc = preprocessing . LabelEncoder () # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\",\u5e76\u5c06\u8be5\u5217\u683c\u5f0f\u53d8\u4e3astr temp_col = data [ feat ] . fillna ( \"NONE\" ) . astype ( str ) . values # \u8f6c\u6362\u6570\u503c data . loc [:, feat ] = lbl_enc . fit_transform ( temp_col ) # \u6839\u636e\"target\"\u5217\u5c06\u8bad\u7ec3\u96c6\u4e0e\u6d4b\u8bd5\u96c6\u5206\u5f00 train = data [ data . target != - 1 ] . reset_index ( drop = True ) test = data [ data . target == - 1 ] . reset_index ( drop = True ) \u5f53\u60a8\u9047\u5230\u5df2\u7ecf\u6709\u6d4b\u8bd5\u6570\u636e\u96c6\u7684\u95ee\u9898\u65f6\uff0c\u8fd9\u4e2a\u6280\u5de7\u5c31\u4f1a\u8d77\u4f5c\u7528\u3002\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd9\u4e00\u62db\u5728\u5b9e\u65f6\u73af\u5883\u4e2d\u4e0d\u8d77\u4f5c\u7528\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u60a8\u6240\u5728\u7684\u516c\u53f8\u63d0\u4f9b\u5b9e\u65f6\u7ade\u4ef7\u89e3\u51b3\u65b9\u6848\uff08RTB\uff09\u3002RTB \u7cfb\u7edf\u4f1a\u5bf9\u5728\u7ebf\u770b\u5230\u7684\u6bcf\u4e2a\u7528\u6237\u8fdb\u884c\u7ade\u4ef7\uff0c\u4ee5\u8d2d\u4e70\u5e7f\u544a\u7a7a\u95f4\u3002\u8fd9\u79cd\u6a21\u5f0f\u53ef\u4f7f\u7528\u7684\u529f\u80fd\u53ef\u80fd\u5305\u62ec\u7f51\u7ad9\u4e2d\u6d4f\u89c8\u7684\u9875\u9762\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u7279\u5f81\u662f\u7528\u6237\u8bbf\u95ee\u7684\u6700\u540e\u4e94\u4e2a\u7c7b\u522b/\u9875\u9762\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5982\u679c\u7f51\u7ad9\u5f15\u5165\u4e86\u65b0\u7684\u7c7b\u522b\uff0c\u6211\u4eec\u5c06\u65e0\u6cd5\u518d\u51c6\u786e\u9884\u6d4b\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u4f1a\u5931\u6548\u3002\u8fd9\u79cd\u60c5\u51b5\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528 \"\u672a\u77e5 \"\u7c7b\u522b\u6765\u907f\u514d \u3002 \u5728\u6211\u4eec\u7684 cat-in-the-dat \u6570\u636e\u96c6\u4e2d\uff0c ord_2 \u5217\u4e2d\u5df2\u7ecf\u6709\u4e86\u672a\u77e5\u7c7b\u522b\u3002 In [ X ]: df . ord_2 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: Freezing 142726 Warm 124239 Cold 97822 Boiling Hot 84790 Hot 67508 Lava Hot 64840 NONE 18075 Name : ord_2 , dtype : int64 \u6211\u4eec\u53ef\u4ee5\u5c06 \"NONE \"\u89c6\u4e3a\u672a\u77e5\u3002\u56e0\u6b64\uff0c\u5982\u679c\u5728\u5b9e\u65f6\u6d4b\u8bd5\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u83b7\u5f97\u4e86\u4ee5\u524d\u4ece\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\uff0c\u6211\u4eec\u5c31\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a \"NONE\"\u3002 \u8fd9\u4e0e\u81ea\u7136\u8bed\u8a00\u5904\u7406\u95ee\u9898\u975e\u5e38\u76f8\u4f3c\u3002\u6211\u4eec\u603b\u662f\u57fa\u4e8e\u56fa\u5b9a\u7684\u8bcd\u6c47\u5efa\u7acb\u6a21\u578b\u3002\u589e\u52a0\u8bcd\u6c47\u91cf\u5c31\u4f1a\u589e\u52a0\u6a21\u578b\u7684\u5927\u5c0f\u3002\u50cf BERT \u8fd9\u6837\u7684\u8f6c\u6362\u5668\u6a21\u578b\u662f\u5728 ~30000 \u4e2a\u5355\u8bcd\uff08\u82f1\u8bed\uff09\u7684\u57fa\u7840\u4e0a\u8bad\u7ec3\u7684\u3002\u56e0\u6b64\uff0c\u5f53\u6709\u65b0\u8bcd\u8f93\u5165\u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u5176\u6807\u8bb0\u4e3a UNK\uff08\u672a\u77e5\uff09\u3002 \u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u5047\u8bbe\u6d4b\u8bd5\u6570\u636e\u4e0e\u8bad\u7ec3\u6570\u636e\u5177\u6709\u76f8\u540c\u7684\u7c7b\u522b\uff0c\u4e5f\u53ef\u4ee5\u5728\u8bad\u7ec3\u6570\u636e\u4e2d\u5f15\u5165\u7f55\u89c1\u6216\u672a\u77e5\u7c7b\u522b\uff0c\u4ee5\u5904\u7406\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u65b0\u7c7b\u522b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u586b\u5165 NaN \u503c\u540e ord_4 \u5217\u7684\u503c\u8ba1\u6570\uff1a In [ X ]: df . ord_4 . fillna ( \"NONE\" ) . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 . . . K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 G 3404 V 3107 J 1950 L 1657 Name : ord_4 , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u6709\u4e9b\u6570\u503c\u53ea\u51fa\u73b0\u4e86\u51e0\u5343\u6b21\uff0c\u6709\u4e9b\u5219\u51fa\u73b0\u4e86\u8fd1 40000 \u6b21\u3002NaN \u4e5f\u7ecf\u5e38\u51fa\u73b0\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ece\u8f93\u51fa\u4e2d\u5220\u9664\u4e86\u4e00\u4e9b\u503c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5b9a\u4e49\u5c06\u4e00\u4e2a\u503c\u79f0\u4e3a \" \u7f55\u89c1\uff08rare\uff09 \"\u7684\u6807\u51c6\u4e86\u3002\u6bd4\u65b9\u8bf4\uff0c\u5728\u8fd9\u4e00\u5217\u4e2d\uff0c\u7a00\u6709\u503c\u7684\u8981\u6c42\u662f\u8ba1\u6570\u5c0f\u4e8e 2000\u3002\u8fd9\u6837\u770b\u6765\uff0cJ \u548c L \u5c31\u53ef\u4ee5\u88ab\u6807\u8bb0\u4e3a\u7a00\u6709\u503c\u4e86\u3002\u4f7f\u7528 pandas\uff0c\u6839\u636e\u8ba1\u6570\u9608\u503c\u66ff\u6362\u7c7b\u522b\u975e\u5e38\u7b80\u5355\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 In [ X ]: df . ord_4 = df . ord_4 . fillna ( \"NONE\" ) In [ X ]: df . loc [ ... : df [ \"ord_4\" ] . value_counts ()[ df [ \"ord_4\" ]] . values < 2000 , ... : \"ord_4\" ... : ] = \"RARE\" In [ X ]: df . ord_4 . value_counts () Out [ X ]: N 39978 P 37890 Y 36657 A 36633 R 33045 U 32897 M 32504 . . . B 25212 E 21871 K 21676 I 19805 NONE 17930 D 17284 F 16721 W 8268 Z 5790 S 4595 RARE 3607 G 3404 V 3107 Name : ord_4 , dtype : int64 \u6211\u4eec\u8ba4\u4e3a\uff0c\u53ea\u8981\u67d0\u4e2a\u7c7b\u522b\u7684\u503c\u5c0f\u4e8e 2000\uff0c\u5c31\u5c06\u5176\u66ff\u6362\u4e3a\u7f55\u89c1\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u5728\u6d4b\u8bd5\u6570\u636e\u65f6\uff0c\u6240\u6709\u672a\u89c1\u8fc7\u7684\u65b0\u7c7b\u522b\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"RARE\"\uff0c\u800c\u6240\u6709\u7f3a\u5931\u503c\u90fd\u5c06\u88ab\u6620\u5c04\u4e3a \"NONE\"\u3002 \u8fd9\u79cd\u65b9\u6cd5\u8fd8\u80fd\u786e\u4fdd\u5373\u4f7f\u6709\u65b0\u7684\u7c7b\u522b\uff0c\u6a21\u578b\u4e5f\u80fd\u5728\u5b9e\u9645\u73af\u5883\u4e2d\u6b63\u5e38\u5de5\u4f5c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u5177\u5907\u4e86\u5904\u7406\u4efb\u4f55\u5e26\u6709\u5206\u7c7b\u53d8\u91cf\u95ee\u9898\u6240\u9700\u7684\u4e00\u5207\u6761\u4ef6\u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5efa\u7acb\u7b2c\u4e00\u4e2a\u6a21\u578b\uff0c\u5e76\u9010\u6b65\u63d0\u9ad8\u5176\u6027\u80fd\u3002 \u5728\u6784\u5efa\u4efb\u4f55\u7c7b\u578b\u7684\u6a21\u578b\u4e4b\u524d\uff0c\u4ea4\u53c9\u68c0\u9a8c\u81f3\u5173\u91cd\u8981\u3002\u6211\u4eec\u5df2\u7ecf\u770b\u5230\u4e86\u6807\u7b7e/\u76ee\u6807\u5206\u5e03\uff0c\u77e5\u9053\u8fd9\u662f\u4e00\u4e2a\u76ee\u6807\u504f\u659c\u7684\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u6765\u5206\u5272\u6570\u636e\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/cat_train.csv\" ) # \u6dfb\u52a0\"kfold\"\u5217\uff0c\u5e76\u7f6e\u4e3a-1 df [ \"kfold\" ] = - 1 # \u6253\u4e71\u6570\u636e\u987a\u5e8f\uff0c\u91cd\u7f6e\u7d22\u5f15 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) # \u5c06\u76ee\u6807\u5217\u53d6\u51fa y = df . target . values # \u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): # \u533a\u5206\u6298\u53e0 df . loc [ v_ , 'kfold' ] = f # \u4fdd\u5b58\u6587\u4ef6 df . to_csv ( \"../input/cat_train_folds.csv\" , index = False ) \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u68c0\u67e5\u65b0\u7684\u6298\u53e0 csv\uff0c\u67e5\u770b\u6bcf\u4e2a\u6298\u53e0\u7684\u6837\u672c\u6570\uff1a In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) In [ X ]: df . kfold . value_counts () Out [ X ]: 4 120000 3 120000 2 120000 1 120000 0 120000 Name : kfold , dtype : int64 \u6240\u6709\u6298\u53e0\u90fd\u6709 120000 \u4e2a\u6837\u672c\u3002\u8fd9\u662f\u610f\u6599\u4e4b\u4e2d\u7684\uff0c\u56e0\u4e3a\u8bad\u7ec3\u6570\u636e\u6709 600000 \u4e2a\u6837\u672c\uff0c\u800c\u6211\u4eec\u505a\u4e865\u6b21\u6298\u53e0\u3002\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4e00\u5207\u987a\u5229\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u68c0\u67e5\u6bcf\u4e2a\u6298\u53e0\u7684\u76ee\u6807\u5206\u5e03\u3002 In [ X ]: df [ df . kfold == 0 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 1 ] . target . value_counts () Out [ X ]: 0 97536 1 22464 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 2 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 3 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 In [ X ]: df [ df . kfold == 4 ] . target . value_counts () Out [ X ]: 0 97535 1 22465 Name : target , dtype : int64 \u6211\u4eec\u770b\u5230\uff0c\u5728\u6bcf\u4e2a\u6298\u53e0\u4e2d\uff0c\u76ee\u6807\u7684\u5206\u5e03\u90fd\u662f\u4e00\u6837\u7684\u3002\u8fd9\u6b63\u662f\u6211\u4eec\u6240\u9700\u8981\u7684\u3002\u5b83\u4e5f\u53ef\u4ee5\u662f\u76f8\u4f3c\u7684\uff0c\u5e76\u4e0d\u4e00\u5b9a\u8981\u4e00\u76f4\u76f8\u540c\u3002\u73b0\u5728\uff0c\u5f53\u6211\u4eec\u5efa\u7acb\u6a21\u578b\u65f6\uff0c\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u6807\u7b7e\u5206\u5e03\u90fd\u5c06\u76f8\u540c\u3002 \u6211\u4eec\u53ef\u4ee5\u5efa\u7acb\u7684\u6700\u7b80\u5355\u7684\u6a21\u578b\u4e4b\u4e00\u662f\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\u5e76\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): # \u8bfb\u53d6\u5206\u5c42k\u6298\u4ea4\u53c9\u68c0\u9a8c\u6570\u636e df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) # \u53d6\u9664\"id\", \"target\", \"kfold\"\u5916\u7684\u5176\u4ed6\u7279\u5f81\u5217 features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] # \u904d\u5386\u7279\u5f81\u5217\u8868 for col in features : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u9a8c\u8bc1\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u6d4b\u8bd5\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u903b\u8f91\u56de\u5f52 model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . target . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( auc ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6298\u53e00 run ( 0 ) \u90a3\u4e48\uff0c\u53d1\u751f\u4e86\u4ec0\u4e48\u5462\uff1f \u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e2a\u51fd\u6570\uff0c\u5c06\u6570\u636e\u5206\u4e3a\u8bad\u7ec3\u548c\u9a8c\u8bc1\u4e24\u90e8\u5206\uff0c\u7ed9\u5b9a\u6298\u53e0\u6570\uff0c\u5904\u7406 NaN \u503c\uff0c\u5bf9\u6240\u6709\u6570\u636e\u8fdb\u884c\u5355\u6b21\u7f16\u7801\uff0c\u5e76\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\u3002 \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u90e8\u5206\u4ee3\u7801\u65f6\uff0c\u4f1a\u4ea7\u751f\u5982\u4e0b\u8f93\u51fa\uff1a \u276f python ohe_logres . py / home / abhishek / miniconda3 / envs / ml / lib / python3 .7 / site - packages / sklearn / linear_model / _logistic . py : 939 : ConvergenceWarning : lbfgs failed to converge ( status = 1 ): STOP : TOTAL NO . of ITERATIONS REACHED LIMIT . Increase the number of iterations ( max_iter ) or scale the data as shown in : https : // scikit - learn . org / stable / modules / preprocessing . html . Please also refer to the documentation for alternative solver options : https : // scikit - learn . org / stable / modules / linear_model . html #logistic- regression extra_warning_msg = _LOGISTIC_SOLVER_CONVERGENCE_MSG ) 0.7847865042255127 \u6709\u4e00\u4e9b\u8b66\u544a\u3002\u903b\u8f91\u56de\u5f52\u4f3c\u4e4e\u6ca1\u6709\u6536\u655b\u5230\u6700\u5927\u8fed\u4ee3\u6b21\u6570\u3002\u6211\u4eec\u6ca1\u6709\u8c03\u6574\u53c2\u6570\uff0c\u6240\u4ee5\u6ca1\u6709\u95ee\u9898\u3002\u6211\u4eec\u770b\u5230 AUC \u4e3a 0.785\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u5bf9\u4ee3\u7801\u8fdb\u884c\u7b80\u5355\u4fee\u6539\uff0c\u8fd0\u884c\u6240\u6709\u6298\u53e0\u3002 .... model = linear_model . LogisticRegression () model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u5faa\u73af\u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u5e76\u6ca1\u6709\u505a\u5f88\u5927\u7684\u6539\u52a8\uff0c\u6240\u4ee5\u6211\u53ea\u663e\u793a\u4e86\u90e8\u5206\u4ee3\u7801\u884c\uff0c\u5176\u4e2d\u4e00\u4e9b\u4ee3\u7801\u884c\u6709\u6539\u52a8\u3002 \u8fd9\u5c31\u6253\u5370\u51fa\u4e86\uff1a python - W ignore ohe_logres . py Fold = 0 , AUC = 0.7847865042255127 Fold = 1 , AUC = 0.7853553605899214 Fold = 2 , AUC = 0.7879321942914885 Fold = 3 , AUC = 0.7870315929550808 Fold = 4 , AUC = 0.7864668243125608 \u8bf7\u6ce8\u610f\uff0c\u6211\u4f7f\u7528\"-W ignore \"\u5ffd\u7565\u4e86\u6240\u6709\u8b66\u544a\u3002 \u6211\u4eec\u770b\u5230\uff0cAUC \u5206\u6570\u5728\u6240\u6709\u8936\u76b1\u4e2d\u90fd\u76f8\u5f53\u7a33\u5b9a\u3002\u5e73\u5747 AUC \u4e3a 0.78631449527\u3002\u5bf9\u4e8e\u6211\u4eec\u7684\u7b2c\u4e00\u4e2a\u6a21\u578b\u6765\u8bf4\u76f8\u5f53\u4e0d\u9519\uff01 \u5f88\u591a\u4eba\u5728\u9047\u5230\u8fd9\u79cd\u95ee\u9898\u65f6\u4f1a\u9996\u5148\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6bd4\u5982\u968f\u673a\u68ee\u6797\u3002\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u5e94\u7528\u968f\u673a\u68ee\u6797\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\uff08label encoding\uff09\uff0c\u5c06\u6bcf\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81\u90fd\u8f6c\u6362\u4e3a\u6574\u6570\uff0c\u800c\u4e0d\u662f\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u72ec\u70ed\u7f16\u7801\u3002 \u8fd9\u79cd\u7f16\u7801\u4e0e\u72ec\u70ed\u7f16\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u3002 import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # \u968f\u673a\u68ee\u6797\u6a21\u578b model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u4f7f\u7528 scikit-learn \u4e2d\u7684\u968f\u673a\u68ee\u6797\uff0c\u5e76\u53d6\u6d88\u4e86\u72ec\u70ed\u7f16\u7801\u3002\u6211\u4eec\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u4ee3\u66ff\u72ec\u70ed\u7f16\u7801\u3002\u5f97\u5206\u5982\u4e0b \u276f python lbl_rf . py Fold = 0 , AUC = 0.7167390828113697 Fold = 1 , AUC = 0.7165459672958506 Fold = 2 , AUC = 0.7159709909587376 Fold = 3 , AUC = 0.7161589664189556 Fold = 4 , AUC = 0.7156020216155978 \u54c7 \u5de8\u5927\u7684\u5dee\u5f02\uff01 \u968f\u673a\u68ee\u6797\u6a21\u578b\u5728\u6ca1\u6709\u4efb\u4f55\u8d85\u53c2\u6570\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u8868\u73b0\u8981\u6bd4\u7b80\u5355\u7684\u903b\u8f91\u56de\u5f52\u5dee\u5f88\u591a\u3002 \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u603b\u662f\u5e94\u8be5\u5148\u4ece\u7b80\u5355\u6a21\u578b\u5f00\u59cb\u7684\u539f\u56e0\u3002\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u7c89\u4e1d\u4f1a\u4ece\u8fd9\u91cc\u5f00\u59cb\uff0c\u800c\u5ffd\u7565\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u8ba4\u4e3a\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u7b80\u5355\u7684\u6a21\u578b\uff0c\u4e0d\u80fd\u5e26\u6765\u6bd4\u968f\u673a\u68ee\u6797\u66f4\u597d\u7684\u4ef7\u503c\u3002\u8fd9\u79cd\u4eba\u5c06\u4f1a\u72af\u4e0b\u5927\u9519\u3002\u5728\u6211\u4eec\u5b9e\u73b0\u968f\u673a\u68ee\u6797\u7684\u8fc7\u7a0b\u4e2d\uff0c\u4e0e\u903b\u8f91\u56de\u5f52\u76f8\u6bd4\uff0c\u6298\u53e0\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u624d\u80fd\u5b8c\u6210\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e0d\u4ec5\u635f\u5931\u4e86 AUC\uff0c\u8fd8\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u6765\u5b8c\u6210\u8bad\u7ec3\u3002\u8bf7\u6ce8\u610f\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u8fdb\u884c\u63a8\u7406\u4e5f\u5f88\u8017\u65f6\uff0c\u800c\u4e14\u5360\u7528\u7684\u7a7a\u95f4\u4e5f\u66f4\u5927\u3002 \u5982\u679c\u6211\u4eec\u613f\u610f\uff0c\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u5728\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u6570\u636e\u4e0a\u8fd0\u884c\u968f\u673a\u68ee\u6797\uff0c\u4f46\u8fd9\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u5947\u5f02\u503c\u5206\u89e3\u6765\u51cf\u5c11\u7a00\u758f\u7684\u72ec\u70ed\u7f16\u7801\u77e9\u9635\u3002\u8fd9\u662f\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u63d0\u53d6\u4e3b\u9898\u7684\u5e38\u7528\u65b9\u6cd5\u3002 import pandas as pd from scipy import sparse from sklearn import decomposition from sklearn import ensemble from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" )] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) x_train = ohe . transform ( df_train [ features ]) x_valid = ohe . transform ( df_valid [ features ]) # \u5947\u5f02\u503c\u5206\u89e3 svd = decomposition . TruncatedSVD ( n_components = 120 ) full_sparse = sparse . vstack (( x_train , x_valid )) svd . fit ( full_sparse ) x_train = svd . transform ( x_train ) x_valid = svd . transform ( x_valid ) model = ensemble . RandomForestClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u6211\u4eec\u5bf9\u5168\u90e8\u6570\u636e\u8fdb\u884c\u72ec\u70ed\u7f16\u7801\uff0c\u7136\u540e\u7528\u8bad\u7ec3\u6570\u636e\u548c\u9a8c\u8bc1\u6570\u636e\u5728\u7a00\u758f\u77e9\u9635\u4e0a\u62df\u5408 scikit-learn \u7684 TruncatedSVD\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c06\u9ad8\u7ef4\u7a00\u758f\u77e9\u9635\u51cf\u5c11\u5230 120 \u4e2a\u7279\u5f81\uff0c\u7136\u540e\u62df\u5408\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u3002 \u4ee5\u4e0b\u662f\u8be5\u6a21\u578b\u7684\u8f93\u51fa\u7ed3\u679c\uff1a \u276f python ohe_svd_rf . py Fold = 0 , AUC = 0.7064863038754249 Fold = 1 , AUC = 0.706050102937374 Fold = 2 , AUC = 0.7086069243167242 Fold = 3 , AUC = 0.7066819080085971 Fold = 4 , AUC = 0.7058154015055585 \u6211\u4eec\u53d1\u73b0\u60c5\u51b5\u66f4\u7cdf\u3002\u770b\u6765\uff0c\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\u7684\u6700\u4f73\u65b9\u6cd5\u662f\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u3002\u968f\u673a\u68ee\u6797\u4f3c\u4e4e\u8017\u65f6\u592a\u591a\u3002\u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u8bd5\u8bd5 XGBoost\u3002\u5982\u679c\u4f60\u4e0d\u77e5\u9053 XGBoost\uff0c\u5b83\u662f\u6700\u6d41\u884c\u7684\u68af\u5ea6\u63d0\u5347\u7b97\u6cd5\u4e4b\u4e00\u3002\u7531\u4e8e\u5b83\u662f\u4e00\u79cd\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u6807\u7b7e\u7f16\u7801\u6570\u636e\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 , n_estimators = 200 ) model . fit ( x_train , df_train . target . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . target . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0c\u6211\u5bf9 xgboost \u53c2\u6570\u505a\u4e86\u4e00\u4e9b\u4fee\u6539\u3002xgboost \u7684\u9ed8\u8ba4\u6700\u5927\u6df1\u5ea6\uff08max_depth\uff09\u662f 3\uff0c\u6211\u628a\u5b83\u6539\u6210\u4e86 7\uff0c\u8fd8\u628a\u4f30\u8ba1\u5668\u6570\u91cf\uff08n_estimators\uff09\u4ece 100 \u6539\u6210\u4e86 200\u3002 \u8be5\u6a21\u578b\u7684 5 \u6298\u4ea4\u53c9\u68c0\u9a8c\u5f97\u5206\u5982\u4e0b\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.7656768851999011 Fold = 1 , AUC = 0.7633006564148015 Fold = 2 , AUC = 0.7654277821434345 Fold = 3 , AUC = 0.7663609758878182 Fold = 4 , AUC = 0.764914671468069 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u4e0d\u505a\u4efb\u4f55\u8c03\u6574\u7684\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u7684\u5f97\u5206\u6bd4\u666e\u901a\u968f\u673a\u68ee\u6797\u8981\u9ad8\u5f97\u591a\u3002 \u60a8\u8fd8\u53ef\u4ee5\u5c1d\u8bd5\u4e00\u4e9b\u7279\u5f81\u5de5\u7a0b\uff0c\u653e\u5f03\u67d0\u4e9b\u5bf9\u6a21\u578b\u6ca1\u6709\u4efb\u4f55\u4ef7\u503c\u7684\u5217\u7b49\u3002\u4f46\u4f3c\u4e4e\u6211\u4eec\u80fd\u505a\u7684\u4e0d\u591a\uff0c\u65e0\u6cd5\u8bc1\u660e\u6a21\u578b\u7684\u6539\u8fdb\u3002\u8ba9\u6211\u4eec\u628a\u6570\u636e\u96c6\u6362\u6210\u53e6\u4e00\u4e2a\u6709\u5927\u91cf\u5206\u7c7b\u53d8\u91cf\u7684\u6570\u636e\u96c6\u3002\u53e6\u4e00\u4e2a\u6709\u540d\u7684\u6570\u636e\u96c6\u662f \u7f8e\u56fd\u6210\u4eba\u4eba\u53e3\u666e\u67e5\u6570\u636e\uff08US adult census data\uff09 \u3002\u8fd9\u4e2a\u6570\u636e\u96c6\u5305\u542b\u4e00\u4e9b\u7279\u5f81\uff0c\u800c\u4f60\u7684\u4efb\u52a1\u662f\u9884\u6d4b\u5de5\u8d44\u7b49\u7ea7\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u3002\u56fe 5 \u663e\u793a\u4e86\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u4e00\u4e9b\u5217\u3002 \u56fe 5\uff1a\u90e8\u5206\u6570\u636e\u96c6\u5c55\u793a \u8be5\u6570\u636e\u96c6\u6709\u4ee5\u4e0b\u51e0\u5217\uff1a - \u5e74\u9f84\uff08age\uff09 \u5de5\u4f5c\u7c7b\u522b\uff08workclass\uff09 \u5b66\u5386\uff08fnlwgt\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education\uff09 \u6559\u80b2\u7a0b\u5ea6\uff08education.num\uff09 \u5a5a\u59fb\u72b6\u51b5\uff08marital.status\uff09 \u804c\u4e1a\uff08occupation\uff09 \u5173\u7cfb\uff08relationship\uff09 \u79cd\u65cf\uff08race\uff09 \u6027\u522b\uff08sex\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u539f\u7c4d\u56fd\uff08native.country\uff09 \u6536\u5165\uff08income\uff09 \u8fd9\u4e9b\u7279\u5f81\u5927\u591a\u4e0d\u8a00\u81ea\u660e\u3002\u90a3\u4e9b\u4e0d\u660e\u767d\u7684\uff0c\u6211\u4eec\u53ef\u4ee5\u4e0d\u8003\u8651\u3002\u8ba9\u6211\u4eec\u5148\u5c1d\u8bd5\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002 \u6211\u4eec\u770b\u5230\u6536\u5165\u5217\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u3002\u8ba9\u6211\u4eec\u5bf9\u8fd9\u4e00\u5217\u8fdb\u884c\u6570\u503c\u7edf\u8ba1\u3002 In [ X ]: import pandas as pd In [ X ]: df = pd . read_csv ( \"../input/adult.csv\" ) In [ X ]: df . income . value_counts () Out [ X ]: <= 50 K 24720 > 50 K 7841 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6709 7841 \u4e2a\u5b9e\u4f8b\u7684\u6536\u5165\u8d85\u8fc7 5 \u4e07\u7f8e\u5143\u3002\u8fd9\u5360\u6837\u672c\u603b\u6570\u7684 24%\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4fdd\u6301\u4e0e\u732b\u6570\u636e\u96c6\u76f8\u540c\u7684\u8bc4\u4f30\u65b9\u6cd5\uff0c\u5373 AUC\u3002 \u5728\u5f00\u59cb\u5efa\u6a21\u4e4b\u524d\uff0c\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u6211\u4eec\u5c06\u53bb\u6389\u51e0\u5217\u7279\u5f81\uff0c\u5373 \u5b66\u5386\uff08fnlwgt\uff09 \u5e74\u9f84\uff08age\uff09 \u8d44\u672c\u6536\u76ca\uff08capital.gain\uff09 \u8d44\u672c\u635f\u5931\uff08capital.loss\uff09 \u6bcf\u5468\u5c0f\u65f6\u6570\uff08hours.per.week\uff09 \u8ba9\u6211\u4eec\u8bd5\u7740\u7528\u903b\u8f91\u56de\u5f52\u548c\u72ec\u70ed\u7f16\u7801\u5668\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u7b2c\u4e00\u6b65\u603b\u662f\u8981\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002\u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u90e8\u5206\u4ee3\u7801\u3002\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u9700\u8981\u5220\u9664\u7684\u5217 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) # \u6620\u5c04 target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } # \u4f7f\u7528\u6620\u5c04\u66ff\u6362 df . loc [:, \"income\" ] = df . income . map ( target_mapping ) # \u53d6\u9664\"kfold\", \"income\"\u5217\u7684\u5176\u4ed6\u5217\u540d features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : # \u5c06\u7a7a\u503c\u66ff\u6362\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) # \u53d6\u8bad\u7ec3\u96c6\uff08kfold\u5217\u4e2d\u4e0d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u53d6\u9a8c\u8bc1\u96c6\uff08kfold\u5217\u4e2d\u4e3afold\u7684\u6837\u672c\uff0c\u91cd\u7f6e\u7d22\u5f15\uff09 df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u72ec\u70ed\u7f16\u7801 ohe = preprocessing . OneHotEncoder () # \u5c06\u8bad\u7ec3\u96c6\u3001\u6d4b\u8bd5\u96c6\u6cbf\u884c\u5408\u5e76 full_data = pd . concat ([ df_train [ features ], df_valid [ features ]], axis = 0 ) ohe . fit ( full_data [ features ]) # \u8f6c\u6362\u8bad\u7ec3\u96c6 x_train = ohe . transform ( df_train [ features ]) # \u8f6c\u6362\u9a8c\u8bc1\u96c6 x_valid = ohe . transform ( df_valid [ features ]) # \u6784\u5efa\u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b model . fit ( x_train , df_train . income . values ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u5f97\u5230\u9884\u6d4b\u6807\u7b7e valid_preds = model . predict_proba ( x_valid )[:, 1 ] # \u8ba1\u7b97auc\u6307\u6807 auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u5f53\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u65f6\uff0c\u6211\u4eec\u4f1a\u5f97\u5230 \u276f python - W ignore ohe_logres . py Fold = 0 , AUC = 0.8794809708119079 Fold = 1 , AUC = 0.8875785068274882 Fold = 2 , AUC = 0.8852609687685753 Fold = 3 , AUC = 0.8681236223251438 Fold = 4 , AUC = 0.8728581541840037 \u5bf9\u4e8e\u4e00\u4e2a\u5982\u6b64\u7b80\u5355\u7684\u6a21\u578b\u6765\u8bf4\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u4e0d\u9519\u7684 AUC\uff01 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5728\u4e0d\u8c03\u6574\u4efb\u4f55\u8d85\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c1d\u8bd5\u4e00\u4e0b\u6807\u7b7e\u7f16\u7801\u7684xgboost\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] df = df . drop ( num_cols , axis = 1 ) target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : # \u8fd0\u884c0~4\u6298 for fold_ in range ( 5 ): run ( fold_ ) \u8ba9\u6211\u4eec\u8fd0\u884c\u4e0a\u9762\u4ee3\u7801\uff1a \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8800810634234078 Fold = 1 , AUC = 0.886811884948154 Fold = 2 , AUC = 0.8854421433318472 Fold = 3 , AUC = 0.8676319549361007 Fold = 4 , AUC = 0.8714450054900602 \u8fd9\u770b\u8d77\u6765\u5df2\u7ecf\u76f8\u5f53\u4e0d\u9519\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b max_depth \u589e\u52a0\u5230 7 \u548c n_estimators \u589e\u52a0\u5230 200 \u65f6\u7684\u5f97\u5206\u3002 \u276f python lbl_xgb . py Fold = 0 , AUC = 0.8764108944332032 Fold = 1 , AUC = 0.8840708537662638 Fold = 2 , AUC = 0.8816601162613102 Fold = 3 , AUC = 0.8662335762581732 Fold = 4 , AUC = 0.8698983461709926 \u770b\u8d77\u6765\u5e76\u6ca1\u6709\u6539\u5584\u3002 \u8fd9\u8868\u660e\uff0c\u4e00\u4e2a\u6570\u636e\u96c6\u7684\u53c2\u6570\u4e0d\u80fd\u79fb\u690d\u5230\u53e6\u4e00\u4e2a\u6570\u636e\u96c6\u3002\u6211\u4eec\u5fc5\u987b\u518d\u6b21\u5c1d\u8bd5\u8c03\u6574\u53c2\u6570\uff0c\u4f46\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u8bf4\u660e\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u5c1d\u8bd5\u5728\u4e0d\u8c03\u6574\u53c2\u6570\u7684\u60c5\u51b5\u4e0b\u5c06\u6570\u503c\u7279\u5f81\u7eb3\u5165 xgboost \u6a21\u578b\u3002 import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) # \u52a0\u5165\u6570\u503c\u7279\u5f81 num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] for col in features : if col not in num_cols : # \u5c06\u7a7a\u503c\u7f6e\u4e3a\"NONE\" df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values # XGBoost\u6a21\u578b model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u56e0\u6b64\uff0c\u6211\u4eec\u4fdd\u7559\u6570\u5b57\u5217\uff0c\u53ea\u662f\u4e0d\u5bf9\u5176\u8fdb\u884c\u6807\u7b7e\u7f16\u7801\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6700\u7ec8\u7279\u5f81\u77e9\u9635\u5c31\u7531\u6570\u5b57\u5217\uff08\u539f\u6837\uff09\u548c\u7f16\u7801\u5206\u7c7b\u5217\u7ec4\u6210\u4e86\u3002\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u7b97\u6cd5\u90fd\u80fd\u8f7b\u677e\u5904\u7406\u8fd9\u79cd\u6df7\u5408\u3002 \u8bf7\u6ce8\u610f\uff0c\u5728\u4f7f\u7528\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65f6\uff0c\u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u6570\u636e\u8fdb\u884c\u5f52\u4e00\u5316\u5904\u7406\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e00\u70b9\u975e\u5e38\u91cd\u8981\uff0c\u5728\u4f7f\u7528\u7ebf\u6027\u6a21\u578b\uff08\u5982\u903b\u8f91\u56de\u5f52\uff09\u65f6\u4e0d\u5bb9\u5ffd\u89c6\u3002 \u73b0\u5728\u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u4e2a\u811a\u672c\uff01 \u276f python lbl_xgb_num . py Fold = 0 , AUC = 0.9209790185449889 Fold = 1 , AUC = 0.9247157449144706 Fold = 2 , AUC = 0.9269329887598243 Fold = 3 , AUC = 0.9119349082169275 Fold = 4 , AUC = 0.9166408030141667 \u54c7\u54e6 \u8fd9\u662f\u4e00\u4e2a\u5f88\u597d\u7684\u5206\u6570\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c1d\u8bd5\u6dfb\u52a0\u4e00\u4e9b\u529f\u80fd\u3002\u6211\u4eec\u5c06\u63d0\u53d6\u6240\u6709\u5206\u7c7b\u5217\uff0c\u5e76\u521b\u5efa\u6240\u6709\u4e8c\u5ea6\u7ec4\u5408\u3002\u8bf7\u770b\u4e0b\u9762\u4ee3\u7801\u6bb5\u4e2d\u7684 feature_engineering \u51fd\u6570\uff0c\u4e86\u89e3\u5982\u4f55\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002 import itertools import pandas as pd import xgboost as xgb from sklearn import metrics from sklearn import preprocessing def feature_engineering ( df , cat_cols ): # \u751f\u6210\u4e24\u4e2a\u7279\u5f81\u7684\u7ec4\u5408 combi = list ( itertools . combinations ( cat_cols , 2 )) for c1 , c2 in combi : df . loc [:, c1 + \"_\" + c2 ] = df [ c1 ] . astype ( str ) + \"_\" + df [ c2 ] . astype ( str ) return df def run ( fold ): df = pd . read_csv ( \"../input/adult_folds.csv\" ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) cat_cols = [ c for c in df . columns if c not in num_cols and c not in ( \"kfold\" , \"income\" )] # \u7279\u5f81\u5de5\u7a0b df = feature_engineering ( df , cat_cols ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" )] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : for fold_ in range ( 5 ): run ( fold_ ) \u8fd9\u662f\u4ece\u5206\u7c7b\u5217\u4e2d\u521b\u5efa\u7279\u5f81\u7684\u4e00\u79cd\u975e\u5e38\u5e7c\u7a1a\u7684\u65b9\u6cd5\u3002\u6211\u4eec\u5e94\u8be5\u4ed4\u7ec6\u7814\u7a76\u6570\u636e\uff0c\u770b\u770b\u54ea\u4e9b\u7ec4\u5408\u6700\u5408\u7406\u3002\u5982\u679c\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\uff0c\u6700\u7ec8\u53ef\u80fd\u4f1a\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u5c31\u9700\u8981\u4f7f\u7528\u67d0\u79cd\u7279\u5f81\u9009\u62e9\u6765\u9009\u51fa\u6700\u4f73\u7279\u5f81\u3002\u7a0d\u540e\u6211\u4eec\u5c06\u8be6\u7ec6\u4ecb\u7ecd\u7279\u5f81\u9009\u62e9\u3002\u73b0\u5728\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5206\u6570\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9211483465031423 Fold = 1 , AUC = 0.9251499446866125 Fold = 2 , AUC = 0.9262344766486692 Fold = 3 , AUC = 0.9114264068794995 Fold = 4 , AUC = 0.9177914453099201 \u770b\u6765\uff0c\u5373\u4f7f\u4e0d\u6539\u53d8\u4efb\u4f55\u8d85\u53c2\u6570\uff0c\u53ea\u589e\u52a0\u4e00\u4e9b\u7279\u5f81\uff0c\u6211\u4eec\u4e5f\u80fd\u63d0\u9ad8\u4e00\u4e9b\u6298\u53e0\u5f97\u5206\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5c06 max_depth \u589e\u52a0\u5230 7 \u662f\u5426\u6709\u5e2e\u52a9\u3002 \u276f python lbl_xgb_num_feat . py Fold = 0 , AUC = 0.9286668430204137 Fold = 1 , AUC = 0.9329340656165378 Fold = 2 , AUC = 0.9319817543218744 Fold = 3 , AUC = 0.919046187194538 Fold = 4 , AUC = 0.9245692057162671 \u6211\u4eec\u518d\u6b21\u6539\u8fdb\u4e86\u6211\u4eec\u7684\u6a21\u578b\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u8fd8\u6ca1\u6709\u4f7f\u7528\u7a00\u6709\u503c\u3001\u4e8c\u503c\u5316\u3001\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u7279\u5f81\u7684\u7ec4\u5408\u4ee5\u53ca\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u3002 \u4ece\u5206\u7c7b\u7279\u5f81\u4e2d\u8fdb\u884c\u7279\u5f81\u5de5\u7a0b\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f\u4f7f\u7528 \u76ee\u6807\u7f16\u7801 \u3002\u4f46\u662f\uff0c\u60a8\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u8fd9\u53ef\u80fd\u4f1a\u4f7f\u60a8\u7684\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u76ee\u6807\u7f16\u7801\u662f\u4e00\u79cd\u5c06\u7ed9\u5b9a\u7279\u5f81\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u6620\u5c04\u5230\u5176\u5e73\u5747\u76ee\u6807\u503c\u7684\u6280\u672f\uff0c\u4f46\u5fc5\u987b\u59cb\u7ec8\u4ee5\u4ea4\u53c9\u9a8c\u8bc1\u7684\u65b9\u5f0f\u8fdb\u884c\u3002\u8fd9\u610f\u5473\u7740\u9996\u5148\u8981\u521b\u5efa\u6298\u53e0\uff0c\u7136\u540e\u4f7f\u7528\u8fd9\u4e9b\u6298\u53e0\u4e3a\u6570\u636e\u7684\u4e0d\u540c\u5217\u521b\u5efa\u76ee\u6807\u7f16\u7801\u7279\u5f81\uff0c\u65b9\u6cd5\u4e0e\u5728\u6298\u53e0\u4e0a\u62df\u5408\u548c\u9884\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u76f8\u540c\u3002\u56e0\u6b64\uff0c\u5982\u679c\u60a8\u521b\u5efa\u4e86 5 \u4e2a\u6298\u53e0\uff0c\u60a8\u5c31\u5fc5\u987b\u521b\u5efa 5 \u6b21\u76ee\u6807\u7f16\u7801\uff0c\u8fd9\u6837\u6700\u7ec8\uff0c\u60a8\u5c31\u53ef\u4ee5\u4e3a\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7684\u53d8\u91cf\u521b\u5efa\u7f16\u7801\uff0c\u800c\u8fd9\u4e9b\u53d8\u91cf\u5e76\u975e\u6765\u81ea\u540c\u4e00\u4e2a\u6298\u53e0\u3002\u7136\u540e\u5728\u62df\u5408\u6a21\u578b\u65f6\uff0c\u5fc5\u987b\u518d\u6b21\u4f7f\u7528\u76f8\u540c\u7684\u6298\u53e0\u3002\u672a\u89c1\u6d4b\u8bd5\u6570\u636e\u7684\u76ee\u6807\u7f16\u7801\u53ef\u4ee5\u6765\u81ea\u5168\u90e8\u8bad\u7ec3\u6570\u636e\uff0c\u4e5f\u53ef\u4ee5\u662f\u6240\u6709 5 \u4e2a\u6298\u53e0\u7684\u5e73\u5747\u503c\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u540c\u4e00\u4e2a\u6210\u4eba\u6570\u636e\u96c6\u4e0a\u4f7f\u7528\u76ee\u6807\u7f16\u7801\uff0c\u4ee5\u4fbf\u8fdb\u884c\u6bd4\u8f83\u3002 import copy import pandas as pd from sklearn import metrics from sklearn import preprocessing import xgboost as xgb def mean_target_encoding ( data ): df = copy . deepcopy ( data ) num_cols = [ \"fnlwgt\" , \"age\" , \"capital.gain\" , \"capital.loss\" , \"hours.per.week\" ] target_mapping = { \"<=50K\" : 0 , \">50K\" : 1 } df . loc [:, \"income\" ] = df . income . map ( target_mapping ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) and f not in num_cols ] for col in features : if col not in num_cols : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for col in features : if col not in num_cols : # \u6807\u7b7e\u7f16\u7801 lbl = preprocessing . LabelEncoder () lbl . fit ( df [ col ]) df . loc [:, col ] = lbl . transform ( df [ col ]) encoded_dfs = [] for fold in range ( 5 ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) for column in features : # \u76ee\u6807\u7f16\u7801 mapping_dict = dict ( df_train . groupby ( column )[ \"income\" ] . mean () ) df_valid . loc [:, column + \"_enc\" ] = df_valid [ column ] . map ( mapping_dict ) encoded_dfs . append ( df_valid ) encoded_df = pd . concat ( encoded_dfs , axis = 0 ) return encoded_df def run ( df , fold ): df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) features = [ f for f in df . columns if f not in ( \"kfold\" , \"income\" ) ] x_train = df_train [ features ] . values x_valid = df_valid [ features ] . values model = xgb . XGBClassifier ( n_jobs =- 1 , max_depth = 7 ) model . fit ( x_train , df_train . income . values ) valid_preds = model . predict_proba ( x_valid )[:, 1 ] auc = metrics . roc_auc_score ( df_valid . income . values , valid_preds ) print ( f \"Fold = { fold } , AUC = { auc } \" ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/adult_folds.csv\" ) df = mean_target_encoding ( df ) for fold_ in range ( 5 ): run ( df , fold_ ) \u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u5728\u4e0a\u8ff0\u7247\u6bb5\u4e2d\uff0c\u6211\u5728\u8fdb\u884c\u76ee\u6807\u7f16\u7801\u65f6\u5e76\u6ca1\u6709\u5220\u9664\u5206\u7c7b\u5217\u3002\u6211\u4fdd\u7559\u4e86\u6240\u6709\u7279\u5f81\uff0c\u5e76\u5728\u6b64\u57fa\u7840\u4e0a\u6dfb\u52a0\u4e86\u76ee\u6807\u7f16\u7801\u7279\u5f81\u3002\u6b64\u5916\uff0c\u6211\u8fd8\u4f7f\u7528\u4e86\u5e73\u5747\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5e73\u5747\u503c\u3001\u4e2d\u4f4d\u6570\u3001\u6807\u51c6\u504f\u5dee\u6216\u76ee\u6807\u7684\u4efb\u4f55\u5176\u4ed6\u51fd\u6570\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u7ed3\u679c\u3002 Fold = 0 , AUC = 0.9332240662017529 Fold = 1 , AUC = 0.9363551625140347 Fold = 2 , AUC = 0.9375013544556173 Fold = 3 , AUC = 0.92237621307625 Fold = 4 , AUC = 0.9292131180445478 \u4e0d\u9519\uff01\u770b\u6765\u6211\u4eec\u53c8\u6709\u8fdb\u6b65\u4e86\u3002\u4e0d\u8fc7\uff0c\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\u5fc5\u987b\u975e\u5e38\u5c0f\u5fc3\uff0c\u56e0\u4e3a\u5b83\u592a\u5bb9\u6613\u51fa\u73b0\u8fc7\u5ea6\u62df\u5408\u3002\u5f53\u6211\u4eec\u4f7f\u7528\u76ee\u6807\u7f16\u7801\u65f6\uff0c\u6700\u597d\u4f7f\u7528\u67d0\u79cd\u5e73\u6ed1\u65b9\u6cd5\u6216\u5728\u7f16\u7801\u503c\u4e2d\u6dfb\u52a0\u566a\u58f0\u3002 Scikit-learn \u7684\u8d21\u732e\u5e93\u4e2d\u6709\u5e26\u5e73\u6ed1\u7684\u76ee\u6807\u7f16\u7801\uff0c\u4f60\u4e5f\u53ef\u4ee5\u521b\u5efa\u81ea\u5df1\u7684\u5e73\u6ed1\u3002\u5e73\u6ed1\u4f1a\u5f15\u5165\u67d0\u79cd\u6b63\u5219\u5316\uff0c\u6709\u52a9\u4e8e\u907f\u514d\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002\u8fd9\u5e76\u4e0d\u96be\u3002 \u5904\u7406\u5206\u7c7b\u7279\u5f81\u662f\u4e00\u9879\u590d\u6742\u7684\u4efb\u52a1\u3002\u8bb8\u591a\u8d44\u6e90\u4e2d\u90fd\u6709\u5927\u91cf\u4fe1\u606f\u3002\u672c\u7ae0\u5e94\u8be5\u80fd\u5e2e\u52a9\u4f60\u5f00\u59cb\u89e3\u51b3\u5206\u7c7b\u53d8\u91cf\u7684\u4efb\u4f55\u95ee\u9898\u3002\u4e0d\u8fc7\uff0c\u5bf9\u4e8e\u5927\u591a\u6570\u95ee\u9898\u6765\u8bf4\uff0c\u9664\u4e86\u72ec\u70ed\u7f16\u7801\u548c\u6807\u7b7e\u7f16\u7801\u4e4b\u5916\uff0c\u4f60\u4e0d\u9700\u8981\u66f4\u591a\u7684\u4e1c\u897f\u3002 \u8981\u8fdb\u4e00\u6b65\u6539\u8fdb\u6a21\u578b\uff0c\u4f60\u53ef\u80fd\u9700\u8981\u66f4\u591a\uff01 \u5728\u672c\u7ae0\u7684\u6700\u540e\uff0c\u6211\u4eec\u4e0d\u80fd\u4e0d\u5728\u8fd9\u4e9b\u6570\u636e\u4e0a\u4f7f\u7528\u795e\u7ecf\u7f51\u7edc\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e00\u79cd\u79f0\u4e3a \u5b9e\u4f53\u5d4c\u5165 \u7684\u6280\u672f\u3002\u5728\u5b9e\u4f53\u5d4c\u5165\u4e2d\uff0c\u7c7b\u522b\u7528\u5411\u91cf\u8868\u793a\u3002\u5728\u4e8c\u503c\u5316\u548c\u72ec\u70ed\u7f16\u7801\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u90fd\u662f\u7528\u5411\u91cf\u6765\u8868\u793a\u7c7b\u522b\u7684\u3002 \u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u6709\u6570\u4ee5\u4e07\u8ba1\u7684\u7c7b\u522b\u600e\u4e48\u529e\uff1f\u8fd9\u5c06\u4f1a\u4ea7\u751f\u5de8\u5927\u7684\u77e9\u9635\uff0c\u6211\u4eec\u5c06\u9700\u8981\u5f88\u957f\u65f6\u95f4\u6765\u8bad\u7ec3\u590d\u6742\u7684\u6a21\u578b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u7528\u5e26\u6709\u6d6e\u70b9\u503c\u7684\u5411\u91cf\u6765\u8868\u793a\u5b83\u4eec\u3002 \u8fd9\u4e2a\u60f3\u6cd5\u975e\u5e38\u7b80\u5355\u3002\u6bcf\u4e2a\u5206\u7c7b\u7279\u5f81\u90fd\u6709\u4e00\u4e2a\u5d4c\u5165\u5c42\u3002\u56e0\u6b64\uff0c\u4e00\u5217\u4e2d\u7684\u6bcf\u4e2a\u7c7b\u522b\u73b0\u5728\u90fd\u53ef\u4ee5\u6620\u5c04\u5230\u4e00\u4e2a\u5d4c\u5165\u5c42\uff08\u5c31\u50cf\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\u5c06\u5355\u8bcd\u6620\u5c04\u5230\u5d4c\u5165\u5c42\u4e00\u6837\uff09\u3002\u7136\u540e\uff0c\u6839\u636e\u5176\u7ef4\u5ea6\u91cd\u5851\u8fd9\u4e9b\u5d4c\u5165\u5c42\uff0c\u4f7f\u5176\u6241\u5e73\u5316\uff0c\u7136\u540e\u5c06\u6240\u6709\u6241\u5e73\u5316\u7684\u8f93\u5165\u5d4c\u5165\u5c42\u8fde\u63a5\u8d77\u6765\u3002\u7136\u540e\u6dfb\u52a0\u4e00\u5806\u5bc6\u96c6\u5c42\u548c\u4e00\u4e2a\u8f93\u51fa\u5c42\uff0c\u5c31\u5927\u529f\u544a\u6210\u4e86\u3002 \u56fe 6\uff1a\u7c7b\u522b\u8f6c\u6362\u4e3a\u6d6e\u70b9\u6216\u5d4c\u5165\u5411\u91cf \u51fa\u4e8e\u67d0\u79cd\u539f\u56e0\uff0c\u6211\u53d1\u73b0\u4f7f\u7528 TF/Keras \u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u56e0\u6b64\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528 TF/Keras \u5b9e\u73b0\u5b83\u3002\u6b64\u5916\uff0c\u8fd9\u662f\u672c\u4e66\u4e2d\u552f\u4e00\u4e00\u4e2a\u4f7f\u7528 TF/Keras \u7684\u793a\u4f8b\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a PyTorch\uff08\u4f7f\u7528 cat-in-the-dat-ii \u6570\u636e\u96c6\uff09\u4e5f\u975e\u5e38\u5bb9\u6613 import os import gc import joblib import pandas as pd import numpy as np from sklearn import metrics , preprocessing from tensorflow.keras import layers from tensorflow.keras import optimizers from tensorflow.keras.models import Model , load_model from tensorflow.keras import callbacks from tensorflow.keras import backend as K from tensorflow.keras import utils def create_model ( data , catcols ): # \u521b\u5efa\u7a7a\u7684\u8f93\u5165\u5217\u8868\u548c\u8f93\u51fa\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6a21\u578b\u7684\u8f93\u5165\u548c\u8f93\u51fa inputs = [] outputs = [] # \u904d\u5386\u5206\u7c7b\u7279\u5f81\u5217\u8868\u4e2d\u7684\u6bcf\u4e2a\u7279\u5f81 for c in catcols : # \u8ba1\u7b97\u7279\u5f81\u4e2d\u552f\u4e00\u503c\u7684\u6570\u91cf num_unique_values = int ( data [ c ] . nunique ()) # \u8ba1\u7b97\u5d4c\u5165\u7ef4\u5ea6\uff0c\u6700\u5927\u4e0d\u8d85\u8fc750 embed_dim = int ( min ( np . ceil (( num_unique_values ) / 2 ), 50 )) # \u521b\u5efa\u6a21\u578b\u7684\u8f93\u5165\u5c42\uff0c\u6bcf\u4e2a\u7279\u5f81\u5bf9\u5e94\u4e00\u4e2a\u8f93\u5165 inp = layers . Input ( shape = ( 1 ,)) # \u521b\u5efa\u5d4c\u5165\u5c42\uff0c\u5c06\u5206\u7c7b\u7279\u5f81\u6620\u5c04\u5230\u4f4e\u7ef4\u5ea6\u7684\u8fde\u7eed\u5411\u91cf out = layers . Embedding ( num_unique_values + 1 , embed_dim , name = c )( inp ) # \u5bf9\u5d4c\u5165\u5c42\u8fdb\u884c\u7a7a\u95f4\u4e22\u5f03\uff08Dropout\uff09 out = layers . SpatialDropout1D ( 0.3 )( out ) # \u5c06\u5d4c\u5165\u5c42\u7684\u5f62\u72b6\u91cd\u65b0\u8c03\u6574\u4e3a\u4e00\u7ef4 out = layers . Reshape ( target_shape = ( embed_dim ,))( out ) # \u5c06\u8f93\u5165\u548c\u8f93\u51fa\u6dfb\u52a0\u5230\u5bf9\u5e94\u7684\u5217\u8868\u4e2d inputs . append ( inp ) outputs . append ( out ) # \u4f7f\u7528Concatenate\u5c42\u5c06\u6240\u6709\u7684\u5d4c\u5165\u5c42\u8f93\u51fa\u8fde\u63a5\u5728\u4e00\u8d77 x = layers . Concatenate ()( outputs ) # \u5bf9\u8fde\u63a5\u540e\u7684\u6570\u636e\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u6dfb\u52a0\u53e6\u4e00\u4e2a\u5177\u6709300\u4e2a\u795e\u7ecf\u5143\u7684\u5bc6\u96c6\u5c42\uff0c\u5e76\u4f7f\u7528ReLU\u6fc0\u6d3b\u51fd\u6570 x = layers . Dense ( 300 , activation = \"relu\" )( x ) # \u5bf9\u8be5\u5c42\u7684\u8f93\u51fa\u8fdb\u884cDropout x = layers . Dropout ( 0.3 )( x ) # \u518d\u6b21\u8fdb\u884c\u6279\u91cf\u5f52\u4e00\u5316 x = layers . BatchNormalization ()( x ) # \u8f93\u51fa\u5c42\uff0c\u5177\u67092\u4e2a\u795e\u7ecf\u5143\uff08\u7528\u4e8e\u4e8c\u8fdb\u5236\u5206\u7c7b\uff09\uff0c\u5e76\u4f7f\u7528softmax\u6fc0\u6d3b\u51fd\u6570 y = layers . Dense ( 2 , activation = \"softmax\" )( x ) # \u521b\u5efa\u6a21\u578b\uff0c\u5c06\u8f93\u5165\u548c\u8f93\u51fa\u4f20\u9012\u7ed9Model\u6784\u9020\u51fd\u6570 model = Model ( inputs = inputs , outputs = y ) # \u7f16\u8bd1\u6a21\u578b\uff0c\u6307\u5b9a\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668 model . compile ( loss = 'binary_crossentropy' , optimizer = 'adam' ) # \u8fd4\u56de\u521b\u5efa\u7684\u6a21\u578b return model def run ( fold ): df = pd . read_csv ( \"../input/cat_train_folds.csv\" ) features = [ f for f in df . columns if f not in ( \"id\" , \"target\" , \"kfold\" ) ] for col in features : df . loc [:, col ] = df [ col ] . astype ( str ) . fillna ( \"NONE\" ) for feat in features : lbl_enc = preprocessing . LabelEncoder () df . loc [:, feat ] = lbl_enc . fit_transform ( df [ feat ] . values ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) model = create_model ( df , features ) xtrain = [ df_train [ features ] . values [:, k ] for k in range ( len ( features ))] xvalid = [ df_valid [ features ] . values [:, k ] for k in range ( len ( features )) ] ytrain = df_train . target . values yvalid = df_valid . target . values ytrain_cat = utils . to_categorical ( ytrain ) yvalid_cat = utils . to_categorical ( yvalid ) model . fit ( xtrain , ytrain_cat , validation_data = ( xvalid , yvalid_cat ), verbose = 1 , batch_size = 1024 , epochs = 3 ) valid_preds = model . predict ( xvalid )[:, 1 ] print ( metrics . roc_auc_score ( yvalid , valid_preds )) K . clear_session () if __name__ == \"__main__\" : run ( 0 ) run ( 1 ) run ( 2 ) run ( 3 ) run ( 4 ) \u4f60\u4f1a\u53d1\u73b0\u8fd9\u79cd\u65b9\u6cd5\u6548\u679c\u6700\u597d\uff0c\u800c\u4e14\u5982\u679c\u4f60\u6709 GPU\uff0c\u901f\u5ea6\u4e5f\u8d85\u5feb\uff01\u8fd9\u79cd\u65b9\u6cd5\u8fd8\u53ef\u4ee5\u8fdb\u4e00\u6b65\u6539\u8fdb\uff0c\u800c\u4e14\u4f60\u65e0\u9700\u62c5\u5fc3\u7279\u5f81\u5de5\u7a0b\uff0c\u56e0\u4e3a\u795e\u7ecf\u7f51\u7edc\u4f1a\u81ea\u884c\u5904\u7406\u3002\u5728\u5904\u7406\u5927\u91cf\u5206\u7c7b\u7279\u5f81\u6570\u636e\u96c6\u65f6\uff0c\u8fd9\u7edd\u5bf9\u503c\u5f97\u4e00\u8bd5\u3002\u5f53\u5d4c\u5165\u5927\u5c0f\u4e0e\u552f\u4e00\u7c7b\u522b\u7684\u6570\u91cf\u76f8\u540c\u65f6\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u7f16\u7801\uff08one-hot-encoding\uff09\u3002 \u672c\u7ae0\u57fa\u672c\u4e0a\u90fd\u662f\u5173\u4e8e\u7279\u5f81\u5de5\u7a0b\u7684\u3002\u8ba9\u6211\u4eec\u5728\u4e0b\u4e00\u7ae0\u4e2d\u770b\u770b\u5982\u4f55\u5728\u6570\u5b57\u7279\u5f81\u548c\u4e0d\u540c\u7c7b\u578b\u7279\u5f81\u7684\u7ec4\u5408\u65b9\u9762\u8fdb\u884c\u66f4\u591a\u7684\u7279\u5f81\u5de5\u7a0b\u3002","title":"\u5904\u7406\u5206\u7c7b\u53d8\u91cf"},{"location":"%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%88%96%E5%9B%9E%E5%BD%92%E6%96%B9%E6%B3%95/","text":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5 \u6587\u672c\u95ee\u9898\u662f\u6211\u7684\u6700\u7231\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u8fd9\u4e9b\u95ee\u9898\u4e5f\u88ab\u79f0\u4e3a \u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u95ee\u9898 \u3002NLP \u95ee\u9898\u4e0e\u56fe\u50cf\u95ee\u9898\u4e5f\u6709\u5f88\u5927\u4e0d\u540c\u3002\u4f60\u9700\u8981\u521b\u5efa\u4ee5\u524d\u4ece\u672a\u4e3a\u8868\u683c\u95ee\u9898\u521b\u5efa\u8fc7\u7684\u6570\u636e\u7ba1\u9053\u3002\u4f60\u9700\u8981\u4e86\u89e3\u5546\u4e1a\u6848\u4f8b\uff0c\u624d\u80fd\u5efa\u7acb\u4e00\u4e2a\u597d\u7684\u6a21\u578b\u3002\u987a\u4fbf\u8bf4\u4e00\u53e5\uff0c\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u4efb\u4f55\u4e8b\u60c5\u90fd\u662f\u5982\u6b64\u3002\u5efa\u7acb\u6a21\u578b\u4f1a\u8ba9\u4f60\u8fbe\u5230\u4e00\u5b9a\u7684\u6c34\u5e73\uff0c\u4f46\u8981\u60f3\u6539\u5584\u548c\u4fc3\u8fdb\u4f60\u6240\u5efa\u7acb\u6a21\u578b\u7684\u4e1a\u52a1\uff0c\u4f60\u5fc5\u987b\u4e86\u89e3\u5b83\u5bf9\u4e1a\u52a1\u7684\u5f71\u54cd\u3002 NLP \u95ee\u9898\u6709\u5f88\u591a\u79cd\uff0c\u5176\u4e2d\u6700\u5e38\u89c1\u7684\u662f\u5b57\u7b26\u4e32\u5206\u7c7b\u3002\u5f88\u591a\u65f6\u5019\uff0c\u6211\u4eec\u4f1a\u770b\u5230\u4eba\u4eec\u5728\u5904\u7406\u8868\u683c\u6570\u636e\u6216\u56fe\u50cf\u65f6\u8868\u73b0\u51fa\u8272\uff0c\u4f46\u5728\u5904\u7406\u6587\u672c\u65f6\uff0c\u4ed6\u4eec\u751a\u81f3\u4e0d\u77e5\u9053\u4ece\u4f55\u5165\u624b\u3002\u6587\u672c\u6570\u636e\u4e0e\u5176\u4ed6\u7c7b\u578b\u7684\u6570\u636e\u96c6\u6ca1\u6709\u4ec0\u4e48\u4e0d\u540c\u3002\u5bf9\u4e8e\u8ba1\u7b97\u673a\u6765\u8bf4\uff0c\u4e00\u5207\u90fd\u662f\u6570\u5b57\u3002 \u5047\u8bbe\u6211\u4eec\u4ece\u60c5\u611f\u5206\u7c7b\u8fd9\u4e00\u57fa\u672c\u4efb\u52a1\u5f00\u59cb\u3002\u6211\u4eec\u5c06\u5c1d\u8bd5\u5bf9\u7535\u5f71\u8bc4\u8bba\u8fdb\u884c\u60c5\u611f\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u6709\u4e00\u4e2a\u6587\u672c\uff0c\u5e76\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u60c5\u611f\u3002\u4f60\u5c06\u5982\u4f55\u5904\u7406\u8fd9\u7c7b\u95ee\u9898\uff1f\u662f\u5e94\u7528\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff1f \u4e0d\uff0c\u7edd\u5bf9\u9519\u4e86\u3002\u4f60\u8981\u4ece\u6700\u57fa\u672c\u7684\u5f00\u59cb\u3002\u8ba9\u6211\u4eec\u5148\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u4ec0\u4e48\u6837\u5b50\u7684\u3002 \u6211\u4eec\u4ece IMDB \u7535\u5f71\u8bc4\u8bba\u6570\u636e\u96c6 \u5f00\u59cb\uff0c\u8be5\u6570\u636e\u96c6\u5305\u542b 25000 \u7bc7\u6b63\u9762\u60c5\u611f\u8bc4\u8bba\u548c 25000 \u7bc7\u8d1f\u9762\u60c5\u611f\u8bc4\u8bba\u3002 \u6211\u5c06\u5728\u6b64\u8ba8\u8bba\u7684\u6982\u5ff5\u51e0\u4e4e\u9002\u7528\u4e8e\u4efb\u4f55\u6587\u672c\u5206\u7c7b\u6570\u636e\u96c6\u3002 \u8fd9\u4e2a\u6570\u636e\u96c6\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4e00\u7bc7\u8bc4\u8bba\u5bf9\u5e94\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5199\u7684\u662f\u8bc4\u8bba\u800c\u4e0d\u662f\u53e5\u5b50\u3002\u8bc4\u8bba\u5c31\u662f\u4e00\u5806\u53e5\u5b50\u3002\u6240\u4ee5\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4f60\u4e00\u5b9a\u53ea\u770b\u5230\u4e86\u5bf9\u5355\u53e5\u7684\u5206\u7c7b\uff0c\u4f46\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u591a\u4e2a\u53e5\u5b50\u8fdb\u884c\u5206\u7c7b\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u8fd9\u610f\u5473\u7740\u4e0d\u4ec5\u4e00\u4e2a\u53e5\u5b50\u4f1a\u5bf9\u60c5\u611f\u4ea7\u751f\u5f71\u54cd\uff0c\u800c\u4e14\u60c5\u611f\u5f97\u5206\u662f\u591a\u4e2a\u53e5\u5b50\u5f97\u5206\u7684\u7ec4\u5408\u3002\u6570\u636e\u7b80\u4ecb\u5982\u56fe 1 \u6240\u793a\u3002 \u5982\u4f55\u7740\u624b\u89e3\u51b3\u8fd9\u6837\u7684\u95ee\u9898\uff1f\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u624b\u5de5\u5236\u4f5c\u4e24\u4efd\u5355\u8bcd\u8868\u3002\u4e00\u4e2a\u5217\u8868\u5305\u542b\u4f60\u80fd\u60f3\u8c61\u5230\u7684\u6240\u6709\u6b63\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u597d\u3001\u68d2\u3001\u597d\u7b49\uff1b\u53e6\u4e00\u4e2a\u5217\u8868\u5305\u542b\u6240\u6709\u8d1f\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u574f\u3001\u6076\u7b49\u3002\u6211\u4eec\u5148\u4e0d\u8981\u4e3e\u4f8b\u8bf4\u660e\u574f\u8bcd\uff0c\u5426\u5219\u8fd9\u672c\u4e66\u5c31\u53ea\u80fd\u4f9b 18 \u5c81\u4ee5\u4e0a\u7684\u4eba\u9605\u8bfb\u4e86\u3002\u4e00\u65e6\u4f60\u6709\u4e86\u8fd9\u4e9b\u5217\u8868\uff0c\u4f60\u751a\u81f3\u4e0d\u9700\u8981\u4e00\u4e2a\u6a21\u578b\u6765\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u4e9b\u5217\u8868\u4e5f\u88ab\u79f0\u4e3a\u60c5\u611f\u8bcd\u5178\u3002\u4f60\u53ef\u4ee5\u7528\u4e00\u4e2a\u7b80\u5355\u7684\u8ba1\u6570\u5668\u6765\u8ba1\u7b97\u53e5\u5b50\u4e2d\u6b63\u9762\u548c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u3002\u5982\u679c\u6b63\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u6b63\u9762\u60c5\u611f\uff1b\u5982\u679c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u8d1f\u9762\u60c5\u611f\u3002\u5982\u679c\u53e5\u5b50\u4e2d\u6ca1\u6709\u8fd9\u4e9b\u8bcd\uff0c\u5219\u53ef\u4ee5\u8bf4\u8be5\u53e5\u5b50\u5177\u6709\u4e2d\u6027\u60c5\u611f\u3002\u8fd9\u662f\u6700\u53e4\u8001\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u73b0\u5728\u4ecd\u6709\u4eba\u5728\u4f7f\u7528\u3002\u5b83\u4e5f\u4e0d\u9700\u8981\u592a\u591a\u4ee3\u7801\u3002 def find_sentiment ( sentence , pos , neg ): sentence = sentence . split () sentence = set ( sentence ) num_common_pos = len ( sentence . intersection ( pos )) num_common_neg = len ( sentence . intersection ( neg )) if num_common_pos > num_common_neg : return \"positive\" if num_common_pos < num_common_neg : return \"negative\" return \"neutral\" \u4e0d\u8fc7\uff0c\u8fd9\u79cd\u65b9\u6cd5\u8003\u8651\u7684\u56e0\u7d20\u5e76\u4e0d\u591a\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u7684 split() \u4e5f\u5e76\u4e0d\u5b8c\u7f8e\u3002\u5982\u679c\u4f7f\u7528 split()\uff0c\u5c31\u4f1a\u51fa\u73b0\u8fd9\u6837\u7684\u53e5\u5b50\uff1a \"hi, how are you?\" \u7ecf\u8fc7\u5206\u5272\u540e\u53d8\u4e3a\uff1a [\"hi,\", \"how\",\"are\",\"you?\"] \u8fd9\u79cd\u65b9\u6cd5\u5e76\u4e0d\u7406\u60f3\uff0c\u56e0\u4e3a\u5355\u8bcd\u4e2d\u5305\u542b\u4e86\u9017\u53f7\u548c\u95ee\u53f7\uff0c\u5b83\u4eec\u5e76\u6ca1\u6709\u88ab\u5206\u5272\u3002\u56e0\u6b64\uff0c\u5982\u679c\u6ca1\u6709\u5728\u5206\u5272\u524d\u5bf9\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u8fdb\u884c\u9884\u5904\u7406\uff0c\u4e0d\u5efa\u8bae\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\u3002\u5c06\u5b57\u7b26\u4e32\u62c6\u5206\u4e3a\u5355\u8bcd\u5217\u8868\u79f0\u4e3a\u6807\u8bb0\u5316\u3002\u6700\u6d41\u884c\u7684\u6807\u8bb0\u5316\u65b9\u6cd5\u4e4b\u4e00\u6765\u81ea NLTK\uff08\u81ea\u7136\u8bed\u8a00\u5de5\u5177\u5305\uff09 \u3002 In [ X ]: from nltk.tokenize import word_tokenize In [ X ]: sentence = \"hi, how are you?\" In [ X ]: sentence . split () Out [ X ]: [ 'hi,' , 'how' , 'are' , 'you?' ] In [ X ]: word_tokenize ( sentence ) Out [ X ]: [ 'hi' , ',' , 'how' , 'are' , 'you' , '?' ] \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u4f7f\u7528 NLTK \u7684\u5355\u8bcd\u6807\u8bb0\u5316\u529f\u80fd\uff0c\u540c\u4e00\u4e2a\u53e5\u5b50\u7684\u62c6\u5206\u6548\u679c\u8981\u597d\u5f97\u591a\u3002\u4f7f\u7528\u5355\u8bcd\u5217\u8868\u8fdb\u884c\u5bf9\u6bd4\u7684\u6548\u679c\u4e5f\u4f1a\u66f4\u597d\uff01\u8fd9\u5c31\u662f\u6211\u4eec\u5c06\u5e94\u7528\u4e8e\u7b2c\u4e00\u4e2a\u60c5\u611f\u68c0\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u3002 \u5728\u5904\u7406 NLP \u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u60a8\u5e94\u8be5\u7ecf\u5e38\u5c1d\u8bd5\u7684\u57fa\u672c\u6a21\u578b\u4e4b\u4e00\u662f \u8bcd\u888b\u6a21\u578b\uff08bag of words\uff09 \u3002\u5728\u8bcd\u888b\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u5de8\u5927\u7684\u7a00\u758f\u77e9\u9635\uff0c\u5b58\u50a8\u8bed\u6599\u5e93\uff08\u8bed\u6599\u5e93=\u6240\u6709\u6587\u6863=\u6240\u6709\u53e5\u5b50\uff09\u4e2d\u6240\u6709\u5355\u8bcd\u7684\u8ba1\u6570\u3002\u4e3a\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 CountVectorizer\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u662f\u5982\u4f55\u5de5\u4f5c\u7684\u3002 from sklearn.feature_extraction.text import CountVectorizer corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer () ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) \u5982\u679c\u6211\u4eec\u6253\u5370 corpus_transformed\uff0c\u5c31\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e0b\u9762\u7684\u7ed3\u679c\uff1a ( 0 , 2 ) 1 ( 0 , 9 ) 1 ( 0 , 11 ) 1 ( 0 , 22 ) 1 ( 1 , 1 ) 1 ( 1 , 3 ) 1 ( 1 , 4 ) 1 ( 1 , 7 ) 1 ( 1 , 8 ) 1 ( 1 , 10 ) 1 ( 1 , 13 ) 1 ( 1 , 17 ) 1 ( 1 , 19 ) 1 ( 1 , 22 ) 2 ( 2 , 0 ) 1 ( 2 , 5 ) 1 ( 2 , 6 ) 1 ( 2 , 14 ) 1 ( 2 , 22 ) 1 ( 3 , 12 ) 1 ( 3 , 15 ) 1 ( 3 , 16 ) 1 ( 3 , 18 ) 1 ( 3 , 20 ) 1 ( 4 , 21 ) 1 \u5728\u524d\u9762\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u89c1\u8bc6\u8fc7\u8fd9\u79cd\u8868\u793a\u6cd5\u3002\u5373\u7a00\u758f\u8868\u793a\u6cd5\u3002\u56e0\u6b64\uff0c\u8bed\u6599\u5e93\u73b0\u5728\u662f\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5176\u4e2d\u7b2c\u4e00\u4e2a\u6837\u672c\u6709 4 \u4e2a\u5143\u7d20\uff0c\u7b2c\u4e8c\u4e2a\u6837\u672c\u6709 10 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\uff0c\u7b2c\u4e09\u4e2a\u6837\u672c\u6709 5 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e9b\u5143\u7d20\u90fd\u6709\u76f8\u5173\u7684\u8ba1\u6570\u3002\u6709\u4e9b\u5143\u7d20\u4f1a\u51fa\u73b0\u4e24\u6b21\uff0c\u6709\u4e9b\u5219\u53ea\u6709\u4e00\u6b21\u3002\u4f8b\u5982\uff0c\u5728\u6837\u672c 2\uff08\u7b2c 1 \u884c\uff09\u4e2d\uff0c\u6211\u4eec\u770b\u5230\u7b2c 22 \u5217\u7684\u6570\u503c\u662f 2\u3002\u8fd9\u662f\u4e3a\u4ec0\u4e48\u5462\uff1f\u7b2c 22 \u5217\u662f\u4ec0\u4e48\uff1f CountVectorizer \u7684\u5de5\u4f5c\u65b9\u5f0f\u662f\u9996\u5148\u5bf9\u53e5\u5b50\u8fdb\u884c\u6807\u8bb0\u5316\u5904\u7406\uff0c\u7136\u540e\u4e3a\u6bcf\u4e2a\u6807\u8bb0\u8d4b\u503c\u3002\u56e0\u6b64\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u7531\u4e00\u4e2a\u552f\u4e00\u7d22\u5f15\u8868\u793a\u3002\u8fd9\u4e9b\u552f\u4e00\u7d22\u5f15\u5c31\u662f\u6211\u4eec\u770b\u5230\u7684\u5217\u3002CountVectorizer \u4f1a\u5b58\u50a8\u8fd9\u4e9b\u4fe1\u606f\u3002 print ( ctv . vocabulary_ ) { 'hello' : 9 , 'how' : 11 , 'are' : 2 , 'you' : 22 , 'im' : 13 , 'getting' : 8 , 'bored' : 4 , 'at' : 3 , 'home' : 10 , 'and' : 1 , 'what' : 19 , 'do' : 7 , 'think' : 17 , 'did' : 6 , 'know' : 14 , 'about' : 0 , 'counts' : 5 , 'let' : 15 , 'see' : 16 , 'if' : 12 , 'this' : 18 , 'works' : 20 , 'yes' : 21 } \u6211\u4eec\u770b\u5230\uff0c\u7d22\u5f15 22 \u5c5e\u4e8e \"you\"\uff0c\u800c\u5728\u7b2c\u4e8c\u53e5\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u4e24\u6b21 \"you\"\u3002\u6211\u5e0c\u671b\u5927\u5bb6\u73b0\u5728\u5df2\u7ecf\u6e05\u695a\u4ec0\u4e48\u662f\u8bcd\u888b\u4e86\u3002\u4f46\u662f\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4e9b\u7279\u6b8a\u5b57\u7b26\u3002\u6709\u65f6\uff0c\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u4e5f\u5f88\u6709\u7528\u3002\u4f8b\u5982\uff0c\"? \"\u5728\u5927\u591a\u6570\u53e5\u5b50\u4e2d\u8868\u793a\u7591\u95ee\u53e5\u3002\u8ba9\u6211\u4eec\u628a scikit-learn \u7684 word_tokenize \u6574\u5408\u5230 CountVectorizer \u4e2d\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002 from sklearn.feature_extraction.text import CountVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) print ( ctv . vocabulary_ ) \u8fd9\u6837\uff0c\u6211\u4eec\u7684\u8bcd\u888b\u5c31\u53d8\u6210\u4e86\uff1a { 'hello' : 14 , ',' : 2 , 'how' : 16 , 'are' : 7 , 'you' : 27 , '?' : 4 , 'im' : 18 , 'getting' : 13 , 'bored' : 9 , 'at' : 8 , 'home' : 15 , '.' : 3 , 'and' : 6 , 'what' : 24 , 'do' : 12 , 'think' : 22 , 'did' : 11 , 'know' : 19 , 'about' : 5 , 'counts' : 10 , 'let' : 20 , \"'s\" : 1 , 'see' : 21 , 'if' : 17 , 'this' : 23 , 'works' : 25 , '!' : 0 , 'yes' : 26 } \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u5229\u7528 IMDB \u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u53e5\u5b50\u521b\u5efa\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5e76\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002\u8be5\u6570\u636e\u96c6\u4e2d\u6b63\u8d1f\u6837\u672c\u7684\u6bd4\u4f8b\u4e3a 1:1\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u5e76\u521b\u5efa\u4e00\u4e2a\u811a\u672c\u6765\u8bad\u7ec35\u4e2a\u6298\u53e0\u3002\u4f60\u4f1a\u95ee\u4f7f\u7528\u54ea\u4e2a\u6a21\u578b\uff1f\u5bf9\u4e8e\u9ad8\u7ef4\u7a00\u758f\u6570\u636e\uff0c\u54ea\u4e2a\u6a21\u578b\u6700\u5feb\uff1f\u903b\u8f91\u56de\u5f52\u3002\u6211\u4eec\u5c06\u9996\u5148\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u6765\u5904\u7406\u8fd9\u4e2a\u6570\u636e\u96c6\uff0c\u5e76\u521b\u5efa\u7b2c\u4e00\u4e2a\u57fa\u51c6\u6a21\u578b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) count_vec = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) count_vec . fit ( train_df . review ) xtrain = count_vec . transform ( train_df . review ) xtest = count_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u6bb5\u4ee3\u7801\u7684\u8fd0\u884c\u9700\u8981\u4e00\u5b9a\u7684\u65f6\u95f4\uff0c\u4f46\u53ef\u4ee5\u5f97\u5230\u4ee5\u4e0b\u8f93\u51fa\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8903 Fold : 1 Accuracy = 0.897 Fold : 2 Accuracy = 0.891 Fold : 3 Accuracy = 0.8914 Fold : 4 Accuracy = 0.8931 \u54c7\uff0c\u51c6\u786e\u7387\u5df2\u7ecf\u8fbe\u5230 89%\uff0c\u800c\u6211\u4eec\u6240\u505a\u7684\u53ea\u662f\u4f7f\u7528\u8bcd\u888b\u548c\u903b\u8f91\u56de\u5f52\uff01\u8fd9\u771f\u662f\u592a\u68d2\u4e86\uff01\u4e0d\u8fc7\uff0c\u8fd9\u4e2a\u6a21\u578b\u7684\u8bad\u7ec3\u82b1\u8d39\u4e86\u5f88\u591a\u65f6\u95f4\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u80fd\u5426\u901a\u8fc7\u4f7f\u7528\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u6765\u7f29\u77ed\u8bad\u7ec3\u65f6\u95f4\u3002\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u5728 NLP \u4efb\u52a1\u4e2d\u76f8\u5f53\u6d41\u884c\uff0c\u56e0\u4e3a\u7a00\u758f\u77e9\u9635\u975e\u5e38\u5e9e\u5927\uff0c\u800c\u6734\u7d20\u8d1d\u53f6\u65af\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\u3002\u8981\u4f7f\u7528\u8fd9\u4e2a\u6a21\u578b\uff0c\u9700\u8981\u66f4\u6539\u4e00\u4e2a\u5bfc\u5165\u548c\u6a21\u578b\u7684\u884c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e2a\u6a21\u578b\u7684\u6027\u80fd\u5982\u4f55\u3002\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 MultinomialNB\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import naive_bayes from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer model = naive_bayes . MultinomialNB () model . fit ( xtrain , train_df . sentiment ) \u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8444 Fold : 1 Accuracy = 0.8499 Fold : 2 Accuracy = 0.8422 Fold : 3 Accuracy = 0.8443 Fold : 4 Accuracy = 0.8455 \u6211\u4eec\u770b\u5230\u8fd9\u4e2a\u5206\u6570\u5f88\u4f4e\u3002\u4f46\u6734\u7d20\u8d1d\u53f6\u65af\u6a21\u578b\u7684\u901f\u5ea6\u975e\u5e38\u5feb\u3002 NLP \u4e2d\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f TF-IDF\uff0c\u5982\u4eca\u5927\u591a\u6570\u4eba\u90fd\u503e\u5411\u4e8e\u5ffd\u7565\u6216\u4e0d\u5c51\u4e8e\u4e86\u89e3\u8fd9\u79cd\u65b9\u6cd5\u3002TF \u662f\u672f\u8bed\u9891\u7387\uff0cIDF \u662f\u53cd\u5411\u6587\u6863\u9891\u7387\u3002\u4ece\u8fd9\u4e9b\u672f\u8bed\u6765\u770b\uff0c\u8fd9\u4f3c\u4e4e\u6709\u4e9b\u56f0\u96be\uff0c\u4f46\u901a\u8fc7 TF \u548c IDF \u7684\u8ba1\u7b97\u516c\u5f0f\uff0c\u4e8b\u60c5\u5c31\u4f1a\u53d8\u5f97\u5f88\u660e\u663e\u3002 $$ TF(t) = \\frac{Number\\ of\\ times\\ a\\ term\\ t\\ appears\\ in\\ a\\ document}{Total\\ number\\ of\\ terms\\ in \\ the\\ document} $$ \\[ IDF(t) = LOG\\left(\\frac{Total\\ number\\ of\\ documents}{Number\\ of\\ documents with\\ term\\ t\\ in\\ it}\\right) \\] \u672f\u8bed t \u7684 TF-IDF \u5b9a\u4e49\u4e3a\uff1a $$ TF-IDF(t) = TF(t) \\times IDF(t) $$ \u4e0e scikit-learn \u4e2d\u7684 CountVectorizer \u7c7b\u4f3c\uff0c\u6211\u4eec\u4e5f\u6709 TfidfVectorizer\u3002\u8ba9\u6211\u4eec\u8bd5\u7740\u50cf\u4f7f\u7528 CountVectorizer \u4e00\u6837\u4f7f\u7528\u5b83\u3002 from sklearn.feature_extraction.text import TfidfVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) print ( corpus_transformed ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a ( 0 , 27 ) 0.2965698850220162 ( 0 , 16 ) 0.4428321995085722 ( 0 , 14 ) 0.4428321995085722 ( 0 , 7 ) 0.4428321995085722 ( 0 , 4 ) 0.35727423026525224 ( 0 , 2 ) 0.4428321995085722 ( 1 , 27 ) 0.35299699146792735 ( 1 , 24 ) 0.2635440111190765 ( 1 , 22 ) 0.2635440111190765 ( 1 , 18 ) 0.2635440111190765 ( 1 , 15 ) 0.2635440111190765 ( 1 , 13 ) 0.2635440111190765 ( 1 , 12 ) 0.2635440111190765 ( 1 , 9 ) 0.2635440111190765 ( 1 , 8 ) 0.2635440111190765 ( 1 , 6 ) 0.2635440111190765 ( 1 , 4 ) 0.42525129752567803 ( 1 , 3 ) 0.2635440111190765 ( 2 , 27 ) 0.31752680284846835 ( 2 , 19 ) 0.4741246485558491 ( 2 , 11 ) 0.4741246485558491 ( 2 , 10 ) 0.4741246485558491 ( 2 , 5 ) 0.4741246485558491 ( 3 , 25 ) 0.38775666010579296 ( 3 , 23 ) 0.38775666010579296 ( 3 , 21 ) 0.38775666010579296 ( 3 , 20 ) 0.38775666010579296 ( 3 , 17 ) 0.38775666010579296 ( 3 , 1 ) 0.38775666010579296 ( 3 , 0 ) 0.3128396318588854 ( 4 , 26 ) 0.2959842226518677 ( 4 , 0 ) 0.9551928286692534 \u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6b21\u6211\u4eec\u5f97\u5230\u7684\u4e0d\u662f\u6574\u6570\u503c\uff0c\u800c\u662f\u6d6e\u70b9\u6570\u3002 \u7528 TfidfVectorizer \u4ee3\u66ff CountVectorizer \u4e5f\u662f\u5c0f\u83dc\u4e00\u789f\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 TfidfTransformer\u3002\u5982\u679c\u4f60\u4f7f\u7528\u7684\u662f\u8ba1\u6570\u503c\uff0c\u53ef\u4ee5\u4f7f\u7528 TfidfTransformer \u5e76\u83b7\u5f97\u4e0e TfidfVectorizer \u76f8\u540c\u7684\u6548\u679c\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfidf_vec . fit ( train_df . review ) xtrain = tfidf_vec . transform ( train_df . review ) xtest = tfidf_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u6211\u4eec\u53ef\u4ee5\u770b\u770b TF-IDF \u5728\u903b\u8f91\u56de\u5f52\u6a21\u578b\u4e0a\u7684\u8868\u73b0\u5982\u4f55\u3002 Fold : 0 Accuracy = 0.8976 Fold : 1 Accuracy = 0.8998 Fold : 2 Accuracy = 0.8948 Fold : 3 Accuracy = 0.8912 Fold : 4 Accuracy = 0.8995 \u6211\u4eec\u770b\u5230\uff0c\u8fd9\u4e9b\u5206\u6570\u90fd\u6bd4 CountVectorizer \u9ad8\u4e00\u4e9b\uff0c\u56e0\u6b64\u5b83\u6210\u4e3a\u4e86\u6211\u4eec\u60f3\u8981\u51fb\u8d25\u7684\u65b0\u57fa\u51c6\u3002 NLP \u4e2d\u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u6982\u5ff5\u662f N-gram\u3002N-grams \u662f\u6309\u987a\u5e8f\u6392\u5217\u7684\u5355\u8bcd\u7ec4\u5408\u3002N-grams \u5f88\u5bb9\u6613\u521b\u5efa\u3002\u60a8\u53ea\u9700\u6ce8\u610f\u987a\u5e8f\u5373\u53ef\u3002\u4e3a\u4e86\u8ba9\u4e8b\u60c5\u53d8\u5f97\u66f4\u7b80\u5355\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 NLTK \u7684 N-gram \u5b9e\u73b0\u3002 from nltk import ngrams from nltk.tokenize import word_tokenize N = 3 sentence = \"hi, how are you?\" tokenized_sentence = word_tokenize ( sentence ) n_grams = list ( ngrams ( tokenized_sentence , N )) print ( n_grams ) \u7531\u6b64\u5f97\u5230\uff1a [( 'hi' , ',' , 'how' ), ( ',' , 'how' , 'are' ), ( 'how' , 'are' , 'you' ), ( 'are' , 'you' , '?' )] \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u521b\u5efa 2-gram \u6216 4-gram \u7b49\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b n-gram \u5c06\u6210\u4e3a\u6211\u4eec\u8bcd\u6c47\u8868\u7684\u4e00\u90e8\u5206\uff0c\u5f53\u6211\u4eec\u8ba1\u7b97\u8ba1\u6570\u6216 tf-idf \u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u4e00\u4e2a n-gram \u89c6\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u6807\u8bb0\u3002\u56e0\u6b64\uff0c\u5728\u67d0\u79cd\u7a0b\u5ea6\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u7ed3\u5408\u4e0a\u4e0b\u6587\u3002scikit-learn \u7684 CountVectorizer \u548c TfidfVectorizer \u5b9e\u73b0\u90fd\u901a\u8fc7 ngram_range \u53c2\u6570\u63d0\u4f9b n-gram\uff0c\u8be5\u53c2\u6570\u6709\u6700\u5c0f\u548c\u6700\u5927\u9650\u5236\u3002\u9ed8\u8ba4\u60c5\u51b5\u4e0b\uff0c\u8be5\u53c2\u6570\u4e3a\uff081, 1\uff09\u3002\u5f53\u6211\u4eec\u5c06\u5176\u6539\u4e3a (1, 3) \u65f6\uff0c\u6211\u4eec\u5c06\u770b\u5230\u5355\u5b57\u5143\u3001\u53cc\u5b57\u5143\u548c\u4e09\u5b57\u5143\u3002\u4ee3\u7801\u6539\u52a8\u5f88\u5c0f\u3002 \u7531\u4e8e\u5230\u76ee\u524d\u4e3a\u6b62\u6211\u4eec\u4f7f\u7528 tf-idf \u5f97\u5230\u4e86\u6700\u597d\u7684\u7ed3\u679c\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5305\u542b n-grams \u76f4\u81f3 trigrams \u662f\u5426\u80fd\u6539\u8fdb\u6a21\u578b\u3002\u552f\u4e00\u9700\u8981\u4fee\u6539\u7684\u662f TfidfVectorizer \u7684\u521d\u59cb\u5316\u3002 tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None , ngram_range = ( 1 , 3 ) ) \u8ba9\u6211\u4eec\u770b\u770b\u662f\u5426\u4f1a\u6709\u6539\u8fdb\u3002 Fold : 0 Accuracy = 0.8931 Fold : 1 Accuracy = 0.8941 Fold : 2 Accuracy = 0.897 Fold : 3 Accuracy = 0.8922 Fold : 4 Accuracy = 0.8847 \u770b\u8d77\u6765\u8fd8\u884c\uff0c\u4f46\u6211\u4eec\u770b\u4e0d\u5230\u4efb\u4f55\u6539\u8fdb\u3002 \u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u591a\u4f7f\u7528 bigrams \u6765\u83b7\u5f97\u6539\u8fdb\u3002 \u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u4e00\u90e8\u5206\u3002\u4e5f\u8bb8\u4f60\u53ef\u4ee5\u81ea\u5df1\u8bd5\u7740\u505a\u3002 NLP \u7684\u57fa\u7840\u77e5\u8bc6\u8fd8\u6709\u5f88\u591a\u3002\u4f60\u5fc5\u987b\u77e5\u9053\u7684\u4e00\u4e2a\u672f\u8bed\u662f\u8bcd\u5e72\u63d0\u53d6\uff08strmming\uff09\u3002\u53e6\u4e00\u4e2a\u662f\u8bcd\u5f62\u8fd8\u539f\uff08lemmatization\uff09\u3002 \u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f \u53ef\u4ee5\u5c06\u4e00\u4e2a\u8bcd\u51cf\u5c11\u5230\u6700\u5c0f\u5f62\u5f0f\u3002\u5728\u8bcd\u5e72\u63d0\u53d6\u7684\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5e72\u5355\u8bcd\uff0c\u800c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5f62\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u8bcd\u5f62\u8fd8\u539f\u6bd4\u8bcd\u5e72\u63d0\u53d6\u66f4\u6fc0\u8fdb\uff0c\u800c\u8bcd\u5e72\u63d0\u53d6\u66f4\u6d41\u884c\u548c\u5e7f\u6cdb\u3002\u8bcd\u5e72\u548c\u8bcd\u5f62\u90fd\u6765\u81ea\u8bed\u8a00\u5b66\u3002\u5982\u679c\u4f60\u6253\u7b97\u4e3a\u67d0\u79cd\u8bed\u8a00\u5236\u4f5c\u8bcd\u5e72\u6216\u8bcd\u578b\uff0c\u9700\u8981\u5bf9\u8be5\u8bed\u8a00\u6709\u6df1\u5165\u7684\u4e86\u89e3\u3002\u5982\u679c\u8981\u8fc7\u591a\u5730\u4ecb\u7ecd\u8fd9\u4e9b\u77e5\u8bc6\uff0c\u5c31\u610f\u5473\u7740\u8981\u5728\u672c\u4e66\u4e2d\u589e\u52a0\u4e00\u7ae0\u3002\u4f7f\u7528 NLTK \u8f6f\u4ef6\u5305\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e24\u79cd\u65b9\u6cd5\u7684\u4e00\u4e9b\u793a\u4f8b\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u5668\u3002\u6211\u5c06\u7528\u6700\u5e38\u89c1\u7684 Snowball Stemmer \u548c WordNet Lemmatizer \u6765\u4e3e\u4f8b\u8bf4\u660e\u3002 from nltk.stem import WordNetLemmatizer from nltk.stem.snowball import SnowballStemmer lemmatizer = WordNetLemmatizer () stemmer = SnowballStemmer ( \"english\" ) words = [ \"fishing\" , \"fishes\" , \"fished\" ] for word in words : print ( f \"word= { word } \" ) print ( f \"stemmed_word= { stemmer . stem ( word ) } \" ) print ( f \"lemma= { lemmatizer . lemmatize ( word ) } \" ) print ( \"\" ) \u8fd9\u5c06\u6253\u5370\uff1a word = fishing stemmed_word = fish lemma = fishing word = fishes stemmed_word = fish lemma = fish word = fished stemmed_word = fish lemma = fished \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u662f\u622a\u7136\u4e0d\u540c\u7684\u3002\u5f53\u6211\u4eec\u8fdb\u884c\u8bcd\u5e72\u63d0\u53d6\u65f6\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u4e00\u4e2a\u8bcd\u7684\u6700\u5c0f\u5f62\u5f0f\uff0c\u5b83\u53ef\u80fd\u662f\u4e5f\u53ef\u80fd\u4e0d\u662f\u8be5\u8bcd\u6240\u5c5e\u8bed\u8a00\u8bcd\u5178\u4e2d\u7684\u4e00\u4e2a\u8bcd\u3002\u4f46\u662f\uff0c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u8fd9\u5c06\u662f\u4e00\u4e2a\u8bcd\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5c1d\u8bd5\u6dfb\u52a0\u8bcd\u5e72\u548c\u8bcd\u7d20\u5316\uff0c\u770b\u770b\u662f\u5426\u80fd\u6539\u5584\u7ed3\u679c\u3002 \u60a8\u8fd8\u5e94\u8be5\u4e86\u89e3\u7684\u4e00\u4e2a\u4e3b\u9898\u662f\u4e3b\u9898\u63d0\u53d6\u3002 \u4e3b\u9898\u63d0\u53d6 \u53ef\u4ee5\u4f7f\u7528\u975e\u8d1f\u77e9\u9635\u56e0\u5f0f\u5206\u89e3\uff08NMF\uff09\u6216\u6f5c\u5728\u8bed\u4e49\u5206\u6790\uff08LSA\uff09\u6765\u5b8c\u6210\uff0c\u540e\u8005\u4e5f\u88ab\u79f0\u4e3a\u5947\u5f02\u503c\u5206\u89e3\u6216 SVD\u3002\u8fd9\u4e9b\u5206\u89e3\u6280\u672f\u53ef\u5c06\u6570\u636e\u7b80\u5316\u4e3a\u7ed9\u5b9a\u6570\u91cf\u7684\u6210\u5206\u3002 \u60a8\u53ef\u4ee5\u5728\u4ece CountVectorizer \u6216 TfidfVectorizer \u4e2d\u83b7\u5f97\u7684\u7a00\u758f\u77e9\u9635\u4e0a\u5e94\u7528\u5176\u4e2d\u4efb\u4f55\u4e00\u79cd\u6280\u672f\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u5e94\u7528\u5230\u4e4b\u524d\u4f7f\u7528\u8fc7\u7684 TfidfVetorizer \u4e0a\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import decomposition from sklearn.feature_extraction.text import TfidfVectorizer corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus = corpus . review . values tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) svd = decomposition . TruncatedSVD ( n_components = 10 ) corpus_svd = svd . fit ( corpus_transformed ) sample_index = 0 feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) N = 5 print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ]) \u60a8\u53ef\u4ee5\u4f7f\u7528\u5faa\u73af\u6765\u8fd0\u884c\u591a\u4e2a\u6837\u672c\u3002 N = 5 for sample_index in range ( 5 ): feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ] ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a [ 'the' , ',' , '.' , 'a' , 'and' ] [ 'br' , '<' , '>' , '/' , '-' ] [ 'i' , 'movie' , '!' , 'it' , 'was' ] [ ',' , '!' , \"''\" , '``' , 'you' ] [ '!' , 'the' , '...' , \"''\" , '``' ] \u4f60\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6839\u672c\u8bf4\u4e0d\u901a\u3002\u600e\u4e48\u529e\u5462\uff1f\u8ba9\u6211\u4eec\u8bd5\u7740\u6e05\u7406\u4e00\u4e0b\uff0c\u770b\u770b\u662f\u5426\u6709\u610f\u4e49\u3002\u8981\u6e05\u7406\u4efb\u4f55\u6587\u672c\u6570\u636e\uff0c\u5c24\u5176\u662f pandas \u6570\u636e\u5e27\u4e2d\u7684\u6587\u672c\u6570\u636e\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u51fd\u6570\u3002 import re import string def clean_text ( s ): s = s . split () s = \" \" . join ( s ) s = re . sub ( f '[ { re . escape ( string . punctuation ) } ]' , '' , s ) return s \u8be5\u51fd\u6570\u4f1a\u5c06 \"hi, how are you????\" \u8fd9\u6837\u7684\u5b57\u7b26\u4e32\u8f6c\u6362\u4e3a \"hi how are you\"\u3002\u8ba9\u6211\u4eec\u628a\u8fd9\u4e2a\u51fd\u6570\u5e94\u7528\u5230\u65e7\u7684 SVD \u4ee3\u7801\u4e2d\uff0c\u770b\u770b\u5b83\u662f\u5426\u80fd\u7ed9\u63d0\u53d6\u7684\u4e3b\u9898\u5e26\u6765\u63d0\u5347\u3002\u4f7f\u7528 pandas\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 apply \u51fd\u6570\u5c06\u6e05\u7406\u4ee3\u7801 \"\u5e94\u7528 \"\u5230\u4efb\u610f\u7ed9\u5b9a\u7684\u5217\u4e2d\u3002 import pandas as pd corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus . loc [:, \"review\" ] = corpus . review . apply ( clean_text ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u5728\u4e3b SVD \u811a\u672c\u4e2d\u6dfb\u52a0\u4e86\u4e00\u884c\u4ee3\u7801\uff0c\u8fd9\u5c31\u662f\u4f7f\u7528\u51fd\u6570\u548c pandas \u5e94\u7528\u7684\u597d\u5904\u3002\u8fd9\u6b21\u751f\u6210\u7684\u4e3b\u9898\u5982\u4e0b\u3002 [ 'the' , 'a' , 'and' , 'of' , 'to' ] [ 'i' , 'movie' , 'it' , 'was' , 'this' ] [ 'the' , 'was' , 'i' , 'were' , 'of' ] [ 'her' , 'was' , 'she' , 'i' , 'he' ] [ 'br' , 'to' , 'they' , 'he' , 'show' ] \u547c\uff01\u81f3\u5c11\u8fd9\u6bd4\u6211\u4eec\u4e4b\u524d\u597d\u591a\u4e86\u3002\u4f46\u4f60\u77e5\u9053\u5417\uff1f\u4f60\u53ef\u4ee5\u901a\u8fc7\u5728\u6e05\u7406\u529f\u80fd\u4e2d\u5220\u9664\u505c\u6b62\u8bcd\uff08stopwords\uff09\u6765\u4f7f\u5b83\u53d8\u5f97\u66f4\u597d\u3002\u4ec0\u4e48\u662fstopwords\uff1f\u5b83\u4eec\u662f\u5b58\u5728\u4e8e\u6bcf\u79cd\u8bed\u8a00\u4e2d\u7684\u9ad8\u9891\u8bcd\u3002\u4f8b\u5982\uff0c\u5728\u82f1\u8bed\u4e2d\uff0c\u8fd9\u4e9b\u8bcd\u5305\u62ec \"a\"\u3001\"an\"\u3001\"the\"\u3001\"for \"\u7b49\u3002\u5220\u9664\u505c\u6b62\u8bcd\u5e76\u975e\u603b\u662f\u660e\u667a\u7684\u9009\u62e9\uff0c\u8fd9\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u4e1a\u52a1\u95ee\u9898\u3002\u50cf \"I need a new dog\"\u8fd9\u6837\u7684\u53e5\u5b50\uff0c\u53bb\u6389\u505c\u6b62\u8bcd\u540e\u4f1a\u53d8\u6210 \"need new dog\"\uff0c\u6b64\u65f6\u6211\u4eec\u4e0d\u77e5\u9053\u8c01\u9700\u8981new dog\u3002 \u5982\u679c\u6211\u4eec\u603b\u662f\u5220\u9664\u505c\u6b62\u8bcd\uff0c\u5c31\u4f1a\u4e22\u5931\u5f88\u591a\u4e0a\u4e0b\u6587\u4fe1\u606f\u3002\u4f60\u53ef\u4ee5\u5728 NLTK \u4e2d\u627e\u5230\u8bb8\u591a\u8bed\u8a00\u7684\u505c\u6b62\u8bcd\uff0c\u5982\u679c\u6ca1\u6709\uff0c\u4f60\u4e5f\u53ef\u4ee5\u5728\u81ea\u5df1\u559c\u6b22\u7684\u641c\u7d22\u5f15\u64ce\u4e0a\u5feb\u901f\u641c\u7d22\u4e00\u4e0b\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u8f6c\u5230\u5927\u591a\u6570\u4eba\u90fd\u559c\u6b22\u4f7f\u7528\u7684\u65b9\u6cd5\uff1a\u6df1\u5ea6\u5b66\u4e60\u3002\u4f46\u9996\u5148\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u8bcd\u5d4c\u5165\uff08embedings for words\uff09\u3002\u4f60\u5df2\u7ecf\u770b\u5230\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u5c06\u6807\u8bb0\u8f6c\u6362\u6210\u4e86\u6570\u5b57\u3002\u56e0\u6b64\uff0c\u5982\u679c\u67d0\u4e2a\u8bed\u6599\u5e93\u4e2d\u6709 N \u4e2a\u552f\u4e00\u7684\u8bcd\u5757\uff0c\u5b83\u4eec\u53ef\u4ee5\u7528 0 \u5230 N-1 \u4e4b\u95f4\u7684\u6574\u6570\u6765\u8868\u793a\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u7528\u5411\u91cf\u6765\u8868\u793a\u8fd9\u4e9b\u6574\u6570\u8bcd\u5757\u3002\u8fd9\u79cd\u5c06\u5355\u8bcd\u8868\u793a\u6210\u5411\u91cf\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u5355\u8bcd\u5d4c\u5165\u6216\u5355\u8bcd\u5411\u91cf\u3002\u8c37\u6b4c\u7684 Word2Vec \u662f\u5c06\u5355\u8bcd\u8f6c\u6362\u4e3a\u5411\u91cf\u7684\u6700\u53e4\u8001\u65b9\u6cd5\u4e4b\u4e00\u3002\u6b64\u5916\uff0c\u8fd8\u6709 Facebook \u7684 FastText \u548c\u65af\u5766\u798f\u5927\u5b66\u7684 GloVe\uff08\u7528\u4e8e\u5355\u8bcd\u8868\u793a\u7684\u5168\u5c40\u5411\u91cf\uff09\u3002\u8fd9\u4e9b\u65b9\u6cd5\u5f7c\u6b64\u5927\u76f8\u5f84\u5ead\u3002 \u5176\u57fa\u672c\u601d\u60f3\u662f\u5efa\u7acb\u4e00\u4e2a\u6d45\u5c42\u7f51\u7edc\uff0c\u901a\u8fc7\u91cd\u6784\u8f93\u5165\u53e5\u5b50\u6765\u5b66\u4e60\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u5468\u56f4\u7684\u6240\u6709\u5355\u8bcd\u6765\u8bad\u7ec3\u7f51\u7edc\u9884\u6d4b\u4e00\u4e2a\u7f3a\u5931\u7684\u5355\u8bcd\uff0c\u5728\u6b64\u8fc7\u7a0b\u4e2d\uff0c\u7f51\u7edc\u5c06\u5b66\u4e60\u5e76\u66f4\u65b0\u6240\u6709\u76f8\u5173\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u8fd9\u79cd\u65b9\u6cd5\u4e5f\u88ab\u79f0\u4e3a\u8fde\u7eed\u8bcd\u888b\u6216 CBoW \u6a21\u578b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e2a\u5355\u8bcd\u6765\u9884\u6d4b\u4e0a\u4e0b\u6587\u4e2d\u7684\u5355\u8bcd\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8df3\u683c\u6a21\u578b\u3002Word2Vec \u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\u5b66\u4e60\u5d4c\u5165\u3002 FastText \u53ef\u4ee5\u5b66\u4e60\u5b57\u7b26 n-gram \u7684\u5d4c\u5165\u3002\u548c\u5355\u8bcd n-gram \u4e00\u6837\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5b57\u7b26\uff0c\u5219\u79f0\u4e3a\u5b57\u7b26 n-gram\uff0c\u6700\u540e\uff0cGloVe \u901a\u8fc7\u5171\u73b0\u77e9\u9635\u6765\u5b66\u4e60\u8fd9\u4e9b\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0c\u6240\u6709\u8fd9\u4e9b\u4e0d\u540c\u7c7b\u578b\u7684\u5d4c\u5165\u6700\u7ec8\u90fd\u4f1a\u8fd4\u56de\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u8bed\u6599\u5e93\uff08\u4f8b\u5982\u82f1\u8bed\u7ef4\u57fa\u767e\u79d1\uff09\u4e2d\u7684\u5355\u8bcd\uff0c\u503c\u662f\u5927\u5c0f\u4e3a N\uff08\u901a\u5e38\u4e3a 300\uff09\u7684\u5411\u91cf\u3002 \u56fe 1\uff1a\u53ef\u89c6\u5316\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u3002 \u56fe 1 \u663e\u793a\u4e86\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u7684\u53ef\u89c6\u5316\u6548\u679c\u3002\u5047\u8bbe\u6211\u4eec\u4ee5\u67d0\u79cd\u65b9\u5f0f\u5b8c\u6210\u4e86\u8bcd\u8bed\u7684\u4e8c\u7ef4\u8868\u793a\u3002\u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u4eceBerlin\uff08\u5fb7\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u4e2d\u51cf\u53bb\u5fb7\u56fd\uff08Germany\uff09\u7684\u5411\u91cf\uff0c\u518d\u52a0\u4e0a\u6cd5\u56fd\uff08france\uff09\u7684\u5411\u91cf\uff0c\u5c31\u4f1a\u5f97\u5230\u4e00\u4e2a\u63a5\u8fd1Paris\uff08\u6cd5\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u3002\u7531\u6b64\u53ef\u89c1\uff0c\u5d4c\u5165\u5f0f\u4e5f\u80fd\u8fdb\u884c\u7c7b\u6bd4\u3002 \u8fd9\u5e76\u4e0d\u603b\u662f\u6b63\u786e\u7684\uff0c\u4f46\u8fd9\u6837\u7684\u4f8b\u5b50\u6709\u52a9\u4e8e\u7406\u89e3\u5355\u8bcd\u5d4c\u5165\u7684\u4f5c\u7528\u3002\u50cf \"\u55e8\uff0c\u4f60\u597d\u5417 \"\u8fd9\u6837\u7684\u53e5\u5b50\u53ef\u4ee5\u7528\u4e0b\u9762\u7684\u4e00\u5806\u5411\u91cf\u6765\u8868\u793a\u3002 hi \u2500> [vector (v1) of size 300] , \u2500> [vector (v2) of size 300] how \u2500> [vector (v3) of size 300] are \u2500> [vector (v4) of size 300] you \u2500> [vector (v5) of size 300] ? \u2500> [vector (v6) of size 300] \u4f7f\u7528\u8fd9\u4e9b\u4fe1\u606f\u6709\u591a\u79cd\u65b9\u6cd5\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u4e4b\u4e00\u5c31\u662f\u4f7f\u7528\u5d4c\u5165\u5411\u91cf\u3002\u5982\u4e0a\u4f8b\u6240\u793a\uff0c\u6bcf\u4e2a\u5355\u8bcd\u90fd\u6709\u4e00\u4e2a 1x300 \u7684\u5d4c\u5165\u5411\u91cf\u3002\u5229\u7528\u8fd9\u4e9b\u4fe1\u606f\uff0c\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u51fa\u6574\u4e2a\u53e5\u5b50\u7684\u5d4c\u5165\u3002\u8ba1\u7b97\u65b9\u6cd5\u6709\u591a\u79cd\u3002\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\u5982\u4e0b\u6240\u793a\u3002\u5728\u8fd9\u4e2a\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u5c06\u7ed9\u5b9a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u5411\u91cf\u63d0\u53d6\u51fa\u6765\uff0c\u7136\u540e\u4ece\u6240\u6709\u6807\u8bb0\u8bcd\u7684\u5355\u8bcd\u5411\u91cf\u4e2d\u521b\u5efa\u4e00\u4e2a\u5f52\u4e00\u5316\u7684\u5355\u8bcd\u5411\u91cf\u3002\u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u53e5\u5b50\u5411\u91cf\u3002 import numpy as np def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): words = str ( s ) . lower () words = tokenizer ( words ) words = [ w for w in words if not w in stop_words ] words = [ w for w in words if w . isalpha ()] M = [] for w in words : if w in embedding_dict : M . append ( embedding_dict [ w ]) if len ( M ) == 0 : return np . zeros ( 300 ) M = np . array ( M ) v = M . sum ( axis = 0 ) return v / np . sqrt (( v ** 2 ) . sum ()) \u6211\u4eec\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5c06\u6240\u6709\u793a\u4f8b\u8f6c\u6362\u6210\u4e00\u4e2a\u5411\u91cf\u3002\u6211\u4eec\u80fd\u5426\u4f7f\u7528 fastText \u5411\u91cf\u6765\u6539\u8fdb\u4e4b\u524d\u7684\u7ed3\u679c\uff1f\u6bcf\u7bc7\u8bc4\u8bba\u90fd\u6709 300 \u4e2a\u7279\u5f81\u3002 import io import numpy as np import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df = df . sample ( frac = 1 ) . reset_index ( drop = True ) print ( \"Loading embeddings\" ) embeddings = load_vectors ( \"../input/crawl-300d-2M.vec\" ) print ( \"Creating sentence vectors\" ) vectors = [] for review in df . review . values : vectors . append ( sentence_to_vec ( s = review , embedding_dict = embeddings , stop_words = [], tokenizer = word_tokenize ) ) vectors = np . array ( vectors ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for fold_ , ( t_ , v_ ) in enumerate ( kf . split ( X = vectors , y = y )): print ( f \"Training fold: { fold_ } \" ) xtrain = vectors [ t_ , :] ytrain = y [ t_ ] xtest = vectors [ v_ , :] ytest = y [ v_ ] model = linear_model . LogisticRegression () model . fit ( xtrain , ytrain ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( ytest , preds ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u5c06\u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Loading embeddings Creating sentence vectors Training fold : 0 Accuracy = 0.8619 Training fold : 1 Accuracy = 0.8661 Training fold : 2 Accuracy = 0.8544 Training fold : 3 Accuracy = 0.8624 Training fold : 4 Accuracy = 0.8595 Wow\uff01\u771f\u662f\u51fa\u4e4e\u610f\u6599\u3002\u6211\u4eec\u6240\u505a\u7684\u4e00\u5207\u90fd\u662f\u4e3a\u4e86\u4f7f\u7528 FastText \u5d4c\u5165\u3002\u8bd5\u7740\u628a\u5d4c\u5165\u5f0f\u6362\u6210 GloVe\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u3002 \u5f53\u6211\u4eec\u8c08\u8bba\u6587\u672c\u6570\u636e\u65f6\uff0c\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\u4e00\u4ef6\u4e8b\u3002\u6587\u672c\u6570\u636e\u4e0e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u975e\u5e38\u76f8\u4f3c\u3002\u5982\u56fe 2 \u6240\u793a\uff0c\u6211\u4eec\u8bc4\u8bba\u4e2d\u7684\u4efb\u4f55\u6837\u672c\u90fd\u662f\u5728\u4e0d\u540c\u65f6\u95f4\u6233\u4e0a\u6309\u9012\u589e\u987a\u5e8f\u6392\u5217\u7684\u6807\u8bb0\u5e8f\u5217\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u53ef\u4ee5\u8868\u793a\u4e3a\u4e00\u4e2a\u5411\u91cf/\u5d4c\u5165\u3002 \u56fe 2\uff1a\u5c06\u6807\u8bb0\u8868\u793a\u4e3a\u5d4c\u5165\uff0c\u5e76\u5c06\u5176\u89c6\u4e3a\u65f6\u95f4\u5e8f\u5217 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e7f\u6cdb\u7528\u4e8e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u7684\u6a21\u578b\uff0c\u4f8b\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\uff08LSTM\uff09\u6216\u95e8\u63a7\u9012\u5f52\u5355\u5143\uff08GRU\uff09\uff0c\u751a\u81f3\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff08CNN\uff09\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u53cc\u5411 LSTM \u6a21\u578b\u3002 \u9996\u5148\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u9879\u76ee\u3002\u4f60\u53ef\u4ee5\u968f\u610f\u7ed9\u5b83\u547d\u540d\u3002\u7136\u540e\uff0c\u6211\u4eec\u7684\u7b2c\u4e00\u6b65\u5c06\u662f\u5206\u5272\u6570\u636e\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f df . to_csv ( \"../input/imdb_folds.csv\" , index = False ) \u5c06\u6570\u636e\u96c6\u5212\u5206\u4e3a\u591a\u4e2a\u6298\u53e0\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5728 dataset.py \u4e2d\u521b\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u4f1a\u8fd4\u56de\u4e00\u4e2a\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\u6837\u672c\u3002 import torch class IMDBDataset : def __init__ ( self , reviews , targets ): self . reviews = reviews self . target = targets def __len__ ( self ): return len ( self . reviews ) def __getitem__ ( self , item ): review = self . reviews [ item , :] target = self . target [ item ] return { \"review\" : torch . tensor ( review , dtype = torch . long ), \"target\" : torch . tensor ( target , dtype = torch . float ) } \u5b8c\u6210\u6570\u636e\u96c6\u5206\u7c7b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa lstm.py\uff0c\u5176\u4e2d\u5305\u542b\u6211\u4eec\u7684 LSTM \u6a21\u578b import torch import torch.nn as nn class LSTM ( nn . Module ): def __init__ ( self , embedding_matrix ): super ( LSTM , self ) . __init__ () num_words = embedding_matrix . shape [ 0 ] embed_dim = embedding_matrix . shape [ 1 ] self . embedding = nn . Embedding ( num_embeddings = num_words , embedding_dim = embed_dim ) self . embedding . weight = nn . Parameter ( torch . tensor ( embedding_matrix , dtype = torch . float32 ) ) self . embedding . weight . requires_grad = False self . lstm = nn . LSTM ( embed_dim , 128 , bidirectional = True , batch_first = True , ) self . out = nn . Linear ( 512 , 1 ) def forward ( self , x ): x = self . embedding ( x ) x , _ = self . lstm ( x ) avg_pool = torch . mean ( x , 1 ) max_pool , _ = torch . max ( x , 1 ) out = torch . cat (( avg_pool , max_pool ), 1 ) out = self . out ( out ) return out \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa engine.py\uff0c\u5176\u4e2d\u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u51fd\u6570\u3002 import torch import torch.nn as nn def train ( data_loader , model , optimizer , device ): model . train () for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () predictions = model ( reviews ) loss = nn . BCEWithLogitsLoss ()( predictions , targets . view ( - 1 , 1 ) ) loss . backward () optimizer . step () def evaluate ( data_loader , model , device ): final_predictions = [] final_targets = [] model . eval () with torch . no_grad (): for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) predictions = model ( reviews ) predictions = predictions . cpu () . numpy () . tolist () targets = data [ \"target\" ] . cpu () . numpy () . tolist () final_predictions . extend ( predictions ) final_targets . extend ( targets ) return final_predictions , final_targets \u8fd9\u4e9b\u51fd\u6570\u5c06\u5728 train.py \u4e2d\u4e3a\u6211\u4eec\u63d0\u4f9b\u5e2e\u52a9\uff0c\u8be5\u51fd\u6570\u7528\u4e8e\u8bad\u7ec3\u591a\u4e2a\u6298\u53e0\u3002 import io import torch import numpy as np import pandas as pd import tensorflow as tf from sklearn import metrics import config import dataset import engine import lstm def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def create_embedding_matrix ( word_index , embedding_dict ): embedding_matrix = np . zeros (( len ( word_index ) + 1 , 300 )) for word , i in word_index . items (): if word in embedding_dict : embedding_matrix [ i ] = embedding_dict [ word ] return embedding_matrix def run ( df , fold ): train_df = df [ df . kfold != fold ] . reset_index ( drop = True ) valid_df = df [ df . kfold == fold ] . reset_index ( drop = True ) print ( \"Fitting tokenizer\" ) tokenizer = tf . keras . preprocessing . text . Tokenizer () tokenizer . fit_on_texts ( df . review . values . tolist ()) xtrain = tokenizer . texts_to_sequences ( train_df . review . values ) xtest = tokenizer . texts_to_sequences ( valid_df . review . values ) xtrain = tf . keras . preprocessing . sequence . pad_sequences ( xtrain , maxlen = config . MAX_LEN ) xtest = tf . keras . preprocessing . sequence . pad_sequences ( xtest , maxlen = config . MAX_LEN ) train_dataset = dataset . IMDBDataset ( reviews = xtrain , targets = train_df . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 2 ) valid_dataset = dataset . IMDBDataset ( reviews = xtest , targets = valid_df . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) print ( \"Loading embeddings\" ) embedding_dict = load_vectors ( \"../input/crawl-300d-2M.vec\" ) embedding_matrix = create_embedding_matrix ( tokenizer . word_index , embedding_dict ) device = torch . device ( \"cuda\" ) model = lstm . LSTM ( embedding_matrix ) model . to ( device ) optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) print ( \"Training Model\" ) best_accuracy = 0 early_stopping_counter = 0 for epoch in range ( config . EPOCHS ): engine . train ( train_data_loader , model , optimizer , device ) outputs , targets = engine . evaluate ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"FOLD: { fold } , Epoch: { epoch } , Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : best_accuracy = accuracy else : early_stopping_counter += 1 if early_stopping_counter > 2 : break if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb_folds.csv\" ) run ( df , fold = 0 ) run ( df , fold = 1 ) run ( df , fold = 2 ) run ( df , fold = 3 ) run ( df , fold = 4 ) \u6700\u540e\u662f config.py\u3002 MAX_LEN = 128 TRAIN_BATCH_SIZE = 16 VALID_BATCH_SIZE = 8 EPOCHS = 10 \u8ba9\u6211\u4eec\u770b\u770b\u8f93\u51fa\uff1a FOLD : 0 , Epoch : 3 , Accuracy Score = 0.9015 FOLD : 1 , Epoch : 4 , Accuracy Score = 0.9007 FOLD : 2 , Epoch : 3 , Accuracy Score = 0.8924 FOLD : 3 , Epoch : 2 , Accuracy Score = 0.9 FOLD : 4 , Epoch : 1 , Accuracy Score = 0.878 \u8fd9\u662f\u8fc4\u4eca\u4e3a\u6b62\u6211\u4eec\u83b7\u5f97\u7684\u6700\u597d\u6210\u7ee9\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u53ea\u663e\u793a\u4e86\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7cbe\u5ea6\u6700\u9ad8\u7684Epoch\u3002 \u4f60\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u9884\u5148\u8bad\u7ec3\u7684\u5d4c\u5165\u548c\u7b80\u5355\u7684\u53cc\u5411 LSTM\u3002 \u5982\u679c\u4f60\u60f3\u6539\u53d8\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u53ea\u6539\u53d8 lstm.py \u4e2d\u7684\u6a21\u578b\u5e76\u4fdd\u6301\u4e00\u5207\u4e0d\u53d8\u3002 \u8fd9\u79cd\u4ee3\u7801\u53ea\u9700\u8981\u5f88\u5c11\u7684\u5b9e\u9a8c\u6539\u52a8\uff0c\u5e76\u4e14\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \u4f8b\u5982\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5b66\u4e60\u5d4c\u5165\u800c\u4e0d\u662f\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6\u4e00\u4e9b\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u7ec4\u5408\u591a\u4e2a\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528GRU\uff0c\u60a8\u53ef\u4ee5\u5728\u5d4c\u5165\u540e\u4f7f\u7528\u7a7a\u95f4dropout\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0GRU LSTM \u5c42\u4e4b\u540e\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0\u4e24\u4e2a LSTM \u5c42\uff0c\u60a8\u53ef\u4ee5\u8fdb\u884c LSTM-GRU-LSTM \u914d\u7f6e\uff0c\u60a8\u53ef\u4ee5\u7528\u5377\u79ef\u5c42\u66ff\u6362 LSTM \u7b49\uff0c\u800c\u65e0\u9700\u5bf9\u4ee3\u7801\u8fdb\u884c\u592a\u591a\u66f4\u6539\u3002 \u6211\u63d0\u5230\u7684\u5927\u90e8\u5206\u5185\u5bb9\u53ea\u9700\u8981\u66f4\u6539\u6a21\u578b\u7c7b\u3002 \u5f53\u60a8\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\u65f6\uff0c\u5c1d\u8bd5\u67e5\u770b\u6709\u591a\u5c11\u5355\u8bcd\u65e0\u6cd5\u627e\u5230\u5d4c\u5165\u4ee5\u53ca\u539f\u56e0\u3002 \u9884\u8bad\u7ec3\u5d4c\u5165\u7684\u5355\u8bcd\u8d8a\u591a\uff0c\u7ed3\u679c\u5c31\u8d8a\u597d\u3002 \u6211\u5411\u60a8\u5c55\u793a\u4ee5\u4e0b\u672a\u6ce8\u91ca\u7684 (!) \u51fd\u6570\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5b83\u4e3a\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u8bad\u7ec3\u5d4c\u5165\u521b\u5efa\u5d4c\u5165\u77e9\u9635\uff0c\u5176\u683c\u5f0f\u4e0e glove \u6216 fastText \u76f8\u540c\uff08\u53ef\u80fd\u9700\u8981\u8fdb\u884c\u4e00\u4e9b\u66f4\u6539\uff09\u3002 def load_embeddings ( word_index , embedding_file , vector_length = 300 ): max_features = len ( word_index ) + 1 words_to_find = list ( word_index . keys ()) more_words_to_find = [] for wtf in words_to_find : more_words_to_find . append ( wtf ) more_words_to_find . append ( str ( wtf ) . capitalize ()) more_words_to_find = set ( more_words_to_find ) def get_coefs ( word , * arr ): return word , np . asarray ( arr , dtype = 'float32' ) embeddings_index = dict ( get_coefs ( * o . strip () . split ( \" \" )) for o in open ( embedding_file ) if o . split ( \" \" )[ 0 ] in more_words_to_find and len ( o ) > 100 ) embedding_matrix = np . zeros (( max_features , vector_length )) for word , i in word_index . items (): if i >= max_features : continue embedding_vector = embeddings_index . get ( word ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . capitalize () ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . upper () ) if ( embedding_vector is not None and len ( embedding_vector ) == vector_length ): embedding_matrix [ i ] = embedding_vector return embedding_matrix \u9605\u8bfb\u5e76\u8fd0\u884c\u4e0a\u9762\u7684\u51fd\u6570\uff0c\u770b\u770b\u53d1\u751f\u4e86\u4ec0\u4e48\u3002 \u8be5\u51fd\u6570\u8fd8\u53ef\u4ee5\u4fee\u6539\u4e3a\u4f7f\u7528\u8bcd\u5e72\u8bcd\u6216\u8bcd\u5f62\u8fd8\u539f\u8bcd\u3002 \u6700\u540e\uff0c\u60a8\u5e0c\u671b\u8bad\u7ec3\u8bed\u6599\u5e93\u4e2d\u7684\u672a\u77e5\u5355\u8bcd\u6570\u91cf\u6700\u5c11\u3002 \u53e6\u4e00\u4e2a\u6280\u5de7\u662f\u5b66\u4e60\u5d4c\u5165\u5c42\uff0c\u5373\u4f7f\u5176\u53ef\u8bad\u7ec3\uff0c\u7136\u540e\u8bad\u7ec3\u7f51\u7edc\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5206\u7c7b\u95ee\u9898\u6784\u5efa\u4e86\u5f88\u591a\u6a21\u578b\u3002 \u7136\u800c\uff0c\u73b0\u5728\u662f\u5e03\u5076\u65f6\u4ee3\uff0c\u8d8a\u6765\u8d8a\u591a\u7684\u4eba\u8f6c\u5411\u57fa\u4e8e\u53d8\u5f62\u91d1\u521a\u7684\u6a21\u578b\u3002 \u57fa\u4e8e Transformer \u7684\u7f51\u7edc\u80fd\u591f\u5904\u7406\u672c\u8d28\u4e0a\u957f\u671f\u7684\u4f9d\u8d56\u5173\u7cfb\u3002 LSTM \u4ec5\u5f53\u5b83\u770b\u5230\u524d\u4e00\u4e2a\u5355\u8bcd\u65f6\u624d\u67e5\u770b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u3002 \u53d8\u538b\u5668\u7684\u60c5\u51b5\u5e76\u975e\u5982\u6b64\u3002 \u5b83\u53ef\u4ee5\u540c\u65f6\u67e5\u770b\u6574\u4e2a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u3002 \u56e0\u6b64\uff0c\u53e6\u4e00\u4e2a\u4f18\u70b9\u662f\u5b83\u53ef\u4ee5\u8f7b\u677e\u5e76\u884c\u5316\u5e76\u66f4\u6709\u6548\u5730\u4f7f\u7528 GPU\u3002 Transformers \u662f\u4e00\u4e2a\u975e\u5e38\u5e7f\u6cdb\u7684\u8bdd\u9898\uff0c\u6709\u592a\u591a\u7684\u6a21\u578b\uff1a BERT\u3001RoBERTa\u3001XLNet\u3001XLM-RoBERTa\u3001T5 \u7b49\u3002\u6211\u5c06\u5411\u60a8\u5c55\u793a\u4e00\u79cd\u53ef\u7528\u4e8e\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\uff08T5 \u9664\u5916\uff09\u8fdb\u884c\u5206\u7c7b\u7684\u901a\u7528\u65b9\u6cd5 \u6211\u4eec\u4e00\u76f4\u5728\u8ba8\u8bba\u7684\u95ee\u9898\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u4e9b\u53d8\u538b\u5668\u9700\u8981\u8bad\u7ec3\u5b83\u4eec\u6240\u9700\u7684\u8ba1\u7b97\u80fd\u529b\u3002 \u56e0\u6b64\uff0c\u5982\u679c\u60a8\u6ca1\u6709\u9ad8\u7aef\u7cfb\u7edf\uff0c\u4e0e\u57fa\u4e8e LSTM \u6216 TF-IDF \u7684\u6a21\u578b\u76f8\u6bd4\uff0c\u8bad\u7ec3\u6a21\u578b\u53ef\u80fd\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u3002 \u6211\u4eec\u8981\u505a\u7684\u7b2c\u4e00\u4ef6\u4e8b\u662f\u521b\u5efa\u4e00\u4e2a\u914d\u7f6e\u6587\u4ef6\u3002 import transformers MAX_LEN = 512 TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 EPOCHS = 10 BERT_PATH = \"../input/bert_base_uncased/\" MODEL_PATH = \"model.bin\" TRAINING_FILE = \"../input/imdb.csv\" TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u8fd9\u91cc\u7684\u914d\u7f6e\u6587\u4ef6\u662f\u6211\u4eec\u5b9a\u4e49\u5206\u8bcd\u5668\u548c\u5176\u4ed6\u6211\u4eec\u60f3\u8981\u7ecf\u5e38\u66f4\u6539\u7684\u53c2\u6570\u7684\u552f\u4e00\u5730\u65b9 \u2014\u2014 \u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u505a\u5f88\u591a\u5b9e\u9a8c\u800c\u4e0d\u9700\u8981\u8fdb\u884c\u5927\u91cf\u66f4\u6539\u3002 \u4e0b\u4e00\u6b65\u662f\u6784\u5efa\u6570\u636e\u96c6\u7c7b\u3002 import config import torch class BERTDataset : def __init__ ( self , review , target ): self . review = review self . target = target self . tokenizer = config . TOKENIZER self . max_len = config . MAX_LEN def __len__ ( self ): return len ( self . review ) def __getitem__ ( self , item ): review = str ( self . review [ item ]) review = \" \" . join ( review . split ()) inputs = self . tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = self . max_len , pad_to_max_length = True , ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] return { \"ids\" : torch . tensor ( ids , dtype = torch . long ), \"mask\" : torch . tensor ( mask , dtype = torch . long ), \"token_type_ids\" : torch . tensor ( token_type_ids , dtype = torch . long ), \"targets\" : torch . tensor ( self . target [ item ], dtype = torch . float ) } \u73b0\u5728\u6211\u4eec\u6765\u5230\u4e86\u8be5\u9879\u76ee\u7684\u6838\u5fc3\uff0c\u5373\u6a21\u578b\u3002 import config import transformers import torch.nn as nn class BERTBaseUncased ( nn . Module ): def __init__ ( self ): super ( BERTBaseUncased , self ) . __init__ () self . bert = transformers . BertModel . from_pretrained ( config . BERT_PATH ) self . bert_drop = nn . Dropout ( 0.3 ) self . out = nn . Linear ( 768 , 1 ) def forward ( self , ids , mask , token_type_ids ): hidden state _ , o2 = self . bert ( ids , attention_mask = mask , token_type_ids = token_type_ids ) bo = self . bert_drop ( o2 ) output = self . out ( bo ) return output \u8be5\u6a21\u578b\u8fd4\u56de\u5355\u4e2a\u8f93\u51fa\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e26\u6709 logits \u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\uff0c\u5b83\u9996\u5148\u5e94\u7528 sigmoid\uff0c\u7136\u540e\u8ba1\u7b97\u635f\u5931\u3002 \u8fd9\u662f\u5728engine.py \u4e2d\u5b8c\u6210\u7684\u3002 import torch import torch.nn as nn def loss_fn ( outputs , targets ): return nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) def train_fn ( data_loader , model , optimizer , device , scheduler ): model . train () for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) loss = loss_fn ( outputs , targets ) loss . backward () optimizer . step () scheduler . step () def eval_fn ( data_loader , model , device ): model . eval () fin_targets = [] fin_outputs = [] with torch . no_grad (): for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) targets = targets . cpu () . detach () fin_targets . extend ( targets . numpy () . tolist ()) outputs = torch . sigmoid ( outputs ) . cpu () . detach () fin_outputs . extend ( outputs . numpy () . tolist ()) return fin_outputs , fin_targets \u6700\u540e\uff0c\u6211\u4eec\u51c6\u5907\u597d\u8bad\u7ec3\u4e86\u3002 \u6211\u4eec\u6765\u770b\u770b\u8bad\u7ec3\u811a\u672c\u5427\uff01 import config import dataset import engine import torch import pandas as pd import torch.nn as nn import numpy as np from model import BERTBaseUncased from sklearn import model_selection from sklearn import metrics from transformers import AdamW from transformers import get_linear_schedule_with_warmup def train (): dfx = pd . read_csv ( config . TRAINING_FILE ) . fillna ( \"none\" ) dfx . sentiment = dfx . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df_train , df_valid = model_selection . train_test_split ( dfx , test_size = 0.1 , random_state = 42 , stratify = dfx . sentiment . values ) df_train = df_train . reset_index ( drop = True ) df_valid = df_valid . reset_index ( drop = True ) train_dataset = dataset . BERTDataset ( review = df_train . review . values , target = df_train . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 4 ) valid_dataset = dataset . BERTDataset ( review = df_valid . review . values , target = df_valid . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) device = torch . device ( \"cuda\" ) model = BERTBaseUncased () model . to ( device ) param_optimizer = list ( model . named_parameters ()) no_decay = [ \"bias\" , \"LayerNorm.bias\" , \"LayerNorm.weight\" ] optimizer_parameters = [ { \"params\" : [ p for n , p in param_optimizer if not any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.001 , } \uff0c { \"params\" : [ p for n , p in param_optimizer if any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.0 , }] num_train_steps = int ( len ( df_train ) / config . TRAIN_BATCH_SIZE * config . EPOCHS ) optimizer = AdamW ( optimizer_parameters , lr = 3e-5 ) scheduler = get_linear_schedule_with_warmup ( optimizer , num_warmup_steps = 0 , num_training_steps = num_train_steps ) model = nn . DataParallel ( model ) best_accuracy = 0 for epoch in range ( config . EPOCHS ): engine . train_fn ( train_data_loader , model , optimizer , device , scheduler ) outputs , targets = engine . eval_fn ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : torch . save ( model . state_dict (), config . MODEL_PATH ) best_accuracy = accuracy if __name__ == \"__main__\" : train () \u4e4d\u4e00\u770b\u53ef\u80fd\u770b\u8d77\u6765\u5f88\u591a\uff0c\u4f46\u4e00\u65e6\u60a8\u4e86\u89e3\u4e86\u5404\u4e2a\u7ec4\u4ef6\uff0c\u5c31\u4e0d\u518d\u90a3\u4e48\u7b80\u5355\u4e86\u3002 \u60a8\u53ea\u9700\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u5373\u53ef\u8f7b\u677e\u5c06\u5176\u66f4\u6539\u4e3a\u60a8\u60f3\u8981\u4f7f\u7528\u7684\u4efb\u4f55\u5176\u4ed6\u53d8\u538b\u5668\u6a21\u578b\u3002 \u8be5\u6a21\u578b\u7684\u51c6\u786e\u7387\u4e3a 93%\uff01 \u54c7\uff01 \u8fd9\u6bd4\u4efb\u4f55\u5176\u4ed6\u6a21\u578b\u90fd\u8981\u597d\u5f97\u591a\u3002 \u4f46\u662f\u8fd9\u503c\u5f97\u5417\uff1f \u6211\u4eec\u4f7f\u7528 LSTM \u80fd\u591f\u5b9e\u73b0 90% \u7684\u76ee\u6807\uff0c\u800c\u4e14\u5b83\u4eec\u66f4\u7b80\u5355\u3001\u66f4\u5bb9\u6613\u8bad\u7ec3\u5e76\u4e14\u63a8\u7406\u901f\u5ea6\u66f4\u5feb\u3002 \u901a\u8fc7\u4f7f\u7528\u4e0d\u540c\u7684\u6570\u636e\u5904\u7406\u6216\u8c03\u6574\u5c42\u3001\u8282\u70b9\u3001dropout\u3001\u5b66\u4e60\u7387\u3001\u66f4\u6539\u4f18\u5316\u5668\u7b49\u53c2\u6570\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6a21\u578b\u6539\u8fdb\u4e00\u4e2a\u767e\u5206\u70b9\u3002\u7136\u540e\u6211\u4eec\u5c06\u4ece BERT \u4e2d\u83b7\u5f97\u7ea6 2% \u7684\u6536\u76ca\u3002 \u53e6\u4e00\u65b9\u9762\uff0cBERT \u7684\u8bad\u7ec3\u65f6\u95f4\u8981\u957f\u5f97\u591a\uff0c\u53c2\u6570\u5f88\u591a\uff0c\u800c\u4e14\u63a8\u7406\u901f\u5ea6\u4e5f\u5f88\u6162\u3002 \u6700\u540e\uff0c\u60a8\u5e94\u8be5\u5ba1\u89c6\u81ea\u5df1\u7684\u4e1a\u52a1\u5e76\u505a\u51fa\u660e\u667a\u7684\u9009\u62e9\u3002 \u4e0d\u8981\u4ec5\u4ec5\u56e0\u4e3a BERT\u201c\u9177\u201d\u800c\u9009\u62e9\u5b83\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6211\u4eec\u5728\u8fd9\u91cc\u8ba8\u8bba\u7684\u552f\u4e00\u4efb\u52a1\u662f\u5206\u7c7b\uff0c\u4f46\u5c06\u5176\u66f4\u6539\u4e3a\u56de\u5f52\u3001\u591a\u6807\u7b7e\u6216\u591a\u7c7b\u53ea\u9700\u8981\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u3002 \u4f8b\u5982\uff0c\u591a\u7c7b\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u7684\u540c\u4e00\u95ee\u9898\u5c06\u6709\u591a\u4e2a\u8f93\u51fa\u548c\u4ea4\u53c9\u71b5\u635f\u5931\u3002 \u5176\u4ed6\u4e00\u5207\u90fd\u5e94\u8be5\u4fdd\u6301\u4e0d\u53d8\u3002 \u81ea\u7136\u8bed\u8a00\u5904\u7406\u975e\u5e38\u5e9e\u5927\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u5176\u4e2d\u7684\u4e00\u5c0f\u90e8\u5206\u3002 \u663e\u7136\uff0c\u8fd9\u662f\u4e00\u4e2a\u5f88\u5927\u7684\u6bd4\u4f8b\uff0c\u56e0\u4e3a\u5927\u591a\u6570\u5de5\u4e1a\u6a21\u578b\u90fd\u662f\u5206\u7c7b\u6216\u56de\u5f52\u6a21\u578b\u3002 \u5982\u679c\u6211\u5f00\u59cb\u8be6\u7ec6\u5199\u6240\u6709\u5185\u5bb9\uff0c\u6211\u6700\u7ec8\u53ef\u80fd\u4f1a\u5199\u51e0\u767e\u9875\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u51b3\u5b9a\u5c06\u6240\u6709\u5185\u5bb9\u5305\u542b\u5728\u4e00\u672c\u5355\u72ec\u7684\u4e66\u4e2d\uff1a\u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55 NLP \u95ee\u9898\uff01","title":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5"},{"location":"%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%88%96%E5%9B%9E%E5%BD%92%E6%96%B9%E6%B3%95/#_1","text":"\u6587\u672c\u95ee\u9898\u662f\u6211\u7684\u6700\u7231\u3002\u4e00\u822c\u6765\u8bf4\uff0c\u8fd9\u4e9b\u95ee\u9898\u4e5f\u88ab\u79f0\u4e3a \u81ea\u7136\u8bed\u8a00\u5904\u7406\uff08NLP\uff09\u95ee\u9898 \u3002NLP \u95ee\u9898\u4e0e\u56fe\u50cf\u95ee\u9898\u4e5f\u6709\u5f88\u5927\u4e0d\u540c\u3002\u4f60\u9700\u8981\u521b\u5efa\u4ee5\u524d\u4ece\u672a\u4e3a\u8868\u683c\u95ee\u9898\u521b\u5efa\u8fc7\u7684\u6570\u636e\u7ba1\u9053\u3002\u4f60\u9700\u8981\u4e86\u89e3\u5546\u4e1a\u6848\u4f8b\uff0c\u624d\u80fd\u5efa\u7acb\u4e00\u4e2a\u597d\u7684\u6a21\u578b\u3002\u987a\u4fbf\u8bf4\u4e00\u53e5\uff0c\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u4efb\u4f55\u4e8b\u60c5\u90fd\u662f\u5982\u6b64\u3002\u5efa\u7acb\u6a21\u578b\u4f1a\u8ba9\u4f60\u8fbe\u5230\u4e00\u5b9a\u7684\u6c34\u5e73\uff0c\u4f46\u8981\u60f3\u6539\u5584\u548c\u4fc3\u8fdb\u4f60\u6240\u5efa\u7acb\u6a21\u578b\u7684\u4e1a\u52a1\uff0c\u4f60\u5fc5\u987b\u4e86\u89e3\u5b83\u5bf9\u4e1a\u52a1\u7684\u5f71\u54cd\u3002 NLP \u95ee\u9898\u6709\u5f88\u591a\u79cd\uff0c\u5176\u4e2d\u6700\u5e38\u89c1\u7684\u662f\u5b57\u7b26\u4e32\u5206\u7c7b\u3002\u5f88\u591a\u65f6\u5019\uff0c\u6211\u4eec\u4f1a\u770b\u5230\u4eba\u4eec\u5728\u5904\u7406\u8868\u683c\u6570\u636e\u6216\u56fe\u50cf\u65f6\u8868\u73b0\u51fa\u8272\uff0c\u4f46\u5728\u5904\u7406\u6587\u672c\u65f6\uff0c\u4ed6\u4eec\u751a\u81f3\u4e0d\u77e5\u9053\u4ece\u4f55\u5165\u624b\u3002\u6587\u672c\u6570\u636e\u4e0e\u5176\u4ed6\u7c7b\u578b\u7684\u6570\u636e\u96c6\u6ca1\u6709\u4ec0\u4e48\u4e0d\u540c\u3002\u5bf9\u4e8e\u8ba1\u7b97\u673a\u6765\u8bf4\uff0c\u4e00\u5207\u90fd\u662f\u6570\u5b57\u3002 \u5047\u8bbe\u6211\u4eec\u4ece\u60c5\u611f\u5206\u7c7b\u8fd9\u4e00\u57fa\u672c\u4efb\u52a1\u5f00\u59cb\u3002\u6211\u4eec\u5c06\u5c1d\u8bd5\u5bf9\u7535\u5f71\u8bc4\u8bba\u8fdb\u884c\u60c5\u611f\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u6709\u4e00\u4e2a\u6587\u672c\uff0c\u5e76\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u60c5\u611f\u3002\u4f60\u5c06\u5982\u4f55\u5904\u7406\u8fd9\u7c7b\u95ee\u9898\uff1f\u662f\u5e94\u7528\u6df1\u5ea6\u795e\u7ecf\u7f51\u7edc\uff1f \u4e0d\uff0c\u7edd\u5bf9\u9519\u4e86\u3002\u4f60\u8981\u4ece\u6700\u57fa\u672c\u7684\u5f00\u59cb\u3002\u8ba9\u6211\u4eec\u5148\u770b\u770b\u8fd9\u4e9b\u6570\u636e\u662f\u4ec0\u4e48\u6837\u5b50\u7684\u3002 \u6211\u4eec\u4ece IMDB \u7535\u5f71\u8bc4\u8bba\u6570\u636e\u96c6 \u5f00\u59cb\uff0c\u8be5\u6570\u636e\u96c6\u5305\u542b 25000 \u7bc7\u6b63\u9762\u60c5\u611f\u8bc4\u8bba\u548c 25000 \u7bc7\u8d1f\u9762\u60c5\u611f\u8bc4\u8bba\u3002 \u6211\u5c06\u5728\u6b64\u8ba8\u8bba\u7684\u6982\u5ff5\u51e0\u4e4e\u9002\u7528\u4e8e\u4efb\u4f55\u6587\u672c\u5206\u7c7b\u6570\u636e\u96c6\u3002 \u8fd9\u4e2a\u6570\u636e\u96c6\u975e\u5e38\u5bb9\u6613\u7406\u89e3\u3002\u4e00\u7bc7\u8bc4\u8bba\u5bf9\u5e94\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5199\u7684\u662f\u8bc4\u8bba\u800c\u4e0d\u662f\u53e5\u5b50\u3002\u8bc4\u8bba\u5c31\u662f\u4e00\u5806\u53e5\u5b50\u3002\u6240\u4ee5\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u4f60\u4e00\u5b9a\u53ea\u770b\u5230\u4e86\u5bf9\u5355\u53e5\u7684\u5206\u7c7b\uff0c\u4f46\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u6211\u4eec\u5c06\u5bf9\u591a\u4e2a\u53e5\u5b50\u8fdb\u884c\u5206\u7c7b\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u8fd9\u610f\u5473\u7740\u4e0d\u4ec5\u4e00\u4e2a\u53e5\u5b50\u4f1a\u5bf9\u60c5\u611f\u4ea7\u751f\u5f71\u54cd\uff0c\u800c\u4e14\u60c5\u611f\u5f97\u5206\u662f\u591a\u4e2a\u53e5\u5b50\u5f97\u5206\u7684\u7ec4\u5408\u3002\u6570\u636e\u7b80\u4ecb\u5982\u56fe 1 \u6240\u793a\u3002 \u5982\u4f55\u7740\u624b\u89e3\u51b3\u8fd9\u6837\u7684\u95ee\u9898\uff1f\u4e00\u4e2a\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u624b\u5de5\u5236\u4f5c\u4e24\u4efd\u5355\u8bcd\u8868\u3002\u4e00\u4e2a\u5217\u8868\u5305\u542b\u4f60\u80fd\u60f3\u8c61\u5230\u7684\u6240\u6709\u6b63\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u597d\u3001\u68d2\u3001\u597d\u7b49\uff1b\u53e6\u4e00\u4e2a\u5217\u8868\u5305\u542b\u6240\u6709\u8d1f\u9762\u8bcd\u6c47\uff0c\u4f8b\u5982\u574f\u3001\u6076\u7b49\u3002\u6211\u4eec\u5148\u4e0d\u8981\u4e3e\u4f8b\u8bf4\u660e\u574f\u8bcd\uff0c\u5426\u5219\u8fd9\u672c\u4e66\u5c31\u53ea\u80fd\u4f9b 18 \u5c81\u4ee5\u4e0a\u7684\u4eba\u9605\u8bfb\u4e86\u3002\u4e00\u65e6\u4f60\u6709\u4e86\u8fd9\u4e9b\u5217\u8868\uff0c\u4f60\u751a\u81f3\u4e0d\u9700\u8981\u4e00\u4e2a\u6a21\u578b\u6765\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u4e9b\u5217\u8868\u4e5f\u88ab\u79f0\u4e3a\u60c5\u611f\u8bcd\u5178\u3002\u4f60\u53ef\u4ee5\u7528\u4e00\u4e2a\u7b80\u5355\u7684\u8ba1\u6570\u5668\u6765\u8ba1\u7b97\u53e5\u5b50\u4e2d\u6b63\u9762\u548c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u3002\u5982\u679c\u6b63\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u6b63\u9762\u60c5\u611f\uff1b\u5982\u679c\u8d1f\u9762\u8bcd\u8bed\u7684\u6570\u91cf\u8f83\u591a\uff0c\u5219\u8868\u793a\u8be5\u53e5\u5b50\u5177\u6709\u8d1f\u9762\u60c5\u611f\u3002\u5982\u679c\u53e5\u5b50\u4e2d\u6ca1\u6709\u8fd9\u4e9b\u8bcd\uff0c\u5219\u53ef\u4ee5\u8bf4\u8be5\u53e5\u5b50\u5177\u6709\u4e2d\u6027\u60c5\u611f\u3002\u8fd9\u662f\u6700\u53e4\u8001\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u73b0\u5728\u4ecd\u6709\u4eba\u5728\u4f7f\u7528\u3002\u5b83\u4e5f\u4e0d\u9700\u8981\u592a\u591a\u4ee3\u7801\u3002 def find_sentiment ( sentence , pos , neg ): sentence = sentence . split () sentence = set ( sentence ) num_common_pos = len ( sentence . intersection ( pos )) num_common_neg = len ( sentence . intersection ( neg )) if num_common_pos > num_common_neg : return \"positive\" if num_common_pos < num_common_neg : return \"negative\" return \"neutral\" \u4e0d\u8fc7\uff0c\u8fd9\u79cd\u65b9\u6cd5\u8003\u8651\u7684\u56e0\u7d20\u5e76\u4e0d\u591a\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u7684 split() \u4e5f\u5e76\u4e0d\u5b8c\u7f8e\u3002\u5982\u679c\u4f7f\u7528 split()\uff0c\u5c31\u4f1a\u51fa\u73b0\u8fd9\u6837\u7684\u53e5\u5b50\uff1a \"hi, how are you?\" \u7ecf\u8fc7\u5206\u5272\u540e\u53d8\u4e3a\uff1a [\"hi,\", \"how\",\"are\",\"you?\"] \u8fd9\u79cd\u65b9\u6cd5\u5e76\u4e0d\u7406\u60f3\uff0c\u56e0\u4e3a\u5355\u8bcd\u4e2d\u5305\u542b\u4e86\u9017\u53f7\u548c\u95ee\u53f7\uff0c\u5b83\u4eec\u5e76\u6ca1\u6709\u88ab\u5206\u5272\u3002\u56e0\u6b64\uff0c\u5982\u679c\u6ca1\u6709\u5728\u5206\u5272\u524d\u5bf9\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u8fdb\u884c\u9884\u5904\u7406\uff0c\u4e0d\u5efa\u8bae\u4f7f\u7528\u8fd9\u79cd\u65b9\u6cd5\u3002\u5c06\u5b57\u7b26\u4e32\u62c6\u5206\u4e3a\u5355\u8bcd\u5217\u8868\u79f0\u4e3a\u6807\u8bb0\u5316\u3002\u6700\u6d41\u884c\u7684\u6807\u8bb0\u5316\u65b9\u6cd5\u4e4b\u4e00\u6765\u81ea NLTK\uff08\u81ea\u7136\u8bed\u8a00\u5de5\u5177\u5305\uff09 \u3002 In [ X ]: from nltk.tokenize import word_tokenize In [ X ]: sentence = \"hi, how are you?\" In [ X ]: sentence . split () Out [ X ]: [ 'hi,' , 'how' , 'are' , 'you?' ] In [ X ]: word_tokenize ( sentence ) Out [ X ]: [ 'hi' , ',' , 'how' , 'are' , 'you' , '?' ] \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u4f7f\u7528 NLTK \u7684\u5355\u8bcd\u6807\u8bb0\u5316\u529f\u80fd\uff0c\u540c\u4e00\u4e2a\u53e5\u5b50\u7684\u62c6\u5206\u6548\u679c\u8981\u597d\u5f97\u591a\u3002\u4f7f\u7528\u5355\u8bcd\u5217\u8868\u8fdb\u884c\u5bf9\u6bd4\u7684\u6548\u679c\u4e5f\u4f1a\u66f4\u597d\uff01\u8fd9\u5c31\u662f\u6211\u4eec\u5c06\u5e94\u7528\u4e8e\u7b2c\u4e00\u4e2a\u60c5\u611f\u68c0\u6d4b\u6a21\u578b\u7684\u65b9\u6cd5\u3002 \u5728\u5904\u7406 NLP \u5206\u7c7b\u95ee\u9898\u65f6\uff0c\u60a8\u5e94\u8be5\u7ecf\u5e38\u5c1d\u8bd5\u7684\u57fa\u672c\u6a21\u578b\u4e4b\u4e00\u662f \u8bcd\u888b\u6a21\u578b\uff08bag of words\uff09 \u3002\u5728\u8bcd\u888b\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u4e00\u4e2a\u5de8\u5927\u7684\u7a00\u758f\u77e9\u9635\uff0c\u5b58\u50a8\u8bed\u6599\u5e93\uff08\u8bed\u6599\u5e93=\u6240\u6709\u6587\u6863=\u6240\u6709\u53e5\u5b50\uff09\u4e2d\u6240\u6709\u5355\u8bcd\u7684\u8ba1\u6570\u3002\u4e3a\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 CountVectorizer\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b83\u662f\u5982\u4f55\u5de5\u4f5c\u7684\u3002 from sklearn.feature_extraction.text import CountVectorizer corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer () ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) \u5982\u679c\u6211\u4eec\u6253\u5370 corpus_transformed\uff0c\u5c31\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e0b\u9762\u7684\u7ed3\u679c\uff1a ( 0 , 2 ) 1 ( 0 , 9 ) 1 ( 0 , 11 ) 1 ( 0 , 22 ) 1 ( 1 , 1 ) 1 ( 1 , 3 ) 1 ( 1 , 4 ) 1 ( 1 , 7 ) 1 ( 1 , 8 ) 1 ( 1 , 10 ) 1 ( 1 , 13 ) 1 ( 1 , 17 ) 1 ( 1 , 19 ) 1 ( 1 , 22 ) 2 ( 2 , 0 ) 1 ( 2 , 5 ) 1 ( 2 , 6 ) 1 ( 2 , 14 ) 1 ( 2 , 22 ) 1 ( 3 , 12 ) 1 ( 3 , 15 ) 1 ( 3 , 16 ) 1 ( 3 , 18 ) 1 ( 3 , 20 ) 1 ( 4 , 21 ) 1 \u5728\u524d\u9762\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u89c1\u8bc6\u8fc7\u8fd9\u79cd\u8868\u793a\u6cd5\u3002\u5373\u7a00\u758f\u8868\u793a\u6cd5\u3002\u56e0\u6b64\uff0c\u8bed\u6599\u5e93\u73b0\u5728\u662f\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5176\u4e2d\u7b2c\u4e00\u4e2a\u6837\u672c\u6709 4 \u4e2a\u5143\u7d20\uff0c\u7b2c\u4e8c\u4e2a\u6837\u672c\u6709 10 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\uff0c\u7b2c\u4e09\u4e2a\u6837\u672c\u6709 5 \u4e2a\u5143\u7d20\uff0c\u4ee5\u6b64\u7c7b\u63a8\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u4e9b\u5143\u7d20\u90fd\u6709\u76f8\u5173\u7684\u8ba1\u6570\u3002\u6709\u4e9b\u5143\u7d20\u4f1a\u51fa\u73b0\u4e24\u6b21\uff0c\u6709\u4e9b\u5219\u53ea\u6709\u4e00\u6b21\u3002\u4f8b\u5982\uff0c\u5728\u6837\u672c 2\uff08\u7b2c 1 \u884c\uff09\u4e2d\uff0c\u6211\u4eec\u770b\u5230\u7b2c 22 \u5217\u7684\u6570\u503c\u662f 2\u3002\u8fd9\u662f\u4e3a\u4ec0\u4e48\u5462\uff1f\u7b2c 22 \u5217\u662f\u4ec0\u4e48\uff1f CountVectorizer \u7684\u5de5\u4f5c\u65b9\u5f0f\u662f\u9996\u5148\u5bf9\u53e5\u5b50\u8fdb\u884c\u6807\u8bb0\u5316\u5904\u7406\uff0c\u7136\u540e\u4e3a\u6bcf\u4e2a\u6807\u8bb0\u8d4b\u503c\u3002\u56e0\u6b64\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u7531\u4e00\u4e2a\u552f\u4e00\u7d22\u5f15\u8868\u793a\u3002\u8fd9\u4e9b\u552f\u4e00\u7d22\u5f15\u5c31\u662f\u6211\u4eec\u770b\u5230\u7684\u5217\u3002CountVectorizer \u4f1a\u5b58\u50a8\u8fd9\u4e9b\u4fe1\u606f\u3002 print ( ctv . vocabulary_ ) { 'hello' : 9 , 'how' : 11 , 'are' : 2 , 'you' : 22 , 'im' : 13 , 'getting' : 8 , 'bored' : 4 , 'at' : 3 , 'home' : 10 , 'and' : 1 , 'what' : 19 , 'do' : 7 , 'think' : 17 , 'did' : 6 , 'know' : 14 , 'about' : 0 , 'counts' : 5 , 'let' : 15 , 'see' : 16 , 'if' : 12 , 'this' : 18 , 'works' : 20 , 'yes' : 21 } \u6211\u4eec\u770b\u5230\uff0c\u7d22\u5f15 22 \u5c5e\u4e8e \"you\"\uff0c\u800c\u5728\u7b2c\u4e8c\u53e5\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u4e24\u6b21 \"you\"\u3002\u6211\u5e0c\u671b\u5927\u5bb6\u73b0\u5728\u5df2\u7ecf\u6e05\u695a\u4ec0\u4e48\u662f\u8bcd\u888b\u4e86\u3002\u4f46\u662f\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4e9b\u7279\u6b8a\u5b57\u7b26\u3002\u6709\u65f6\uff0c\u8fd9\u4e9b\u7279\u6b8a\u5b57\u7b26\u4e5f\u5f88\u6709\u7528\u3002\u4f8b\u5982\uff0c\"? \"\u5728\u5927\u591a\u6570\u53e5\u5b50\u4e2d\u8868\u793a\u7591\u95ee\u53e5\u3002\u8ba9\u6211\u4eec\u628a scikit-learn \u7684 word_tokenize \u6574\u5408\u5230 CountVectorizer \u4e2d\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002 from sklearn.feature_extraction.text import CountVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] ctv = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) ctv . fit ( corpus ) corpus_transformed = ctv . transform ( corpus ) print ( ctv . vocabulary_ ) \u8fd9\u6837\uff0c\u6211\u4eec\u7684\u8bcd\u888b\u5c31\u53d8\u6210\u4e86\uff1a { 'hello' : 14 , ',' : 2 , 'how' : 16 , 'are' : 7 , 'you' : 27 , '?' : 4 , 'im' : 18 , 'getting' : 13 , 'bored' : 9 , 'at' : 8 , 'home' : 15 , '.' : 3 , 'and' : 6 , 'what' : 24 , 'do' : 12 , 'think' : 22 , 'did' : 11 , 'know' : 19 , 'about' : 5 , 'counts' : 10 , 'let' : 20 , \"'s\" : 1 , 'see' : 21 , 'if' : 17 , 'this' : 23 , 'works' : 25 , '!' : 0 , 'yes' : 26 } \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u5229\u7528 IMDB \u6570\u636e\u96c6\u4e2d\u7684\u6240\u6709\u53e5\u5b50\u521b\u5efa\u4e00\u4e2a\u7a00\u758f\u77e9\u9635\uff0c\u5e76\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u3002\u8be5\u6570\u636e\u96c6\u4e2d\u6b63\u8d1f\u6837\u672c\u7684\u6bd4\u4f8b\u4e3a 1:1\uff0c\u56e0\u6b64\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u6211\u4eec\u5c06\u4f7f\u7528 StratifiedKFold \u5e76\u521b\u5efa\u4e00\u4e2a\u811a\u672c\u6765\u8bad\u7ec35\u4e2a\u6298\u53e0\u3002\u4f60\u4f1a\u95ee\u4f7f\u7528\u54ea\u4e2a\u6a21\u578b\uff1f\u5bf9\u4e8e\u9ad8\u7ef4\u7a00\u758f\u6570\u636e\uff0c\u54ea\u4e2a\u6a21\u578b\u6700\u5feb\uff1f\u903b\u8f91\u56de\u5f52\u3002\u6211\u4eec\u5c06\u9996\u5148\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u6765\u5904\u7406\u8fd9\u4e2a\u6570\u636e\u96c6\uff0c\u5e76\u521b\u5efa\u7b2c\u4e00\u4e2a\u57fa\u51c6\u6a21\u578b\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) count_vec = CountVectorizer ( tokenizer = word_tokenize , token_pattern = None ) count_vec . fit ( train_df . review ) xtrain = count_vec . transform ( train_df . review ) xtest = count_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u6bb5\u4ee3\u7801\u7684\u8fd0\u884c\u9700\u8981\u4e00\u5b9a\u7684\u65f6\u95f4\uff0c\u4f46\u53ef\u4ee5\u5f97\u5230\u4ee5\u4e0b\u8f93\u51fa\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8903 Fold : 1 Accuracy = 0.897 Fold : 2 Accuracy = 0.891 Fold : 3 Accuracy = 0.8914 Fold : 4 Accuracy = 0.8931 \u54c7\uff0c\u51c6\u786e\u7387\u5df2\u7ecf\u8fbe\u5230 89%\uff0c\u800c\u6211\u4eec\u6240\u505a\u7684\u53ea\u662f\u4f7f\u7528\u8bcd\u888b\u548c\u903b\u8f91\u56de\u5f52\uff01\u8fd9\u771f\u662f\u592a\u68d2\u4e86\uff01\u4e0d\u8fc7\uff0c\u8fd9\u4e2a\u6a21\u578b\u7684\u8bad\u7ec3\u82b1\u8d39\u4e86\u5f88\u591a\u65f6\u95f4\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u80fd\u5426\u901a\u8fc7\u4f7f\u7528\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u6765\u7f29\u77ed\u8bad\u7ec3\u65f6\u95f4\u3002\u6734\u7d20\u8d1d\u53f6\u65af\u5206\u7c7b\u5668\u5728 NLP \u4efb\u52a1\u4e2d\u76f8\u5f53\u6d41\u884c\uff0c\u56e0\u4e3a\u7a00\u758f\u77e9\u9635\u975e\u5e38\u5e9e\u5927\uff0c\u800c\u6734\u7d20\u8d1d\u53f6\u65af\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6a21\u578b\u3002\u8981\u4f7f\u7528\u8fd9\u4e2a\u6a21\u578b\uff0c\u9700\u8981\u66f4\u6539\u4e00\u4e2a\u5bfc\u5165\u548c\u6a21\u578b\u7684\u884c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e2a\u6a21\u578b\u7684\u6027\u80fd\u5982\u4f55\u3002\u6211\u4eec\u5c06\u4f7f\u7528 scikit-learn \u4e2d\u7684 MultinomialNB\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import naive_bayes from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import CountVectorizer model = naive_bayes . MultinomialNB () model . fit ( xtrain , train_df . sentiment ) \u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Fold : 0 Accuracy = 0.8444 Fold : 1 Accuracy = 0.8499 Fold : 2 Accuracy = 0.8422 Fold : 3 Accuracy = 0.8443 Fold : 4 Accuracy = 0.8455 \u6211\u4eec\u770b\u5230\u8fd9\u4e2a\u5206\u6570\u5f88\u4f4e\u3002\u4f46\u6734\u7d20\u8d1d\u53f6\u65af\u6a21\u578b\u7684\u901f\u5ea6\u975e\u5e38\u5feb\u3002 NLP \u4e2d\u7684\u53e6\u4e00\u79cd\u65b9\u6cd5\u662f TF-IDF\uff0c\u5982\u4eca\u5927\u591a\u6570\u4eba\u90fd\u503e\u5411\u4e8e\u5ffd\u7565\u6216\u4e0d\u5c51\u4e8e\u4e86\u89e3\u8fd9\u79cd\u65b9\u6cd5\u3002TF \u662f\u672f\u8bed\u9891\u7387\uff0cIDF \u662f\u53cd\u5411\u6587\u6863\u9891\u7387\u3002\u4ece\u8fd9\u4e9b\u672f\u8bed\u6765\u770b\uff0c\u8fd9\u4f3c\u4e4e\u6709\u4e9b\u56f0\u96be\uff0c\u4f46\u901a\u8fc7 TF \u548c IDF \u7684\u8ba1\u7b97\u516c\u5f0f\uff0c\u4e8b\u60c5\u5c31\u4f1a\u53d8\u5f97\u5f88\u660e\u663e\u3002 $$ TF(t) = \\frac{Number\\ of\\ times\\ a\\ term\\ t\\ appears\\ in\\ a\\ document}{Total\\ number\\ of\\ terms\\ in \\ the\\ document} $$ \\[ IDF(t) = LOG\\left(\\frac{Total\\ number\\ of\\ documents}{Number\\ of\\ documents with\\ term\\ t\\ in\\ it}\\right) \\] \u672f\u8bed t \u7684 TF-IDF \u5b9a\u4e49\u4e3a\uff1a $$ TF-IDF(t) = TF(t) \\times IDF(t) $$ \u4e0e scikit-learn \u4e2d\u7684 CountVectorizer \u7c7b\u4f3c\uff0c\u6211\u4eec\u4e5f\u6709 TfidfVectorizer\u3002\u8ba9\u6211\u4eec\u8bd5\u7740\u50cf\u4f7f\u7528 CountVectorizer \u4e00\u6837\u4f7f\u7528\u5b83\u3002 from sklearn.feature_extraction.text import TfidfVectorizer from nltk.tokenize import word_tokenize corpus = [ \"hello, how are you?\" , \"im getting bored at home. And you? What do you think?\" , \"did you know about counts\" , \"let's see if this works!\" , \"YES!!!!\" ] tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) print ( corpus_transformed ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a ( 0 , 27 ) 0.2965698850220162 ( 0 , 16 ) 0.4428321995085722 ( 0 , 14 ) 0.4428321995085722 ( 0 , 7 ) 0.4428321995085722 ( 0 , 4 ) 0.35727423026525224 ( 0 , 2 ) 0.4428321995085722 ( 1 , 27 ) 0.35299699146792735 ( 1 , 24 ) 0.2635440111190765 ( 1 , 22 ) 0.2635440111190765 ( 1 , 18 ) 0.2635440111190765 ( 1 , 15 ) 0.2635440111190765 ( 1 , 13 ) 0.2635440111190765 ( 1 , 12 ) 0.2635440111190765 ( 1 , 9 ) 0.2635440111190765 ( 1 , 8 ) 0.2635440111190765 ( 1 , 6 ) 0.2635440111190765 ( 1 , 4 ) 0.42525129752567803 ( 1 , 3 ) 0.2635440111190765 ( 2 , 27 ) 0.31752680284846835 ( 2 , 19 ) 0.4741246485558491 ( 2 , 11 ) 0.4741246485558491 ( 2 , 10 ) 0.4741246485558491 ( 2 , 5 ) 0.4741246485558491 ( 3 , 25 ) 0.38775666010579296 ( 3 , 23 ) 0.38775666010579296 ( 3 , 21 ) 0.38775666010579296 ( 3 , 20 ) 0.38775666010579296 ( 3 , 17 ) 0.38775666010579296 ( 3 , 1 ) 0.38775666010579296 ( 3 , 0 ) 0.3128396318588854 ( 4 , 26 ) 0.2959842226518677 ( 4 , 0 ) 0.9551928286692534 \u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6b21\u6211\u4eec\u5f97\u5230\u7684\u4e0d\u662f\u6574\u6570\u503c\uff0c\u800c\u662f\u6d6e\u70b9\u6570\u3002 \u7528 TfidfVectorizer \u4ee3\u66ff CountVectorizer \u4e5f\u662f\u5c0f\u83dc\u4e00\u789f\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 TfidfTransformer\u3002\u5982\u679c\u4f60\u4f7f\u7528\u7684\u662f\u8ba1\u6570\u503c\uff0c\u53ef\u4ee5\u4f7f\u7528 TfidfTransformer \u5e76\u83b7\u5f97\u4e0e TfidfVectorizer \u76f8\u540c\u7684\u6548\u679c\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer for fold_ in range ( 5 ): train_df = df [ df . kfold != fold_ ] . reset_index ( drop = True ) test_df = df [ df . kfold == fold_ ] . reset_index ( drop = True ) tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfidf_vec . fit ( train_df . review ) xtrain = tfidf_vec . transform ( train_df . review ) xtest = tfidf_vec . transform ( test_df . review ) model = linear_model . LogisticRegression () model . fit ( xtrain , train_df . sentiment ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( test_df . sentiment , preds ) print ( f \"Fold: { fold_ } \" ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u6211\u4eec\u53ef\u4ee5\u770b\u770b TF-IDF \u5728\u903b\u8f91\u56de\u5f52\u6a21\u578b\u4e0a\u7684\u8868\u73b0\u5982\u4f55\u3002 Fold : 0 Accuracy = 0.8976 Fold : 1 Accuracy = 0.8998 Fold : 2 Accuracy = 0.8948 Fold : 3 Accuracy = 0.8912 Fold : 4 Accuracy = 0.8995 \u6211\u4eec\u770b\u5230\uff0c\u8fd9\u4e9b\u5206\u6570\u90fd\u6bd4 CountVectorizer \u9ad8\u4e00\u4e9b\uff0c\u56e0\u6b64\u5b83\u6210\u4e3a\u4e86\u6211\u4eec\u60f3\u8981\u51fb\u8d25\u7684\u65b0\u57fa\u51c6\u3002 NLP \u4e2d\u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u6982\u5ff5\u662f N-gram\u3002N-grams \u662f\u6309\u987a\u5e8f\u6392\u5217\u7684\u5355\u8bcd\u7ec4\u5408\u3002N-grams \u5f88\u5bb9\u6613\u521b\u5efa\u3002\u60a8\u53ea\u9700\u6ce8\u610f\u987a\u5e8f\u5373\u53ef\u3002\u4e3a\u4e86\u8ba9\u4e8b\u60c5\u53d8\u5f97\u66f4\u7b80\u5355\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 NLTK \u7684 N-gram \u5b9e\u73b0\u3002 from nltk import ngrams from nltk.tokenize import word_tokenize N = 3 sentence = \"hi, how are you?\" tokenized_sentence = word_tokenize ( sentence ) n_grams = list ( ngrams ( tokenized_sentence , N )) print ( n_grams ) \u7531\u6b64\u5f97\u5230\uff1a [( 'hi' , ',' , 'how' ), ( ',' , 'how' , 'are' ), ( 'how' , 'are' , 'you' ), ( 'are' , 'you' , '?' )] \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u521b\u5efa 2-gram \u6216 4-gram \u7b49\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b n-gram \u5c06\u6210\u4e3a\u6211\u4eec\u8bcd\u6c47\u8868\u7684\u4e00\u90e8\u5206\uff0c\u5f53\u6211\u4eec\u8ba1\u7b97\u8ba1\u6570\u6216 tf-idf \u65f6\uff0c\u6211\u4eec\u4f1a\u5c06\u4e00\u4e2a n-gram \u89c6\u4e3a\u4e00\u4e2a\u5168\u65b0\u7684\u6807\u8bb0\u3002\u56e0\u6b64\uff0c\u5728\u67d0\u79cd\u7a0b\u5ea6\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u7ed3\u5408\u4e0a\u4e0b\u6587\u3002scikit-learn \u7684 CountVectorizer \u548c TfidfVectorizer \u5b9e\u73b0\u90fd\u901a\u8fc7 ngram_range \u53c2\u6570\u63d0\u4f9b n-gram\uff0c\u8be5\u53c2\u6570\u6709\u6700\u5c0f\u548c\u6700\u5927\u9650\u5236\u3002\u9ed8\u8ba4\u60c5\u51b5\u4e0b\uff0c\u8be5\u53c2\u6570\u4e3a\uff081, 1\uff09\u3002\u5f53\u6211\u4eec\u5c06\u5176\u6539\u4e3a (1, 3) \u65f6\uff0c\u6211\u4eec\u5c06\u770b\u5230\u5355\u5b57\u5143\u3001\u53cc\u5b57\u5143\u548c\u4e09\u5b57\u5143\u3002\u4ee3\u7801\u6539\u52a8\u5f88\u5c0f\u3002 \u7531\u4e8e\u5230\u76ee\u524d\u4e3a\u6b62\u6211\u4eec\u4f7f\u7528 tf-idf \u5f97\u5230\u4e86\u6700\u597d\u7684\u7ed3\u679c\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5305\u542b n-grams \u76f4\u81f3 trigrams \u662f\u5426\u80fd\u6539\u8fdb\u6a21\u578b\u3002\u552f\u4e00\u9700\u8981\u4fee\u6539\u7684\u662f TfidfVectorizer \u7684\u521d\u59cb\u5316\u3002 tfidf_vec = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None , ngram_range = ( 1 , 3 ) ) \u8ba9\u6211\u4eec\u770b\u770b\u662f\u5426\u4f1a\u6709\u6539\u8fdb\u3002 Fold : 0 Accuracy = 0.8931 Fold : 1 Accuracy = 0.8941 Fold : 2 Accuracy = 0.897 Fold : 3 Accuracy = 0.8922 Fold : 4 Accuracy = 0.8847 \u770b\u8d77\u6765\u8fd8\u884c\uff0c\u4f46\u6211\u4eec\u770b\u4e0d\u5230\u4efb\u4f55\u6539\u8fdb\u3002 \u4e5f\u8bb8\u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u591a\u4f7f\u7528 bigrams \u6765\u83b7\u5f97\u6539\u8fdb\u3002 \u6211\u4e0d\u4f1a\u5728\u8fd9\u91cc\u5c55\u793a\u8fd9\u4e00\u90e8\u5206\u3002\u4e5f\u8bb8\u4f60\u53ef\u4ee5\u81ea\u5df1\u8bd5\u7740\u505a\u3002 NLP \u7684\u57fa\u7840\u77e5\u8bc6\u8fd8\u6709\u5f88\u591a\u3002\u4f60\u5fc5\u987b\u77e5\u9053\u7684\u4e00\u4e2a\u672f\u8bed\u662f\u8bcd\u5e72\u63d0\u53d6\uff08strmming\uff09\u3002\u53e6\u4e00\u4e2a\u662f\u8bcd\u5f62\u8fd8\u539f\uff08lemmatization\uff09\u3002 \u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f \u53ef\u4ee5\u5c06\u4e00\u4e2a\u8bcd\u51cf\u5c11\u5230\u6700\u5c0f\u5f62\u5f0f\u3002\u5728\u8bcd\u5e72\u63d0\u53d6\u7684\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5e72\u5355\u8bcd\uff0c\u800c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u5904\u7406\u540e\u7684\u5355\u8bcd\u79f0\u4e3a\u8bcd\u5f62\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u8bcd\u5f62\u8fd8\u539f\u6bd4\u8bcd\u5e72\u63d0\u53d6\u66f4\u6fc0\u8fdb\uff0c\u800c\u8bcd\u5e72\u63d0\u53d6\u66f4\u6d41\u884c\u548c\u5e7f\u6cdb\u3002\u8bcd\u5e72\u548c\u8bcd\u5f62\u90fd\u6765\u81ea\u8bed\u8a00\u5b66\u3002\u5982\u679c\u4f60\u6253\u7b97\u4e3a\u67d0\u79cd\u8bed\u8a00\u5236\u4f5c\u8bcd\u5e72\u6216\u8bcd\u578b\uff0c\u9700\u8981\u5bf9\u8be5\u8bed\u8a00\u6709\u6df1\u5165\u7684\u4e86\u89e3\u3002\u5982\u679c\u8981\u8fc7\u591a\u5730\u4ecb\u7ecd\u8fd9\u4e9b\u77e5\u8bc6\uff0c\u5c31\u610f\u5473\u7740\u8981\u5728\u672c\u4e66\u4e2d\u589e\u52a0\u4e00\u7ae0\u3002\u4f7f\u7528 NLTK \u8f6f\u4ef6\u5305\u53ef\u4ee5\u8f7b\u677e\u5b8c\u6210\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e24\u79cd\u65b9\u6cd5\u7684\u4e00\u4e9b\u793a\u4f8b\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u5668\u3002\u6211\u5c06\u7528\u6700\u5e38\u89c1\u7684 Snowball Stemmer \u548c WordNet Lemmatizer \u6765\u4e3e\u4f8b\u8bf4\u660e\u3002 from nltk.stem import WordNetLemmatizer from nltk.stem.snowball import SnowballStemmer lemmatizer = WordNetLemmatizer () stemmer = SnowballStemmer ( \"english\" ) words = [ \"fishing\" , \"fishes\" , \"fished\" ] for word in words : print ( f \"word= { word } \" ) print ( f \"stemmed_word= { stemmer . stem ( word ) } \" ) print ( f \"lemma= { lemmatizer . lemmatize ( word ) } \" ) print ( \"\" ) \u8fd9\u5c06\u6253\u5370\uff1a word = fishing stemmed_word = fish lemma = fishing word = fishes stemmed_word = fish lemma = fish word = fished stemmed_word = fish lemma = fished \u6b63\u5982\u60a8\u6240\u770b\u5230\u7684\uff0c\u8bcd\u5e72\u63d0\u53d6\u548c\u8bcd\u5f62\u8fd8\u539f\u662f\u622a\u7136\u4e0d\u540c\u7684\u3002\u5f53\u6211\u4eec\u8fdb\u884c\u8bcd\u5e72\u63d0\u53d6\u65f6\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u4e00\u4e2a\u8bcd\u7684\u6700\u5c0f\u5f62\u5f0f\uff0c\u5b83\u53ef\u80fd\u662f\u4e5f\u53ef\u80fd\u4e0d\u662f\u8be5\u8bcd\u6240\u5c5e\u8bed\u8a00\u8bcd\u5178\u4e2d\u7684\u4e00\u4e2a\u8bcd\u3002\u4f46\u662f\uff0c\u5728\u8bcd\u5f62\u8fd8\u539f\u60c5\u51b5\u4e0b\uff0c\u8fd9\u5c06\u662f\u4e00\u4e2a\u8bcd\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5c1d\u8bd5\u6dfb\u52a0\u8bcd\u5e72\u548c\u8bcd\u7d20\u5316\uff0c\u770b\u770b\u662f\u5426\u80fd\u6539\u5584\u7ed3\u679c\u3002 \u60a8\u8fd8\u5e94\u8be5\u4e86\u89e3\u7684\u4e00\u4e2a\u4e3b\u9898\u662f\u4e3b\u9898\u63d0\u53d6\u3002 \u4e3b\u9898\u63d0\u53d6 \u53ef\u4ee5\u4f7f\u7528\u975e\u8d1f\u77e9\u9635\u56e0\u5f0f\u5206\u89e3\uff08NMF\uff09\u6216\u6f5c\u5728\u8bed\u4e49\u5206\u6790\uff08LSA\uff09\u6765\u5b8c\u6210\uff0c\u540e\u8005\u4e5f\u88ab\u79f0\u4e3a\u5947\u5f02\u503c\u5206\u89e3\u6216 SVD\u3002\u8fd9\u4e9b\u5206\u89e3\u6280\u672f\u53ef\u5c06\u6570\u636e\u7b80\u5316\u4e3a\u7ed9\u5b9a\u6570\u91cf\u7684\u6210\u5206\u3002 \u60a8\u53ef\u4ee5\u5728\u4ece CountVectorizer \u6216 TfidfVectorizer \u4e2d\u83b7\u5f97\u7684\u7a00\u758f\u77e9\u9635\u4e0a\u5e94\u7528\u5176\u4e2d\u4efb\u4f55\u4e00\u79cd\u6280\u672f\u3002 \u8ba9\u6211\u4eec\u628a\u5b83\u5e94\u7528\u5230\u4e4b\u524d\u4f7f\u7528\u8fc7\u7684 TfidfVetorizer \u4e0a\u3002 import pandas as pd from nltk.tokenize import word_tokenize from sklearn import decomposition from sklearn.feature_extraction.text import TfidfVectorizer corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus = corpus . review . values tfv = TfidfVectorizer ( tokenizer = word_tokenize , token_pattern = None ) tfv . fit ( corpus ) corpus_transformed = tfv . transform ( corpus ) svd = decomposition . TruncatedSVD ( n_components = 10 ) corpus_svd = svd . fit ( corpus_transformed ) sample_index = 0 feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) N = 5 print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ]) \u60a8\u53ef\u4ee5\u4f7f\u7528\u5faa\u73af\u6765\u8fd0\u884c\u591a\u4e2a\u6837\u672c\u3002 N = 5 for sample_index in range ( 5 ): feature_scores = dict ( zip ( tfv . get_feature_names (), corpus_svd . components_ [ sample_index ] ) ) print ( sorted ( feature_scores , key = feature_scores . get , reverse = True )[: N ] ) \u8f93\u51fa\u7ed3\u679c\u5982\u4e0b\uff1a [ 'the' , ',' , '.' , 'a' , 'and' ] [ 'br' , '<' , '>' , '/' , '-' ] [ 'i' , 'movie' , '!' , 'it' , 'was' ] [ ',' , '!' , \"''\" , '``' , 'you' ] [ '!' , 'the' , '...' , \"''\" , '``' ] \u4f60\u53ef\u4ee5\u770b\u5230\uff0c\u8fd9\u6839\u672c\u8bf4\u4e0d\u901a\u3002\u600e\u4e48\u529e\u5462\uff1f\u8ba9\u6211\u4eec\u8bd5\u7740\u6e05\u7406\u4e00\u4e0b\uff0c\u770b\u770b\u662f\u5426\u6709\u610f\u4e49\u3002\u8981\u6e05\u7406\u4efb\u4f55\u6587\u672c\u6570\u636e\uff0c\u5c24\u5176\u662f pandas \u6570\u636e\u5e27\u4e2d\u7684\u6587\u672c\u6570\u636e\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u51fd\u6570\u3002 import re import string def clean_text ( s ): s = s . split () s = \" \" . join ( s ) s = re . sub ( f '[ { re . escape ( string . punctuation ) } ]' , '' , s ) return s \u8be5\u51fd\u6570\u4f1a\u5c06 \"hi, how are you????\" \u8fd9\u6837\u7684\u5b57\u7b26\u4e32\u8f6c\u6362\u4e3a \"hi how are you\"\u3002\u8ba9\u6211\u4eec\u628a\u8fd9\u4e2a\u51fd\u6570\u5e94\u7528\u5230\u65e7\u7684 SVD \u4ee3\u7801\u4e2d\uff0c\u770b\u770b\u5b83\u662f\u5426\u80fd\u7ed9\u63d0\u53d6\u7684\u4e3b\u9898\u5e26\u6765\u63d0\u5347\u3002\u4f7f\u7528 pandas\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 apply \u51fd\u6570\u5c06\u6e05\u7406\u4ee3\u7801 \"\u5e94\u7528 \"\u5230\u4efb\u610f\u7ed9\u5b9a\u7684\u5217\u4e2d\u3002 import pandas as pd corpus = pd . read_csv ( \"../input/imdb.csv\" , nrows = 10000 ) corpus . loc [:, \"review\" ] = corpus . review . apply ( clean_text ) \u8bf7\u6ce8\u610f\uff0c\u6211\u4eec\u53ea\u5728\u4e3b SVD \u811a\u672c\u4e2d\u6dfb\u52a0\u4e86\u4e00\u884c\u4ee3\u7801\uff0c\u8fd9\u5c31\u662f\u4f7f\u7528\u51fd\u6570\u548c pandas \u5e94\u7528\u7684\u597d\u5904\u3002\u8fd9\u6b21\u751f\u6210\u7684\u4e3b\u9898\u5982\u4e0b\u3002 [ 'the' , 'a' , 'and' , 'of' , 'to' ] [ 'i' , 'movie' , 'it' , 'was' , 'this' ] [ 'the' , 'was' , 'i' , 'were' , 'of' ] [ 'her' , 'was' , 'she' , 'i' , 'he' ] [ 'br' , 'to' , 'they' , 'he' , 'show' ] \u547c\uff01\u81f3\u5c11\u8fd9\u6bd4\u6211\u4eec\u4e4b\u524d\u597d\u591a\u4e86\u3002\u4f46\u4f60\u77e5\u9053\u5417\uff1f\u4f60\u53ef\u4ee5\u901a\u8fc7\u5728\u6e05\u7406\u529f\u80fd\u4e2d\u5220\u9664\u505c\u6b62\u8bcd\uff08stopwords\uff09\u6765\u4f7f\u5b83\u53d8\u5f97\u66f4\u597d\u3002\u4ec0\u4e48\u662fstopwords\uff1f\u5b83\u4eec\u662f\u5b58\u5728\u4e8e\u6bcf\u79cd\u8bed\u8a00\u4e2d\u7684\u9ad8\u9891\u8bcd\u3002\u4f8b\u5982\uff0c\u5728\u82f1\u8bed\u4e2d\uff0c\u8fd9\u4e9b\u8bcd\u5305\u62ec \"a\"\u3001\"an\"\u3001\"the\"\u3001\"for \"\u7b49\u3002\u5220\u9664\u505c\u6b62\u8bcd\u5e76\u975e\u603b\u662f\u660e\u667a\u7684\u9009\u62e9\uff0c\u8fd9\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u4e1a\u52a1\u95ee\u9898\u3002\u50cf \"I need a new dog\"\u8fd9\u6837\u7684\u53e5\u5b50\uff0c\u53bb\u6389\u505c\u6b62\u8bcd\u540e\u4f1a\u53d8\u6210 \"need new dog\"\uff0c\u6b64\u65f6\u6211\u4eec\u4e0d\u77e5\u9053\u8c01\u9700\u8981new dog\u3002 \u5982\u679c\u6211\u4eec\u603b\u662f\u5220\u9664\u505c\u6b62\u8bcd\uff0c\u5c31\u4f1a\u4e22\u5931\u5f88\u591a\u4e0a\u4e0b\u6587\u4fe1\u606f\u3002\u4f60\u53ef\u4ee5\u5728 NLTK \u4e2d\u627e\u5230\u8bb8\u591a\u8bed\u8a00\u7684\u505c\u6b62\u8bcd\uff0c\u5982\u679c\u6ca1\u6709\uff0c\u4f60\u4e5f\u53ef\u4ee5\u5728\u81ea\u5df1\u559c\u6b22\u7684\u641c\u7d22\u5f15\u64ce\u4e0a\u5feb\u901f\u641c\u7d22\u4e00\u4e0b\u3002 \u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u8f6c\u5230\u5927\u591a\u6570\u4eba\u90fd\u559c\u6b22\u4f7f\u7528\u7684\u65b9\u6cd5\uff1a\u6df1\u5ea6\u5b66\u4e60\u3002\u4f46\u9996\u5148\uff0c\u6211\u4eec\u5fc5\u987b\u77e5\u9053\u4ec0\u4e48\u662f\u8bcd\u5d4c\u5165\uff08embedings for words\uff09\u3002\u4f60\u5df2\u7ecf\u770b\u5230\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u5c06\u6807\u8bb0\u8f6c\u6362\u6210\u4e86\u6570\u5b57\u3002\u56e0\u6b64\uff0c\u5982\u679c\u67d0\u4e2a\u8bed\u6599\u5e93\u4e2d\u6709 N \u4e2a\u552f\u4e00\u7684\u8bcd\u5757\uff0c\u5b83\u4eec\u53ef\u4ee5\u7528 0 \u5230 N-1 \u4e4b\u95f4\u7684\u6574\u6570\u6765\u8868\u793a\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5c06\u7528\u5411\u91cf\u6765\u8868\u793a\u8fd9\u4e9b\u6574\u6570\u8bcd\u5757\u3002\u8fd9\u79cd\u5c06\u5355\u8bcd\u8868\u793a\u6210\u5411\u91cf\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u5355\u8bcd\u5d4c\u5165\u6216\u5355\u8bcd\u5411\u91cf\u3002\u8c37\u6b4c\u7684 Word2Vec \u662f\u5c06\u5355\u8bcd\u8f6c\u6362\u4e3a\u5411\u91cf\u7684\u6700\u53e4\u8001\u65b9\u6cd5\u4e4b\u4e00\u3002\u6b64\u5916\uff0c\u8fd8\u6709 Facebook \u7684 FastText \u548c\u65af\u5766\u798f\u5927\u5b66\u7684 GloVe\uff08\u7528\u4e8e\u5355\u8bcd\u8868\u793a\u7684\u5168\u5c40\u5411\u91cf\uff09\u3002\u8fd9\u4e9b\u65b9\u6cd5\u5f7c\u6b64\u5927\u76f8\u5f84\u5ead\u3002 \u5176\u57fa\u672c\u601d\u60f3\u662f\u5efa\u7acb\u4e00\u4e2a\u6d45\u5c42\u7f51\u7edc\uff0c\u901a\u8fc7\u91cd\u6784\u8f93\u5165\u53e5\u5b50\u6765\u5b66\u4e60\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u5468\u56f4\u7684\u6240\u6709\u5355\u8bcd\u6765\u8bad\u7ec3\u7f51\u7edc\u9884\u6d4b\u4e00\u4e2a\u7f3a\u5931\u7684\u5355\u8bcd\uff0c\u5728\u6b64\u8fc7\u7a0b\u4e2d\uff0c\u7f51\u7edc\u5c06\u5b66\u4e60\u5e76\u66f4\u65b0\u6240\u6709\u76f8\u5173\u5355\u8bcd\u7684\u5d4c\u5165\u3002\u8fd9\u79cd\u65b9\u6cd5\u4e5f\u88ab\u79f0\u4e3a\u8fde\u7eed\u8bcd\u888b\u6216 CBoW \u6a21\u578b\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e2a\u5355\u8bcd\u6765\u9884\u6d4b\u4e0a\u4e0b\u6587\u4e2d\u7684\u5355\u8bcd\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684\u8df3\u683c\u6a21\u578b\u3002Word2Vec \u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\u5b66\u4e60\u5d4c\u5165\u3002 FastText \u53ef\u4ee5\u5b66\u4e60\u5b57\u7b26 n-gram \u7684\u5d4c\u5165\u3002\u548c\u5355\u8bcd n-gram \u4e00\u6837\uff0c\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7684\u662f\u5b57\u7b26\uff0c\u5219\u79f0\u4e3a\u5b57\u7b26 n-gram\uff0c\u6700\u540e\uff0cGloVe \u901a\u8fc7\u5171\u73b0\u77e9\u9635\u6765\u5b66\u4e60\u8fd9\u4e9b\u5d4c\u5165\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0c\u6240\u6709\u8fd9\u4e9b\u4e0d\u540c\u7c7b\u578b\u7684\u5d4c\u5165\u6700\u7ec8\u90fd\u4f1a\u8fd4\u56de\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u8bed\u6599\u5e93\uff08\u4f8b\u5982\u82f1\u8bed\u7ef4\u57fa\u767e\u79d1\uff09\u4e2d\u7684\u5355\u8bcd\uff0c\u503c\u662f\u5927\u5c0f\u4e3a N\uff08\u901a\u5e38\u4e3a 300\uff09\u7684\u5411\u91cf\u3002 \u56fe 1\uff1a\u53ef\u89c6\u5316\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u3002 \u56fe 1 \u663e\u793a\u4e86\u4e8c\u7ef4\u5355\u8bcd\u5d4c\u5165\u7684\u53ef\u89c6\u5316\u6548\u679c\u3002\u5047\u8bbe\u6211\u4eec\u4ee5\u67d0\u79cd\u65b9\u5f0f\u5b8c\u6210\u4e86\u8bcd\u8bed\u7684\u4e8c\u7ef4\u8868\u793a\u3002\u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u4eceBerlin\uff08\u5fb7\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u4e2d\u51cf\u53bb\u5fb7\u56fd\uff08Germany\uff09\u7684\u5411\u91cf\uff0c\u518d\u52a0\u4e0a\u6cd5\u56fd\uff08france\uff09\u7684\u5411\u91cf\uff0c\u5c31\u4f1a\u5f97\u5230\u4e00\u4e2a\u63a5\u8fd1Paris\uff08\u6cd5\u56fd\u9996\u90fd\uff09\u7684\u5411\u91cf\u3002\u7531\u6b64\u53ef\u89c1\uff0c\u5d4c\u5165\u5f0f\u4e5f\u80fd\u8fdb\u884c\u7c7b\u6bd4\u3002 \u8fd9\u5e76\u4e0d\u603b\u662f\u6b63\u786e\u7684\uff0c\u4f46\u8fd9\u6837\u7684\u4f8b\u5b50\u6709\u52a9\u4e8e\u7406\u89e3\u5355\u8bcd\u5d4c\u5165\u7684\u4f5c\u7528\u3002\u50cf \"\u55e8\uff0c\u4f60\u597d\u5417 \"\u8fd9\u6837\u7684\u53e5\u5b50\u53ef\u4ee5\u7528\u4e0b\u9762\u7684\u4e00\u5806\u5411\u91cf\u6765\u8868\u793a\u3002 hi \u2500> [vector (v1) of size 300] , \u2500> [vector (v2) of size 300] how \u2500> [vector (v3) of size 300] are \u2500> [vector (v4) of size 300] you \u2500> [vector (v5) of size 300] ? \u2500> [vector (v6) of size 300] \u4f7f\u7528\u8fd9\u4e9b\u4fe1\u606f\u6709\u591a\u79cd\u65b9\u6cd5\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u4e4b\u4e00\u5c31\u662f\u4f7f\u7528\u5d4c\u5165\u5411\u91cf\u3002\u5982\u4e0a\u4f8b\u6240\u793a\uff0c\u6bcf\u4e2a\u5355\u8bcd\u90fd\u6709\u4e00\u4e2a 1x300 \u7684\u5d4c\u5165\u5411\u91cf\u3002\u5229\u7528\u8fd9\u4e9b\u4fe1\u606f\uff0c\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u51fa\u6574\u4e2a\u53e5\u5b50\u7684\u5d4c\u5165\u3002\u8ba1\u7b97\u65b9\u6cd5\u6709\u591a\u79cd\u3002\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\u5982\u4e0b\u6240\u793a\u3002\u5728\u8fd9\u4e2a\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u5c06\u7ed9\u5b9a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u5411\u91cf\u63d0\u53d6\u51fa\u6765\uff0c\u7136\u540e\u4ece\u6240\u6709\u6807\u8bb0\u8bcd\u7684\u5355\u8bcd\u5411\u91cf\u4e2d\u521b\u5efa\u4e00\u4e2a\u5f52\u4e00\u5316\u7684\u5355\u8bcd\u5411\u91cf\u3002\u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u53e5\u5b50\u5411\u91cf\u3002 import numpy as np def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): words = str ( s ) . lower () words = tokenizer ( words ) words = [ w for w in words if not w in stop_words ] words = [ w for w in words if w . isalpha ()] M = [] for w in words : if w in embedding_dict : M . append ( embedding_dict [ w ]) if len ( M ) == 0 : return np . zeros ( 300 ) M = np . array ( M ) v = M . sum ( axis = 0 ) return v / np . sqrt (( v ** 2 ) . sum ()) \u6211\u4eec\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5c06\u6240\u6709\u793a\u4f8b\u8f6c\u6362\u6210\u4e00\u4e2a\u5411\u91cf\u3002\u6211\u4eec\u80fd\u5426\u4f7f\u7528 fastText \u5411\u91cf\u6765\u6539\u8fdb\u4e4b\u524d\u7684\u7ed3\u679c\uff1f\u6bcf\u7bc7\u8bc4\u8bba\u90fd\u6709 300 \u4e2a\u7279\u5f81\u3002 import io import numpy as np import pandas as pd from nltk.tokenize import word_tokenize from sklearn import linear_model from sklearn import metrics from sklearn import model_selection from sklearn.feature_extraction.text import TfidfVectorizer def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def sentence_to_vec ( s , embedding_dict , stop_words , tokenizer ): if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df = df . sample ( frac = 1 ) . reset_index ( drop = True ) print ( \"Loading embeddings\" ) embeddings = load_vectors ( \"../input/crawl-300d-2M.vec\" ) print ( \"Creating sentence vectors\" ) vectors = [] for review in df . review . values : vectors . append ( sentence_to_vec ( s = review , embedding_dict = embeddings , stop_words = [], tokenizer = word_tokenize ) ) vectors = np . array ( vectors ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for fold_ , ( t_ , v_ ) in enumerate ( kf . split ( X = vectors , y = y )): print ( f \"Training fold: { fold_ } \" ) xtrain = vectors [ t_ , :] ytrain = y [ t_ ] xtest = vectors [ v_ , :] ytest = y [ v_ ] model = linear_model . LogisticRegression () model . fit ( xtrain , ytrain ) preds = model . predict ( xtest ) accuracy = metrics . accuracy_score ( ytest , preds ) print ( f \"Accuracy = { accuracy } \" ) print ( \"\" ) \u8fd9\u5c06\u5f97\u5230\u5982\u4e0b\u7ed3\u679c\uff1a Loading embeddings Creating sentence vectors Training fold : 0 Accuracy = 0.8619 Training fold : 1 Accuracy = 0.8661 Training fold : 2 Accuracy = 0.8544 Training fold : 3 Accuracy = 0.8624 Training fold : 4 Accuracy = 0.8595 Wow\uff01\u771f\u662f\u51fa\u4e4e\u610f\u6599\u3002\u6211\u4eec\u6240\u505a\u7684\u4e00\u5207\u90fd\u662f\u4e3a\u4e86\u4f7f\u7528 FastText \u5d4c\u5165\u3002\u8bd5\u7740\u628a\u5d4c\u5165\u5f0f\u6362\u6210 GloVe\uff0c\u770b\u770b\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002\u6211\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u3002 \u5f53\u6211\u4eec\u8c08\u8bba\u6587\u672c\u6570\u636e\u65f6\uff0c\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\u4e00\u4ef6\u4e8b\u3002\u6587\u672c\u6570\u636e\u4e0e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u975e\u5e38\u76f8\u4f3c\u3002\u5982\u56fe 2 \u6240\u793a\uff0c\u6211\u4eec\u8bc4\u8bba\u4e2d\u7684\u4efb\u4f55\u6837\u672c\u90fd\u662f\u5728\u4e0d\u540c\u65f6\u95f4\u6233\u4e0a\u6309\u9012\u589e\u987a\u5e8f\u6392\u5217\u7684\u6807\u8bb0\u5e8f\u5217\uff0c\u6bcf\u4e2a\u6807\u8bb0\u90fd\u53ef\u4ee5\u8868\u793a\u4e3a\u4e00\u4e2a\u5411\u91cf/\u5d4c\u5165\u3002 \u56fe 2\uff1a\u5c06\u6807\u8bb0\u8868\u793a\u4e3a\u5d4c\u5165\uff0c\u5e76\u5c06\u5176\u89c6\u4e3a\u65f6\u95f4\u5e8f\u5217 \u8fd9\u610f\u5473\u7740\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e7f\u6cdb\u7528\u4e8e\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u7684\u6a21\u578b\uff0c\u4f8b\u5982\u957f\u77ed\u671f\u8bb0\u5fc6\uff08LSTM\uff09\u6216\u95e8\u63a7\u9012\u5f52\u5355\u5143\uff08GRU\uff09\uff0c\u751a\u81f3\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff08CNN\uff09\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u5728\u8be5\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u7b80\u5355\u7684\u53cc\u5411 LSTM \u6a21\u578b\u3002 \u9996\u5148\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u9879\u76ee\u3002\u4f60\u53ef\u4ee5\u968f\u610f\u7ed9\u5b83\u547d\u540d\u3002\u7136\u540e\uff0c\u6211\u4eec\u7684\u7b2c\u4e00\u6b65\u5c06\u662f\u5206\u5272\u6570\u636e\u8fdb\u884c\u4ea4\u53c9\u9a8c\u8bc1\u3002 import pandas as pd from sklearn import model_selection if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb.csv\" ) df . sentiment = df . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df [ \"kfold\" ] = - 1 df = df . sample ( frac = 1 ) . reset_index ( drop = True ) y = df . sentiment . values kf = model_selection . StratifiedKFold ( n_splits = 5 ) for f , ( t_ , v_ ) in enumerate ( kf . split ( X = df , y = y )): df . loc [ v_ , 'kfold' ] = f df . to_csv ( \"../input/imdb_folds.csv\" , index = False ) \u5c06\u6570\u636e\u96c6\u5212\u5206\u4e3a\u591a\u4e2a\u6298\u53e0\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5728 dataset.py \u4e2d\u521b\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u6570\u636e\u96c6\u7c7b\u3002\u6570\u636e\u96c6\u7c7b\u4f1a\u8fd4\u56de\u4e00\u4e2a\u8bad\u7ec3\u6216\u9a8c\u8bc1\u6570\u636e\u6837\u672c\u3002 import torch class IMDBDataset : def __init__ ( self , reviews , targets ): self . reviews = reviews self . target = targets def __len__ ( self ): return len ( self . reviews ) def __getitem__ ( self , item ): review = self . reviews [ item , :] target = self . target [ item ] return { \"review\" : torch . tensor ( review , dtype = torch . long ), \"target\" : torch . tensor ( target , dtype = torch . float ) } \u5b8c\u6210\u6570\u636e\u96c6\u5206\u7c7b\u540e\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u521b\u5efa lstm.py\uff0c\u5176\u4e2d\u5305\u542b\u6211\u4eec\u7684 LSTM \u6a21\u578b import torch import torch.nn as nn class LSTM ( nn . Module ): def __init__ ( self , embedding_matrix ): super ( LSTM , self ) . __init__ () num_words = embedding_matrix . shape [ 0 ] embed_dim = embedding_matrix . shape [ 1 ] self . embedding = nn . Embedding ( num_embeddings = num_words , embedding_dim = embed_dim ) self . embedding . weight = nn . Parameter ( torch . tensor ( embedding_matrix , dtype = torch . float32 ) ) self . embedding . weight . requires_grad = False self . lstm = nn . LSTM ( embed_dim , 128 , bidirectional = True , batch_first = True , ) self . out = nn . Linear ( 512 , 1 ) def forward ( self , x ): x = self . embedding ( x ) x , _ = self . lstm ( x ) avg_pool = torch . mean ( x , 1 ) max_pool , _ = torch . max ( x , 1 ) out = torch . cat (( avg_pool , max_pool ), 1 ) out = self . out ( out ) return out \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa engine.py\uff0c\u5176\u4e2d\u5305\u542b\u8bad\u7ec3\u548c\u8bc4\u4f30\u51fd\u6570\u3002 import torch import torch.nn as nn def train ( data_loader , model , optimizer , device ): model . train () for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () predictions = model ( reviews ) loss = nn . BCEWithLogitsLoss ()( predictions , targets . view ( - 1 , 1 ) ) loss . backward () optimizer . step () def evaluate ( data_loader , model , device ): final_predictions = [] final_targets = [] model . eval () with torch . no_grad (): for data in data_loader : reviews = data [ \"review\" ] targets = data [ \"target\" ] reviews = reviews . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) predictions = model ( reviews ) predictions = predictions . cpu () . numpy () . tolist () targets = data [ \"target\" ] . cpu () . numpy () . tolist () final_predictions . extend ( predictions ) final_targets . extend ( targets ) return final_predictions , final_targets \u8fd9\u4e9b\u51fd\u6570\u5c06\u5728 train.py \u4e2d\u4e3a\u6211\u4eec\u63d0\u4f9b\u5e2e\u52a9\uff0c\u8be5\u51fd\u6570\u7528\u4e8e\u8bad\u7ec3\u591a\u4e2a\u6298\u53e0\u3002 import io import torch import numpy as np import pandas as pd import tensorflow as tf from sklearn import metrics import config import dataset import engine import lstm def load_vectors ( fname ): fin = io . open ( fname , 'r' , encoding = 'utf-8' , newline = ' \\n ' , errors = 'ignore' ) n , d = map ( int , fin . readline () . split ()) data = {} for line in fin : tokens = line . rstrip () . split ( ' ' ) data [ tokens [ 0 ]] = list ( map ( float , tokens [ 1 :])) return data def create_embedding_matrix ( word_index , embedding_dict ): embedding_matrix = np . zeros (( len ( word_index ) + 1 , 300 )) for word , i in word_index . items (): if word in embedding_dict : embedding_matrix [ i ] = embedding_dict [ word ] return embedding_matrix def run ( df , fold ): train_df = df [ df . kfold != fold ] . reset_index ( drop = True ) valid_df = df [ df . kfold == fold ] . reset_index ( drop = True ) print ( \"Fitting tokenizer\" ) tokenizer = tf . keras . preprocessing . text . Tokenizer () tokenizer . fit_on_texts ( df . review . values . tolist ()) xtrain = tokenizer . texts_to_sequences ( train_df . review . values ) xtest = tokenizer . texts_to_sequences ( valid_df . review . values ) xtrain = tf . keras . preprocessing . sequence . pad_sequences ( xtrain , maxlen = config . MAX_LEN ) xtest = tf . keras . preprocessing . sequence . pad_sequences ( xtest , maxlen = config . MAX_LEN ) train_dataset = dataset . IMDBDataset ( reviews = xtrain , targets = train_df . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 2 ) valid_dataset = dataset . IMDBDataset ( reviews = xtest , targets = valid_df . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) print ( \"Loading embeddings\" ) embedding_dict = load_vectors ( \"../input/crawl-300d-2M.vec\" ) embedding_matrix = create_embedding_matrix ( tokenizer . word_index , embedding_dict ) device = torch . device ( \"cuda\" ) model = lstm . LSTM ( embedding_matrix ) model . to ( device ) optimizer = torch . optim . Adam ( model . parameters (), lr = 1e-3 ) print ( \"Training Model\" ) best_accuracy = 0 early_stopping_counter = 0 for epoch in range ( config . EPOCHS ): engine . train ( train_data_loader , model , optimizer , device ) outputs , targets = engine . evaluate ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"FOLD: { fold } , Epoch: { epoch } , Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : best_accuracy = accuracy else : early_stopping_counter += 1 if early_stopping_counter > 2 : break if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/imdb_folds.csv\" ) run ( df , fold = 0 ) run ( df , fold = 1 ) run ( df , fold = 2 ) run ( df , fold = 3 ) run ( df , fold = 4 ) \u6700\u540e\u662f config.py\u3002 MAX_LEN = 128 TRAIN_BATCH_SIZE = 16 VALID_BATCH_SIZE = 8 EPOCHS = 10 \u8ba9\u6211\u4eec\u770b\u770b\u8f93\u51fa\uff1a FOLD : 0 , Epoch : 3 , Accuracy Score = 0.9015 FOLD : 1 , Epoch : 4 , Accuracy Score = 0.9007 FOLD : 2 , Epoch : 3 , Accuracy Score = 0.8924 FOLD : 3 , Epoch : 2 , Accuracy Score = 0.9 FOLD : 4 , Epoch : 1 , Accuracy Score = 0.878 \u8fd9\u662f\u8fc4\u4eca\u4e3a\u6b62\u6211\u4eec\u83b7\u5f97\u7684\u6700\u597d\u6210\u7ee9\u3002 \u8bf7\u6ce8\u610f\uff0c\u6211\u53ea\u663e\u793a\u4e86\u6bcf\u4e2a\u6298\u53e0\u4e2d\u7cbe\u5ea6\u6700\u9ad8\u7684Epoch\u3002 \u4f60\u4e00\u5b9a\u5df2\u7ecf\u6ce8\u610f\u5230\uff0c\u6211\u4eec\u4f7f\u7528\u4e86\u9884\u5148\u8bad\u7ec3\u7684\u5d4c\u5165\u548c\u7b80\u5355\u7684\u53cc\u5411 LSTM\u3002 \u5982\u679c\u4f60\u60f3\u6539\u53d8\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u53ea\u6539\u53d8 lstm.py \u4e2d\u7684\u6a21\u578b\u5e76\u4fdd\u6301\u4e00\u5207\u4e0d\u53d8\u3002 \u8fd9\u79cd\u4ee3\u7801\u53ea\u9700\u8981\u5f88\u5c11\u7684\u5b9e\u9a8c\u6539\u52a8\uff0c\u5e76\u4e14\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \u4f8b\u5982\uff0c\u60a8\u53ef\u4ee5\u81ea\u5df1\u5b66\u4e60\u5d4c\u5165\u800c\u4e0d\u662f\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6\u4e00\u4e9b\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u7ec4\u5408\u591a\u4e2a\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528GRU\uff0c\u60a8\u53ef\u4ee5\u5728\u5d4c\u5165\u540e\u4f7f\u7528\u7a7a\u95f4dropout\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0GRU LSTM \u5c42\u4e4b\u540e\uff0c\u60a8\u53ef\u4ee5\u6dfb\u52a0\u4e24\u4e2a LSTM \u5c42\uff0c\u60a8\u53ef\u4ee5\u8fdb\u884c LSTM-GRU-LSTM \u914d\u7f6e\uff0c\u60a8\u53ef\u4ee5\u7528\u5377\u79ef\u5c42\u66ff\u6362 LSTM \u7b49\uff0c\u800c\u65e0\u9700\u5bf9\u4ee3\u7801\u8fdb\u884c\u592a\u591a\u66f4\u6539\u3002 \u6211\u63d0\u5230\u7684\u5927\u90e8\u5206\u5185\u5bb9\u53ea\u9700\u8981\u66f4\u6539\u6a21\u578b\u7c7b\u3002 \u5f53\u60a8\u4f7f\u7528\u9884\u8bad\u7ec3\u7684\u5d4c\u5165\u65f6\uff0c\u5c1d\u8bd5\u67e5\u770b\u6709\u591a\u5c11\u5355\u8bcd\u65e0\u6cd5\u627e\u5230\u5d4c\u5165\u4ee5\u53ca\u539f\u56e0\u3002 \u9884\u8bad\u7ec3\u5d4c\u5165\u7684\u5355\u8bcd\u8d8a\u591a\uff0c\u7ed3\u679c\u5c31\u8d8a\u597d\u3002 \u6211\u5411\u60a8\u5c55\u793a\u4ee5\u4e0b\u672a\u6ce8\u91ca\u7684 (!) \u51fd\u6570\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u5b83\u4e3a\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u8bad\u7ec3\u5d4c\u5165\u521b\u5efa\u5d4c\u5165\u77e9\u9635\uff0c\u5176\u683c\u5f0f\u4e0e glove \u6216 fastText \u76f8\u540c\uff08\u53ef\u80fd\u9700\u8981\u8fdb\u884c\u4e00\u4e9b\u66f4\u6539\uff09\u3002 def load_embeddings ( word_index , embedding_file , vector_length = 300 ): max_features = len ( word_index ) + 1 words_to_find = list ( word_index . keys ()) more_words_to_find = [] for wtf in words_to_find : more_words_to_find . append ( wtf ) more_words_to_find . append ( str ( wtf ) . capitalize ()) more_words_to_find = set ( more_words_to_find ) def get_coefs ( word , * arr ): return word , np . asarray ( arr , dtype = 'float32' ) embeddings_index = dict ( get_coefs ( * o . strip () . split ( \" \" )) for o in open ( embedding_file ) if o . split ( \" \" )[ 0 ] in more_words_to_find and len ( o ) > 100 ) embedding_matrix = np . zeros (( max_features , vector_length )) for word , i in word_index . items (): if i >= max_features : continue embedding_vector = embeddings_index . get ( word ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . capitalize () ) if embedding_vector is None : embedding_vector = embeddings_index . get ( str ( word ) . upper () ) if ( embedding_vector is not None and len ( embedding_vector ) == vector_length ): embedding_matrix [ i ] = embedding_vector return embedding_matrix \u9605\u8bfb\u5e76\u8fd0\u884c\u4e0a\u9762\u7684\u51fd\u6570\uff0c\u770b\u770b\u53d1\u751f\u4e86\u4ec0\u4e48\u3002 \u8be5\u51fd\u6570\u8fd8\u53ef\u4ee5\u4fee\u6539\u4e3a\u4f7f\u7528\u8bcd\u5e72\u8bcd\u6216\u8bcd\u5f62\u8fd8\u539f\u8bcd\u3002 \u6700\u540e\uff0c\u60a8\u5e0c\u671b\u8bad\u7ec3\u8bed\u6599\u5e93\u4e2d\u7684\u672a\u77e5\u5355\u8bcd\u6570\u91cf\u6700\u5c11\u3002 \u53e6\u4e00\u4e2a\u6280\u5de7\u662f\u5b66\u4e60\u5d4c\u5165\u5c42\uff0c\u5373\u4f7f\u5176\u53ef\u8bad\u7ec3\uff0c\u7136\u540e\u8bad\u7ec3\u7f51\u7edc\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u5206\u7c7b\u95ee\u9898\u6784\u5efa\u4e86\u5f88\u591a\u6a21\u578b\u3002 \u7136\u800c\uff0c\u73b0\u5728\u662f\u5e03\u5076\u65f6\u4ee3\uff0c\u8d8a\u6765\u8d8a\u591a\u7684\u4eba\u8f6c\u5411\u57fa\u4e8e\u53d8\u5f62\u91d1\u521a\u7684\u6a21\u578b\u3002 \u57fa\u4e8e Transformer \u7684\u7f51\u7edc\u80fd\u591f\u5904\u7406\u672c\u8d28\u4e0a\u957f\u671f\u7684\u4f9d\u8d56\u5173\u7cfb\u3002 LSTM \u4ec5\u5f53\u5b83\u770b\u5230\u524d\u4e00\u4e2a\u5355\u8bcd\u65f6\u624d\u67e5\u770b\u4e0b\u4e00\u4e2a\u5355\u8bcd\u3002 \u53d8\u538b\u5668\u7684\u60c5\u51b5\u5e76\u975e\u5982\u6b64\u3002 \u5b83\u53ef\u4ee5\u540c\u65f6\u67e5\u770b\u6574\u4e2a\u53e5\u5b50\u4e2d\u7684\u6240\u6709\u5355\u8bcd\u3002 \u56e0\u6b64\uff0c\u53e6\u4e00\u4e2a\u4f18\u70b9\u662f\u5b83\u53ef\u4ee5\u8f7b\u677e\u5e76\u884c\u5316\u5e76\u66f4\u6709\u6548\u5730\u4f7f\u7528 GPU\u3002 Transformers \u662f\u4e00\u4e2a\u975e\u5e38\u5e7f\u6cdb\u7684\u8bdd\u9898\uff0c\u6709\u592a\u591a\u7684\u6a21\u578b\uff1a BERT\u3001RoBERTa\u3001XLNet\u3001XLM-RoBERTa\u3001T5 \u7b49\u3002\u6211\u5c06\u5411\u60a8\u5c55\u793a\u4e00\u79cd\u53ef\u7528\u4e8e\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\uff08T5 \u9664\u5916\uff09\u8fdb\u884c\u5206\u7c7b\u7684\u901a\u7528\u65b9\u6cd5 \u6211\u4eec\u4e00\u76f4\u5728\u8ba8\u8bba\u7684\u95ee\u9898\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u4e9b\u53d8\u538b\u5668\u9700\u8981\u8bad\u7ec3\u5b83\u4eec\u6240\u9700\u7684\u8ba1\u7b97\u80fd\u529b\u3002 \u56e0\u6b64\uff0c\u5982\u679c\u60a8\u6ca1\u6709\u9ad8\u7aef\u7cfb\u7edf\uff0c\u4e0e\u57fa\u4e8e LSTM \u6216 TF-IDF \u7684\u6a21\u578b\u76f8\u6bd4\uff0c\u8bad\u7ec3\u6a21\u578b\u53ef\u80fd\u9700\u8981\u66f4\u957f\u7684\u65f6\u95f4\u3002 \u6211\u4eec\u8981\u505a\u7684\u7b2c\u4e00\u4ef6\u4e8b\u662f\u521b\u5efa\u4e00\u4e2a\u914d\u7f6e\u6587\u4ef6\u3002 import transformers MAX_LEN = 512 TRAIN_BATCH_SIZE = 8 VALID_BATCH_SIZE = 4 EPOCHS = 10 BERT_PATH = \"../input/bert_base_uncased/\" MODEL_PATH = \"model.bin\" TRAINING_FILE = \"../input/imdb.csv\" TOKENIZER = transformers . BertTokenizer . from_pretrained ( BERT_PATH , do_lower_case = True ) \u8fd9\u91cc\u7684\u914d\u7f6e\u6587\u4ef6\u662f\u6211\u4eec\u5b9a\u4e49\u5206\u8bcd\u5668\u548c\u5176\u4ed6\u6211\u4eec\u60f3\u8981\u7ecf\u5e38\u66f4\u6539\u7684\u53c2\u6570\u7684\u552f\u4e00\u5730\u65b9 \u2014\u2014 \u8fd9\u6837\u6211\u4eec\u5c31\u53ef\u4ee5\u505a\u5f88\u591a\u5b9e\u9a8c\u800c\u4e0d\u9700\u8981\u8fdb\u884c\u5927\u91cf\u66f4\u6539\u3002 \u4e0b\u4e00\u6b65\u662f\u6784\u5efa\u6570\u636e\u96c6\u7c7b\u3002 import config import torch class BERTDataset : def __init__ ( self , review , target ): self . review = review self . target = target self . tokenizer = config . TOKENIZER self . max_len = config . MAX_LEN def __len__ ( self ): return len ( self . review ) def __getitem__ ( self , item ): review = str ( self . review [ item ]) review = \" \" . join ( review . split ()) inputs = self . tokenizer . encode_plus ( review , None , add_special_tokens = True , max_length = self . max_len , pad_to_max_length = True , ) ids = inputs [ \"input_ids\" ] mask = inputs [ \"attention_mask\" ] token_type_ids = inputs [ \"token_type_ids\" ] return { \"ids\" : torch . tensor ( ids , dtype = torch . long ), \"mask\" : torch . tensor ( mask , dtype = torch . long ), \"token_type_ids\" : torch . tensor ( token_type_ids , dtype = torch . long ), \"targets\" : torch . tensor ( self . target [ item ], dtype = torch . float ) } \u73b0\u5728\u6211\u4eec\u6765\u5230\u4e86\u8be5\u9879\u76ee\u7684\u6838\u5fc3\uff0c\u5373\u6a21\u578b\u3002 import config import transformers import torch.nn as nn class BERTBaseUncased ( nn . Module ): def __init__ ( self ): super ( BERTBaseUncased , self ) . __init__ () self . bert = transformers . BertModel . from_pretrained ( config . BERT_PATH ) self . bert_drop = nn . Dropout ( 0.3 ) self . out = nn . Linear ( 768 , 1 ) def forward ( self , ids , mask , token_type_ids ): hidden state _ , o2 = self . bert ( ids , attention_mask = mask , token_type_ids = token_type_ids ) bo = self . bert_drop ( o2 ) output = self . out ( bo ) return output \u8be5\u6a21\u578b\u8fd4\u56de\u5355\u4e2a\u8f93\u51fa\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u5e26\u6709 logits \u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\uff0c\u5b83\u9996\u5148\u5e94\u7528 sigmoid\uff0c\u7136\u540e\u8ba1\u7b97\u635f\u5931\u3002 \u8fd9\u662f\u5728engine.py \u4e2d\u5b8c\u6210\u7684\u3002 import torch import torch.nn as nn def loss_fn ( outputs , targets ): return nn . BCEWithLogitsLoss ()( outputs , targets . view ( - 1 , 1 )) def train_fn ( data_loader , model , optimizer , device , scheduler ): model . train () for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) optimizer . zero_grad () outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) loss = loss_fn ( outputs , targets ) loss . backward () optimizer . step () scheduler . step () def eval_fn ( data_loader , model , device ): model . eval () fin_targets = [] fin_outputs = [] with torch . no_grad (): for d in data_loader : ids = d [ \"ids\" ] token_type_ids = d [ \"token_type_ids\" ] mask = d [ \"mask\" ] targets = d [ \"targets\" ] ids = ids . to ( device , dtype = torch . long ) token_type_ids = token_type_ids . to ( device , dtype = torch . long ) mask = mask . to ( device , dtype = torch . long ) targets = targets . to ( device , dtype = torch . float ) outputs = model ( ids = ids , mask = mask , token_type_ids = token_type_ids ) targets = targets . cpu () . detach () fin_targets . extend ( targets . numpy () . tolist ()) outputs = torch . sigmoid ( outputs ) . cpu () . detach () fin_outputs . extend ( outputs . numpy () . tolist ()) return fin_outputs , fin_targets \u6700\u540e\uff0c\u6211\u4eec\u51c6\u5907\u597d\u8bad\u7ec3\u4e86\u3002 \u6211\u4eec\u6765\u770b\u770b\u8bad\u7ec3\u811a\u672c\u5427\uff01 import config import dataset import engine import torch import pandas as pd import torch.nn as nn import numpy as np from model import BERTBaseUncased from sklearn import model_selection from sklearn import metrics from transformers import AdamW from transformers import get_linear_schedule_with_warmup def train (): dfx = pd . read_csv ( config . TRAINING_FILE ) . fillna ( \"none\" ) dfx . sentiment = dfx . sentiment . apply ( lambda x : 1 if x == \"positive\" else 0 ) df_train , df_valid = model_selection . train_test_split ( dfx , test_size = 0.1 , random_state = 42 , stratify = dfx . sentiment . values ) df_train = df_train . reset_index ( drop = True ) df_valid = df_valid . reset_index ( drop = True ) train_dataset = dataset . BERTDataset ( review = df_train . review . values , target = df_train . sentiment . values ) train_data_loader = torch . utils . data . DataLoader ( train_dataset , batch_size = config . TRAIN_BATCH_SIZE , num_workers = 4 ) valid_dataset = dataset . BERTDataset ( review = df_valid . review . values , target = df_valid . sentiment . values ) valid_data_loader = torch . utils . data . DataLoader ( valid_dataset , batch_size = config . VALID_BATCH_SIZE , num_workers = 1 ) device = torch . device ( \"cuda\" ) model = BERTBaseUncased () model . to ( device ) param_optimizer = list ( model . named_parameters ()) no_decay = [ \"bias\" , \"LayerNorm.bias\" , \"LayerNorm.weight\" ] optimizer_parameters = [ { \"params\" : [ p for n , p in param_optimizer if not any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.001 , } \uff0c { \"params\" : [ p for n , p in param_optimizer if any ( nd in n for nd in no_decay ) ], \"weight_decay\" : 0.0 , }] num_train_steps = int ( len ( df_train ) / config . TRAIN_BATCH_SIZE * config . EPOCHS ) optimizer = AdamW ( optimizer_parameters , lr = 3e-5 ) scheduler = get_linear_schedule_with_warmup ( optimizer , num_warmup_steps = 0 , num_training_steps = num_train_steps ) model = nn . DataParallel ( model ) best_accuracy = 0 for epoch in range ( config . EPOCHS ): engine . train_fn ( train_data_loader , model , optimizer , device , scheduler ) outputs , targets = engine . eval_fn ( valid_data_loader , model , device ) outputs = np . array ( outputs ) >= 0.5 accuracy = metrics . accuracy_score ( targets , outputs ) print ( f \"Accuracy Score = { accuracy } \" ) if accuracy > best_accuracy : torch . save ( model . state_dict (), config . MODEL_PATH ) best_accuracy = accuracy if __name__ == \"__main__\" : train () \u4e4d\u4e00\u770b\u53ef\u80fd\u770b\u8d77\u6765\u5f88\u591a\uff0c\u4f46\u4e00\u65e6\u60a8\u4e86\u89e3\u4e86\u5404\u4e2a\u7ec4\u4ef6\uff0c\u5c31\u4e0d\u518d\u90a3\u4e48\u7b80\u5355\u4e86\u3002 \u60a8\u53ea\u9700\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u5373\u53ef\u8f7b\u677e\u5c06\u5176\u66f4\u6539\u4e3a\u60a8\u60f3\u8981\u4f7f\u7528\u7684\u4efb\u4f55\u5176\u4ed6\u53d8\u538b\u5668\u6a21\u578b\u3002 \u8be5\u6a21\u578b\u7684\u51c6\u786e\u7387\u4e3a 93%\uff01 \u54c7\uff01 \u8fd9\u6bd4\u4efb\u4f55\u5176\u4ed6\u6a21\u578b\u90fd\u8981\u597d\u5f97\u591a\u3002 \u4f46\u662f\u8fd9\u503c\u5f97\u5417\uff1f \u6211\u4eec\u4f7f\u7528 LSTM \u80fd\u591f\u5b9e\u73b0 90% \u7684\u76ee\u6807\uff0c\u800c\u4e14\u5b83\u4eec\u66f4\u7b80\u5355\u3001\u66f4\u5bb9\u6613\u8bad\u7ec3\u5e76\u4e14\u63a8\u7406\u901f\u5ea6\u66f4\u5feb\u3002 \u901a\u8fc7\u4f7f\u7528\u4e0d\u540c\u7684\u6570\u636e\u5904\u7406\u6216\u8c03\u6574\u5c42\u3001\u8282\u70b9\u3001dropout\u3001\u5b66\u4e60\u7387\u3001\u66f4\u6539\u4f18\u5316\u5668\u7b49\u53c2\u6570\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6a21\u578b\u6539\u8fdb\u4e00\u4e2a\u767e\u5206\u70b9\u3002\u7136\u540e\u6211\u4eec\u5c06\u4ece BERT \u4e2d\u83b7\u5f97\u7ea6 2% \u7684\u6536\u76ca\u3002 \u53e6\u4e00\u65b9\u9762\uff0cBERT \u7684\u8bad\u7ec3\u65f6\u95f4\u8981\u957f\u5f97\u591a\uff0c\u53c2\u6570\u5f88\u591a\uff0c\u800c\u4e14\u63a8\u7406\u901f\u5ea6\u4e5f\u5f88\u6162\u3002 \u6700\u540e\uff0c\u60a8\u5e94\u8be5\u5ba1\u89c6\u81ea\u5df1\u7684\u4e1a\u52a1\u5e76\u505a\u51fa\u660e\u667a\u7684\u9009\u62e9\u3002 \u4e0d\u8981\u4ec5\u4ec5\u56e0\u4e3a BERT\u201c\u9177\u201d\u800c\u9009\u62e9\u5b83\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6211\u4eec\u5728\u8fd9\u91cc\u8ba8\u8bba\u7684\u552f\u4e00\u4efb\u52a1\u662f\u5206\u7c7b\uff0c\u4f46\u5c06\u5176\u66f4\u6539\u4e3a\u56de\u5f52\u3001\u591a\u6807\u7b7e\u6216\u591a\u7c7b\u53ea\u9700\u8981\u66f4\u6539\u51e0\u884c\u4ee3\u7801\u3002 \u4f8b\u5982\uff0c\u591a\u7c7b\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u7684\u540c\u4e00\u95ee\u9898\u5c06\u6709\u591a\u4e2a\u8f93\u51fa\u548c\u4ea4\u53c9\u71b5\u635f\u5931\u3002 \u5176\u4ed6\u4e00\u5207\u90fd\u5e94\u8be5\u4fdd\u6301\u4e0d\u53d8\u3002 \u81ea\u7136\u8bed\u8a00\u5904\u7406\u975e\u5e38\u5e9e\u5927\uff0c\u6211\u4eec\u53ea\u8ba8\u8bba\u4e86\u5176\u4e2d\u7684\u4e00\u5c0f\u90e8\u5206\u3002 \u663e\u7136\uff0c\u8fd9\u662f\u4e00\u4e2a\u5f88\u5927\u7684\u6bd4\u4f8b\uff0c\u56e0\u4e3a\u5927\u591a\u6570\u5de5\u4e1a\u6a21\u578b\u90fd\u662f\u5206\u7c7b\u6216\u56de\u5f52\u6a21\u578b\u3002 \u5982\u679c\u6211\u5f00\u59cb\u8be6\u7ec6\u5199\u6240\u6709\u5185\u5bb9\uff0c\u6211\u6700\u7ec8\u53ef\u80fd\u4f1a\u5199\u51e0\u767e\u9875\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u51b3\u5b9a\u5c06\u6240\u6709\u5185\u5bb9\u5305\u542b\u5728\u4e00\u672c\u5355\u72ec\u7684\u4e66\u4e2d\uff1a\u63a5\u8fd1\uff08\u51e0\u4e4e\uff09\u4efb\u4f55 NLP \u95ee\u9898\uff01","title":"\u6587\u672c\u5206\u7c7b\u6216\u56de\u5f52\u65b9\u6cd5"},{"location":"%E6%97%A0%E7%9B%91%E7%9D%A3%E5%92%8C%E6%9C%89%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0/","text":"\u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60 \u5728\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u65f6\uff0c\u901a\u5e38\u6709\u4e24\u7c7b\u6570\u636e\uff08\u548c\u673a\u5668\u5b66\u4e60\u6a21\u578b\uff09\uff1a \u76d1\u7763\u6570\u636e\uff1a\u603b\u662f\u6709\u4e00\u4e2a\u6216\u591a\u4e2a\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807 \u65e0\u76d1\u7763\u6570\u636e\uff1a\u6ca1\u6709\u4efb\u4f55\u76ee\u6807\u53d8\u91cf\u3002 \u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\u3002\u6211\u4eec\u9700\u8981\u9884\u6d4b\u4e00\u4e2a\u503c\u7684\u95ee\u9898\u88ab\u79f0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u3002\u4f8b\u5982\uff0c\u5982\u679c\u95ee\u9898\u662f\u6839\u636e\u5386\u53f2\u623f\u4ef7\u9884\u6d4b\u623f\u4ef7\uff0c\u90a3\u4e48\u533b\u9662\u3001\u5b66\u6821\u6216\u8d85\u5e02\u7684\u5b58\u5728\uff0c\u4e0e\u6700\u8fd1\u516c\u5171\u4ea4\u901a\u7684\u8ddd\u79bb\u7b49\u7279\u5f81\u5c31\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u540c\u6837\uff0c\u5f53\u6211\u4eec\u5f97\u5230\u732b\u548c\u72d7\u7684\u56fe\u50cf\u65f6\uff0c\u6211\u4eec\u4e8b\u5148\u77e5\u9053\u54ea\u4e9b\u662f\u732b\uff0c\u54ea\u4e9b\u662f\u72d7\uff0c\u5982\u679c\u4efb\u52a1\u662f\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u6240\u63d0\u4f9b\u7684\u56fe\u50cf\u662f\u732b\u8fd8\u662f\u72d7\uff0c\u90a3\u4e48\u8fd9\u4e2a\u95ee\u9898\u5c31\u88ab\u8ba4\u4e3a\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002 \u56fe 1\uff1a\u6709\u76d1\u7763\u5b66\u4e60\u6570\u636e \u5982\u56fe 1 \u6240\u793a\uff0c\u6570\u636e\u7684\u6bcf\u4e00\u884c\u90fd\u4e0e\u4e00\u4e2a\u76ee\u6807\u6216\u6807\u7b7e\u76f8\u5173\u8054\u3002\u5217\u662f\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u884c\u4ee3\u8868\u4e0d\u540c\u7684\u6570\u636e\u70b9\uff0c\u901a\u5e38\u79f0\u4e3a\u6837\u672c\u3002\u793a\u4f8b\u4e2d\u7684\u5341\u4e2a\u6837\u672c\u6709\u5341\u4e2a\u7279\u5f81\u548c\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\uff0c\u76ee\u6807\u53d8\u91cf\u53ef\u4ee5\u662f\u6570\u5b57\u6216\u7c7b\u522b\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5206\u7c7b\u53d8\u91cf\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5b9e\u6570\uff0c\u95ee\u9898\u5c31\u88ab\u5b9a\u4e49\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6709\u76d1\u7763\u95ee\u9898\u53ef\u5206\u4e3a\u4e24\u4e2a\u5b50\u7c7b\uff1a \u5206\u7c7b\uff1a\u9884\u6d4b\u7c7b\u522b\uff0c\u5982\u732b\u6216\u72d7 \u56de\u5f52\uff1a\u9884\u6d4b\u503c\uff0c\u5982\u623f\u4ef7 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6709\u65f6\u6211\u4eec\u53ef\u80fd\u4f1a\u5728\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u4f7f\u7528\u56de\u5f52\uff0c\u8fd9\u53d6\u51b3\u4e8e\u7528\u4e8e\u8bc4\u4f30\u7684\u6307\u6807\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u7a0d\u540e\u4f1a\u8ba8\u8bba\u8fd9\u4e2a\u95ee\u9898\u3002 \u53e6\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u95ee\u9898\u662f\u65e0\u76d1\u7763\u7c7b\u578b\u3002 \u65e0\u76d1\u7763 \u6570\u636e\u96c6\u6ca1\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\uff0c\u4e00\u822c\u6765\u8bf4\uff0c\u4e0e\u6709\u76d1\u7763\u95ee\u9898\u76f8\u6bd4\uff0c\u5904\u7406\u65e0\u76d1\u7763\u6570\u636e\u96c6\u66f4\u5177\u6311\u6218\u6027\u3002 \u5047\u8bbe\u4f60\u5728\u4e00\u5bb6\u5904\u7406\u4fe1\u7528\u5361\u4ea4\u6613\u7684\u91d1\u878d\u516c\u53f8\u5de5\u4f5c\u3002\u6bcf\u79d2\u949f\u90fd\u6709\u5927\u91cf\u6570\u636e\u6d8c\u5165\u3002\u552f\u4e00\u7684\u95ee\u9898\u662f\uff0c\u5f88\u96be\u627e\u5230\u4e00\u4e2a\u4eba\u6765\u5c06\u6bcf\u7b14\u4ea4\u6613\u6807\u8bb0\u4e3a\u6709\u6548\u4ea4\u6613\u3001\u771f\u5b9e\u4ea4\u6613\u6216\u6b3a\u8bc8\u4ea4\u6613\u3002\u5f53\u6211\u4eec\u6ca1\u6709\u4efb\u4f55\u5173\u4e8e\u4ea4\u6613\u662f\u6b3a\u8bc8\u8fd8\u662f\u771f\u5b9e\u7684\u4fe1\u606f\u65f6\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u65e0\u76d1\u7763\u95ee\u9898\u3002\u8981\u89e3\u51b3\u8fd9\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5fc5\u987b\u8003\u8651\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a\u591a\u5c11\u4e2a \u805a\u7c7b \u3002\u805a\u7c7b\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd8\u6709\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5e94\u7528\u4e8e\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5bf9\u4e8e\u6b3a\u8bc8\u68c0\u6d4b\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\u6570\u636e\u53ef\u4ee5\u5206\u4e3a\u4e24\u7c7b\uff08\u6b3a\u8bc8\u6216\u771f\u5b9e\uff09\u3002 \u5f53\u6211\u4eec\u77e5\u9053\u805a\u7c7b\u7684\u6570\u91cf\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u805a\u7c7b\u7b97\u6cd5\u6765\u89e3\u51b3\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5728\u56fe 2 \u4e2d\uff0c\u5047\u8bbe\u6570\u636e\u5206\u4e3a\u4e24\u7c7b\uff0c\u6df1\u8272\u4ee3\u8868\u6b3a\u8bc8\uff0c\u6d45\u8272\u4ee3\u8868\u771f\u5b9e\u4ea4\u6613\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528\u805a\u7c7b\u65b9\u6cd5\u4e4b\u524d\uff0c\u6211\u4eec\u5e76\u4e0d\u77e5\u9053\u8fd9\u4e9b\u7c7b\u522b\u3002\u5e94\u7528\u805a\u7c7b\u7b97\u6cd5\u540e\uff0c\u6211\u4eec\u5e94\u8be5\u80fd\u591f\u533a\u5206\u8fd9\u4e24\u4e2a\u5047\u5b9a\u76ee\u6807\u3002 \u4e3a\u4e86\u7406\u89e3\u65e0\u76d1\u7763\u95ee\u9898\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u8bb8\u591a\u5206\u89e3\u6280\u672f\uff0c\u5982 \u4e3b\u6210\u5206\u5206\u6790\uff08PCA\uff09\u3001t-\u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09 \u7b49\u3002 \u6709\u76d1\u7763\u7684\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\uff0c\u56e0\u4e3a\u5b83\u4eec\u5f88\u5bb9\u6613\u8bc4\u4f30\u3002\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u8bc4\u4f30\u6280\u672f\u3002\u7136\u800c\uff0c\u5bf9\u65e0\u76d1\u7763\u7b97\u6cd5\u7684\u7ed3\u679c\u8fdb\u884c\u8bc4\u4f30\u5177\u6709\u6311\u6218\u6027\uff0c\u9700\u8981\u5927\u91cf\u7684\u4eba\u4e3a\u5e72\u9884\u6216\u542f\u53d1\u5f0f\u65b9\u6cd5\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3b\u8981\u5173\u6ce8\u6709\u76d1\u7763\u6570\u636e\u548c\u6a21\u578b\uff0c\u4f46\u8fd9\u5e76\u4e0d\u610f\u5473\u7740\u6211\u4eec\u4f1a\u5ffd\u7565\u65e0\u76d1\u7763\u6570\u636e\u95ee\u9898\u3002 \u56fe 2\uff1a\u65e0\u76d1\u7763\u5b66\u4e60\u6570\u636e\u96c6 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u5f53\u4eba\u4eec\u5f00\u59cb\u5b66\u4e60\u6570\u636e\u79d1\u5b66\u6216\u673a\u5668\u5b66\u4e60\u65f6\uff0c\u90fd\u4f1a\u4ece\u975e\u5e38\u8457\u540d\u7684\u6570\u636e\u96c6\u5f00\u59cb\uff0c\u4f8b\u5982\u6cf0\u5766\u5c3c\u514b\u6570\u636e\u96c6\u6216\u8679\u819c\u6570\u636e\u96c6\uff0c\u8fd9\u4e9b\u90fd\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u5728\u6cf0\u5766\u5c3c\u514b\u53f7\u6570\u636e\u96c6\u4e2d\uff0c\u4f60\u5fc5\u987b\u6839\u636e\u8239\u7968\u7b49\u7ea7\u3001\u6027\u522b\u3001\u5e74\u9f84\u7b49\u56e0\u7d20\u9884\u6d4b\u6cf0\u5766\u5c3c\u514b\u53f7\u4e0a\u4e58\u5ba2\u7684\u5b58\u6d3b\u7387\u3002\u540c\u6837\uff0c\u5728\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\u4e2d\uff0c\u60a8\u5fc5\u987b\u6839\u636e\u843c\u7247\u5bbd\u5ea6\u3001\u82b1\u74e3\u957f\u5ea6\u3001\u843c\u7247\u957f\u5ea6\u548c\u82b1\u74e3\u5bbd\u5ea6\u7b49\u56e0\u7d20\u9884\u6d4b\u82b1\u7684\u79cd\u7c7b\u3002 \u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u80fd\u5305\u62ec\u7528\u4e8e\u5ba2\u6237\u7ec6\u5206\u7684\u6570\u636e\u96c6\u3002 \u4f8b\u5982\uff0c\u60a8\u62e5\u6709\u8bbf\u95ee\u60a8\u7684\u7535\u5b50\u5546\u52a1\u7f51\u7ad9\u7684\u5ba2\u6237\u6570\u636e\uff0c\u6216\u8005\u8bbf\u95ee\u5546\u5e97\u6216\u5546\u573a\u7684\u5ba2\u6237\u6570\u636e\uff0c\u800c\u60a8\u5e0c\u671b\u5c06\u5b83\u4eec\u7ec6\u5206\u6216\u805a\u7c7b\u4e3a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u65e0\u76d1\u7763\u6570\u636e\u96c6\u7684\u53e6\u4e00\u4e2a\u4f8b\u5b50\u53ef\u80fd\u5305\u62ec\u4fe1\u7528\u5361\u6b3a\u8bc8\u68c0\u6d4b\u6216\u5bf9\u51e0\u5f20\u56fe\u7247\u8fdb\u884c\u805a\u7c7b\u7b49\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd8\u53ef\u4ee5\u5c06\u6709\u76d1\u7763\u6570\u636e\u96c6\u8f6c\u6362\u4e3a\u65e0\u76d1\u7763\u6570\u636e\u96c6\uff0c\u4ee5\u67e5\u770b\u5b83\u4eec\u5728\u7ed8\u5236\u65f6\u7684\u6548\u679c\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u56fe 3 \u4e2d\u7684\u6570\u636e\u96c6\u3002\u56fe 3 \u663e\u793a\u7684\u662f MNIST \u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u6d41\u884c\u7684\u624b\u5199\u6570\u5b57\u6570\u636e\u96c6\uff0c\u5b83\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\uff0c\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u4f60\u4f1a\u5f97\u5230\u6570\u5b57\u56fe\u50cf\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u6b63\u786e\u6807\u7b7e\u3002\u4f60\u5fc5\u987b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u53ea\u63d0\u4f9b\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\u8bc6\u522b\u51fa\u54ea\u4e2a\u6570\u5b57\u662f\u5b83\u3002 \u56fe 3\uff1aMNIST\u6570\u636e\u96c6 \u5982\u679c\u6211\u4eec\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c t \u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09\u5206\u89e3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u53ea\u9700\u5728\u56fe\u50cf\u50cf\u7d20\u4e0a\u964d\u7ef4\u81f3 2 \u4e2a\u7ef4\u5ea6\uff0c\u5c31\u80fd\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4e0a\u5206\u79bb\u56fe\u50cf\u3002\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1aMNIST \u6570\u636e\u96c6\u7684 t-SNE \u53ef\u89c6\u5316\u3002\u4f7f\u7528\u4e86 3000 \u5e45\u56fe\u50cf\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002\u9996\u5148\u662f\u5bfc\u5165\u6240\u6709\u9700\u8981\u7684\u5e93\u3002 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from sklearn import datasets from sklearn import manifold % matplotlib inline \u6211\u4eec\u4f7f\u7528 matplotlib \u548c seaborn \u8fdb\u884c\u7ed8\u56fe\uff0c\u4f7f\u7528 numpy \u5904\u7406\u6570\u503c\u6570\u7ec4\uff0c\u4f7f\u7528 pandas \u4ece\u6570\u503c\u6570\u7ec4\u521b\u5efa\u6570\u636e\u5e27\uff0c\u4f7f\u7528 scikit-learn (sklearn) \u83b7\u53d6\u6570\u636e\u5e76\u6267\u884c t-SNE\u3002 \u5bfc\u5165\u540e\uff0c\u6211\u4eec\u9700\u8981\u4e0b\u8f7d\u6570\u636e\u5e76\u5355\u72ec\u8bfb\u53d6\uff0c\u6216\u8005\u4f7f\u7528 sklearn \u7684\u5185\u7f6e\u51fd\u6570\u6765\u63d0\u4f9b MNIST \u6570\u636e\u96c6\u3002 data = datasets . fetch_openml ( 'mnist_784' , version = 1 , return_X_y = True ) pixel_values , targets = data targets = targets . astype ( int ) \u5728\u8fd9\u90e8\u5206\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 sklearn \u6570\u636e\u96c6\u83b7\u53d6\u4e86\u6570\u636e\uff0c\u5e76\u83b7\u5f97\u4e86\u4e00\u4e2a\u50cf\u7d20\u503c\u6570\u7ec4\u548c\u53e6\u4e00\u4e2a\u76ee\u6807\u6570\u7ec4\u3002\u7531\u4e8e\u76ee\u6807\u662f\u5b57\u7b26\u4e32\u7c7b\u578b\uff0c\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u3002 pixel_values \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a 70000x784 \u7684\u4e8c\u7ef4\u6570\u7ec4\u3002 \u5171\u6709 70000 \u5f20\u4e0d\u540c\u7684\u56fe\u50cf\uff0c\u6bcf\u5f20\u56fe\u50cf\u5927\u5c0f\u4e3a 28x28 \u50cf\u7d20\u3002\u5e73\u94fa 28x28 \u540e\u5f97\u5230 784 \u4e2a\u6570\u636e\u70b9\u3002 \u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u91cd\u5851\u4e3a\u539f\u6765\u7684\u5f62\u72b6\uff0c\u7136\u540e\u4f7f\u7528 matplotlib \u7ed8\u5236\u6210\u56fe\u8868\uff0c\u4ece\u800c\u5c06\u5176\u53ef\u89c6\u5316\u3002 single_image = pixel_values [ 1 , :] . reshape ( 28 , 28 ) plt . imshow ( single_image , cmap = 'gray' ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u7ed8\u5236\u5982\u4e0b\u56fe\u50cf\uff1a \u56fe 5\uff1a\u7ed8\u5236MNIST\u6570\u636e\u96c6\u5355\u5f20\u56fe\u7247 \u6700\u91cd\u8981\u7684\u4e00\u6b65\u662f\u5728\u6211\u4eec\u83b7\u53d6\u6570\u636e\u4e4b\u540e\u3002 tsne = manifold . TSNE ( n_components = 2 , random_state = 42 ) transformed_data = tsne . fit_transform ( pixel_values [: 3000 , :]) \u8fd9\u4e00\u6b65\u521b\u5efa\u4e86\u6570\u636e\u7684 t-SNE \u53d8\u6362\u3002\u6211\u4eec\u53ea\u4f7f\u7528 2 \u4e2a\u7ef4\u5ea6\uff0c\u56e0\u4e3a\u5728\u4e8c\u7ef4\u73af\u5883\u4e2d\u53ef\u4ee5\u5f88\u597d\u5730\u5c06\u5b83\u4eec\u53ef\u89c6\u5316\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u8f6c\u6362\u540e\u7684\u6570\u636e\u662f\u4e00\u4e2a 3000x2 \u5f62\u72b6\u7684\u6570\u7ec4\uff083000 \u884c 2 \u5217\uff09\u3002\u5728\u6570\u7ec4\u4e0a\u8c03\u7528 pd.DataFrame \u53ef\u4ee5\u5c06\u8fd9\u6837\u7684\u6570\u636e\u8f6c\u6362\u4e3a pandas \u6570\u636e\u5e27\u3002 tsne_df = pd . DataFrame ( np . column_stack (( transformed_data , targets [: 3000 ])), columns = [ \"x\" , \"y\" , \"targets\" ]) tsne_df . loc [:, \"targets\" ] = tsne_df . targets . astype ( int ) \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a numpy \u6570\u7ec4\u521b\u5efa\u4e00\u4e2a pandas \u6570\u636e\u5e27\u3002x \u548c y \u662f t-SNE \u5206\u89e3\u7684\u4e24\u4e2a\u7ef4\u5ea6\uff0ctarget \u662f\u5b9e\u9645\u6570\u5b57\u3002\u8fd9\u6837\u6211\u4eec\u5c31\u5f97\u5230\u4e86\u5982\u56fe 6 \u6240\u793a\u7684\u6570\u636e\u5e27\u3002 \u56fe 6\uff1at-SNE\u540e\u6570\u636e\u524d10\u884c \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 seaborn \u548c matplotlib \u7ed8\u5236\u5b83\u3002 grid = sns . FacetGrid ( tsne_df , hue = \"targets\" , size = 8 ) grid . map ( plt . scatter , \"x\" , \"y\" ) . add_legend () \u8fd9\u662f\u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u89c6\u5316\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5728\u540c\u4e00\u6570\u636e\u96c6\u4e0a\u8fdb\u884c k-means \u805a\u7c7b\uff0c\u770b\u770b\u5b83\u5728\u65e0\u76d1\u7763\u73af\u5883\u4e0b\u7684\u8868\u73b0\u5982\u4f55\u3002\u4e00\u4e2a\u7ecf\u5e38\u51fa\u73b0\u7684\u95ee\u9898\u662f\uff0c\u5982\u4f55\u5728 k-means \u805a\u7c7b\u4e2d\u627e\u5230\u6700\u4f73\u7684\u7c07\u6570\u3002\u8fd9\u4e2a\u95ee\u9898\u6ca1\u6709\u6b63\u786e\u7b54\u6848\u3002\u4f60\u5fc5\u987b\u901a\u8fc7\u4ea4\u53c9\u9a8c\u8bc1\u6765\u627e\u5230\u6700\u4f73\u7c07\u6570\u3002\u672c\u4e66\u7a0d\u540e\u5c06\u8ba8\u8bba\u4ea4\u53c9\u9a8c\u8bc1\u3002\u8bf7\u6ce8\u610f\uff0c\u4e0a\u8ff0\u4ee3\u7801\u662f\u5728 jupyter \u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\u7684\u3002 \u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 jupyter \u505a\u4e00\u4e9b\u7b80\u5355\u7684\u4e8b\u60c5\uff0c\u6bd4\u5982\u4e0a\u9762\u7684\u4f8b\u5b50\u548c \u7ed8\u56fe\u3002\u5bf9\u4e8e\u672c\u4e66\u4e2d\u7684\u5927\u90e8\u5206\u5185\u5bb9\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 python \u811a\u672c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6 IDE \u56e0\u4e3a\u7ed3\u679c\u90fd\u662f\u4e00\u6837\u7684\u3002 MNIST \u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u628a\u5b83\u8f6c\u6362\u6210\u4e00\u4e2a\u65e0\u76d1\u7763\u7684\u95ee\u9898\uff0c\u53ea\u662f\u4e3a\u4e86\u68c0\u67e5\u5b83\u662f\u5426\u80fd\u5e26\u6765\u4efb\u4f55\u597d\u7684\u7ed3\u679c\u3002\u5982\u679c\u6211\u4eec\u4f7f\u7528\u5206\u7c7b\u7b97\u6cd5\uff0c\u6548\u679c\u4f1a\u66f4\u597d\u3002\u8ba9\u6211\u4eec\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u4e00\u63a2\u7a76\u7adf\u3002","title":"\u6709\u76d1\u7763\u548c\u65e0\u76d1\u7763\u5b66\u4e60"},{"location":"%E6%97%A0%E7%9B%91%E7%9D%A3%E5%92%8C%E6%9C%89%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0/#_1","text":"\u5728\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u65f6\uff0c\u901a\u5e38\u6709\u4e24\u7c7b\u6570\u636e\uff08\u548c\u673a\u5668\u5b66\u4e60\u6a21\u578b\uff09\uff1a \u76d1\u7763\u6570\u636e\uff1a\u603b\u662f\u6709\u4e00\u4e2a\u6216\u591a\u4e2a\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807 \u65e0\u76d1\u7763\u6570\u636e\uff1a\u6ca1\u6709\u4efb\u4f55\u76ee\u6807\u53d8\u91cf\u3002 \u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\u3002\u6211\u4eec\u9700\u8981\u9884\u6d4b\u4e00\u4e2a\u503c\u7684\u95ee\u9898\u88ab\u79f0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u3002\u4f8b\u5982\uff0c\u5982\u679c\u95ee\u9898\u662f\u6839\u636e\u5386\u53f2\u623f\u4ef7\u9884\u6d4b\u623f\u4ef7\uff0c\u90a3\u4e48\u533b\u9662\u3001\u5b66\u6821\u6216\u8d85\u5e02\u7684\u5b58\u5728\uff0c\u4e0e\u6700\u8fd1\u516c\u5171\u4ea4\u901a\u7684\u8ddd\u79bb\u7b49\u7279\u5f81\u5c31\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u540c\u6837\uff0c\u5f53\u6211\u4eec\u5f97\u5230\u732b\u548c\u72d7\u7684\u56fe\u50cf\u65f6\uff0c\u6211\u4eec\u4e8b\u5148\u77e5\u9053\u54ea\u4e9b\u662f\u732b\uff0c\u54ea\u4e9b\u662f\u72d7\uff0c\u5982\u679c\u4efb\u52a1\u662f\u521b\u5efa\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u6240\u63d0\u4f9b\u7684\u56fe\u50cf\u662f\u732b\u8fd8\u662f\u72d7\uff0c\u90a3\u4e48\u8fd9\u4e2a\u95ee\u9898\u5c31\u88ab\u8ba4\u4e3a\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002 \u56fe 1\uff1a\u6709\u76d1\u7763\u5b66\u4e60\u6570\u636e \u5982\u56fe 1 \u6240\u793a\uff0c\u6570\u636e\u7684\u6bcf\u4e00\u884c\u90fd\u4e0e\u4e00\u4e2a\u76ee\u6807\u6216\u6807\u7b7e\u76f8\u5173\u8054\u3002\u5217\u662f\u4e0d\u540c\u7684\u7279\u5f81\uff0c\u884c\u4ee3\u8868\u4e0d\u540c\u7684\u6570\u636e\u70b9\uff0c\u901a\u5e38\u79f0\u4e3a\u6837\u672c\u3002\u793a\u4f8b\u4e2d\u7684\u5341\u4e2a\u6837\u672c\u6709\u5341\u4e2a\u7279\u5f81\u548c\u4e00\u4e2a\u76ee\u6807\u53d8\u91cf\uff0c\u76ee\u6807\u53d8\u91cf\u53ef\u4ee5\u662f\u6570\u5b57\u6216\u7c7b\u522b\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5206\u7c7b\u53d8\u91cf\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u76ee\u6807\u53d8\u91cf\u662f\u5b9e\u6570\uff0c\u95ee\u9898\u5c31\u88ab\u5b9a\u4e49\u4e3a\u56de\u5f52\u95ee\u9898\u3002\u56e0\u6b64\uff0c\u6709\u76d1\u7763\u95ee\u9898\u53ef\u5206\u4e3a\u4e24\u4e2a\u5b50\u7c7b\uff1a \u5206\u7c7b\uff1a\u9884\u6d4b\u7c7b\u522b\uff0c\u5982\u732b\u6216\u72d7 \u56de\u5f52\uff1a\u9884\u6d4b\u503c\uff0c\u5982\u623f\u4ef7 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u6709\u65f6\u6211\u4eec\u53ef\u80fd\u4f1a\u5728\u5206\u7c7b\u8bbe\u7f6e\u4e2d\u4f7f\u7528\u56de\u5f52\uff0c\u8fd9\u53d6\u51b3\u4e8e\u7528\u4e8e\u8bc4\u4f30\u7684\u6307\u6807\u3002\u4e0d\u8fc7\uff0c\u6211\u4eec\u7a0d\u540e\u4f1a\u8ba8\u8bba\u8fd9\u4e2a\u95ee\u9898\u3002 \u53e6\u4e00\u79cd\u673a\u5668\u5b66\u4e60\u95ee\u9898\u662f\u65e0\u76d1\u7763\u7c7b\u578b\u3002 \u65e0\u76d1\u7763 \u6570\u636e\u96c6\u6ca1\u6709\u4e0e\u4e4b\u76f8\u5173\u7684\u76ee\u6807\uff0c\u4e00\u822c\u6765\u8bf4\uff0c\u4e0e\u6709\u76d1\u7763\u95ee\u9898\u76f8\u6bd4\uff0c\u5904\u7406\u65e0\u76d1\u7763\u6570\u636e\u96c6\u66f4\u5177\u6311\u6218\u6027\u3002 \u5047\u8bbe\u4f60\u5728\u4e00\u5bb6\u5904\u7406\u4fe1\u7528\u5361\u4ea4\u6613\u7684\u91d1\u878d\u516c\u53f8\u5de5\u4f5c\u3002\u6bcf\u79d2\u949f\u90fd\u6709\u5927\u91cf\u6570\u636e\u6d8c\u5165\u3002\u552f\u4e00\u7684\u95ee\u9898\u662f\uff0c\u5f88\u96be\u627e\u5230\u4e00\u4e2a\u4eba\u6765\u5c06\u6bcf\u7b14\u4ea4\u6613\u6807\u8bb0\u4e3a\u6709\u6548\u4ea4\u6613\u3001\u771f\u5b9e\u4ea4\u6613\u6216\u6b3a\u8bc8\u4ea4\u6613\u3002\u5f53\u6211\u4eec\u6ca1\u6709\u4efb\u4f55\u5173\u4e8e\u4ea4\u6613\u662f\u6b3a\u8bc8\u8fd8\u662f\u771f\u5b9e\u7684\u4fe1\u606f\u65f6\uff0c\u95ee\u9898\u5c31\u53d8\u6210\u4e86\u65e0\u76d1\u7763\u95ee\u9898\u3002\u8981\u89e3\u51b3\u8fd9\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5fc5\u987b\u8003\u8651\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a\u591a\u5c11\u4e2a \u805a\u7c7b \u3002\u805a\u7c7b\u662f\u89e3\u51b3\u6b64\u7c7b\u95ee\u9898\u7684\u65b9\u6cd5\u4e4b\u4e00\uff0c\u4f46\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u8fd8\u6709\u5176\u4ed6\u51e0\u79cd\u65b9\u6cd5\u53ef\u4ee5\u5e94\u7528\u4e8e\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5bf9\u4e8e\u6b3a\u8bc8\u68c0\u6d4b\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u8bf4\u6570\u636e\u53ef\u4ee5\u5206\u4e3a\u4e24\u7c7b\uff08\u6b3a\u8bc8\u6216\u771f\u5b9e\uff09\u3002 \u5f53\u6211\u4eec\u77e5\u9053\u805a\u7c7b\u7684\u6570\u91cf\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u805a\u7c7b\u7b97\u6cd5\u6765\u89e3\u51b3\u65e0\u76d1\u7763\u95ee\u9898\u3002\u5728\u56fe 2 \u4e2d\uff0c\u5047\u8bbe\u6570\u636e\u5206\u4e3a\u4e24\u7c7b\uff0c\u6df1\u8272\u4ee3\u8868\u6b3a\u8bc8\uff0c\u6d45\u8272\u4ee3\u8868\u771f\u5b9e\u4ea4\u6613\u3002\u7136\u800c\uff0c\u5728\u4f7f\u7528\u805a\u7c7b\u65b9\u6cd5\u4e4b\u524d\uff0c\u6211\u4eec\u5e76\u4e0d\u77e5\u9053\u8fd9\u4e9b\u7c7b\u522b\u3002\u5e94\u7528\u805a\u7c7b\u7b97\u6cd5\u540e\uff0c\u6211\u4eec\u5e94\u8be5\u80fd\u591f\u533a\u5206\u8fd9\u4e24\u4e2a\u5047\u5b9a\u76ee\u6807\u3002 \u4e3a\u4e86\u7406\u89e3\u65e0\u76d1\u7763\u95ee\u9898\uff0c\u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528\u8bb8\u591a\u5206\u89e3\u6280\u672f\uff0c\u5982 \u4e3b\u6210\u5206\u5206\u6790\uff08PCA\uff09\u3001t-\u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09 \u7b49\u3002 \u6709\u76d1\u7763\u7684\u95ee\u9898\u66f4\u5bb9\u6613\u89e3\u51b3\uff0c\u56e0\u4e3a\u5b83\u4eec\u5f88\u5bb9\u6613\u8bc4\u4f30\u3002\u6211\u4eec\u5c06\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u8bc4\u4f30\u6280\u672f\u3002\u7136\u800c\uff0c\u5bf9\u65e0\u76d1\u7763\u7b97\u6cd5\u7684\u7ed3\u679c\u8fdb\u884c\u8bc4\u4f30\u5177\u6709\u6311\u6218\u6027\uff0c\u9700\u8981\u5927\u91cf\u7684\u4eba\u4e3a\u5e72\u9884\u6216\u542f\u53d1\u5f0f\u65b9\u6cd5\u3002\u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4e3b\u8981\u5173\u6ce8\u6709\u76d1\u7763\u6570\u636e\u548c\u6a21\u578b\uff0c\u4f46\u8fd9\u5e76\u4e0d\u610f\u5473\u7740\u6211\u4eec\u4f1a\u5ffd\u7565\u65e0\u76d1\u7763\u6570\u636e\u95ee\u9898\u3002 \u56fe 2\uff1a\u65e0\u76d1\u7763\u5b66\u4e60\u6570\u636e\u96c6 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u5f53\u4eba\u4eec\u5f00\u59cb\u5b66\u4e60\u6570\u636e\u79d1\u5b66\u6216\u673a\u5668\u5b66\u4e60\u65f6\uff0c\u90fd\u4f1a\u4ece\u975e\u5e38\u8457\u540d\u7684\u6570\u636e\u96c6\u5f00\u59cb\uff0c\u4f8b\u5982\u6cf0\u5766\u5c3c\u514b\u6570\u636e\u96c6\u6216\u8679\u819c\u6570\u636e\u96c6\uff0c\u8fd9\u4e9b\u90fd\u662f\u6709\u76d1\u7763\u7684\u95ee\u9898\u3002\u5728\u6cf0\u5766\u5c3c\u514b\u53f7\u6570\u636e\u96c6\u4e2d\uff0c\u4f60\u5fc5\u987b\u6839\u636e\u8239\u7968\u7b49\u7ea7\u3001\u6027\u522b\u3001\u5e74\u9f84\u7b49\u56e0\u7d20\u9884\u6d4b\u6cf0\u5766\u5c3c\u514b\u53f7\u4e0a\u4e58\u5ba2\u7684\u5b58\u6d3b\u7387\u3002\u540c\u6837\uff0c\u5728\u9e22\u5c3e\u82b1\u6570\u636e\u96c6\u4e2d\uff0c\u60a8\u5fc5\u987b\u6839\u636e\u843c\u7247\u5bbd\u5ea6\u3001\u82b1\u74e3\u957f\u5ea6\u3001\u843c\u7247\u957f\u5ea6\u548c\u82b1\u74e3\u5bbd\u5ea6\u7b49\u56e0\u7d20\u9884\u6d4b\u82b1\u7684\u79cd\u7c7b\u3002 \u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u80fd\u5305\u62ec\u7528\u4e8e\u5ba2\u6237\u7ec6\u5206\u7684\u6570\u636e\u96c6\u3002 \u4f8b\u5982\uff0c\u60a8\u62e5\u6709\u8bbf\u95ee\u60a8\u7684\u7535\u5b50\u5546\u52a1\u7f51\u7ad9\u7684\u5ba2\u6237\u6570\u636e\uff0c\u6216\u8005\u8bbf\u95ee\u5546\u5e97\u6216\u5546\u573a\u7684\u5ba2\u6237\u6570\u636e\uff0c\u800c\u60a8\u5e0c\u671b\u5c06\u5b83\u4eec\u7ec6\u5206\u6216\u805a\u7c7b\u4e3a\u4e0d\u540c\u7684\u7c7b\u522b\u3002\u65e0\u76d1\u7763\u6570\u636e\u96c6\u7684\u53e6\u4e00\u4e2a\u4f8b\u5b50\u53ef\u80fd\u5305\u62ec\u4fe1\u7528\u5361\u6b3a\u8bc8\u68c0\u6d4b\u6216\u5bf9\u51e0\u5f20\u56fe\u7247\u8fdb\u884c\u805a\u7c7b\u7b49\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd8\u53ef\u4ee5\u5c06\u6709\u76d1\u7763\u6570\u636e\u96c6\u8f6c\u6362\u4e3a\u65e0\u76d1\u7763\u6570\u636e\u96c6\uff0c\u4ee5\u67e5\u770b\u5b83\u4eec\u5728\u7ed8\u5236\u65f6\u7684\u6548\u679c\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u56fe 3 \u4e2d\u7684\u6570\u636e\u96c6\u3002\u56fe 3 \u663e\u793a\u7684\u662f MNIST \u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u975e\u5e38\u6d41\u884c\u7684\u624b\u5199\u6570\u5b57\u6570\u636e\u96c6\uff0c\u5b83\u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u95ee\u9898\uff0c\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u4f60\u4f1a\u5f97\u5230\u6570\u5b57\u56fe\u50cf\u548c\u4e0e\u4e4b\u76f8\u5173\u7684\u6b63\u786e\u6807\u7b7e\u3002\u4f60\u5fc5\u987b\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5728\u53ea\u63d0\u4f9b\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\u8bc6\u522b\u51fa\u54ea\u4e2a\u6570\u5b57\u662f\u5b83\u3002 \u56fe 3\uff1aMNIST\u6570\u636e\u96c6 \u5982\u679c\u6211\u4eec\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c t \u5206\u5e03\u968f\u673a\u90bb\u57df\u5d4c\u5165\uff08t-SNE\uff09\u5206\u89e3\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u53ea\u9700\u5728\u56fe\u50cf\u50cf\u7d20\u4e0a\u964d\u7ef4\u81f3 2 \u4e2a\u7ef4\u5ea6\uff0c\u5c31\u80fd\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4e0a\u5206\u79bb\u56fe\u50cf\u3002\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1aMNIST \u6570\u636e\u96c6\u7684 t-SNE \u53ef\u89c6\u5316\u3002\u4f7f\u7528\u4e86 3000 \u5e45\u56fe\u50cf\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002\u9996\u5148\u662f\u5bfc\u5165\u6240\u6709\u9700\u8981\u7684\u5e93\u3002 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from sklearn import datasets from sklearn import manifold % matplotlib inline \u6211\u4eec\u4f7f\u7528 matplotlib \u548c seaborn \u8fdb\u884c\u7ed8\u56fe\uff0c\u4f7f\u7528 numpy \u5904\u7406\u6570\u503c\u6570\u7ec4\uff0c\u4f7f\u7528 pandas \u4ece\u6570\u503c\u6570\u7ec4\u521b\u5efa\u6570\u636e\u5e27\uff0c\u4f7f\u7528 scikit-learn (sklearn) \u83b7\u53d6\u6570\u636e\u5e76\u6267\u884c t-SNE\u3002 \u5bfc\u5165\u540e\uff0c\u6211\u4eec\u9700\u8981\u4e0b\u8f7d\u6570\u636e\u5e76\u5355\u72ec\u8bfb\u53d6\uff0c\u6216\u8005\u4f7f\u7528 sklearn \u7684\u5185\u7f6e\u51fd\u6570\u6765\u63d0\u4f9b MNIST \u6570\u636e\u96c6\u3002 data = datasets . fetch_openml ( 'mnist_784' , version = 1 , return_X_y = True ) pixel_values , targets = data targets = targets . astype ( int ) \u5728\u8fd9\u90e8\u5206\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4f7f\u7528 sklearn \u6570\u636e\u96c6\u83b7\u53d6\u4e86\u6570\u636e\uff0c\u5e76\u83b7\u5f97\u4e86\u4e00\u4e2a\u50cf\u7d20\u503c\u6570\u7ec4\u548c\u53e6\u4e00\u4e2a\u76ee\u6807\u6570\u7ec4\u3002\u7531\u4e8e\u76ee\u6807\u662f\u5b57\u7b26\u4e32\u7c7b\u578b\uff0c\u6211\u4eec\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u3002 pixel_values \u662f\u4e00\u4e2a\u5f62\u72b6\u4e3a 70000x784 \u7684\u4e8c\u7ef4\u6570\u7ec4\u3002 \u5171\u6709 70000 \u5f20\u4e0d\u540c\u7684\u56fe\u50cf\uff0c\u6bcf\u5f20\u56fe\u50cf\u5927\u5c0f\u4e3a 28x28 \u50cf\u7d20\u3002\u5e73\u94fa 28x28 \u540e\u5f97\u5230 784 \u4e2a\u6570\u636e\u70b9\u3002 \u6211\u4eec\u53ef\u4ee5\u5c06\u8be5\u6570\u636e\u96c6\u4e2d\u7684\u6837\u672c\u91cd\u5851\u4e3a\u539f\u6765\u7684\u5f62\u72b6\uff0c\u7136\u540e\u4f7f\u7528 matplotlib \u7ed8\u5236\u6210\u56fe\u8868\uff0c\u4ece\u800c\u5c06\u5176\u53ef\u89c6\u5316\u3002 single_image = pixel_values [ 1 , :] . reshape ( 28 , 28 ) plt . imshow ( single_image , cmap = 'gray' ) \u8fd9\u6bb5\u4ee3\u7801\u5c06\u7ed8\u5236\u5982\u4e0b\u56fe\u50cf\uff1a \u56fe 5\uff1a\u7ed8\u5236MNIST\u6570\u636e\u96c6\u5355\u5f20\u56fe\u7247 \u6700\u91cd\u8981\u7684\u4e00\u6b65\u662f\u5728\u6211\u4eec\u83b7\u53d6\u6570\u636e\u4e4b\u540e\u3002 tsne = manifold . TSNE ( n_components = 2 , random_state = 42 ) transformed_data = tsne . fit_transform ( pixel_values [: 3000 , :]) \u8fd9\u4e00\u6b65\u521b\u5efa\u4e86\u6570\u636e\u7684 t-SNE \u53d8\u6362\u3002\u6211\u4eec\u53ea\u4f7f\u7528 2 \u4e2a\u7ef4\u5ea6\uff0c\u56e0\u4e3a\u5728\u4e8c\u7ef4\u73af\u5883\u4e2d\u53ef\u4ee5\u5f88\u597d\u5730\u5c06\u5b83\u4eec\u53ef\u89c6\u5316\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u8f6c\u6362\u540e\u7684\u6570\u636e\u662f\u4e00\u4e2a 3000x2 \u5f62\u72b6\u7684\u6570\u7ec4\uff083000 \u884c 2 \u5217\uff09\u3002\u5728\u6570\u7ec4\u4e0a\u8c03\u7528 pd.DataFrame \u53ef\u4ee5\u5c06\u8fd9\u6837\u7684\u6570\u636e\u8f6c\u6362\u4e3a pandas \u6570\u636e\u5e27\u3002 tsne_df = pd . DataFrame ( np . column_stack (( transformed_data , targets [: 3000 ])), columns = [ \"x\" , \"y\" , \"targets\" ]) tsne_df . loc [:, \"targets\" ] = tsne_df . targets . astype ( int ) \u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a numpy \u6570\u7ec4\u521b\u5efa\u4e00\u4e2a pandas \u6570\u636e\u5e27\u3002x \u548c y \u662f t-SNE \u5206\u89e3\u7684\u4e24\u4e2a\u7ef4\u5ea6\uff0ctarget \u662f\u5b9e\u9645\u6570\u5b57\u3002\u8fd9\u6837\u6211\u4eec\u5c31\u5f97\u5230\u4e86\u5982\u56fe 6 \u6240\u793a\u7684\u6570\u636e\u5e27\u3002 \u56fe 6\uff1at-SNE\u540e\u6570\u636e\u524d10\u884c \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 seaborn \u548c matplotlib \u7ed8\u5236\u5b83\u3002 grid = sns . FacetGrid ( tsne_df , hue = \"targets\" , size = 8 ) grid . map ( plt . scatter , \"x\" , \"y\" ) . add_legend () \u8fd9\u662f\u65e0\u76d1\u7763\u6570\u636e\u96c6\u53ef\u89c6\u5316\u7684\u4e00\u79cd\u65b9\u6cd5\u3002\u6211\u4eec\u8fd8\u53ef\u4ee5\u5728\u540c\u4e00\u6570\u636e\u96c6\u4e0a\u8fdb\u884c k-means \u805a\u7c7b\uff0c\u770b\u770b\u5b83\u5728\u65e0\u76d1\u7763\u73af\u5883\u4e0b\u7684\u8868\u73b0\u5982\u4f55\u3002\u4e00\u4e2a\u7ecf\u5e38\u51fa\u73b0\u7684\u95ee\u9898\u662f\uff0c\u5982\u4f55\u5728 k-means \u805a\u7c7b\u4e2d\u627e\u5230\u6700\u4f73\u7684\u7c07\u6570\u3002\u8fd9\u4e2a\u95ee\u9898\u6ca1\u6709\u6b63\u786e\u7b54\u6848\u3002\u4f60\u5fc5\u987b\u901a\u8fc7\u4ea4\u53c9\u9a8c\u8bc1\u6765\u627e\u5230\u6700\u4f73\u7c07\u6570\u3002\u672c\u4e66\u7a0d\u540e\u5c06\u8ba8\u8bba\u4ea4\u53c9\u9a8c\u8bc1\u3002\u8bf7\u6ce8\u610f\uff0c\u4e0a\u8ff0\u4ee3\u7801\u662f\u5728 jupyter \u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\u7684\u3002 \u5728\u672c\u4e66\u4e2d\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 jupyter \u505a\u4e00\u4e9b\u7b80\u5355\u7684\u4e8b\u60c5\uff0c\u6bd4\u5982\u4e0a\u9762\u7684\u4f8b\u5b50\u548c \u7ed8\u56fe\u3002\u5bf9\u4e8e\u672c\u4e66\u4e2d\u7684\u5927\u90e8\u5206\u5185\u5bb9\uff0c\u6211\u4eec\u5c06\u4f7f\u7528 python \u811a\u672c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u5176\u4ed6 IDE \u56e0\u4e3a\u7ed3\u679c\u90fd\u662f\u4e00\u6837\u7684\u3002 MNIST \u662f\u4e00\u4e2a\u6709\u76d1\u7763\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u628a\u5b83\u8f6c\u6362\u6210\u4e00\u4e2a\u65e0\u76d1\u7763\u7684\u95ee\u9898\uff0c\u53ea\u662f\u4e3a\u4e86\u68c0\u67e5\u5b83\u662f\u5426\u80fd\u5e26\u6765\u4efb\u4f55\u597d\u7684\u7ed3\u679c\u3002\u5982\u679c\u6211\u4eec\u4f7f\u7528\u5206\u7c7b\u7b97\u6cd5\uff0c\u6548\u679c\u4f1a\u66f4\u597d\u3002\u8ba9\u6211\u4eec\u5728\u63a5\u4e0b\u6765\u7684\u7ae0\u8282\u4e2d\u4e00\u63a2\u7a76\u7adf\u3002","title":"\u65e0\u76d1\u7763\u548c\u6709\u76d1\u7763\u5b66\u4e60"},{"location":"%E7%89%B9%E5%BE%81%E5%B7%A5%E7%A8%8B/","text":"\u7279\u5f81\u5de5\u7a0b \u7279\u5f81\u5de5\u7a0b\u662f\u6784\u5efa\u826f\u597d\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u90e8\u5206\u4e4b\u4e00\u3002\u5982\u679c\u6211\u4eec\u62e5\u6709\u6709\u7528\u7684\u7279\u5f81\uff0c\u6a21\u578b\u5c31\u4f1a\u8868\u73b0\u5f97\u66f4\u597d\u3002\u5728\u8bb8\u591a\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u907f\u514d\u4f7f\u7528\u5927\u578b\u590d\u6742\u6a21\u578b\uff0c\u800c\u4f7f\u7528\u5177\u6709\u5173\u952e\u5de5\u7a0b\u7279\u5f81\u7684\u7b80\u5355\u6a21\u578b\u3002\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\uff0c\u53ea\u6709\u5f53\u4f60\u5bf9\u95ee\u9898\u7684\u9886\u57df\u6709\u4e00\u5b9a\u7684\u4e86\u89e3\uff0c\u5e76\u4e14\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u76f8\u5173\u6570\u636e\u65f6\uff0c\u624d\u80fd\u4ee5\u6700\u4f73\u65b9\u5f0f\u5b8c\u6210\u7279\u5f81\u5de5\u7a0b\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e9b\u901a\u7528\u6280\u672f\uff0c\u4ece\u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u7279\u5f81\u3002\u7279\u5f81\u5de5\u7a0b\u4e0d\u4ec5\u4ec5\u662f\u4ece\u6570\u636e\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\uff0c\u8fd8\u5305\u62ec\u4e0d\u540c\u7c7b\u578b\u7684\u5f52\u4e00\u5316\u548c\u8f6c\u6362\u3002 \u5728\u6709\u5173\u5206\u7c7b\u7279\u5f81\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u4e86\u89e3\u4e86\u7ed3\u5408\u4e0d\u540c\u5206\u7c7b\u53d8\u91cf\u7684\u65b9\u6cd5\u3001\u5982\u4f55\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u8ba1\u6570\u3001\u6807\u7b7e\u7f16\u7801\u548c\u4f7f\u7528\u5d4c\u5165\u3002\u8fd9\u4e9b\u51e0\u4e4e\u90fd\u662f\u5229\u7528\u5206\u7c7b\u53d8\u91cf\u8bbe\u8ba1\u7279\u5f81\u7684\u65b9\u6cd5\u3002\u56e0\u6b64\uff0c\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u7684\u91cd\u70b9\u5c06\u4ec5\u9650\u4e8e\u6570\u503c\u53d8\u91cf\u4ee5\u53ca\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u7684\u7ec4\u5408\u3002 \u8ba9\u6211\u4eec\u4ece\u6700\u7b80\u5355\u4f46\u5e94\u7528\u6700\u5e7f\u6cdb\u7684\u7279\u5f81\u5de5\u7a0b\u6280\u672f\u5f00\u59cb\u3002\u5047\u8bbe\u4f60\u6b63\u5728\u5904\u7406\u65e5\u671f\u548c\u65f6\u95f4\u6570\u636e\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709\u4e00\u4e2a\u5e26\u6709\u65e5\u671f\u7c7b\u578b\u5217\u7684 pandas \u6570\u636e\u5e27\u3002\u5229\u7528\u8fd9\u4e00\u5217\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4ee5\u4e0b\u7279\u5f81\uff1a \u5e74 \u5e74\u4e2d\u7684\u5468 \u6708 \u661f\u671f \u5468\u672b \u5c0f\u65f6 \u8fd8\u6709\u66f4\u591a \u800c\u4f7f\u7528pandas\u5c31\u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u6dfb\u52a0'year'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5e74\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'year' ] = df [ 'datetime_column' ] . dt . year # \u6dfb\u52a0'weekofyear'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5468\u6570\u63d0\u53d6\u51fa\u6765 df . loc [:, 'weekofyear' ] = df [ 'datetime_column' ] . dt . weekofyear # \u6dfb\u52a0'month'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u6708\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'month' ] = df [ 'datetime_column' ] . dt . month # \u6dfb\u52a0'dayofweek'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u661f\u671f\u51e0\u63d0\u53d6\u51fa\u6765 df . loc [:, 'dayofweek' ] = df [ 'datetime_column' ] . dt . dayofweek # \u6dfb\u52a0'weekend'\u5217\uff0c\u5224\u65ad\u5f53\u5929\u662f\u5426\u4e3a\u5468\u672b df . loc [:, 'weekend' ] = ( df . datetime_column . dt . weekday >= 5 ) . astype ( int ) # \u6dfb\u52a0 'hour' \u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5c0f\u65f6\u63d0\u53d6\u51fa\u6765 df . loc [:, 'hour' ] = df [ 'datetime_column' ] . dt . hour \u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u65e5\u671f\u65f6\u95f4\u5217\u521b\u5efa\u4e00\u7cfb\u5217\u65b0\u5217\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u53ef\u4ee5\u521b\u5efa\u7684\u4e00\u4e9b\u793a\u4f8b\u529f\u80fd\u3002 import pandas as pd # \u521b\u5efa\u65e5\u671f\u65f6\u95f4\u5e8f\u5217\uff0c\u5305\u542b\u4e86\u4ece '2020-01-06' \u5230 '2020-01-10' \u7684\u65e5\u671f\u65f6\u95f4\u70b9\uff0c\u65f6\u95f4\u95f4\u9694\u4e3a10\u5c0f\u65f6 s = pd . date_range ( '2020-01-06' , '2020-01-10' , freq = '10H' ) . to_series () # \u63d0\u53d6\u5bf9\u5e94\u65f6\u95f4\u7279\u5f81 features = { \"dayofweek\" : s . dt . dayofweek . values , \"dayofyear\" : s . dt . dayofyear . values , \"hour\" : s . dt . hour . values , \"is_leap_year\" : s . dt . is_leap_year . values , \"quarter\" : s . dt . quarter . values , \"weekofyear\" : s . dt . weekofyear . values } \u8fd9\u5c06\u4ece\u7ed9\u5b9a\u7cfb\u5217\u4e2d\u751f\u6210\u4e00\u4e2a\u7279\u5f81\u5b57\u5178\u3002\u60a8\u53ef\u4ee5\u5c06\u6b64\u5e94\u7528\u4e8e pandas \u6570\u636e\u4e2d\u7684\u4efb\u4f55\u65e5\u671f\u65f6\u95f4\u5217\u3002\u8fd9\u4e9b\u662f pandas \u63d0\u4f9b\u7684\u4f17\u591a\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u4e2d\u7684\u4e00\u90e8\u5206\u3002\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u975e\u5e38\u91cd\u8981\uff0c\u4f8b\u5982\uff0c\u5728\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97\u7684\u9500\u552e\u989d\u65f6\uff0c\u5982\u679c\u60f3\u5728\u805a\u5408\u7279\u5f81\u4e0a\u4f7f\u7528 xgboost \u7b49\u6a21\u578b\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u5c31\u975e\u5e38\u91cd\u8981\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u4e0b\u6240\u793a\u7684\u6570\u636e\uff1a \u56fe 1\uff1a\u5305\u542b\u5206\u7c7b\u548c\u65e5\u671f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u6709\u4e00\u4e2a\u65e5\u671f\u5217\uff0c\u4ece\u4e2d\u53ef\u4ee5\u8f7b\u677e\u63d0\u53d6\u5e74\u3001\u6708\u3001\u5b63\u5ea6\u7b49\u7279\u5f81\u3002\u7136\u540e\uff0c\u6211\u4eec\u6709\u4e00\u4e2a customer_id \u5217\uff0c\u8be5\u5217\u6709\u591a\u4e2a\u6761\u76ee\uff0c\u56e0\u6b64\u4e00\u4e2a\u5ba2\u6237\u4f1a\u88ab\u770b\u5230\u5f88\u591a\u6b21\uff08\u622a\u56fe\u4e2d\u770b\u4e0d\u5230\uff09\u3002\u6bcf\u4e2a\u65e5\u671f\u548c\u5ba2\u6237 ID \u90fd\u6709\u4e09\u4e2a\u5206\u7c7b\u7279\u5f81\u548c\u4e00\u4e2a\u6570\u5b57\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u4ece\u4e2d\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff1a - \u5ba2\u6237\u6700\u6d3b\u8dc3\u7684\u6708\u4efd\u662f\u51e0\u6708 - \u67d0\u4e2a\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u7684\u8ba1\u6570\u662f\u591a\u5c11 - \u67d0\u5e74\u67d0\u6708\u67d0\u5468\u67d0\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u6570\u91cf\u662f\u591a\u5c11\uff1f - \u67d0\u4e2a\u5ba2\u6237\u7684 num1 \u5e73\u5747\u503c\u662f\u591a\u5c11\uff1f - \u7b49\u7b49\u3002 \u4f7f\u7528 pandas \u4e2d\u7684\u805a\u5408\uff0c\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u521b\u5efa\u7c7b\u4f3c\u7684\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u5b9e\u73b0\u3002 def generate_features ( df ): df . loc [:, 'year' ] = df [ 'date' ] . dt . year df . loc [:, 'weekofyear' ] = df [ 'date' ] . dt . weekofyear df . loc [:, 'month' ] = df [ 'date' ] . dt . month df . loc [:, 'dayofweek' ] = df [ 'date' ] . dt . dayofweek df . loc [:, 'weekend' ] = ( df [ 'date' ] . dt . weekday >= 5 ) . astype ( int ) aggs = {} # \u5bf9 'month' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'month' ] = [ 'nunique' , 'mean' ] # \u5bf9 'weekofyear' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'weekofyear' ] = [ 'nunique' , 'mean' ] # \u5bf9 'num1' \u5217\u8fdb\u884c sum\u3001max\u3001min\u3001mean \u805a\u5408 aggs [ 'num1' ] = [ 'sum' , 'max' , 'min' , 'mean' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c size \u805a\u5408 aggs [ 'customer_id' ] = [ 'size' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c nunique \u805a\u5408 aggs [ 'customer_id' ] = [ 'nunique' ] # \u5bf9\u6570\u636e\u5e94\u7528\u4e0d\u540c\u7684\u805a\u5408\u51fd\u6570 agg_df = df . groupby ( 'customer_id' ) . agg ( aggs ) # \u91cd\u7f6e\u7d22\u5f15 agg_df = agg_df . reset_index () return agg_df \u8bf7\u6ce8\u610f\uff0c\u5728\u4e0a\u8ff0\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u8df3\u8fc7\u4e86\u5206\u7c7b\u53d8\u91cf\uff0c\u4f46\u60a8\u53ef\u4ee5\u50cf\u4f7f\u7528\u5176\u4ed6\u805a\u5408\u53d8\u91cf\u4e00\u6837\u4f7f\u7528\u5b83\u4eec\u3002 \u56fe 2\uff1a\u603b\u4f53\u7279\u5f81\u548c\u5176\u4ed6\u7279\u5f81 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u56fe 2 \u4e2d\u7684\u6570\u636e\u4e0e\u5e26\u6709 customer_id \u5217\u7684\u539f\u59cb\u6570\u636e\u5e27\u8fde\u63a5\u8d77\u6765\uff0c\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5e76\u4e0d\u662f\u8981\u9884\u6d4b\u4ec0\u4e48\uff1b\u6211\u4eec\u53ea\u662f\u5728\u521b\u5efa\u901a\u7528\u7279\u5f81\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u8bd5\u56fe\u5728\u8fd9\u91cc\u9884\u6d4b\u4ec0\u4e48\uff0c\u521b\u5efa\u7279\u5f81\u4f1a\u66f4\u5bb9\u6613\u3002 \u4f8b\u5982\uff0c\u6709\u65f6\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u95ee\u9898\u65f6\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u7684\u7279\u5f81\u4e0d\u662f\u5355\u4e2a\u503c\uff0c\u800c\u662f\u4e00\u7cfb\u5217\u503c\u3002 \u4f8b\u5982\uff0c\u5ba2\u6237\u5728\u7279\u5b9a\u65f6\u95f4\u6bb5\u5185\u7684\u4ea4\u6613\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4f1a\u521b\u5efa\u4e0d\u540c\u7c7b\u578b\u7684\u7279\u5f81\uff0c\u4f8b\u5982\uff1a\u4f7f\u7528\u6570\u503c\u7279\u5f81\u65f6\uff0c\u5728\u5bf9\u5206\u7c7b\u5217\u8fdb\u884c\u5206\u7ec4\u65f6\uff0c\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e8e\u65f6\u95f4\u5206\u5e03\u503c\u5217\u8868\u7684\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u7cfb\u5217\u7edf\u8ba1\u7279\u5f81\uff0c\u4f8b\u5982 \u5e73\u5747\u503c \u6700\u5927\u503c \u6700\u5c0f\u503c \u72ec\u7279\u6027 \u504f\u659c \u5cf0\u5ea6 Kstat \u767e\u5206\u4f4d\u6570 \u5b9a\u91cf \u5cf0\u503c\u5230\u5cf0\u503c \u4ee5\u53ca\u66f4\u591a \u8fd9\u4e9b\u53ef\u4ee5\u4f7f\u7528\u7b80\u5355\u7684 numpy \u51fd\u6570\u521b\u5efa\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 import numpy as np # \u521b\u5efa\u5b57\u5178\uff0c\u7528\u4e8e\u5b58\u50a8\u4e0d\u540c\u7684\u7edf\u8ba1\u7279\u5f81 feature_dict = {} # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5e73\u5747\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'mean' \u952e\u4e0b feature_dict [ 'mean' ] = np . mean ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5927\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'max' \u952e\u4e0b feature_dict [ 'max' ] = np . max ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5c0f\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'min' \u952e\u4e0b feature_dict [ 'min' ] = np . min ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6807\u51c6\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'std' \u952e\u4e0b feature_dict [ 'std' ] = np . std ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u65b9\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'var' \u952e\u4e0b feature_dict [ 'var' ] = np . var ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5dee\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'ptp' \u952e\u4e0b feature_dict [ 'ptp' ] = np . ptp ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c10\u767e\u5206\u4f4d\u6570\uff08\u5373\u767e\u5206\u4e4b10\u5206\u4f4d\u6570\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_10' \u952e\u4e0b feature_dict [ 'percentile_10' ] = np . percentile ( x , 10 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c60\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_60' \u952e\u4e0b feature_dict [ 'percentile_60' ] = np . percentile ( x , 60 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c90\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_90' \u952e\u4e0b feature_dict [ 'percentile_90' ] = np . percentile ( x , 90 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u76845%\u5206\u4f4d\u6570\uff08\u53730.05\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_5' \u952e\u4e0b feature_dict [ 'quantile_5' ] = np . quantile ( x , 0.05 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768495%\u5206\u4f4d\u6570\uff08\u53730.95\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_95' \u952e\u4e0b feature_dict [ 'quantile_95' ] = np . quantile ( x , 0.95 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768499%\u5206\u4f4d\u6570\uff08\u53730.99\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_99' \u952e\u4e0b feature_dict [ 'quantile_99' ] = np . quantile ( x , 0.99 ) \u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff08\u6570\u503c\u5217\u8868\uff09\u53ef\u4ee5\u8f6c\u6362\u6210\u8bb8\u591a\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4e00\u4e2a\u540d\u4e3a tsfresh \u7684 python \u5e93\u975e\u5e38\u6709\u7528\u3002 from tsfresh.feature_extraction import feature_calculators as fc # \u8ba1\u7b97 x \u6570\u5217\u7684\u7edd\u5bf9\u80fd\u91cf\uff08abs_energy\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'abs_energy' \u952e\u4e0b feature_dict [ 'abs_energy' ] = fc . abs_energy ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u9ad8\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_above_mean' \u952e\u4e0b feature_dict [ 'count_above_mean' ] = fc . count_above_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u4f4e\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_below_mean' \u952e\u4e0b feature_dict [ 'count_below_mean' ] = fc . count_below_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u7edd\u5bf9\u53d8\u5316\uff08mean_abs_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_abs_change' \u952e\u4e0b feature_dict [ 'mean_abs_change' ] = fc . mean_abs_change ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u53d8\u5316\u7387\uff08mean_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_change' \u952e\u4e0b feature_dict [ 'mean_change' ] = fc . mean_change ( x ) \u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\uff1btsfresh \u63d0\u4f9b\u4e86\u6570\u767e\u79cd\u7279\u5f81\u548c\u6570\u5341\u79cd\u4e0d\u540c\u7279\u5f81\u7684\u53d8\u4f53\uff0c\u4f60\u53ef\u4ee5\u5c06\u5b83\u4eec\u7528\u4e8e\u57fa\u4e8e\u65f6\u95f4\u5e8f\u5217\uff08\u503c\u5217\u8868\uff09\u7684\u7279\u5f81\u3002\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0cx \u662f\u4e00\u4e2a\u503c\u5217\u8868\u3002\u4f46\u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\u3002\u60a8\u8fd8\u53ef\u4ee5\u4e3a\u5305\u542b\u6216\u4e0d\u5305\u542b\u5206\u7c7b\u6570\u636e\u7684\u6570\u503c\u6570\u636e\u521b\u5efa\u8bb8\u591a\u5176\u4ed6\u7279\u5f81\u3002\u751f\u6210\u8bb8\u591a\u7279\u5f81\u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u521b\u5efa\u4e00\u5806\u591a\u9879\u5f0f\u7279\u5f81\u3002\u4f8b\u5982\uff0c\u4ece\u4e24\u4e2a\u7279\u5f81 \"a \"\u548c \"b \"\u751f\u6210\u7684\u4e8c\u7ea7\u591a\u9879\u5f0f\u7279\u5f81\u5305\u62ec \"a\"\u3001\"b\"\u3001\"ab\"\u3001\"a^2 \"\u548c \"b^2\"\u3002 import numpy as np df = pd . DataFrame ( np . random . rand ( 100 , 2 ), columns = [ f \"f_ { i } \" for i in range ( 1 , 3 )]) \u5982\u56fe 3 \u6240\u793a\uff0c\u5b83\u7ed9\u51fa\u4e86\u4e00\u4e2a\u6570\u636e\u8868\u3002 \u56fe 3\uff1a\u5305\u542b\u4e24\u4e2a\u6570\u5b57\u7279\u5f81\u7684\u968f\u673a\u6570\u636e\u8868 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u7684 PolynomialFeatures \u521b\u5efa\u4e24\u6b21\u591a\u9879\u5f0f\u7279\u5f81\u3002 from sklearn import preprocessing # \u6307\u5b9a\u591a\u9879\u5f0f\u7684\u6b21\u6570\u4e3a 2\uff0c\u4e0d\u4ec5\u8003\u8651\u4ea4\u4e92\u9879\uff0c\u4e0d\u5305\u62ec\u504f\u5dee\uff08include_bias=False\uff09 pf = preprocessing . PolynomialFeatures ( degree = 2 , interaction_only = False , include_bias = False ) # \u62df\u5408\uff0c\u521b\u5efa\u591a\u9879\u5f0f\u7279\u5f81 pf . fit ( df ) # \u8f6c\u6362\u6570\u636e poly_feats = pf . transform ( df ) # \u83b7\u53d6\u751f\u6210\u7684\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf num_feats = poly_feats . shape [ 1 ] # \u4e3a\u65b0\u751f\u6210\u7684\u7279\u5f81\u547d\u540d df_transformed = pd . DataFrame ( poly_feats , columns = [ f \"f_ { i } \" for i in range ( 1 , num_feats + 1 )] ) \u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u6570\u636e\u8868\uff0c\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1a\u5e26\u6709\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e\u8868 \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e9b\u591a\u9879\u5f0f\u7279\u5f81\u3002\u5982\u679c\u521b\u5efa\u7684\u662f\u4e09\u6b21\u591a\u9879\u5f0f\u7279\u5f81\uff0c\u6700\u7ec8\u603b\u5171\u4f1a\u6709\u4e5d\u4e2a\u7279\u5f81\u3002\u7279\u5f81\u7684\u6570\u91cf\u8d8a\u591a\uff0c\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf\u4e5f\u5c31\u8d8a\u591a\uff0c\u800c\u4e14\u4f60\u8fd8\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5982\u679c\u6570\u636e\u96c6\u4e2d\u6709\u5f88\u591a\u6837\u672c\uff0c\u90a3\u4e48\u521b\u5efa\u8fd9\u7c7b\u7279\u5f81\u5c31\u9700\u8981\u82b1\u8d39\u4e00\u4e9b\u65f6\u95f4\u3002 \u56fe 5\uff1a\u6570\u5b57\u7279\u5f81\u5217\u7684\u76f4\u65b9\u56fe \u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u529f\u80fd\u662f\u5c06\u6570\u5b57\u8f6c\u6362\u4e3a\u7c7b\u522b\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \u5206\u7bb1 \u3002\u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u56fe 5\uff0c\u5b83\u663e\u793a\u4e86\u4e00\u4e2a\u968f\u673a\u6570\u5b57\u7279\u5f81\u7684\u6837\u672c\u76f4\u65b9\u56fe\u3002\u6211\u4eec\u5728\u8be5\u56fe\u4e2d\u4f7f\u7528\u4e8610\u4e2a\u5206\u7bb1\uff0c\u53ef\u4ee5\u770b\u5230\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a10\u4e2a\u90e8\u5206\u3002\u8fd9\u53ef\u4ee5\u4f7f\u7528 pandas \u7684cat\u51fd\u6570\u6765\u5b9e\u73b0\u3002 # \u521b\u5efa10\u4e2a\u5206\u7bb1 df [ \"f_bin_10\" ] = pd . cut ( df [ \"f_1\" ], bins = 10 , labels = False ) # \u521b\u5efa100\u4e2a\u5206\u7bb1 df [ \"f_bin_100\" ] = pd . cut ( df [ \"f_1\" ], bins = 100 , labels = False ) \u5982\u56fe 6 \u6240\u793a\uff0c\u8fd9\u5c06\u5728\u6570\u636e\u5e27\u4e2d\u751f\u6210\u4e24\u4e2a\u65b0\u7279\u5f81\u3002 \u56fe 6\uff1a\u6570\u503c\u7279\u5f81\u5206\u7bb1 \u5f53\u4f60\u8fdb\u884c\u5206\u7c7b\u65f6\uff0c\u53ef\u4ee5\u540c\u65f6\u4f7f\u7528\u5206\u7bb1\u548c\u539f\u59cb\u7279\u5f81\u3002\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u540e\u534a\u90e8\u5206\u5b66\u4e60\u66f4\u591a\u5173\u4e8e\u9009\u62e9\u7279\u5f81\u7684\u77e5\u8bc6\u3002\u5206\u7bb1\u8fd8\u53ef\u4ee5\u5c06\u6570\u5b57\u7279\u5f81\u89c6\u4e3a\u5206\u7c7b\u7279\u5f81\u3002 \u53e6\u4e00\u79cd\u53ef\u4ee5\u4ece\u6570\u503c\u7279\u5f81\u4e2d\u521b\u5efa\u7684\u6709\u8da3\u7279\u5f81\u7c7b\u578b\u662f\u5bf9\u6570\u53d8\u6362\u3002\u8bf7\u770b\u56fe 7 \u4e2d\u7684\u7279\u5f81 f_3\u3002 \u4e0e\u5176\u4ed6\u65b9\u5dee\u8f83\u5c0f\u7684\u7279\u5f81\u76f8\u6bd4\uff08\u5047\u8bbe\u5982\u6b64\uff09\uff0cf_3 \u662f\u4e00\u79cd\u65b9\u5dee\u975e\u5e38\u5927\u7684\u7279\u6b8a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5e0c\u671b\u964d\u4f4e\u8fd9\u4e00\u5217\u7684\u65b9\u5dee\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6570\u53d8\u6362\u6765\u5b9e\u73b0\u3002 f_3 \u5217\u7684\u503c\u8303\u56f4\u4e3a 0 \u5230 10000\uff0c\u76f4\u65b9\u56fe\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u7279\u5f81 f_3 \u7684\u76f4\u65b9\u56fe \u6211\u4eec\u53ef\u4ee5\u5bf9\u8fd9\u4e00\u5217\u5e94\u7528 log(1 + x) \u6765\u51cf\u5c11\u5176\u65b9\u5dee\u3002\u56fe 9 \u663e\u793a\u4e86\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u76f4\u65b9\u56fe\u7684\u53d8\u5316\u3002 \u56fe 9\uff1a\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u7684 f_3 \u76f4\u65b9\u56fe \u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e0d\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u548c\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u7684\u65b9\u5dee\u3002 In [ X ]: df . f_3 . var () Out [ X ]: 8077265.875858586 In [ X ]: df . f_3 . apply ( lambda x : np . log ( 1 + x )) . var () Out [ X ]: 0.6058771732119975 \u6709\u65f6\uff0c\u4e5f\u53ef\u4ee5\u7528\u6307\u6570\u6765\u4ee3\u66ff\u5bf9\u6570\u3002\u4e00\u79cd\u975e\u5e38\u6709\u8da3\u7684\u60c5\u51b5\u662f\uff0c\u60a8\u4f7f\u7528\u57fa\u4e8e\u5bf9\u6570\u7684\u8bc4\u4f30\u6307\u6807\uff0c\u4f8b\u5982 RMSLE\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u5728\u5bf9\u6570\u53d8\u6362\u7684\u76ee\u6807\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u7136\u540e\u5728\u9884\u6d4b\u65f6\u4f7f\u7528\u6307\u6570\u503c\u8f6c\u6362\u56de\u539f\u59cb\u503c\u3002\u8fd9\u5c06\u6709\u52a9\u4e8e\u9488\u5bf9\u6307\u6807\u4f18\u5316\u6a21\u578b\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u7c7b\u6570\u5b57\u7279\u5f81\u90fd\u662f\u57fa\u4e8e\u76f4\u89c9\u521b\u5efa\u7684\u3002\u6ca1\u6709\u516c\u5f0f\u53ef\u5faa\u3002\u5982\u679c\u60a8\u4ece\u4e8b\u7684\u662f\u67d0\u4e00\u884c\u4e1a\uff0c\u60a8\u5c06\u521b\u5efa\u7279\u5b9a\u884c\u4e1a\u7684\u7279\u5f81\u3002 \u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u548c\u6570\u503c\u53d8\u91cf\u65f6\uff0c\u53ef\u80fd\u4f1a\u9047\u5230\u7f3a\u5931\u503c\u3002\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e9b\u5904\u7406\u5206\u7c7b\u7279\u5f81\u4e2d\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\uff0c\u4f46\u8fd8\u6709\u66f4\u591a\u65b9\u6cd5\u53ef\u4ee5\u5904\u7406\u7f3a\u5931\u503c/NaN \u503c\u3002\u8fd9\u4e5f\u88ab\u89c6\u4e3a\u7279\u5f81\u5de5\u7a0b\u3002 \u5982\u679c\u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u9047\u5230\u7f3a\u5931\u503c\uff0c\u5c31\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\uff01\u8fd9\u6837\u505a\u867d\u7136\u7b80\u5355\uff0c\u4f46\uff08\u51e0\u4e4e\uff09\u603b\u662f\u6709\u6548\u7684\uff01 \u5728\u6570\u503c\u6570\u636e\u4e2d\u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u9009\u62e9\u4e00\u4e2a\u5728\u7279\u5b9a\u7279\u5f81\u4e2d\u6ca1\u6709\u51fa\u73b0\u7684\u503c\uff0c\u7136\u540e\u7528\u5b83\u6765\u586b\u8865\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u7279\u5f81\u4e2d\u6ca1\u6709 0\u3002\u8fd9\u662f\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\uff0c\u4f46\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\u3002\u5bf9\u4e8e\u6570\u503c\u6570\u636e\u6765\u8bf4\uff0c\u6bd4\u586b\u5145 0 \u66f4\u6709\u6548\u7684\u65b9\u6cd5\u4e4b\u4e00\u662f\u4f7f\u7528\u5e73\u5747\u503c\u8fdb\u884c\u586b\u5145\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8be5\u7279\u5f81\u6240\u6709\u503c\u7684\u4e2d\u4f4d\u6570\u6765\u586b\u5145\uff0c\u6216\u8005\u4f7f\u7528\u6700\u5e38\u89c1\u7684\u503c\u6765\u586b\u5145\u7f3a\u5931\u503c\u3002\u8fd9\u6837\u505a\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002 \u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u9ad8\u7ea7\u65b9\u6cd5\u662f\u4f7f\u7528 K \u8fd1\u90bb\u6cd5 \u3002 \u60a8\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\uff0c\u7136\u540e\u5229\u7528\u67d0\u79cd\u8ddd\u79bb\u5ea6\u91cf\uff08\u4f8b\u5982\u6b27\u6c0f\u8ddd\u79bb\uff09\u627e\u5230\u6700\u8fd1\u7684\u90bb\u5c45\u3002\u7136\u540e\u53d6\u6240\u6709\u8fd1\u90bb\u7684\u5e73\u5747\u503c\u6765\u586b\u8865\u7f3a\u5931\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 KNN \u6765\u586b\u8865\u8fd9\u6837\u7684\u7f3a\u5931\u503c\u3002 \u56fe 10\uff1a\u6709\u7f3a\u5931\u503c\u7684\u4e8c\u7ef4\u6570\u7ec4 \u8ba9\u6211\u4eec\u770b\u770b KNN \u662f\u5982\u4f55\u5904\u7406\u56fe 10 \u6240\u793a\u7684\u7f3a\u5931\u503c\u77e9\u9635\u7684\u3002 import numpy as np from sklearn import impute # \u751f\u6210\u7ef4\u5ea6\u4e3a (10, 6) \u7684\u968f\u673a\u6574\u6570\u77e9\u9635 X\uff0c\u6570\u503c\u8303\u56f4\u5728 1 \u5230 14 \u4e4b\u95f4 X = np . random . randint ( 1 , 15 , ( 10 , 6 )) # \u6570\u636e\u7c7b\u578b\u8f6c\u6362\u4e3a float X = X . astype ( float ) # \u5728\u77e9\u9635 X \u4e2d\u968f\u673a\u9009\u62e9 10 \u4e2a\u4f4d\u7f6e\uff0c\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u5143\u7d20\u8bbe\u7f6e\u4e3a NaN\uff08\u7f3a\u5931\u503c\uff09 X . ravel ()[ np . random . choice ( X . size , 10 , replace = False )] = np . nan # \u521b\u5efa\u4e00\u4e2a KNNImputer \u5bf9\u8c61 knn_imputer\uff0c\u6307\u5b9a\u90bb\u5c45\u6570\u91cf\u4e3a 2 knn_imputer = impute . KNNImputer ( n_neighbors = 2 ) # # \u4f7f\u7528 knn_imputer \u5bf9\u77e9\u9635 X \u8fdb\u884c\u62df\u5408\u548c\u8f6c\u6362\uff0c\u7528 K-\u6700\u8fd1\u90bb\u65b9\u6cd5\u586b\u8865\u7f3a\u5931\u503c knn_imputer . fit_transform ( X ) \u5982\u56fe 11 \u6240\u793a\uff0c\u5b83\u586b\u5145\u4e86\u4e0a\u8ff0\u77e9\u9635\u3002 \u56fe 11\uff1aKNN\u4f30\u7b97\u7684\u6570\u503c \u53e6\u4e00\u79cd\u5f25\u8865\u5217\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\u662f\u8bad\u7ec3\u56de\u5f52\u6a21\u578b\uff0c\u8bd5\u56fe\u6839\u636e\u5176\u4ed6\u5217\u9884\u6d4b\u67d0\u5217\u7684\u7f3a\u5931\u503c\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u4ece\u6709\u7f3a\u5931\u503c\u7684\u4e00\u5217\u5f00\u59cb\uff0c\u5c06\u8fd9\u4e00\u5217\u4f5c\u4e3a\u65e0\u7f3a\u5931\u503c\u56de\u5f52\u6a21\u578b\u7684\u76ee\u6807\u5217\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6240\u6709\u5176\u4ed6\u5217\uff0c\u5bf9\u76f8\u5173\u5217\u4e2d\u6ca1\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\uff0c\u7136\u540e\u5c1d\u8bd5\u9884\u6d4b\u4e4b\u524d\u5220\u9664\u7684\u6837\u672c\u7684\u76ee\u6807\u5217\uff08\u540c\u4e00\u5217\uff09\u3002\u8fd9\u6837\uff0c\u57fa\u4e8e\u6a21\u578b\u7684\u4f30\u7b97\u5c31\u4f1a\u66f4\u52a0\u7a33\u5065\u3002 \u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u5bf9\u4e8e\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c\u6570\u503c\u5f52\u4e00\u5316\uff0c\u56e0\u4e3a\u5b83\u4eec\u53ef\u4ee5\u81ea\u884c\u5904\u7406\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u6240\u5c55\u793a\u7684\u53ea\u662f\u521b\u5efa\u4e00\u822c\u7279\u5f81\u7684\u4e00\u4e9b\u65b9\u6cd5\u3002\u73b0\u5728\uff0c\u5047\u8bbe\u60a8\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u9884\u6d4b\u4e0d\u540c\u5546\u54c1\uff08\u6bcf\u5468\u6216\u6bcf\u6708\uff09\u5546\u5e97\u9500\u552e\u989d\u7684\u95ee\u9898\u3002\u60a8\u6709\u5546\u54c1\uff0c\u4e5f\u6709\u5546\u5e97 ID\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u6bcf\u4e2a\u5546\u5e97\u7684\u5546\u54c1\u7b49\u7279\u5f81\u3002\u73b0\u5728\uff0c\u8fd9\u662f\u4e0a\u6587\u6ca1\u6709\u8ba8\u8bba\u7684\u7279\u5f81\u4e4b\u4e00\u3002\u8fd9\u7c7b\u7279\u5f81\u4e0d\u80fd\u4e00\u6982\u800c\u8bba\uff0c\u5b8c\u5168\u6765\u81ea\u4e8e\u9886\u57df\u3001\u6570\u636e\u548c\u4e1a\u52a1\u77e5\u8bc6\u3002\u67e5\u770b\u6570\u636e\uff0c\u627e\u51fa\u9002\u5408\u7684\u7279\u5f81\uff0c\u7136\u540e\u521b\u5efa\u76f8\u5e94\u7684\u7279\u5f81\u3002\u5982\u679c\u60a8\u4f7f\u7528\u7684\u662f\u903b\u8f91\u56de\u5f52\u7b49\u7ebf\u6027\u6a21\u578b\u6216 SVM \u7b49\u6a21\u578b\uff0c\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\u5bf9\u7279\u5f81\u8fdb\u884c\u7f29\u653e\u6216\u5f52\u4e00\u5316\u5904\u7406\u3002\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65e0\u9700\u5bf9\u7279\u5f81\u8fdb\u884c\u4efb\u4f55\u5f52\u4e00\u5316\u5904\u7406\u5373\u53ef\u6b63\u5e38\u5de5\u4f5c\u3002","title":"\u7279\u5f81\u5de5\u7a0b"},{"location":"%E7%89%B9%E5%BE%81%E5%B7%A5%E7%A8%8B/#_1","text":"\u7279\u5f81\u5de5\u7a0b\u662f\u6784\u5efa\u826f\u597d\u673a\u5668\u5b66\u4e60\u6a21\u578b\u7684\u6700\u5173\u952e\u90e8\u5206\u4e4b\u4e00\u3002\u5982\u679c\u6211\u4eec\u62e5\u6709\u6709\u7528\u7684\u7279\u5f81\uff0c\u6a21\u578b\u5c31\u4f1a\u8868\u73b0\u5f97\u66f4\u597d\u3002\u5728\u8bb8\u591a\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u907f\u514d\u4f7f\u7528\u5927\u578b\u590d\u6742\u6a21\u578b\uff0c\u800c\u4f7f\u7528\u5177\u6709\u5173\u952e\u5de5\u7a0b\u7279\u5f81\u7684\u7b80\u5355\u6a21\u578b\u3002\u6211\u4eec\u5fc5\u987b\u7262\u8bb0\uff0c\u53ea\u6709\u5f53\u4f60\u5bf9\u95ee\u9898\u7684\u9886\u57df\u6709\u4e00\u5b9a\u7684\u4e86\u89e3\uff0c\u5e76\u4e14\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u53d6\u51b3\u4e8e\u76f8\u5173\u6570\u636e\u65f6\uff0c\u624d\u80fd\u4ee5\u6700\u4f73\u65b9\u5f0f\u5b8c\u6210\u7279\u5f81\u5de5\u7a0b\u3002\u4e0d\u8fc7\uff0c\u60a8\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u4e00\u4e9b\u901a\u7528\u6280\u672f\uff0c\u4ece\u51e0\u4e4e\u6240\u6709\u7c7b\u578b\u7684\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u4e2d\u521b\u5efa\u7279\u5f81\u3002\u7279\u5f81\u5de5\u7a0b\u4e0d\u4ec5\u4ec5\u662f\u4ece\u6570\u636e\u4e2d\u521b\u5efa\u65b0\u7279\u5f81\uff0c\u8fd8\u5305\u62ec\u4e0d\u540c\u7c7b\u578b\u7684\u5f52\u4e00\u5316\u548c\u8f6c\u6362\u3002 \u5728\u6709\u5173\u5206\u7c7b\u7279\u5f81\u7684\u7ae0\u8282\u4e2d\uff0c\u6211\u4eec\u5df2\u7ecf\u4e86\u89e3\u4e86\u7ed3\u5408\u4e0d\u540c\u5206\u7c7b\u53d8\u91cf\u7684\u65b9\u6cd5\u3001\u5982\u4f55\u5c06\u5206\u7c7b\u53d8\u91cf\u8f6c\u6362\u4e3a\u8ba1\u6570\u3001\u6807\u7b7e\u7f16\u7801\u548c\u4f7f\u7528\u5d4c\u5165\u3002\u8fd9\u4e9b\u51e0\u4e4e\u90fd\u662f\u5229\u7528\u5206\u7c7b\u53d8\u91cf\u8bbe\u8ba1\u7279\u5f81\u7684\u65b9\u6cd5\u3002\u56e0\u6b64\uff0c\u5728\u672c\u7ae0\u4e2d\uff0c\u6211\u4eec\u7684\u91cd\u70b9\u5c06\u4ec5\u9650\u4e8e\u6570\u503c\u53d8\u91cf\u4ee5\u53ca\u6570\u503c\u53d8\u91cf\u548c\u5206\u7c7b\u53d8\u91cf\u7684\u7ec4\u5408\u3002 \u8ba9\u6211\u4eec\u4ece\u6700\u7b80\u5355\u4f46\u5e94\u7528\u6700\u5e7f\u6cdb\u7684\u7279\u5f81\u5de5\u7a0b\u6280\u672f\u5f00\u59cb\u3002\u5047\u8bbe\u4f60\u6b63\u5728\u5904\u7406\u65e5\u671f\u548c\u65f6\u95f4\u6570\u636e\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709\u4e00\u4e2a\u5e26\u6709\u65e5\u671f\u7c7b\u578b\u5217\u7684 pandas \u6570\u636e\u5e27\u3002\u5229\u7528\u8fd9\u4e00\u5217\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4ee5\u4e0b\u7279\u5f81\uff1a \u5e74 \u5e74\u4e2d\u7684\u5468 \u6708 \u661f\u671f \u5468\u672b \u5c0f\u65f6 \u8fd8\u6709\u66f4\u591a \u800c\u4f7f\u7528pandas\u5c31\u53ef\u4ee5\u975e\u5e38\u5bb9\u6613\u5730\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u6dfb\u52a0'year'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5e74\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'year' ] = df [ 'datetime_column' ] . dt . year # \u6dfb\u52a0'weekofyear'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5468\u6570\u63d0\u53d6\u51fa\u6765 df . loc [:, 'weekofyear' ] = df [ 'datetime_column' ] . dt . weekofyear # \u6dfb\u52a0'month'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u6708\u4efd\u63d0\u53d6\u51fa\u6765 df . loc [:, 'month' ] = df [ 'datetime_column' ] . dt . month # \u6dfb\u52a0'dayofweek'\u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u661f\u671f\u51e0\u63d0\u53d6\u51fa\u6765 df . loc [:, 'dayofweek' ] = df [ 'datetime_column' ] . dt . dayofweek # \u6dfb\u52a0'weekend'\u5217\uff0c\u5224\u65ad\u5f53\u5929\u662f\u5426\u4e3a\u5468\u672b df . loc [:, 'weekend' ] = ( df . datetime_column . dt . weekday >= 5 ) . astype ( int ) # \u6dfb\u52a0 'hour' \u5217\uff0c\u5c06 'datetime_column' \u4e2d\u7684\u5c0f\u65f6\u63d0\u53d6\u51fa\u6765 df . loc [:, 'hour' ] = df [ 'datetime_column' ] . dt . hour \u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u4f7f\u7528\u65e5\u671f\u65f6\u95f4\u5217\u521b\u5efa\u4e00\u7cfb\u5217\u65b0\u5217\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u53ef\u4ee5\u521b\u5efa\u7684\u4e00\u4e9b\u793a\u4f8b\u529f\u80fd\u3002 import pandas as pd # \u521b\u5efa\u65e5\u671f\u65f6\u95f4\u5e8f\u5217\uff0c\u5305\u542b\u4e86\u4ece '2020-01-06' \u5230 '2020-01-10' \u7684\u65e5\u671f\u65f6\u95f4\u70b9\uff0c\u65f6\u95f4\u95f4\u9694\u4e3a10\u5c0f\u65f6 s = pd . date_range ( '2020-01-06' , '2020-01-10' , freq = '10H' ) . to_series () # \u63d0\u53d6\u5bf9\u5e94\u65f6\u95f4\u7279\u5f81 features = { \"dayofweek\" : s . dt . dayofweek . values , \"dayofyear\" : s . dt . dayofyear . values , \"hour\" : s . dt . hour . values , \"is_leap_year\" : s . dt . is_leap_year . values , \"quarter\" : s . dt . quarter . values , \"weekofyear\" : s . dt . weekofyear . values } \u8fd9\u5c06\u4ece\u7ed9\u5b9a\u7cfb\u5217\u4e2d\u751f\u6210\u4e00\u4e2a\u7279\u5f81\u5b57\u5178\u3002\u60a8\u53ef\u4ee5\u5c06\u6b64\u5e94\u7528\u4e8e pandas \u6570\u636e\u4e2d\u7684\u4efb\u4f55\u65e5\u671f\u65f6\u95f4\u5217\u3002\u8fd9\u4e9b\u662f pandas \u63d0\u4f9b\u7684\u4f17\u591a\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u4e2d\u7684\u4e00\u90e8\u5206\u3002\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u6570\u636e\u65f6\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u975e\u5e38\u91cd\u8981\uff0c\u4f8b\u5982\uff0c\u5728\u9884\u6d4b\u4e00\u5bb6\u5546\u5e97\u7684\u9500\u552e\u989d\u65f6\uff0c\u5982\u679c\u60f3\u5728\u805a\u5408\u7279\u5f81\u4e0a\u4f7f\u7528 xgboost \u7b49\u6a21\u578b\uff0c\u65e5\u671f\u65f6\u95f4\u7279\u5f81\u5c31\u975e\u5e38\u91cd\u8981\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a\u5982\u4e0b\u6240\u793a\u7684\u6570\u636e\uff1a \u56fe 1\uff1a\u5305\u542b\u5206\u7c7b\u548c\u65e5\u671f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e \u5728\u56fe 1 \u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\u6709\u4e00\u4e2a\u65e5\u671f\u5217\uff0c\u4ece\u4e2d\u53ef\u4ee5\u8f7b\u677e\u63d0\u53d6\u5e74\u3001\u6708\u3001\u5b63\u5ea6\u7b49\u7279\u5f81\u3002\u7136\u540e\uff0c\u6211\u4eec\u6709\u4e00\u4e2a customer_id \u5217\uff0c\u8be5\u5217\u6709\u591a\u4e2a\u6761\u76ee\uff0c\u56e0\u6b64\u4e00\u4e2a\u5ba2\u6237\u4f1a\u88ab\u770b\u5230\u5f88\u591a\u6b21\uff08\u622a\u56fe\u4e2d\u770b\u4e0d\u5230\uff09\u3002\u6bcf\u4e2a\u65e5\u671f\u548c\u5ba2\u6237 ID \u90fd\u6709\u4e09\u4e2a\u5206\u7c7b\u7279\u5f81\u548c\u4e00\u4e2a\u6570\u5b57\u7279\u5f81\u3002\u6211\u4eec\u53ef\u4ee5\u4ece\u4e2d\u521b\u5efa\u5927\u91cf\u7279\u5f81\uff1a - \u5ba2\u6237\u6700\u6d3b\u8dc3\u7684\u6708\u4efd\u662f\u51e0\u6708 - \u67d0\u4e2a\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u7684\u8ba1\u6570\u662f\u591a\u5c11 - \u67d0\u5e74\u67d0\u6708\u67d0\u5468\u67d0\u5ba2\u6237\u7684 cat1\u3001cat2\u3001cat3 \u6570\u91cf\u662f\u591a\u5c11\uff1f - \u67d0\u4e2a\u5ba2\u6237\u7684 num1 \u5e73\u5747\u503c\u662f\u591a\u5c11\uff1f - \u7b49\u7b49\u3002 \u4f7f\u7528 pandas \u4e2d\u7684\u805a\u5408\uff0c\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u521b\u5efa\u7c7b\u4f3c\u7684\u529f\u80fd\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u5b9e\u73b0\u3002 def generate_features ( df ): df . loc [:, 'year' ] = df [ 'date' ] . dt . year df . loc [:, 'weekofyear' ] = df [ 'date' ] . dt . weekofyear df . loc [:, 'month' ] = df [ 'date' ] . dt . month df . loc [:, 'dayofweek' ] = df [ 'date' ] . dt . dayofweek df . loc [:, 'weekend' ] = ( df [ 'date' ] . dt . weekday >= 5 ) . astype ( int ) aggs = {} # \u5bf9 'month' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'month' ] = [ 'nunique' , 'mean' ] # \u5bf9 'weekofyear' \u5217\u8fdb\u884c nunique \u548c mean \u805a\u5408 aggs [ 'weekofyear' ] = [ 'nunique' , 'mean' ] # \u5bf9 'num1' \u5217\u8fdb\u884c sum\u3001max\u3001min\u3001mean \u805a\u5408 aggs [ 'num1' ] = [ 'sum' , 'max' , 'min' , 'mean' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c size \u805a\u5408 aggs [ 'customer_id' ] = [ 'size' ] # \u5bf9 'customer_id' \u5217\u8fdb\u884c nunique \u805a\u5408 aggs [ 'customer_id' ] = [ 'nunique' ] # \u5bf9\u6570\u636e\u5e94\u7528\u4e0d\u540c\u7684\u805a\u5408\u51fd\u6570 agg_df = df . groupby ( 'customer_id' ) . agg ( aggs ) # \u91cd\u7f6e\u7d22\u5f15 agg_df = agg_df . reset_index () return agg_df \u8bf7\u6ce8\u610f\uff0c\u5728\u4e0a\u8ff0\u51fd\u6570\u4e2d\uff0c\u6211\u4eec\u8df3\u8fc7\u4e86\u5206\u7c7b\u53d8\u91cf\uff0c\u4f46\u60a8\u53ef\u4ee5\u50cf\u4f7f\u7528\u5176\u4ed6\u805a\u5408\u53d8\u91cf\u4e00\u6837\u4f7f\u7528\u5b83\u4eec\u3002 \u56fe 2\uff1a\u603b\u4f53\u7279\u5f81\u548c\u5176\u4ed6\u7279\u5f81 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u56fe 2 \u4e2d\u7684\u6570\u636e\u4e0e\u5e26\u6709 customer_id \u5217\u7684\u539f\u59cb\u6570\u636e\u5e27\u8fde\u63a5\u8d77\u6765\uff0c\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5e76\u4e0d\u662f\u8981\u9884\u6d4b\u4ec0\u4e48\uff1b\u6211\u4eec\u53ea\u662f\u5728\u521b\u5efa\u901a\u7528\u7279\u5f81\u3002\u4e0d\u8fc7\uff0c\u5982\u679c\u6211\u4eec\u8bd5\u56fe\u5728\u8fd9\u91cc\u9884\u6d4b\u4ec0\u4e48\uff0c\u521b\u5efa\u7279\u5f81\u4f1a\u66f4\u5bb9\u6613\u3002 \u4f8b\u5982\uff0c\u6709\u65f6\u5728\u5904\u7406\u65f6\u95f4\u5e8f\u5217\u95ee\u9898\u65f6\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u7684\u7279\u5f81\u4e0d\u662f\u5355\u4e2a\u503c\uff0c\u800c\u662f\u4e00\u7cfb\u5217\u503c\u3002 \u4f8b\u5982\uff0c\u5ba2\u6237\u5728\u7279\u5b9a\u65f6\u95f4\u6bb5\u5185\u7684\u4ea4\u6613\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4f1a\u521b\u5efa\u4e0d\u540c\u7c7b\u578b\u7684\u7279\u5f81\uff0c\u4f8b\u5982\uff1a\u4f7f\u7528\u6570\u503c\u7279\u5f81\u65f6\uff0c\u5728\u5bf9\u5206\u7c7b\u5217\u8fdb\u884c\u5206\u7ec4\u65f6\uff0c\u4f1a\u5f97\u5230\u7c7b\u4f3c\u4e8e\u65f6\u95f4\u5206\u5e03\u503c\u5217\u8868\u7684\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u4e00\u7cfb\u5217\u7edf\u8ba1\u7279\u5f81\uff0c\u4f8b\u5982 \u5e73\u5747\u503c \u6700\u5927\u503c \u6700\u5c0f\u503c \u72ec\u7279\u6027 \u504f\u659c \u5cf0\u5ea6 Kstat \u767e\u5206\u4f4d\u6570 \u5b9a\u91cf \u5cf0\u503c\u5230\u5cf0\u503c \u4ee5\u53ca\u66f4\u591a \u8fd9\u4e9b\u53ef\u4ee5\u4f7f\u7528\u7b80\u5355\u7684 numpy \u51fd\u6570\u521b\u5efa\uff0c\u5982\u4e0b\u9762\u7684 python \u4ee3\u7801\u6bb5\u6240\u793a\u3002 import numpy as np # \u521b\u5efa\u5b57\u5178\uff0c\u7528\u4e8e\u5b58\u50a8\u4e0d\u540c\u7684\u7edf\u8ba1\u7279\u5f81 feature_dict = {} # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5e73\u5747\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'mean' \u952e\u4e0b feature_dict [ 'mean' ] = np . mean ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5927\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'max' \u952e\u4e0b feature_dict [ 'max' ] = np . max ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6700\u5c0f\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'min' \u952e\u4e0b feature_dict [ 'min' ] = np . min ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u6807\u51c6\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'std' \u952e\u4e0b feature_dict [ 'std' ] = np . std ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u65b9\u5dee\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'var' \u952e\u4e0b feature_dict [ 'var' ] = np . var ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u5dee\u503c\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'ptp' \u952e\u4e0b feature_dict [ 'ptp' ] = np . ptp ( x ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c10\u767e\u5206\u4f4d\u6570\uff08\u5373\u767e\u5206\u4e4b10\u5206\u4f4d\u6570\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_10' \u952e\u4e0b feature_dict [ 'percentile_10' ] = np . percentile ( x , 10 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c60\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_60' \u952e\u4e0b feature_dict [ 'percentile_60' ] = np . percentile ( x , 60 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u7684\u7b2c90\u767e\u5206\u4f4d\u6570\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'percentile_90' \u952e\u4e0b feature_dict [ 'percentile_90' ] = np . percentile ( x , 90 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u76845%\u5206\u4f4d\u6570\uff08\u53730.05\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_5' \u952e\u4e0b feature_dict [ 'quantile_5' ] = np . quantile ( x , 0.05 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768495%\u5206\u4f4d\u6570\uff08\u53730.95\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_95' \u952e\u4e0b feature_dict [ 'quantile_95' ] = np . quantile ( x , 0.95 ) # \u8ba1\u7b97 x \u4e2d\u5143\u7d20\u768499%\u5206\u4f4d\u6570\uff08\u53730.99\u5206\u4f4d\u6570\uff09\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u4e2d\u7684 'quantile_99' \u952e\u4e0b feature_dict [ 'quantile_99' ] = np . quantile ( x , 0.99 ) \u65f6\u95f4\u5e8f\u5217\u6570\u636e\uff08\u6570\u503c\u5217\u8868\uff09\u53ef\u4ee5\u8f6c\u6362\u6210\u8bb8\u591a\u7279\u5f81\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4e00\u4e2a\u540d\u4e3a tsfresh \u7684 python \u5e93\u975e\u5e38\u6709\u7528\u3002 from tsfresh.feature_extraction import feature_calculators as fc # \u8ba1\u7b97 x \u6570\u5217\u7684\u7edd\u5bf9\u80fd\u91cf\uff08abs_energy\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'abs_energy' \u952e\u4e0b feature_dict [ 'abs_energy' ] = fc . abs_energy ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u9ad8\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_above_mean' \u952e\u4e0b feature_dict [ 'count_above_mean' ] = fc . count_above_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u4e2d\u4f4e\u4e8e\u5747\u503c\u7684\u6570\u636e\u70b9\u6570\u91cf\uff0c\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'count_below_mean' \u952e\u4e0b feature_dict [ 'count_below_mean' ] = fc . count_below_mean ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u7edd\u5bf9\u53d8\u5316\uff08mean_abs_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_abs_change' \u952e\u4e0b feature_dict [ 'mean_abs_change' ] = fc . mean_abs_change ( x ) # \u8ba1\u7b97 x \u6570\u5217\u7684\u5747\u503c\u53d8\u5316\u7387\uff08mean_change\uff09\uff0c\u5e76\u5c06\u7ed3\u679c\u5b58\u50a8\u5728 feature_dict \u5b57\u5178\u4e2d\u7684 'mean_change' \u952e\u4e0b feature_dict [ 'mean_change' ] = fc . mean_change ( x ) \u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\uff1btsfresh \u63d0\u4f9b\u4e86\u6570\u767e\u79cd\u7279\u5f81\u548c\u6570\u5341\u79cd\u4e0d\u540c\u7279\u5f81\u7684\u53d8\u4f53\uff0c\u4f60\u53ef\u4ee5\u5c06\u5b83\u4eec\u7528\u4e8e\u57fa\u4e8e\u65f6\u95f4\u5e8f\u5217\uff08\u503c\u5217\u8868\uff09\u7684\u7279\u5f81\u3002\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0cx \u662f\u4e00\u4e2a\u503c\u5217\u8868\u3002\u4f46\u8fd9\u8fd8\u4e0d\u662f\u5168\u90e8\u3002\u60a8\u8fd8\u53ef\u4ee5\u4e3a\u5305\u542b\u6216\u4e0d\u5305\u542b\u5206\u7c7b\u6570\u636e\u7684\u6570\u503c\u6570\u636e\u521b\u5efa\u8bb8\u591a\u5176\u4ed6\u7279\u5f81\u3002\u751f\u6210\u8bb8\u591a\u7279\u5f81\u7684\u4e00\u4e2a\u7b80\u5355\u65b9\u6cd5\u5c31\u662f\u521b\u5efa\u4e00\u5806\u591a\u9879\u5f0f\u7279\u5f81\u3002\u4f8b\u5982\uff0c\u4ece\u4e24\u4e2a\u7279\u5f81 \"a \"\u548c \"b \"\u751f\u6210\u7684\u4e8c\u7ea7\u591a\u9879\u5f0f\u7279\u5f81\u5305\u62ec \"a\"\u3001\"b\"\u3001\"ab\"\u3001\"a^2 \"\u548c \"b^2\"\u3002 import numpy as np df = pd . DataFrame ( np . random . rand ( 100 , 2 ), columns = [ f \"f_ { i } \" for i in range ( 1 , 3 )]) \u5982\u56fe 3 \u6240\u793a\uff0c\u5b83\u7ed9\u51fa\u4e86\u4e00\u4e2a\u6570\u636e\u8868\u3002 \u56fe 3\uff1a\u5305\u542b\u4e24\u4e2a\u6570\u5b57\u7279\u5f81\u7684\u968f\u673a\u6570\u636e\u8868 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u7684 PolynomialFeatures \u521b\u5efa\u4e24\u6b21\u591a\u9879\u5f0f\u7279\u5f81\u3002 from sklearn import preprocessing # \u6307\u5b9a\u591a\u9879\u5f0f\u7684\u6b21\u6570\u4e3a 2\uff0c\u4e0d\u4ec5\u8003\u8651\u4ea4\u4e92\u9879\uff0c\u4e0d\u5305\u62ec\u504f\u5dee\uff08include_bias=False\uff09 pf = preprocessing . PolynomialFeatures ( degree = 2 , interaction_only = False , include_bias = False ) # \u62df\u5408\uff0c\u521b\u5efa\u591a\u9879\u5f0f\u7279\u5f81 pf . fit ( df ) # \u8f6c\u6362\u6570\u636e poly_feats = pf . transform ( df ) # \u83b7\u53d6\u751f\u6210\u7684\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf num_feats = poly_feats . shape [ 1 ] # \u4e3a\u65b0\u751f\u6210\u7684\u7279\u5f81\u547d\u540d df_transformed = pd . DataFrame ( poly_feats , columns = [ f \"f_ { i } \" for i in range ( 1 , num_feats + 1 )] ) \u8fd9\u6837\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u6570\u636e\u8868\uff0c\u5982\u56fe 4 \u6240\u793a\u3002 \u56fe 4\uff1a\u5e26\u6709\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6837\u672c\u6570\u636e\u8868 \u73b0\u5728\uff0c\u6211\u4eec\u521b\u5efa\u4e86\u4e00\u4e9b\u591a\u9879\u5f0f\u7279\u5f81\u3002\u5982\u679c\u521b\u5efa\u7684\u662f\u4e09\u6b21\u591a\u9879\u5f0f\u7279\u5f81\uff0c\u6700\u7ec8\u603b\u5171\u4f1a\u6709\u4e5d\u4e2a\u7279\u5f81\u3002\u7279\u5f81\u7684\u6570\u91cf\u8d8a\u591a\uff0c\u591a\u9879\u5f0f\u7279\u5f81\u7684\u6570\u91cf\u4e5f\u5c31\u8d8a\u591a\uff0c\u800c\u4e14\u4f60\u8fd8\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5982\u679c\u6570\u636e\u96c6\u4e2d\u6709\u5f88\u591a\u6837\u672c\uff0c\u90a3\u4e48\u521b\u5efa\u8fd9\u7c7b\u7279\u5f81\u5c31\u9700\u8981\u82b1\u8d39\u4e00\u4e9b\u65f6\u95f4\u3002 \u56fe 5\uff1a\u6570\u5b57\u7279\u5f81\u5217\u7684\u76f4\u65b9\u56fe \u53e6\u4e00\u4e2a\u6709\u8da3\u7684\u529f\u80fd\u662f\u5c06\u6570\u5b57\u8f6c\u6362\u4e3a\u7c7b\u522b\u3002\u8fd9\u5c31\u662f\u6240\u8c13\u7684 \u5206\u7bb1 \u3002\u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u56fe 5\uff0c\u5b83\u663e\u793a\u4e86\u4e00\u4e2a\u968f\u673a\u6570\u5b57\u7279\u5f81\u7684\u6837\u672c\u76f4\u65b9\u56fe\u3002\u6211\u4eec\u5728\u8be5\u56fe\u4e2d\u4f7f\u7528\u4e8610\u4e2a\u5206\u7bb1\uff0c\u53ef\u4ee5\u770b\u5230\u6211\u4eec\u53ef\u4ee5\u5c06\u6570\u636e\u5206\u4e3a10\u4e2a\u90e8\u5206\u3002\u8fd9\u53ef\u4ee5\u4f7f\u7528 pandas \u7684cat\u51fd\u6570\u6765\u5b9e\u73b0\u3002 # \u521b\u5efa10\u4e2a\u5206\u7bb1 df [ \"f_bin_10\" ] = pd . cut ( df [ \"f_1\" ], bins = 10 , labels = False ) # \u521b\u5efa100\u4e2a\u5206\u7bb1 df [ \"f_bin_100\" ] = pd . cut ( df [ \"f_1\" ], bins = 100 , labels = False ) \u5982\u56fe 6 \u6240\u793a\uff0c\u8fd9\u5c06\u5728\u6570\u636e\u5e27\u4e2d\u751f\u6210\u4e24\u4e2a\u65b0\u7279\u5f81\u3002 \u56fe 6\uff1a\u6570\u503c\u7279\u5f81\u5206\u7bb1 \u5f53\u4f60\u8fdb\u884c\u5206\u7c7b\u65f6\uff0c\u53ef\u4ee5\u540c\u65f6\u4f7f\u7528\u5206\u7bb1\u548c\u539f\u59cb\u7279\u5f81\u3002\u6211\u4eec\u5c06\u5728\u672c\u7ae0\u540e\u534a\u90e8\u5206\u5b66\u4e60\u66f4\u591a\u5173\u4e8e\u9009\u62e9\u7279\u5f81\u7684\u77e5\u8bc6\u3002\u5206\u7bb1\u8fd8\u53ef\u4ee5\u5c06\u6570\u5b57\u7279\u5f81\u89c6\u4e3a\u5206\u7c7b\u7279\u5f81\u3002 \u53e6\u4e00\u79cd\u53ef\u4ee5\u4ece\u6570\u503c\u7279\u5f81\u4e2d\u521b\u5efa\u7684\u6709\u8da3\u7279\u5f81\u7c7b\u578b\u662f\u5bf9\u6570\u53d8\u6362\u3002\u8bf7\u770b\u56fe 7 \u4e2d\u7684\u7279\u5f81 f_3\u3002 \u4e0e\u5176\u4ed6\u65b9\u5dee\u8f83\u5c0f\u7684\u7279\u5f81\u76f8\u6bd4\uff08\u5047\u8bbe\u5982\u6b64\uff09\uff0cf_3 \u662f\u4e00\u79cd\u65b9\u5dee\u975e\u5e38\u5927\u7684\u7279\u6b8a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5e0c\u671b\u964d\u4f4e\u8fd9\u4e00\u5217\u7684\u65b9\u5dee\uff0c\u8fd9\u53ef\u4ee5\u901a\u8fc7\u5bf9\u6570\u53d8\u6362\u6765\u5b9e\u73b0\u3002 f_3 \u5217\u7684\u503c\u8303\u56f4\u4e3a 0 \u5230 10000\uff0c\u76f4\u65b9\u56fe\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u7279\u5f81 f_3 \u7684\u76f4\u65b9\u56fe \u6211\u4eec\u53ef\u4ee5\u5bf9\u8fd9\u4e00\u5217\u5e94\u7528 log(1 + x) \u6765\u51cf\u5c11\u5176\u65b9\u5dee\u3002\u56fe 9 \u663e\u793a\u4e86\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u76f4\u65b9\u56fe\u7684\u53d8\u5316\u3002 \u56fe 9\uff1a\u5e94\u7528\u5bf9\u6570\u53d8\u6362\u540e\u7684 f_3 \u76f4\u65b9\u56fe \u8ba9\u6211\u4eec\u6765\u770b\u770b\u4e0d\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u548c\u4f7f\u7528\u5bf9\u6570\u53d8\u6362\u7684\u65b9\u5dee\u3002 In [ X ]: df . f_3 . var () Out [ X ]: 8077265.875858586 In [ X ]: df . f_3 . apply ( lambda x : np . log ( 1 + x )) . var () Out [ X ]: 0.6058771732119975 \u6709\u65f6\uff0c\u4e5f\u53ef\u4ee5\u7528\u6307\u6570\u6765\u4ee3\u66ff\u5bf9\u6570\u3002\u4e00\u79cd\u975e\u5e38\u6709\u8da3\u7684\u60c5\u51b5\u662f\uff0c\u60a8\u4f7f\u7528\u57fa\u4e8e\u5bf9\u6570\u7684\u8bc4\u4f30\u6307\u6807\uff0c\u4f8b\u5982 RMSLE\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u5728\u5bf9\u6570\u53d8\u6362\u7684\u76ee\u6807\u4e0a\u8fdb\u884c\u8bad\u7ec3\uff0c\u7136\u540e\u5728\u9884\u6d4b\u65f6\u4f7f\u7528\u6307\u6570\u503c\u8f6c\u6362\u56de\u539f\u59cb\u503c\u3002\u8fd9\u5c06\u6709\u52a9\u4e8e\u9488\u5bf9\u6307\u6807\u4f18\u5316\u6a21\u578b\u3002 \u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u7c7b\u6570\u5b57\u7279\u5f81\u90fd\u662f\u57fa\u4e8e\u76f4\u89c9\u521b\u5efa\u7684\u3002\u6ca1\u6709\u516c\u5f0f\u53ef\u5faa\u3002\u5982\u679c\u60a8\u4ece\u4e8b\u7684\u662f\u67d0\u4e00\u884c\u4e1a\uff0c\u60a8\u5c06\u521b\u5efa\u7279\u5b9a\u884c\u4e1a\u7684\u7279\u5f81\u3002 \u5728\u5904\u7406\u5206\u7c7b\u53d8\u91cf\u548c\u6570\u503c\u53d8\u91cf\u65f6\uff0c\u53ef\u80fd\u4f1a\u9047\u5230\u7f3a\u5931\u503c\u3002\u5728\u4e0a\u4e00\u7ae0\u4e2d\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u4e00\u4e9b\u5904\u7406\u5206\u7c7b\u7279\u5f81\u4e2d\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\uff0c\u4f46\u8fd8\u6709\u66f4\u591a\u65b9\u6cd5\u53ef\u4ee5\u5904\u7406\u7f3a\u5931\u503c/NaN \u503c\u3002\u8fd9\u4e5f\u88ab\u89c6\u4e3a\u7279\u5f81\u5de5\u7a0b\u3002 \u5982\u679c\u5728\u5206\u7c7b\u7279\u5f81\u4e2d\u9047\u5230\u7f3a\u5931\u503c\uff0c\u5c31\u5c06\u5176\u89c6\u4e3a\u4e00\u4e2a\u65b0\u7684\u7c7b\u522b\uff01\u8fd9\u6837\u505a\u867d\u7136\u7b80\u5355\uff0c\u4f46\uff08\u51e0\u4e4e\uff09\u603b\u662f\u6709\u6548\u7684\uff01 \u5728\u6570\u503c\u6570\u636e\u4e2d\u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u65b9\u6cd5\u662f\u9009\u62e9\u4e00\u4e2a\u5728\u7279\u5b9a\u7279\u5f81\u4e2d\u6ca1\u6709\u51fa\u73b0\u7684\u503c\uff0c\u7136\u540e\u7528\u5b83\u6765\u586b\u8865\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u7279\u5f81\u4e2d\u6ca1\u6709 0\u3002\u8fd9\u662f\u5176\u4e2d\u4e00\u79cd\u65b9\u6cd5\uff0c\u4f46\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\u3002\u5bf9\u4e8e\u6570\u503c\u6570\u636e\u6765\u8bf4\uff0c\u6bd4\u586b\u5145 0 \u66f4\u6709\u6548\u7684\u65b9\u6cd5\u4e4b\u4e00\u662f\u4f7f\u7528\u5e73\u5747\u503c\u8fdb\u884c\u586b\u5145\u3002\u60a8\u4e5f\u53ef\u4ee5\u5c1d\u8bd5\u4f7f\u7528\u8be5\u7279\u5f81\u6240\u6709\u503c\u7684\u4e2d\u4f4d\u6570\u6765\u586b\u5145\uff0c\u6216\u8005\u4f7f\u7528\u6700\u5e38\u89c1\u7684\u503c\u6765\u586b\u5145\u7f3a\u5931\u503c\u3002\u8fd9\u6837\u505a\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002 \u586b\u8865\u7f3a\u5931\u503c\u7684\u4e00\u79cd\u9ad8\u7ea7\u65b9\u6cd5\u662f\u4f7f\u7528 K \u8fd1\u90bb\u6cd5 \u3002 \u60a8\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\uff0c\u7136\u540e\u5229\u7528\u67d0\u79cd\u8ddd\u79bb\u5ea6\u91cf\uff08\u4f8b\u5982\u6b27\u6c0f\u8ddd\u79bb\uff09\u627e\u5230\u6700\u8fd1\u7684\u90bb\u5c45\u3002\u7136\u540e\u53d6\u6240\u6709\u8fd1\u90bb\u7684\u5e73\u5747\u503c\u6765\u586b\u8865\u7f3a\u5931\u503c\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528 KNN \u6765\u586b\u8865\u8fd9\u6837\u7684\u7f3a\u5931\u503c\u3002 \u56fe 10\uff1a\u6709\u7f3a\u5931\u503c\u7684\u4e8c\u7ef4\u6570\u7ec4 \u8ba9\u6211\u4eec\u770b\u770b KNN \u662f\u5982\u4f55\u5904\u7406\u56fe 10 \u6240\u793a\u7684\u7f3a\u5931\u503c\u77e9\u9635\u7684\u3002 import numpy as np from sklearn import impute # \u751f\u6210\u7ef4\u5ea6\u4e3a (10, 6) \u7684\u968f\u673a\u6574\u6570\u77e9\u9635 X\uff0c\u6570\u503c\u8303\u56f4\u5728 1 \u5230 14 \u4e4b\u95f4 X = np . random . randint ( 1 , 15 , ( 10 , 6 )) # \u6570\u636e\u7c7b\u578b\u8f6c\u6362\u4e3a float X = X . astype ( float ) # \u5728\u77e9\u9635 X \u4e2d\u968f\u673a\u9009\u62e9 10 \u4e2a\u4f4d\u7f6e\uff0c\u5c06\u8fd9\u4e9b\u4f4d\u7f6e\u7684\u5143\u7d20\u8bbe\u7f6e\u4e3a NaN\uff08\u7f3a\u5931\u503c\uff09 X . ravel ()[ np . random . choice ( X . size , 10 , replace = False )] = np . nan # \u521b\u5efa\u4e00\u4e2a KNNImputer \u5bf9\u8c61 knn_imputer\uff0c\u6307\u5b9a\u90bb\u5c45\u6570\u91cf\u4e3a 2 knn_imputer = impute . KNNImputer ( n_neighbors = 2 ) # # \u4f7f\u7528 knn_imputer \u5bf9\u77e9\u9635 X \u8fdb\u884c\u62df\u5408\u548c\u8f6c\u6362\uff0c\u7528 K-\u6700\u8fd1\u90bb\u65b9\u6cd5\u586b\u8865\u7f3a\u5931\u503c knn_imputer . fit_transform ( X ) \u5982\u56fe 11 \u6240\u793a\uff0c\u5b83\u586b\u5145\u4e86\u4e0a\u8ff0\u77e9\u9635\u3002 \u56fe 11\uff1aKNN\u4f30\u7b97\u7684\u6570\u503c \u53e6\u4e00\u79cd\u5f25\u8865\u5217\u7f3a\u5931\u503c\u7684\u65b9\u6cd5\u662f\u8bad\u7ec3\u56de\u5f52\u6a21\u578b\uff0c\u8bd5\u56fe\u6839\u636e\u5176\u4ed6\u5217\u9884\u6d4b\u67d0\u5217\u7684\u7f3a\u5931\u503c\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u4ece\u6709\u7f3a\u5931\u503c\u7684\u4e00\u5217\u5f00\u59cb\uff0c\u5c06\u8fd9\u4e00\u5217\u4f5c\u4e3a\u65e0\u7f3a\u5931\u503c\u56de\u5f52\u6a21\u578b\u7684\u76ee\u6807\u5217\u3002\u73b0\u5728\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u6240\u6709\u5176\u4ed6\u5217\uff0c\u5bf9\u76f8\u5173\u5217\u4e2d\u6ca1\u6709\u7f3a\u5931\u503c\u7684\u6837\u672c\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\uff0c\u7136\u540e\u5c1d\u8bd5\u9884\u6d4b\u4e4b\u524d\u5220\u9664\u7684\u6837\u672c\u7684\u76ee\u6807\u5217\uff08\u540c\u4e00\u5217\uff09\u3002\u8fd9\u6837\uff0c\u57fa\u4e8e\u6a21\u578b\u7684\u4f30\u7b97\u5c31\u4f1a\u66f4\u52a0\u7a33\u5065\u3002 \u8bf7\u52a1\u5fc5\u8bb0\u4f4f\uff0c\u5bf9\u4e8e\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6ca1\u6709\u5fc5\u8981\u8fdb\u884c\u6570\u503c\u5f52\u4e00\u5316\uff0c\u56e0\u4e3a\u5b83\u4eec\u53ef\u4ee5\u81ea\u884c\u5904\u7406\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u6240\u5c55\u793a\u7684\u53ea\u662f\u521b\u5efa\u4e00\u822c\u7279\u5f81\u7684\u4e00\u4e9b\u65b9\u6cd5\u3002\u73b0\u5728\uff0c\u5047\u8bbe\u60a8\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u9884\u6d4b\u4e0d\u540c\u5546\u54c1\uff08\u6bcf\u5468\u6216\u6bcf\u6708\uff09\u5546\u5e97\u9500\u552e\u989d\u7684\u95ee\u9898\u3002\u60a8\u6709\u5546\u54c1\uff0c\u4e5f\u6709\u5546\u5e97 ID\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u4ee5\u521b\u5efa\u6bcf\u4e2a\u5546\u5e97\u7684\u5546\u54c1\u7b49\u7279\u5f81\u3002\u73b0\u5728\uff0c\u8fd9\u662f\u4e0a\u6587\u6ca1\u6709\u8ba8\u8bba\u7684\u7279\u5f81\u4e4b\u4e00\u3002\u8fd9\u7c7b\u7279\u5f81\u4e0d\u80fd\u4e00\u6982\u800c\u8bba\uff0c\u5b8c\u5168\u6765\u81ea\u4e8e\u9886\u57df\u3001\u6570\u636e\u548c\u4e1a\u52a1\u77e5\u8bc6\u3002\u67e5\u770b\u6570\u636e\uff0c\u627e\u51fa\u9002\u5408\u7684\u7279\u5f81\uff0c\u7136\u540e\u521b\u5efa\u76f8\u5e94\u7684\u7279\u5f81\u3002\u5982\u679c\u60a8\u4f7f\u7528\u7684\u662f\u903b\u8f91\u56de\u5f52\u7b49\u7ebf\u6027\u6a21\u578b\u6216 SVM \u7b49\u6a21\u578b\uff0c\u8bf7\u52a1\u5fc5\u8bb0\u4f4f\u5bf9\u7279\u5f81\u8fdb\u884c\u7f29\u653e\u6216\u5f52\u4e00\u5316\u5904\u7406\u3002\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u65e0\u9700\u5bf9\u7279\u5f81\u8fdb\u884c\u4efb\u4f55\u5f52\u4e00\u5316\u5904\u7406\u5373\u53ef\u6b63\u5e38\u5de5\u4f5c\u3002","title":"\u7279\u5f81\u5de5\u7a0b"},{"location":"%E7%89%B9%E5%BE%81%E9%80%89%E6%8B%A9/","text":"\u7279\u5f81\u9009\u62e9 \u5f53\u4f60\u521b\u5efa\u4e86\u6210\u5343\u4e0a\u4e07\u4e2a\u7279\u5f81\u540e\uff0c\u5c31\u8be5\u4ece\u4e2d\u6311\u9009\u51fa\u51e0\u4e2a\u4e86\u3002\u4f46\u662f\uff0c\u6211\u4eec\u7edd\u4e0d\u5e94\u8be5\u521b\u5efa\u6210\u767e\u4e0a\u5343\u4e2a\u65e0\u7528\u7684\u7279\u5f81\u3002\u7279\u5f81\u8fc7\u591a\u4f1a\u5e26\u6765\u4e00\u4e2a\u4f17\u6240\u5468\u77e5\u7684\u95ee\u9898\uff0c\u5373 \"\u7ef4\u5ea6\u8bc5\u5492\"\u3002\u5982\u679c\u4f60\u6709\u5f88\u591a\u7279\u5f81\uff0c\u4f60\u4e5f\u5fc5\u987b\u6709\u5f88\u591a\u8bad\u7ec3\u6837\u672c\u6765\u6355\u6349\u6240\u6709\u7279\u5f81\u3002\u4ec0\u4e48\u662f \"\u5927\u91cf \"\u5e76\u6ca1\u6709\u6b63\u786e\u7684\u5b9a\u4e49\uff0c\u8fd9\u9700\u8981\u60a8\u901a\u8fc7\u6b63\u786e\u9a8c\u8bc1\u60a8\u7684\u6a21\u578b\u548c\u68c0\u67e5\u8bad\u7ec3\u6a21\u578b\u6240\u9700\u7684\u65f6\u95f4\u6765\u786e\u5b9a\u3002 \u9009\u62e9\u7279\u5f81\u7684\u6700\u7b80\u5355\u65b9\u6cd5\u662f \u5220\u9664\u65b9\u5dee\u975e\u5e38\u5c0f\u7684\u7279\u5f81 \u3002\u5982\u679c\u7279\u5f81\u7684\u65b9\u5dee\u975e\u5e38\u5c0f\uff08\u5373\u975e\u5e38\u63a5\u8fd1\u4e8e 0\uff09\uff0c\u5b83\u4eec\u5c31\u63a5\u8fd1\u4e8e\u5e38\u91cf\uff0c\u56e0\u6b64\u6839\u672c\u4e0d\u4f1a\u7ed9\u4efb\u4f55\u6a21\u578b\u589e\u52a0\u4efb\u4f55\u4ef7\u503c\u3002\u6700\u597d\u7684\u529e\u6cd5\u5c31\u662f\u53bb\u6389\u5b83\u4eec\uff0c\u4ece\u800c\u964d\u4f4e\u590d\u6742\u5ea6\u3002\u8bf7\u6ce8\u610f\uff0c\u65b9\u5dee\u4e5f\u53d6\u51b3\u4e8e\u6570\u636e\u7684\u7f29\u653e\u3002 Scikit-learn \u7684 VarianceThreshold \u5b9e\u73b0\u4e86\u8fd9\u4e00\u70b9\u3002 from sklearn.feature_selection import VarianceThreshold data = ... # \u521b\u5efa VarianceThreshold \u5bf9\u8c61 var_thresh\uff0c\u6307\u5b9a\u65b9\u5dee\u9608\u503c\u4e3a 0.1 var_thresh = VarianceThreshold ( threshold = 0.1 ) # \u4f7f\u7528 var_thresh \u5bf9\u6570\u636e data \u8fdb\u884c\u62df\u5408\u548c\u53d8\u6362\uff0c\u5c06\u65b9\u5dee\u4f4e\u4e8e\u9608\u503c\u7684\u7279\u5f81\u79fb\u9664 transformed_data = var_thresh . fit_transform ( data ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u5220\u9664\u76f8\u5173\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u3002\u8981\u8ba1\u7b97\u4e0d\u540c\u6570\u5b57\u7279\u5f81\u4e4b\u95f4\u7684\u76f8\u5173\u6027\uff0c\u53ef\u4ee5\u4f7f\u7528\u76ae\u5c14\u900a\u76f8\u5173\u6027\u3002 import pandas as pd from sklearn.datasets import fetch_california_housing # \u52a0\u8f7d\u6570\u636e data = fetch_california_housing () # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u77e9\u9635 X X = data [ \"data\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u7684\u5217\u540d col_names = data [ \"feature_names\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf y y = data [ \"target\" ] df = pd . DataFrame ( X , columns = col_names ) # \u6dfb\u52a0 MedInc_Sqrt \u5217\uff0c\u662f MedInc \u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u5e73\u65b9\u6839\u8fd0\u7b97\u7684\u7ed3\u679c df . loc [:, \"MedInc_Sqrt\" ] = df . MedInc . apply ( np . sqrt ) # \u8ba1\u7b97\u76ae\u5c14\u900a\u76f8\u5173\u6027\u77e9\u9635 df . corr () \u5f97\u51fa\u76f8\u5173\u77e9\u9635\uff0c\u5982\u56fe 1 \u6240\u793a\u3002 \u56fe 1\uff1a\u76ae\u5c14\u900a\u76f8\u5173\u77e9\u9635\u6837\u672c \u6211\u4eec\u770b\u5230\uff0cMedInc_Sqrt \u4e0e MedInc \u7684\u76f8\u5173\u6027\u975e\u5e38\u9ad8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5220\u9664\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u4e00\u4e9b \u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5 \u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u53ea\u4e0d\u8fc7\u662f\u9488\u5bf9\u7ed9\u5b9a\u76ee\u6807\u5bf9\u6bcf\u4e2a\u7279\u5f81\u8fdb\u884c\u8bc4\u5206\u3002 \u4e92\u4fe1\u606f \u3001 \u65b9\u5dee\u5206\u6790 F \u68c0\u9a8c\u548c chi2 \u662f\u4e00\u4e9b\u6700\u5e38\u7528\u7684\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u3002\u5728 scikit- learn \u4e2d\uff0c\u6709\u4e24\u79cd\u65b9\u6cd5\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e9b\u65b9\u6cd5\u3002 - SelectKBest\uff1a\u4fdd\u7559\u5f97\u5206\u6700\u9ad8\u7684 k \u4e2a\u7279\u5f81 - SelectPercentile\uff1a\u4fdd\u7559\u7528\u6237\u6307\u5b9a\u767e\u5206\u6bd4\u5185\u7684\u9876\u7ea7\u7279\u5f81\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u53ea\u6709\u975e\u8d1f\u6570\u636e\u624d\u80fd\u4f7f\u7528 chi2\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u5f53\u6211\u4eec\u6709\u4e00\u4e9b\u5355\u8bcd\u6216\u57fa\u4e8e tf-idf \u7684\u7279\u5f81\u65f6\uff0c\u8fd9\u662f\u4e00\u79cd\u7279\u522b\u6709\u7528\u7684\u7279\u5f81\u9009\u62e9\u6280\u672f\u3002\u6700\u597d\u4e3a\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u521b\u5efa\u4e00\u4e2a\u5305\u88c5\u5668\uff0c\u51e0\u4e4e\u53ef\u4ee5\u7528\u4e8e\u4efb\u4f55\u65b0\u95ee\u9898\u3002 from sklearn.feature_selection import chi2 from sklearn.feature_selection import f_classif from sklearn.feature_selection import f_regression from sklearn.feature_selection import mutual_info_classif from sklearn.feature_selection import mutual_info_regression from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectPercentile class UnivariateFeatureSelction : def __init__ ( self , n_features , problem_type , scoring ): # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u5206\u7c7b\u95ee\u9898 if problem_type == \"classification\" : # \u521b\u5efa\u5b57\u5178 valid_scoring \uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_classif\" : f_classif , \"chi2\" : chi2 , \"mutual_info_classif\" : mutual_info_classif } # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u56de\u5f52\u95ee\u9898 else : # \u521b\u5efa\u5b57\u5178 valid_scoring\uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_regression\" : f_regression , \"mutual_info_regression\" : mutual_info_regression } # \u68c0\u67e5\u7279\u5f81\u91cd\u8981\u6027\u65b9\u5f0f\u662f\u5426\u5728\u5b57\u5178\u4e2d if scoring not in valid_scoring : raise Exception ( \"Invalid scoring function\" ) # \u68c0\u67e5 n_features \u7684\u7c7b\u578b\uff0c\u5982\u679c\u662f\u6574\u6570\uff0c\u5219\u4f7f\u7528 SelectKBest \u8fdb\u884c\u7279\u5f81\u9009\u62e9 if isinstance ( n_features , int ): self . selection = SelectKBest ( valid_scoring [ scoring ], k = n_features ) # \u5982\u679c n_features \u662f\u6d6e\u70b9\u6570\uff0c\u5219\u4f7f\u7528 SelectPercentile \u8fdb\u884c\u7279\u5f81\u9009\u62e9 elif isinstance ( n_features , float ): self . selection = SelectPercentile ( valid_scoring [ scoring ], percentile = int ( n_features * 100 ) ) # \u5982\u679c n_features \u7c7b\u578b\u65e0\u6548\uff0c\u5f15\u53d1\u5f02\u5e38 else : raise Exception ( \"Invalid type of feature\" ) # \u5b9a\u4e49 fit \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 def fit ( self , X , y ): return self . selection . fit ( X , y ) # \u5b9a\u4e49 transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def transform ( self , X ): return self . selection . transform ( X ) # \u5b9a\u4e49 fit_transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668\u5e76\u540c\u65f6\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def fit_transform ( self , X , y ): return self . selection . fit_transform ( X , y ) \u4f7f\u7528\u8be5\u7c7b\u975e\u5e38\u7b80\u5355\u3002 # \u5b9e\u4f8b\u5316\u7279\u5f81\u9009\u62e9\u5668\uff0c\u4fdd\u7559\u524d10%\u7684\u7279\u5f81\uff0c\u56de\u5f52\u95ee\u9898\uff0c\u4f7f\u7528f_regression\u8861\u91cf\u7279\u5f81\u91cd\u8981\u6027 ufs = UnivariateFeatureSelction ( n_features = 0.1 , problem_type = \"regression\" , scoring = \"f_regression\" ) # \u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 ufs . fit ( X , y ) # \u7279\u5f81\u8f6c\u6362 X_transformed = ufs . transform ( X ) \u8fd9\u6837\u5c31\u80fd\u6ee1\u8db3\u5927\u90e8\u5206\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u7684\u9700\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u521b\u5efa\u8f83\u5c11\u800c\u91cd\u8981\u7684\u7279\u5f81\u901a\u5e38\u6bd4\u521b\u5efa\u6570\u4ee5\u767e\u8ba1\u7684\u7279\u5f81\u8981\u597d\u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u4e0d\u4e00\u5b9a\u603b\u662f\u8868\u73b0\u826f\u597d\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u4eba\u4eec\u66f4\u559c\u6b22\u4f7f\u7528\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 \u4f7f\u7528\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u7684\u6700\u7b80\u5355\u5f62\u5f0f\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u3002\u5728\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u4e2d\uff0c\u7b2c\u4e00\u6b65\u662f\u9009\u62e9\u4e00\u4e2a\u6a21\u578b\u3002\u7b2c\u4e8c\u6b65\u662f\u9009\u62e9\u635f\u5931/\u8bc4\u5206\u51fd\u6570\u3002\u7b2c\u4e09\u6b65\u4e5f\u662f\u6700\u540e\u4e00\u6b65\u662f\u53cd\u590d\u8bc4\u4f30\u6bcf\u4e2a\u7279\u5f81\uff0c\u5982\u679c\u80fd\u63d0\u9ad8\u635f\u5931/\u8bc4\u5206\uff0c\u5c31\u5c06\u5176\u6dfb\u52a0\u5230 \"\u597d \"\u7279\u5f81\u5217\u8868\u4e2d\u3002\u6ca1\u6709\u6bd4\u8fd9\u66f4\u7b80\u5355\u7684\u4e86\u3002\u4f46\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u8fd9\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u662f\u6709\u539f\u56e0\u7684\u3002\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u8fc7\u7a0b\u5728\u6bcf\u6b21\u8bc4\u4f30\u7279\u5f81\u65f6\u90fd\u4f1a\u9002\u5408\u7ed9\u5b9a\u7684\u6a21\u578b\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u8ba1\u7b97\u6210\u672c\u975e\u5e38\u9ad8\u3002\u5b8c\u6210\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u4e5f\u9700\u8981\u5927\u91cf\u65f6\u95f4\u3002\u5982\u679c\u4e0d\u6b63\u786e\u4f7f\u7528\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\uff0c\u751a\u81f3\u4f1a\u5bfc\u81f4\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn.datasets import make_classification class GreedyFeatureSelection : # \u5b9a\u4e49\u8bc4\u4f30\u5206\u6570\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd def evaluate_score ( self , X , y ): # \u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u8bad\u7ec3\u6a21\u578b model . fit ( X , y ) # \u9884\u6d4b\u6982\u7387\u503c predictions = model . predict_proba ( X )[:, 1 ] # \u8ba1\u7b97 AUC \u5206\u6570 auc = metrics . roc_auc_score ( y , predictions ) return auc # \u7279\u5f81\u9009\u62e9\u51fd\u6570 def _feature_selection ( self , X , y ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6700\u4f73\u7279\u5f81\u548c\u6700\u4f73\u5206\u6570 good_features = [] best_scores = [] # \u83b7\u53d6\u7279\u5f81\u6570\u91cf num_features = X . shape [ 1 ] # \u5f00\u59cb\u7279\u5f81\u9009\u62e9\u7684\u5faa\u73af while True : this_feature = None best_score = 0 # \u904d\u5386\u6bcf\u4e2a\u7279\u5f81 for feature in range ( num_features ): if feature in good_features : continue selected_features = good_features + [ feature ] xtrain = X [:, selected_features ] score = self . evaluate_score ( xtrain , y ) # \u5982\u679c\u5f53\u524d\u7279\u5f81\u7684\u5f97\u5206\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u5f97\u5206\uff0c\u5219\u66f4\u65b0 if score > best_score : this_feature = feature best_score = score # \u82e5\u627e\u5230\u4e86\u65b0\u7684\u6700\u4f73\u7279\u5f81 if this_feature != None : # \u7279\u5f81\u6dfb\u52a0\u5230 good_features \u5217\u8868 good_features . append ( this_feature ) # \u5f97\u5206\u6dfb\u52a0\u5230 best_scores \u5217\u8868 best_scores . append ( best_score ) # \u5982\u679c best_scores \u5217\u8868\u957f\u5ea6\u5927\u4e8e2\uff0c\u5e76\u4e14\u6700\u540e\u4e24\u4e2a\u5f97\u5206\u76f8\u6bd4\u8f83\u5dee\uff0c\u5219\u7ed3\u675f\u5faa\u73af if len ( best_scores ) > 2 : if best_scores [ - 1 ] < best_scores [ - 2 ]: break # \u8fd4\u56de\u6700\u4f73\u7279\u5f81\u7684\u5f97\u5206\u5217\u8868\u548c\u6700\u4f73\u7279\u5f81\u5217\u8868 return best_scores [: - 1 ], good_features [: - 1 ] # \u5b9a\u4e49\u7c7b\u7684\u8c03\u7528\u65b9\u6cd5\uff0c\u7528\u4e8e\u6267\u884c\u7279\u5f81\u9009\u62e9 def __call__ ( self , X , y ): scores , features = self . _feature_selection ( X , y ) return X [:, features ], scores if __name__ == \"__main__\" : # \u751f\u6210\u4e00\u4e2a\u793a\u4f8b\u7684\u5206\u7c7b\u6570\u636e\u96c6 X \u548c\u6807\u7b7e y X , y = make_classification ( n_samples = 1000 , n_features = 100 ) # \u5b9e\u4f8b\u5316 GreedyFeatureSelection \u7c7b\uff0c\u5e76\u4f7f\u7528 __call__ \u65b9\u6cd5\u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed , scores = GreedyFeatureSelection ()( X , y ) \u8fd9\u79cd\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u4f1a\u8fd4\u56de\u5206\u6570\u548c\u7279\u5f81\u7d22\u5f15\u5217\u8868\u3002\u56fe 2 \u663e\u793a\u4e86\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u589e\u52a0\u4e00\u4e2a\u65b0\u7279\u5f81\u540e\uff0c\u5206\u6570\u662f\u5982\u4f55\u63d0\u9ad8\u7684\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u67d0\u4e00\u70b9\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u65e0\u6cd5\u63d0\u9ad8\u5206\u6570\u4e86\uff0c\u8fd9\u5c31\u662f\u6211\u4eec\u505c\u6b62\u7684\u5730\u65b9\u3002 \u53e6\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u9012\u5f52\u7279\u5f81\u6d88\u9664\u6cd5\uff08RFE\uff09\u3002\u5728\u524d\u4e00\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a\u7279\u5f81\u5f00\u59cb\uff0c\u7136\u540e\u4e0d\u65ad\u6dfb\u52a0\u65b0\u7684\u7279\u5f81\uff0c\u4f46\u5728 RFE \u4e2d\uff0c\u6211\u4eec\u4ece\u6240\u6709\u7279\u5f81\u5f00\u59cb\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u4e0d\u65ad\u53bb\u9664\u4e00\u4e2a\u5bf9\u7ed9\u5b9a\u6a21\u578b\u63d0\u4f9b\u6700\u5c0f\u503c\u7684\u7279\u5f81\u3002\u4f46\u6211\u4eec\u5982\u4f55\u77e5\u9053\u54ea\u4e2a\u7279\u5f81\u7684\u4ef7\u503c\u6700\u5c0f\u5462\uff1f\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7ebf\u6027\u652f\u6301\u5411\u91cf\u673a\uff08SVM\uff09\u6216\u903b\u8f91\u56de\u5f52\u7b49\u6a21\u578b\uff0c\u6211\u4eec\u4f1a\u4e3a\u6bcf\u4e2a\u7279\u5f81\u5f97\u5230\u4e00\u4e2a\u7cfb\u6570\uff0c\u8be5\u7cfb\u6570\u51b3\u5b9a\u4e86\u7279\u5f81\u7684\u91cd\u8981\u6027\u3002\u800c\u5bf9\u4e8e\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u7279\u5f81\u91cd\u8981\u6027\uff0c\u800c\u4e0d\u662f\u7cfb\u6570\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u53ef\u4ee5\u5254\u9664\u6700\u4e0d\u91cd\u8981\u7684\u7279\u5f81\uff0c\u76f4\u5230\u8fbe\u5230\u6240\u9700\u7684\u7279\u5f81\u6570\u91cf\u4e3a\u6b62\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u51b3\u5b9a\u8981\u4fdd\u7559\u591a\u5c11\u7279\u5f81\u3002 \u56fe 2\uff1a\u589e\u52a0\u65b0\u7279\u5f81\u540e\uff0c\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7684 AUC \u5206\u6570\u5982\u4f55\u53d8\u5316 \u5f53\u6211\u4eec\u8fdb\u884c\u9012\u5f52\u7279\u5f81\u5254\u9664\u65f6\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u4f1a\u5254\u9664\u7279\u5f81\u91cd\u8981\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u6216\u7cfb\u6570\u63a5\u8fd1 0 \u7684\u7279\u5f81\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5f53\u4f60\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u8fd9\u6837\u7684\u6a21\u578b\u8fdb\u884c\u4e8c\u5143\u5206\u7c7b\u65f6\uff0c\u5982\u679c\u7279\u5f81\u5bf9\u6b63\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u6b63\uff0c\u800c\u5982\u679c\u7279\u5f81\u5bf9\u8d1f\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u8d1f\u3002\u4fee\u6539\u6211\u4eec\u7684\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7c7b\uff0c\u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u9012\u5f52\u7279\u5f81\u6d88\u9664\u7c7b\u975e\u5e38\u5bb9\u6613\uff0c\u4f46 scikit-learn \u4e5f\u63d0\u4f9b\u4e86 RFE\u3002\u4e0b\u9762\u7684\u793a\u4f8b\u5c55\u793a\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u7528\u6cd5\u3002 import pandas as pd from sklearn.feature_selection import RFE from sklearn.linear_model import LinearRegression from sklearn.datasets import fetch_california_housing data = fetch_california_housing () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] model = LinearRegression () # \u521b\u5efa RFE\uff08\u9012\u5f52\u7279\u5f81\u6d88\u9664\uff09\uff0c\u6307\u5b9a\u6a21\u578b\u4e3a\u7ebf\u6027\u56de\u5f52\u6a21\u578b\uff0c\u8981\u9009\u62e9\u7684\u7279\u5f81\u6570\u91cf\u4e3a 3 rfe = RFE ( estimator = model , n_features_to_select = 3 ) # \u8bad\u7ec3\u6a21\u578b rfe . fit ( X , y ) # \u4f7f\u7528 RFE \u9009\u62e9\u7684\u7279\u5f81\u8fdb\u884c\u6570\u636e\u8f6c\u6362 X_transformed = rfe . transform ( X ) \u6211\u4eec\u770b\u5230\u4e86\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u7684\u4e24\u79cd\u4e0d\u540c\u7684\u8d2a\u5a6a\u65b9\u6cd5\u3002\u4f46\u4e5f\u53ef\u4ee5\u6839\u636e\u6570\u636e\u62df\u5408\u6a21\u578b\uff0c\u7136\u540e\u901a\u8fc7\u7279\u5f81\u7cfb\u6570\u6216\u7279\u5f81\u7684\u91cd\u8981\u6027\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u5982\u679c\u4f7f\u7528\u7cfb\u6570\uff0c\u5219\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u9608\u503c\uff0c\u5982\u679c\u7cfb\u6570\u9ad8\u4e8e\u8be5\u9608\u503c\uff0c\u5219\u53ef\u4ee5\u4fdd\u7559\u8be5\u7279\u5f81\uff0c\u5426\u5219\u5c06\u5176\u5254\u9664\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4ece\u968f\u673a\u68ee\u6797\u8fd9\u6837\u7684\u6a21\u578b\u4e2d\u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u5b9e\u4f8b\u5316\u968f\u673a\u68ee\u6797\u6a21\u578b model = RandomForestRegressor () # \u62df\u5408\u6a21\u578b model . fit ( X , y ) \u968f\u673a\u68ee\u6797\uff08\u6216\u4efb\u4f55\u6a21\u578b\uff09\u7684\u7279\u5f81\u91cd\u8981\u6027\u53ef\u6309\u5982\u4e0b\u65b9\u5f0f\u7ed8\u5236\u3002 # \u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027 importances = model . feature_importances_ # \u964d\u5e8f\u6392\u5217 idxs = np . argsort ( importances ) # \u8bbe\u5b9a\u6807\u9898 plt . title ( 'Feature Importances' ) # \u521b\u5efa\u76f4\u65b9\u56fe plt . barh ( range ( len ( idxs )), importances [ idxs ], align = 'center' ) # y\u8f74\u6807\u7b7e plt . yticks ( range ( len ( idxs )), [ col_names [ i ] for i in idxs ]) # x\u8f74\u6807\u7b7e plt . xlabel ( 'Random Forest Feature Importance' ) plt . show () \u7ed3\u679c\u5982\u56fe 3 \u6240\u793a\u3002 \u56fe 3\uff1a\u7279\u5f81\u91cd\u8981\u6027\u56fe \u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u6700\u4f73\u7279\u5f81\u5e76\u4e0d\u662f\u4ec0\u4e48\u65b0\u9c9c\u4e8b\u3002\u60a8\u53ef\u4ee5\u4ece\u4e00\u4e2a\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u53e6\u4e00\u4e2a\u6a21\u578b\u8fdb\u884c\u8bad\u7ec3\u3002\u4f8b\u5982\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u7cfb\u6570\u6765\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u968f\u673a\u68ee\u6797\uff08Random Forest\uff09\u5bf9\u6240\u9009\u7279\u5f81\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 SelectFromModel \u7c7b\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u76f4\u63a5\u4ece\u7ed9\u5b9a\u7684\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u6839\u636e\u9700\u8981\u6307\u5b9a\u7cfb\u6570\u6216\u7279\u5f81\u91cd\u8981\u6027\u7684\u9608\u503c\uff0c\u4ee5\u53ca\u8981\u9009\u62e9\u7684\u7279\u5f81\u7684\u6700\u5927\u6570\u91cf\u3002 \u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\uff0c\u6211\u4eec\u4f7f\u7528 SelectFromModel \u4e2d\u7684\u9ed8\u8ba4\u53c2\u6570\u6765\u9009\u62e9\u7279\u5f81\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor from sklearn.feature_selection import SelectFromModel data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u521b\u5efa\u968f\u673a\u68ee\u6797\u6a21\u578b\u56de\u5f52\u6a21\u578b model = RandomForestRegressor () # \u521b\u5efa SelectFromModel \u5bf9\u8c61 sfm\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u6a21\u578b\u4f5c\u4e3a\u4f30\u7b97\u5668 sfm = SelectFromModel ( estimator = model ) # \u4f7f\u7528 sfm \u5bf9\u7279\u5f81\u77e9\u9635 X \u548c\u76ee\u6807\u53d8\u91cf y \u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed = sfm . fit_transform ( X , y ) # \u83b7\u53d6\u7ecf\u8fc7\u7279\u5f81\u9009\u62e9\u540e\u7684\u7279\u5f81\u63a9\u7801\uff08True \u8868\u793a\u7279\u5f81\u88ab\u9009\u62e9\uff0cFalse \u8868\u793a\u7279\u5f81\u672a\u88ab\u9009\u62e9\uff09 support = sfm . get_support () # \u6253\u5370\u88ab\u9009\u62e9\u7684\u7279\u5f81\u5217\u540d print ([ x for x , y in zip ( col_names , support ) if y == True ]) \u4e0a\u9762\u7a0b\u5e8f\u6253\u5370\u7ed3\u679c\uff1a ['bmi'\uff0c's5']\u3002\u6211\u4eec\u518d\u770b\u56fe 3\uff0c\u5c31\u4f1a\u53d1\u73b0\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u4e2a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u76f4\u63a5\u4ece\u968f\u673a\u68ee\u6797\u63d0\u4f9b\u7684\u7279\u5f81\u91cd\u8981\u6027\u4e2d\u8fdb\u884c\u9009\u62e9\u3002\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4ef6\u4e8b\uff0c\u90a3\u5c31\u662f\u4f7f\u7528 L1\uff08Lasso\uff09\u60e9\u7f5a\u6a21\u578b \u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u5f53\u6211\u4eec\u4f7f\u7528 L1 \u60e9\u7f5a\u8fdb\u884c\u6b63\u5219\u5316\u65f6\uff0c\u5927\u90e8\u5206\u7cfb\u6570\u90fd\u5c06\u4e3a 0\uff08\u6216\u63a5\u8fd1 0\uff09\uff0c\u56e0\u6b64\u6211\u4eec\u8981\u9009\u62e9\u7cfb\u6570\u4e0d\u4e3a 0 \u7684\u7279\u5f81\u3002\u53ea\u9700\u5c06\u6a21\u578b\u9009\u62e9\u7247\u6bb5\u4e2d\u7684\u968f\u673a\u68ee\u6797\u66ff\u6362\u4e3a\u652f\u6301 L1 \u60e9\u7f5a\u7684\u6a21\u578b\uff08\u5982 lasso \u56de\u5f52\uff09\u5373\u53ef\u3002\u6240\u6709\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u90fd\u63d0\u4f9b\u7279\u5f81\u91cd\u8981\u6027\uff0c\u56e0\u6b64\u672c\u7ae0\u4e2d\u5c55\u793a\u7684\u6240\u6709\u57fa\u4e8e\u6a21\u578b\u7684\u7247\u6bb5\u90fd\u53ef\u7528\u4e8e XGBoost\u3001LightGBM \u6216 CatBoost\u3002\u7279\u5f81\u91cd\u8981\u6027\u51fd\u6570\u7684\u540d\u79f0\u53ef\u80fd\u4e0d\u540c\uff0c\u4ea7\u751f\u7ed3\u679c\u7684\u683c\u5f0f\u4e5f\u53ef\u80fd\u4e0d\u540c\uff0c\u4f46\u7528\u6cd5\u662f\u4e00\u6837\u7684\u3002\u6700\u540e\uff0c\u5728\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u65f6\u5fc5\u987b\u5c0f\u5fc3\u8c28\u614e\u3002\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u9009\u62e9\u7279\u5f81\uff0c\u5e76\u5728\u9a8c\u8bc1\u6570\u636e\u4e0a\u9a8c\u8bc1\u6a21\u578b\uff0c\u4ee5\u4fbf\u5728\u4e0d\u8fc7\u5ea6\u62df\u5408\u6a21\u578b\u7684\u60c5\u51b5\u4e0b\u6b63\u786e\u9009\u62e9\u7279\u5f81\u3002","title":"\u7279\u5f81\u9009\u62e9"},{"location":"%E7%89%B9%E5%BE%81%E9%80%89%E6%8B%A9/#_1","text":"\u5f53\u4f60\u521b\u5efa\u4e86\u6210\u5343\u4e0a\u4e07\u4e2a\u7279\u5f81\u540e\uff0c\u5c31\u8be5\u4ece\u4e2d\u6311\u9009\u51fa\u51e0\u4e2a\u4e86\u3002\u4f46\u662f\uff0c\u6211\u4eec\u7edd\u4e0d\u5e94\u8be5\u521b\u5efa\u6210\u767e\u4e0a\u5343\u4e2a\u65e0\u7528\u7684\u7279\u5f81\u3002\u7279\u5f81\u8fc7\u591a\u4f1a\u5e26\u6765\u4e00\u4e2a\u4f17\u6240\u5468\u77e5\u7684\u95ee\u9898\uff0c\u5373 \"\u7ef4\u5ea6\u8bc5\u5492\"\u3002\u5982\u679c\u4f60\u6709\u5f88\u591a\u7279\u5f81\uff0c\u4f60\u4e5f\u5fc5\u987b\u6709\u5f88\u591a\u8bad\u7ec3\u6837\u672c\u6765\u6355\u6349\u6240\u6709\u7279\u5f81\u3002\u4ec0\u4e48\u662f \"\u5927\u91cf \"\u5e76\u6ca1\u6709\u6b63\u786e\u7684\u5b9a\u4e49\uff0c\u8fd9\u9700\u8981\u60a8\u901a\u8fc7\u6b63\u786e\u9a8c\u8bc1\u60a8\u7684\u6a21\u578b\u548c\u68c0\u67e5\u8bad\u7ec3\u6a21\u578b\u6240\u9700\u7684\u65f6\u95f4\u6765\u786e\u5b9a\u3002 \u9009\u62e9\u7279\u5f81\u7684\u6700\u7b80\u5355\u65b9\u6cd5\u662f \u5220\u9664\u65b9\u5dee\u975e\u5e38\u5c0f\u7684\u7279\u5f81 \u3002\u5982\u679c\u7279\u5f81\u7684\u65b9\u5dee\u975e\u5e38\u5c0f\uff08\u5373\u975e\u5e38\u63a5\u8fd1\u4e8e 0\uff09\uff0c\u5b83\u4eec\u5c31\u63a5\u8fd1\u4e8e\u5e38\u91cf\uff0c\u56e0\u6b64\u6839\u672c\u4e0d\u4f1a\u7ed9\u4efb\u4f55\u6a21\u578b\u589e\u52a0\u4efb\u4f55\u4ef7\u503c\u3002\u6700\u597d\u7684\u529e\u6cd5\u5c31\u662f\u53bb\u6389\u5b83\u4eec\uff0c\u4ece\u800c\u964d\u4f4e\u590d\u6742\u5ea6\u3002\u8bf7\u6ce8\u610f\uff0c\u65b9\u5dee\u4e5f\u53d6\u51b3\u4e8e\u6570\u636e\u7684\u7f29\u653e\u3002 Scikit-learn \u7684 VarianceThreshold \u5b9e\u73b0\u4e86\u8fd9\u4e00\u70b9\u3002 from sklearn.feature_selection import VarianceThreshold data = ... # \u521b\u5efa VarianceThreshold \u5bf9\u8c61 var_thresh\uff0c\u6307\u5b9a\u65b9\u5dee\u9608\u503c\u4e3a 0.1 var_thresh = VarianceThreshold ( threshold = 0.1 ) # \u4f7f\u7528 var_thresh \u5bf9\u6570\u636e data \u8fdb\u884c\u62df\u5408\u548c\u53d8\u6362\uff0c\u5c06\u65b9\u5dee\u4f4e\u4e8e\u9608\u503c\u7684\u7279\u5f81\u79fb\u9664 transformed_data = var_thresh . fit_transform ( data ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u5220\u9664\u76f8\u5173\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u3002\u8981\u8ba1\u7b97\u4e0d\u540c\u6570\u5b57\u7279\u5f81\u4e4b\u95f4\u7684\u76f8\u5173\u6027\uff0c\u53ef\u4ee5\u4f7f\u7528\u76ae\u5c14\u900a\u76f8\u5173\u6027\u3002 import pandas as pd from sklearn.datasets import fetch_california_housing # \u52a0\u8f7d\u6570\u636e data = fetch_california_housing () # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u77e9\u9635 X X = data [ \"data\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u7279\u5f81\u7684\u5217\u540d col_names = data [ \"feature_names\" ] # \u4ece\u6570\u636e\u96c6\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf y y = data [ \"target\" ] df = pd . DataFrame ( X , columns = col_names ) # \u6dfb\u52a0 MedInc_Sqrt \u5217\uff0c\u662f MedInc \u5217\u4e2d\u6bcf\u4e2a\u5143\u7d20\u8fdb\u884c\u5e73\u65b9\u6839\u8fd0\u7b97\u7684\u7ed3\u679c df . loc [:, \"MedInc_Sqrt\" ] = df . MedInc . apply ( np . sqrt ) # \u8ba1\u7b97\u76ae\u5c14\u900a\u76f8\u5173\u6027\u77e9\u9635 df . corr () \u5f97\u51fa\u76f8\u5173\u77e9\u9635\uff0c\u5982\u56fe 1 \u6240\u793a\u3002 \u56fe 1\uff1a\u76ae\u5c14\u900a\u76f8\u5173\u77e9\u9635\u6837\u672c \u6211\u4eec\u770b\u5230\uff0cMedInc_Sqrt \u4e0e MedInc \u7684\u76f8\u5173\u6027\u975e\u5e38\u9ad8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5220\u9664\u5176\u4e2d\u4e00\u4e2a\u7279\u5f81\u3002 \u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u4e00\u4e9b \u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5 \u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u53ea\u4e0d\u8fc7\u662f\u9488\u5bf9\u7ed9\u5b9a\u76ee\u6807\u5bf9\u6bcf\u4e2a\u7279\u5f81\u8fdb\u884c\u8bc4\u5206\u3002 \u4e92\u4fe1\u606f \u3001 \u65b9\u5dee\u5206\u6790 F \u68c0\u9a8c\u548c chi2 \u662f\u4e00\u4e9b\u6700\u5e38\u7528\u7684\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u3002\u5728 scikit- learn \u4e2d\uff0c\u6709\u4e24\u79cd\u65b9\u6cd5\u53ef\u4ee5\u4f7f\u7528\u8fd9\u4e9b\u65b9\u6cd5\u3002 - SelectKBest\uff1a\u4fdd\u7559\u5f97\u5206\u6700\u9ad8\u7684 k \u4e2a\u7279\u5f81 - SelectPercentile\uff1a\u4fdd\u7559\u7528\u6237\u6307\u5b9a\u767e\u5206\u6bd4\u5185\u7684\u9876\u7ea7\u7279\u5f81\u3002 \u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u53ea\u6709\u975e\u8d1f\u6570\u636e\u624d\u80fd\u4f7f\u7528 chi2\u3002\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e2d\uff0c\u5f53\u6211\u4eec\u6709\u4e00\u4e9b\u5355\u8bcd\u6216\u57fa\u4e8e tf-idf \u7684\u7279\u5f81\u65f6\uff0c\u8fd9\u662f\u4e00\u79cd\u7279\u522b\u6709\u7528\u7684\u7279\u5f81\u9009\u62e9\u6280\u672f\u3002\u6700\u597d\u4e3a\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u521b\u5efa\u4e00\u4e2a\u5305\u88c5\u5668\uff0c\u51e0\u4e4e\u53ef\u4ee5\u7528\u4e8e\u4efb\u4f55\u65b0\u95ee\u9898\u3002 from sklearn.feature_selection import chi2 from sklearn.feature_selection import f_classif from sklearn.feature_selection import f_regression from sklearn.feature_selection import mutual_info_classif from sklearn.feature_selection import mutual_info_regression from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectPercentile class UnivariateFeatureSelction : def __init__ ( self , n_features , problem_type , scoring ): # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u5206\u7c7b\u95ee\u9898 if problem_type == \"classification\" : # \u521b\u5efa\u5b57\u5178 valid_scoring \uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_classif\" : f_classif , \"chi2\" : chi2 , \"mutual_info_classif\" : mutual_info_classif } # \u82e5\u95ee\u9898\u7c7b\u578b\u662f\u56de\u5f52\u95ee\u9898 else : # \u521b\u5efa\u5b57\u5178 valid_scoring\uff0c\u5305\u542b\u5404\u79cd\u7279\u5f81\u91cd\u8981\u6027\u8861\u91cf\u65b9\u5f0f valid_scoring = { \"f_regression\" : f_regression , \"mutual_info_regression\" : mutual_info_regression } # \u68c0\u67e5\u7279\u5f81\u91cd\u8981\u6027\u65b9\u5f0f\u662f\u5426\u5728\u5b57\u5178\u4e2d if scoring not in valid_scoring : raise Exception ( \"Invalid scoring function\" ) # \u68c0\u67e5 n_features \u7684\u7c7b\u578b\uff0c\u5982\u679c\u662f\u6574\u6570\uff0c\u5219\u4f7f\u7528 SelectKBest \u8fdb\u884c\u7279\u5f81\u9009\u62e9 if isinstance ( n_features , int ): self . selection = SelectKBest ( valid_scoring [ scoring ], k = n_features ) # \u5982\u679c n_features \u662f\u6d6e\u70b9\u6570\uff0c\u5219\u4f7f\u7528 SelectPercentile \u8fdb\u884c\u7279\u5f81\u9009\u62e9 elif isinstance ( n_features , float ): self . selection = SelectPercentile ( valid_scoring [ scoring ], percentile = int ( n_features * 100 ) ) # \u5982\u679c n_features \u7c7b\u578b\u65e0\u6548\uff0c\u5f15\u53d1\u5f02\u5e38 else : raise Exception ( \"Invalid type of feature\" ) # \u5b9a\u4e49 fit \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 def fit ( self , X , y ): return self . selection . fit ( X , y ) # \u5b9a\u4e49 transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u5bf9\u6570\u636e\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def transform ( self , X ): return self . selection . transform ( X ) # \u5b9a\u4e49 fit_transform \u65b9\u6cd5\uff0c\u7528\u4e8e\u62df\u5408\u7279\u5f81\u9009\u62e9\u5668\u5e76\u540c\u65f6\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u8f6c\u6362 def fit_transform ( self , X , y ): return self . selection . fit_transform ( X , y ) \u4f7f\u7528\u8be5\u7c7b\u975e\u5e38\u7b80\u5355\u3002 # \u5b9e\u4f8b\u5316\u7279\u5f81\u9009\u62e9\u5668\uff0c\u4fdd\u7559\u524d10%\u7684\u7279\u5f81\uff0c\u56de\u5f52\u95ee\u9898\uff0c\u4f7f\u7528f_regression\u8861\u91cf\u7279\u5f81\u91cd\u8981\u6027 ufs = UnivariateFeatureSelction ( n_features = 0.1 , problem_type = \"regression\" , scoring = \"f_regression\" ) # \u62df\u5408\u7279\u5f81\u9009\u62e9\u5668 ufs . fit ( X , y ) # \u7279\u5f81\u8f6c\u6362 X_transformed = ufs . transform ( X ) \u8fd9\u6837\u5c31\u80fd\u6ee1\u8db3\u5927\u90e8\u5206\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u7684\u9700\u6c42\u3002\u8bf7\u6ce8\u610f\uff0c\u521b\u5efa\u8f83\u5c11\u800c\u91cd\u8981\u7684\u7279\u5f81\u901a\u5e38\u6bd4\u521b\u5efa\u6570\u4ee5\u767e\u8ba1\u7684\u7279\u5f81\u8981\u597d\u3002\u5355\u53d8\u91cf\u7279\u5f81\u9009\u62e9\u4e0d\u4e00\u5b9a\u603b\u662f\u8868\u73b0\u826f\u597d\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u4eba\u4eec\u66f4\u559c\u6b22\u4f7f\u7528\u673a\u5668\u5b66\u4e60\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 \u4f7f\u7528\u6a21\u578b\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u7684\u6700\u7b80\u5355\u5f62\u5f0f\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u3002\u5728\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u4e2d\uff0c\u7b2c\u4e00\u6b65\u662f\u9009\u62e9\u4e00\u4e2a\u6a21\u578b\u3002\u7b2c\u4e8c\u6b65\u662f\u9009\u62e9\u635f\u5931/\u8bc4\u5206\u51fd\u6570\u3002\u7b2c\u4e09\u6b65\u4e5f\u662f\u6700\u540e\u4e00\u6b65\u662f\u53cd\u590d\u8bc4\u4f30\u6bcf\u4e2a\u7279\u5f81\uff0c\u5982\u679c\u80fd\u63d0\u9ad8\u635f\u5931/\u8bc4\u5206\uff0c\u5c31\u5c06\u5176\u6dfb\u52a0\u5230 \"\u597d \"\u7279\u5f81\u5217\u8868\u4e2d\u3002\u6ca1\u6709\u6bd4\u8fd9\u66f4\u7b80\u5355\u7684\u4e86\u3002\u4f46\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u8fd9\u88ab\u79f0\u4e3a\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u662f\u6709\u539f\u56e0\u7684\u3002\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u8fc7\u7a0b\u5728\u6bcf\u6b21\u8bc4\u4f30\u7279\u5f81\u65f6\u90fd\u4f1a\u9002\u5408\u7ed9\u5b9a\u7684\u6a21\u578b\u3002\u8fd9\u79cd\u65b9\u6cd5\u7684\u8ba1\u7b97\u6210\u672c\u975e\u5e38\u9ad8\u3002\u5b8c\u6210\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\u4e5f\u9700\u8981\u5927\u91cf\u65f6\u95f4\u3002\u5982\u679c\u4e0d\u6b63\u786e\u4f7f\u7528\u8fd9\u79cd\u7279\u5f81\u9009\u62e9\uff0c\u751a\u81f3\u4f1a\u5bfc\u81f4\u6a21\u578b\u8fc7\u5ea6\u62df\u5408\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import pandas as pd from sklearn import linear_model from sklearn import metrics from sklearn.datasets import make_classification class GreedyFeatureSelection : # \u5b9a\u4e49\u8bc4\u4f30\u5206\u6570\u7684\u65b9\u6cd5\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd def evaluate_score ( self , X , y ): # \u903b\u8f91\u56de\u5f52\u6a21\u578b model = linear_model . LogisticRegression () # \u8bad\u7ec3\u6a21\u578b model . fit ( X , y ) # \u9884\u6d4b\u6982\u7387\u503c predictions = model . predict_proba ( X )[:, 1 ] # \u8ba1\u7b97 AUC \u5206\u6570 auc = metrics . roc_auc_score ( y , predictions ) return auc # \u7279\u5f81\u9009\u62e9\u51fd\u6570 def _feature_selection ( self , X , y ): # \u521d\u59cb\u5316\u7a7a\u5217\u8868\uff0c\u7528\u4e8e\u5b58\u50a8\u6700\u4f73\u7279\u5f81\u548c\u6700\u4f73\u5206\u6570 good_features = [] best_scores = [] # \u83b7\u53d6\u7279\u5f81\u6570\u91cf num_features = X . shape [ 1 ] # \u5f00\u59cb\u7279\u5f81\u9009\u62e9\u7684\u5faa\u73af while True : this_feature = None best_score = 0 # \u904d\u5386\u6bcf\u4e2a\u7279\u5f81 for feature in range ( num_features ): if feature in good_features : continue selected_features = good_features + [ feature ] xtrain = X [:, selected_features ] score = self . evaluate_score ( xtrain , y ) # \u5982\u679c\u5f53\u524d\u7279\u5f81\u7684\u5f97\u5206\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u5f97\u5206\uff0c\u5219\u66f4\u65b0 if score > best_score : this_feature = feature best_score = score # \u82e5\u627e\u5230\u4e86\u65b0\u7684\u6700\u4f73\u7279\u5f81 if this_feature != None : # \u7279\u5f81\u6dfb\u52a0\u5230 good_features \u5217\u8868 good_features . append ( this_feature ) # \u5f97\u5206\u6dfb\u52a0\u5230 best_scores \u5217\u8868 best_scores . append ( best_score ) # \u5982\u679c best_scores \u5217\u8868\u957f\u5ea6\u5927\u4e8e2\uff0c\u5e76\u4e14\u6700\u540e\u4e24\u4e2a\u5f97\u5206\u76f8\u6bd4\u8f83\u5dee\uff0c\u5219\u7ed3\u675f\u5faa\u73af if len ( best_scores ) > 2 : if best_scores [ - 1 ] < best_scores [ - 2 ]: break # \u8fd4\u56de\u6700\u4f73\u7279\u5f81\u7684\u5f97\u5206\u5217\u8868\u548c\u6700\u4f73\u7279\u5f81\u5217\u8868 return best_scores [: - 1 ], good_features [: - 1 ] # \u5b9a\u4e49\u7c7b\u7684\u8c03\u7528\u65b9\u6cd5\uff0c\u7528\u4e8e\u6267\u884c\u7279\u5f81\u9009\u62e9 def __call__ ( self , X , y ): scores , features = self . _feature_selection ( X , y ) return X [:, features ], scores if __name__ == \"__main__\" : # \u751f\u6210\u4e00\u4e2a\u793a\u4f8b\u7684\u5206\u7c7b\u6570\u636e\u96c6 X \u548c\u6807\u7b7e y X , y = make_classification ( n_samples = 1000 , n_features = 100 ) # \u5b9e\u4f8b\u5316 GreedyFeatureSelection \u7c7b\uff0c\u5e76\u4f7f\u7528 __call__ \u65b9\u6cd5\u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed , scores = GreedyFeatureSelection ()( X , y ) \u8fd9\u79cd\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u65b9\u6cd5\u4f1a\u8fd4\u56de\u5206\u6570\u548c\u7279\u5f81\u7d22\u5f15\u5217\u8868\u3002\u56fe 2 \u663e\u793a\u4e86\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u589e\u52a0\u4e00\u4e2a\u65b0\u7279\u5f81\u540e\uff0c\u5206\u6570\u662f\u5982\u4f55\u63d0\u9ad8\u7684\u3002\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u5728\u67d0\u4e00\u70b9\u4e4b\u540e\uff0c\u6211\u4eec\u5c31\u65e0\u6cd5\u63d0\u9ad8\u5206\u6570\u4e86\uff0c\u8fd9\u5c31\u662f\u6211\u4eec\u505c\u6b62\u7684\u5730\u65b9\u3002 \u53e6\u4e00\u79cd\u8d2a\u5a6a\u7684\u65b9\u6cd5\u88ab\u79f0\u4e3a\u9012\u5f52\u7279\u5f81\u6d88\u9664\u6cd5\uff08RFE\uff09\u3002\u5728\u524d\u4e00\u79cd\u65b9\u6cd5\u4e2d\uff0c\u6211\u4eec\u4ece\u4e00\u4e2a\u7279\u5f81\u5f00\u59cb\uff0c\u7136\u540e\u4e0d\u65ad\u6dfb\u52a0\u65b0\u7684\u7279\u5f81\uff0c\u4f46\u5728 RFE \u4e2d\uff0c\u6211\u4eec\u4ece\u6240\u6709\u7279\u5f81\u5f00\u59cb\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\u4e0d\u65ad\u53bb\u9664\u4e00\u4e2a\u5bf9\u7ed9\u5b9a\u6a21\u578b\u63d0\u4f9b\u6700\u5c0f\u503c\u7684\u7279\u5f81\u3002\u4f46\u6211\u4eec\u5982\u4f55\u77e5\u9053\u54ea\u4e2a\u7279\u5f81\u7684\u4ef7\u503c\u6700\u5c0f\u5462\uff1f\u5982\u679c\u6211\u4eec\u4f7f\u7528\u7ebf\u6027\u652f\u6301\u5411\u91cf\u673a\uff08SVM\uff09\u6216\u903b\u8f91\u56de\u5f52\u7b49\u6a21\u578b\uff0c\u6211\u4eec\u4f1a\u4e3a\u6bcf\u4e2a\u7279\u5f81\u5f97\u5230\u4e00\u4e2a\u7cfb\u6570\uff0c\u8be5\u7cfb\u6570\u51b3\u5b9a\u4e86\u7279\u5f81\u7684\u91cd\u8981\u6027\u3002\u800c\u5bf9\u4e8e\u4efb\u4f55\u57fa\u4e8e\u6811\u7684\u6a21\u578b\uff0c\u6211\u4eec\u5f97\u5230\u7684\u662f\u7279\u5f81\u91cd\u8981\u6027\uff0c\u800c\u4e0d\u662f\u7cfb\u6570\u3002\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u53ef\u4ee5\u5254\u9664\u6700\u4e0d\u91cd\u8981\u7684\u7279\u5f81\uff0c\u76f4\u5230\u8fbe\u5230\u6240\u9700\u7684\u7279\u5f81\u6570\u91cf\u4e3a\u6b62\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u51b3\u5b9a\u8981\u4fdd\u7559\u591a\u5c11\u7279\u5f81\u3002 \u56fe 2\uff1a\u589e\u52a0\u65b0\u7279\u5f81\u540e\uff0c\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7684 AUC \u5206\u6570\u5982\u4f55\u53d8\u5316 \u5f53\u6211\u4eec\u8fdb\u884c\u9012\u5f52\u7279\u5f81\u5254\u9664\u65f6\uff0c\u5728\u6bcf\u6b21\u8fed\u4ee3\u4e2d\uff0c\u6211\u4eec\u90fd\u4f1a\u5254\u9664\u7279\u5f81\u91cd\u8981\u6027\u8f83\u9ad8\u7684\u7279\u5f81\u6216\u7cfb\u6570\u63a5\u8fd1 0 \u7684\u7279\u5f81\u3002\u8bf7\u8bb0\u4f4f\uff0c\u5f53\u4f60\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u8fd9\u6837\u7684\u6a21\u578b\u8fdb\u884c\u4e8c\u5143\u5206\u7c7b\u65f6\uff0c\u5982\u679c\u7279\u5f81\u5bf9\u6b63\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u6b63\uff0c\u800c\u5982\u679c\u7279\u5f81\u5bf9\u8d1f\u5206\u7c7b\u5f88\u91cd\u8981\uff0c\u5176\u7cfb\u6570\u5c31\u4f1a\u66f4\u8d1f\u3002\u4fee\u6539\u6211\u4eec\u7684\u8d2a\u5a6a\u7279\u5f81\u9009\u62e9\u7c7b\uff0c\u521b\u5efa\u4e00\u4e2a\u65b0\u7684\u9012\u5f52\u7279\u5f81\u6d88\u9664\u7c7b\u975e\u5e38\u5bb9\u6613\uff0c\u4f46 scikit-learn \u4e5f\u63d0\u4f9b\u4e86 RFE\u3002\u4e0b\u9762\u7684\u793a\u4f8b\u5c55\u793a\u4e86\u4e00\u4e2a\u7b80\u5355\u7684\u7528\u6cd5\u3002 import pandas as pd from sklearn.feature_selection import RFE from sklearn.linear_model import LinearRegression from sklearn.datasets import fetch_california_housing data = fetch_california_housing () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] model = LinearRegression () # \u521b\u5efa RFE\uff08\u9012\u5f52\u7279\u5f81\u6d88\u9664\uff09\uff0c\u6307\u5b9a\u6a21\u578b\u4e3a\u7ebf\u6027\u56de\u5f52\u6a21\u578b\uff0c\u8981\u9009\u62e9\u7684\u7279\u5f81\u6570\u91cf\u4e3a 3 rfe = RFE ( estimator = model , n_features_to_select = 3 ) # \u8bad\u7ec3\u6a21\u578b rfe . fit ( X , y ) # \u4f7f\u7528 RFE \u9009\u62e9\u7684\u7279\u5f81\u8fdb\u884c\u6570\u636e\u8f6c\u6362 X_transformed = rfe . transform ( X ) \u6211\u4eec\u770b\u5230\u4e86\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u7684\u4e24\u79cd\u4e0d\u540c\u7684\u8d2a\u5a6a\u65b9\u6cd5\u3002\u4f46\u4e5f\u53ef\u4ee5\u6839\u636e\u6570\u636e\u62df\u5408\u6a21\u578b\uff0c\u7136\u540e\u901a\u8fc7\u7279\u5f81\u7cfb\u6570\u6216\u7279\u5f81\u7684\u91cd\u8981\u6027\u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u5982\u679c\u4f7f\u7528\u7cfb\u6570\uff0c\u5219\u53ef\u4ee5\u9009\u62e9\u4e00\u4e2a\u9608\u503c\uff0c\u5982\u679c\u7cfb\u6570\u9ad8\u4e8e\u8be5\u9608\u503c\uff0c\u5219\u53ef\u4ee5\u4fdd\u7559\u8be5\u7279\u5f81\uff0c\u5426\u5219\u5c06\u5176\u5254\u9664\u3002 \u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4ece\u968f\u673a\u68ee\u6797\u8fd9\u6837\u7684\u6a21\u578b\u4e2d\u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u5b9e\u4f8b\u5316\u968f\u673a\u68ee\u6797\u6a21\u578b model = RandomForestRegressor () # \u62df\u5408\u6a21\u578b model . fit ( X , y ) \u968f\u673a\u68ee\u6797\uff08\u6216\u4efb\u4f55\u6a21\u578b\uff09\u7684\u7279\u5f81\u91cd\u8981\u6027\u53ef\u6309\u5982\u4e0b\u65b9\u5f0f\u7ed8\u5236\u3002 # \u83b7\u53d6\u7279\u5f81\u91cd\u8981\u6027 importances = model . feature_importances_ # \u964d\u5e8f\u6392\u5217 idxs = np . argsort ( importances ) # \u8bbe\u5b9a\u6807\u9898 plt . title ( 'Feature Importances' ) # \u521b\u5efa\u76f4\u65b9\u56fe plt . barh ( range ( len ( idxs )), importances [ idxs ], align = 'center' ) # y\u8f74\u6807\u7b7e plt . yticks ( range ( len ( idxs )), [ col_names [ i ] for i in idxs ]) # x\u8f74\u6807\u7b7e plt . xlabel ( 'Random Forest Feature Importance' ) plt . show () \u7ed3\u679c\u5982\u56fe 3 \u6240\u793a\u3002 \u56fe 3\uff1a\u7279\u5f81\u91cd\u8981\u6027\u56fe \u4ece\u6a21\u578b\u4e2d\u9009\u62e9\u6700\u4f73\u7279\u5f81\u5e76\u4e0d\u662f\u4ec0\u4e48\u65b0\u9c9c\u4e8b\u3002\u60a8\u53ef\u4ee5\u4ece\u4e00\u4e2a\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u53e6\u4e00\u4e2a\u6a21\u578b\u8fdb\u884c\u8bad\u7ec3\u3002\u4f8b\u5982\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528\u903b\u8f91\u56de\u5f52\u7cfb\u6570\u6765\u9009\u62e9\u7279\u5f81\uff0c\u7136\u540e\u4f7f\u7528\u968f\u673a\u68ee\u6797\uff08Random Forest\uff09\u5bf9\u6240\u9009\u7279\u5f81\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002Scikit-learn \u8fd8\u63d0\u4f9b\u4e86 SelectFromModel \u7c7b\uff0c\u53ef\u4ee5\u5e2e\u52a9\u4f60\u76f4\u63a5\u4ece\u7ed9\u5b9a\u7684\u6a21\u578b\u4e2d\u9009\u62e9\u7279\u5f81\u3002\u60a8\u8fd8\u53ef\u4ee5\u6839\u636e\u9700\u8981\u6307\u5b9a\u7cfb\u6570\u6216\u7279\u5f81\u91cd\u8981\u6027\u7684\u9608\u503c\uff0c\u4ee5\u53ca\u8981\u9009\u62e9\u7684\u7279\u5f81\u7684\u6700\u5927\u6570\u91cf\u3002 \u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\uff0c\u6211\u4eec\u4f7f\u7528 SelectFromModel \u4e2d\u7684\u9ed8\u8ba4\u53c2\u6570\u6765\u9009\u62e9\u7279\u5f81\u3002 import pandas as pd from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor from sklearn.feature_selection import SelectFromModel data = load_diabetes () X = data [ \"data\" ] col_names = data [ \"feature_names\" ] y = data [ \"target\" ] # \u521b\u5efa\u968f\u673a\u68ee\u6797\u6a21\u578b\u56de\u5f52\u6a21\u578b model = RandomForestRegressor () # \u521b\u5efa SelectFromModel \u5bf9\u8c61 sfm\uff0c\u4f7f\u7528\u968f\u673a\u68ee\u6797\u6a21\u578b\u4f5c\u4e3a\u4f30\u7b97\u5668 sfm = SelectFromModel ( estimator = model ) # \u4f7f\u7528 sfm \u5bf9\u7279\u5f81\u77e9\u9635 X \u548c\u76ee\u6807\u53d8\u91cf y \u8fdb\u884c\u7279\u5f81\u9009\u62e9 X_transformed = sfm . fit_transform ( X , y ) # \u83b7\u53d6\u7ecf\u8fc7\u7279\u5f81\u9009\u62e9\u540e\u7684\u7279\u5f81\u63a9\u7801\uff08True \u8868\u793a\u7279\u5f81\u88ab\u9009\u62e9\uff0cFalse \u8868\u793a\u7279\u5f81\u672a\u88ab\u9009\u62e9\uff09 support = sfm . get_support () # \u6253\u5370\u88ab\u9009\u62e9\u7684\u7279\u5f81\u5217\u540d print ([ x for x , y in zip ( col_names , support ) if y == True ]) \u4e0a\u9762\u7a0b\u5e8f\u6253\u5370\u7ed3\u679c\uff1a ['bmi'\uff0c's5']\u3002\u6211\u4eec\u518d\u770b\u56fe 3\uff0c\u5c31\u4f1a\u53d1\u73b0\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u4e2a\u7279\u5f81\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u76f4\u63a5\u4ece\u968f\u673a\u68ee\u6797\u63d0\u4f9b\u7684\u7279\u5f81\u91cd\u8981\u6027\u4e2d\u8fdb\u884c\u9009\u62e9\u3002\u6211\u4eec\u8fd8\u7f3a\u5c11\u4e00\u4ef6\u4e8b\uff0c\u90a3\u5c31\u662f\u4f7f\u7528 L1\uff08Lasso\uff09\u60e9\u7f5a\u6a21\u578b \u8fdb\u884c\u7279\u5f81\u9009\u62e9\u3002\u5f53\u6211\u4eec\u4f7f\u7528 L1 \u60e9\u7f5a\u8fdb\u884c\u6b63\u5219\u5316\u65f6\uff0c\u5927\u90e8\u5206\u7cfb\u6570\u90fd\u5c06\u4e3a 0\uff08\u6216\u63a5\u8fd1 0\uff09\uff0c\u56e0\u6b64\u6211\u4eec\u8981\u9009\u62e9\u7cfb\u6570\u4e0d\u4e3a 0 \u7684\u7279\u5f81\u3002\u53ea\u9700\u5c06\u6a21\u578b\u9009\u62e9\u7247\u6bb5\u4e2d\u7684\u968f\u673a\u68ee\u6797\u66ff\u6362\u4e3a\u652f\u6301 L1 \u60e9\u7f5a\u7684\u6a21\u578b\uff08\u5982 lasso \u56de\u5f52\uff09\u5373\u53ef\u3002\u6240\u6709\u57fa\u4e8e\u6811\u7684\u6a21\u578b\u90fd\u63d0\u4f9b\u7279\u5f81\u91cd\u8981\u6027\uff0c\u56e0\u6b64\u672c\u7ae0\u4e2d\u5c55\u793a\u7684\u6240\u6709\u57fa\u4e8e\u6a21\u578b\u7684\u7247\u6bb5\u90fd\u53ef\u7528\u4e8e XGBoost\u3001LightGBM \u6216 CatBoost\u3002\u7279\u5f81\u91cd\u8981\u6027\u51fd\u6570\u7684\u540d\u79f0\u53ef\u80fd\u4e0d\u540c\uff0c\u4ea7\u751f\u7ed3\u679c\u7684\u683c\u5f0f\u4e5f\u53ef\u80fd\u4e0d\u540c\uff0c\u4f46\u7528\u6cd5\u662f\u4e00\u6837\u7684\u3002\u6700\u540e\uff0c\u5728\u8fdb\u884c\u7279\u5f81\u9009\u62e9\u65f6\u5fc5\u987b\u5c0f\u5fc3\u8c28\u614e\u3002\u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u9009\u62e9\u7279\u5f81\uff0c\u5e76\u5728\u9a8c\u8bc1\u6570\u636e\u4e0a\u9a8c\u8bc1\u6a21\u578b\uff0c\u4ee5\u4fbf\u5728\u4e0d\u8fc7\u5ea6\u62df\u5408\u6a21\u578b\u7684\u60c5\u51b5\u4e0b\u6b63\u786e\u9009\u62e9\u7279\u5f81\u3002","title":"\u7279\u5f81\u9009\u62e9"},{"location":"%E7%BB%84%E5%90%88%E5%92%8C%E5%A0%86%E5%8F%A0%E6%96%B9%E6%B3%95/","text":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5 \u542c\u5230\u4e0a\u9762\u4e24\u4e2a\u8bcd\uff0c\u6211\u4eec\u9996\u5148\u60f3\u5230\u7684\u5c31\u662f\u5728\u7ebf\uff08online\uff09/\u79bb\u7ebf\uff08offline\uff09\u673a\u5668\u5b66\u4e60\u7ade\u8d5b\u3002\u51e0\u5e74\u524d\u662f\u8fd9\u6837\uff0c\u4f46\u73b0\u5728\u968f\u7740\u8ba1\u7b97\u80fd\u529b\u7684\u8fdb\u6b65\u548c\u865a\u62df\u5b9e\u4f8b\u7684\u5ec9\u4ef7\uff0c\u4eba\u4eec\u751a\u81f3\u5f00\u59cb\u5728\u884c\u4e1a\u4e2d\u4f7f\u7528\u7ec4\u5408\u6a21\u578b\uff08ensemble models\uff09\u3002\u4f8b\u5982\uff0c\u90e8\u7f72\u591a\u4e2a\u795e\u7ecf\u7f51\u7edc\u5e76\u5b9e\u65f6\u4e3a\u5b83\u4eec\u63d0\u4f9b\u670d\u52a1\u975e\u5e38\u5bb9\u6613\uff0c\u54cd\u5e94\u65f6\u95f4\u5c0f\u4e8e 500 \u6beb\u79d2\u3002\u6709\u65f6\uff0c\u4e00\u4e2a\u5e9e\u5927\u7684\u795e\u7ecf\u7f51\u7edc\u6216\u5927\u578b\u6a21\u578b\u4e5f\u53ef\u4ee5\u88ab\u5176\u4ed6\u51e0\u4e2a\u6a21\u578b\u53d6\u4ee3\uff0c\u8fd9\u4e9b\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u6027\u80fd\u4e0e\u5927\u578b\u6a21\u578b\u76f8\u4f3c\uff0c\u901f\u5ea6\u5374\u5feb\u4e00\u500d\u3002\u5982\u679c\u662f\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff08\u4e9b\uff09\u6a21\u578b\u5462\uff1f\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u4e8e\u9009\u62e9\u591a\u4e2a\u5c0f\u673a\u578b\uff0c\u5b83\u4eec\u901f\u5ea6\u66f4\u5feb\uff0c\u6027\u80fd\u4e0e\u5927\u673a\u578b\u548c\u6162\u673a\u578b\u76f8\u540c\u3002\u8bf7\u8bb0\u4f4f\uff0c\u8f83\u5c0f\u7684\u578b\u53f7\u4e5f\u66f4\u5bb9\u6613\u548c\u66f4\u5feb\u5730\u8fdb\u884c\u8c03\u6574\u3002 \u7ec4\u5408\uff08ensembling\uff09\u4e0d\u8fc7\u662f\u4e0d\u540c\u6a21\u578b\u7684\u7ec4\u5408\u3002\u6a21\u578b\u53ef\u4ee5\u901a\u8fc7\u9884\u6d4b/\u6982\u7387\u8fdb\u884c\u7ec4\u5408\u3002\u7ec4\u5408\u6a21\u578b\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6c42\u5e73\u5747\u503c\u3002 $$ Ensemble Probabilities = (M1_proba + M2_proba + ... + Mn_Proba)/n $$ \u8fd9\u662f\u6700\u7b80\u5355\u4e5f\u662f\u6700\u6709\u6548\u7684\u7ec4\u5408\u6a21\u578b\u7684\u65b9\u6cd5\u3002\u5728\u7b80\u5355\u5e73\u5747\u6cd5\u4e2d\uff0c\u6240\u6709\u6a21\u578b\u7684\u6743\u91cd\u90fd\u662f\u76f8\u7b49\u7684\u3002\u65e0\u8bba\u91c7\u7528\u54ea\u79cd\u7ec4\u5408\u65b9\u6cd5\uff0c\u60a8\u90fd\u5e94\u8be5\u7262\u8bb0\u4e00\u70b9\uff0c\u90a3\u5c31\u662f\u60a8\u5e94\u8be5\u59cb\u7ec8\u5c06\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b/\u6982\u7387\u7ec4\u5408\u5728\u4e00\u8d77\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u7ec4\u5408\u76f8\u5173\u6027\u4e0d\u9ad8\u7684\u6a21\u578b\u6bd4\u7ec4\u5408\u76f8\u5173\u6027\u5f88\u9ad8\u7684\u6a21\u578b\u6548\u679c\u66f4\u597d\u3002 \u5982\u679c\u6ca1\u6709\u6982\u7387\uff0c\u4e5f\u53ef\u4ee5\u7ec4\u5408\u9884\u6d4b\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6295\u7968\u3002\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u8fdb\u884c\u591a\u7c7b\u5206\u7c7b\uff0c\u6709\u4e09\u4e2a\u7c7b\u522b\uff1a 0\u30011 \u548c 2\u3002 [0, 0, 1] : \u6700\u9ad8\u7968\u6570\uff1a 0 [0, 1, 2] : \u6700\u9ad8\u7968\u7ea7\uff1a \u65e0\uff08\u968f\u673a\u9009\u62e9\u4e00\u4e2a\uff09 [2, 2, 2] : \u6700\u9ad8\u7968\u6570\uff1a 2 \u4ee5\u4e0b\u7b80\u5355\u51fd\u6570\u53ef\u4ee5\u5b8c\u6210\u8fd9\u4e9b\u7b80\u5355\u64cd\u4f5c\u3002 import numpy as np def mean_predictions ( probas ): # \u8ba1\u7b97\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u6bcf\u884c\u5e73\u5747\u503c return np . mean ( probas , axis = 1 ) def max_voting ( preds ): # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u67e5\u627e\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u7684\u7d22\u5f15 idxs = np . argmax ( preds , axis = 1 ) # \u6839\u636e\u7d22\u5f15\u53d6\u51fa\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u5143\u7d20 return np . take_along_axis ( preds , idxs [:, None ], axis = 1 ) \u8bf7\u6ce8\u610f\uff0cprobas \u7684\u6bcf\u4e00\u5217\u90fd\u53ea\u6709\u4e00\u4e2a\u6982\u7387\uff08\u5373\u4e8c\u5143\u5206\u7c7b\uff0c\u901a\u5e38\u4e3a\u7c7b\u522b 1\uff09\u3002\u56e0\u6b64\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u4e00\u4e2a\u65b0\u6a21\u578b\u3002\u540c\u6837\uff0c\u5bf9\u4e8e preds\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u6765\u81ea\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b\u503c\u3002\u8fd9\u4e24\u4e2a\u51fd\u6570\u90fd\u5047\u8bbe\u4e86\u4e00\u4e2a 2 \u7ef4 numpy \u6570\u7ec4\u3002\u60a8\u53ef\u4ee5\u6839\u636e\u81ea\u5df1\u7684\u9700\u6c42\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u6709\u4e00\u4e2a 2 \u7ef4\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u6bcf\u4e2a\u6a21\u578b\u7684\u6982\u7387\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u51fd\u6570\u4f1a\u6709\u4e00\u4e9b\u53d8\u5316\u3002 \u53e6\u4e00\u79cd\u7ec4\u5408\u591a\u4e2a\u6a21\u578b\u7684\u65b9\u6cd5\u662f\u901a\u8fc7\u5b83\u4eec\u7684 \u6982\u7387\u6392\u5e8f \u3002\u5f53\u76f8\u5173\u6307\u6807\u662f\u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u65f6\uff0c\u8fd9\u79cd\u7ec4\u5408\u65b9\u5f0f\u975e\u5e38\u6709\u6548\uff0c\u56e0\u4e3a AUC \u5c31\u662f\u5bf9\u6837\u672c\u8fdb\u884c\u6392\u5e8f\u3002 def rank_mean ( probas ): # \u521b\u5efa\u7a7a\u5217\u8868ranked\u5b58\u50a8\u6bcf\u4e2a\u7c7b\u522b\u6982\u7387\u503c\u6392\u540d ranked = [] # \u904d\u5386\u6982\u7387\u503c\u6bcf\u4e00\u5217\uff08\u6bcf\u4e2a\u7c7b\u522b\u7684\u6982\u7387\u503c\uff09 for i in range ( probas . shape [ 1 ]): # \u5f53\u524d\u5217\u6982\u7387\u503c\u6392\u540d\uff0crank_data\u662f\u6392\u540d\u7ed3\u679c rank_data = stats . rankdata ( probas [:, i ]) # \u5c06\u5f53\u524d\u5217\u6392\u540d\u7ed3\u679c\u6dfb\u52a0\u5230ranked\u5217\u8868\u4e2d ranked . append ( rank_data ) # \u5c06ranked\u5217\u8868\u4e2d\u6392\u540d\u7ed3\u679c\u6309\u5217\u5806\u53e0\uff0c\u5f62\u6210\u4e8c\u7ef4\u6570\u7ec4 ranked = np . column_stack ( ranked ) # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u8ba1\u7b97\u6837\u672c\u6392\u540d\u5e73\u5747\u503c return np . mean ( ranked , axis = 1 ) \u8bf7\u6ce8\u610f\uff0c\u5728 scipy \u7684 rankdata \u4e2d\uff0c\u7b49\u7ea7\u4ece 1 \u5f00\u59cb\u3002 \u4e3a\u4ec0\u4e48\u8fd9\u7c7b\u96c6\u5408\u6709\u6548\uff1f\u8ba9\u6211\u4eec\u770b\u770b\u56fe 1\u3002 \u56fe 1\uff1a\u4e09\u4eba\u731c\u5927\u8c61\u7684\u8eab\u9ad8 \u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u6709\u4e09\u4e2a\u4eba\u5728\u731c\u5927\u8c61\u7684\u9ad8\u5ea6\uff0c\u90a3\u4e48\u539f\u59cb\u9ad8\u5ea6\u5c06\u975e\u5e38\u63a5\u8fd1\u4e09\u4e2a\u4eba\u731c\u6d4b\u7684\u5e73\u5747\u503c\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u4eba\u90fd\u80fd\u731c\u5230\u975e\u5e38\u63a5\u8fd1\u5927\u8c61\u539f\u6765\u7684\u9ad8\u5ea6\u3002\u63a5\u8fd1\u4f30\u8ba1\u503c\u610f\u5473\u7740\u8bef\u5dee\uff0c\u4f46\u5982\u679c\u6211\u4eec\u5c06\u4e09\u4e2a\u9884\u6d4b\u503c\u5e73\u5747\uff0c\u5c31\u80fd\u5c06\u8bef\u5dee\u964d\u5230\u6700\u4f4e\u3002\u8fd9\u5c31\u662f\u591a\u4e2a\u6a21\u578b\u5e73\u5747\u7684\u4e3b\u8981\u601d\u60f3\u3002 $$ Final\\ Probabilities = w_1 \\times M1_proba + w_2 \\times M2_proba + \\cdots + w_n \\times Mn_proba $$ \u5176\u4e2d \\((w_1 + w_2 + w_3 + \\cdots + w_n)=1.0\\) \u4f8b\u5982\uff0c\u5982\u679c\u4f60\u6709\u4e00\u4e2a AUC \u975e\u5e38\u9ad8\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u548c\u4e00\u4e2a AUC \u7a0d\u4f4e\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0c\u968f\u673a\u68ee\u6797\u6a21\u578b\u5360 70%\uff0c\u903b\u8f91\u56de\u5f52\u6a21\u578b\u5360 30%\u3002\u90a3\u4e48\uff0c\u6211\u662f\u5982\u4f55\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u7684\u5462\uff1f\u8ba9\u6211\u4eec\u518d\u6dfb\u52a0\u4e00\u4e2a\u6a21\u578b\uff0c\u5047\u8bbe\u73b0\u5728\u6211\u4eec\u4e5f\u6709\u4e00\u4e2a xgboost \u6a21\u578b\uff0c\u5b83\u7684 AUC \u6bd4\u968f\u673a\u68ee\u6797\u9ad8\u3002\u73b0\u5728\uff0c\u6211\u5c06\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0cxgboost\uff1a\u968f\u673a\u68ee\u6797\uff1a\u903b\u8f91\u56de\u5f52\u7684\u6bd4\u4f8b\u4e3a 3:2:1\u3002\u5f88\u7b80\u5355\u5427\uff1f\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u6613\u5982\u53cd\u638c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u662f\u5982\u4f55\u505a\u5230\u7684\u3002 \u5047\u5b9a\u6211\u4eec\u6709\u4e09\u53ea\u7334\u5b50\uff0c\u4e09\u53ea\u65cb\u94ae\u7684\u6570\u503c\u5728 0 \u548c 1 \u4e4b\u95f4\u3002\u8fd9\u4e9b\u7334\u5b50\u8f6c\u52a8\u65cb\u94ae\uff0c\u6211\u4eec\u8ba1\u7b97\u5b83\u4eec\u6bcf\u8f6c\u5230\u4e00\u4e2a\u6570\u503c\u65f6\u7684 AUC \u5206\u6570\u3002\u6700\u7ec8\uff0c\u7334\u5b50\u4eec\u4f1a\u627e\u5230\u4e00\u4e2a\u80fd\u7ed9\u51fa\u6700\u4f73 AUC \u7684\u7ec4\u5408\u3002\u6ca1\u9519\uff0c\u8fd9\u5c31\u662f\u968f\u673a\u641c\u7d22\uff01\u5728\u8fdb\u884c\u8fd9\u7c7b\u641c\u7d22\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\u4e24\u4e2a\u6700\u91cd\u8981\u7684\u7ec4\u5408\u89c4\u5219\u3002 \u7ec4\u5408\u7684\u7b2c\u4e00\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u7ec4\u5408\u7684\u7b2c\u4e8c\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u662f\u7684\u3002\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u6761\u89c4\u5219\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u6298\u53e0\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u5047\u8bbe\u6211\u4eec\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\uff1a\u6298\u53e0 1 \u548c\u6298\u53e0 2\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u6837\u505a\u53ea\u662f\u4e3a\u4e86\u7b80\u5316\u89e3\u91ca\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\uff0c\u60a8\u5e94\u8be5\u521b\u5efa\u66f4\u591a\u7684\u6298\u53e0\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u968f\u673a\u68ee\u6797\u6a21\u578b\u3001\u903b\u8f91\u56de\u5f52\u6a21\u578b\u548c xgboost \u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 2 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u4e4b\u540e\uff0c\u6211\u4eec\u5728\u6298\u53e0 2 \u4e0a\u4ece\u5934\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 1 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u4e3a\u6240\u6709\u8bad\u7ec3\u6570\u636e\u521b\u5efa\u4e86\u9884\u6d4b\u7ed3\u679c\u3002\u73b0\u5728\uff0c\u4e3a\u4e86\u5408\u5e76\u8fd9\u4e9b\u6a21\u578b\uff0c\u6211\u4eec\u5c06\u6298\u53e0 1 \u548c\u6298\u53e0 1 \u7684\u6240\u6709\u9884\u6d4b\u6570\u636e\u5408\u5e76\u5728\u4e00\u8d77\uff0c\u7136\u540e\u521b\u5efa\u4e00\u4e2a\u4f18\u5316\u51fd\u6570\uff0c\u8bd5\u56fe\u627e\u5230\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4fbf\u9488\u5bf9\u6298\u53e0 2 \u7684\u76ee\u6807\u6700\u5c0f\u5316\u8bef\u5dee\u6216\u6700\u5927\u5316 AUC\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u662f\u7528\u4e09\u4e2a\u6a21\u578b\u7684\u9884\u6d4b\u6982\u7387\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u4e00\u4e2a\u4f18\u5316\u6a21\u578b\uff0c\u7136\u540e\u5728\u6298\u53e0 2 \u4e0a\u5bf9\u5176\u8fdb\u884c\u8bc4\u4f30\u3002\u8ba9\u6211\u4eec\u5148\u6765\u770b\u770b\u6211\u4eec\u53ef\u4ee5\u7528\u6765\u627e\u5230\u591a\u4e2a\u6a21\u578b\u7684\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4f18\u5316 AUC\uff08\u6216\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u6d4b\u6307\u6807\u7ec4\u5408\uff09\u7684\u7c7b\u3002 import numpy as np from functools import partial from scipy.optimize import fmin from sklearn import metrics class OptimizeAUC : def __init__ ( self ): # \u521d\u59cb\u5316\u7cfb\u6570 self . coef_ = 0 def _auc ( self , coef , X , y ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u7cfb\u6570 x_coef = X * coef # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8ba1\u7b97AUC\u5206\u6570 auc_score = metrics . roc_auc_score ( y , predictions ) # \u8fd4\u56de\u8d1fAUC\u4ee5\u4fbf\u6700\u5c0f\u5316 return - 1.0 * auc_score def fit ( self , X , y ): # \u521b\u5efa\u5e26\u6709\u90e8\u5206\u53c2\u6570\u7684\u76ee\u6807\u51fd\u6570 loss_partial = partial ( self . _auc , X = X , y = y ) # \u521d\u59cb\u5316\u7cfb\u6570 initial_coef = np . random . dirichlet ( np . ones ( X . shape [ 1 ]), size = 1 ) # \u4f7f\u7528fmin\u51fd\u6570\u4f18\u5316AUC\u76ee\u6807\u51fd\u6570\uff0c\u627e\u5230\u6700\u4f18\u7cfb\u6570 self . coef_ = fmin ( loss_partial , initial_coef , disp = True ) def predict ( self , X ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u8bad\u7ec3\u597d\u7684\u7cfb\u6570 x_coef = X * self . coef_ # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8fd4\u56de\u9884\u6d4b\u7ed3\u679c return predictions \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528\u5b83\uff0c\u5e76\u5c06\u5176\u4e0e\u7b80\u5355\u5e73\u5747\u6cd5\u8fdb\u884c\u6bd4\u8f83\u3002 import xgboost as xgb from sklearn.datasets import make_classification from sklearn import ensemble from sklearn import linear_model from sklearn import metrics from sklearn import model_selection # \u751f\u6210\u4e00\u4e2a\u5206\u7c7b\u6570\u636e\u96c6 X , y = make_classification ( n_samples = 10000 , n_features = 25 ) # \u5212\u5206\u6570\u636e\u96c6\u4e3a\u4e24\u4e2a\u4ea4\u53c9\u9a8c\u8bc1\u6298\u53e0 xfold1 , xfold2 , yfold1 , yfold2 = model_selection . train_test_split ( X , y , test_size = 0.5 , stratify = y ) # \u521d\u59cb\u5316\u4e09\u4e2a\u4e0d\u540c\u7684\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold1 , yfold1 ) rf . fit ( xfold1 , yfold1 ) xgbc . fit ( xfold1 , yfold1 ) # \u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold2 )[:, 1 ] pred_rf = rf . predict_proba ( xfold2 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold2 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold2_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold2 = [] for i in range ( fold2_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold2 , fold2_preds [:, i ]) aucs_fold2 . append ( auc ) print ( f \"Fold-2: LR AUC = { aucs_fold2 [ 0 ] } \" ) print ( f \"Fold-2: RF AUC = { aucs_fold2 [ 1 ] } \" ) print ( f \"Fold-2: XGB AUC = { aucs_fold2 [ 2 ] } \" ) print ( f \"Fold-2: Average Pred AUC = { aucs_fold2 [ 3 ] } \" ) # \u91cd\u65b0\u521d\u59cb\u5316\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold2 , yfold2 ) rf . fit ( xfold2 , yfold2 ) xgbc . fit ( xfold2 , yfold2 ) # \u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold1 )[:, 1 ] pred_rf = rf . predict_proba ( xfold1 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold1 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold1_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold1 = [] for i in range ( fold1_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold1 , fold1_preds [:, i ]) aucs_fold1 . append ( auc ) print ( f \"Fold-1: LR AUC = { aucs_fold1 [ 0 ] } \" ) print ( f \"Fold-1: RF AUC = { aucs_fold1 [ 1 ] } \" ) print ( f \"Fold-1: XGB AUC = { aucs_fold1 [ 2 ] } \" ) print ( f \"Fold-1: Average prediction AUC = { aucs_fold1 [ 3 ] } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765\u8bad\u7ec3\u4f18\u5316\u5668 opt . fit ( fold1_preds [:, : - 1 ], yfold1 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold2 = opt . predict ( fold2_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold2 , opt_preds_fold2 ) print ( f \"Optimized AUC, Fold 2 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765 opt . fit ( fold2_preds [:, : - 1 ], yfold2 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold1 = opt . predict ( fold1_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold1 , opt_preds_fold1 ) print ( f \"Optimized AUC, Fold 1 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) \u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u8f93\u51fa\uff1a \u276f python auc_opt . py Fold - 2 : LR AUC = 0.9145446769443348 Fold - 2 : RF AUC = 0.9269918948683287 Fold - 2 : XGB AUC = 0.9302436595508696 Fold - 2 : Average Pred AUC = 0.927701495890154 Fold - 1 : LR AUC = 0.9050872233256017 Fold - 1 : RF AUC = 0.9179382818311258 Fold - 1 : XGB AUC = 0.9195837242005629 Fold - 1 : Average prediction AUC = 0.9189669233123695 Optimization terminated successfully . Current function value : - 0.920643 Iterations : 50 Function evaluations : 109 Optimized AUC , Fold 2 = 0.9305386199756128 Coefficients = [ - 0.00188194 0.19328336 0.35891836 ] Optimization terminated successfully . Current function value : - 0.931232 Iterations : 56 Function evaluations : 113 Optimized AUC , Fold 1 = 0.9192523637234037 Coefficients = [ - 0.15655124 0.22393151 0.58711366 ] \u6211\u4eec\u770b\u5230\uff0c\u5e73\u5747\u503c\u66f4\u597d\uff0c\u4f46\u4f7f\u7528\u4f18\u5316\u5668\u627e\u5230\u9608\u503c\u66f4\u597d\uff01\u6709\u65f6\uff0c\u5e73\u5747\u503c\u662f\u6700\u597d\u7684\u9009\u62e9\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u7cfb\u6570\u52a0\u8d77\u6765\u5e76\u6ca1\u6709\u8fbe\u5230 1.0\uff0c\u4f46\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u6211\u4eec\u8981\u5904\u7406\u7684\u662f AUC\uff0c\u800c AUC \u53ea\u5173\u5fc3\u7b49\u7ea7\u3002 \u5373\u4f7f\u968f\u673a\u68ee\u6797\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u968f\u673a\u68ee\u6797\u53ea\u662f\u8bb8\u591a\u7b80\u5355\u51b3\u7b56\u6811\u7684\u7ec4\u5408\u3002\u968f\u673a\u68ee\u6797\u5c5e\u4e8e\u96c6\u5408\u6a21\u578b\u7684\u4e00\u79cd\uff0c\u4e5f\u5c31\u662f\u4fd7\u79f0\u7684 \"bagging\" \u3002\u5728\u888b\u96c6\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u5c0f\u6570\u636e\u5b50\u96c6\u5e76\u8bad\u7ec3\u591a\u4e2a\u7b80\u5355\u6a21\u578b\u3002\u6700\u7ec8\u7ed3\u679c\u7531\u6240\u6709\u8fd9\u4e9b\u5c0f\u6a21\u578b\u7684\u9884\u6d4b\u7ed3\u679c\uff08\u5982\u5e73\u5747\u503c\uff09\u7ec4\u5408\u800c\u6210\u3002 \u6211\u4eec\u4f7f\u7528\u7684 xgboost \u6a21\u578b\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u6240\u6709\u68af\u5ea6\u63d0\u5347\u6a21\u578b\u90fd\u662f\u96c6\u5408\u6a21\u578b\uff0c\u7edf\u79f0\u4e3a \u63d0\u5347\u6a21\u578b\uff08boosting models\uff09 \u3002\u63d0\u5347\u6a21\u578b\u7684\u5de5\u4f5c\u539f\u7406\u4e0e\u88c5\u888b\u6a21\u578b\u7c7b\u4f3c\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u63d0\u5347\u6a21\u578b\u4e2d\u7684\u8fde\u7eed\u6a21\u578b\u662f\u6839\u636e\u8bef\u5dee\u6b8b\u5dee\u8bad\u7ec3\u7684\uff0c\u5e76\u503e\u5411\u4e8e\u6700\u5c0f\u5316\u524d\u9762\u6a21\u578b\u7684\u8bef\u5dee\u3002\u8fd9\u6837\uff0c\u63d0\u5347\u6a21\u578b\u5c31\u80fd\u5b8c\u7f8e\u5730\u5b66\u4e60\u6570\u636e\uff0c\u56e0\u6b64\u5bb9\u6613\u51fa\u73b0\u8fc7\u62df\u5408\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u770b\u5230\u7684\u4ee3\u7801\u7247\u6bb5\u53ea\u8003\u8651\u4e86\u4e00\u5217\u3002\u4f46\u60c5\u51b5\u5e76\u975e\u603b\u662f\u5982\u6b64\uff0c\u5f88\u591a\u65f6\u5019\u60a8\u9700\u8981\u5904\u7406\u591a\u5217\u9884\u6d4b\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u4f1a\u9047\u5230\u4ece\u591a\u4e2a\u7c7b\u522b\u4e2d\u9884\u6d4b\u4e00\u4e2a\u7c7b\u522b\u7684\u95ee\u9898\uff0c\u5373\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5bf9\u4e8e\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u9009\u62e9\u6295\u7968\u65b9\u6cd5\u3002\u4f46\u6295\u7968\u6cd5\u5e76\u4e0d\u603b\u662f\u6700\u4f73\u65b9\u6cd5\u3002\u5982\u679c\u8981\u7ec4\u5408\u6982\u7387\uff0c\u5c31\u4f1a\u6709\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4\uff0c\u800c\u4e0d\u662f\u50cf\u6211\u4eec\u4e4b\u524d\u4f18\u5316 AUC \u65f6\u7684\u5411\u91cf\u3002\u5982\u679c\u6709\u591a\u4e2a\u7c7b\u522b\uff0c\u53ef\u4ee5\u5c1d\u8bd5\u4f18\u5316\u5bf9\u6570\u635f\u5931\uff08\u6216\u5176\u4ed6\u4e0e\u4e1a\u52a1\u76f8\u5173\u7684\u6307\u6807\uff09\u3002 \u8981\u8fdb\u884c\u7ec4\u5408\uff0c\u53ef\u4ee5\u5728\u62df\u5408\u51fd\u6570 (X) \u4e2d\u4f7f\u7528 numpy \u6570\u7ec4\u5217\u8868\u800c\u4e0d\u662f numpy \u6570\u7ec4\uff0c\u968f\u540e\u8fd8\u9700\u8981\u66f4\u6539\u4f18\u5316\u5668\u548c\u9884\u6d4b\u51fd\u6570\u3002\u6211\u5c31\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u5427\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8fdb\u5165\u4e0b\u4e00\u4e2a\u6709\u8da3\u7684\u8bdd\u9898\uff0c\u8fd9\u4e2a\u8bdd\u9898\u76f8\u5f53\u6d41\u884c\uff0c\u88ab\u79f0\u4e3a \u5806\u53e0 \u3002\u56fe 2 \u5c55\u793a\u4e86\u5982\u4f55\u5806\u53e0\u6a21\u578b\u3002 \u56fe2 : Stacking \u5806\u53e0\u4e0d\u50cf\u5236\u9020\u706b\u7bad\u3002\u5b83\u7b80\u5355\u660e\u4e86\u3002\u5982\u679c\u60a8\u8fdb\u884c\u4e86\u6b63\u786e\u7684\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5e76\u5728\u6574\u4e2a\u5efa\u6a21\u8fc7\u7a0b\u4e2d\u4fdd\u6301\u6298\u53e0\u4e0d\u53d8\uff0c\u90a3\u4e48\u5c31\u4e0d\u4f1a\u51fa\u73b0\u4efb\u4f55\u8fc7\u5ea6\u8d34\u5408\u7684\u60c5\u51b5\u3002 \u8ba9\u6211\u7528\u7b80\u5355\u7684\u8981\u70b9\u5411\u4f60\u63cf\u8ff0\u4e00\u4e0b\u8fd9\u4e2a\u60f3\u6cd5\u3002 - \u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u82e5\u5e72\u6298\u53e0\u3002 - \u8bad\u7ec3\u4e00\u5806\u6a21\u578b\uff1a M1\u3001M2.....Mn\u3002 - \u521b\u5efa\u5b8c\u6574\u7684\u8bad\u7ec3\u9884\u6d4b\uff08\u4f7f\u7528\u975e\u6298\u53e0\u8bad\u7ec3\uff09\uff0c\u5e76\u4f7f\u7528\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\u9884\u6d4b\u3002 - \u76f4\u5230\u8fd9\u91cc\u662f\u7b2c 1 \u5c42 (L1)\u3002 - \u5c06\u8fd9\u4e9b\u6a21\u578b\u7684\u6298\u53e0\u9884\u6d4b\u4f5c\u4e3a\u53e6\u4e00\u4e2a\u6a21\u578b\u7684\u7279\u5f81\u3002\u8fd9\u5c31\u662f\u4e8c\u7ea7\u6a21\u578b\uff08L2\uff09\u3002 - \u4f7f\u7528\u4e0e\u4e4b\u524d\u76f8\u540c\u7684\u6298\u53e0\u6765\u8bad\u7ec3\u8fd9\u4e2a L2 \u6a21\u578b\u3002 - \u73b0\u5728\uff0c\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u521b\u5efa OOF\uff08\u6298\u53e0\u5916\uff09\u9884\u6d4b\u3002 - \u73b0\u5728\u60a8\u5c31\u6709\u4e86\u8bad\u7ec3\u6570\u636e\u7684 L2 \u9884\u6d4b\u548c\u6700\u7ec8\u6d4b\u8bd5\u96c6\u9884\u6d4b\u3002 \u60a8\u53ef\u4ee5\u4e0d\u65ad\u91cd\u590d L1 \u90e8\u5206\uff0c\u4e5f\u53ef\u4ee5\u521b\u5efa\u4efb\u610f\u591a\u7684\u5c42\u6b21\u3002 \u6709\u65f6\uff0c\u4f60\u8fd8\u4f1a\u9047\u5230\u4e00\u4e2a\u53eb\u6df7\u5408\u7684\u672f\u8bed blending \u3002\u5982\u679c\u4f60\u9047\u5230\u4e86\uff0c\u4e0d\u7528\u592a\u62c5\u5fc3\u3002\u5b83\u53ea\u4e0d\u8fc7\u662f\u7528\u4e00\u4e2a\u4fdd\u7559\u7ec4\u6765\u5806\u53e0\uff0c\u800c\u4e0d\u662f\u591a\u91cd\u6298\u53e0\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u6211\u5728\u672c\u7ae0\u4e2d\u6240\u63cf\u8ff0\u7684\u5185\u5bb9\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u7c7b\u578b\u7684\u95ee\u9898\uff1a\u5206\u7c7b\u3001\u56de\u5f52\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u7b49\u3002","title":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5"},{"location":"%E7%BB%84%E5%90%88%E5%92%8C%E5%A0%86%E5%8F%A0%E6%96%B9%E6%B3%95/#_1","text":"\u542c\u5230\u4e0a\u9762\u4e24\u4e2a\u8bcd\uff0c\u6211\u4eec\u9996\u5148\u60f3\u5230\u7684\u5c31\u662f\u5728\u7ebf\uff08online\uff09/\u79bb\u7ebf\uff08offline\uff09\u673a\u5668\u5b66\u4e60\u7ade\u8d5b\u3002\u51e0\u5e74\u524d\u662f\u8fd9\u6837\uff0c\u4f46\u73b0\u5728\u968f\u7740\u8ba1\u7b97\u80fd\u529b\u7684\u8fdb\u6b65\u548c\u865a\u62df\u5b9e\u4f8b\u7684\u5ec9\u4ef7\uff0c\u4eba\u4eec\u751a\u81f3\u5f00\u59cb\u5728\u884c\u4e1a\u4e2d\u4f7f\u7528\u7ec4\u5408\u6a21\u578b\uff08ensemble models\uff09\u3002\u4f8b\u5982\uff0c\u90e8\u7f72\u591a\u4e2a\u795e\u7ecf\u7f51\u7edc\u5e76\u5b9e\u65f6\u4e3a\u5b83\u4eec\u63d0\u4f9b\u670d\u52a1\u975e\u5e38\u5bb9\u6613\uff0c\u54cd\u5e94\u65f6\u95f4\u5c0f\u4e8e 500 \u6beb\u79d2\u3002\u6709\u65f6\uff0c\u4e00\u4e2a\u5e9e\u5927\u7684\u795e\u7ecf\u7f51\u7edc\u6216\u5927\u578b\u6a21\u578b\u4e5f\u53ef\u4ee5\u88ab\u5176\u4ed6\u51e0\u4e2a\u6a21\u578b\u53d6\u4ee3\uff0c\u8fd9\u4e9b\u6a21\u578b\u4f53\u79ef\u5c0f\uff0c\u6027\u80fd\u4e0e\u5927\u578b\u6a21\u578b\u76f8\u4f3c\uff0c\u901f\u5ea6\u5374\u5feb\u4e00\u500d\u3002\u5982\u679c\u662f\u8fd9\u79cd\u60c5\u51b5\uff0c\u4f60\u4f1a\u9009\u62e9\u54ea\u4e2a\uff08\u4e9b\uff09\u6a21\u578b\u5462\uff1f\u6211\u4e2a\u4eba\u66f4\u503e\u5411\u4e8e\u9009\u62e9\u591a\u4e2a\u5c0f\u673a\u578b\uff0c\u5b83\u4eec\u901f\u5ea6\u66f4\u5feb\uff0c\u6027\u80fd\u4e0e\u5927\u673a\u578b\u548c\u6162\u673a\u578b\u76f8\u540c\u3002\u8bf7\u8bb0\u4f4f\uff0c\u8f83\u5c0f\u7684\u578b\u53f7\u4e5f\u66f4\u5bb9\u6613\u548c\u66f4\u5feb\u5730\u8fdb\u884c\u8c03\u6574\u3002 \u7ec4\u5408\uff08ensembling\uff09\u4e0d\u8fc7\u662f\u4e0d\u540c\u6a21\u578b\u7684\u7ec4\u5408\u3002\u6a21\u578b\u53ef\u4ee5\u901a\u8fc7\u9884\u6d4b/\u6982\u7387\u8fdb\u884c\u7ec4\u5408\u3002\u7ec4\u5408\u6a21\u578b\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6c42\u5e73\u5747\u503c\u3002 $$ Ensemble Probabilities = (M1_proba + M2_proba + ... + Mn_Proba)/n $$ \u8fd9\u662f\u6700\u7b80\u5355\u4e5f\u662f\u6700\u6709\u6548\u7684\u7ec4\u5408\u6a21\u578b\u7684\u65b9\u6cd5\u3002\u5728\u7b80\u5355\u5e73\u5747\u6cd5\u4e2d\uff0c\u6240\u6709\u6a21\u578b\u7684\u6743\u91cd\u90fd\u662f\u76f8\u7b49\u7684\u3002\u65e0\u8bba\u91c7\u7528\u54ea\u79cd\u7ec4\u5408\u65b9\u6cd5\uff0c\u60a8\u90fd\u5e94\u8be5\u7262\u8bb0\u4e00\u70b9\uff0c\u90a3\u5c31\u662f\u60a8\u5e94\u8be5\u59cb\u7ec8\u5c06\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b/\u6982\u7387\u7ec4\u5408\u5728\u4e00\u8d77\u3002\u7b80\u5355\u5730\u8bf4\uff0c\u7ec4\u5408\u76f8\u5173\u6027\u4e0d\u9ad8\u7684\u6a21\u578b\u6bd4\u7ec4\u5408\u76f8\u5173\u6027\u5f88\u9ad8\u7684\u6a21\u578b\u6548\u679c\u66f4\u597d\u3002 \u5982\u679c\u6ca1\u6709\u6982\u7387\uff0c\u4e5f\u53ef\u4ee5\u7ec4\u5408\u9884\u6d4b\u3002\u6700\u7b80\u5355\u7684\u65b9\u6cd5\u5c31\u662f\u6295\u7968\u3002\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u8fdb\u884c\u591a\u7c7b\u5206\u7c7b\uff0c\u6709\u4e09\u4e2a\u7c7b\u522b\uff1a 0\u30011 \u548c 2\u3002 [0, 0, 1] : \u6700\u9ad8\u7968\u6570\uff1a 0 [0, 1, 2] : \u6700\u9ad8\u7968\u7ea7\uff1a \u65e0\uff08\u968f\u673a\u9009\u62e9\u4e00\u4e2a\uff09 [2, 2, 2] : \u6700\u9ad8\u7968\u6570\uff1a 2 \u4ee5\u4e0b\u7b80\u5355\u51fd\u6570\u53ef\u4ee5\u5b8c\u6210\u8fd9\u4e9b\u7b80\u5355\u64cd\u4f5c\u3002 import numpy as np def mean_predictions ( probas ): # \u8ba1\u7b97\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u6bcf\u884c\u5e73\u5747\u503c return np . mean ( probas , axis = 1 ) def max_voting ( preds ): # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u67e5\u627e\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u7684\u7d22\u5f15 idxs = np . argmax ( preds , axis = 1 ) # \u6839\u636e\u7d22\u5f15\u53d6\u51fa\u6bcf\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u5143\u7d20 return np . take_along_axis ( preds , idxs [:, None ], axis = 1 ) \u8bf7\u6ce8\u610f\uff0cprobas \u7684\u6bcf\u4e00\u5217\u90fd\u53ea\u6709\u4e00\u4e2a\u6982\u7387\uff08\u5373\u4e8c\u5143\u5206\u7c7b\uff0c\u901a\u5e38\u4e3a\u7c7b\u522b 1\uff09\u3002\u56e0\u6b64\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u4e00\u4e2a\u65b0\u6a21\u578b\u3002\u540c\u6837\uff0c\u5bf9\u4e8e preds\uff0c\u6bcf\u4e00\u5217\u90fd\u662f\u6765\u81ea\u4e0d\u540c\u6a21\u578b\u7684\u9884\u6d4b\u503c\u3002\u8fd9\u4e24\u4e2a\u51fd\u6570\u90fd\u5047\u8bbe\u4e86\u4e00\u4e2a 2 \u7ef4 numpy \u6570\u7ec4\u3002\u60a8\u53ef\u4ee5\u6839\u636e\u81ea\u5df1\u7684\u9700\u6c42\u5bf9\u5176\u8fdb\u884c\u4fee\u6539\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u6709\u4e00\u4e2a 2 \u7ef4\u6570\u7ec4\uff0c\u5176\u4e2d\u5305\u542b\u6bcf\u4e2a\u6a21\u578b\u7684\u6982\u7387\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u51fd\u6570\u4f1a\u6709\u4e00\u4e9b\u53d8\u5316\u3002 \u53e6\u4e00\u79cd\u7ec4\u5408\u591a\u4e2a\u6a21\u578b\u7684\u65b9\u6cd5\u662f\u901a\u8fc7\u5b83\u4eec\u7684 \u6982\u7387\u6392\u5e8f \u3002\u5f53\u76f8\u5173\u6307\u6807\u662f\u66f2\u7ebf\u4e0b\u9762\u79ef\uff08AUC\uff09\u65f6\uff0c\u8fd9\u79cd\u7ec4\u5408\u65b9\u5f0f\u975e\u5e38\u6709\u6548\uff0c\u56e0\u4e3a AUC \u5c31\u662f\u5bf9\u6837\u672c\u8fdb\u884c\u6392\u5e8f\u3002 def rank_mean ( probas ): # \u521b\u5efa\u7a7a\u5217\u8868ranked\u5b58\u50a8\u6bcf\u4e2a\u7c7b\u522b\u6982\u7387\u503c\u6392\u540d ranked = [] # \u904d\u5386\u6982\u7387\u503c\u6bcf\u4e00\u5217\uff08\u6bcf\u4e2a\u7c7b\u522b\u7684\u6982\u7387\u503c\uff09 for i in range ( probas . shape [ 1 ]): # \u5f53\u524d\u5217\u6982\u7387\u503c\u6392\u540d\uff0crank_data\u662f\u6392\u540d\u7ed3\u679c rank_data = stats . rankdata ( probas [:, i ]) # \u5c06\u5f53\u524d\u5217\u6392\u540d\u7ed3\u679c\u6dfb\u52a0\u5230ranked\u5217\u8868\u4e2d ranked . append ( rank_data ) # \u5c06ranked\u5217\u8868\u4e2d\u6392\u540d\u7ed3\u679c\u6309\u5217\u5806\u53e0\uff0c\u5f62\u6210\u4e8c\u7ef4\u6570\u7ec4 ranked = np . column_stack ( ranked ) # \u6cbf\u7740\u7b2c\u4e8c\u4e2a\u7ef4\u5ea6\uff08\u5217\uff09\u8ba1\u7b97\u6837\u672c\u6392\u540d\u5e73\u5747\u503c return np . mean ( ranked , axis = 1 ) \u8bf7\u6ce8\u610f\uff0c\u5728 scipy \u7684 rankdata \u4e2d\uff0c\u7b49\u7ea7\u4ece 1 \u5f00\u59cb\u3002 \u4e3a\u4ec0\u4e48\u8fd9\u7c7b\u96c6\u5408\u6709\u6548\uff1f\u8ba9\u6211\u4eec\u770b\u770b\u56fe 1\u3002 \u56fe 1\uff1a\u4e09\u4eba\u731c\u5927\u8c61\u7684\u8eab\u9ad8 \u56fe 1 \u663e\u793a\uff0c\u5982\u679c\u6709\u4e09\u4e2a\u4eba\u5728\u731c\u5927\u8c61\u7684\u9ad8\u5ea6\uff0c\u90a3\u4e48\u539f\u59cb\u9ad8\u5ea6\u5c06\u975e\u5e38\u63a5\u8fd1\u4e09\u4e2a\u4eba\u731c\u6d4b\u7684\u5e73\u5747\u503c\u3002\u6211\u4eec\u5047\u8bbe\u8fd9\u4e9b\u4eba\u90fd\u80fd\u731c\u5230\u975e\u5e38\u63a5\u8fd1\u5927\u8c61\u539f\u6765\u7684\u9ad8\u5ea6\u3002\u63a5\u8fd1\u4f30\u8ba1\u503c\u610f\u5473\u7740\u8bef\u5dee\uff0c\u4f46\u5982\u679c\u6211\u4eec\u5c06\u4e09\u4e2a\u9884\u6d4b\u503c\u5e73\u5747\uff0c\u5c31\u80fd\u5c06\u8bef\u5dee\u964d\u5230\u6700\u4f4e\u3002\u8fd9\u5c31\u662f\u591a\u4e2a\u6a21\u578b\u5e73\u5747\u7684\u4e3b\u8981\u601d\u60f3\u3002 $$ Final\\ Probabilities = w_1 \\times M1_proba + w_2 \\times M2_proba + \\cdots + w_n \\times Mn_proba $$ \u5176\u4e2d \\((w_1 + w_2 + w_3 + \\cdots + w_n)=1.0\\) \u4f8b\u5982\uff0c\u5982\u679c\u4f60\u6709\u4e00\u4e2a AUC \u975e\u5e38\u9ad8\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u548c\u4e00\u4e2a AUC \u7a0d\u4f4e\u7684\u903b\u8f91\u56de\u5f52\u6a21\u578b\uff0c\u4f60\u53ef\u4ee5\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0c\u968f\u673a\u68ee\u6797\u6a21\u578b\u5360 70%\uff0c\u903b\u8f91\u56de\u5f52\u6a21\u578b\u5360 30%\u3002\u90a3\u4e48\uff0c\u6211\u662f\u5982\u4f55\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u7684\u5462\uff1f\u8ba9\u6211\u4eec\u518d\u6dfb\u52a0\u4e00\u4e2a\u6a21\u578b\uff0c\u5047\u8bbe\u73b0\u5728\u6211\u4eec\u4e5f\u6709\u4e00\u4e2a xgboost \u6a21\u578b\uff0c\u5b83\u7684 AUC \u6bd4\u968f\u673a\u68ee\u6797\u9ad8\u3002\u73b0\u5728\uff0c\u6211\u5c06\u628a\u5b83\u4eec\u7ed3\u5408\u8d77\u6765\uff0cxgboost\uff1a\u968f\u673a\u68ee\u6797\uff1a\u903b\u8f91\u56de\u5f52\u7684\u6bd4\u4f8b\u4e3a 3:2:1\u3002\u5f88\u7b80\u5355\u5427\uff1f\u5f97\u51fa\u8fd9\u4e9b\u6570\u5b57\u6613\u5982\u53cd\u638c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u662f\u5982\u4f55\u505a\u5230\u7684\u3002 \u5047\u5b9a\u6211\u4eec\u6709\u4e09\u53ea\u7334\u5b50\uff0c\u4e09\u53ea\u65cb\u94ae\u7684\u6570\u503c\u5728 0 \u548c 1 \u4e4b\u95f4\u3002\u8fd9\u4e9b\u7334\u5b50\u8f6c\u52a8\u65cb\u94ae\uff0c\u6211\u4eec\u8ba1\u7b97\u5b83\u4eec\u6bcf\u8f6c\u5230\u4e00\u4e2a\u6570\u503c\u65f6\u7684 AUC \u5206\u6570\u3002\u6700\u7ec8\uff0c\u7334\u5b50\u4eec\u4f1a\u627e\u5230\u4e00\u4e2a\u80fd\u7ed9\u51fa\u6700\u4f73 AUC \u7684\u7ec4\u5408\u3002\u6ca1\u9519\uff0c\u8fd9\u5c31\u662f\u968f\u673a\u641c\u7d22\uff01\u5728\u8fdb\u884c\u8fd9\u7c7b\u641c\u7d22\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u8bb0\u4f4f\u4e24\u4e2a\u6700\u91cd\u8981\u7684\u7ec4\u5408\u89c4\u5219\u3002 \u7ec4\u5408\u7684\u7b2c\u4e00\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u7ec4\u5408\u7684\u7b2c\u4e8c\u6761\u89c4\u5219\u662f\uff0c\u5728\u5f00\u59cb\u5408\u594f\u4e4b\u524d\uff0c\u4e00\u5b9a\u8981\u5148\u521b\u5efa\u6298\u53e0\u3002 \u662f\u7684\u3002\u8fd9\u662f\u6700\u91cd\u8981\u7684\u4e24\u6761\u89c4\u5219\u3002\u7b2c\u4e00\u6b65\u662f\u521b\u5efa\u6298\u53e0\u3002\u4e3a\u4e86\u7b80\u5355\u8d77\u89c1\uff0c\u5047\u8bbe\u6211\u4eec\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\uff1a\u6298\u53e0 1 \u548c\u6298\u53e0 2\u3002\u8bf7\u6ce8\u610f\uff0c\u8fd9\u6837\u505a\u53ea\u662f\u4e3a\u4e86\u7b80\u5316\u89e3\u91ca\u3002\u5728\u5b9e\u9645\u5e94\u7528\u4e2d\uff0c\u60a8\u5e94\u8be5\u521b\u5efa\u66f4\u591a\u7684\u6298\u53e0\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u968f\u673a\u68ee\u6797\u6a21\u578b\u3001\u903b\u8f91\u56de\u5f52\u6a21\u578b\u548c xgboost \u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 2 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u4e4b\u540e\uff0c\u6211\u4eec\u5728\u6298\u53e0 2 \u4e0a\u4ece\u5934\u5f00\u59cb\u8bad\u7ec3\u6a21\u578b\uff0c\u5e76\u5728\u6298\u53e0 1 \u4e0a\u8fdb\u884c\u9884\u6d4b\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u5c31\u4e3a\u6240\u6709\u8bad\u7ec3\u6570\u636e\u521b\u5efa\u4e86\u9884\u6d4b\u7ed3\u679c\u3002\u73b0\u5728\uff0c\u4e3a\u4e86\u5408\u5e76\u8fd9\u4e9b\u6a21\u578b\uff0c\u6211\u4eec\u5c06\u6298\u53e0 1 \u548c\u6298\u53e0 1 \u7684\u6240\u6709\u9884\u6d4b\u6570\u636e\u5408\u5e76\u5728\u4e00\u8d77\uff0c\u7136\u540e\u521b\u5efa\u4e00\u4e2a\u4f18\u5316\u51fd\u6570\uff0c\u8bd5\u56fe\u627e\u5230\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4fbf\u9488\u5bf9\u6298\u53e0 2 \u7684\u76ee\u6807\u6700\u5c0f\u5316\u8bef\u5dee\u6216\u6700\u5927\u5316 AUC\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u662f\u7528\u4e09\u4e2a\u6a21\u578b\u7684\u9884\u6d4b\u6982\u7387\u5728\u6298\u53e0 1 \u4e0a\u8bad\u7ec3\u4e00\u4e2a\u4f18\u5316\u6a21\u578b\uff0c\u7136\u540e\u5728\u6298\u53e0 2 \u4e0a\u5bf9\u5176\u8fdb\u884c\u8bc4\u4f30\u3002\u8ba9\u6211\u4eec\u5148\u6765\u770b\u770b\u6211\u4eec\u53ef\u4ee5\u7528\u6765\u627e\u5230\u591a\u4e2a\u6a21\u578b\u7684\u6700\u4f73\u6743\u91cd\uff0c\u4ee5\u4f18\u5316 AUC\uff08\u6216\u4efb\u4f55\u7c7b\u578b\u7684\u9884\u6d4b\u6307\u6807\u7ec4\u5408\uff09\u7684\u7c7b\u3002 import numpy as np from functools import partial from scipy.optimize import fmin from sklearn import metrics class OptimizeAUC : def __init__ ( self ): # \u521d\u59cb\u5316\u7cfb\u6570 self . coef_ = 0 def _auc ( self , coef , X , y ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u7cfb\u6570 x_coef = X * coef # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8ba1\u7b97AUC\u5206\u6570 auc_score = metrics . roc_auc_score ( y , predictions ) # \u8fd4\u56de\u8d1fAUC\u4ee5\u4fbf\u6700\u5c0f\u5316 return - 1.0 * auc_score def fit ( self , X , y ): # \u521b\u5efa\u5e26\u6709\u90e8\u5206\u53c2\u6570\u7684\u76ee\u6807\u51fd\u6570 loss_partial = partial ( self . _auc , X = X , y = y ) # \u521d\u59cb\u5316\u7cfb\u6570 initial_coef = np . random . dirichlet ( np . ones ( X . shape [ 1 ]), size = 1 ) # \u4f7f\u7528fmin\u51fd\u6570\u4f18\u5316AUC\u76ee\u6807\u51fd\u6570\uff0c\u627e\u5230\u6700\u4f18\u7cfb\u6570 self . coef_ = fmin ( loss_partial , initial_coef , disp = True ) def predict ( self , X ): # \u5bf9\u8f93\u5165\u6570\u636e\u4e58\u4ee5\u8bad\u7ec3\u597d\u7684\u7cfb\u6570 x_coef = X * self . coef_ # \u8ba1\u7b97\u6bcf\u4e2a\u6837\u672c\u9884\u6d4b\u503c predictions = np . sum ( x_coef , axis = 1 ) # \u8fd4\u56de\u9884\u6d4b\u7ed3\u679c return predictions \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5982\u4f55\u4f7f\u7528\u5b83\uff0c\u5e76\u5c06\u5176\u4e0e\u7b80\u5355\u5e73\u5747\u6cd5\u8fdb\u884c\u6bd4\u8f83\u3002 import xgboost as xgb from sklearn.datasets import make_classification from sklearn import ensemble from sklearn import linear_model from sklearn import metrics from sklearn import model_selection # \u751f\u6210\u4e00\u4e2a\u5206\u7c7b\u6570\u636e\u96c6 X , y = make_classification ( n_samples = 10000 , n_features = 25 ) # \u5212\u5206\u6570\u636e\u96c6\u4e3a\u4e24\u4e2a\u4ea4\u53c9\u9a8c\u8bc1\u6298\u53e0 xfold1 , xfold2 , yfold1 , yfold2 = model_selection . train_test_split ( X , y , test_size = 0.5 , stratify = y ) # \u521d\u59cb\u5316\u4e09\u4e2a\u4e0d\u540c\u7684\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold1 , yfold1 ) rf . fit ( xfold1 , yfold1 ) xgbc . fit ( xfold1 , yfold1 ) # \u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold2 )[:, 1 ] pred_rf = rf . predict_proba ( xfold2 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold2 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold2_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold2 = [] for i in range ( fold2_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold2 , fold2_preds [:, i ]) aucs_fold2 . append ( auc ) print ( f \"Fold-2: LR AUC = { aucs_fold2 [ 0 ] } \" ) print ( f \"Fold-2: RF AUC = { aucs_fold2 [ 1 ] } \" ) print ( f \"Fold-2: XGB AUC = { aucs_fold2 [ 2 ] } \" ) print ( f \"Fold-2: Average Pred AUC = { aucs_fold2 [ 3 ] } \" ) # \u91cd\u65b0\u521d\u59cb\u5316\u5206\u7c7b\u5668 logreg = linear_model . LogisticRegression () rf = ensemble . RandomForestClassifier () xgbc = xgb . XGBClassifier () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8bad\u7ec3\u5206\u7c7b\u5668 logreg . fit ( xfold2 , yfold2 ) rf . fit ( xfold2 , yfold2 ) xgbc . fit ( xfold2 , yfold2 ) # \u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u8fdb\u884c\u9884\u6d4b pred_logreg = logreg . predict_proba ( xfold1 )[:, 1 ] pred_rf = rf . predict_proba ( xfold1 )[:, 1 ] pred_xgbc = xgbc . predict_proba ( xfold1 )[:, 1 ] # \u8ba1\u7b97\u5e73\u5747\u9884\u6d4b\u7ed3\u679c avg_pred = ( pred_logreg + pred_rf + pred_xgbc ) / 3 fold1_preds = np . column_stack (( pred_logreg , pred_rf , pred_xgbc , avg_pred )) # \u8ba1\u7b97\u6bcf\u4e2a\u6a21\u578b\u7684AUC\u5206\u6570\u5e76\u6253\u5370 aucs_fold1 = [] for i in range ( fold1_preds . shape [ 1 ]): auc = metrics . roc_auc_score ( yfold1 , fold1_preds [:, i ]) aucs_fold1 . append ( auc ) print ( f \"Fold-1: LR AUC = { aucs_fold1 [ 0 ] } \" ) print ( f \"Fold-1: RF AUC = { aucs_fold1 [ 1 ] } \" ) print ( f \"Fold-1: XGB AUC = { aucs_fold1 [ 2 ] } \" ) print ( f \"Fold-1: Average prediction AUC = { aucs_fold1 [ 3 ] } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765\u8bad\u7ec3\u4f18\u5316\u5668 opt . fit ( fold1_preds [:, : - 1 ], yfold1 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold2 = opt . predict ( fold2_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold2 , opt_preds_fold2 ) print ( f \"Optimized AUC, Fold 2 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) # \u521d\u59cb\u5316AUC\u4f18\u5316\u5668 opt = OptimizeAUC () # \u4f7f\u7528\u7b2c\u4e8c\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u6765 opt . fit ( fold2_preds [:, : - 1 ], yfold2 ) # \u4f7f\u7528\u4f18\u5316\u5668\u5bf9\u7b2c\u4e00\u4e2a\u6298\u53e0\u6570\u636e\u96c6\u7684\u9884\u6d4b\u7ed3\u679c\u8fdb\u884c\u4f18\u5316 opt_preds_fold1 = opt . predict ( fold1_preds [:, : - 1 ]) auc = metrics . roc_auc_score ( yfold1 , opt_preds_fold1 ) print ( f \"Optimized AUC, Fold 1 = { auc } \" ) print ( f \"Coefficients = { opt . coef_ } \" ) \u8ba9\u6211\u4eec\u770b\u4e00\u4e0b\u8f93\u51fa\uff1a \u276f python auc_opt . py Fold - 2 : LR AUC = 0.9145446769443348 Fold - 2 : RF AUC = 0.9269918948683287 Fold - 2 : XGB AUC = 0.9302436595508696 Fold - 2 : Average Pred AUC = 0.927701495890154 Fold - 1 : LR AUC = 0.9050872233256017 Fold - 1 : RF AUC = 0.9179382818311258 Fold - 1 : XGB AUC = 0.9195837242005629 Fold - 1 : Average prediction AUC = 0.9189669233123695 Optimization terminated successfully . Current function value : - 0.920643 Iterations : 50 Function evaluations : 109 Optimized AUC , Fold 2 = 0.9305386199756128 Coefficients = [ - 0.00188194 0.19328336 0.35891836 ] Optimization terminated successfully . Current function value : - 0.931232 Iterations : 56 Function evaluations : 113 Optimized AUC , Fold 1 = 0.9192523637234037 Coefficients = [ - 0.15655124 0.22393151 0.58711366 ] \u6211\u4eec\u770b\u5230\uff0c\u5e73\u5747\u503c\u66f4\u597d\uff0c\u4f46\u4f7f\u7528\u4f18\u5316\u5668\u627e\u5230\u9608\u503c\u66f4\u597d\uff01\u6709\u65f6\uff0c\u5e73\u5747\u503c\u662f\u6700\u597d\u7684\u9009\u62e9\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u7cfb\u6570\u52a0\u8d77\u6765\u5e76\u6ca1\u6709\u8fbe\u5230 1.0\uff0c\u4f46\u8fd9\u6ca1\u5173\u7cfb\uff0c\u56e0\u4e3a\u6211\u4eec\u8981\u5904\u7406\u7684\u662f AUC\uff0c\u800c AUC \u53ea\u5173\u5fc3\u7b49\u7ea7\u3002 \u5373\u4f7f\u968f\u673a\u68ee\u6797\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u968f\u673a\u68ee\u6797\u53ea\u662f\u8bb8\u591a\u7b80\u5355\u51b3\u7b56\u6811\u7684\u7ec4\u5408\u3002\u968f\u673a\u68ee\u6797\u5c5e\u4e8e\u96c6\u5408\u6a21\u578b\u7684\u4e00\u79cd\uff0c\u4e5f\u5c31\u662f\u4fd7\u79f0\u7684 \"bagging\" \u3002\u5728\u888b\u96c6\u6a21\u578b\u4e2d\uff0c\u6211\u4eec\u521b\u5efa\u5c0f\u6570\u636e\u5b50\u96c6\u5e76\u8bad\u7ec3\u591a\u4e2a\u7b80\u5355\u6a21\u578b\u3002\u6700\u7ec8\u7ed3\u679c\u7531\u6240\u6709\u8fd9\u4e9b\u5c0f\u6a21\u578b\u7684\u9884\u6d4b\u7ed3\u679c\uff08\u5982\u5e73\u5747\u503c\uff09\u7ec4\u5408\u800c\u6210\u3002 \u6211\u4eec\u4f7f\u7528\u7684 xgboost \u6a21\u578b\u4e5f\u662f\u4e00\u4e2a\u96c6\u5408\u6a21\u578b\u3002\u6240\u6709\u68af\u5ea6\u63d0\u5347\u6a21\u578b\u90fd\u662f\u96c6\u5408\u6a21\u578b\uff0c\u7edf\u79f0\u4e3a \u63d0\u5347\u6a21\u578b\uff08boosting models\uff09 \u3002\u63d0\u5347\u6a21\u578b\u7684\u5de5\u4f5c\u539f\u7406\u4e0e\u88c5\u888b\u6a21\u578b\u7c7b\u4f3c\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u63d0\u5347\u6a21\u578b\u4e2d\u7684\u8fde\u7eed\u6a21\u578b\u662f\u6839\u636e\u8bef\u5dee\u6b8b\u5dee\u8bad\u7ec3\u7684\uff0c\u5e76\u503e\u5411\u4e8e\u6700\u5c0f\u5316\u524d\u9762\u6a21\u578b\u7684\u8bef\u5dee\u3002\u8fd9\u6837\uff0c\u63d0\u5347\u6a21\u578b\u5c31\u80fd\u5b8c\u7f8e\u5730\u5b66\u4e60\u6570\u636e\uff0c\u56e0\u6b64\u5bb9\u6613\u51fa\u73b0\u8fc7\u62df\u5408\u3002 \u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u770b\u5230\u7684\u4ee3\u7801\u7247\u6bb5\u53ea\u8003\u8651\u4e86\u4e00\u5217\u3002\u4f46\u60c5\u51b5\u5e76\u975e\u603b\u662f\u5982\u6b64\uff0c\u5f88\u591a\u65f6\u5019\u60a8\u9700\u8981\u5904\u7406\u591a\u5217\u9884\u6d4b\u3002\u4f8b\u5982\uff0c\u60a8\u53ef\u80fd\u4f1a\u9047\u5230\u4ece\u591a\u4e2a\u7c7b\u522b\u4e2d\u9884\u6d4b\u4e00\u4e2a\u7c7b\u522b\u7684\u95ee\u9898\uff0c\u5373\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5bf9\u4e8e\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u4f60\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u9009\u62e9\u6295\u7968\u65b9\u6cd5\u3002\u4f46\u6295\u7968\u6cd5\u5e76\u4e0d\u603b\u662f\u6700\u4f73\u65b9\u6cd5\u3002\u5982\u679c\u8981\u7ec4\u5408\u6982\u7387\uff0c\u5c31\u4f1a\u6709\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4\uff0c\u800c\u4e0d\u662f\u50cf\u6211\u4eec\u4e4b\u524d\u4f18\u5316 AUC \u65f6\u7684\u5411\u91cf\u3002\u5982\u679c\u6709\u591a\u4e2a\u7c7b\u522b\uff0c\u53ef\u4ee5\u5c1d\u8bd5\u4f18\u5316\u5bf9\u6570\u635f\u5931\uff08\u6216\u5176\u4ed6\u4e0e\u4e1a\u52a1\u76f8\u5173\u7684\u6307\u6807\uff09\u3002 \u8981\u8fdb\u884c\u7ec4\u5408\uff0c\u53ef\u4ee5\u5728\u62df\u5408\u51fd\u6570 (X) \u4e2d\u4f7f\u7528 numpy \u6570\u7ec4\u5217\u8868\u800c\u4e0d\u662f numpy \u6570\u7ec4\uff0c\u968f\u540e\u8fd8\u9700\u8981\u66f4\u6539\u4f18\u5316\u5668\u548c\u9884\u6d4b\u51fd\u6570\u3002\u6211\u5c31\u628a\u5b83\u4f5c\u4e3a\u4e00\u4e2a\u7ec3\u4e60\u7559\u7ed9\u5927\u5bb6\u5427\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8fdb\u5165\u4e0b\u4e00\u4e2a\u6709\u8da3\u7684\u8bdd\u9898\uff0c\u8fd9\u4e2a\u8bdd\u9898\u76f8\u5f53\u6d41\u884c\uff0c\u88ab\u79f0\u4e3a \u5806\u53e0 \u3002\u56fe 2 \u5c55\u793a\u4e86\u5982\u4f55\u5806\u53e0\u6a21\u578b\u3002 \u56fe2 : Stacking \u5806\u53e0\u4e0d\u50cf\u5236\u9020\u706b\u7bad\u3002\u5b83\u7b80\u5355\u660e\u4e86\u3002\u5982\u679c\u60a8\u8fdb\u884c\u4e86\u6b63\u786e\u7684\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5e76\u5728\u6574\u4e2a\u5efa\u6a21\u8fc7\u7a0b\u4e2d\u4fdd\u6301\u6298\u53e0\u4e0d\u53d8\uff0c\u90a3\u4e48\u5c31\u4e0d\u4f1a\u51fa\u73b0\u4efb\u4f55\u8fc7\u5ea6\u8d34\u5408\u7684\u60c5\u51b5\u3002 \u8ba9\u6211\u7528\u7b80\u5355\u7684\u8981\u70b9\u5411\u4f60\u63cf\u8ff0\u4e00\u4e0b\u8fd9\u4e2a\u60f3\u6cd5\u3002 - \u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u82e5\u5e72\u6298\u53e0\u3002 - \u8bad\u7ec3\u4e00\u5806\u6a21\u578b\uff1a M1\u3001M2.....Mn\u3002 - \u521b\u5efa\u5b8c\u6574\u7684\u8bad\u7ec3\u9884\u6d4b\uff08\u4f7f\u7528\u975e\u6298\u53e0\u8bad\u7ec3\uff09\uff0c\u5e76\u4f7f\u7528\u6240\u6709\u8fd9\u4e9b\u6a21\u578b\u8fdb\u884c\u6d4b\u8bd5\u9884\u6d4b\u3002 - \u76f4\u5230\u8fd9\u91cc\u662f\u7b2c 1 \u5c42 (L1)\u3002 - \u5c06\u8fd9\u4e9b\u6a21\u578b\u7684\u6298\u53e0\u9884\u6d4b\u4f5c\u4e3a\u53e6\u4e00\u4e2a\u6a21\u578b\u7684\u7279\u5f81\u3002\u8fd9\u5c31\u662f\u4e8c\u7ea7\u6a21\u578b\uff08L2\uff09\u3002 - \u4f7f\u7528\u4e0e\u4e4b\u524d\u76f8\u540c\u7684\u6298\u53e0\u6765\u8bad\u7ec3\u8fd9\u4e2a L2 \u6a21\u578b\u3002 - \u73b0\u5728\uff0c\u5728\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4e0a\u521b\u5efa OOF\uff08\u6298\u53e0\u5916\uff09\u9884\u6d4b\u3002 - \u73b0\u5728\u60a8\u5c31\u6709\u4e86\u8bad\u7ec3\u6570\u636e\u7684 L2 \u9884\u6d4b\u548c\u6700\u7ec8\u6d4b\u8bd5\u96c6\u9884\u6d4b\u3002 \u60a8\u53ef\u4ee5\u4e0d\u65ad\u91cd\u590d L1 \u90e8\u5206\uff0c\u4e5f\u53ef\u4ee5\u521b\u5efa\u4efb\u610f\u591a\u7684\u5c42\u6b21\u3002 \u6709\u65f6\uff0c\u4f60\u8fd8\u4f1a\u9047\u5230\u4e00\u4e2a\u53eb\u6df7\u5408\u7684\u672f\u8bed blending \u3002\u5982\u679c\u4f60\u9047\u5230\u4e86\uff0c\u4e0d\u7528\u592a\u62c5\u5fc3\u3002\u5b83\u53ea\u4e0d\u8fc7\u662f\u7528\u4e00\u4e2a\u4fdd\u7559\u7ec4\u6765\u5806\u53e0\uff0c\u800c\u4e0d\u662f\u591a\u91cd\u6298\u53e0\u3002\u5fc5\u987b\u6307\u51fa\u7684\u662f\uff0c\u6211\u5728\u672c\u7ae0\u4e2d\u6240\u63cf\u8ff0\u7684\u5185\u5bb9\u53ef\u4ee5\u5e94\u7528\u4e8e\u4efb\u4f55\u7c7b\u578b\u7684\u95ee\u9898\uff1a\u5206\u7c7b\u3001\u56de\u5f52\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u7b49\u3002","title":"\u7ec4\u5408\u548c\u5806\u53e0\u65b9\u6cd5"},{"location":"%E7%BB%84%E7%BB%87%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%A1%B9%E7%9B%AE/","text":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee \u7ec8\u4e8e\uff0c\u6211\u4eec\u53ef\u4ee5\u5f00\u59cb\u6784\u5efa\u7b2c\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e86\u3002 \u662f\u8fd9\u6837\u5417\uff1f \u5728\u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u6ce8\u610f\u51e0\u4ef6\u4e8b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u6211\u4eec\u5c06\u5728\u96c6\u6210\u5f00\u53d1\u73af\u5883/\u6587\u672c\u7f16\u8f91\u5668\u4e2d\u5de5\u4f5c\uff0c\u800c\u4e0d\u662f\u5728 jupyter notebook\u4e2d\u3002\u4f60\u4e5f\u53ef\u4ee5\u5728 jupyter notebook\u4e2d\u5de5\u4f5c\uff0c\u8fd9\u5b8c\u5168\u53d6\u51b3\u4e8e\u4f60\u3002\u4e0d\u8fc7\uff0c\u6211\u5c06\u53ea\u4f7f\u7528 jupyter notebook\u6765\u63a2\u7d22\u6570\u636e\u3001\u7ed8\u5236\u56fe\u8868\u548c\u56fe\u5f62\u3002\u6211\u4eec\u5c06\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u6784\u5efa\u5206\u7c7b\u6846\u67b6\uff0c\u5373\u63d2\u5373\u7528\u3002\u60a8\u65e0\u9700\u5bf9\u4ee3\u7801\u505a\u592a\u591a\u6539\u52a8\u5c31\u80fd\u8bad\u7ec3\u6a21\u578b\uff0c\u800c\u4e14\u5f53\u60a8\u6539\u8fdb\u6a21\u578b\u65f6\uff0c\u8fd8\u80fd\u4f7f\u7528 git \u5bf9\u5176\u8fdb\u884c\u8ddf\u8e2a\u3002 \u6211\u4eec\u9996\u5148\u6765\u770b\u770b\u6587\u4ef6\u7684\u7ed3\u6784\u3002\u5bf9\u4e8e\u4f60\u6b63\u5728\u505a\u7684\u4efb\u4f55\u9879\u76ee\uff0c\u90fd\u8981\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\u5939\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u5c06\u9879\u76ee\u547d\u540d\u4e3a \"project\"\u3002 \u9879\u76ee\u6587\u4ef6\u5939\u5185\u90e8\u5e94\u8be5\u5982\u4e0b\u6240\u793a\u3002 input train.csv test.csv src create_folds.py train.py inference.py models.py config.py model_dispatcher.py models model_rf.bin model_et.bin notebooks exploration.ipynb check_data.ipynb README.md LICENSE \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e9b\u6587\u4ef6\u5939\u548c\u6587\u4ef6\u7684\u5185\u5bb9\u3002 input/ \uff1a\u8be5\u6587\u4ef6\u5939\u5305\u542b\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7684\u6240\u6709\u8f93\u5165\u6587\u4ef6\u548c\u6570\u636e\u3002\u5982\u679c\u60a8\u6b63\u5728\u5f00\u53d1 NLP \u9879\u76ee\uff0c\u60a8\u53ef\u4ee5\u5c06embeddings\u653e\u5728\u8fd9\u91cc\u3002\u5982\u679c\u662f\u56fe\u50cf\u9879\u76ee\uff0c\u6240\u6709\u56fe\u50cf\u90fd\u653e\u5728\u8be5\u6587\u4ef6\u5939\u4e0b\u7684\u5b50\u6587\u4ef6\u5939\u4e2d\u3002 src/ \uff1a\u6211\u4eec\u5c06\u5728\u8fd9\u91cc\u4fdd\u5b58\u4e0e\u9879\u76ee\u76f8\u5173\u7684\u6240\u6709 python \u811a\u672c\u3002\u5982\u679c\u6211\u8bf4\u7684\u662f\u4e00\u4e2a python \u811a\u672c\uff0c\u5373\u4efb\u4f55 *.py \u6587\u4ef6\uff0c\u5b83\u90fd\u5b58\u50a8\u5728 src \u6587\u4ef6\u5939\u4e2d\u3002 models/ \uff1a\u8be5\u6587\u4ef6\u5939\u4fdd\u5b58\u6240\u6709\u8bad\u7ec3\u8fc7\u7684\u6a21\u578b\u3002 notebook/ \uff1a\u6240\u6709 jupyter notebook\uff08\u5373\u4efb\u4f55 *.ipynb \u6587\u4ef6\uff09\u90fd\u5b58\u50a8\u5728\u7b14\u8bb0\u672c \u6587\u4ef6\u5939\u4e2d\u3002 README.md \uff1a\u8fd9\u662f\u4e00\u4e2a\u6807\u8bb0\u7b26\u6587\u4ef6\uff0c\u60a8\u53ef\u4ee5\u5728\u5176\u4e2d\u63cf\u8ff0\u60a8\u7684\u9879\u76ee\uff0c\u5e76\u5199\u660e\u5982\u4f55\u8bad\u7ec3\u6a21\u578b\u6216\u5728\u751f\u4ea7\u73af\u5883\u4e2d\u4f7f\u7528\u3002 LICENSE \uff1a\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u6587\u4ef6\uff0c\u5305\u542b\u9879\u76ee\u7684\u8bb8\u53ef\u8bc1\uff0c\u5982 MIT\u3001Apache \u7b49\u3002\u5173\u4e8e\u8bb8\u53ef\u8bc1\u7684\u8be6\u7ec6\u4ecb\u7ecd\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 \u5047\u8bbe\u4f60\u6b63\u5728\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u5bf9 MNIST \u6570\u636e\u96c6\uff08\u51e0\u4e4e\u6bcf\u672c\u673a\u5668\u5b66\u4e60\u4e66\u7c4d\u90fd\u4f1a\u7528\u5230\u7684\u6570\u636e\u96c6\uff09\u8fdb\u884c\u5206\u7c7b\u3002\u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\uff0c\u6211\u4eec\u5728\u4ea4\u53c9\u68c0\u9a8c\u4e00\u7ae0\u4e2d\u4e5f\u63d0\u5230\u8fc7 MNIST \u6570\u636e\u96c6\u3002\u6240\u4ee5\uff0c\u6211\u5c31\u4e0d\u89e3\u91ca\u8fd9\u4e2a\u6570\u636e\u96c6\u662f\u4ec0\u4e48\u6837\u5b50\u4e86\u3002\u7f51\u4e0a\u6709\u8bb8\u591a\u4e0d\u540c\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6\uff0c\u4f46\u6211\u4eec\u5c06\u4f7f\u7528 CSV \u683c\u5f0f\u7684\u6570\u636e\u96c6\u3002 \u5728\u8fd9\u79cd\u683c\u5f0f\u7684\u6570\u636e\u96c6\u4e2d\uff0cCSV \u7684\u6bcf\u4e00\u884c\u90fd\u5305\u542b\u56fe\u50cf\u7684\u6807\u7b7e\u548c 784 \u4e2a\u50cf\u7d20\u503c\uff0c\u50cf\u7d20\u503c\u8303\u56f4\u4ece 0 \u5230 255\u3002\u6570\u636e\u96c6\u5305\u542b 60000 \u5f20\u8fd9\u79cd\u683c\u5f0f\u7684\u56fe\u50cf\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u8f7b\u677e\u8bfb\u53d6\u8fd9\u79cd\u6570\u636e\u683c\u5f0f\u3002 \u8bf7\u6ce8\u610f\uff0c\u5c3d\u7ba1\u56fe 1 \u663e\u793a\u6240\u6709\u50cf\u7d20\u503c\u5747\u4e3a\u96f6\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002 \u56fe 1\uff1aCSV\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u6807\u7b7e\u5217\u7684\u8ba1\u6570\u3002 \u56fe 2\uff1aMNIST \u6570\u636e\u96c6\u4e2d\u7684\u6807\u7b7e\u8ba1\u6570 \u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c\u66f4\u591a\u7684\u63a2\u7d22\u3002\u6211\u4eec\u5df2\u7ecf\u77e5\u9053\u4e86\u6211\u4eec\u6240\u62e5\u6709\u7684\u6570\u636e\uff0c\u6ca1\u6709\u5fc5\u8981\u518d\u5bf9\u4e0d\u540c\u7684\u50cf\u7d20\u503c\u8fdb\u884c\u7ed8\u56fe\u3002\u4ece\u56fe 2 \u4e2d\u53ef\u4ee5\u6e05\u695a\u5730\u770b\u51fa\uff0c\u6807\u7b7e\u7684\u5206\u5e03\u76f8\u5f53\u5747\u5300\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387/F1 \u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u8fd9\u5c31\u662f\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u7684\u7b2c\u4e00\u6b65\uff1a\u786e\u5b9a\u8861\u91cf\u6807\u51c6\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7f16\u5199\u4e00\u4e9b\u4ee3\u7801\u4e86\u3002\u6211\u4eec\u9700\u8981\u521b\u5efa src/ \u6587\u4ef6\u5939\u548c\u4e00\u4e9b python \u811a\u672c\u3002 \u8bf7\u6ce8\u610f\uff0c\u8bad\u7ec3 CSV \u6587\u4ef6\u4f4d\u4e8e input/ \u6587\u4ef6\u5939\u4e2d\uff0c\u540d\u4e3a mnist_train.csv \u3002 \u5bf9\u4e8e\u8fd9\u6837\u4e00\u4e2a\u9879\u76ee\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5e94\u8be5\u662f\u4ec0\u4e48\u6837\u7684\u5462\uff1f \u9996\u5148\u8981\u521b\u5efa\u7684\u811a\u672c\u662f create_folds.py \u3002 \u8fd9\u5c06\u5728 input/ \u6587\u4ef6\u5939\u4e2d\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a mnist_train_folds.csv \u7684\u65b0\u6587\u4ef6\uff0c\u4e0e mnist_train.csv \u76f8\u540c\u3002\u552f\u4e00\u4e0d\u540c\u7684\u662f\uff0c\u8fd9\u4e2a CSV \u6587\u4ef6\u7ecf\u8fc7\u4e86\u968f\u673a\u6392\u5e8f\uff0c\u5e76\u65b0\u589e\u4e86\u4e00\u5217\u540d\u4e3a kfold \u7684\u5185\u5bb9\u3002 \u4e00\u65e6\u6211\u4eec\u51b3\u5b9a\u4e86\u8981\u4f7f\u7528\u54ea\u79cd\u8bc4\u4f30\u6307\u6807\u5e76\u521b\u5efa\u4e86\u6298\u53e0\uff0c\u5c31\u53ef\u4ee5\u5f00\u59cb\u521b\u5efa\u57fa\u672c\u6a21\u578b\u4e86\u3002\u8fd9\u53ef\u4ee5\u5728 train.py \u4e2d\u5b8c\u6210\u3002 import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/mnist_train_folds.csv\" ) # \u9009\u53d6df\u4e2dkfold\u5217\u4e0d\u7b49\u4e8efold df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u9009\u53d6df\u4e2dkfold\u5217\u7b49\u4e8efold df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u8bad\u7ec3\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_train = df_train . drop ( \"label\" , axis = 1 ) . values # \u8bad\u7ec3\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_train = df_train . label . values # \u9a8c\u8bc1\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values # \u9a8c\u8bc1\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_valid = df_valid . label . values # \u5b9e\u4f8b\u5316\u51b3\u7b56\u6811\u6a21\u578b clf = tree . DecisionTreeClassifier () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b clf . fit ( x_train , y_train ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u8f93\u5165\u5f97\u5230\u9884\u6d4b\u7ed3\u679c preds = clf . predict ( x_valid ) # \u8ba1\u7b97\u9a8c\u8bc1\u96c6\u51c6\u786e\u7387 accuracy = metrics . accuracy_score ( y_valid , preds ) # \u6253\u5370fold\u4fe1\u606f\u548c\u51c6\u786e\u7387 print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) # \u4fdd\u5b58\u6a21\u578b joblib . dump ( clf , f \"../models/dt_ { fold } .bin\" ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u60a8\u53ef\u4ee5\u5728\u63a7\u5236\u53f0\u8c03\u7528 python train.py \u8fd0\u884c\u8be5\u811a\u672c\u3002 \u276f python train . py Fold = 0 , Accuracy = 0.8680833333333333 Fold = 1 , Accuracy = 0.8685 Fold = 2 , Accuracy = 0.8674166666666666 Fold = 3 , Accuracy = 0.8703333333333333 Fold = 4 , Accuracy = 0.8699166666666667 \u67e5\u770b\u8bad\u7ec3\u811a\u672c\u65f6\uff0c\u60a8\u4f1a\u53d1\u73b0\u8fd8\u6709\u4e00\u4e9b\u5185\u5bb9\u662f\u786c\u7f16\u7801\u7684\uff0c\u4f8b\u5982\u6298\u53e0\u6570\u3001\u8bad\u7ec3\u6587\u4ef6\u548c\u8f93\u51fa\u6587\u4ef6\u5939\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u6240\u6709\u8fd9\u4e9b\u4fe1\u606f\u7684\u914d\u7f6e\u6587\u4ef6\uff1a config.py \u3002 TRAINING_FILE = \"../input/mnist_train_folds.csv\" MODEL_OUTPUT = \"../models/\" \u6211\u4eec\u8fd8\u5bf9\u8bad\u7ec3\u811a\u672c\u8fdb\u884c\u4e86\u4e00\u4e9b\u4fee\u6539\u3002\u8bad\u7ec3\u6587\u4ef6\u73b0\u5728\u4f7f\u7528\u914d\u7f6e\u6587\u4ef6\u3002\u8fd9\u6837\uff0c\u66f4\u6539\u6570\u636e\u6216\u6a21\u578b\u8f93\u51fa\u5c31\u66f4\u5bb9\u6613\u4e86\u3002 import os import config import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u4f7f\u7528config\u4e2d\u7684\u8def\u5f84\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values clf = tree . DecisionTreeClassifier () clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" ) ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5e76\u6ca1\u6709\u5c55\u793a\u8fd9\u4e2a\u57f9\u8bad\u811a\u672c\u4e0e\u4e4b\u524d\u811a\u672c\u7684\u533a\u522b\u3002\u8bf7\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e24\u4e2a\u811a\u672c\uff0c\u81ea\u5df1\u627e\u51fa\u4e0d\u540c\u4e4b\u5904\u3002\u533a\u522b\u5e76\u4e0d\u591a\u3002 \u4e0e\u8bad\u7ec3\u811a\u672c\u76f8\u5173\u7684\u8fd8\u6709\u4e00\u70b9\u53ef\u4ee5\u6539\u8fdb\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u4e3a\u6bcf\u4e2a\u6298\u53e0\u591a\u6b21\u8c03\u7528\u8fd0\u884c\u51fd\u6570\u3002\u6709\u65f6\uff0c\u5728\u540c\u4e00\u4e2a\u811a\u672c\u4e2d\u8fd0\u884c\u591a\u4e2a\u6298\u53e0\u5e76\u4e0d\u53ef\u53d6\uff0c\u56e0\u4e3a\u5185\u5b58\u6d88\u8017\u53ef\u80fd\u4f1a\u4e0d\u65ad\u589e\u52a0\uff0c\u7a0b\u5e8f\u53ef\u80fd\u4f1a\u5d29\u6e83\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5411\u8bad\u7ec3\u811a\u672c\u4f20\u9012\u53c2\u6570\u3002\u6211\u559c\u6b22\u4f7f\u7528 argparse\u3002 import argparse if __name__ == \"__main__\" : # \u5b9e\u4f8b\u5316\u53c2\u6570\u73af\u5883 parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # \u8bfb\u53d6\u53c2\u6570 args = parser . parse_args () run ( fold = args . fold ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u8fd0\u884c python \u811a\u672c\uff0c\u4f46\u4ec5\u9650\u4e8e\u7ed9\u5b9a\u7684\u6298\u53e0\u3002 \u276f python train . py -- fold 0 Fold = 0 , Accuracy = 0.8656666666666667 \u4ed4\u7ec6\u89c2\u5bdf\uff0c\u6211\u4eec\u7684\u7b2c 0 \u6298\u5f97\u5206\u4e0e\u4e4b\u524d\u6709\u4e9b\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6a21\u578b\u4e2d\u5b58\u5728\u968f\u673a\u6027\u3002\u6211\u4eec\u5c06\u5728\u540e\u9762\u7684\u7ae0\u8282\u4e2d\u8ba8\u8bba\u5982\u4f55\u5904\u7406\u968f\u673a\u6027\u3002 \u73b0\u5728\uff0c\u5982\u679c\u4f60\u613f\u610f\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a shell \u811a\u672c \uff0c\u9488\u5bf9\u4e0d\u540c\u7684\u6298\u53e0\u4f7f\u7528\u4e0d\u540c\u7684\u547d\u4ee4\uff0c\u7136\u540e\u4e00\u8d77\u8fd0\u884c\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 python train . py -- fold 0 python train . py -- fold 1 python train . py -- fold 2 python train . py -- fold 3 python train . py -- fold 4 \u60a8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u5b83\u3002 \u276f sh run . sh Fold = 0 , Accuracy = 0.8675 Fold = 1 , Accuracy = 0.8693333333333333 Fold = 2 , Accuracy = 0.8683333333333333 Fold = 3 , Accuracy = 0.8704166666666666 Fold = 4 , Accuracy = 0.8685 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u53d6\u5f97\u4e86\u4e00\u4e9b\u8fdb\u5c55\uff0c\u4f46\u5982\u679c\u6211\u4eec\u770b\u4e00\u4e0b\u6211\u4eec\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u6211\u4eec\u4ecd\u7136\u53d7\u5230\u4e00\u4e9b\u4e1c\u897f\u7684\u9650\u5236\uff0c\u4f8b\u5982\u6a21\u578b\u3002\u6a21\u578b\u662f\u786c\u7f16\u7801\u5728\u8bad\u7ec3\u811a\u672c\u4e2d\u7684\uff0c\u53ea\u6709\u4fee\u6539\u811a\u672c\u624d\u80fd\u6539\u53d8\u5b83\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u65b0\u7684 python \u811a\u672c\uff0c\u540d\u4e3a model_dispatcher.py \u3002model_dispatcher.py\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5c06\u8c03\u5ea6\u6211\u4eec\u7684\u6a21\u578b\u5230\u8bad\u7ec3\u811a\u672c\u4e2d\u3002 from sklearn import tree models = { # \u4ee5gini\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), # \u4ee5entropy\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), } model_dispatcher.py \u4ece scikit-learn \u4e2d\u5bfc\u5165\u4e86 tree\uff0c\u5e76\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u6a21\u578b\u7684\u540d\u79f0\uff0c\u503c\u662f\u6a21\u578b\u672c\u8eab\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5b9a\u4e49\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u51b3\u7b56\u6811\uff0c\u4e00\u79cd\u4f7f\u7528\u57fa\u5c3c\u6807\u51c6\uff0c\u53e6\u4e00\u79cd\u4f7f\u7528\u71b5\u6807\u51c6\u3002\u8981\u4f7f\u7528 py\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u8bad\u7ec3\u811a\u672c\u505a\u4e00\u4e9b\u4fee\u6539\u3002 import argparse import os import joblib import pandas as pd from sklearn import metrics import config import model_dispatcher def run ( fold , model ): df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values # \u6839\u636emodel\u53c2\u6570\u9009\u62e9\u6a21\u578b clf = model_dispatcher . models [ model ] clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" )) if __name__ == \"__main__\" : parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # model\u53c2\u6570 parser . add_argument ( \"--model\" , type = str ) args = parser . parse_args () run ( fold = args . fold , model = args . model ) train.py \u6709\u51e0\u5904\u91cd\u5927\u6539\u52a8\uff1a - \u5bfc\u5165 model_dispatcher - \u4e3a ArgumentParser \u6dfb\u52a0 --model \u53c2\u6570 - \u4e3a run() \u51fd\u6570\u6dfb\u52a0model\u53c2\u6570 - \u4f7f\u7528\u8c03\u5ea6\u7a0b\u5e8f\u83b7\u53d6\u6307\u5b9a\u540d\u79f0\u7684\u6a21\u578b \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u811a\u672c\uff1a \u276f python train . py -- fold 0 -- model decision_tree_gini Fold = 0 , Accuracy = 0.8665833333333334 \u6216\u6267\u884c\u4ee5\u4e0b\u547d\u4ee4 \u276f python train . py -- fold 0 -- model decision_tree_entropy Fold = 0 , Accuracy = 0.8705833333333334 \u73b0\u5728\uff0c\u5982\u679c\u8981\u6dfb\u52a0\u65b0\u6a21\u578b\uff0c\u53ea\u9700\u4fee\u6539 model_dispatcher.py \u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u6dfb\u52a0\u968f\u673a\u68ee\u6797\uff0c\u770b\u770b\u51c6\u786e\u7387\u4f1a\u6709\u4ec0\u4e48\u53d8\u5316\u3002 from sklearn import ensemble from sklearn import tree models = { \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), # \u968f\u673a\u68ee\u6797\u6a21\u578b \"rf\" : ensemble . RandomForestClassifier (), } \u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u3002 \u276f python train . py -- fold 0 -- model rf Fold = 0 , Accuracy = 0.9670833333333333 \u54c7\uff0c\u4e00\u4e2a\u7b80\u5355\u7684\u6539\u52a8\u5c31\u80fd\u8ba9\u5206\u6570\u6709\u5982\u6b64\u5927\u7684\u63d0\u5347\uff01\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528 run.sh \u811a\u672c\u8fd0\u884c 5 \u4e2a\u6298\u53e0\uff01 python train . py -- fold 0 -- model rf python train . py -- fold 1 -- model rf python train . py -- fold 2 -- model rf python train . py -- fold 3 -- model rf python train . py -- fold 4 -- model rf \u5f97\u5206\u60c5\u51b5\u5982\u4e0b \u276f sh run . sh Fold = 0 , Accuracy = 0.9674166666666667 Fold = 1 , Accuracy = 0.9698333333333333 Fold = 2 , Accuracy = 0.96575 Fold = 3 , Accuracy = 0.9684166666666667 Fold = 4 , Accuracy = 0.9666666666666667 MNIST \u51e0\u4e4e\u662f\u6bcf\u672c\u4e66\u548c\u6bcf\u7bc7\u535a\u5ba2\u90fd\u4f1a\u8ba8\u8bba\u7684\u95ee\u9898\u3002\u4f46\u6211\u8bd5\u56fe\u5c06\u8fd9\u4e2a\u95ee\u9898\u8f6c\u6362\u5f97\u66f4\u6709\u8da3\uff0c\u5e76\u5411\u4f60\u5c55\u793a\u5982\u4f55\u4e3a\u4f60\u6b63\u5728\u505a\u7684\u6216\u8ba1\u5212\u5728\u4e0d\u4e45\u7684\u5c06\u6765\u505a\u7684\u51e0\u4e4e\u6240\u6709\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7f16\u5199\u4e00\u4e2a\u57fa\u672c\u6846\u67b6\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4ee5\u6539\u8fdb\u8fd9\u4e2a MNIST \u6a21\u578b\u548c\u8fd9\u4e2a\u6846\u67b6\uff0c\u6211\u4eec\u5c06\u5728\u4ee5\u540e\u7684\u7ae0\u8282\u4e2d\u770b\u5230\u3002 \u6211\u4f7f\u7528\u4e86\u4e00\u4e9b\u811a\u672c\uff0c\u5982 model_dispatcher.py \u548c config.py \uff0c\u5e76\u5c06\u5b83\u4eec\u5bfc\u5165\u5230\u6211\u7684\u8bad\u7ec3\u811a\u672c\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u6ca1\u6709\u5bfc\u5165 \uff0c\u4f60\u4e5f\u4e0d\u5e94\u8be5\u5bfc\u5165\u3002\u5982\u679c\u6211\u5bfc\u5165\u4e86 \uff0c\u4f60\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u77e5\u9053\u6a21\u578b\u5b57\u5178\u662f\u4ece\u54ea\u91cc\u6765\u7684\u3002\u7f16\u5199\u4f18\u79c0\u3001\u6613\u61c2\u7684\u4ee3\u7801\u662f\u4e00\u4e2a\u4eba\u5fc5\u987b\u5177\u5907\u7684\u57fa\u672c\u7d20\u8d28\uff0c\u4f46\u8bb8\u591a\u6570\u636e\u79d1\u5b66\u5bb6\u5374\u5ffd\u89c6\u4e86\u8fd9\u4e00\u70b9\u3002\u5982\u679c\u4f60\u6240\u505a\u7684\u9879\u76ee\u80fd\u8ba9\u5176\u4ed6\u4eba\u7406\u89e3\u5e76\u4f7f\u7528\uff0c\u800c\u65e0\u9700\u54a8\u8be2\u4f60\u7684\u610f\u89c1\uff0c\u90a3\u4e48\u4f60\u5c31\u8282\u7701\u4e86\u4ed6\u4eec\u7684\u65f6\u95f4\u548c\u81ea\u5df1\u7684\u65f6\u95f4\uff0c\u53ef\u4ee5\u5c06\u8fd9\u4e9b\u65f6\u95f4\u6295\u5165\u5230\u6539\u8fdb\u4f60\u7684\u9879\u76ee\u6216\u5f00\u53d1\u65b0\u9879\u76ee\u4e2d\u53bb\u3002","title":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee"},{"location":"%E7%BB%84%E7%BB%87%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%A1%B9%E7%9B%AE/#_1","text":"\u7ec8\u4e8e\uff0c\u6211\u4eec\u53ef\u4ee5\u5f00\u59cb\u6784\u5efa\u7b2c\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u6a21\u578b\u4e86\u3002 \u662f\u8fd9\u6837\u5417\uff1f \u5728\u5f00\u59cb\u4e4b\u524d\uff0c\u6211\u4eec\u5fc5\u987b\u6ce8\u610f\u51e0\u4ef6\u4e8b\u3002\u8bf7\u8bb0\u4f4f\uff0c\u6211\u4eec\u5c06\u5728\u96c6\u6210\u5f00\u53d1\u73af\u5883/\u6587\u672c\u7f16\u8f91\u5668\u4e2d\u5de5\u4f5c\uff0c\u800c\u4e0d\u662f\u5728 jupyter notebook\u4e2d\u3002\u4f60\u4e5f\u53ef\u4ee5\u5728 jupyter notebook\u4e2d\u5de5\u4f5c\uff0c\u8fd9\u5b8c\u5168\u53d6\u51b3\u4e8e\u4f60\u3002\u4e0d\u8fc7\uff0c\u6211\u5c06\u53ea\u4f7f\u7528 jupyter notebook\u6765\u63a2\u7d22\u6570\u636e\u3001\u7ed8\u5236\u56fe\u8868\u548c\u56fe\u5f62\u3002\u6211\u4eec\u5c06\u4ee5\u8fd9\u6837\u4e00\u79cd\u65b9\u5f0f\u6784\u5efa\u5206\u7c7b\u6846\u67b6\uff0c\u5373\u63d2\u5373\u7528\u3002\u60a8\u65e0\u9700\u5bf9\u4ee3\u7801\u505a\u592a\u591a\u6539\u52a8\u5c31\u80fd\u8bad\u7ec3\u6a21\u578b\uff0c\u800c\u4e14\u5f53\u60a8\u6539\u8fdb\u6a21\u578b\u65f6\uff0c\u8fd8\u80fd\u4f7f\u7528 git \u5bf9\u5176\u8fdb\u884c\u8ddf\u8e2a\u3002 \u6211\u4eec\u9996\u5148\u6765\u770b\u770b\u6587\u4ef6\u7684\u7ed3\u6784\u3002\u5bf9\u4e8e\u4f60\u6b63\u5728\u505a\u7684\u4efb\u4f55\u9879\u76ee\uff0c\u90fd\u8981\u521b\u5efa\u4e00\u4e2a\u65b0\u6587\u4ef6\u5939\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u6211\u5c06\u9879\u76ee\u547d\u540d\u4e3a \"project\"\u3002 \u9879\u76ee\u6587\u4ef6\u5939\u5185\u90e8\u5e94\u8be5\u5982\u4e0b\u6240\u793a\u3002 input train.csv test.csv src create_folds.py train.py inference.py models.py config.py model_dispatcher.py models model_rf.bin model_et.bin notebooks exploration.ipynb check_data.ipynb README.md LICENSE \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e9b\u6587\u4ef6\u5939\u548c\u6587\u4ef6\u7684\u5185\u5bb9\u3002 input/ \uff1a\u8be5\u6587\u4ef6\u5939\u5305\u542b\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7684\u6240\u6709\u8f93\u5165\u6587\u4ef6\u548c\u6570\u636e\u3002\u5982\u679c\u60a8\u6b63\u5728\u5f00\u53d1 NLP \u9879\u76ee\uff0c\u60a8\u53ef\u4ee5\u5c06embeddings\u653e\u5728\u8fd9\u91cc\u3002\u5982\u679c\u662f\u56fe\u50cf\u9879\u76ee\uff0c\u6240\u6709\u56fe\u50cf\u90fd\u653e\u5728\u8be5\u6587\u4ef6\u5939\u4e0b\u7684\u5b50\u6587\u4ef6\u5939\u4e2d\u3002 src/ \uff1a\u6211\u4eec\u5c06\u5728\u8fd9\u91cc\u4fdd\u5b58\u4e0e\u9879\u76ee\u76f8\u5173\u7684\u6240\u6709 python \u811a\u672c\u3002\u5982\u679c\u6211\u8bf4\u7684\u662f\u4e00\u4e2a python \u811a\u672c\uff0c\u5373\u4efb\u4f55 *.py \u6587\u4ef6\uff0c\u5b83\u90fd\u5b58\u50a8\u5728 src \u6587\u4ef6\u5939\u4e2d\u3002 models/ \uff1a\u8be5\u6587\u4ef6\u5939\u4fdd\u5b58\u6240\u6709\u8bad\u7ec3\u8fc7\u7684\u6a21\u578b\u3002 notebook/ \uff1a\u6240\u6709 jupyter notebook\uff08\u5373\u4efb\u4f55 *.ipynb \u6587\u4ef6\uff09\u90fd\u5b58\u50a8\u5728\u7b14\u8bb0\u672c \u6587\u4ef6\u5939\u4e2d\u3002 README.md \uff1a\u8fd9\u662f\u4e00\u4e2a\u6807\u8bb0\u7b26\u6587\u4ef6\uff0c\u60a8\u53ef\u4ee5\u5728\u5176\u4e2d\u63cf\u8ff0\u60a8\u7684\u9879\u76ee\uff0c\u5e76\u5199\u660e\u5982\u4f55\u8bad\u7ec3\u6a21\u578b\u6216\u5728\u751f\u4ea7\u73af\u5883\u4e2d\u4f7f\u7528\u3002 LICENSE \uff1a\u8fd9\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u6587\u672c\u6587\u4ef6\uff0c\u5305\u542b\u9879\u76ee\u7684\u8bb8\u53ef\u8bc1\uff0c\u5982 MIT\u3001Apache \u7b49\u3002\u5173\u4e8e\u8bb8\u53ef\u8bc1\u7684\u8be6\u7ec6\u4ecb\u7ecd\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 \u5047\u8bbe\u4f60\u6b63\u5728\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u5bf9 MNIST \u6570\u636e\u96c6\uff08\u51e0\u4e4e\u6bcf\u672c\u673a\u5668\u5b66\u4e60\u4e66\u7c4d\u90fd\u4f1a\u7528\u5230\u7684\u6570\u636e\u96c6\uff09\u8fdb\u884c\u5206\u7c7b\u3002\u5982\u679c\u4f60\u8fd8\u8bb0\u5f97\uff0c\u6211\u4eec\u5728\u4ea4\u53c9\u68c0\u9a8c\u4e00\u7ae0\u4e2d\u4e5f\u63d0\u5230\u8fc7 MNIST \u6570\u636e\u96c6\u3002\u6240\u4ee5\uff0c\u6211\u5c31\u4e0d\u89e3\u91ca\u8fd9\u4e2a\u6570\u636e\u96c6\u662f\u4ec0\u4e48\u6837\u5b50\u4e86\u3002\u7f51\u4e0a\u6709\u8bb8\u591a\u4e0d\u540c\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6\uff0c\u4f46\u6211\u4eec\u5c06\u4f7f\u7528 CSV \u683c\u5f0f\u7684\u6570\u636e\u96c6\u3002 \u5728\u8fd9\u79cd\u683c\u5f0f\u7684\u6570\u636e\u96c6\u4e2d\uff0cCSV \u7684\u6bcf\u4e00\u884c\u90fd\u5305\u542b\u56fe\u50cf\u7684\u6807\u7b7e\u548c 784 \u4e2a\u50cf\u7d20\u503c\uff0c\u50cf\u7d20\u503c\u8303\u56f4\u4ece 0 \u5230 255\u3002\u6570\u636e\u96c6\u5305\u542b 60000 \u5f20\u8fd9\u79cd\u683c\u5f0f\u7684\u56fe\u50cf\u3002 \u6211\u4eec\u53ef\u4ee5\u4f7f\u7528 pandas \u8f7b\u677e\u8bfb\u53d6\u8fd9\u79cd\u6570\u636e\u683c\u5f0f\u3002 \u8bf7\u6ce8\u610f\uff0c\u5c3d\u7ba1\u56fe 1 \u663e\u793a\u6240\u6709\u50cf\u7d20\u503c\u5747\u4e3a\u96f6\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002 \u56fe 1\uff1aCSV\u683c\u5f0f\u7684 MNIST \u6570\u636e\u96c6 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\u6807\u7b7e\u5217\u7684\u8ba1\u6570\u3002 \u56fe 2\uff1aMNIST \u6570\u636e\u96c6\u4e2d\u7684\u6807\u7b7e\u8ba1\u6570 \u6211\u4eec\u4e0d\u9700\u8981\u5bf9\u8fd9\u4e2a\u6570\u636e\u96c6\u8fdb\u884c\u66f4\u591a\u7684\u63a2\u7d22\u3002\u6211\u4eec\u5df2\u7ecf\u77e5\u9053\u4e86\u6211\u4eec\u6240\u62e5\u6709\u7684\u6570\u636e\uff0c\u6ca1\u6709\u5fc5\u8981\u518d\u5bf9\u4e0d\u540c\u7684\u50cf\u7d20\u503c\u8fdb\u884c\u7ed8\u56fe\u3002\u4ece\u56fe 2 \u4e2d\u53ef\u4ee5\u6e05\u695a\u5730\u770b\u51fa\uff0c\u6807\u7b7e\u7684\u5206\u5e03\u76f8\u5f53\u5747\u5300\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u51c6\u786e\u7387/F1 \u4f5c\u4e3a\u8861\u91cf\u6807\u51c6\u3002\u8fd9\u5c31\u662f\u5904\u7406\u673a\u5668\u5b66\u4e60\u95ee\u9898\u7684\u7b2c\u4e00\u6b65\uff1a\u786e\u5b9a\u8861\u91cf\u6807\u51c6\uff01 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7f16\u5199\u4e00\u4e9b\u4ee3\u7801\u4e86\u3002\u6211\u4eec\u9700\u8981\u521b\u5efa src/ \u6587\u4ef6\u5939\u548c\u4e00\u4e9b python \u811a\u672c\u3002 \u8bf7\u6ce8\u610f\uff0c\u8bad\u7ec3 CSV \u6587\u4ef6\u4f4d\u4e8e input/ \u6587\u4ef6\u5939\u4e2d\uff0c\u540d\u4e3a mnist_train.csv \u3002 \u5bf9\u4e8e\u8fd9\u6837\u4e00\u4e2a\u9879\u76ee\uff0c\u8fd9\u4e9b\u6587\u4ef6\u5e94\u8be5\u662f\u4ec0\u4e48\u6837\u7684\u5462\uff1f \u9996\u5148\u8981\u521b\u5efa\u7684\u811a\u672c\u662f create_folds.py \u3002 \u8fd9\u5c06\u5728 input/ \u6587\u4ef6\u5939\u4e2d\u521b\u5efa\u4e00\u4e2a\u540d\u4e3a mnist_train_folds.csv \u7684\u65b0\u6587\u4ef6\uff0c\u4e0e mnist_train.csv \u76f8\u540c\u3002\u552f\u4e00\u4e0d\u540c\u7684\u662f\uff0c\u8fd9\u4e2a CSV \u6587\u4ef6\u7ecf\u8fc7\u4e86\u968f\u673a\u6392\u5e8f\uff0c\u5e76\u65b0\u589e\u4e86\u4e00\u5217\u540d\u4e3a kfold \u7684\u5185\u5bb9\u3002 \u4e00\u65e6\u6211\u4eec\u51b3\u5b9a\u4e86\u8981\u4f7f\u7528\u54ea\u79cd\u8bc4\u4f30\u6307\u6807\u5e76\u521b\u5efa\u4e86\u6298\u53e0\uff0c\u5c31\u53ef\u4ee5\u5f00\u59cb\u521b\u5efa\u57fa\u672c\u6a21\u578b\u4e86\u3002\u8fd9\u53ef\u4ee5\u5728 train.py \u4e2d\u5b8c\u6210\u3002 import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6 df = pd . read_csv ( \"../input/mnist_train_folds.csv\" ) # \u9009\u53d6df\u4e2dkfold\u5217\u4e0d\u7b49\u4e8efold df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) # \u9009\u53d6df\u4e2dkfold\u5217\u7b49\u4e8efold df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) # \u8bad\u7ec3\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_train = df_train . drop ( \"label\" , axis = 1 ) . values # \u8bad\u7ec3\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_train = df_train . label . values # \u9a8c\u8bc1\u96c6\u8f93\u5165\uff0c\u5220\u9664label\u5217 x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values # \u9a8c\u8bc1\u96c6\u8f93\u51fa\uff0c\u53d6label\u5217 y_valid = df_valid . label . values # \u5b9e\u4f8b\u5316\u51b3\u7b56\u6811\u6a21\u578b clf = tree . DecisionTreeClassifier () # \u4f7f\u7528\u8bad\u7ec3\u96c6\u8bad\u7ec3\u6a21\u578b clf . fit ( x_train , y_train ) # \u4f7f\u7528\u9a8c\u8bc1\u96c6\u8f93\u5165\u5f97\u5230\u9884\u6d4b\u7ed3\u679c preds = clf . predict ( x_valid ) # \u8ba1\u7b97\u9a8c\u8bc1\u96c6\u51c6\u786e\u7387 accuracy = metrics . accuracy_score ( y_valid , preds ) # \u6253\u5370fold\u4fe1\u606f\u548c\u51c6\u786e\u7387 print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) # \u4fdd\u5b58\u6a21\u578b joblib . dump ( clf , f \"../models/dt_ { fold } .bin\" ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u60a8\u53ef\u4ee5\u5728\u63a7\u5236\u53f0\u8c03\u7528 python train.py \u8fd0\u884c\u8be5\u811a\u672c\u3002 \u276f python train . py Fold = 0 , Accuracy = 0.8680833333333333 Fold = 1 , Accuracy = 0.8685 Fold = 2 , Accuracy = 0.8674166666666666 Fold = 3 , Accuracy = 0.8703333333333333 Fold = 4 , Accuracy = 0.8699166666666667 \u67e5\u770b\u8bad\u7ec3\u811a\u672c\u65f6\uff0c\u60a8\u4f1a\u53d1\u73b0\u8fd8\u6709\u4e00\u4e9b\u5185\u5bb9\u662f\u786c\u7f16\u7801\u7684\uff0c\u4f8b\u5982\u6298\u53e0\u6570\u3001\u8bad\u7ec3\u6587\u4ef6\u548c\u8f93\u51fa\u6587\u4ef6\u5939\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u6240\u6709\u8fd9\u4e9b\u4fe1\u606f\u7684\u914d\u7f6e\u6587\u4ef6\uff1a config.py \u3002 TRAINING_FILE = \"../input/mnist_train_folds.csv\" MODEL_OUTPUT = \"../models/\" \u6211\u4eec\u8fd8\u5bf9\u8bad\u7ec3\u811a\u672c\u8fdb\u884c\u4e86\u4e00\u4e9b\u4fee\u6539\u3002\u8bad\u7ec3\u6587\u4ef6\u73b0\u5728\u4f7f\u7528\u914d\u7f6e\u6587\u4ef6\u3002\u8fd9\u6837\uff0c\u66f4\u6539\u6570\u636e\u6216\u6a21\u578b\u8f93\u51fa\u5c31\u66f4\u5bb9\u6613\u4e86\u3002 import os import config import joblib import pandas as pd from sklearn import metrics from sklearn import tree def run ( fold ): # \u4f7f\u7528config\u4e2d\u7684\u8def\u5f84\u8bfb\u53d6\u6570\u636e df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values clf = tree . DecisionTreeClassifier () clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" ) ) if __name__ == \"__main__\" : # \u8fd0\u884c\u6bcf\u4e2a\u6298\u53e0 run ( fold = 0 ) run ( fold = 1 ) run ( fold = 2 ) run ( fold = 3 ) run ( fold = 4 ) \u8bf7\u6ce8\u610f\uff0c\u6211\u5e76\u6ca1\u6709\u5c55\u793a\u8fd9\u4e2a\u57f9\u8bad\u811a\u672c\u4e0e\u4e4b\u524d\u811a\u672c\u7684\u533a\u522b\u3002\u8bf7\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e24\u4e2a\u811a\u672c\uff0c\u81ea\u5df1\u627e\u51fa\u4e0d\u540c\u4e4b\u5904\u3002\u533a\u522b\u5e76\u4e0d\u591a\u3002 \u4e0e\u8bad\u7ec3\u811a\u672c\u76f8\u5173\u7684\u8fd8\u6709\u4e00\u70b9\u53ef\u4ee5\u6539\u8fdb\u3002\u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u6211\u4eec\u4e3a\u6bcf\u4e2a\u6298\u53e0\u591a\u6b21\u8c03\u7528\u8fd0\u884c\u51fd\u6570\u3002\u6709\u65f6\uff0c\u5728\u540c\u4e00\u4e2a\u811a\u672c\u4e2d\u8fd0\u884c\u591a\u4e2a\u6298\u53e0\u5e76\u4e0d\u53ef\u53d6\uff0c\u56e0\u4e3a\u5185\u5b58\u6d88\u8017\u53ef\u80fd\u4f1a\u4e0d\u65ad\u589e\u52a0\uff0c\u7a0b\u5e8f\u53ef\u80fd\u4f1a\u5d29\u6e83\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u6211\u4eec\u53ef\u4ee5\u5411\u8bad\u7ec3\u811a\u672c\u4f20\u9012\u53c2\u6570\u3002\u6211\u559c\u6b22\u4f7f\u7528 argparse\u3002 import argparse if __name__ == \"__main__\" : # \u5b9e\u4f8b\u5316\u53c2\u6570\u73af\u5883 parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # \u8bfb\u53d6\u53c2\u6570 args = parser . parse_args () run ( fold = args . fold ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u518d\u6b21\u8fd0\u884c python \u811a\u672c\uff0c\u4f46\u4ec5\u9650\u4e8e\u7ed9\u5b9a\u7684\u6298\u53e0\u3002 \u276f python train . py -- fold 0 Fold = 0 , Accuracy = 0.8656666666666667 \u4ed4\u7ec6\u89c2\u5bdf\uff0c\u6211\u4eec\u7684\u7b2c 0 \u6298\u5f97\u5206\u4e0e\u4e4b\u524d\u6709\u4e9b\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6a21\u578b\u4e2d\u5b58\u5728\u968f\u673a\u6027\u3002\u6211\u4eec\u5c06\u5728\u540e\u9762\u7684\u7ae0\u8282\u4e2d\u8ba8\u8bba\u5982\u4f55\u5904\u7406\u968f\u673a\u6027\u3002 \u73b0\u5728\uff0c\u5982\u679c\u4f60\u613f\u610f\uff0c\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a shell \u811a\u672c \uff0c\u9488\u5bf9\u4e0d\u540c\u7684\u6298\u53e0\u4f7f\u7528\u4e0d\u540c\u7684\u547d\u4ee4\uff0c\u7136\u540e\u4e00\u8d77\u8fd0\u884c\uff0c\u5982\u4e0b\u56fe\u6240\u793a\u3002 python train . py -- fold 0 python train . py -- fold 1 python train . py -- fold 2 python train . py -- fold 3 python train . py -- fold 4 \u60a8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u5b83\u3002 \u276f sh run . sh Fold = 0 , Accuracy = 0.8675 Fold = 1 , Accuracy = 0.8693333333333333 Fold = 2 , Accuracy = 0.8683333333333333 Fold = 3 , Accuracy = 0.8704166666666666 Fold = 4 , Accuracy = 0.8685 \u6211\u4eec\u73b0\u5728\u5df2\u7ecf\u53d6\u5f97\u4e86\u4e00\u4e9b\u8fdb\u5c55\uff0c\u4f46\u5982\u679c\u6211\u4eec\u770b\u4e00\u4e0b\u6211\u4eec\u7684\u8bad\u7ec3\u811a\u672c\uff0c\u6211\u4eec\u4ecd\u7136\u53d7\u5230\u4e00\u4e9b\u4e1c\u897f\u7684\u9650\u5236\uff0c\u4f8b\u5982\u6a21\u578b\u3002\u6a21\u578b\u662f\u786c\u7f16\u7801\u5728\u8bad\u7ec3\u811a\u672c\u4e2d\u7684\uff0c\u53ea\u6709\u4fee\u6539\u811a\u672c\u624d\u80fd\u6539\u53d8\u5b83\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u5c06\u521b\u5efa\u4e00\u4e2a\u65b0\u7684 python \u811a\u672c\uff0c\u540d\u4e3a model_dispatcher.py \u3002model_dispatcher.py\uff0c\u987e\u540d\u601d\u4e49\uff0c\u5c06\u8c03\u5ea6\u6211\u4eec\u7684\u6a21\u578b\u5230\u8bad\u7ec3\u811a\u672c\u4e2d\u3002 from sklearn import tree models = { # \u4ee5gini\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), # \u4ee5entropy\u7cfb\u6570\u5ea6\u91cf\u7684\u51b3\u7b56\u6811 \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), } model_dispatcher.py \u4ece scikit-learn \u4e2d\u5bfc\u5165\u4e86 tree\uff0c\u5e76\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b57\u5178\uff0c\u5176\u4e2d\u952e\u662f\u6a21\u578b\u7684\u540d\u79f0\uff0c\u503c\u662f\u6a21\u578b\u672c\u8eab\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5b9a\u4e49\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u51b3\u7b56\u6811\uff0c\u4e00\u79cd\u4f7f\u7528\u57fa\u5c3c\u6807\u51c6\uff0c\u53e6\u4e00\u79cd\u4f7f\u7528\u71b5\u6807\u51c6\u3002\u8981\u4f7f\u7528 py\uff0c\u6211\u4eec\u9700\u8981\u5bf9\u8bad\u7ec3\u811a\u672c\u505a\u4e00\u4e9b\u4fee\u6539\u3002 import argparse import os import joblib import pandas as pd from sklearn import metrics import config import model_dispatcher def run ( fold , model ): df = pd . read_csv ( config . TRAINING_FILE ) df_train = df [ df . kfold != fold ] . reset_index ( drop = True ) df_valid = df [ df . kfold == fold ] . reset_index ( drop = True ) x_train = df_train . drop ( \"label\" , axis = 1 ) . values y_train = df_train . label . values x_valid = df_valid . drop ( \"label\" , axis = 1 ) . values y_valid = df_valid . label . values # \u6839\u636emodel\u53c2\u6570\u9009\u62e9\u6a21\u578b clf = model_dispatcher . models [ model ] clf . fit ( x_train , y_train ) preds = clf . predict ( x_valid ) accuracy = metrics . accuracy_score ( y_valid , preds ) print ( f \"Fold= { fold } , Accuracy= { accuracy } \" ) joblib . dump ( clf , os . path . join ( config . MODEL_OUTPUT , f \"dt_ { fold } .bin\" )) if __name__ == \"__main__\" : parser = argparse . ArgumentParser () # fold\u53c2\u6570 parser . add_argument ( \"--fold\" , type = int ) # model\u53c2\u6570 parser . add_argument ( \"--model\" , type = str ) args = parser . parse_args () run ( fold = args . fold , model = args . model ) train.py \u6709\u51e0\u5904\u91cd\u5927\u6539\u52a8\uff1a - \u5bfc\u5165 model_dispatcher - \u4e3a ArgumentParser \u6dfb\u52a0 --model \u53c2\u6570 - \u4e3a run() \u51fd\u6570\u6dfb\u52a0model\u53c2\u6570 - \u4f7f\u7528\u8c03\u5ea6\u7a0b\u5e8f\u83b7\u53d6\u6307\u5b9a\u540d\u79f0\u7684\u6a21\u578b \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4ee5\u4e0b\u547d\u4ee4\u8fd0\u884c\u811a\u672c\uff1a \u276f python train . py -- fold 0 -- model decision_tree_gini Fold = 0 , Accuracy = 0.8665833333333334 \u6216\u6267\u884c\u4ee5\u4e0b\u547d\u4ee4 \u276f python train . py -- fold 0 -- model decision_tree_entropy Fold = 0 , Accuracy = 0.8705833333333334 \u73b0\u5728\uff0c\u5982\u679c\u8981\u6dfb\u52a0\u65b0\u6a21\u578b\uff0c\u53ea\u9700\u4fee\u6539 model_dispatcher.py \u3002\u8ba9\u6211\u4eec\u5c1d\u8bd5\u6dfb\u52a0\u968f\u673a\u68ee\u6797\uff0c\u770b\u770b\u51c6\u786e\u7387\u4f1a\u6709\u4ec0\u4e48\u53d8\u5316\u3002 from sklearn import ensemble from sklearn import tree models = { \"decision_tree_gini\" : tree . DecisionTreeClassifier ( criterion = \"gini\" ), \"decision_tree_entropy\" : tree . DecisionTreeClassifier ( criterion = \"entropy\" ), # \u968f\u673a\u68ee\u6797\u6a21\u578b \"rf\" : ensemble . RandomForestClassifier (), } \u8ba9\u6211\u4eec\u8fd0\u884c\u8fd9\u6bb5\u4ee3\u7801\u3002 \u276f python train . py -- fold 0 -- model rf Fold = 0 , Accuracy = 0.9670833333333333 \u54c7\uff0c\u4e00\u4e2a\u7b80\u5355\u7684\u6539\u52a8\u5c31\u80fd\u8ba9\u5206\u6570\u6709\u5982\u6b64\u5927\u7684\u63d0\u5347\uff01\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528 run.sh \u811a\u672c\u8fd0\u884c 5 \u4e2a\u6298\u53e0\uff01 python train . py -- fold 0 -- model rf python train . py -- fold 1 -- model rf python train . py -- fold 2 -- model rf python train . py -- fold 3 -- model rf python train . py -- fold 4 -- model rf \u5f97\u5206\u60c5\u51b5\u5982\u4e0b \u276f sh run . sh Fold = 0 , Accuracy = 0.9674166666666667 Fold = 1 , Accuracy = 0.9698333333333333 Fold = 2 , Accuracy = 0.96575 Fold = 3 , Accuracy = 0.9684166666666667 Fold = 4 , Accuracy = 0.9666666666666667 MNIST \u51e0\u4e4e\u662f\u6bcf\u672c\u4e66\u548c\u6bcf\u7bc7\u535a\u5ba2\u90fd\u4f1a\u8ba8\u8bba\u7684\u95ee\u9898\u3002\u4f46\u6211\u8bd5\u56fe\u5c06\u8fd9\u4e2a\u95ee\u9898\u8f6c\u6362\u5f97\u66f4\u6709\u8da3\uff0c\u5e76\u5411\u4f60\u5c55\u793a\u5982\u4f55\u4e3a\u4f60\u6b63\u5728\u505a\u7684\u6216\u8ba1\u5212\u5728\u4e0d\u4e45\u7684\u5c06\u6765\u505a\u7684\u51e0\u4e4e\u6240\u6709\u673a\u5668\u5b66\u4e60\u9879\u76ee\u7f16\u5199\u4e00\u4e2a\u57fa\u672c\u6846\u67b6\u3002\u6709\u8bb8\u591a\u4e0d\u540c\u7684\u65b9\u6cd5\u53ef\u4ee5\u6539\u8fdb\u8fd9\u4e2a MNIST \u6a21\u578b\u548c\u8fd9\u4e2a\u6846\u67b6\uff0c\u6211\u4eec\u5c06\u5728\u4ee5\u540e\u7684\u7ae0\u8282\u4e2d\u770b\u5230\u3002 \u6211\u4f7f\u7528\u4e86\u4e00\u4e9b\u811a\u672c\uff0c\u5982 model_dispatcher.py \u548c config.py \uff0c\u5e76\u5c06\u5b83\u4eec\u5bfc\u5165\u5230\u6211\u7684\u8bad\u7ec3\u811a\u672c\u4e2d\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u6ca1\u6709\u5bfc\u5165 \uff0c\u4f60\u4e5f\u4e0d\u5e94\u8be5\u5bfc\u5165\u3002\u5982\u679c\u6211\u5bfc\u5165\u4e86 \uff0c\u4f60\u5c31\u6c38\u8fdc\u4e0d\u4f1a\u77e5\u9053\u6a21\u578b\u5b57\u5178\u662f\u4ece\u54ea\u91cc\u6765\u7684\u3002\u7f16\u5199\u4f18\u79c0\u3001\u6613\u61c2\u7684\u4ee3\u7801\u662f\u4e00\u4e2a\u4eba\u5fc5\u987b\u5177\u5907\u7684\u57fa\u672c\u7d20\u8d28\uff0c\u4f46\u8bb8\u591a\u6570\u636e\u79d1\u5b66\u5bb6\u5374\u5ffd\u89c6\u4e86\u8fd9\u4e00\u70b9\u3002\u5982\u679c\u4f60\u6240\u505a\u7684\u9879\u76ee\u80fd\u8ba9\u5176\u4ed6\u4eba\u7406\u89e3\u5e76\u4f7f\u7528\uff0c\u800c\u65e0\u9700\u54a8\u8be2\u4f60\u7684\u610f\u89c1\uff0c\u90a3\u4e48\u4f60\u5c31\u8282\u7701\u4e86\u4ed6\u4eec\u7684\u65f6\u95f4\u548c\u81ea\u5df1\u7684\u65f6\u95f4\uff0c\u53ef\u4ee5\u5c06\u8fd9\u4e9b\u65f6\u95f4\u6295\u5165\u5230\u6539\u8fdb\u4f60\u7684\u9879\u76ee\u6216\u5f00\u53d1\u65b0\u9879\u76ee\u4e2d\u53bb\u3002","title":"\u7ec4\u7ec7\u673a\u5668\u5b66\u4e60\u9879\u76ee"},{"location":"%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87/","text":"\u8bc4\u4f30\u6307\u6807 \u8bf4\u5230\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u4f60\u4f1a\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u9047\u5230\u5f88\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u6307\u6807\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u751a\u81f3\u4f1a\u6839\u636e\u4e1a\u52a1\u95ee\u9898\u521b\u5efa\u5ea6\u91cf\u6807\u51c6\u3002\u9010\u4e00\u4ecb\u7ecd\u548c\u89e3\u91ca\u6bcf\u4e00\u79cd\u5ea6\u91cf\u7c7b\u578b\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u5ea6\u91cf\u6807\u51c6\uff0c\u4f9b\u4f60\u5728\u6700\u521d\u7684\u51e0\u4e2a\u9879\u76ee\u4e2d\u4f7f\u7528\u3002 \u5728\u672c\u4e66\u7684\u5f00\u5934\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u76d1\u7763\u5b66\u4e60\u548c\u975e\u76d1\u7763\u5b66\u4e60\u3002\u867d\u7136\u65e0\u76d1\u7763\u5b66\u4e60\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e9b\u6307\u6807\uff0c\u4f46\u6211\u4eec\u5c06\u53ea\u5173\u6ce8\u6709\u76d1\u7763\u5b66\u4e60\u3002\u8fd9\u662f\u56e0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u591a\uff0c\u800c\u4e14\u5bf9\u65e0\u76d1\u7763\u65b9\u6cd5\u7684\u8bc4\u4f30\u76f8\u5f53\u4e3b\u89c2\u3002 \u5982\u679c\u6211\u4eec\u8c08\u8bba\u5206\u7c7b\u95ee\u9898\uff0c\u6700\u5e38\u7528\u7684\u6307\u6807\u662f\uff1a \u51c6\u786e\u7387\uff08Accuracy\uff09 \u7cbe\u786e\u7387\uff08P\uff09 \u53ec\u56de\u7387\uff08R\uff09 F1 \u5206\u6570\uff08F1\uff09 AUC\uff08AUC\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u8bf4\u5230\u56de\u5f52\uff0c\u6700\u5e38\u7528\u7684\u8bc4\u4ef7\u6307\u6807\u662f \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee \uff08MAE\uff09 \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u5747\u65b9\u6839\u8bef\u5dee \uff08RMSE\uff09 \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \uff08RMSLE\uff09 \u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee \uff08MPE\uff09 \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee \uff08MAPE\uff09 R2 \u4e86\u89e3\u4e0a\u8ff0\u6307\u6807\u7684\u5de5\u4f5c\u539f\u7406\u5e76\u4e0d\u662f\u6211\u4eec\u5fc5\u987b\u4e86\u89e3\u7684\u552f\u4e00\u4e8b\u60c5\u3002\u6211\u4eec\u8fd8\u5fc5\u987b\u77e5\u9053\u4f55\u65f6\u4f7f\u7528\u54ea\u4e9b\u6307\u6807\uff0c\u800c\u8fd9\u53d6\u51b3\u4e8e\u4f60\u6709\u4ec0\u4e48\u6837\u7684\u6570\u636e\u548c\u76ee\u6807\u3002\u6211\u8ba4\u4e3a\u8fd9\u4e0e\u76ee\u6807\u6709\u5173\uff0c\u800c\u4e0e\u6570\u636e\u65e0\u5173\u3002 \u8981\u8fdb\u4e00\u6b65\u4e86\u89e3\u8fd9\u4e9b\u6307\u6807\uff0c\u8ba9\u6211\u4eec\u4ece\u4e00\u4e2a\u7b80\u5355\u7684\u95ee\u9898\u5f00\u59cb\u3002\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a \u4e8c\u5143\u5206\u7c7b \u95ee\u9898\uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u76ee\u6807\u7684\u95ee\u9898\uff0c\u5047\u8bbe\u8fd9\u662f\u4e00\u4e2a\u80f8\u90e8 X \u5149\u56fe\u50cf\u5206\u7c7b\u95ee\u9898\u3002\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6ca1\u6709\u95ee\u9898\uff0c\u800c\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6709\u80ba\u584c\u9677\uff0c\u4e5f\u5c31\u662f\u6240\u8c13\u7684\u6c14\u80f8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u5206\u7c7b\u5668\uff0c\u5728\u7ed9\u5b9a\u80f8\u90e8 X \u5149\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\uff0c\u5b83\u80fd\u68c0\u6d4b\u51fa\u56fe\u50cf\u662f\u5426\u6709\u6c14\u80f8\u3002 \u56fe 1\uff1a\u6c14\u80f8\u80ba\u90e8\u56fe\u50cf \u6211\u4eec\u8fd8\u5047\u8bbe\u6709\u76f8\u540c\u6570\u91cf\u7684\u6c14\u80f8\u548c\u975e\u6c14\u80f8\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u6bd4\u5982\u5404 100 \u5f20\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709 100 \u5f20\u9633\u6027\u6837\u672c\u548c 100 \u5f20\u9634\u6027\u6837\u672c\uff0c\u5171\u8ba1 200 \u5f20\u56fe\u50cf\u3002 \u7b2c\u4e00\u6b65\u662f\u5c06\u4e0a\u8ff0\u6570\u636e\u5206\u4e3a\u4e24\u7ec4\uff0c\u6bcf\u7ec4 100 \u5f20\u56fe\u50cf\uff0c\u5373\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u8fd9\u4e24\u4e2a\u96c6\u5408\u4e2d\uff0c\u6211\u4eec\u90fd\u6709 50 \u4e2a\u6b63\u6837\u672c\u548c 50 \u4e2a\u8d1f\u6837\u672c\u3002 \u5728\u4e8c\u5143\u5206\u7c7b\u6307\u6807\u4e2d\uff0c\u5f53\u6b63\u8d1f\u6837\u672c\u6570\u91cf\u76f8\u7b49\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f7f\u7528\u51c6\u786e\u7387\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002 \u51c6\u786e\u7387 \uff1a\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6700\u76f4\u63a5\u7684\u6307\u6807\u4e4b\u4e00\u3002\u5b83\u5b9a\u4e49\u4e86\u6a21\u578b\u7684\u51c6\u786e\u5ea6\u3002\u5bf9\u4e8e\u4e0a\u8ff0\u95ee\u9898\uff0c\u5982\u679c\u4f60\u5efa\u7acb\u7684\u6a21\u578b\u80fd\u51c6\u786e\u5206\u7c7b 90 \u5f20\u56fe\u7247\uff0c\u90a3\u4e48\u4f60\u7684\u51c6\u786e\u7387\u5c31\u662f 90% \u6216 0.90\u3002\u5982\u679c\u53ea\u6709 83 \u5e45\u56fe\u50cf\u88ab\u6b63\u786e\u5206\u7c7b\uff0c\u90a3\u4e48\u6a21\u578b\u7684\u51c6\u786e\u7387\u5c31\u662f 83% \u6216 0.83\u3002 \u8ba1\u7b97\u51c6\u786e\u7387\u7684 Python \u4ee3\u7801\u4e5f\u975e\u5e38\u7b80\u5355\u3002 def accuracy ( y_true , y_pred ): # \u4e3a\u6b63\u786e\u9884\u6d4b\u6570\u521d\u59cb\u5316\u4e00\u4e2a\u7b80\u5355\u8ba1\u6570\u5668 correct_counter = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): if yt == yp : # \u5982\u679c\u9884\u6d4b\u6807\u7b7e\u4e0e\u771f\u5b9e\u6807\u7b7e\u76f8\u540c\uff0c\u5219\u589e\u52a0\u8ba1\u6570\u5668 correct_counter += 1 # \u8fd4\u56de\u6b63\u786e\u7387\uff0c\u6b63\u786e\u6807\u7b7e\u6570/\u603b\u6807\u7b7e\u6570 return correct_counter / len ( y_true ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u8ba1\u7b97\u51c6\u786e\u7387\u3002 In [ X ]: from sklearn import metrics ... : l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] ... : metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u73b0\u5728\uff0c\u5047\u8bbe\u6211\u4eec\u628a\u6570\u636e\u96c6\u7a0d\u5fae\u6539\u52a8\u4e00\u4e0b\uff0c\u6709 180 \u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u53ea\u6709 20 \u5f20\u6709\u6c14\u80f8\u3002\u5373\u4f7f\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4e5f\u8981\u521b\u5efa\u6b63\u8d1f\uff08\u6c14\u80f8\u4e0e\u975e\u6c14\u80f8\uff09\u76ee\u6807\u6bd4\u4f8b\u76f8\u540c\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u6bcf\u4e00\u7ec4\u4e2d\uff0c\u6211\u4eec\u6709 90 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u5982\u679c\u8bf4\u9a8c\u8bc1\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u90fd\u662f\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u90a3\u4e48\u60a8\u7684\u51c6\u786e\u7387\u4f1a\u662f\u591a\u5c11\u5462\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\uff1b\u60a8\u5bf9 90% \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u6b63\u786e\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u7684\u51c6\u786e\u7387\u662f 90%\u3002 \u4f46\u8bf7\u518d\u770b\u4e00\u904d\u3002 \u4f60\u751a\u81f3\u6ca1\u6709\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5c31\u5f97\u5230\u4e86 90% \u7684\u51c6\u786e\u7387\u3002\u8fd9\u4f3c\u4e4e\u6709\u70b9\u6ca1\u7528\u3002\u5982\u679c\u6211\u4eec\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u6570\u636e\u96c6\u662f\u504f\u659c\u7684\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u6bd4\u53e6\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u591a\u5f88\u591a\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u56e0\u4e3a\u5b83\u4e0d\u80fd\u4ee3\u8868\u6570\u636e\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u80fd\u4f1a\u83b7\u5f97\u5f88\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u4f46\u60a8\u7684\u6a21\u578b\u5728\u5b9e\u9645\u6837\u672c\u4e2d\u7684\u8868\u73b0\u53ef\u80fd\u5e76\u4e0d\u7406\u60f3\uff0c\u800c\u4e14\u60a8\u4e5f\u65e0\u6cd5\u5411\u7ecf\u7406\u89e3\u91ca\u539f\u56e0\u3002 \u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b \u7cbe\u786e\u7387 \u7b49\u5176\u4ed6\u6307\u6807\u3002 \u5728\u5b66\u4e60\u7cbe\u786e\u7387\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e86\u89e3\u4e00\u4e9b\u672f\u8bed\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5047\u8bbe\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e3a\u6b63\u7c7b (1)\uff0c\u6ca1\u6709\u6c14\u80f8\u7684\u4e3a\u8d1f\u7c7b (0)\u3002 \u771f\u9633\u6027 \uff08TP\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6709\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9633\u6027\u3002 \u771f\u9634\u6027 \uff08TN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u800c\u5b9e\u9645\u76ee\u6807\u663e\u793a\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u6b63\u786e\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9633\u6027\uff1b\u5982\u679c\u60a8\u7684\u6a21\u578b\u51c6\u786e\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9634\u6027\u3002 \u5047\u9633\u6027 \uff08FP\uff09 \uff1a\u7ed9\u5b9a\u4e00\u5f20\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u975e\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9633\u6027\u3002 \u5047\u9634\u6027 \uff08FN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u975e\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\uff0c\u90a3\u4e48\u5b83\u5c31\u662f\u5047\u9633\u6027\u3002\u5982\u679c\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5219\u662f\u5047\u9634\u6027\u3002 \u8ba9\u6211\u4eec\u9010\u4e00\u770b\u770b\u8fd9\u4e9b\u5b9e\u73b0\u3002 def true_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u8ba1\u6570\u5668 tp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 1 : tp += 1 # \u8fd4\u56de\u771f\u9633\u6027\u6837\u672c\u6570 return tp def true_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9634\u6027\u6837\u672c\u8ba1\u6570\u5668 tn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 0 : tn += 1 # \u8fd4\u56de\u771f\u9634\u6027\u6837\u672c\u6570 return tn def false_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9633\u6027\u8ba1\u6570\u5668 fp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 1 : fp += 1 # \u8fd4\u56de\u5047\u9633\u6027\u6837\u672c\u6570 return fp def false_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9634\u6027\u8ba1\u6570\u5668 fn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 0 : fn += 1 # \u8fd4\u56de\u5047\u9634\u6027\u6570 return fn \u6211\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e9b\u529f\u80fd\u7684\u65b9\u6cd5\u975e\u5e38\u7b80\u5355\uff0c\u800c\u4e14\u53ea\u9002\u7528\u4e8e\u4e8c\u5143\u5206\u7c7b\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u8fd9\u4e9b\u51fd\u6570\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: true_positive ( l1 , l2 ) Out [ X ]: 2 In [ X ]: false_positive ( l1 , l2 ) Out [ X ]: 1 In [ X ]: false_negative ( l1 , l2 ) Out [ X ]: 2 In [ X ]: true_negative ( l1 , l2 ) Out [ X ]: 3 \u5982\u679c\u6211\u4eec\u5fc5\u987b\u7528\u4e0a\u8ff0\u672f\u8bed\u6765\u5b9a\u4e49\u7cbe\u786e\u7387\uff0c\u6211\u4eec\u53ef\u4ee5\u5199\u4e3a\uff1a \\[ Accuracy Score = (TP + TN)/(TP + TN + FP +FN) \\] \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5728 python \u4e2d\u4f7f\u7528 TP\u3001TN\u3001FP \u548c FN \u5feb\u901f\u5b9e\u73b0\u51c6\u786e\u5ea6\u5f97\u5206\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a accuracy_v2\u3002 def accuracy_v2 ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u51c6\u786e\u7387 accuracy_score = ( tp + tn ) / ( tp + tn + fp + fn ) return accuracy_score \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0e\u4e4b\u524d\u7684\u5b9e\u73b0\u548c scikit-learn \u7248\u672c\u8fdb\u884c\u6bd4\u8f83\uff0c\u5feb\u901f\u68c0\u67e5\u8be5\u51fd\u6570\u7684\u6b63\u786e\u6027\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: accuracy ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: accuracy_v2 ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u8bf7\u6ce8\u610f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0cmetrics.accuracy_score \u6765\u81ea scikit-learn\u3002 \u5f88\u597d\u3002\u6240\u6709\u503c\u90fd\u5339\u914d\u3002\u8fd9\u8bf4\u660e\u6211\u4eec\u5728\u5b9e\u73b0\u8fc7\u7a0b\u4e2d\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u5176\u4ed6\u91cd\u8981\u6307\u6807\u3002 \u9996\u5148\u662f\u7cbe\u786e\u7387\u3002\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u662f \\[ Precision = TP/(TP + FP) \\] \u5047\u8bbe\u6211\u4eec\u5728\u65b0\u7684\u504f\u659c\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u4e86\u4e00\u4e2a\u65b0\u6a21\u578b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 90 \u5f20\u56fe\u50cf\u4e2d\u7684 80 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u56fe\u50cf\u4e2d\u7684 8 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6210\u529f\u8bc6\u522b\u4e86 100 \u5f20\u56fe\u50cf\u4e2d\u7684 88 \u5f20\u3002\u56e0\u6b64\uff0c\u51c6\u786e\u7387\u4e3a 0.88 \u6216 88%\u3002 \u4f46\u662f\uff0c\u5728\u8fd9 100 \u5f20\u6837\u672c\u4e2d\uff0c\u6709 10 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u6c14\u80f8\uff0c2 \u5f20\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u975e\u6c14\u80f8\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 TP : 8 TN: 80 FP: 10 FN: 2 \u7cbe\u786e\u7387\u4e3a 8 / (8 + 10) = 0.444\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u5728\u8bc6\u522b\u9633\u6027\u6837\u672c\uff08\u6c14\u80f8\uff09\u65f6\u6709 44.4% \u7684\u6b63\u786e\u7387\u3002 \u73b0\u5728\uff0c\u65e2\u7136\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86 TP\u3001TN\u3001FP \u548c FN\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5728 python \u4e2d\u5b9e\u73b0\u7cbe\u786e\u7387\u4e86\u3002 def precision ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8ba9\u6211\u4eec\u8bd5\u8bd5\u8fd9\u79cd\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u65b9\u5f0f\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: precision ( l1 , l2 ) Out [ X ]: 0.6666666666666666 \u8fd9\u4f3c\u4e4e\u6ca1\u6709\u95ee\u9898\u3002 \u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u770b \u53ec\u56de\u7387 \u3002\u53ec\u56de\u7387\u7684\u5b9a\u4e49\u662f\uff1a \\[ Recall = TP/(TP + FN) \\] \u5728\u4e0a\u8ff0\u60c5\u51b5\u4e0b\uff0c\u53ec\u56de\u7387\u4e3a 8 / (8 + 2) = 0.80\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 80% \u7684\u9633\u6027\u6837\u672c\u3002 def recall ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u53ec\u56de\u7387 recall = tp / ( tp + fn ) return recall \u5c31\u6211\u4eec\u7684\u4e24\u4e2a\u5c0f\u5217\u8868\u800c\u8a00\uff0c\u53ec\u56de\u7387\u5e94\u8be5\u662f 0.5\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: recall ( l1 , l2 ) Out [ X ]: 0.5 \u8fd9\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u503c\u76f8\u7b26\uff01 \u5bf9\u4e8e\u4e00\u4e2a \"\u597d \"\u6a21\u578b\u6765\u8bf4\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u503c\u90fd\u5e94\u8be5\u5f88\u9ad8\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0c\u53ec\u56de\u503c\u76f8\u5f53\u9ad8\u3002\u4f46\u662f\uff0c\u7cbe\u786e\u7387\u5374\u5f88\u4f4e\uff01\u6211\u4eec\u7684\u6a21\u578b\u4ea7\u751f\u4e86\u5927\u91cf\u7684\u8bef\u62a5\uff0c\u4f46\u8bef\u62a5\u8f83\u5c11\u3002\u5728\u8fd9\u7c7b\u95ee\u9898\u4e2d\uff0c\u5047\u9634\u6027\u8f83\u5c11\u662f\u597d\u4e8b\uff0c\u56e0\u4e3a\u4f60\u4e0d\u60f3\u5728\u75c5\u4eba\u6709\u6c14\u80f8\u7684\u60c5\u51b5\u4e0b\u5374\u8bf4\u4ed6\u4eec\u6ca1\u6709\u6c14\u80f8\u3002\u8fd9\u6837\u505a\u4f1a\u9020\u6210\u66f4\u5927\u7684\u4f24\u5bb3\u3002\u4f46\u6211\u4eec\u4e5f\u6709\u5f88\u591a\u5047\u9633\u6027\u7ed3\u679c\uff0c\u8fd9\u4e5f\u4e0d\u662f\u597d\u4e8b\u3002 \u5927\u591a\u6570\u6a21\u578b\u90fd\u4f1a\u9884\u6d4b\u4e00\u4e2a\u6982\u7387\uff0c\u5f53\u6211\u4eec\u9884\u6d4b\u65f6\uff0c\u901a\u5e38\u4f1a\u5c06\u8fd9\u4e2a\u9608\u503c\u9009\u4e3a 0.5\u3002\u8fd9\u4e2a\u9608\u503c\u5e76\u4e0d\u603b\u662f\u7406\u60f3\u7684\uff0c\u6839\u636e\u8fd9\u4e2a\u9608\u503c\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u503c\u53ef\u80fd\u4f1a\u53d1\u751f\u5f88\u5927\u7684\u53d8\u5316\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u7684\u6bcf\u4e2a\u9608\u503c\u90fd\u80fd\u8ba1\u7b97\u51fa\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u53ef\u4ee5\u5728\u8fd9\u4e9b\u503c\u4e4b\u95f4\u7ed8\u5236\u51fa\u66f2\u7ebf\u56fe\u3002\u8fd9\u5e45\u56fe\u6216\u66f2\u7ebf\u88ab\u79f0\u4e3a \"\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\"\u3002 \u5728\u7814\u7a76\u7cbe\u786e\u7387-\u8c03\u7528\u66f2\u7ebf\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5047\u8bbe\u6709\u4e24\u4e2a\u5217\u8868\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0.02638412 , 0.11114267 , 0.31620708 , ... : 0.0490937 , 0.0191491 , 0.17554844 , ... : 0.15952202 , 0.03819563 , 0.11639273 , ... : 0.079377 , 0.08584789 , 0.39095342 , ... : 0.27259048 , 0.03447096 , 0.04644807 , ... : 0.03543574 , 0.18521942 , 0.05934905 , ... : 0.61977213 , 0.33056815 ] \u56e0\u6b64\uff0cy_true \u662f\u6211\u4eec\u7684\u76ee\u6807\u503c\uff0c\u800c y_pred \u662f\u6837\u672c\u88ab\u8d4b\u503c\u4e3a 1 \u7684\u6982\u7387\u503c\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u6211\u4eec\u8981\u770b\u7684\u662f\u9884\u6d4b\u4e2d\u7684\u6982\u7387\uff0c\u800c\u4e0d\u662f\u9884\u6d4b\u503c\uff08\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u9884\u6d4b\u503c\u7684\u8ba1\u7b97\u9608\u503c\u4e3a 0.5\uff09\u3002 precisions = [] recalls = [] thresholds = [ 0.0490937 , 0.05934905 , 0.079377 , 0.08584789 , 0.11114267 , 0.11639273 , 0.15952202 , 0.17554844 , 0.18521942 , 0.27259048 , 0.31620708 , 0.33056815 , 0.39095342 , 0.61977213 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for i in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_prediction = [ 1 if x >= i else 0 for x in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , temp_prediction ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , temp_prediction ) # \u52a0\u5165\u7cbe\u786e\u7387\u5217\u8868 precisions . append ( p ) # \u52a0\u5165\u53ec\u56de\u7387\u5217\u8868 recalls . append ( r ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7ed8\u5236\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 7 , 7 )) # x\u8f74\u4e3a\u53ec\u56de\u7387\uff0cy\u8f74\u4e3a\u7cbe\u786e\u7387 plt . plot ( recalls , precisions ) # \u6dfb\u52a0x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a15 plt . xlabel ( 'Recall' , fontsize = 15 ) # \u6dfb\u52a0y\u8f74\u6807\u7b7e\uff0c\u5b57\u6761\u5927\u5c0f\u4e3a15 plt . ylabel ( 'Precision' , fontsize = 15 ) \u56fe 2 \u663e\u793a\u4e86\u6211\u4eec\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\u5f97\u5230\u7684\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 \u56fe 2\uff1a\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u8fd9\u6761 \u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u4e0e\u60a8\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230\u7684\u66f2\u7ebf\u622a\u7136\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6211\u4eec\u53ea\u6709 20 \u4e2a\u6837\u672c\uff0c\u5176\u4e2d\u53ea\u6709 3 \u4e2a\u662f\u9633\u6027\u6837\u672c\u3002\u4f46\u8fd9\u6ca1\u4ec0\u4e48\u597d\u62c5\u5fc3\u7684\u3002\u8fd9\u8fd8\u662f\u90a3\u6761\u7cbe\u786e\u7387-\u53ec\u56de\u66f2\u7ebf\u3002 \u4f60\u4f1a\u53d1\u73b0\uff0c\u9009\u62e9\u4e00\u4e2a\u65e2\u80fd\u63d0\u4f9b\u826f\u597d\u7cbe\u786e\u7387\u53c8\u80fd\u63d0\u4f9b\u53ec\u56de\u503c\u7684\u9608\u503c\u662f\u5f88\u6709\u6311\u6218\u6027\u7684\u3002\u5982\u679c\u9608\u503c\u8fc7\u9ad8\uff0c\u771f\u9633\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u51cf\u5c11\uff0c\u800c\u5047\u9634\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u589e\u52a0\u3002\u8fd9\u4f1a\u964d\u4f4e\u53ec\u56de\u7387\uff0c\u4f46\u7cbe\u786e\u7387\u5f97\u5206\u4f1a\u5f88\u9ad8\u3002\u5982\u679c\u5c06\u9608\u503c\u964d\u5f97\u592a\u4f4e\uff0c\u5219\u8bef\u62a5\u4f1a\u5927\u91cf\u589e\u52a0\uff0c\u7cbe\u786e\u7387\u4e5f\u4f1a\u964d\u4f4e\u3002 \u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u8d8a\u63a5\u8fd1 1 \u8d8a\u597d\u3002 F1 \u5206\u6570\u662f\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7efc\u5408\u6307\u6807\u3002\u5b83\u88ab\u5b9a\u4e49\u4e3a\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7b80\u5355\u52a0\u6743\u5e73\u5747\u503c\uff08\u8c03\u548c\u5e73\u5747\u503c\uff09\u3002\u5982\u679c\u6211\u4eec\u7528 P \u8868\u793a\u7cbe\u786e\u7387\uff0c\u7528 R \u8868\u793a\u53ec\u56de\u7387\uff0c\u90a3\u4e48 F1 \u5206\u6570\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a \\[ F1 = 2PR/(P + R) \\] \u6839\u636e TP\u3001FP \u548c FN\uff0c\u7a0d\u52a0\u6570\u5b66\u8ba1\u7b97\u5c31\u80fd\u5f97\u51fa\u4ee5\u4e0b F1 \u7b49\u5f0f\uff1a \\[ F1 = 2TP/(2TP + FP + FN) \\] Python \u5b9e\u73b0\u5f88\u7b80\u5355\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86\u8fd9\u4e9b def f1 ( y_true , y_pred ): # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , y_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , y_pred ) # \u8ba1\u7b97f1\u503c score = 2 * p * r / ( p + r ) return score \u8ba9\u6211\u4eec\u770b\u770b\u5176\u7ed3\u679c\uff0c\u5e76\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: f1 ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u901a\u8fc7 scikit learn\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u76f8\u540c\u7684\u5217\u8868\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . f1_score ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u4e0e\u5176\u5355\u72ec\u770b\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u60a8\u8fd8\u53ef\u4ee5\u53ea\u770b F1 \u5206\u6570\u3002\u4e0e\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c\u51c6\u786e\u5ea6\u4e00\u6837\uff0cF1 \u5206\u6570\u7684\u8303\u56f4\u4e5f\u662f\u4ece 0 \u5230 1\uff0c\u5b8c\u7f8e\u9884\u6d4b\u6a21\u578b\u7684 F1 \u5206\u6570\u4e3a 1\u3002 \u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5e94\u8be5\u4e86\u89e3\u5176\u4ed6\u4e00\u4e9b\u5173\u952e\u672f\u8bed\u3002 \u7b2c\u4e00\u4e2a\u672f\u8bed\u662f TPR \u6216\u771f\u9633\u6027\u7387\uff08True Positive Rate\uff09\uff0c\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\u3002 \\[ TPR = TP/(TP + FN) \\] \u5c3d\u7ba1\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\uff0c\u4f46\u6211\u4eec\u5c06\u4e3a\u5b83\u521b\u5efa\u4e00\u4e2a python \u51fd\u6570\uff0c\u4ee5\u4fbf\u4eca\u540e\u4f7f\u7528\u8fd9\u4e2a\u540d\u79f0\u3002 def tpr ( y_true , y_pred ): # \u771f\u9633\u6027\u7387\uff08TPR\uff09\uff0c\u4e0e\u53ec\u56de\u7387\u8ba1\u7b97\u516c\u5f0f\u4e00\u81f4 return recall ( y_true , y_pred ) TPR \u6216\u53ec\u56de\u7387\u4e5f\u88ab\u79f0\u4e3a\u7075\u654f\u5ea6\u3002 \u800c FPR \u6216\u5047\u9633\u6027\u7387\uff08False Positive Rate\uff09\u7684\u5b9a\u4e49\u662f\uff1a \\[ FPR = FP / (TN + FP) \\] def fpr ( y_true , y_pred ): # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u8fd4\u56de\u5047\u9633\u6027\u7387\uff08FPR\uff09 return fp / ( tn + fp ) 1 - FPR \u88ab\u79f0\u4e3a\u7279\u5f02\u6027\u6216\u771f\u9634\u6027\u7387\u6216 TNR\u3002\u8fd9\u4e9b\u672f\u8bed\u5f88\u591a\uff0c\u4f46\u5176\u4e2d\u6700\u91cd\u8981\u7684\u53ea\u6709 TPR \u548c FPR\u3002\u5047\u8bbe\u6211\u4eec\u53ea\u6709 15 \u4e2a\u6837\u672c\uff0c\u5176\u76ee\u6807\u503c\u4e3a\u4e8c\u5143\uff1a Actual targets : [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1] \u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u7c7b\u4f3c\u968f\u673a\u68ee\u6797\u7684\u6a21\u578b\uff0c\u5c31\u80fd\u5f97\u5230\u6837\u672c\u5448\u9633\u6027\u7684\u6982\u7387\u3002 Predicted probabilities for 1: [0.1, 0.3, 0.2, 0.6, 0.8, 0.05, 0.9, 0.5, 0.3, 0.66, 0.3, 0.2, 0.85, 0.15, 0.99] \u5bf9\u4e8e >= 0.5 \u7684\u5178\u578b\u9608\u503c\uff0c\u6211\u4eec\u53ef\u4ee5\u8bc4\u4f30\u4e0a\u8ff0\u6240\u6709\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387/TPR\u3001F1 \u548c FPR \u503c\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u5c06\u9608\u503c\u9009\u4e3a 0.4 \u6216 0.6\uff0c\u4e5f\u53ef\u4ee5\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0 \u5230 1 \u4e4b\u95f4\u7684\u4efb\u4f55\u503c\uff0c\u5e76\u8ba1\u7b97\u4e0a\u8ff0\u6240\u6709\u6307\u6807\u3002 \u4e0d\u8fc7\uff0c\u6211\u4eec\u53ea\u8ba1\u7b97\u4e24\u4e2a\u503c\uff1a TPR \u548c FPR\u3002 # \u521d\u59cb\u5316\u771f\u9633\u6027\u7387\u5217\u8868 tpr_list = [] # \u521d\u59cb\u5316\u5047\u9633\u6027\u7387\u5217\u8868 fpr_list = [] # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u7387 temp_tpr = tpr ( y_true , temp_pred ) # \u5047\u9633\u6027\u7387 temp_fpr = fpr ( y_true , temp_pred ) # \u5c06\u771f\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 tpr_list . append ( temp_tpr ) # \u5c06\u5047\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 fpr_list . append ( temp_fpr ) \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u6bcf\u4e2a\u9608\u503c\u7684 TPR \u503c\u548c FPR \u503c\u3002 \u56fe 3\uff1a\u9608\u503c\u3001TPR \u548c FPR \u503c\u8868 \u5982\u679c\u6211\u4eec\u7ed8\u5236\u5982\u56fe 3 \u6240\u793a\u7684\u8868\u683c\uff0c\u5373\u4ee5 TPR \u4e3a Y \u8f74\uff0cFPR \u4e3a X \u8f74\uff0c\u5c31\u4f1a\u5f97\u5230\u5982\u56fe 4 \u6240\u793a\u7684\u66f2\u7ebf\u3002 \u56fe 4\uff1aROC\u66f2\u7ebf \u8fd9\u6761\u66f2\u7ebf\u4e5f\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u3002\u5982\u679c\u6211\u4eec\u8ba1\u7b97\u8fd9\u6761 ROC \u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff0c\u5c31\u662f\u5728\u8ba1\u7b97\u53e6\u4e00\u4e2a\u6307\u6807\uff0c\u5f53\u6570\u636e\u96c6\u7684\u4e8c\u5143\u76ee\u6807\u504f\u659c\u65f6\uff0c\u8fd9\u4e2a\u6307\u6807\u5c31\u4f1a\u975e\u5e38\u5e38\u7528\u3002 \u8fd9\u4e2a\u6307\u6807\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u6216\u66f2\u7ebf\u4e0b\u9762\u79ef\uff0c\u7b80\u79f0 AUC\u3002\u8ba1\u7b97 ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002\u5728\u6b64\uff0c\u6211\u4eec\u5c06\u91c7\u7528 scikit- learn \u7684\u5947\u5999\u5b9e\u73b0\u65b9\u6cd5\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: metrics . roc_auc_score ( y_true , y_pred ) Out [ X ]: 0.8300000000000001 AUC \u503c\u4ece 0 \u5230 1 \u4e0d\u7b49\u3002 AUC = 1 \u610f\u5473\u7740\u60a8\u62e5\u6709\u4e00\u4e2a\u5b8c\u7f8e\u7684\u6a21\u578b\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u610f\u5473\u7740\u4f60\u5728\u9a8c\u8bc1\u65f6\u72af\u4e86\u4e00\u4e9b\u9519\u8bef\uff0c\u5e94\u8be5\u91cd\u65b0\u5ba1\u89c6\u6570\u636e\u5904\u7406\u548c\u9a8c\u8bc1\u6d41\u7a0b\u3002\u5982\u679c\u4f60\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\uff0c\u90a3\u4e48\u606d\u559c\u4f60\uff0c\u4f60\u5df2\u7ecf\u62e5\u6709\u4e86\u9488\u5bf9\u6570\u636e\u96c6\u5efa\u7acb\u7684\u6700\u4f73\u6a21\u578b\u3002 AUC = 0 \u610f\u5473\u7740\u60a8\u7684\u6a21\u578b\u975e\u5e38\u7cdf\u7cd5\uff08\u6216\u975e\u5e38\u597d\uff01\uff09\u3002\u8bd5\u7740\u53cd\u8f6c\u9884\u6d4b\u7684\u6982\u7387\uff0c\u4f8b\u5982\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u6b63\u7c7b\u7684\u6982\u7387\u662f p\uff0c\u8bd5\u7740\u7528 1-p \u4ee3\u66ff\u5b83\u3002\u8fd9\u79cd AUC \u4e5f\u53ef\u80fd\u610f\u5473\u7740\u60a8\u7684\u9a8c\u8bc1\u6216\u6570\u636e\u5904\u7406\u5b58\u5728\u95ee\u9898\u3002 AUC = 0.5 \u610f\u5473\u7740\u4f60\u7684\u9884\u6d4b\u662f\u968f\u673a\u7684\u3002\u56e0\u6b64\uff0c\u5bf9\u4e8e\u4efb\u4f55\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u5982\u679c\u6211\u5c06\u6240\u6709\u76ee\u6807\u90fd\u9884\u6d4b\u4e3a 0.5\uff0c\u6211\u5c06\u5f97\u5230 0.5 \u7684 AUC\u3002 AUC \u503c\u4ecb\u4e8e 0 \u548c 0.5 \u4e4b\u95f4\uff0c\u610f\u5473\u7740\u4f60\u7684\u6a21\u578b\u6bd4\u968f\u673a\u6a21\u578b\u66f4\u5dee\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u662f\u56e0\u4e3a\u4f60\u98a0\u5012\u4e86\u7c7b\u522b\u3002 \u5982\u679c\u60a8\u5c1d\u8bd5\u53cd\u8f6c\u9884\u6d4b\uff0c\u60a8\u7684 AUC \u503c\u53ef\u80fd\u4f1a\u8d85\u8fc7 0.5\u3002\u63a5\u8fd1 1 \u7684 AUC \u503c\u88ab\u8ba4\u4e3a\u662f\u597d\u503c\u3002 \u4f46 AUC \u5bf9\u6211\u4eec\u7684\u6a21\u578b\u6709\u4ec0\u4e48\u5f71\u54cd\u5462\uff1f \u5047\u8bbe\u60a8\u5efa\u7acb\u4e86\u4e00\u4e2a\u4ece\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e2d\u68c0\u6d4b\u6c14\u80f8\u7684\u6a21\u578b\uff0c\u5176 AUC \u503c\u4e3a 0.85\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5982\u679c\u60a8\u4ece\u6570\u636e\u96c6\u4e2d\u968f\u673a\u9009\u62e9\u4e00\u5f20\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9633\u6027\u6837\u672c\uff09\u548c\u53e6\u4e00\u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9634\u6027\u6837\u672c\uff09\uff0c\u90a3\u4e48\u6c14\u80f8\u56fe\u50cf\u7684\u6392\u540d\u5c06\u9ad8\u4e8e\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u6982\u7387\u4e3a 0.85\u3002 \u8ba1\u7b97\u6982\u7387\u548c AUC \u540e\uff0c\u60a8\u9700\u8981\u5bf9\u6d4b\u8bd5\u96c6\u8fdb\u884c\u9884\u6d4b\u3002\u6839\u636e\u95ee\u9898\u548c\u4f7f\u7528\u60c5\u51b5\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u6982\u7387\u6216\u5b9e\u9645\u7c7b\u522b\u3002\u5982\u679c\u4f60\u60f3\u8981\u6982\u7387\uff0c\u8fd9\u5e76\u4e0d\u96be\u3002\u5982\u679c\u60a8\u60f3\u8981\u7c7b\u522b\uff0c\u5219\u9700\u8981\u9009\u62e9\u4e00\u4e2a\u9608\u503c\u3002\u5728\u4e8c\u5143\u5206\u7c7b\u7684\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u91c7\u7528\u7c7b\u4f3c\u4e0b\u9762\u7684\u65b9\u6cd5\u3002 \\[ Prediction = Probability >= Threshold \\] \u4e5f\u5c31\u662f\u8bf4\uff0c\u9884\u6d4b\u662f\u4e00\u4e2a\u53ea\u5305\u542b\u4e8c\u5143\u53d8\u91cf\u7684\u65b0\u5217\u8868\u3002\u5982\u679c\u6982\u7387\u5927\u4e8e\u6216\u7b49\u4e8e\u7ed9\u5b9a\u7684\u9608\u503c\uff0c\u5219\u9884\u6d4b\u4e2d\u7684\u4e00\u9879\u4e3a 1\uff0c\u5426\u5219\u4e3a 0\u3002 \u4f60\u731c\u600e\u4e48\u7740\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 ROC \u66f2\u7ebf\u6765\u9009\u62e9\u8fd9\u4e2a\u9608\u503c\uff01ROC \u66f2\u7ebf\u4f1a\u544a\u8bc9\u60a8\u9608\u503c\u5bf9\u5047\u9633\u6027\u7387\u548c\u771f\u9633\u6027\u7387\u7684\u5f71\u54cd\uff0c\u8fdb\u800c\u5f71\u54cd\u5047\u9633\u6027\u548c\u771f\u9633\u6027\u3002\u60a8\u5e94\u8be5\u9009\u62e9\u6700\u9002\u5408\u60a8\u7684\u95ee\u9898\u548c\u6570\u636e\u96c6\u7684\u9608\u503c\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u4e0d\u5e0c\u671b\u6709\u592a\u591a\u7684\u8bef\u62a5\uff0c\u90a3\u4e48\u9608\u503c\u5c31\u5e94\u8be5\u9ad8\u4e00\u4e9b\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e5f\u4f1a\u5e26\u6765\u66f4\u591a\u7684\u8bef\u62a5\u3002\u6ce8\u610f\u6743\u8861\u5229\u5f0a\uff0c\u9009\u62e9\u6700\u4f73\u9608\u503c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u9608\u503c\u5982\u4f55\u5f71\u54cd\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u503c\u3002 # \u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list = [] # \u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list = [] # \u771f\u5b9e\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 temp_tp = true_positive ( y_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 temp_fp = false_positive ( y_true , temp_pred ) # \u52a0\u5165\u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list . append ( temp_tp ) # \u52a0\u5165\u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list . append ( temp_fp ) \u5229\u7528\u8fd9\u4e00\u70b9\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8868\u683c\uff0c\u5982\u56fe 5 \u6240\u793a\u3002 \u56fe 5\uff1a\u4e0d\u540c\u9608\u503c\u7684 TP \u503c\u548c FP \u503c \u5982\u56fe 6 \u6240\u793a\uff0c\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0cROC \u66f2\u7ebf\u5de6\u4e0a\u89d2\u7684\u503c\u5e94\u8be5\u662f\u4e00\u4e2a\u76f8\u5f53\u4e0d\u9519\u7684\u9608\u503c\u3002 \u5bf9\u6bd4\u8868\u683c\u548c ROC \u66f2\u7ebf\uff0c\u6211\u4eec\u53ef\u4ee5\u53d1\u73b0\uff0c0.6 \u5de6\u53f3\u7684\u9608\u503c\u76f8\u5f53\u4e0d\u9519\uff0c\u65e2\u4e0d\u4f1a\u4e22\u5931\u5927\u91cf\u7684\u771f\u9633\u6027\u7ed3\u679c\uff0c\u4e5f\u4e0d\u4f1a\u51fa\u73b0\u5927\u91cf\u7684\u5047\u9633\u6027\u7ed3\u679c\u3002 \u56fe 6\uff1a\u4ece ROC \u66f2\u7ebf\u6700\u5de6\u4fa7\u7684\u9876\u70b9\u9009\u62e9\u6700\u4f73\u9608\u503c AUC \u662f\u4e1a\u5185\u5e7f\u6cdb\u5e94\u7528\u4e8e\u504f\u659c\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6307\u6807\uff0c\u4e5f\u662f\u6bcf\u4e2a\u4eba\u90fd\u5e94\u8be5\u4e86\u89e3\u7684\u6307\u6807\u3002\u4e00\u65e6\u7406\u89e3\u4e86 AUC \u80cc\u540e\u7684\u7406\u5ff5\uff08\u5982\u4e0a\u6587\u6240\u8ff0\uff09\uff0c\u4e5f\u5c31\u5f88\u5bb9\u6613\u5411\u4e1a\u754c\u53ef\u80fd\u4f1a\u8bc4\u4f30\u60a8\u7684\u6a21\u578b\u7684\u975e\u6280\u672f\u4eba\u5458\u89e3\u91ca\u5b83\u4e86\u3002 \u5b66\u4e60 AUC \u540e\uff0c\u4f60\u5e94\u8be5\u5b66\u4e60\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u6307\u6807\u662f\u5bf9\u6570\u635f\u5931\u3002\u5bf9\u4e8e\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5c06\u5bf9\u6570\u635f\u5931\u5b9a\u4e49\u4e3a\uff1a \\[ LogLoss = -1.0 \\times (target \\times log(prediction) + (1-target) \\times log(1-prediction)) \\] \u5176\u4e2d\uff0c\u76ee\u6807\u503c\u4e3a 0 \u6216 1\uff0c\u9884\u6d4b\u503c\u4e3a\u6837\u672c\u5c5e\u4e8e\u7c7b\u522b 1 \u7684\u6982\u7387\u3002 \u5bf9\u4e8e\u6570\u636e\u96c6\u4e2d\u7684\u591a\u4e2a\u6837\u672c\uff0c\u6240\u6709\u6837\u672c\u7684\u5bf9\u6570\u635f\u5931\u53ea\u662f\u6240\u6709\u5355\u4e2a\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u9700\u8981\u8bb0\u4f4f\u7684\u4e00\u70b9\u662f\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u4e0d\u6b63\u786e\u6216\u504f\u5dee\u8f83\u5927\u7684\u9884\u6d4b\u8fdb\u884c\u76f8\u5f53\u9ad8\u7684\u60e9\u7f5a\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u975e\u5e38\u786e\u5b9a\u548c\u975e\u5e38\u9519\u8bef\u7684\u9884\u6d4b\u8fdb\u884c\u60e9\u7f5a\u3002 import numpy as np def log_loss ( y_true , y_proba ): # \u6781\u5c0f\u503c\uff0c\u9632\u6b620\u505a\u5206\u6bcd epsilon = 1e-15 # \u5bf9\u6570\u635f\u5931\u5217\u8868 loss = [] # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_proba ): # \u9650\u5236yp\u8303\u56f4\uff0c\u6700\u5c0f\u4e3aepsilon\uff0c\u6700\u5927\u4e3a1-epsilon yp = np . clip ( yp , epsilon , 1 - epsilon ) # \u8ba1\u7b97\u5bf9\u6570\u635f\u5931 temp_loss = - 1.0 * ( yt * np . log ( yp ) + ( 1 - yt ) * np . log ( 1 - yp )) # \u52a0\u5165\u5bf9\u6570\u635f\u5931\u5217\u8868 loss . append ( temp_loss ) return np . mean ( loss ) \u8ba9\u6211\u4eec\u6d4b\u8bd5\u4e00\u4e0b\u51fd\u6570\u6267\u884c\u60c5\u51b5\uff1a In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_proba = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u6211\u4eec\u53ef\u4ee5\u5c06\u5176\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u5b9e\u73b0\u662f\u6b63\u786e\u7684\u3002 \u5bf9\u6570\u635f\u5931\u7684\u5b9e\u73b0\u5f88\u5bb9\u6613\u3002\u89e3\u91ca\u8d77\u6765\u4f3c\u4e4e\u6709\u70b9\u56f0\u96be\u3002\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5bf9\u6570\u635f\u5931\u7684\u60e9\u7f5a\u8981\u6bd4\u5176\u4ed6\u6307\u6807\u5927\u5f97\u591a\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u6709 51% \u7684\u628a\u63e1\u8ba4\u4e3a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\uff0c\u90a3\u4e48\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.51) + (1 - 1) \\times log(1 - 0.51))=0.67 \\] \u5982\u679c\u4f60\u5bf9\u5c5e\u4e8e 0 \u7c7b\u7684\u6837\u672c\u6709 49% \u7684\u628a\u63e1\uff0c\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.49) + (1 - 1) \\times log(1 - 0.49))=0.67 \\] \u56e0\u6b64\uff0c\u5373\u4f7f\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0.5 \u7684\u622a\u65ad\u503c\u5e76\u5f97\u5230\u5b8c\u7f8e\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u6211\u4eec\u4ecd\u7136\u4f1a\u6709\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002\u56e0\u6b64\uff0c\u5728\u5904\u7406\u5bf9\u6570\u635f\u5931\u65f6\uff0c\u4f60\u9700\u8981\u975e\u5e38\u5c0f\u5fc3\uff1b\u4efb\u4f55\u4e0d\u786e\u5b9a\u7684\u9884\u6d4b\u90fd\u4f1a\u4ea7\u751f\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002 \u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u5927\u591a\u6570\u6307\u6807\u90fd\u53ef\u4ee5\u8f6c\u6362\u6210\u591a\u7c7b\u7248\u672c\u3002\u8fd9\u4e2a\u60f3\u6cd5\u5f88\u7b80\u5355\u3002\u4ee5\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u4e3a\u4f8b\u3002\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u6bcf\u4e00\u7c7b\u7684\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u6709\u4e09\u79cd\u4e0d\u540c\u7684\u8ba1\u7b97\u65b9\u6cd5\uff0c\u6709\u65f6\u53ef\u80fd\u4f1a\u4ee4\u4eba\u56f0\u60d1\u3002\u5047\u8bbe\u6211\u4eec\u9996\u5148\u5bf9\u7cbe\u786e\u7387\u611f\u5174\u8da3\u3002\u6211\u4eec\u77e5\u9053\uff0c\u7cbe\u786e\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u3002 \u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Macro averaged precision\uff09\uff1a\u5206\u522b\u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u7cbe\u786e\u7387\u7136\u540e\u6c42\u5e73\u5747\u503c \u5fae\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Micro averaged precision\uff09\uff1a\u6309\u7c7b\u8ba1\u7b97\u771f\u9633\u6027\u548c\u5047\u9633\u6027\uff0c\u7136\u540e\u7528\u5176\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387\u3002\u7136\u540e\u4ee5\u6b64\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387 \u52a0\u6743\u7cbe\u786e\u7387 \uff08Weighted precision\uff09\uff1a\u4e0e\u5b8f\u89c2\u7cbe\u786e\u7387\u76f8\u540c\uff0c\u4f46\u8fd9\u91cc\u662f\u52a0\u6743\u5e73\u5747\u7cbe\u786e\u7387 \u53d6\u51b3\u4e8e\u6bcf\u4e2a\u7c7b\u522b\u4e2d\u7684\u9879\u76ee\u6570 \u8fd9\u770b\u4f3c\u590d\u6742\uff0c\u4f46\u5728 python \u5b9e\u73b0\u4e2d\u5f88\u5bb9\u6613\u7406\u89e3\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import numpy as np def macro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u5982\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u8ba1\u7b97\u7cbe\u786e\u5ea6 temp_precision = tp / ( tp + fp ) # \u5404\u7c7b\u7cbe\u786e\u7387\u76f8\u52a0 precision += temp_precision # \u8ba1\u7b97\u5e73\u5747\u503c precision /= num_classes return precision \u4f60\u4f1a\u53d1\u73b0\u8fd9\u5e76\u4e0d\u96be\u3002\u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5fae\u5e73\u5747\u7cbe\u786e\u7387\u5206\u6570\u3002 import numpy as np def micro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u6570 tp = 0 # \u521d\u59cb\u5316\u5047\u9633\u6027\u6837\u672c\u6570 fp = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 tp += true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 fp += false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8fd9\u4e5f\u4e0d\u96be\u3002\u90a3\u4ec0\u4e48\u96be\uff1f\u4ec0\u4e48\u90fd\u4e0d\u96be\u3002\u673a\u5668\u5b66\u4e60\u5f88\u7b80\u5355\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u52a0\u6743\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u3002 from collections import Counter import numpy as np def weighted_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 temp_precision = tp / ( tp + fp ) # \u6839\u636e\u8be5\u79cd\u7c7b\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_precision = class_counts [ class_ ] * temp_precision # \u52a0\u6743\u7cbe\u786e\u7387\u6c42\u548c precision += weighted_precision # \u8ba1\u7b97\u5e73\u5747\u7cbe\u786e\u7387 overall_precision = precision / len ( y_true ) return overall_precision \u5c06\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff0c\u4ee5\u4e86\u89e3\u5b9e\u73b0\u662f\u5426\u6b63\u786e\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: macro_precision ( y_true , y_pred ) Out [ X ]: 0.3611111111111111 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"macro\" ) Out [ X ]: 0.3611111111111111 In [ X ]: micro_precision ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"micro\" ) Out [ X ]: 0.4444444444444444 In [ X ]: weighted_precision ( y_true , y_pred ) Out [ X ]: 0.39814814814814814 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.39814814814814814 \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6b63\u786e\u5730\u5b9e\u73b0\u4e86\u4e00\u5207\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u91cc\u5c55\u793a\u7684\u5b9e\u73b0\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\uff0c\u4f46\u5374\u662f\u6700\u5bb9\u6613\u7406\u89e3\u7684\u3002 \u540c\u6837\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5b9e\u73b0 \u591a\u7c7b\u522b\u7684\u53ec\u56de\u7387\u6307\u6807 \u3002\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u3001\u5047\u9633\u6027\u548c\u5047\u9634\u6027\uff0c\u800c F1 \u5219\u53d6\u51b3\u4e8e\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u53ec\u56de\u7387\u7684\u5b9e\u73b0\u65b9\u6cd5\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\uff0c\u8fd9\u91cc\u5b9e\u73b0\u7684\u662f\u591a\u7c7b F1 \u7684\u4e00\u4e2a\u7248\u672c\uff0c\u5373\u52a0\u6743\u5e73\u5747\u503c\u3002 from collections import Counter import numpy as np def weighted_f1 ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316F1\u503c f1 = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( temp_true , temp_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( temp_true , temp_pred ) # \u82e5\u7cbe\u786e\u7387+\u53ec\u56de\u7387\u4e0d\u4e3a0\uff0c\u5219\u4f7f\u7528\u516c\u5f0f\u8ba1\u7b97F1\u503c if p + r != 0 : temp_f1 = 2 * p * r / ( p + r ) # \u5426\u5219\u76f4\u63a5\u4e3a0 else : temp_f1 = 0 # \u6839\u636e\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_f1 = class_counts [ class_ ] * temp_f1 # \u52a0\u6743F1\u503c\u76f8\u52a0 f1 += weighted_f1 # \u8ba1\u7b97\u52a0\u6743\u5e73\u5747F1\u503c overall_f1 = f1 / len ( y_true ) return overall_f1 \u8bf7\u6ce8\u610f\uff0c\u4e0a\u9762\u6709\u51e0\u884c\u4ee3\u7801\u662f\u65b0\u5199\u7684\u3002\u56e0\u6b64\uff0c\u4f60\u5e94\u8be5\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e9b\u4ee3\u7801\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: weighted_f1 ( y_true , y_pred ) Out [ X ]: 0.41269841269841273 In [ X ]: metrics . f1_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.41269841269841273 \u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u591a\u7c7b\u95ee\u9898\u5b9e\u73b0\u4e86\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002\u540c\u6837\uff0c\u60a8\u4e5f\u53ef\u4ee5\u5c06 AUC \u548c\u5bf9\u6570\u635f\u5931\u8f6c\u6362\u4e3a\u591a\u7c7b\u683c\u5f0f\u3002\u8fd9\u79cd\u8f6c\u6362\u683c\u5f0f\u88ab\u79f0\u4e3a one-vs-all \u3002\u8fd9\u91cc\u6211\u4e0d\u6253\u7b97\u5b9e\u73b0\u5b83\u4eec\uff0c\u56e0\u4e3a\u5b9e\u73b0\u65b9\u6cd5\u4e0e\u6211\u4eec\u5df2\u7ecf\u8ba8\u8bba\u8fc7\u7684\u5f88\u76f8\u4f3c\u3002 \u5728\u4e8c\u5143\u6216\u591a\u7c7b\u5206\u7c7b\u4e2d\uff0c\u770b\u4e00\u4e0b \u6df7\u6dc6\u77e9\u9635 \u4e5f\u5f88\u6d41\u884c\u3002\u4e0d\u8981\u56f0\u60d1\uff0c\u8fd9\u5f88\u7b80\u5355\u3002\u6df7\u6dc6\u77e9\u9635\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u5305\u542b TP\u3001FP\u3001TN \u548c FN \u7684\u8868\u683c\u3002\u4f7f\u7528\u6df7\u6dc6\u77e9\u9635\uff0c\u60a8\u53ef\u4ee5\u5feb\u901f\u67e5\u770b\u6709\u591a\u5c11\u6837\u672c\u88ab\u9519\u8bef\u5206\u7c7b\uff0c\u6709\u591a\u5c11\u6837\u672c\u88ab\u6b63\u786e\u5206\u7c7b\u3002\u4e5f\u8bb8\u6709\u4eba\u4f1a\u8bf4\uff0c\u6df7\u6dc6\u77e9\u9635\u5e94\u8be5\u5728\u672c\u7ae0\u5f88\u65e9\u5c31\u8bb2\u5230\uff0c\u4f46\u6211\u6ca1\u6709\u8fd9\u4e48\u505a\u3002\u5982\u679c\u4e86\u89e3\u4e86 TP\u3001FP\u3001TN\u3001FN\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c AUC\uff0c\u5c31\u5f88\u5bb9\u6613\u7406\u89e3\u548c\u89e3\u91ca\u6df7\u6dc6\u77e9\u9635\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u56fe 7 \u4e2d\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6df7\u6dc6\u77e9\u9635\u7531 TP\u3001FP\u3001FN \u548c TN \u7ec4\u6210\u3002\u6211\u4eec\u53ea\u9700\u8981\u8fd9\u4e9b\u503c\u6765\u8ba1\u7b97\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u3001F1 \u5206\u6570\u548c AUC\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u4e5f\u559c\u6b22\u628a FP \u79f0\u4e3a \u7b2c\u4e00\u7c7b\u9519\u8bef \uff0c\u628a FN \u79f0\u4e3a \u7b2c\u4e8c\u7c7b\u9519\u8bef \u3002 \u56fe 7\uff1a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6df7\u6dc6\u77e9\u9635 \u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u4e8c\u5143\u6df7\u6dc6\u77e9\u9635\u6269\u5c55\u4e3a\u591a\u7c7b\u6df7\u6dc6\u77e9\u9635\u3002\u5b83\u4f1a\u662f\u4ec0\u4e48\u6837\u5b50\u5462\uff1f\u5982\u679c\u6211\u4eec\u6709 N \u4e2a\u7c7b\u522b\uff0c\u5b83\u5c06\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a NxN \u7684\u77e9\u9635\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b\uff0c\u6211\u4eec\u90fd\u8981\u8ba1\u7b97\u76f8\u5173\u7c7b\u522b\u548c\u5176\u4ed6\u7c7b\u522b\u7684\u6837\u672c\u603b\u6570\u3002\u4e3e\u4e2a\u4f8b\u5b50\u53ef\u4ee5\u8ba9\u6211\u4eec\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4ee5\u4e0b\u771f\u5b9e\u6807\u7b7e\uff1a \\[ [0, 1, 2, 0, 1, 2, 0, 2, 2] \\] \u6211\u4eec\u7684\u9884\u6d4b\u6807\u7b7e\u662f\uff1a \\[ [0, 2, 1, 0, 2, 1, 0, 0, 2] \\] \u90a3\u4e48\uff0c\u6211\u4eec\u7684\u6df7\u6dc6\u77e9\u9635\u5c06\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u591a\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635 \u56fe 8 \u8bf4\u660e\u4e86\u4ec0\u4e48\uff1f \u8ba9\u6211\u4eec\u6765\u770b\u770b 0 \u7c7b\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e 0 \u7c7b\u3002\u7136\u800c\uff0c\u5728\u9884\u6d4b\u4e2d\uff0c\u6211\u4eec\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 0 \u7c7b\uff0c1 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\u3002\u7406\u60f3\u60c5\u51b5\u4e0b\uff0c\u5bf9\u4e8e\u771f\u5b9e\u6807\u7b7e\u4e2d\u7684\u7c7b\u522b 0\uff0c\u9884\u6d4b\u6807\u7b7e 1 \u548c 2 \u5e94\u8be5\u6ca1\u6709\u4efb\u4f55\u6837\u672c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u7c7b\u522b 2\u3002\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 4\uff0c\u800c\u5728\u9884\u6d4b\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 3\u3002 \u4e00\u4e2a\u5b8c\u7f8e\u7684\u6df7\u6dc6\u77e9\u9635\u53ea\u80fd\u4ece\u5de6\u5230\u53f3\u659c\u5411\u586b\u5145\u3002 \u6df7\u6dc6\u77e9\u9635 \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u8ba1\u7b97\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u4e0d\u540c\u6307\u6807\u3002Scikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u76f4\u63a5\u7684\u65b9\u6cd5\u6765\u751f\u6210\u6df7\u6dc6\u77e9\u9635\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5728\u56fe 8 \u4e2d\u663e\u793a\u7684\u6df7\u6dc6\u77e9\u9635\u662f scikit-learn \u6df7\u6dc6\u77e9\u9635\u7684\u8f6c\u7f6e\uff0c\u539f\u59cb\u7248\u672c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u7ed8\u5236\u3002 import matplotlib.pyplot as plt import seaborn as sns from sklearn import metrics # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] # \u9884\u6d4b\u6837\u672c\u6807\u7b7e y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] # \u8ba1\u7b97\u6df7\u6dc6\u77e9\u9635 cm = metrics . confusion_matrix ( y_true , y_pred ) # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 10 , 10 )) # \u521b\u5efa\u65b9\u683c cmap = sns . cubehelix_palette ( 50 , hue = 0.05 , rot = 0 , light = 0.9 , dark = 0 , as_cmap = True ) # \u89c4\u5b9a\u5b57\u4f53\u5927\u5c0f sns . set ( font_scale = 2.5 ) # \u7ed8\u5236\u70ed\u56fe sns . heatmap ( cm , annot = True , cmap = cmap , cbar = False ) # y\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . ylabel ( 'Actual Labels' , fontsize = 20 ) # x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . xlabel ( 'Predicted Labels' , fontsize = 20 ) \u56e0\u6b64\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u89e3\u51b3\u4e86\u4e8c\u5143\u5206\u7c7b\u548c\u591a\u7c7b\u5206\u7c7b\u7684\u5ea6\u91cf\u95ee\u9898\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u8ba8\u8bba\u53e6\u4e00\u79cd\u7c7b\u578b\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u5373\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u5728\u591a\u6807\u7b7e\u5206\u7c7b\u4e2d\uff0c\u6bcf\u4e2a\u6837\u672c\u90fd\u53ef\u80fd\u4e0e\u4e00\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u76f8\u5173\u8054\u3002\u8fd9\u7c7b\u95ee\u9898\u7684\u4e00\u4e2a\u7b80\u5355\u4f8b\u5b50\u5c31\u662f\u8981\u6c42\u4f60\u9884\u6d4b\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53\u3002 \u56fe 9 \u663e\u793a\u4e86\u4e00\u4e2a\u8457\u540d\u6570\u636e\u96c6\u7684\u56fe\u50cf\u793a\u4f8b\u3002\u8bf7\u6ce8\u610f\uff0c\u8be5\u6570\u636e\u96c6\u7684\u76ee\u6807\u6709\u6240\u4e0d\u540c\uff0c\u4f46\u6211\u4eec\u6682\u4e14\u4e0d\u53bb\u8ba8\u8bba\u5b83\u3002\u6211\u4eec\u5047\u8bbe\u5176\u76ee\u7684\u53ea\u662f\u9884\u6d4b\u56fe\u50cf\u4e2d\u662f\u5426\u5b58\u5728\u67d0\u4e2a\u7269\u4f53\u3002\u5728\u56fe 9 \u4e2d\uff0c\u6211\u4eec\u6709\u6905\u5b50\u3001\u82b1\u76c6\u3001\u7a97\u6237\uff0c\u4f46\u6ca1\u6709\u5176\u4ed6\u7269\u4f53\uff0c\u5982\u7535\u8111\u3001\u5e8a\u3001\u7535\u89c6\u7b49\u3002\u56e0\u6b64\uff0c\u4e00\u5e45\u56fe\u50cf\u53ef\u80fd\u6709\u591a\u4e2a\u76f8\u5173\u76ee\u6807\u3002\u8fd9\u7c7b\u95ee\u9898\u5c31\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u95ee\u9898\u3002 \u56fe 9\uff1a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53 \u8fd9\u7c7b\u5206\u7c7b\u95ee\u9898\u7684\u8861\u91cf\u6807\u51c6\u6709\u4e9b\u4e0d\u540c\u3002\u4e00\u4e9b\u5408\u9002\u7684 \u6700\u5e38\u89c1\u7684\u6307\u6807\u6709\uff1a k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u786e\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 \u8ba9\u6211\u4eec\u4ece k \u7cbe\u786e\u7387\u6216\u8005 P@k \u6211\u4eec\u4e0d\u80fd\u5c06\u8fd9\u4e00\u7cbe\u786e\u7387\u4e0e\u524d\u9762\u8ba8\u8bba\u7684\u7cbe\u786e\u7387\u6df7\u6dc6\u3002\u5982\u679c\u60a8\u6709\u4e00\u4e2a\u7ed9\u5b9a\u6837\u672c\u7684\u539f\u59cb\u7c7b\u522b\u5217\u8868\u548c\u540c\u4e00\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u7c7b\u522b\u5217\u8868\uff0c\u90a3\u4e48\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u5c31\u662f\u9884\u6d4b\u5217\u8868\u4e2d\u4ec5\u8003\u8651\u524d k \u4e2a\u9884\u6d4b\u7ed3\u679c\u7684\u547d\u4e2d\u6570\u9664\u4ee5 k\u3002 \u5982\u679c\u60a8\u5bf9\u6b64\u611f\u5230\u56f0\u60d1\uff0c\u4f7f\u7528 python \u4ee3\u7801\u540e\u5c31\u4f1a\u660e\u767d\u3002 def pk ( y_true , y_pred , k ): # \u5982\u679ck\u4e3a0 if k == 0 : # \u8fd4\u56de0 return 0 # \u53d6\u9884\u6d4b\u6807\u7b7e\u524dk\u4e2a y_pred = y_pred [: k ] # \u5c06\u9884\u6d4b\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 pred_set = set ( y_pred ) # \u5c06\u771f\u5b9e\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 true_set = set ( y_true ) # \u9884\u6d4b\u6807\u7b7e\u96c6\u5408\u4e0e\u771f\u5b9e\u6807\u7b7e\u96c6\u5408\u4ea4\u96c6 common_values = pred_set . intersection ( true_set ) # \u8ba1\u7b97\u7cbe\u786e\u7387 return len ( common_values ) / len ( y_pred [: k ]) \u6709\u4e86\u4ee3\u7801\uff0c\u4e00\u5207\u90fd\u53d8\u5f97\u66f4\u5bb9\u6613\u7406\u89e3\u4e86\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e86 k \u5e73\u5747\u7cbe\u786e\u7387\u6216 AP@k \u3002AP@k \u662f\u901a\u8fc7 P@k \u8ba1\u7b97\u5f97\u51fa\u7684\u3002\u4f8b\u5982\uff0c\u5982\u679c\u8981\u8ba1\u7b97 AP@3\uff0c\u6211\u4eec\u8981\u5148\u8ba1\u7b97 P@1\u3001P@2 \u548c P@3\uff0c\u7136\u540e\u5c06\u603b\u548c\u9664\u4ee5 3\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u7684\u5b9e\u73b0\u3002 def apk ( y_true , y_pred , k ): # \u521d\u59cb\u5316P@k\u5217\u8868 pk_values = [] # \u904d\u53861~k for i in range ( 1 , k + 1 ): # \u5c06P@k\u52a0\u5165\u5217\u8868 pk_values . append ( pk ( y_true , y_pred , i )) # \u82e5\u957f\u5ea6\u4e3a0 if len ( pk_values ) == 0 : # \u8fd4\u56de0 return 0 # \u5426\u5219\u8ba1\u7b97AP@K return sum ( pk_values ) / len ( pk_values ) \u8fd9\u4e24\u4e2a\u51fd\u6570\u53ef\u4ee5\u7528\u6765\u8ba1\u7b97\u4e24\u4e2a\u7ed9\u5b9a\u5217\u8868\u7684 k \u5e73\u5747\u7cbe\u786e\u7387 (AP@k)\uff1b\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u8ba1\u7b97\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: for i in range ( len ( y_true )): ... : for j in range ( 1 , 4 ): ... : print ( ... : f \"\"\" ...: y_true= { y_true [ i ] } , ...: y_pred= { y_pred [ i ] } , ...: AP@ { j } = { apk ( y_true [ i ], y_pred [ i ], k = j ) } ...: \"\"\" ... : ) ... : y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 1 = 0.0 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 2 = 0.25 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 3 = 0.38888888888888884 \u8bf7\u6ce8\u610f\uff0c\u6211\u7701\u7565\u4e86\u8f93\u51fa\u7ed3\u679c\u4e2d\u7684\u8bb8\u591a\u6570\u503c\uff0c\u4f46\u4f60\u4f1a\u660e\u767d\u5176\u4e2d\u7684\u610f\u601d\u3002\u8fd9\u5c31\u662f\u6211\u4eec\u5982\u4f55\u8ba1\u7b97 AP@k \u7684\u65b9\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u7684 AP@k\u3002\u5728\u673a\u5668\u5b66\u4e60\u4e2d\uff0c\u6211\u4eec\u5bf9\u6240\u6709\u6837\u672c\u90fd\u611f\u5174\u8da3\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u6709 \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387 k \u6216 MAP@k \u3002MAP@k \u53ea\u662f AP@k \u7684\u5e73\u5747\u503c\uff0c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b python \u4ee3\u7801\u8f7b\u677e\u8ba1\u7b97\u3002 def mapk ( y_true , y_pred , k ): # \u521d\u59cb\u5316AP@k\u5217\u8868 apk_values = [] # \u904d\u53860~\uff08\u771f\u5b9e\u6807\u7b7e\u6570-1\uff09 for i in range ( len ( y_true )): # \u5c06AP@K\u52a0\u5165\u5217\u8868 apk_values . append ( apk ( y_true [ i ], y_pred [ i ], k = k ) ) # \u8ba1\u7b97\u5e73\u5747AP@k return sum ( apk_values ) / len ( apk_values ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u9488\u5bf9\u76f8\u540c\u7684\u5217\u8868\u8ba1\u7b97 k=1\u30012\u30013 \u548c 4 \u65f6\u7684 MAP@k\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: mapk ( y_true , y_pred , k = 1 ) Out [ X ]: 0.3333333333333333 In [ X ]: mapk ( y_true , y_pred , k = 2 ) Out [ X ]: 0.375 In [ X ]: mapk ( y_true , y_pred , k = 3 ) Out [ X ]: 0.3611111111111111 In [ X ]: mapk ( y_true , y_pred , k = 4 ) Out [ X ]: 0.34722222222222215 P@k\u3001AP@k \u548c MAP@k \u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u5176\u4e2d 1 \u4e3a\u6700\u4f73\u3002 \u8bf7\u6ce8\u610f\uff0c\u6709\u65f6\u60a8\u53ef\u80fd\u4f1a\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230 P@k \u548c AP@k \u7684\u4e0d\u540c\u5b9e\u73b0\u65b9\u5f0f\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5176\u4e2d\u4e00\u79cd\u5b9e\u73b0\u65b9\u5f0f\u3002 import numpy as np def apk ( actual , predicted , k = 10 ): # \u82e5\u9884\u6d4b\u6807\u7b7e\u957f\u5ea6\u5927\u4e8ek if len ( predicted ) > k : # \u53d6\u524dk\u4e2a\u6807\u7b7e predicted = predicted [: k ] score = 0.0 num_hits = 0.0 for i , p in enumerate ( predicted ): if p in actual and p not in predicted [: i ]: num_hits += 1.0 score += num_hits / ( i + 1.0 ) if not actual : return 0.0 return score / min ( len ( actual ), k ) \u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u662f AP@k \u7684\u53e6\u4e00\u4e2a\u7248\u672c\uff0c\u5176\u4e2d\u987a\u5e8f\u5f88\u91cd\u8981\uff0c\u6211\u4eec\u8981\u6743\u8861\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u7684\u7ed3\u679c\u4e0e\u6211\u7684\u4ecb\u7ecd\u7565\u6709\u4e0d\u540c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6765\u770b\u770b \u591a\u6807\u7b7e\u5206\u7c7b\u7684\u5bf9\u6570\u635f\u5931 \u3002\u8fd9\u5f88\u5bb9\u6613\u3002\u60a8\u53ef\u4ee5\u5c06\u76ee\u6807\u8f6c\u6362\u4e3a\u4e8c\u5143\u5206\u7c7b\uff0c\u7136\u540e\u5bf9\u6bcf\u4e00\u5217\u4f7f\u7528\u5bf9\u6570\u635f\u5931\u3002\u6700\u540e\uff0c\u4f60\u53ef\u4ee5\u6c42\u51fa\u6bcf\u5217\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u8fd9\u4e5f\u88ab\u79f0\u4e3a\u5e73\u5747\u5217\u5bf9\u6570\u635f\u5931\u3002\u5f53\u7136\uff0c\u8fd8\u6709\u5176\u4ed6\u65b9\u6cd5\u53ef\u4ee5\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff0c\u4f60\u5e94\u8be5\u5728\u9047\u5230\u65f6\u52a0\u4ee5\u63a2\u7d22\u3002 \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u8bf4\u5df2\u7ecf\u638c\u63e1\u4e86\u6240\u6709\u4e8c\u5143\u5206\u7c7b\u3001\u591a\u7c7b\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u6307\u6807\uff0c\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u6307\u6807\u3002 \u56de\u5f52\u4e2d\u6700\u5e38\u89c1\u7684\u6307\u6807\u662f \u8bef\u5dee\uff08Error\uff09 \u3002\u8bef\u5dee\u5f88\u7b80\u5355\uff0c\u4e5f\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \\[ Error = True\\ Value - Predicted\\ Value \\] \u7edd\u5bf9\u8bef\u5dee\uff08Absolute error\uff09 \u53ea\u662f\u4e0a\u8ff0\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\u3002 \\[ Absolute\\ Error = Abs(True\\ Value - Predicted\\ Value) \\] \u63a5\u4e0b\u6765\u6211\u4eec\u8ba8\u8bba \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff08MAE\uff09 \u3002\u5b83\u53ea\u662f\u6240\u6709\u7edd\u5bf9\u8bef\u5dee\u7684\u5e73\u5747\u503c\u3002 import numpy as np def mean_absolute_error ( y_true , y_pred ): #\u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u7edd\u5bf9\u8bef\u5dee error += np . abs ( yt - yp ) # \u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee return error / len ( y_true ) \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5e73\u65b9\u8bef\u5dee\u548c \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u3002 \\[ Squared\\ Error = (True Value - Predicted\\ Value)^2 \\] \u5747\u65b9\u8bef\u5dee\uff08MSE\uff09\u7684\u8ba1\u7b97\u65b9\u5f0f\u5982\u4e0b def mean_squared_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u8bef\u5dee\u5e73\u65b9\u548c error += ( yt - yp ) ** 2 # \u8ba1\u7b97\u5747\u65b9\u8bef\u5dee return error / len ( y_true ) MSE \u548c RMSE\uff08\u5747\u65b9\u6839\u8bef\u5dee\uff09 \u662f\u8bc4\u4f30\u56de\u5f52\u6a21\u578b\u6700\u5e38\u7528\u7684\u6307\u6807\u3002 \\[ RMSE = SQRT(MSE) \\] \u540c\u4e00\u7c7b\u8bef\u5dee\u7684\u53e6\u4e00\u79cd\u7c7b\u578b\u662f \u5e73\u65b9\u5bf9\u6570\u8bef\u5dee \u3002\u6709\u4eba\u79f0\u5176\u4e3a SLE \uff0c\u5f53\u6211\u4eec\u53d6\u6240\u6709\u6837\u672c\u4e2d\u8fd9\u4e00\u8bef\u5dee\u7684\u5e73\u5747\u503c\u65f6\uff0c\u5b83\u88ab\u79f0\u4e3a MSLE\uff08\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee\uff09\uff0c\u5b9e\u73b0\u65b9\u6cd5\u5982\u4e0b\u3002 import numpy as np def mean_squared_log_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee error += ( np . log ( 1 + yt ) - np . log ( 1 + yp )) ** 2 # \u8ba1\u7b97\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee return error / len ( y_true ) \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \u53ea\u662f\u5176\u5e73\u65b9\u6839\u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a RMSLE \u3002 \u7136\u540e\u662f\u767e\u5206\u6bd4\u8bef\u5dee\uff1a \\[ Percentage\\ Error = (( True\\ Value \u2013 Predicted\\ Value ) / True\\ Value ) \\times 100 \\] \u540c\u6837\u53ef\u4ee5\u8f6c\u6362\u4e3a\u6240\u6709\u6837\u672c\u7684\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee\u3002 def mean_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u767e\u5206\u6bd4\u8bef\u5dee error += ( yt - yp ) / yt # \u8fd4\u56de\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u7edd\u5bf9\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\uff08\u4e5f\u662f\u66f4\u5e38\u89c1\u7684\u7248\u672c\uff09\u88ab\u79f0\u4e3a \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee\u6216 MAPE \u3002 import numpy as np def mean_abs_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee error += np . abs ( yt - yp ) / yt #\u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u56de\u5f52\u7684\u6700\u5927\u4f18\u70b9\u662f\uff0c\u53ea\u6709\u51e0\u4e2a\u6700\u5e38\u7528\u7684\u6307\u6807\uff0c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u6240\u6709\u56de\u5f52\u95ee\u9898\u3002\u4e0e\u5206\u7c7b\u6307\u6807\u76f8\u6bd4\uff0c\u56de\u5f52\u6307\u6807\u66f4\u5bb9\u6613\u7406\u89e3\u3002 \u8ba9\u6211\u4eec\u6765\u8c08\u8c08\u53e6\u4e00\u4e2a\u56de\u5f52\u6307\u6807 \\(R^2\\) \uff08R \u65b9\uff09\uff0c\u4e5f\u79f0\u4e3a \u5224\u5b9a\u7cfb\u6570 \u3002 \u7b80\u5355\u5730\u8bf4\uff0cR \u65b9\u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u3002R \u65b9\u63a5\u8fd1 1.0 \u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u76f8\u5f53\u597d\uff0c\u800c\u63a5\u8fd1 0 \u5219\u8868\u793a\u6a21\u578b\u4e0d\u662f\u90a3\u4e48\u597d\u3002\u5f53\u6a21\u578b\u53ea\u662f\u505a\u51fa\u8352\u8c2c\u7684\u9884\u6d4b\u65f6\uff0cR \u65b9\u4e5f\u53ef\u80fd\u662f\u8d1f\u503c\u3002 R \u65b9\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\u6240\u793a\uff0c\u4f46 Python \u7684\u5b9e\u73b0\u603b\u662f\u80fd\u8ba9\u4e00\u5207\u66f4\u52a0\u6e05\u6670\u3002 \\[ R^2 = \\frac{\\sum^{N}_{i=1}(y_{t_i}-y_{p_i})^2}{\\sum^{N}_{i=1}(y_{t_i} - y_{t_{mean}})} \\] import numpy as np def r2 ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u5747\u771f\u5b9e\u503c mean_true_value = np . mean ( y_true ) # \u521d\u59cb\u5316\u5e73\u65b9\u8bef\u5dee numerator = 0 denominator = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): numerator += ( yt - yp ) ** 2 denominator += ( yt - mean_true_value ) ** 2 ratio = numerator / denominator # \u8ba1\u7b97R\u65b9 return 1 \u2013 ratio \u8fd8\u6709\u66f4\u591a\u7684\u8bc4\u4ef7\u6307\u6807\uff0c\u8fd9\u4e2a\u6e05\u5355\u6c38\u8fdc\u4e5f\u5217\u4e0d\u5b8c\u3002\u6211\u53ef\u4ee5\u5199\u4e00\u672c\u4e66\uff0c\u53ea\u4ecb\u7ecd\u4e0d\u540c\u7684\u8bc4\u4ef7\u6307\u6807\u3002\u4e5f\u8bb8\u6211\u4f1a\u7684\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b\u8bc4\u4f30\u6307\u6807\u51e0\u4e4e\u53ef\u4ee5\u6ee1\u8db3\u4f60\u60f3\u5c1d\u8bd5\u89e3\u51b3\u7684\u6240\u6709\u95ee\u9898\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ee5\u6700\u76f4\u63a5\u7684\u65b9\u5f0f\u5b9e\u73b0\u4e86\u8fd9\u4e9b\u6307\u6807\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u4eec\u4e0d\u591f\u9ad8\u6548\u3002\u4f60\u53ef\u4ee5\u901a\u8fc7\u6b63\u786e\u4f7f\u7528 numpy \u4ee5\u975e\u5e38\u9ad8\u6548\u7684\u65b9\u5f0f\u5b9e\u73b0\u5176\u4e2d\u5927\u90e8\u5206\u6307\u6807\u3002\u4f8b\u5982\uff0c\u770b\u770b\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\u7684\u5b9e\u73b0\uff0c\u4e0d\u9700\u8981\u4efb\u4f55\u5faa\u73af\u3002 import numpy as np def mae_np ( y_true , y_pred ): return np . mean ( np . abs ( y_true - y_pred )) \u6211\u672c\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5b9e\u73b0\u6240\u6709\u6307\u6807\uff0c\u4f46\u4e3a\u4e86\u5b66\u4e60\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b\u5e95\u5c42\u5b9e\u73b0\u3002\u4e00\u65e6\u4f60\u5b66\u4f1a\u4e86\u7eaf python \u7684\u5e95\u5c42\u5b9e\u73b0\uff0c\u5e76\u4e14\u4e0d\u4f7f\u7528\u5927\u91cf numpy\uff0c\u4f60\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u5176\u8f6c\u6362\u4e3a numpy\uff0c\u5e76\u4f7f\u5176\u53d8\u5f97\u66f4\u5feb\u3002 \u7136\u540e\u662f\u4e00\u4e9b\u9ad8\u7ea7\u5ea6\u91cf\u3002 \u5176\u4e2d\u4e00\u4e2a\u5e94\u7528\u76f8\u5f53\u5e7f\u6cdb\u7684\u6307\u6807\u662f \u4e8c\u6b21\u52a0\u6743\u5361\u5e15 \uff0c\u4e5f\u79f0\u4e3a QWK \u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a\u79d1\u6069\u5361\u5e15\u3002 QWK \u8861\u91cf\u4e24\u4e2a \"\u8bc4\u5206 \"\u4e4b\u95f4\u7684 \"\u4e00\u81f4\u6027\"\u3002\u8bc4\u5206\u53ef\u4ee5\u662f 0 \u5230 N \u4e4b\u95f4\u7684\u4efb\u4f55\u5b9e\u6570\uff0c\u9884\u6d4b\u4e5f\u5728\u540c\u4e00\u8303\u56f4\u5185\u3002\u4e00\u81f4\u6027\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u8fd9\u4e9b\u8bc4\u7ea7\u4e4b\u95f4\u7684\u63a5\u8fd1\u7a0b\u5ea6\u3002\u56e0\u6b64\uff0c\u5b83\u9002\u7528\u4e8e\u6709 N \u4e2a\u4e0d\u540c\u7c7b\u522b\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u4e00\u81f4\u5ea6\u9ad8\uff0c\u5206\u6570\u5c31\u66f4\u63a5\u8fd1 1.0\u3002Cohen's kappa \u5728 scikit-learn \u4e2d\u6709\u5f88\u597d\u7684\u5b9e\u73b0\uff0c\u5173\u4e8e\u8be5\u6307\u6807\u7684\u8be6\u7ec6\u8ba8\u8bba\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 ] In [ X ]: y_pred = [ 2 , 1 , 3 , 1 , 2 , 3 , 3 , 1 , 2 ] In [ X ]: metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) Out [ X ]: 0.33333333333333337 In [ X ]: metrics . accuracy_score ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 \u60a8\u53ef\u4ee5\u770b\u5230\uff0c\u5c3d\u7ba1\u51c6\u786e\u5ea6\u5f88\u9ad8\uff0c\u4f46 QWK \u5374\u5f88\u4f4e\u3002QWK \u5927\u4e8e 0.85 \u5373\u4e3a\u975e\u5e38\u597d\uff01 \u4e00\u4e2a\u91cd\u8981\u7684\u6307\u6807\u662f \u9a6c\u4fee\u76f8\u5173\u7cfb\u6570\uff08MCC\uff09 \u30021 \u4ee3\u8868\u5b8c\u7f8e\u9884\u6d4b\uff0c-1 \u4ee3\u8868\u4e0d\u5b8c\u7f8e\u9884\u6d4b\uff0c0 \u4ee3\u8868\u968f\u673a\u9884\u6d4b\u3002MCC \u7684\u8ba1\u7b97\u516c\u5f0f\u975e\u5e38\u7b80\u5355\u3002 \\[ MCC = \\frac{TP \\times TN - FP \\times FN}{\\sqrt{(TP + FP) \\times (FN + TN) \\times (FP + TN) \\times (TP + FN)}} \\] \u6211\u4eec\u770b\u5230\uff0cMCC \u8003\u8651\u4e86 TP\u3001FP\u3001TN \u548c FN\uff0c\u56e0\u6b64\u53ef\u7528\u4e8e\u5904\u7406\u7c7b\u504f\u659c\u7684\u95ee\u9898\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u7684\u65b9\u6cd5\u5728 python \u4e2d\u5feb\u901f\u5b9e\u73b0\u5b83\u3002 def mcc ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) numerator = ( tp * tn ) - ( fp * fn ) denominator = ( ( tp + fp ) * ( fn + tn ) * ( fp + tn ) * ( tp + fn ) ) denominator = denominator ** 0.5 return numerator / denominator \u8fd9\u4e9b\u6307\u6807\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002 \u9700\u8981\u6ce8\u610f\u7684\u4e00\u70b9\u662f\uff0c\u5728\u8bc4\u4f30\u975e\u76d1\u7763\u65b9\u6cd5\uff08\u4f8b\u5982\u67d0\u79cd\u805a\u7c7b\uff09\u65f6\uff0c\u6700\u597d\u521b\u5efa\u6216\u624b\u52a8\u6807\u8bb0\u6d4b\u8bd5\u96c6\uff0c\u5e76\u5c06\u5176\u4e0e\u5efa\u6a21\u90e8\u5206\u7684\u6240\u6709\u5185\u5bb9\u5206\u5f00\u3002\u5b8c\u6210\u805a\u7c7b\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u4efb\u4f55\u4e00\u79cd\u76d1\u7763\u5b66\u4e60\u6307\u6807\u6765\u8bc4\u4f30\u6d4b\u8bd5\u96c6\u7684\u6027\u80fd\u4e86\u3002 \u4e00\u65e6\u6211\u4eec\u4e86\u89e3\u4e86\u7279\u5b9a\u95ee\u9898\u5e94\u8be5\u4f7f\u7528\u4ec0\u4e48\u6307\u6807\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f00\u59cb\u66f4\u6df1\u5165\u5730\u7814\u7a76\u6211\u4eec\u7684\u6a21\u578b\uff0c\u4ee5\u6c42\u6539\u8fdb\u3002","title":"\u8bc4\u4f30\u6307\u6807"},{"location":"%E8%AF%84%E4%BC%B0%E6%8C%87%E6%A0%87/#_1","text":"\u8bf4\u5230\u673a\u5668\u5b66\u4e60\u95ee\u9898\uff0c\u4f60\u4f1a\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u9047\u5230\u5f88\u591a\u4e0d\u540c\u7c7b\u578b\u7684\u6307\u6807\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u751a\u81f3\u4f1a\u6839\u636e\u4e1a\u52a1\u95ee\u9898\u521b\u5efa\u5ea6\u91cf\u6807\u51c6\u3002\u9010\u4e00\u4ecb\u7ecd\u548c\u89e3\u91ca\u6bcf\u4e00\u79cd\u5ea6\u91cf\u7c7b\u578b\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002\u76f8\u53cd\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u4e00\u4e9b\u6700\u5e38\u89c1\u7684\u5ea6\u91cf\u6807\u51c6\uff0c\u4f9b\u4f60\u5728\u6700\u521d\u7684\u51e0\u4e2a\u9879\u76ee\u4e2d\u4f7f\u7528\u3002 \u5728\u672c\u4e66\u7684\u5f00\u5934\uff0c\u6211\u4eec\u4ecb\u7ecd\u4e86\u76d1\u7763\u5b66\u4e60\u548c\u975e\u76d1\u7763\u5b66\u4e60\u3002\u867d\u7136\u65e0\u76d1\u7763\u5b66\u4e60\u53ef\u4ee5\u4f7f\u7528\u4e00\u4e9b\u6307\u6807\uff0c\u4f46\u6211\u4eec\u5c06\u53ea\u5173\u6ce8\u6709\u76d1\u7763\u5b66\u4e60\u3002\u8fd9\u662f\u56e0\u4e3a\u6709\u76d1\u7763\u95ee\u9898\u6bd4\u65e0\u76d1\u7763\u95ee\u9898\u591a\uff0c\u800c\u4e14\u5bf9\u65e0\u76d1\u7763\u65b9\u6cd5\u7684\u8bc4\u4f30\u76f8\u5f53\u4e3b\u89c2\u3002 \u5982\u679c\u6211\u4eec\u8c08\u8bba\u5206\u7c7b\u95ee\u9898\uff0c\u6700\u5e38\u7528\u7684\u6307\u6807\u662f\uff1a \u51c6\u786e\u7387\uff08Accuracy\uff09 \u7cbe\u786e\u7387\uff08P\uff09 \u53ec\u56de\u7387\uff08R\uff09 F1 \u5206\u6570\uff08F1\uff09 AUC\uff08AUC\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u8bf4\u5230\u56de\u5f52\uff0c\u6700\u5e38\u7528\u7684\u8bc4\u4ef7\u6307\u6807\u662f \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee \uff08MAE\uff09 \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u5747\u65b9\u6839\u8bef\u5dee \uff08RMSE\uff09 \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \uff08RMSLE\uff09 \u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee \uff08MPE\uff09 \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee \uff08MAPE\uff09 R2 \u4e86\u89e3\u4e0a\u8ff0\u6307\u6807\u7684\u5de5\u4f5c\u539f\u7406\u5e76\u4e0d\u662f\u6211\u4eec\u5fc5\u987b\u4e86\u89e3\u7684\u552f\u4e00\u4e8b\u60c5\u3002\u6211\u4eec\u8fd8\u5fc5\u987b\u77e5\u9053\u4f55\u65f6\u4f7f\u7528\u54ea\u4e9b\u6307\u6807\uff0c\u800c\u8fd9\u53d6\u51b3\u4e8e\u4f60\u6709\u4ec0\u4e48\u6837\u7684\u6570\u636e\u548c\u76ee\u6807\u3002\u6211\u8ba4\u4e3a\u8fd9\u4e0e\u76ee\u6807\u6709\u5173\uff0c\u800c\u4e0e\u6570\u636e\u65e0\u5173\u3002 \u8981\u8fdb\u4e00\u6b65\u4e86\u89e3\u8fd9\u4e9b\u6307\u6807\uff0c\u8ba9\u6211\u4eec\u4ece\u4e00\u4e2a\u7b80\u5355\u7684\u95ee\u9898\u5f00\u59cb\u3002\u5047\u8bbe\u6211\u4eec\u6709\u4e00\u4e2a \u4e8c\u5143\u5206\u7c7b \u95ee\u9898\uff0c\u5373\u53ea\u6709\u4e24\u4e2a\u76ee\u6807\u7684\u95ee\u9898\uff0c\u5047\u8bbe\u8fd9\u662f\u4e00\u4e2a\u80f8\u90e8 X \u5149\u56fe\u50cf\u5206\u7c7b\u95ee\u9898\u3002\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6ca1\u6709\u95ee\u9898\uff0c\u800c\u6709\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u6709\u80ba\u584c\u9677\uff0c\u4e5f\u5c31\u662f\u6240\u8c13\u7684\u6c14\u80f8\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u7684\u4efb\u52a1\u662f\u5efa\u7acb\u4e00\u4e2a\u5206\u7c7b\u5668\uff0c\u5728\u7ed9\u5b9a\u80f8\u90e8 X \u5149\u56fe\u50cf\u7684\u60c5\u51b5\u4e0b\uff0c\u5b83\u80fd\u68c0\u6d4b\u51fa\u56fe\u50cf\u662f\u5426\u6709\u6c14\u80f8\u3002 \u56fe 1\uff1a\u6c14\u80f8\u80ba\u90e8\u56fe\u50cf \u6211\u4eec\u8fd8\u5047\u8bbe\u6709\u76f8\u540c\u6570\u91cf\u7684\u6c14\u80f8\u548c\u975e\u6c14\u80f8\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u6bd4\u5982\u5404 100 \u5f20\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6709 100 \u5f20\u9633\u6027\u6837\u672c\u548c 100 \u5f20\u9634\u6027\u6837\u672c\uff0c\u5171\u8ba1 200 \u5f20\u56fe\u50cf\u3002 \u7b2c\u4e00\u6b65\u662f\u5c06\u4e0a\u8ff0\u6570\u636e\u5206\u4e3a\u4e24\u7ec4\uff0c\u6bcf\u7ec4 100 \u5f20\u56fe\u50cf\uff0c\u5373\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u8fd9\u4e24\u4e2a\u96c6\u5408\u4e2d\uff0c\u6211\u4eec\u90fd\u6709 50 \u4e2a\u6b63\u6837\u672c\u548c 50 \u4e2a\u8d1f\u6837\u672c\u3002 \u5728\u4e8c\u5143\u5206\u7c7b\u6307\u6807\u4e2d\uff0c\u5f53\u6b63\u8d1f\u6837\u672c\u6570\u91cf\u76f8\u7b49\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f7f\u7528\u51c6\u786e\u7387\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002 \u51c6\u786e\u7387 \uff1a\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6700\u76f4\u63a5\u7684\u6307\u6807\u4e4b\u4e00\u3002\u5b83\u5b9a\u4e49\u4e86\u6a21\u578b\u7684\u51c6\u786e\u5ea6\u3002\u5bf9\u4e8e\u4e0a\u8ff0\u95ee\u9898\uff0c\u5982\u679c\u4f60\u5efa\u7acb\u7684\u6a21\u578b\u80fd\u51c6\u786e\u5206\u7c7b 90 \u5f20\u56fe\u7247\uff0c\u90a3\u4e48\u4f60\u7684\u51c6\u786e\u7387\u5c31\u662f 90% \u6216 0.90\u3002\u5982\u679c\u53ea\u6709 83 \u5e45\u56fe\u50cf\u88ab\u6b63\u786e\u5206\u7c7b\uff0c\u90a3\u4e48\u6a21\u578b\u7684\u51c6\u786e\u7387\u5c31\u662f 83% \u6216 0.83\u3002 \u8ba1\u7b97\u51c6\u786e\u7387\u7684 Python \u4ee3\u7801\u4e5f\u975e\u5e38\u7b80\u5355\u3002 def accuracy ( y_true , y_pred ): # \u4e3a\u6b63\u786e\u9884\u6d4b\u6570\u521d\u59cb\u5316\u4e00\u4e2a\u7b80\u5355\u8ba1\u6570\u5668 correct_counter = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): if yt == yp : # \u5982\u679c\u9884\u6d4b\u6807\u7b7e\u4e0e\u771f\u5b9e\u6807\u7b7e\u76f8\u540c\uff0c\u5219\u589e\u52a0\u8ba1\u6570\u5668 correct_counter += 1 # \u8fd4\u56de\u6b63\u786e\u7387\uff0c\u6b63\u786e\u6807\u7b7e\u6570/\u603b\u6807\u7b7e\u6570 return correct_counter / len ( y_true ) \u6211\u4eec\u8fd8\u53ef\u4ee5\u4f7f\u7528 scikit-learn \u8ba1\u7b97\u51c6\u786e\u7387\u3002 In [ X ]: from sklearn import metrics ... : l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] ... : metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u73b0\u5728\uff0c\u5047\u8bbe\u6211\u4eec\u628a\u6570\u636e\u96c6\u7a0d\u5fae\u6539\u52a8\u4e00\u4e0b\uff0c\u6709 180 \u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\uff0c\u53ea\u6709 20 \u5f20\u6709\u6c14\u80f8\u3002\u5373\u4f7f\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u4e5f\u8981\u521b\u5efa\u6b63\u8d1f\uff08\u6c14\u80f8\u4e0e\u975e\u6c14\u80f8\uff09\u76ee\u6807\u6bd4\u4f8b\u76f8\u540c\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u3002\u5728\u6bcf\u4e00\u7ec4\u4e2d\uff0c\u6211\u4eec\u6709 90 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u5982\u679c\u8bf4\u9a8c\u8bc1\u96c6\u4e2d\u7684\u6240\u6709\u56fe\u50cf\u90fd\u662f\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u90a3\u4e48\u60a8\u7684\u51c6\u786e\u7387\u4f1a\u662f\u591a\u5c11\u5462\uff1f\u8ba9\u6211\u4eec\u6765\u770b\u770b\uff1b\u60a8\u5bf9 90% \u7684\u56fe\u50cf\u8fdb\u884c\u4e86\u6b63\u786e\u5206\u7c7b\u3002\u56e0\u6b64\uff0c\u60a8\u7684\u51c6\u786e\u7387\u662f 90%\u3002 \u4f46\u8bf7\u518d\u770b\u4e00\u904d\u3002 \u4f60\u751a\u81f3\u6ca1\u6709\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\uff0c\u5c31\u5f97\u5230\u4e86 90% \u7684\u51c6\u786e\u7387\u3002\u8fd9\u4f3c\u4e4e\u6709\u70b9\u6ca1\u7528\u3002\u5982\u679c\u6211\u4eec\u4ed4\u7ec6\u89c2\u5bdf\uff0c\u5c31\u4f1a\u53d1\u73b0\u6570\u636e\u96c6\u662f\u504f\u659c\u7684\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u6bd4\u53e6\u4e00\u4e2a\u7c7b\u522b\u4e2d\u7684\u6837\u672c\u6570\u91cf\u591a\u5f88\u591a\u3002\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u4f7f\u7528\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u662f\u4e0d\u53ef\u53d6\u7684\uff0c\u56e0\u4e3a\u5b83\u4e0d\u80fd\u4ee3\u8868\u6570\u636e\u3002\u56e0\u6b64\uff0c\u60a8\u53ef\u80fd\u4f1a\u83b7\u5f97\u5f88\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u4f46\u60a8\u7684\u6a21\u578b\u5728\u5b9e\u9645\u6837\u672c\u4e2d\u7684\u8868\u73b0\u53ef\u80fd\u5e76\u4e0d\u7406\u60f3\uff0c\u800c\u4e14\u60a8\u4e5f\u65e0\u6cd5\u5411\u7ecf\u7406\u89e3\u91ca\u539f\u56e0\u3002 \u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b \u7cbe\u786e\u7387 \u7b49\u5176\u4ed6\u6307\u6807\u3002 \u5728\u5b66\u4e60\u7cbe\u786e\u7387\u4e4b\u524d\uff0c\u6211\u4eec\u9700\u8981\u4e86\u89e3\u4e00\u4e9b\u672f\u8bed\u3002\u5728\u8fd9\u91cc\uff0c\u6211\u4eec\u5047\u8bbe\u6709\u6c14\u80f8\u7684\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e3a\u6b63\u7c7b (1)\uff0c\u6ca1\u6709\u6c14\u80f8\u7684\u4e3a\u8d1f\u7c7b (0)\u3002 \u771f\u9633\u6027 \uff08TP\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6709\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9633\u6027\u3002 \u771f\u9634\u6027 \uff08TN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u800c\u5b9e\u9645\u76ee\u6807\u663e\u793a\u8be5\u56fe\u50cf\u6ca1\u6709\u6c14\u80f8\uff0c\u5219\u89c6\u4e3a\u771f\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u6b63\u786e\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9633\u6027\uff1b\u5982\u679c\u60a8\u7684\u6a21\u578b\u51c6\u786e\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5b83\u5c31\u662f\u771f\u9634\u6027\u3002 \u5047\u9633\u6027 \uff08FP\uff09 \uff1a\u7ed9\u5b9a\u4e00\u5f20\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u975e\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9633\u6027\u3002 \u5047\u9634\u6027 \uff08FN\uff09 \uff1a \u7ed9\u5b9a\u4e00\u5e45\u56fe\u50cf\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9884\u6d4b\u4e3a\u975e\u6c14\u80f8\uff0c\u800c\u8be5\u56fe\u50cf\u7684\u5b9e\u9645\u76ee\u6807\u662f\u6c14\u80f8\uff0c\u5219\u4e3a\u5047\u9634\u6027\u3002 \u7b80\u5355\u5730\u8bf4\uff0c\u5982\u679c\u60a8\u7684\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9633\u6027\u7c7b\uff0c\u90a3\u4e48\u5b83\u5c31\u662f\u5047\u9633\u6027\u3002\u5982\u679c\u6a21\u578b\u9519\u8bef\u5730\uff08\u6216\u865a\u5047\u5730\uff09\u9884\u6d4b\u4e86\u9634\u6027\u7c7b\u522b\uff0c\u5219\u662f\u5047\u9634\u6027\u3002 \u8ba9\u6211\u4eec\u9010\u4e00\u770b\u770b\u8fd9\u4e9b\u5b9e\u73b0\u3002 def true_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u8ba1\u6570\u5668 tp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 1 : tp += 1 # \u8fd4\u56de\u771f\u9633\u6027\u6837\u672c\u6570 return tp def true_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u771f\u9634\u6027\u6837\u672c\u8ba1\u6570\u5668 tn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u4e14\u9884\u6d4b\u6807\u7b7e\u4e5f\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 0 : tn += 1 # \u8fd4\u56de\u771f\u9634\u6027\u6837\u672c\u6570 return tn def false_positive ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9633\u6027\u8ba1\u6570\u5668 fp = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u8d1f\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u6b63\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 0 and yp == 1 : fp += 1 # \u8fd4\u56de\u5047\u9633\u6027\u6837\u672c\u6570 return fp def false_negative ( y_true , y_pred ): # \u521d\u59cb\u5316\u5047\u9634\u6027\u8ba1\u6570\u5668 fn = 0 # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_pred ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3a\u6b63\u7c7b\u800c\u9884\u6d4b\u6807\u7b7e\u4e3a\u8d1f\u7c7b\uff0c\u8ba1\u6570\u5668\u589e\u52a0 if yt == 1 and yp == 0 : fn += 1 # \u8fd4\u56de\u5047\u9634\u6027\u6570 return fn \u6211\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e9b\u529f\u80fd\u7684\u65b9\u6cd5\u975e\u5e38\u7b80\u5355\uff0c\u800c\u4e14\u53ea\u9002\u7528\u4e8e\u4e8c\u5143\u5206\u7c7b\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u8fd9\u4e9b\u51fd\u6570\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: true_positive ( l1 , l2 ) Out [ X ]: 2 In [ X ]: false_positive ( l1 , l2 ) Out [ X ]: 1 In [ X ]: false_negative ( l1 , l2 ) Out [ X ]: 2 In [ X ]: true_negative ( l1 , l2 ) Out [ X ]: 3 \u5982\u679c\u6211\u4eec\u5fc5\u987b\u7528\u4e0a\u8ff0\u672f\u8bed\u6765\u5b9a\u4e49\u7cbe\u786e\u7387\uff0c\u6211\u4eec\u53ef\u4ee5\u5199\u4e3a\uff1a \\[ Accuracy Score = (TP + TN)/(TP + TN + FP +FN) \\] \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u5728 python \u4e2d\u4f7f\u7528 TP\u3001TN\u3001FP \u548c FN \u5feb\u901f\u5b9e\u73b0\u51c6\u786e\u5ea6\u5f97\u5206\u3002\u6211\u4eec\u5c06\u5176\u79f0\u4e3a accuracy_v2\u3002 def accuracy_v2 ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u51c6\u786e\u7387 accuracy_score = ( tp + tn ) / ( tp + tn + fp + fn ) return accuracy_score \u6211\u4eec\u53ef\u4ee5\u901a\u8fc7\u4e0e\u4e4b\u524d\u7684\u5b9e\u73b0\u548c scikit-learn \u7248\u672c\u8fdb\u884c\u6bd4\u8f83\uff0c\u5feb\u901f\u68c0\u67e5\u8be5\u51fd\u6570\u7684\u6b63\u786e\u6027\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: accuracy ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: accuracy_v2 ( l1 , l2 ) Out [ X ]: 0.625 In [ X ]: metrics . accuracy_score ( l1 , l2 ) Out [ X ]: 0.625 \u8bf7\u6ce8\u610f\uff0c\u5728\u8fd9\u6bb5\u4ee3\u7801\u4e2d\uff0cmetrics.accuracy_score \u6765\u81ea scikit-learn\u3002 \u5f88\u597d\u3002\u6240\u6709\u503c\u90fd\u5339\u914d\u3002\u8fd9\u8bf4\u660e\u6211\u4eec\u5728\u5b9e\u73b0\u8fc7\u7a0b\u4e2d\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u5176\u4ed6\u91cd\u8981\u6307\u6807\u3002 \u9996\u5148\u662f\u7cbe\u786e\u7387\u3002\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u662f \\[ Precision = TP/(TP + FP) \\] \u5047\u8bbe\u6211\u4eec\u5728\u65b0\u7684\u504f\u659c\u6570\u636e\u96c6\u4e0a\u5efa\u7acb\u4e86\u4e00\u4e2a\u65b0\u6a21\u578b\uff0c\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 90 \u5f20\u56fe\u50cf\u4e2d\u7684 80 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u548c 10 \u5f20\u56fe\u50cf\u4e2d\u7684 8 \u5f20\u6c14\u80f8\u56fe\u50cf\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6210\u529f\u8bc6\u522b\u4e86 100 \u5f20\u56fe\u50cf\u4e2d\u7684 88 \u5f20\u3002\u56e0\u6b64\uff0c\u51c6\u786e\u7387\u4e3a 0.88 \u6216 88%\u3002 \u4f46\u662f\uff0c\u5728\u8fd9 100 \u5f20\u6837\u672c\u4e2d\uff0c\u6709 10 \u5f20\u975e\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u6c14\u80f8\uff0c2 \u5f20\u6c14\u80f8\u56fe\u50cf\u88ab\u8bef\u5224\u4e3a\u975e\u6c14\u80f8\u3002 \u56e0\u6b64\uff0c\u6211\u4eec\u6709 TP : 8 TN: 80 FP: 10 FN: 2 \u7cbe\u786e\u7387\u4e3a 8 / (8 + 10) = 0.444\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u5728\u8bc6\u522b\u9633\u6027\u6837\u672c\uff08\u6c14\u80f8\uff09\u65f6\u6709 44.4% \u7684\u6b63\u786e\u7387\u3002 \u73b0\u5728\uff0c\u65e2\u7136\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86 TP\u3001TN\u3001FP \u548c FN\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5728 python \u4e2d\u5b9e\u73b0\u7cbe\u786e\u7387\u4e86\u3002 def precision ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8ba9\u6211\u4eec\u8bd5\u8bd5\u8fd9\u79cd\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u65b9\u5f0f\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: precision ( l1 , l2 ) Out [ X ]: 0.6666666666666666 \u8fd9\u4f3c\u4e4e\u6ca1\u6709\u95ee\u9898\u3002 \u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u6765\u770b \u53ec\u56de\u7387 \u3002\u53ec\u56de\u7387\u7684\u5b9a\u4e49\u662f\uff1a \\[ Recall = TP/(TP + FN) \\] \u5728\u4e0a\u8ff0\u60c5\u51b5\u4e0b\uff0c\u53ec\u56de\u7387\u4e3a 8 / (8 + 2) = 0.80\u3002\u8fd9\u610f\u5473\u7740\u6211\u4eec\u7684\u6a21\u578b\u6b63\u786e\u8bc6\u522b\u4e86 80% \u7684\u9633\u6027\u6837\u672c\u3002 def recall ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) # \u53ec\u56de\u7387 recall = tp / ( tp + fn ) return recall \u5c31\u6211\u4eec\u7684\u4e24\u4e2a\u5c0f\u5217\u8868\u800c\u8a00\uff0c\u53ec\u56de\u7387\u5e94\u8be5\u662f 0.5\u3002\u8ba9\u6211\u4eec\u68c0\u67e5\u4e00\u4e0b\u3002 In [ X ]: l1 = [ 0 , 1 , 1 , 1 , 0 , 0 , 0 , 1 ] ... : l2 = [ 0 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ] In [ X ]: recall ( l1 , l2 ) Out [ X ]: 0.5 \u8fd9\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u503c\u76f8\u7b26\uff01 \u5bf9\u4e8e\u4e00\u4e2a \"\u597d \"\u6a21\u578b\u6765\u8bf4\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u503c\u90fd\u5e94\u8be5\u5f88\u9ad8\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u4e0a\u9762\u7684\u4f8b\u5b50\u4e2d\uff0c\u53ec\u56de\u503c\u76f8\u5f53\u9ad8\u3002\u4f46\u662f\uff0c\u7cbe\u786e\u7387\u5374\u5f88\u4f4e\uff01\u6211\u4eec\u7684\u6a21\u578b\u4ea7\u751f\u4e86\u5927\u91cf\u7684\u8bef\u62a5\uff0c\u4f46\u8bef\u62a5\u8f83\u5c11\u3002\u5728\u8fd9\u7c7b\u95ee\u9898\u4e2d\uff0c\u5047\u9634\u6027\u8f83\u5c11\u662f\u597d\u4e8b\uff0c\u56e0\u4e3a\u4f60\u4e0d\u60f3\u5728\u75c5\u4eba\u6709\u6c14\u80f8\u7684\u60c5\u51b5\u4e0b\u5374\u8bf4\u4ed6\u4eec\u6ca1\u6709\u6c14\u80f8\u3002\u8fd9\u6837\u505a\u4f1a\u9020\u6210\u66f4\u5927\u7684\u4f24\u5bb3\u3002\u4f46\u6211\u4eec\u4e5f\u6709\u5f88\u591a\u5047\u9633\u6027\u7ed3\u679c\uff0c\u8fd9\u4e5f\u4e0d\u662f\u597d\u4e8b\u3002 \u5927\u591a\u6570\u6a21\u578b\u90fd\u4f1a\u9884\u6d4b\u4e00\u4e2a\u6982\u7387\uff0c\u5f53\u6211\u4eec\u9884\u6d4b\u65f6\uff0c\u901a\u5e38\u4f1a\u5c06\u8fd9\u4e2a\u9608\u503c\u9009\u4e3a 0.5\u3002\u8fd9\u4e2a\u9608\u503c\u5e76\u4e0d\u603b\u662f\u7406\u60f3\u7684\uff0c\u6839\u636e\u8fd9\u4e2a\u9608\u503c\uff0c\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u503c\u53ef\u80fd\u4f1a\u53d1\u751f\u5f88\u5927\u7684\u53d8\u5316\u3002\u5982\u679c\u6211\u4eec\u9009\u62e9\u7684\u6bcf\u4e2a\u9608\u503c\u90fd\u80fd\u8ba1\u7b97\u51fa\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u90a3\u4e48\u6211\u4eec\u5c31\u53ef\u4ee5\u5728\u8fd9\u4e9b\u503c\u4e4b\u95f4\u7ed8\u5236\u51fa\u66f2\u7ebf\u56fe\u3002\u8fd9\u5e45\u56fe\u6216\u66f2\u7ebf\u88ab\u79f0\u4e3a \"\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\"\u3002 \u5728\u7814\u7a76\u7cbe\u786e\u7387-\u8c03\u7528\u66f2\u7ebf\u4e4b\u524d\uff0c\u6211\u4eec\u5148\u5047\u8bbe\u6709\u4e24\u4e2a\u5217\u8868\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0.02638412 , 0.11114267 , 0.31620708 , ... : 0.0490937 , 0.0191491 , 0.17554844 , ... : 0.15952202 , 0.03819563 , 0.11639273 , ... : 0.079377 , 0.08584789 , 0.39095342 , ... : 0.27259048 , 0.03447096 , 0.04644807 , ... : 0.03543574 , 0.18521942 , 0.05934905 , ... : 0.61977213 , 0.33056815 ] \u56e0\u6b64\uff0cy_true \u662f\u6211\u4eec\u7684\u76ee\u6807\u503c\uff0c\u800c y_pred \u662f\u6837\u672c\u88ab\u8d4b\u503c\u4e3a 1 \u7684\u6982\u7387\u503c\u3002\u56e0\u6b64\uff0c\u73b0\u5728\u6211\u4eec\u8981\u770b\u7684\u662f\u9884\u6d4b\u4e2d\u7684\u6982\u7387\uff0c\u800c\u4e0d\u662f\u9884\u6d4b\u503c\uff08\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u9884\u6d4b\u503c\u7684\u8ba1\u7b97\u9608\u503c\u4e3a 0.5\uff09\u3002 precisions = [] recalls = [] thresholds = [ 0.0490937 , 0.05934905 , 0.079377 , 0.08584789 , 0.11114267 , 0.11639273 , 0.15952202 , 0.17554844 , 0.18521942 , 0.27259048 , 0.31620708 , 0.33056815 , 0.39095342 , 0.61977213 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for i in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_prediction = [ 1 if x >= i else 0 for x in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , temp_prediction ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , temp_prediction ) # \u52a0\u5165\u7cbe\u786e\u7387\u5217\u8868 precisions . append ( p ) # \u52a0\u5165\u53ec\u56de\u7387\u5217\u8868 recalls . append ( r ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u7ed8\u5236\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 7 , 7 )) # x\u8f74\u4e3a\u53ec\u56de\u7387\uff0cy\u8f74\u4e3a\u7cbe\u786e\u7387 plt . plot ( recalls , precisions ) # \u6dfb\u52a0x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a15 plt . xlabel ( 'Recall' , fontsize = 15 ) # \u6dfb\u52a0y\u8f74\u6807\u7b7e\uff0c\u5b57\u6761\u5927\u5c0f\u4e3a15 plt . ylabel ( 'Precision' , fontsize = 15 ) \u56fe 2 \u663e\u793a\u4e86\u6211\u4eec\u901a\u8fc7\u8fd9\u79cd\u65b9\u6cd5\u5f97\u5230\u7684\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf\u3002 \u56fe 2\uff1a\u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u8fd9\u6761 \u7cbe\u786e\u7387-\u53ec\u56de\u7387\u66f2\u7ebf \u4e0e\u60a8\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230\u7684\u66f2\u7ebf\u622a\u7136\u4e0d\u540c\u3002\u8fd9\u662f\u56e0\u4e3a\u6211\u4eec\u53ea\u6709 20 \u4e2a\u6837\u672c\uff0c\u5176\u4e2d\u53ea\u6709 3 \u4e2a\u662f\u9633\u6027\u6837\u672c\u3002\u4f46\u8fd9\u6ca1\u4ec0\u4e48\u597d\u62c5\u5fc3\u7684\u3002\u8fd9\u8fd8\u662f\u90a3\u6761\u7cbe\u786e\u7387-\u53ec\u56de\u66f2\u7ebf\u3002 \u4f60\u4f1a\u53d1\u73b0\uff0c\u9009\u62e9\u4e00\u4e2a\u65e2\u80fd\u63d0\u4f9b\u826f\u597d\u7cbe\u786e\u7387\u53c8\u80fd\u63d0\u4f9b\u53ec\u56de\u503c\u7684\u9608\u503c\u662f\u5f88\u6709\u6311\u6218\u6027\u7684\u3002\u5982\u679c\u9608\u503c\u8fc7\u9ad8\uff0c\u771f\u9633\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u51cf\u5c11\uff0c\u800c\u5047\u9634\u6027\u7684\u6570\u91cf\u5c31\u4f1a\u589e\u52a0\u3002\u8fd9\u4f1a\u964d\u4f4e\u53ec\u56de\u7387\uff0c\u4f46\u7cbe\u786e\u7387\u5f97\u5206\u4f1a\u5f88\u9ad8\u3002\u5982\u679c\u5c06\u9608\u503c\u964d\u5f97\u592a\u4f4e\uff0c\u5219\u8bef\u62a5\u4f1a\u5927\u91cf\u589e\u52a0\uff0c\u7cbe\u786e\u7387\u4e5f\u4f1a\u964d\u4f4e\u3002 \u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u8d8a\u63a5\u8fd1 1 \u8d8a\u597d\u3002 F1 \u5206\u6570\u662f\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7efc\u5408\u6307\u6807\u3002\u5b83\u88ab\u5b9a\u4e49\u4e3a\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u7684\u7b80\u5355\u52a0\u6743\u5e73\u5747\u503c\uff08\u8c03\u548c\u5e73\u5747\u503c\uff09\u3002\u5982\u679c\u6211\u4eec\u7528 P \u8868\u793a\u7cbe\u786e\u7387\uff0c\u7528 R \u8868\u793a\u53ec\u56de\u7387\uff0c\u90a3\u4e48 F1 \u5206\u6570\u53ef\u4ee5\u8868\u793a\u4e3a\uff1a \\[ F1 = 2PR/(P + R) \\] \u6839\u636e TP\u3001FP \u548c FN\uff0c\u7a0d\u52a0\u6570\u5b66\u8ba1\u7b97\u5c31\u80fd\u5f97\u51fa\u4ee5\u4e0b F1 \u7b49\u5f0f\uff1a \\[ F1 = 2TP/(2TP + FP + FN) \\] Python \u5b9e\u73b0\u5f88\u7b80\u5355\uff0c\u56e0\u4e3a\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u4e86\u8fd9\u4e9b def f1 ( y_true , y_pred ): # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( y_true , y_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( y_true , y_pred ) # \u8ba1\u7b97f1\u503c score = 2 * p * r / ( p + r ) return score \u8ba9\u6211\u4eec\u770b\u770b\u5176\u7ed3\u679c\uff0c\u5e76\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\u3002 In [ X ]: y_true = [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: y_pred = [ 0 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , ... : 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 ] In [ X ]: f1 ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u901a\u8fc7 scikit learn\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u76f8\u540c\u7684\u5217\u8868\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . f1_score ( y_true , y_pred ) Out [ X ]: 0.5714285714285715 \u4e0e\u5176\u5355\u72ec\u770b\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\uff0c\u60a8\u8fd8\u53ef\u4ee5\u53ea\u770b F1 \u5206\u6570\u3002\u4e0e\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c\u51c6\u786e\u5ea6\u4e00\u6837\uff0cF1 \u5206\u6570\u7684\u8303\u56f4\u4e5f\u662f\u4ece 0 \u5230 1\uff0c\u5b8c\u7f8e\u9884\u6d4b\u6a21\u578b\u7684 F1 \u5206\u6570\u4e3a 1\u3002 \u6b64\u5916\uff0c\u6211\u4eec\u8fd8\u5e94\u8be5\u4e86\u89e3\u5176\u4ed6\u4e00\u4e9b\u5173\u952e\u672f\u8bed\u3002 \u7b2c\u4e00\u4e2a\u672f\u8bed\u662f TPR \u6216\u771f\u9633\u6027\u7387\uff08True Positive Rate\uff09\uff0c\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\u3002 \\[ TPR = TP/(TP + FN) \\] \u5c3d\u7ba1\u5b83\u4e0e\u53ec\u56de\u7387\u76f8\u540c\uff0c\u4f46\u6211\u4eec\u5c06\u4e3a\u5b83\u521b\u5efa\u4e00\u4e2a python \u51fd\u6570\uff0c\u4ee5\u4fbf\u4eca\u540e\u4f7f\u7528\u8fd9\u4e2a\u540d\u79f0\u3002 def tpr ( y_true , y_pred ): # \u771f\u9633\u6027\u7387\uff08TPR\uff09\uff0c\u4e0e\u53ec\u56de\u7387\u8ba1\u7b97\u516c\u5f0f\u4e00\u81f4 return recall ( y_true , y_pred ) TPR \u6216\u53ec\u56de\u7387\u4e5f\u88ab\u79f0\u4e3a\u7075\u654f\u5ea6\u3002 \u800c FPR \u6216\u5047\u9633\u6027\u7387\uff08False Positive Rate\uff09\u7684\u5b9a\u4e49\u662f\uff1a \\[ FPR = FP / (TN + FP) \\] def fpr ( y_true , y_pred ): # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u8fd4\u56de\u5047\u9633\u6027\u7387\uff08FPR\uff09 return fp / ( tn + fp ) 1 - FPR \u88ab\u79f0\u4e3a\u7279\u5f02\u6027\u6216\u771f\u9634\u6027\u7387\u6216 TNR\u3002\u8fd9\u4e9b\u672f\u8bed\u5f88\u591a\uff0c\u4f46\u5176\u4e2d\u6700\u91cd\u8981\u7684\u53ea\u6709 TPR \u548c FPR\u3002\u5047\u8bbe\u6211\u4eec\u53ea\u6709 15 \u4e2a\u6837\u672c\uff0c\u5176\u76ee\u6807\u503c\u4e3a\u4e8c\u5143\uff1a Actual targets : [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1] \u6211\u4eec\u8bad\u7ec3\u4e00\u4e2a\u7c7b\u4f3c\u968f\u673a\u68ee\u6797\u7684\u6a21\u578b\uff0c\u5c31\u80fd\u5f97\u5230\u6837\u672c\u5448\u9633\u6027\u7684\u6982\u7387\u3002 Predicted probabilities for 1: [0.1, 0.3, 0.2, 0.6, 0.8, 0.05, 0.9, 0.5, 0.3, 0.66, 0.3, 0.2, 0.85, 0.15, 0.99] \u5bf9\u4e8e >= 0.5 \u7684\u5178\u578b\u9608\u503c\uff0c\u6211\u4eec\u53ef\u4ee5\u8bc4\u4f30\u4e0a\u8ff0\u6240\u6709\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387/TPR\u3001F1 \u548c FPR \u503c\u3002\u4f46\u662f\uff0c\u5982\u679c\u6211\u4eec\u5c06\u9608\u503c\u9009\u4e3a 0.4 \u6216 0.6\uff0c\u4e5f\u53ef\u4ee5\u505a\u5230\u8fd9\u4e00\u70b9\u3002\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0 \u5230 1 \u4e4b\u95f4\u7684\u4efb\u4f55\u503c\uff0c\u5e76\u8ba1\u7b97\u4e0a\u8ff0\u6240\u6709\u6307\u6807\u3002 \u4e0d\u8fc7\uff0c\u6211\u4eec\u53ea\u8ba1\u7b97\u4e24\u4e2a\u503c\uff1a TPR \u548c FPR\u3002 # \u521d\u59cb\u5316\u771f\u9633\u6027\u7387\u5217\u8868 tpr_list = [] # \u521d\u59cb\u5316\u5047\u9633\u6027\u7387\u5217\u8868 fpr_list = [] # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u7387 temp_tpr = tpr ( y_true , temp_pred ) # \u5047\u9633\u6027\u7387 temp_fpr = fpr ( y_true , temp_pred ) # \u5c06\u771f\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 tpr_list . append ( temp_tpr ) # \u5c06\u5047\u9633\u6027\u7387\u52a0\u5165\u5217\u8868 fpr_list . append ( temp_fpr ) \u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u4ee5\u5f97\u5230\u6bcf\u4e2a\u9608\u503c\u7684 TPR \u503c\u548c FPR \u503c\u3002 \u56fe 3\uff1a\u9608\u503c\u3001TPR \u548c FPR \u503c\u8868 \u5982\u679c\u6211\u4eec\u7ed8\u5236\u5982\u56fe 3 \u6240\u793a\u7684\u8868\u683c\uff0c\u5373\u4ee5 TPR \u4e3a Y \u8f74\uff0cFPR \u4e3a X \u8f74\uff0c\u5c31\u4f1a\u5f97\u5230\u5982\u56fe 4 \u6240\u793a\u7684\u66f2\u7ebf\u3002 \u56fe 4\uff1aROC\u66f2\u7ebf \u8fd9\u6761\u66f2\u7ebf\u4e5f\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u3002\u5982\u679c\u6211\u4eec\u8ba1\u7b97\u8fd9\u6761 ROC \u66f2\u7ebf\u4e0b\u7684\u9762\u79ef\uff0c\u5c31\u662f\u5728\u8ba1\u7b97\u53e6\u4e00\u4e2a\u6307\u6807\uff0c\u5f53\u6570\u636e\u96c6\u7684\u4e8c\u5143\u76ee\u6807\u504f\u659c\u65f6\uff0c\u8fd9\u4e2a\u6307\u6807\u5c31\u4f1a\u975e\u5e38\u5e38\u7528\u3002 \u8fd9\u4e2a\u6307\u6807\u88ab\u79f0\u4e3a ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u6216\u66f2\u7ebf\u4e0b\u9762\u79ef\uff0c\u7b80\u79f0 AUC\u3002\u8ba1\u7b97 ROC \u66f2\u7ebf\u4e0b\u9762\u79ef\u7684\u65b9\u6cd5\u6709\u5f88\u591a\u3002\u5728\u6b64\uff0c\u6211\u4eec\u5c06\u91c7\u7528 scikit- learn \u7684\u5947\u5999\u5b9e\u73b0\u65b9\u6cd5\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: metrics . roc_auc_score ( y_true , y_pred ) Out [ X ]: 0.8300000000000001 AUC \u503c\u4ece 0 \u5230 1 \u4e0d\u7b49\u3002 AUC = 1 \u610f\u5473\u7740\u60a8\u62e5\u6709\u4e00\u4e2a\u5b8c\u7f8e\u7684\u6a21\u578b\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u610f\u5473\u7740\u4f60\u5728\u9a8c\u8bc1\u65f6\u72af\u4e86\u4e00\u4e9b\u9519\u8bef\uff0c\u5e94\u8be5\u91cd\u65b0\u5ba1\u89c6\u6570\u636e\u5904\u7406\u548c\u9a8c\u8bc1\u6d41\u7a0b\u3002\u5982\u679c\u4f60\u6ca1\u6709\u72af\u4efb\u4f55\u9519\u8bef\uff0c\u90a3\u4e48\u606d\u559c\u4f60\uff0c\u4f60\u5df2\u7ecf\u62e5\u6709\u4e86\u9488\u5bf9\u6570\u636e\u96c6\u5efa\u7acb\u7684\u6700\u4f73\u6a21\u578b\u3002 AUC = 0 \u610f\u5473\u7740\u60a8\u7684\u6a21\u578b\u975e\u5e38\u7cdf\u7cd5\uff08\u6216\u975e\u5e38\u597d\uff01\uff09\u3002\u8bd5\u7740\u53cd\u8f6c\u9884\u6d4b\u7684\u6982\u7387\uff0c\u4f8b\u5982\uff0c\u5982\u679c\u60a8\u9884\u6d4b\u6b63\u7c7b\u7684\u6982\u7387\u662f p\uff0c\u8bd5\u7740\u7528 1-p \u4ee3\u66ff\u5b83\u3002\u8fd9\u79cd AUC \u4e5f\u53ef\u80fd\u610f\u5473\u7740\u60a8\u7684\u9a8c\u8bc1\u6216\u6570\u636e\u5904\u7406\u5b58\u5728\u95ee\u9898\u3002 AUC = 0.5 \u610f\u5473\u7740\u4f60\u7684\u9884\u6d4b\u662f\u968f\u673a\u7684\u3002\u56e0\u6b64\uff0c\u5bf9\u4e8e\u4efb\u4f55\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u5982\u679c\u6211\u5c06\u6240\u6709\u76ee\u6807\u90fd\u9884\u6d4b\u4e3a 0.5\uff0c\u6211\u5c06\u5f97\u5230 0.5 \u7684 AUC\u3002 AUC \u503c\u4ecb\u4e8e 0 \u548c 0.5 \u4e4b\u95f4\uff0c\u610f\u5473\u7740\u4f60\u7684\u6a21\u578b\u6bd4\u968f\u673a\u6a21\u578b\u66f4\u5dee\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u8fd9\u662f\u56e0\u4e3a\u4f60\u98a0\u5012\u4e86\u7c7b\u522b\u3002 \u5982\u679c\u60a8\u5c1d\u8bd5\u53cd\u8f6c\u9884\u6d4b\uff0c\u60a8\u7684 AUC \u503c\u53ef\u80fd\u4f1a\u8d85\u8fc7 0.5\u3002\u63a5\u8fd1 1 \u7684 AUC \u503c\u88ab\u8ba4\u4e3a\u662f\u597d\u503c\u3002 \u4f46 AUC \u5bf9\u6211\u4eec\u7684\u6a21\u578b\u6709\u4ec0\u4e48\u5f71\u54cd\u5462\uff1f \u5047\u8bbe\u60a8\u5efa\u7acb\u4e86\u4e00\u4e2a\u4ece\u80f8\u90e8 X \u5149\u56fe\u50cf\u4e2d\u68c0\u6d4b\u6c14\u80f8\u7684\u6a21\u578b\uff0c\u5176 AUC \u503c\u4e3a 0.85\u3002\u8fd9\u610f\u5473\u7740\uff0c\u5982\u679c\u60a8\u4ece\u6570\u636e\u96c6\u4e2d\u968f\u673a\u9009\u62e9\u4e00\u5f20\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9633\u6027\u6837\u672c\uff09\u548c\u53e6\u4e00\u5f20\u6ca1\u6709\u6c14\u80f8\u7684\u56fe\u50cf\uff08\u9634\u6027\u6837\u672c\uff09\uff0c\u90a3\u4e48\u6c14\u80f8\u56fe\u50cf\u7684\u6392\u540d\u5c06\u9ad8\u4e8e\u975e\u6c14\u80f8\u56fe\u50cf\uff0c\u6982\u7387\u4e3a 0.85\u3002 \u8ba1\u7b97\u6982\u7387\u548c AUC \u540e\uff0c\u60a8\u9700\u8981\u5bf9\u6d4b\u8bd5\u96c6\u8fdb\u884c\u9884\u6d4b\u3002\u6839\u636e\u95ee\u9898\u548c\u4f7f\u7528\u60c5\u51b5\uff0c\u60a8\u53ef\u80fd\u9700\u8981\u6982\u7387\u6216\u5b9e\u9645\u7c7b\u522b\u3002\u5982\u679c\u4f60\u60f3\u8981\u6982\u7387\uff0c\u8fd9\u5e76\u4e0d\u96be\u3002\u5982\u679c\u60a8\u60f3\u8981\u7c7b\u522b\uff0c\u5219\u9700\u8981\u9009\u62e9\u4e00\u4e2a\u9608\u503c\u3002\u5728\u4e8c\u5143\u5206\u7c7b\u7684\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u91c7\u7528\u7c7b\u4f3c\u4e0b\u9762\u7684\u65b9\u6cd5\u3002 \\[ Prediction = Probability >= Threshold \\] \u4e5f\u5c31\u662f\u8bf4\uff0c\u9884\u6d4b\u662f\u4e00\u4e2a\u53ea\u5305\u542b\u4e8c\u5143\u53d8\u91cf\u7684\u65b0\u5217\u8868\u3002\u5982\u679c\u6982\u7387\u5927\u4e8e\u6216\u7b49\u4e8e\u7ed9\u5b9a\u7684\u9608\u503c\uff0c\u5219\u9884\u6d4b\u4e2d\u7684\u4e00\u9879\u4e3a 1\uff0c\u5426\u5219\u4e3a 0\u3002 \u4f60\u731c\u600e\u4e48\u7740\uff0c\u4f60\u53ef\u4ee5\u4f7f\u7528 ROC \u66f2\u7ebf\u6765\u9009\u62e9\u8fd9\u4e2a\u9608\u503c\uff01ROC \u66f2\u7ebf\u4f1a\u544a\u8bc9\u60a8\u9608\u503c\u5bf9\u5047\u9633\u6027\u7387\u548c\u771f\u9633\u6027\u7387\u7684\u5f71\u54cd\uff0c\u8fdb\u800c\u5f71\u54cd\u5047\u9633\u6027\u548c\u771f\u9633\u6027\u3002\u60a8\u5e94\u8be5\u9009\u62e9\u6700\u9002\u5408\u60a8\u7684\u95ee\u9898\u548c\u6570\u636e\u96c6\u7684\u9608\u503c\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u4e0d\u5e0c\u671b\u6709\u592a\u591a\u7684\u8bef\u62a5\uff0c\u90a3\u4e48\u9608\u503c\u5c31\u5e94\u8be5\u9ad8\u4e00\u4e9b\u3002\u4e0d\u8fc7\uff0c\u8fd9\u4e5f\u4f1a\u5e26\u6765\u66f4\u591a\u7684\u8bef\u62a5\u3002\u6ce8\u610f\u6743\u8861\u5229\u5f0a\uff0c\u9009\u62e9\u6700\u4f73\u9608\u503c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u8fd9\u4e9b\u9608\u503c\u5982\u4f55\u5f71\u54cd\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u503c\u3002 # \u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list = [] # \u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list = [] # \u771f\u5b9e\u6807\u7b7e y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] # \u9884\u6d4b\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387 y_pred = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , 0.85 , 0.15 , 0.99 ] # \u9884\u6d4b\u9608\u503c thresholds = [ 0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.85 , 0.9 , 0.99 , 1.0 ] # \u904d\u5386\u9884\u6d4b\u9608\u503c for thresh in thresholds : # \u82e5\u6837\u672c\u4e3a\u6b63\u7c7b\uff081\uff09\u7684\u6982\u7387\u5927\u4e8e\u9608\u503c\uff0c\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if x >= thresh else 0 for x in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 temp_tp = true_positive ( y_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 temp_fp = false_positive ( y_true , temp_pred ) # \u52a0\u5165\u771f\u9633\u6027\u6837\u672c\u6570\u5217\u8868 tp_list . append ( temp_tp ) # \u52a0\u5165\u5047\u9633\u6027\u6837\u672c\u6570\u5217\u8868 fp_list . append ( temp_fp ) \u5229\u7528\u8fd9\u4e00\u70b9\uff0c\u6211\u4eec\u53ef\u4ee5\u521b\u5efa\u4e00\u4e2a\u8868\u683c\uff0c\u5982\u56fe 5 \u6240\u793a\u3002 \u56fe 5\uff1a\u4e0d\u540c\u9608\u503c\u7684 TP \u503c\u548c FP \u503c \u5982\u56fe 6 \u6240\u793a\uff0c\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0cROC \u66f2\u7ebf\u5de6\u4e0a\u89d2\u7684\u503c\u5e94\u8be5\u662f\u4e00\u4e2a\u76f8\u5f53\u4e0d\u9519\u7684\u9608\u503c\u3002 \u5bf9\u6bd4\u8868\u683c\u548c ROC \u66f2\u7ebf\uff0c\u6211\u4eec\u53ef\u4ee5\u53d1\u73b0\uff0c0.6 \u5de6\u53f3\u7684\u9608\u503c\u76f8\u5f53\u4e0d\u9519\uff0c\u65e2\u4e0d\u4f1a\u4e22\u5931\u5927\u91cf\u7684\u771f\u9633\u6027\u7ed3\u679c\uff0c\u4e5f\u4e0d\u4f1a\u51fa\u73b0\u5927\u91cf\u7684\u5047\u9633\u6027\u7ed3\u679c\u3002 \u56fe 6\uff1a\u4ece ROC \u66f2\u7ebf\u6700\u5de6\u4fa7\u7684\u9876\u70b9\u9009\u62e9\u6700\u4f73\u9608\u503c AUC \u662f\u4e1a\u5185\u5e7f\u6cdb\u5e94\u7528\u4e8e\u504f\u659c\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6307\u6807\uff0c\u4e5f\u662f\u6bcf\u4e2a\u4eba\u90fd\u5e94\u8be5\u4e86\u89e3\u7684\u6307\u6807\u3002\u4e00\u65e6\u7406\u89e3\u4e86 AUC \u80cc\u540e\u7684\u7406\u5ff5\uff08\u5982\u4e0a\u6587\u6240\u8ff0\uff09\uff0c\u4e5f\u5c31\u5f88\u5bb9\u6613\u5411\u4e1a\u754c\u53ef\u80fd\u4f1a\u8bc4\u4f30\u60a8\u7684\u6a21\u578b\u7684\u975e\u6280\u672f\u4eba\u5458\u89e3\u91ca\u5b83\u4e86\u3002 \u5b66\u4e60 AUC \u540e\uff0c\u4f60\u5e94\u8be5\u5b66\u4e60\u7684\u53e6\u4e00\u4e2a\u91cd\u8981\u6307\u6807\u662f\u5bf9\u6570\u635f\u5931\u3002\u5bf9\u4e8e\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\uff0c\u6211\u4eec\u5c06\u5bf9\u6570\u635f\u5931\u5b9a\u4e49\u4e3a\uff1a \\[ LogLoss = -1.0 \\times (target \\times log(prediction) + (1-target) \\times log(1-prediction)) \\] \u5176\u4e2d\uff0c\u76ee\u6807\u503c\u4e3a 0 \u6216 1\uff0c\u9884\u6d4b\u503c\u4e3a\u6837\u672c\u5c5e\u4e8e\u7c7b\u522b 1 \u7684\u6982\u7387\u3002 \u5bf9\u4e8e\u6570\u636e\u96c6\u4e2d\u7684\u591a\u4e2a\u6837\u672c\uff0c\u6240\u6709\u6837\u672c\u7684\u5bf9\u6570\u635f\u5931\u53ea\u662f\u6240\u6709\u5355\u4e2a\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u9700\u8981\u8bb0\u4f4f\u7684\u4e00\u70b9\u662f\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u4e0d\u6b63\u786e\u6216\u504f\u5dee\u8f83\u5927\u7684\u9884\u6d4b\u8fdb\u884c\u76f8\u5f53\u9ad8\u7684\u60e9\u7f5a\uff0c\u4e5f\u5c31\u662f\u8bf4\uff0c\u5bf9\u6570\u635f\u5931\u4f1a\u5bf9\u975e\u5e38\u786e\u5b9a\u548c\u975e\u5e38\u9519\u8bef\u7684\u9884\u6d4b\u8fdb\u884c\u60e9\u7f5a\u3002 import numpy as np def log_loss ( y_true , y_proba ): # \u6781\u5c0f\u503c\uff0c\u9632\u6b620\u505a\u5206\u6bcd epsilon = 1e-15 # \u5bf9\u6570\u635f\u5931\u5217\u8868 loss = [] # \u904d\u5386y_true\uff0cy_pred\u4e2d\u6240\u6709\u5143\u7d20 for yt , yp in zip ( y_true , y_proba ): # \u9650\u5236yp\u8303\u56f4\uff0c\u6700\u5c0f\u4e3aepsilon\uff0c\u6700\u5927\u4e3a1-epsilon yp = np . clip ( yp , epsilon , 1 - epsilon ) # \u8ba1\u7b97\u5bf9\u6570\u635f\u5931 temp_loss = - 1.0 * ( yt * np . log ( yp ) + ( 1 - yt ) * np . log ( 1 - yp )) # \u52a0\u5165\u5bf9\u6570\u635f\u5931\u5217\u8868 loss . append ( temp_loss ) return np . mean ( loss ) \u8ba9\u6211\u4eec\u6d4b\u8bd5\u4e00\u4e0b\u51fd\u6570\u6267\u884c\u60c5\u51b5\uff1a In [ X ]: y_true = [ 0 , 0 , 0 , 0 , 1 , 0 , 1 , ... : 0 , 0 , 1 , 0 , 1 , 0 , 0 , 1 ] In [ X ]: y_proba = [ 0.1 , 0.3 , 0.2 , 0.6 , 0.8 , 0.05 , ... : 0.9 , 0.5 , 0.3 , 0.66 , 0.3 , 0.2 , ... : 0.85 , 0.15 , 0.99 ] In [ X ]: log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u6211\u4eec\u53ef\u4ee5\u5c06\u5176\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff1a In [ X ]: from sklearn import metrics In [ X ]: metrics . log_loss ( y_true , y_proba ) Out [ X ]: 0.49882711861432294 \u56e0\u6b64\uff0c\u6211\u4eec\u7684\u5b9e\u73b0\u662f\u6b63\u786e\u7684\u3002 \u5bf9\u6570\u635f\u5931\u7684\u5b9e\u73b0\u5f88\u5bb9\u6613\u3002\u89e3\u91ca\u8d77\u6765\u4f3c\u4e4e\u6709\u70b9\u56f0\u96be\u3002\u4f60\u5fc5\u987b\u8bb0\u4f4f\uff0c\u5bf9\u6570\u635f\u5931\u7684\u60e9\u7f5a\u8981\u6bd4\u5176\u4ed6\u6307\u6807\u5927\u5f97\u591a\u3002 \u4f8b\u5982\uff0c\u5982\u679c\u60a8\u6709 51% \u7684\u628a\u63e1\u8ba4\u4e3a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\uff0c\u90a3\u4e48\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.51) + (1 - 1) \\times log(1 - 0.51))=0.67 \\] \u5982\u679c\u4f60\u5bf9\u5c5e\u4e8e 0 \u7c7b\u7684\u6837\u672c\u6709 49% \u7684\u628a\u63e1\uff0c\u5bf9\u6570\u635f\u5931\u5c31\u662f\uff1a \\[ -1.0 \\times (1 \\times log(0.49) + (1 - 1) \\times log(1 - 0.49))=0.67 \\] \u56e0\u6b64\uff0c\u5373\u4f7f\u6211\u4eec\u53ef\u4ee5\u9009\u62e9 0.5 \u7684\u622a\u65ad\u503c\u5e76\u5f97\u5230\u5b8c\u7f8e\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u6211\u4eec\u4ecd\u7136\u4f1a\u6709\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002\u56e0\u6b64\uff0c\u5728\u5904\u7406\u5bf9\u6570\u635f\u5931\u65f6\uff0c\u4f60\u9700\u8981\u975e\u5e38\u5c0f\u5fc3\uff1b\u4efb\u4f55\u4e0d\u786e\u5b9a\u7684\u9884\u6d4b\u90fd\u4f1a\u4ea7\u751f\u975e\u5e38\u9ad8\u7684\u5bf9\u6570\u635f\u5931\u3002 \u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u5927\u591a\u6570\u6307\u6807\u90fd\u53ef\u4ee5\u8f6c\u6362\u6210\u591a\u7c7b\u7248\u672c\u3002\u8fd9\u4e2a\u60f3\u6cd5\u5f88\u7b80\u5355\u3002\u4ee5\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u4e3a\u4f8b\u3002\u6211\u4eec\u53ef\u4ee5\u8ba1\u7b97\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u4e2d\u6bcf\u4e00\u7c7b\u7684\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u6709\u4e09\u79cd\u4e0d\u540c\u7684\u8ba1\u7b97\u65b9\u6cd5\uff0c\u6709\u65f6\u53ef\u80fd\u4f1a\u4ee4\u4eba\u56f0\u60d1\u3002\u5047\u8bbe\u6211\u4eec\u9996\u5148\u5bf9\u7cbe\u786e\u7387\u611f\u5174\u8da3\u3002\u6211\u4eec\u77e5\u9053\uff0c\u7cbe\u786e\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u548c\u5047\u9633\u6027\u3002 \u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Macro averaged precision\uff09\uff1a\u5206\u522b\u8ba1\u7b97\u6240\u6709\u7c7b\u522b\u7684\u7cbe\u786e\u7387\u7136\u540e\u6c42\u5e73\u5747\u503c \u5fae\u89c2\u5e73\u5747\u7cbe\u786e\u7387 \uff08Micro averaged precision\uff09\uff1a\u6309\u7c7b\u8ba1\u7b97\u771f\u9633\u6027\u548c\u5047\u9633\u6027\uff0c\u7136\u540e\u7528\u5176\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387\u3002\u7136\u540e\u4ee5\u6b64\u8ba1\u7b97\u603b\u4f53\u7cbe\u786e\u7387 \u52a0\u6743\u7cbe\u786e\u7387 \uff08Weighted precision\uff09\uff1a\u4e0e\u5b8f\u89c2\u7cbe\u786e\u7387\u76f8\u540c\uff0c\u4f46\u8fd9\u91cc\u662f\u52a0\u6743\u5e73\u5747\u7cbe\u786e\u7387 \u53d6\u51b3\u4e8e\u6bcf\u4e2a\u7c7b\u522b\u4e2d\u7684\u9879\u76ee\u6570 \u8fd9\u770b\u4f3c\u590d\u6742\uff0c\u4f46\u5728 python \u5b9e\u73b0\u4e2d\u5f88\u5bb9\u6613\u7406\u89e3\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5b8f\u89c2\u5e73\u5747\u7cbe\u786e\u7387\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 import numpy as np def macro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u5982\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u8ba1\u7b97\u7cbe\u786e\u5ea6 temp_precision = tp / ( tp + fp ) # \u5404\u7c7b\u7cbe\u786e\u7387\u76f8\u52a0 precision += temp_precision # \u8ba1\u7b97\u5e73\u5747\u503c precision /= num_classes return precision \u4f60\u4f1a\u53d1\u73b0\u8fd9\u5e76\u4e0d\u96be\u3002\u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5fae\u5e73\u5747\u7cbe\u786e\u7387\u5206\u6570\u3002 import numpy as np def micro_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u521d\u59cb\u5316\u771f\u9633\u6027\u6837\u672c\u6570 tp = 0 # \u521d\u59cb\u5316\u5047\u9633\u6027\u6837\u672c\u6570 fp = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 tp += true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570\u76f8\u52a0 fp += false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 precision = tp / ( tp + fp ) return precision \u8fd9\u4e5f\u4e0d\u96be\u3002\u90a3\u4ec0\u4e48\u96be\uff1f\u4ec0\u4e48\u90fd\u4e0d\u96be\u3002\u673a\u5668\u5b66\u4e60\u5f88\u7b80\u5355\u3002\u73b0\u5728\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u52a0\u6743\u7cbe\u786e\u7387\u7684\u5b9e\u73b0\u3002 from collections import Counter import numpy as np def weighted_precision ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316\u7cbe\u786e\u7387 precision = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( temp_true , temp_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( temp_true , temp_pred ) # \u7cbe\u786e\u7387 temp_precision = tp / ( tp + fp ) # \u6839\u636e\u8be5\u79cd\u7c7b\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_precision = class_counts [ class_ ] * temp_precision # \u52a0\u6743\u7cbe\u786e\u7387\u6c42\u548c precision += weighted_precision # \u8ba1\u7b97\u5e73\u5747\u7cbe\u786e\u7387 overall_precision = precision / len ( y_true ) return overall_precision \u5c06\u6211\u4eec\u7684\u5b9e\u73b0\u4e0e scikit-learn \u8fdb\u884c\u6bd4\u8f83\uff0c\u4ee5\u4e86\u89e3\u5b9e\u73b0\u662f\u5426\u6b63\u786e\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: macro_precision ( y_true , y_pred ) Out [ X ]: 0.3611111111111111 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"macro\" ) Out [ X ]: 0.3611111111111111 In [ X ]: micro_precision ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"micro\" ) Out [ X ]: 0.4444444444444444 In [ X ]: weighted_precision ( y_true , y_pred ) Out [ X ]: 0.39814814814814814 In [ X ]: metrics . precision_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.39814814814814814 \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6b63\u786e\u5730\u5b9e\u73b0\u4e86\u4e00\u5207\u3002 \u8bf7\u6ce8\u610f\uff0c\u8fd9\u91cc\u5c55\u793a\u7684\u5b9e\u73b0\u53ef\u80fd\u4e0d\u662f\u6700\u6709\u6548\u7684\uff0c\u4f46\u5374\u662f\u6700\u5bb9\u6613\u7406\u89e3\u7684\u3002 \u540c\u6837\uff0c\u6211\u4eec\u4e5f\u53ef\u4ee5\u5b9e\u73b0 \u591a\u7c7b\u522b\u7684\u53ec\u56de\u7387\u6307\u6807 \u3002\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u53d6\u51b3\u4e8e\u771f\u9633\u6027\u3001\u5047\u9633\u6027\u548c\u5047\u9634\u6027\uff0c\u800c F1 \u5219\u53d6\u51b3\u4e8e\u7cbe\u786e\u7387\u548c\u53ec\u56de\u7387\u3002 \u53ec\u56de\u7387\u7684\u5b9e\u73b0\u65b9\u6cd5\u7559\u5f85\u8bfb\u8005\u7ec3\u4e60\uff0c\u8fd9\u91cc\u5b9e\u73b0\u7684\u662f\u591a\u7c7b F1 \u7684\u4e00\u4e2a\u7248\u672c\uff0c\u5373\u52a0\u6743\u5e73\u5747\u503c\u3002 from collections import Counter import numpy as np def weighted_f1 ( y_true , y_pred ): # \u79cd\u7c7b\u6570 num_classes = len ( np . unique ( y_true )) # \u7edf\u8ba1\u5404\u79cd\u7c7b\u6837\u672c\u6570 class_counts = Counter ( y_true ) # \u521d\u59cb\u5316F1\u503c f1 = 0 # \u904d\u53860~\uff08\u79cd\u7c7b\u6570-1\uff09 for class_ in range ( num_classes ): # \u82e5\u771f\u5b9e\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_true = [ 1 if p == class_ else 0 for p in y_true ] # \u82e5\u9884\u6d4b\u6807\u7b7e\u4e3aclass_\u4e3a1\uff0c\u5426\u5219\u4e3a0 temp_pred = [ 1 if p == class_ else 0 for p in y_pred ] # \u8ba1\u7b97\u7cbe\u786e\u7387 p = precision ( temp_true , temp_pred ) # \u8ba1\u7b97\u53ec\u56de\u7387 r = recall ( temp_true , temp_pred ) # \u82e5\u7cbe\u786e\u7387+\u53ec\u56de\u7387\u4e0d\u4e3a0\uff0c\u5219\u4f7f\u7528\u516c\u5f0f\u8ba1\u7b97F1\u503c if p + r != 0 : temp_f1 = 2 * p * r / ( p + r ) # \u5426\u5219\u76f4\u63a5\u4e3a0 else : temp_f1 = 0 # \u6839\u636e\u6837\u672c\u6570\u5206\u914d\u6743\u91cd weighted_f1 = class_counts [ class_ ] * temp_f1 # \u52a0\u6743F1\u503c\u76f8\u52a0 f1 += weighted_f1 # \u8ba1\u7b97\u52a0\u6743\u5e73\u5747F1\u503c overall_f1 = f1 / len ( y_true ) return overall_f1 \u8bf7\u6ce8\u610f\uff0c\u4e0a\u9762\u6709\u51e0\u884c\u4ee3\u7801\u662f\u65b0\u5199\u7684\u3002\u56e0\u6b64\uff0c\u4f60\u5e94\u8be5\u4ed4\u7ec6\u9605\u8bfb\u8fd9\u4e9b\u4ee3\u7801\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] In [ X ]: y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] In [ X ]: weighted_f1 ( y_true , y_pred ) Out [ X ]: 0.41269841269841273 In [ X ]: metrics . f1_score ( y_true , y_pred , average = \"weighted\" ) Out [ X ]: 0.41269841269841273 \u56e0\u6b64\uff0c\u6211\u4eec\u5df2\u7ecf\u4e3a\u591a\u7c7b\u95ee\u9898\u5b9e\u73b0\u4e86\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c F1\u3002\u540c\u6837\uff0c\u60a8\u4e5f\u53ef\u4ee5\u5c06 AUC \u548c\u5bf9\u6570\u635f\u5931\u8f6c\u6362\u4e3a\u591a\u7c7b\u683c\u5f0f\u3002\u8fd9\u79cd\u8f6c\u6362\u683c\u5f0f\u88ab\u79f0\u4e3a one-vs-all \u3002\u8fd9\u91cc\u6211\u4e0d\u6253\u7b97\u5b9e\u73b0\u5b83\u4eec\uff0c\u56e0\u4e3a\u5b9e\u73b0\u65b9\u6cd5\u4e0e\u6211\u4eec\u5df2\u7ecf\u8ba8\u8bba\u8fc7\u7684\u5f88\u76f8\u4f3c\u3002 \u5728\u4e8c\u5143\u6216\u591a\u7c7b\u5206\u7c7b\u4e2d\uff0c\u770b\u4e00\u4e0b \u6df7\u6dc6\u77e9\u9635 \u4e5f\u5f88\u6d41\u884c\u3002\u4e0d\u8981\u56f0\u60d1\uff0c\u8fd9\u5f88\u7b80\u5355\u3002\u6df7\u6dc6\u77e9\u9635\u53ea\u4e0d\u8fc7\u662f\u4e00\u4e2a\u5305\u542b TP\u3001FP\u3001TN \u548c FN \u7684\u8868\u683c\u3002\u4f7f\u7528\u6df7\u6dc6\u77e9\u9635\uff0c\u60a8\u53ef\u4ee5\u5feb\u901f\u67e5\u770b\u6709\u591a\u5c11\u6837\u672c\u88ab\u9519\u8bef\u5206\u7c7b\uff0c\u6709\u591a\u5c11\u6837\u672c\u88ab\u6b63\u786e\u5206\u7c7b\u3002\u4e5f\u8bb8\u6709\u4eba\u4f1a\u8bf4\uff0c\u6df7\u6dc6\u77e9\u9635\u5e94\u8be5\u5728\u672c\u7ae0\u5f88\u65e9\u5c31\u8bb2\u5230\uff0c\u4f46\u6211\u6ca1\u6709\u8fd9\u4e48\u505a\u3002\u5982\u679c\u4e86\u89e3\u4e86 TP\u3001FP\u3001TN\u3001FN\u3001\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u548c AUC\uff0c\u5c31\u5f88\u5bb9\u6613\u7406\u89e3\u548c\u89e3\u91ca\u6df7\u6dc6\u77e9\u9635\u4e86\u3002\u8ba9\u6211\u4eec\u770b\u770b\u56fe 7 \u4e2d\u4e8c\u5143\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635\u3002 \u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c\u6df7\u6dc6\u77e9\u9635\u7531 TP\u3001FP\u3001FN \u548c TN \u7ec4\u6210\u3002\u6211\u4eec\u53ea\u9700\u8981\u8fd9\u4e9b\u503c\u6765\u8ba1\u7b97\u7cbe\u786e\u7387\u3001\u53ec\u56de\u7387\u3001F1 \u5206\u6570\u548c AUC\u3002\u6709\u65f6\uff0c\u4eba\u4eec\u4e5f\u559c\u6b22\u628a FP \u79f0\u4e3a \u7b2c\u4e00\u7c7b\u9519\u8bef \uff0c\u628a FN \u79f0\u4e3a \u7b2c\u4e8c\u7c7b\u9519\u8bef \u3002 \u56fe 7\uff1a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u7684\u6df7\u6dc6\u77e9\u9635 \u6211\u4eec\u8fd8\u53ef\u4ee5\u5c06\u4e8c\u5143\u6df7\u6dc6\u77e9\u9635\u6269\u5c55\u4e3a\u591a\u7c7b\u6df7\u6dc6\u77e9\u9635\u3002\u5b83\u4f1a\u662f\u4ec0\u4e48\u6837\u5b50\u5462\uff1f\u5982\u679c\u6211\u4eec\u6709 N \u4e2a\u7c7b\u522b\uff0c\u5b83\u5c06\u662f\u4e00\u4e2a\u5927\u5c0f\u4e3a NxN \u7684\u77e9\u9635\u3002\u5bf9\u4e8e\u6bcf\u4e2a\u7c7b\u522b\uff0c\u6211\u4eec\u90fd\u8981\u8ba1\u7b97\u76f8\u5173\u7c7b\u522b\u548c\u5176\u4ed6\u7c7b\u522b\u7684\u6837\u672c\u603b\u6570\u3002\u4e3e\u4e2a\u4f8b\u5b50\u53ef\u4ee5\u8ba9\u6211\u4eec\u66f4\u597d\u5730\u7406\u89e3\u8fd9\u4e00\u70b9\u3002 \u5047\u8bbe\u6211\u4eec\u6709\u4ee5\u4e0b\u771f\u5b9e\u6807\u7b7e\uff1a \\[ [0, 1, 2, 0, 1, 2, 0, 2, 2] \\] \u6211\u4eec\u7684\u9884\u6d4b\u6807\u7b7e\u662f\uff1a \\[ [0, 2, 1, 0, 2, 1, 0, 0, 2] \\] \u90a3\u4e48\uff0c\u6211\u4eec\u7684\u6df7\u6dc6\u77e9\u9635\u5c06\u5982\u56fe 8 \u6240\u793a\u3002 \u56fe 8\uff1a\u591a\u5206\u7c7b\u95ee\u9898\u7684\u6df7\u6dc6\u77e9\u9635 \u56fe 8 \u8bf4\u660e\u4e86\u4ec0\u4e48\uff1f \u8ba9\u6211\u4eec\u6765\u770b\u770b 0 \u7c7b\u3002\u6211\u4eec\u770b\u5230\uff0c\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e 0 \u7c7b\u3002\u7136\u800c\uff0c\u5728\u9884\u6d4b\u4e2d\uff0c\u6211\u4eec\u6709 3 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 0 \u7c7b\uff0c1 \u4e2a\u6837\u672c\u5c5e\u4e8e\u7b2c 1 \u7c7b\u3002\u7406\u60f3\u60c5\u51b5\u4e0b\uff0c\u5bf9\u4e8e\u771f\u5b9e\u6807\u7b7e\u4e2d\u7684\u7c7b\u522b 0\uff0c\u9884\u6d4b\u6807\u7b7e 1 \u548c 2 \u5e94\u8be5\u6ca1\u6709\u4efb\u4f55\u6837\u672c\u3002\u8ba9\u6211\u4eec\u770b\u770b\u7c7b\u522b 2\u3002\u5728\u771f\u5b9e\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 4\uff0c\u800c\u5728\u9884\u6d4b\u6807\u7b7e\u4e2d\uff0c\u8fd9\u4e2a\u6570\u5b57\u52a0\u8d77\u6765\u662f 3\u3002 \u4e00\u4e2a\u5b8c\u7f8e\u7684\u6df7\u6dc6\u77e9\u9635\u53ea\u80fd\u4ece\u5de6\u5230\u53f3\u659c\u5411\u586b\u5145\u3002 \u6df7\u6dc6\u77e9\u9635 \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u7684\u65b9\u6cd5\u6765\u8ba1\u7b97\u6211\u4eec\u4e4b\u524d\u8ba8\u8bba\u8fc7\u7684\u4e0d\u540c\u6307\u6807\u3002Scikit-learn \u63d0\u4f9b\u4e86\u4e00\u79cd\u7b80\u5355\u76f4\u63a5\u7684\u65b9\u6cd5\u6765\u751f\u6210\u6df7\u6dc6\u77e9\u9635\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5728\u56fe 8 \u4e2d\u663e\u793a\u7684\u6df7\u6dc6\u77e9\u9635\u662f scikit-learn \u6df7\u6dc6\u77e9\u9635\u7684\u8f6c\u7f6e\uff0c\u539f\u59cb\u7248\u672c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u7ed8\u5236\u3002 import matplotlib.pyplot as plt import seaborn as sns from sklearn import metrics # \u771f\u5b9e\u6837\u672c\u6807\u7b7e y_true = [ 0 , 1 , 2 , 0 , 1 , 2 , 0 , 2 , 2 ] # \u9884\u6d4b\u6837\u672c\u6807\u7b7e y_pred = [ 0 , 2 , 1 , 0 , 2 , 1 , 0 , 0 , 2 ] # \u8ba1\u7b97\u6df7\u6dc6\u77e9\u9635 cm = metrics . confusion_matrix ( y_true , y_pred ) # \u521b\u5efa\u753b\u5e03 plt . figure ( figsize = ( 10 , 10 )) # \u521b\u5efa\u65b9\u683c cmap = sns . cubehelix_palette ( 50 , hue = 0.05 , rot = 0 , light = 0.9 , dark = 0 , as_cmap = True ) # \u89c4\u5b9a\u5b57\u4f53\u5927\u5c0f sns . set ( font_scale = 2.5 ) # \u7ed8\u5236\u70ed\u56fe sns . heatmap ( cm , annot = True , cmap = cmap , cbar = False ) # y\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . ylabel ( 'Actual Labels' , fontsize = 20 ) # x\u8f74\u6807\u7b7e\uff0c\u5b57\u4f53\u5927\u5c0f\u4e3a20 plt . xlabel ( 'Predicted Labels' , fontsize = 20 ) \u56e0\u6b64\uff0c\u5230\u76ee\u524d\u4e3a\u6b62\uff0c\u6211\u4eec\u5df2\u7ecf\u89e3\u51b3\u4e86\u4e8c\u5143\u5206\u7c7b\u548c\u591a\u7c7b\u5206\u7c7b\u7684\u5ea6\u91cf\u95ee\u9898\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5c06\u8ba8\u8bba\u53e6\u4e00\u79cd\u7c7b\u578b\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u5373\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u5728\u591a\u6807\u7b7e\u5206\u7c7b\u4e2d\uff0c\u6bcf\u4e2a\u6837\u672c\u90fd\u53ef\u80fd\u4e0e\u4e00\u4e2a\u6216\u591a\u4e2a\u7c7b\u522b\u76f8\u5173\u8054\u3002\u8fd9\u7c7b\u95ee\u9898\u7684\u4e00\u4e2a\u7b80\u5355\u4f8b\u5b50\u5c31\u662f\u8981\u6c42\u4f60\u9884\u6d4b\u7ed9\u5b9a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53\u3002 \u56fe 9 \u663e\u793a\u4e86\u4e00\u4e2a\u8457\u540d\u6570\u636e\u96c6\u7684\u56fe\u50cf\u793a\u4f8b\u3002\u8bf7\u6ce8\u610f\uff0c\u8be5\u6570\u636e\u96c6\u7684\u76ee\u6807\u6709\u6240\u4e0d\u540c\uff0c\u4f46\u6211\u4eec\u6682\u4e14\u4e0d\u53bb\u8ba8\u8bba\u5b83\u3002\u6211\u4eec\u5047\u8bbe\u5176\u76ee\u7684\u53ea\u662f\u9884\u6d4b\u56fe\u50cf\u4e2d\u662f\u5426\u5b58\u5728\u67d0\u4e2a\u7269\u4f53\u3002\u5728\u56fe 9 \u4e2d\uff0c\u6211\u4eec\u6709\u6905\u5b50\u3001\u82b1\u76c6\u3001\u7a97\u6237\uff0c\u4f46\u6ca1\u6709\u5176\u4ed6\u7269\u4f53\uff0c\u5982\u7535\u8111\u3001\u5e8a\u3001\u7535\u89c6\u7b49\u3002\u56e0\u6b64\uff0c\u4e00\u5e45\u56fe\u50cf\u53ef\u80fd\u6709\u591a\u4e2a\u76f8\u5173\u76ee\u6807\u3002\u8fd9\u7c7b\u95ee\u9898\u5c31\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u95ee\u9898\u3002 \u56fe 9\uff1a\u56fe\u50cf\u4e2d\u7684\u4e0d\u540c\u7269\u4f53 \u8fd9\u7c7b\u5206\u7c7b\u95ee\u9898\u7684\u8861\u91cf\u6807\u51c6\u6709\u4e9b\u4e0d\u540c\u3002\u4e00\u4e9b\u5408\u9002\u7684 \u6700\u5e38\u89c1\u7684\u6307\u6807\u6709\uff1a k \u7cbe\u786e\u7387\uff08P@k\uff09 k \u5e73\u5747\u7cbe\u786e\u7387\uff08AP@k\uff09 k \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387\uff08MAP@k\uff09 \u5bf9\u6570\u635f\u5931\uff08Log loss\uff09 \u8ba9\u6211\u4eec\u4ece k \u7cbe\u786e\u7387\u6216\u8005 P@k \u6211\u4eec\u4e0d\u80fd\u5c06\u8fd9\u4e00\u7cbe\u786e\u7387\u4e0e\u524d\u9762\u8ba8\u8bba\u7684\u7cbe\u786e\u7387\u6df7\u6dc6\u3002\u5982\u679c\u60a8\u6709\u4e00\u4e2a\u7ed9\u5b9a\u6837\u672c\u7684\u539f\u59cb\u7c7b\u522b\u5217\u8868\u548c\u540c\u4e00\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u7c7b\u522b\u5217\u8868\uff0c\u90a3\u4e48\u7cbe\u786e\u7387\u7684\u5b9a\u4e49\u5c31\u662f\u9884\u6d4b\u5217\u8868\u4e2d\u4ec5\u8003\u8651\u524d k \u4e2a\u9884\u6d4b\u7ed3\u679c\u7684\u547d\u4e2d\u6570\u9664\u4ee5 k\u3002 \u5982\u679c\u60a8\u5bf9\u6b64\u611f\u5230\u56f0\u60d1\uff0c\u4f7f\u7528 python \u4ee3\u7801\u540e\u5c31\u4f1a\u660e\u767d\u3002 def pk ( y_true , y_pred , k ): # \u5982\u679ck\u4e3a0 if k == 0 : # \u8fd4\u56de0 return 0 # \u53d6\u9884\u6d4b\u6807\u7b7e\u524dk\u4e2a y_pred = y_pred [: k ] # \u5c06\u9884\u6d4b\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 pred_set = set ( y_pred ) # \u5c06\u771f\u5b9e\u6807\u7b7e\u8f6c\u6362\u4e3a\u96c6\u5408 true_set = set ( y_true ) # \u9884\u6d4b\u6807\u7b7e\u96c6\u5408\u4e0e\u771f\u5b9e\u6807\u7b7e\u96c6\u5408\u4ea4\u96c6 common_values = pred_set . intersection ( true_set ) # \u8ba1\u7b97\u7cbe\u786e\u7387 return len ( common_values ) / len ( y_pred [: k ]) \u6709\u4e86\u4ee3\u7801\uff0c\u4e00\u5207\u90fd\u53d8\u5f97\u66f4\u5bb9\u6613\u7406\u89e3\u4e86\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6709\u4e86 k \u5e73\u5747\u7cbe\u786e\u7387\u6216 AP@k \u3002AP@k \u662f\u901a\u8fc7 P@k \u8ba1\u7b97\u5f97\u51fa\u7684\u3002\u4f8b\u5982\uff0c\u5982\u679c\u8981\u8ba1\u7b97 AP@3\uff0c\u6211\u4eec\u8981\u5148\u8ba1\u7b97 P@1\u3001P@2 \u548c P@3\uff0c\u7136\u540e\u5c06\u603b\u548c\u9664\u4ee5 3\u3002 \u8ba9\u6211\u4eec\u6765\u770b\u770b\u5b83\u7684\u5b9e\u73b0\u3002 def apk ( y_true , y_pred , k ): # \u521d\u59cb\u5316P@k\u5217\u8868 pk_values = [] # \u904d\u53861~k for i in range ( 1 , k + 1 ): # \u5c06P@k\u52a0\u5165\u5217\u8868 pk_values . append ( pk ( y_true , y_pred , i )) # \u82e5\u957f\u5ea6\u4e3a0 if len ( pk_values ) == 0 : # \u8fd4\u56de0 return 0 # \u5426\u5219\u8ba1\u7b97AP@K return sum ( pk_values ) / len ( pk_values ) \u8fd9\u4e24\u4e2a\u51fd\u6570\u53ef\u4ee5\u7528\u6765\u8ba1\u7b97\u4e24\u4e2a\u7ed9\u5b9a\u5217\u8868\u7684 k \u5e73\u5747\u7cbe\u786e\u7387 (AP@k)\uff1b\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u8ba1\u7b97\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: for i in range ( len ( y_true )): ... : for j in range ( 1 , 4 ): ... : print ( ... : f \"\"\" ...: y_true= { y_true [ i ] } , ...: y_pred= { y_pred [ i ] } , ...: AP@ { j } = { apk ( y_true [ i ], y_pred [ i ], k = j ) } ...: \"\"\" ... : ) ... : y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 1 = 0.0 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 2 = 0.25 y_true = [ 1 , 2 , 3 ], y_pred = [ 0 , 1 , 2 ], AP @ 3 = 0.38888888888888884 \u8bf7\u6ce8\u610f\uff0c\u6211\u7701\u7565\u4e86\u8f93\u51fa\u7ed3\u679c\u4e2d\u7684\u8bb8\u591a\u6570\u503c\uff0c\u4f46\u4f60\u4f1a\u660e\u767d\u5176\u4e2d\u7684\u610f\u601d\u3002\u8fd9\u5c31\u662f\u6211\u4eec\u5982\u4f55\u8ba1\u7b97 AP@k \u7684\u65b9\u6cd5\uff0c\u5373\u6bcf\u4e2a\u6837\u672c\u7684 AP@k\u3002\u5728\u673a\u5668\u5b66\u4e60\u4e2d\uff0c\u6211\u4eec\u5bf9\u6240\u6709\u6837\u672c\u90fd\u611f\u5174\u8da3\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u6211\u4eec\u6709 \u5747\u503c\u5e73\u5747\u7cbe\u786e\u7387 k \u6216 MAP@k \u3002MAP@k \u53ea\u662f AP@k \u7684\u5e73\u5747\u503c\uff0c\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b python \u4ee3\u7801\u8f7b\u677e\u8ba1\u7b97\u3002 def mapk ( y_true , y_pred , k ): # \u521d\u59cb\u5316AP@k\u5217\u8868 apk_values = [] # \u904d\u53860~\uff08\u771f\u5b9e\u6807\u7b7e\u6570-1\uff09 for i in range ( len ( y_true )): # \u5c06AP@K\u52a0\u5165\u5217\u8868 apk_values . append ( apk ( y_true [ i ], y_pred [ i ], k = k ) ) # \u8ba1\u7b97\u5e73\u5747AP@k return sum ( apk_values ) / len ( apk_values ) \u73b0\u5728\uff0c\u6211\u4eec\u53ef\u4ee5\u9488\u5bf9\u76f8\u540c\u7684\u5217\u8868\u8ba1\u7b97 k=1\u30012\u30013 \u548c 4 \u65f6\u7684 MAP@k\u3002 In [ X ]: y_true = [ ... : [ 1 , 2 , 3 ], ... : [ 0 , 2 ], ... : [ 1 ], ... : [ 2 , 3 ], ... : [ 1 , 0 ], ... : [] ... : ] In [ X ]: y_pred = [ ... : [ 0 , 1 , 2 ], ... : [ 1 ], ... : [ 0 , 2 , 3 ], ... : [ 2 , 3 , 4 , 0 ], ... : [ 0 , 1 , 2 ], ... : [ 0 ] ... : ] In [ X ]: mapk ( y_true , y_pred , k = 1 ) Out [ X ]: 0.3333333333333333 In [ X ]: mapk ( y_true , y_pred , k = 2 ) Out [ X ]: 0.375 In [ X ]: mapk ( y_true , y_pred , k = 3 ) Out [ X ]: 0.3611111111111111 In [ X ]: mapk ( y_true , y_pred , k = 4 ) Out [ X ]: 0.34722222222222215 P@k\u3001AP@k \u548c MAP@k \u7684\u8303\u56f4\u90fd\u662f\u4ece 0 \u5230 1\uff0c\u5176\u4e2d 1 \u4e3a\u6700\u4f73\u3002 \u8bf7\u6ce8\u610f\uff0c\u6709\u65f6\u60a8\u53ef\u80fd\u4f1a\u5728\u4e92\u8054\u7f51\u4e0a\u770b\u5230 P@k \u548c AP@k \u7684\u4e0d\u540c\u5b9e\u73b0\u65b9\u5f0f\u3002 \u4f8b\u5982\uff0c\u8ba9\u6211\u4eec\u6765\u770b\u770b\u5176\u4e2d\u4e00\u79cd\u5b9e\u73b0\u65b9\u5f0f\u3002 import numpy as np def apk ( actual , predicted , k = 10 ): # \u82e5\u9884\u6d4b\u6807\u7b7e\u957f\u5ea6\u5927\u4e8ek if len ( predicted ) > k : # \u53d6\u524dk\u4e2a\u6807\u7b7e predicted = predicted [: k ] score = 0.0 num_hits = 0.0 for i , p in enumerate ( predicted ): if p in actual and p not in predicted [: i ]: num_hits += 1.0 score += num_hits / ( i + 1.0 ) if not actual : return 0.0 return score / min ( len ( actual ), k ) \u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u662f AP@k \u7684\u53e6\u4e00\u4e2a\u7248\u672c\uff0c\u5176\u4e2d\u987a\u5e8f\u5f88\u91cd\u8981\uff0c\u6211\u4eec\u8981\u6743\u8861\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u79cd\u5b9e\u73b0\u65b9\u5f0f\u7684\u7ed3\u679c\u4e0e\u6211\u7684\u4ecb\u7ecd\u7565\u6709\u4e0d\u540c\u3002 \u73b0\u5728\uff0c\u6211\u4eec\u6765\u770b\u770b \u591a\u6807\u7b7e\u5206\u7c7b\u7684\u5bf9\u6570\u635f\u5931 \u3002\u8fd9\u5f88\u5bb9\u6613\u3002\u60a8\u53ef\u4ee5\u5c06\u76ee\u6807\u8f6c\u6362\u4e3a\u4e8c\u5143\u5206\u7c7b\uff0c\u7136\u540e\u5bf9\u6bcf\u4e00\u5217\u4f7f\u7528\u5bf9\u6570\u635f\u5931\u3002\u6700\u540e\uff0c\u4f60\u53ef\u4ee5\u6c42\u51fa\u6bcf\u5217\u5bf9\u6570\u635f\u5931\u7684\u5e73\u5747\u503c\u3002\u8fd9\u4e5f\u88ab\u79f0\u4e3a\u5e73\u5747\u5217\u5bf9\u6570\u635f\u5931\u3002\u5f53\u7136\uff0c\u8fd8\u6709\u5176\u4ed6\u65b9\u6cd5\u53ef\u4ee5\u5b9e\u73b0\u8fd9\u4e00\u70b9\uff0c\u4f60\u5e94\u8be5\u5728\u9047\u5230\u65f6\u52a0\u4ee5\u63a2\u7d22\u3002 \u6211\u4eec\u73b0\u5728\u53ef\u4ee5\u8bf4\u5df2\u7ecf\u638c\u63e1\u4e86\u6240\u6709\u4e8c\u5143\u5206\u7c7b\u3001\u591a\u7c7b\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u6307\u6807\uff0c\u73b0\u5728\u6211\u4eec\u53ef\u4ee5\u8f6c\u5411\u56de\u5f52\u6307\u6807\u3002 \u56de\u5f52\u4e2d\u6700\u5e38\u89c1\u7684\u6307\u6807\u662f \u8bef\u5dee\uff08Error\uff09 \u3002\u8bef\u5dee\u5f88\u7b80\u5355\uff0c\u4e5f\u5f88\u5bb9\u6613\u7406\u89e3\u3002 \\[ Error = True\\ Value - Predicted\\ Value \\] \u7edd\u5bf9\u8bef\u5dee\uff08Absolute error\uff09 \u53ea\u662f\u4e0a\u8ff0\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\u3002 \\[ Absolute\\ Error = Abs(True\\ Value - Predicted\\ Value) \\] \u63a5\u4e0b\u6765\u6211\u4eec\u8ba8\u8bba \u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff08MAE\uff09 \u3002\u5b83\u53ea\u662f\u6240\u6709\u7edd\u5bf9\u8bef\u5dee\u7684\u5e73\u5747\u503c\u3002 import numpy as np def mean_absolute_error ( y_true , y_pred ): #\u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u7edd\u5bf9\u8bef\u5dee error += np . abs ( yt - yp ) # \u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee return error / len ( y_true ) \u540c\u6837\uff0c\u6211\u4eec\u8fd8\u6709\u5e73\u65b9\u8bef\u5dee\u548c \u5747\u65b9\u8bef\u5dee \uff08MSE\uff09 \u3002 \\[ Squared\\ Error = (True Value - Predicted\\ Value)^2 \\] \u5747\u65b9\u8bef\u5dee\uff08MSE\uff09\u7684\u8ba1\u7b97\u65b9\u5f0f\u5982\u4e0b def mean_squared_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u7d2f\u52a0\u8bef\u5dee\u5e73\u65b9\u548c error += ( yt - yp ) ** 2 # \u8ba1\u7b97\u5747\u65b9\u8bef\u5dee return error / len ( y_true ) MSE \u548c RMSE\uff08\u5747\u65b9\u6839\u8bef\u5dee\uff09 \u662f\u8bc4\u4f30\u56de\u5f52\u6a21\u578b\u6700\u5e38\u7528\u7684\u6307\u6807\u3002 \\[ RMSE = SQRT(MSE) \\] \u540c\u4e00\u7c7b\u8bef\u5dee\u7684\u53e6\u4e00\u79cd\u7c7b\u578b\u662f \u5e73\u65b9\u5bf9\u6570\u8bef\u5dee \u3002\u6709\u4eba\u79f0\u5176\u4e3a SLE \uff0c\u5f53\u6211\u4eec\u53d6\u6240\u6709\u6837\u672c\u4e2d\u8fd9\u4e00\u8bef\u5dee\u7684\u5e73\u5747\u503c\u65f6\uff0c\u5b83\u88ab\u79f0\u4e3a MSLE\uff08\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee\uff09\uff0c\u5b9e\u73b0\u65b9\u6cd5\u5982\u4e0b\u3002 import numpy as np def mean_squared_log_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee error += ( np . log ( 1 + yt ) - np . log ( 1 + yp )) ** 2 # \u8ba1\u7b97\u5e73\u5747\u5e73\u65b9\u5bf9\u6570\u8bef\u5dee return error / len ( y_true ) \u5747\u65b9\u6839\u5bf9\u6570\u8bef\u5dee \u53ea\u662f\u5176\u5e73\u65b9\u6839\u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a RMSLE \u3002 \u7136\u540e\u662f\u767e\u5206\u6bd4\u8bef\u5dee\uff1a \\[ Percentage\\ Error = (( True\\ Value \u2013 Predicted\\ Value ) / True\\ Value ) \\times 100 \\] \u540c\u6837\u53ef\u4ee5\u8f6c\u6362\u4e3a\u6240\u6709\u6837\u672c\u7684\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee\u3002 def mean_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u767e\u5206\u6bd4\u8bef\u5dee error += ( yt - yp ) / yt # \u8fd4\u56de\u5e73\u5747\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u7edd\u5bf9\u8bef\u5dee\u7684\u7edd\u5bf9\u503c\uff08\u4e5f\u662f\u66f4\u5e38\u89c1\u7684\u7248\u672c\uff09\u88ab\u79f0\u4e3a \u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee\u6216 MAPE \u3002 import numpy as np def mean_abs_percentage_error ( y_true , y_pred ): # \u521d\u59cb\u5316\u8bef\u5dee error = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): # \u8ba1\u7b97\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee error += np . abs ( yt - yp ) / yt #\u8fd4\u56de\u5e73\u5747\u7edd\u5bf9\u767e\u5206\u6bd4\u8bef\u5dee return error / len ( y_true ) \u56de\u5f52\u7684\u6700\u5927\u4f18\u70b9\u662f\uff0c\u53ea\u6709\u51e0\u4e2a\u6700\u5e38\u7528\u7684\u6307\u6807\uff0c\u51e0\u4e4e\u53ef\u4ee5\u5e94\u7528\u4e8e\u6240\u6709\u56de\u5f52\u95ee\u9898\u3002\u4e0e\u5206\u7c7b\u6307\u6807\u76f8\u6bd4\uff0c\u56de\u5f52\u6307\u6807\u66f4\u5bb9\u6613\u7406\u89e3\u3002 \u8ba9\u6211\u4eec\u6765\u8c08\u8c08\u53e6\u4e00\u4e2a\u56de\u5f52\u6307\u6807 \\(R^2\\) \uff08R \u65b9\uff09\uff0c\u4e5f\u79f0\u4e3a \u5224\u5b9a\u7cfb\u6570 \u3002 \u7b80\u5355\u5730\u8bf4\uff0cR \u65b9\u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u3002R \u65b9\u63a5\u8fd1 1.0 \u8868\u793a\u6a21\u578b\u4e0e\u6570\u636e\u7684\u62df\u5408\u7a0b\u5ea6\u76f8\u5f53\u597d\uff0c\u800c\u63a5\u8fd1 0 \u5219\u8868\u793a\u6a21\u578b\u4e0d\u662f\u90a3\u4e48\u597d\u3002\u5f53\u6a21\u578b\u53ea\u662f\u505a\u51fa\u8352\u8c2c\u7684\u9884\u6d4b\u65f6\uff0cR \u65b9\u4e5f\u53ef\u80fd\u662f\u8d1f\u503c\u3002 R \u65b9\u7684\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\u6240\u793a\uff0c\u4f46 Python \u7684\u5b9e\u73b0\u603b\u662f\u80fd\u8ba9\u4e00\u5207\u66f4\u52a0\u6e05\u6670\u3002 \\[ R^2 = \\frac{\\sum^{N}_{i=1}(y_{t_i}-y_{p_i})^2}{\\sum^{N}_{i=1}(y_{t_i} - y_{t_{mean}})} \\] import numpy as np def r2 ( y_true , y_pred ): # \u8ba1\u7b97\u5e73\u5747\u771f\u5b9e\u503c mean_true_value = np . mean ( y_true ) # \u521d\u59cb\u5316\u5e73\u65b9\u8bef\u5dee numerator = 0 denominator = 0 # \u904d\u5386y_true, y_pred for yt , yp in zip ( y_true , y_pred ): numerator += ( yt - yp ) ** 2 denominator += ( yt - mean_true_value ) ** 2 ratio = numerator / denominator # \u8ba1\u7b97R\u65b9 return 1 \u2013 ratio \u8fd8\u6709\u66f4\u591a\u7684\u8bc4\u4ef7\u6307\u6807\uff0c\u8fd9\u4e2a\u6e05\u5355\u6c38\u8fdc\u4e5f\u5217\u4e0d\u5b8c\u3002\u6211\u53ef\u4ee5\u5199\u4e00\u672c\u4e66\uff0c\u53ea\u4ecb\u7ecd\u4e0d\u540c\u7684\u8bc4\u4ef7\u6307\u6807\u3002\u4e5f\u8bb8\u6211\u4f1a\u7684\u3002\u73b0\u5728\uff0c\u8fd9\u4e9b\u8bc4\u4f30\u6307\u6807\u51e0\u4e4e\u53ef\u4ee5\u6ee1\u8db3\u4f60\u60f3\u5c1d\u8bd5\u89e3\u51b3\u7684\u6240\u6709\u95ee\u9898\u3002\u8bf7\u6ce8\u610f\uff0c\u6211\u5df2\u7ecf\u4ee5\u6700\u76f4\u63a5\u7684\u65b9\u5f0f\u5b9e\u73b0\u4e86\u8fd9\u4e9b\u6307\u6807\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u4eec\u4e0d\u591f\u9ad8\u6548\u3002\u4f60\u53ef\u4ee5\u901a\u8fc7\u6b63\u786e\u4f7f\u7528 numpy \u4ee5\u975e\u5e38\u9ad8\u6548\u7684\u65b9\u5f0f\u5b9e\u73b0\u5176\u4e2d\u5927\u90e8\u5206\u6307\u6807\u3002\u4f8b\u5982\uff0c\u770b\u770b\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\u7684\u5b9e\u73b0\uff0c\u4e0d\u9700\u8981\u4efb\u4f55\u5faa\u73af\u3002 import numpy as np def mae_np ( y_true , y_pred ): return np . mean ( np . abs ( y_true - y_pred )) \u6211\u672c\u53ef\u4ee5\u7528\u8fd9\u79cd\u65b9\u6cd5\u5b9e\u73b0\u6240\u6709\u6307\u6807\uff0c\u4f46\u4e3a\u4e86\u5b66\u4e60\uff0c\u6700\u597d\u8fd8\u662f\u770b\u770b\u5e95\u5c42\u5b9e\u73b0\u3002\u4e00\u65e6\u4f60\u5b66\u4f1a\u4e86\u7eaf python \u7684\u5e95\u5c42\u5b9e\u73b0\uff0c\u5e76\u4e14\u4e0d\u4f7f\u7528\u5927\u91cf numpy\uff0c\u4f60\u5c31\u53ef\u4ee5\u5f88\u5bb9\u6613\u5730\u5c06\u5176\u8f6c\u6362\u4e3a numpy\uff0c\u5e76\u4f7f\u5176\u53d8\u5f97\u66f4\u5feb\u3002 \u7136\u540e\u662f\u4e00\u4e9b\u9ad8\u7ea7\u5ea6\u91cf\u3002 \u5176\u4e2d\u4e00\u4e2a\u5e94\u7528\u76f8\u5f53\u5e7f\u6cdb\u7684\u6307\u6807\u662f \u4e8c\u6b21\u52a0\u6743\u5361\u5e15 \uff0c\u4e5f\u79f0\u4e3a QWK \u3002\u5b83\u4e5f\u88ab\u79f0\u4e3a\u79d1\u6069\u5361\u5e15\u3002 QWK \u8861\u91cf\u4e24\u4e2a \"\u8bc4\u5206 \"\u4e4b\u95f4\u7684 \"\u4e00\u81f4\u6027\"\u3002\u8bc4\u5206\u53ef\u4ee5\u662f 0 \u5230 N \u4e4b\u95f4\u7684\u4efb\u4f55\u5b9e\u6570\uff0c\u9884\u6d4b\u4e5f\u5728\u540c\u4e00\u8303\u56f4\u5185\u3002\u4e00\u81f4\u6027\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u8fd9\u4e9b\u8bc4\u7ea7\u4e4b\u95f4\u7684\u63a5\u8fd1\u7a0b\u5ea6\u3002\u56e0\u6b64\uff0c\u5b83\u9002\u7528\u4e8e\u6709 N \u4e2a\u4e0d\u540c\u7c7b\u522b\u7684\u5206\u7c7b\u95ee\u9898\u3002\u5982\u679c\u4e00\u81f4\u5ea6\u9ad8\uff0c\u5206\u6570\u5c31\u66f4\u63a5\u8fd1 1.0\u3002Cohen's kappa \u5728 scikit-learn \u4e2d\u6709\u5f88\u597d\u7684\u5b9e\u73b0\uff0c\u5173\u4e8e\u8be5\u6307\u6807\u7684\u8be6\u7ec6\u8ba8\u8bba\u8d85\u51fa\u4e86\u672c\u4e66\u7684\u8303\u56f4\u3002 In [ X ]: from sklearn import metrics In [ X ]: y_true = [ 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 ] In [ X ]: y_pred = [ 2 , 1 , 3 , 1 , 2 , 3 , 3 , 1 , 2 ] In [ X ]: metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) Out [ X ]: 0.33333333333333337 In [ X ]: metrics . accuracy_score ( y_true , y_pred ) Out [ X ]: 0.4444444444444444 \u60a8\u53ef\u4ee5\u770b\u5230\uff0c\u5c3d\u7ba1\u51c6\u786e\u5ea6\u5f88\u9ad8\uff0c\u4f46 QWK \u5374\u5f88\u4f4e\u3002QWK \u5927\u4e8e 0.85 \u5373\u4e3a\u975e\u5e38\u597d\uff01 \u4e00\u4e2a\u91cd\u8981\u7684\u6307\u6807\u662f \u9a6c\u4fee\u76f8\u5173\u7cfb\u6570\uff08MCC\uff09 \u30021 \u4ee3\u8868\u5b8c\u7f8e\u9884\u6d4b\uff0c-1 \u4ee3\u8868\u4e0d\u5b8c\u7f8e\u9884\u6d4b\uff0c0 \u4ee3\u8868\u968f\u673a\u9884\u6d4b\u3002MCC \u7684\u8ba1\u7b97\u516c\u5f0f\u975e\u5e38\u7b80\u5355\u3002 \\[ MCC = \\frac{TP \\times TN - FP \\times FN}{\\sqrt{(TP + FP) \\times (FN + TN) \\times (FP + TN) \\times (TP + FN)}} \\] \u6211\u4eec\u770b\u5230\uff0cMCC \u8003\u8651\u4e86 TP\u3001FP\u3001TN \u548c FN\uff0c\u56e0\u6b64\u53ef\u7528\u4e8e\u5904\u7406\u7c7b\u504f\u659c\u7684\u95ee\u9898\u3002\u60a8\u53ef\u4ee5\u4f7f\u7528\u6211\u4eec\u5df2\u7ecf\u5b9e\u73b0\u7684\u65b9\u6cd5\u5728 python \u4e2d\u5feb\u901f\u5b9e\u73b0\u5b83\u3002 def mcc ( y_true , y_pred ): # \u771f\u9633\u6027\u6837\u672c\u6570 tp = true_positive ( y_true , y_pred ) # \u771f\u9634\u6027\u6837\u672c\u6570 tn = true_negative ( y_true , y_pred ) # \u5047\u9633\u6027\u6837\u672c\u6570 fp = false_positive ( y_true , y_pred ) # \u5047\u9634\u6027\u6837\u672c\u6570 fn = false_negative ( y_true , y_pred ) numerator = ( tp * tn ) - ( fp * fn ) denominator = ( ( tp + fp ) * ( fn + tn ) * ( fp + tn ) * ( tp + fn ) ) denominator = denominator ** 0.5 return numerator / denominator \u8fd9\u4e9b\u6307\u6807\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5165\u95e8\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u673a\u5668\u5b66\u4e60\u95ee\u9898\u3002 \u9700\u8981\u6ce8\u610f\u7684\u4e00\u70b9\u662f\uff0c\u5728\u8bc4\u4f30\u975e\u76d1\u7763\u65b9\u6cd5\uff08\u4f8b\u5982\u67d0\u79cd\u805a\u7c7b\uff09\u65f6\uff0c\u6700\u597d\u521b\u5efa\u6216\u624b\u52a8\u6807\u8bb0\u6d4b\u8bd5\u96c6\uff0c\u5e76\u5c06\u5176\u4e0e\u5efa\u6a21\u90e8\u5206\u7684\u6240\u6709\u5185\u5bb9\u5206\u5f00\u3002\u5b8c\u6210\u805a\u7c7b\u540e\uff0c\u5c31\u53ef\u4ee5\u4f7f\u7528\u4efb\u4f55\u4e00\u79cd\u76d1\u7763\u5b66\u4e60\u6307\u6807\u6765\u8bc4\u4f30\u6d4b\u8bd5\u96c6\u7684\u6027\u80fd\u4e86\u3002 \u4e00\u65e6\u6211\u4eec\u4e86\u89e3\u4e86\u7279\u5b9a\u95ee\u9898\u5e94\u8be5\u4f7f\u7528\u4ec0\u4e48\u6307\u6807\uff0c\u6211\u4eec\u5c31\u53ef\u4ee5\u5f00\u59cb\u66f4\u6df1\u5165\u5730\u7814\u7a76\u6211\u4eec\u7684\u6a21\u578b\uff0c\u4ee5\u6c42\u6539\u8fdb\u3002","title":"\u8bc4\u4f30\u6307\u6807"},{"location":"%E8%B6%85%E5%8F%82%E6%95%B0%E4%BC%98%E5%8C%96/","text":"\u8d85\u53c2\u6570\u4f18\u5316 \u6709\u4e86\u4f18\u79c0\u7684\u6a21\u578b\uff0c\u5c31\u6709\u4e86\u4f18\u5316\u8d85\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u5f97\u5206\u6a21\u578b\u7684\u96be\u9898\u3002\u90a3\u4e48\uff0c\u4ec0\u4e48\u662f\u8d85\u53c2\u6570\u4f18\u5316\u5462\uff1f\u5047\u8bbe\u60a8\u7684\u673a\u5668\u5b66\u4e60\u9879\u76ee\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u6d41\u7a0b\u3002\u6709\u4e00\u4e2a\u6570\u636e\u96c6\uff0c\u4f60\u76f4\u63a5\u5e94\u7528\u4e00\u4e2a\u6a21\u578b\uff0c\u7136\u540e\u5f97\u5230\u7ed3\u679c\u3002\u6a21\u578b\u5728\u8fd9\u91cc\u7684\u53c2\u6570\u88ab\u79f0\u4e3a\u8d85\u53c2\u6570\uff0c\u5373\u63a7\u5236\u6a21\u578b\u8bad\u7ec3/\u62df\u5408\u8fc7\u7a0b\u7684\u53c2\u6570\u3002\u5982\u679c\u6211\u4eec\u7528 SGD \u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u6a21\u578b\u7684\u53c2\u6570\u662f\u659c\u7387\u548c\u504f\u5dee\uff0c\u8d85\u53c2\u6570\u662f\u5b66\u4e60\u7387\u3002\u4f60\u4f1a\u53d1\u73b0\u6211\u5728\u672c\u7ae0\u548c\u672c\u4e66\u4e2d\u4ea4\u66ff\u4f7f\u7528\u8fd9\u4e9b\u672f\u8bed\u3002\u5047\u8bbe\u6a21\u578b\u4e2d\u6709\u4e09\u4e2a\u53c2\u6570 a\u3001b\u3001c\uff0c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u90fd\u53ef\u4ee5\u662f 1 \u5230 10 \u4e4b\u95f4\u7684\u6574\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u7684 \"\u6b63\u786e \"\u7ec4\u5408\u5c06\u4e3a\u60a8\u63d0\u4f9b\u6700\u4f73\u7ed3\u679c\u3002\u56e0\u6b64\uff0c\u8fd9\u5c31\u6709\u70b9\u50cf\u4e00\u4e2a\u88c5\u6709\u4e09\u62e8\u5bc6\u7801\u9501\u7684\u624b\u63d0\u7bb1\u3002\u4e0d\u8fc7\uff0c\u4e09\u62e8\u5bc6\u7801\u9501\u53ea\u6709\u4e00\u4e2a\u6b63\u786e\u7b54\u6848\u3002\u800c\u6a21\u578b\u6709\u5f88\u591a\u6b63\u786e\u7b54\u6848\u3002\u90a3\u4e48\uff0c\u5982\u4f55\u627e\u5230\u6700\u4f73\u53c2\u6570\u5462\uff1f\u4e00\u79cd\u65b9\u6cd5\u662f\u5bf9\u6240\u6709\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\uff0c\u770b\u54ea\u79cd\u7ec4\u5408\u80fd\u63d0\u9ad8\u6307\u6807\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u521d\u59cb\u5316\u6700\u4f73\u51c6\u786e\u5ea6 best_accuracy = 0 # \u521d\u59cb\u5316\u6700\u4f73\u53c2\u6570\u7684\u5b57\u5178 best_parameters = { \"a\" : 0 , \"b\" : 0 , \"c\" : 0 } # \u5faa\u73af\u904d\u5386 a \u7684\u53d6\u503c\u8303\u56f4 1~10 for a in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 b \u7684\u53d6\u503c\u8303\u56f4 1~10 for b in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 c \u7684\u53d6\u503c\u8303\u56f4 1~10 for c in range ( 1 , 11 ): # \u521b\u5efa\u6a21\u578b\uff0c\u4f7f\u7528 a\u3001b\u3001c \u53c2\u6570 model = MODEL ( a , b , c ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u6a21\u578b model . fit ( training_data ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u9a8c\u8bc1\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( validation_data ) # \u8ba1\u7b97\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 accuracy = metrics . accuracy_score ( targets , preds ) # \u5982\u679c\u5f53\u524d\u51c6\u786e\u5ea6\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u5219\u66f4\u65b0\u6700\u4f73\u51c6\u786e\u5ea6\u548c\u6700\u4f73\u53c2\u6570 if accuracy > best_accuracy : best_accuracy = accuracy best_parameters [ \"a\" ] = a best_parameters [ \"b\" ] = b best_parameters [ \"c\" ] = c \u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4ece 1 \u5230 10 \u5bf9\u6240\u6709\u53c2\u6570\u8fdb\u884c\u4e86\u62df\u5408\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u603b\u5171\u8981\u5bf9\u6a21\u578b\u8fdb\u884c 1000 \u6b21\uff0810 x 10 x 10\uff09\u62df\u5408\u3002\u8fd9\u53ef\u80fd\u4f1a\u5f88\u6602\u8d35\uff0c\u56e0\u4e3a\u6a21\u578b\u7684\u8bad\u7ec3\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u8fc7\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\u5e94\u8be5\u6ca1\u95ee\u9898\uff0c\u4f46\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\uff0c\u5e76\u4e0d\u662f\u53ea\u6709\u4e09\u4e2a\u53c2\u6570\uff0c\u6bcf\u4e2a\u53c2\u6570\u4e5f\u4e0d\u662f\u53ea\u6709\u5341\u4e2a\u503c\u3002 \u5927\u591a\u6570\u6a21\u578b\u53c2\u6570\u90fd\u662f\u5b9e\u6570\uff0c\u4e0d\u540c\u53c2\u6570\u7684\u7ec4\u5408\u53ef\u4ee5\u662f\u65e0\u9650\u7684\u3002 \u8ba9\u6211\u4eec\u770b\u770b scikit-learn \u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002 RandomForestClassifier ( n_estimators = 100 , criterion = 'gini' , max_depth = None , min_samples_split = 2 , min_samples_leaf = 1 , min_weight_fraction_leaf = 0.0 , max_features = 'auto' , max_leaf_nodes = None , min_impurity_decrease = 0.0 , min_impurity_split = None , bootstrap = True , oob_score = False , n_jobs = None , random_state = None , verbose = 0 , warm_start = False , class_weight = None , ccp_alpha = 0.0 , max_samples = None , ) \u6709 19 \u4e2a\u53c2\u6570\uff0c\u800c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u7684\u6240\u6709\u7ec4\u5408\uff0c\u4ee5\u53ca\u5b83\u4eec\u53ef\u4ee5\u627f\u62c5\u7684\u6240\u6709\u503c\uff0c\u90fd\u5c06\u662f\u65e0\u7a77\u65e0\u5c3d\u7684\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u6ca1\u6709\u8db3\u591f\u7684\u8d44\u6e90\u548c\u65f6\u95f4\u6765\u505a\u8fd9\u4ef6\u4e8b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6307\u5b9a\u4e86\u4e00\u4e2a\u53c2\u6570\u7f51\u683c\u3002\u5728\u8fd9\u4e2a\u7f51\u683c\u4e0a\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408\u7684\u641c\u7d22\u79f0\u4e3a\u7f51\u683c\u641c\u7d22\u3002\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0cn_estimators \u53ef\u4ee5\u662f 100\u3001200\u3001250\u3001300\u3001400\u3001500\uff1bmax_depth \u53ef\u4ee5\u662f 1\u30012\u30015\u30017\u300111\u300115\uff1bcriterion \u53ef\u4ee5\u662f gini \u6216 entropy\u3002\u8fd9\u4e9b\u53c2\u6570\u770b\u8d77\u6765\u5e76\u4e0d\u591a\uff0c\u4f46\u5982\u679c\u6570\u636e\u96c6\u8fc7\u5927\uff0c\u8ba1\u7b97\u8d77\u6765\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u53ef\u4ee5\u50cf\u4e4b\u524d\u4e00\u6837\u521b\u5efa\u4e09\u4e2a for \u5faa\u73af\uff0c\u5e76\u5728\u9a8c\u8bc1\u96c6\u4e0a\u8ba1\u7b97\u5f97\u5206\uff0c\u8fd9\u6837\u5c31\u80fd\u5b9e\u73b0\u7f51\u683c\u641c\u7d22\u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u8981\u8fdb\u884c k \u6298\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5219\u9700\u8981\u66f4\u591a\u7684\u5faa\u73af\uff0c\u8fd9\u610f\u5473\u7740\u9700\u8981\u66f4\u591a\u7684\u65f6\u95f4\u6765\u627e\u5230\u5b8c\u7f8e\u7684\u53c2\u6570\u3002\u56e0\u6b64\uff0c\u7f51\u683c\u641c\u7d22\u5e76\u4e0d\u6d41\u884c\u3002\u8ba9\u6211\u4eec\u4ee5\u6839\u636e \u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4 \u6570\u636e\u96c6\u4e3a\u4f8b\uff0c\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 \u56fe 1\uff1a\u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4\u6570\u636e\u96c6\u5c55\u793a \u8bad\u7ec3\u96c6\u4e2d\u53ea\u6709 2000 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u53ef\u4ee5\u8f7b\u677e\u5730\u4f7f\u7528\u5206\u5c42 kfold \u548c\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5177\u6709\u4e0a\u8ff0\u53c2\u6570\u8303\u56f4\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\uff0c\u5e76\u5728\u4e0b\u9762\u7684\u793a\u4f8b\u4e2d\u4e86\u89e3\u5982\u4f55\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u3002 # rf_grid_search.py import numpy as np import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u5220\u9664 price_range \u5217 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u53d6\u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\uff0c\u4f7f\u7528\u6240\u6709\u53ef\u7528\u7684 CPU \u6838\u5fc3\u8fdb\u884c\u8bad\u7ec3 classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid = { \"n_estimators\" : [ 100 , 200 , 250 , 300 , 400 , 500 ], \"max_depth\" : [ 1 , 2 , 5 , 7 , 11 , 15 ], \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22 model = model_selection . GridSearchCV ( estimator = classifier , param_grid = param_grid , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( f \"Best score: { model . best_score_ } \" ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u8fd9\u91cc\u6253\u5370\u4e86\u5f88\u591a\u5185\u5bb9\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u6700\u540e\u51e0\u884c\u3002 [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.895 , total = 1.0 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.890 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.910 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.880 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.870 , total = 1.1 s [ Parallel ( n_jobs = 1 )]: Done 360 out of 360 | elapsed : 3.7 min finished Best score : 0.889 Best parameters set : criterion : 'entropy' max_depth : 15 n_estimators : 500 \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c5\u6298\u4ea4\u53c9\u68c0\u9a8c\u6700\u4f73\u5f97\u5206\u662f 0.889\uff0c\u6211\u4eec\u7684\u7f51\u683c\u641c\u7d22\u5f97\u5230\u4e86\u6700\u4f73\u53c2\u6570\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u7684\u4e0b\u4e00\u4e2a\u6700\u4f73\u65b9\u6cd5\u662f \u968f\u673a\u641c\u7d22 \u3002\u5728\u968f\u673a\u641c\u7d22\u4e2d\uff0c\u6211\u4eec\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u53c2\u6570\u7ec4\u5408\uff0c\u7136\u540e\u8ba1\u7b97\u4ea4\u53c9\u9a8c\u8bc1\u5f97\u5206\u3002\u8fd9\u91cc\u6d88\u8017\u7684\u65f6\u95f4\u6bd4\u7f51\u683c\u641c\u7d22\u5c11\uff0c\u56e0\u4e3a\u6211\u4eec\u4e0d\u5bf9\u6240\u6709\u4e0d\u540c\u7684\u53c2\u6570\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\u3002\u6211\u4eec\u9009\u62e9\u8981\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u5c11\u6b21\u8bc4\u4f30\uff0c\u8fd9\u5c31\u51b3\u5b9a\u4e86\u641c\u7d22\u6240\u9700\u7684\u65f6\u95f4\u3002\u4ee3\u7801\u4e0e\u4e0a\u9762\u7684\u5dee\u522b\u4e0d\u5927\u3002\u9664 GridSearchCV \u5916\uff0c\u6211\u4eec\u4f7f\u7528 RandomizedSearchCV\u3002 if __name__ == \"__main__\" : classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u66f4\u6539\u641c\u7d22\u7a7a\u95f4 param_grid = { \"n_estimators\" : np . arange ( 100 , 1500 , 100 ), \"max_depth\" : np . arange ( 1 , 31 ), \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u968f\u673a\u53c2\u6570\u641c\u7d22 model = model_selection . RandomizedSearchCV ( estimator = classifier , param_distributions = param_grid , n_iter = 20 , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) print ( f \"Best score: { model . best_score_ } \" ) print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u6211\u4eec\u66f4\u6539\u4e86\u968f\u673a\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c\uff0c\u7ed3\u679c\u4f3c\u4e4e\u6709\u4e86\u4e9b\u8bb8\u6539\u8fdb\u3002 Best score : 0.8905 Best parameters set : criterion : entropy max_depth : 25 n_estimators : 300 \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8f83\u5c11\uff0c\u968f\u673a\u641c\u7d22\u6bd4\u7f51\u683c\u641c\u7d22\u66f4\u5feb\u3002\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\uff0c\u4f60\u53ef\u4ee5\u4e3a\u5404\u79cd\u6a21\u578b\u627e\u5230\u6700\u4f18\u53c2\u6570\uff0c\u53ea\u8981\u5b83\u4eec\u6709\u62df\u5408\u548c\u9884\u6d4b\u529f\u80fd\uff0c\u8fd9\u4e5f\u662f scikit-learn \u7684\u6807\u51c6\u3002\u6709\u65f6\uff0c\u4f60\u53ef\u80fd\u60f3\u4f7f\u7528\u7ba1\u9053\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u7531\u4e24\u5217\u6587\u672c\u7ec4\u6210\uff0c\u4f60\u9700\u8981\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u7c7b\u522b\u3002\u8ba9\u6211\u4eec\u5047\u8bbe\u4f60\u9009\u62e9\u7684\u7ba1\u9053\u662f\u9996\u5148\u4ee5\u534a\u76d1\u7763\u7684\u65b9\u5f0f\u5e94\u7528 tf-idf\uff0c\u7136\u540e\u4f7f\u7528 SVD \u548c SVM \u5206\u7c7b\u5668\u3002\u73b0\u5728\u7684\u95ee\u9898\u662f\uff0c\u6211\u4eec\u5fc5\u987b\u9009\u62e9 SVD \u7684\u6210\u5206\uff0c\u8fd8\u9700\u8981\u8c03\u6574 SVM \u7684\u53c2\u6570\u3002\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import numpy as np import pandas as pd from sklearn import metrics from sklearn import model_selection from sklearn import pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # \u8ba1\u7b97\u52a0\u6743\u4e8c\u6b21 Kappa \u5206\u6570 def quadratic_weighted_kappa ( y_true , y_pred ): return metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) if __name__ == '__main__' : # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( '../input/train.csv' ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u63d0\u53d6 id \u5217\u7684\u503c\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u7c7b\u578b\uff0c\u5b58\u50a8\u5728\u53d8\u91cf idx \u4e2d idx = test . id . values . astype ( int ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 train = train . drop ( 'id' , axis = 1 ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 test = test . drop ( 'id' , axis = 1 ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf 'relevance' \uff0c\u5b58\u50a8\u5728\u53d8\u91cf y \u4e2d y = train . relevance . values # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 traindata \u4e2d traindata = list ( train . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 testdata \u4e2d testdata = list ( test . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u521b\u5efa\u4e00\u4e2a TfidfVectorizer \u5bf9\u8c61 tfv\uff0c\u7528\u4e8e\u5c06\u6587\u672c\u6570\u636e\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv = TfidfVectorizer ( min_df = 3 , max_features = None , strip_accents = 'unicode' , analyzer = 'word' , token_pattern = r '\\w{1,}' , ngram_range = ( 1 , 3 ), use_idf = 1 , smooth_idf = 1 , sublinear_tf = 1 , stop_words = 'english' ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408 TfidfVectorizer\uff0c\u5c06\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv . fit ( traindata ) # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X X = tfv . transform ( traindata ) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X_test X_test = tfv . transform ( testdata ) # \u521b\u5efa TruncatedSVD \u5bf9\u8c61 svd\uff0c\u7528\u4e8e\u8fdb\u884c\u5947\u5f02\u503c\u5206\u89e3 svd = TruncatedSVD () # \u521b\u5efa StandardScaler \u5bf9\u8c61 scl\uff0c\u7528\u4e8e\u8fdb\u884c\u7279\u5f81\u7f29\u653e scl = StandardScaler () # \u521b\u5efa\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668\u5bf9\u8c61 svm_model svm_model = SVC () # \u521b\u5efa\u673a\u5668\u5b66\u4e60\u7ba1\u9053 clf\uff0c\u5305\u542b\u5947\u5f02\u503c\u5206\u89e3\u3001\u7279\u5f81\u7f29\u653e\u548c\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668 clf = pipeline . Pipeline ( [ ( 'svd' , svd ), ( 'scl' , scl ), ( 'svm' , svm_model ) ] ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid param_grid = { 'svd__n_components' : [ 200 , 300 ], 'svm__C' : [ 10 , 12 ] } # \u521b\u5efa\u81ea\u5b9a\u4e49\u7684\u8bc4\u5206\u51fd\u6570 kappa_scorer\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd kappa_scorer = metrics . make_scorer ( quadratic_weighted_kappa , greater_is_better = True ) # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model = model_selection . GridSearchCV ( estimator = clf , param_grid = param_grid , scoring = kappa_scorer , verbose = 10 , n_jobs =- 1 , refit = True , cv = 5 ) # \u4f7f\u7528 GridSearchCV \u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( \"Best score: %0.3f \" % model . best_score_ ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( \" \\t %s : %r \" % ( param_name , best_parameters [ param_name ])) # \u83b7\u53d6\u6700\u4f73\u6a21\u578b best_model = model . best_estimator_ best_model . fit ( X , y ) # \u4f7f\u7528\u6700\u4f73\u6a21\u578b\u8fdb\u884c\u9884\u6d4b preds = best_model . predict ( ... ) \u8fd9\u91cc\u663e\u793a\u7684\u7ba1\u9053\u5305\u62ec SVD\uff08\u5947\u5f02\u503c\u5206\u89e3\uff09\u3001\u6807\u51c6\u7f29\u653e\u548c SVM\uff08\u652f\u6301\u5411\u91cf\u673a\uff09\u6a21\u578b\u3002\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e\u6ca1\u6709\u8bad\u7ec3\u6570\u636e\uff0c\u60a8\u65e0\u6cd5\u6309\u539f\u6837\u8fd0\u884c\u4e0a\u8ff0\u4ee3\u7801\u3002\u5f53\u6211\u4eec\u8fdb\u5165\u9ad8\u7ea7\u8d85\u53c2\u6570\u4f18\u5316\u6280\u672f\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4e0d\u540c\u7c7b\u578b\u7684 \u6700\u5c0f\u5316\u7b97\u6cd5 \u6765\u7814\u7a76\u51fd\u6570\u7684\u6700\u5c0f\u5316\u3002\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u591a\u79cd\u6700\u5c0f\u5316\u51fd\u6570\u6765\u5b9e\u73b0\uff0c\u5982\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u3001\u5185\u5c14\u5fb7-\u6885\u5fb7\u4f18\u5316\u7b97\u6cd5\u3001\u4f7f\u7528\u8d1d\u53f6\u65af\u6280\u672f\u548c\u9ad8\u65af\u8fc7\u7a0b\u5bfb\u627e\u6700\u4f18\u53c2\u6570\u6216\u4f7f\u7528\u9057\u4f20\u7b97\u6cd5\u3002\u6211\u5c06\u5728 \"\u96c6\u5408\u4e0e\u5806\u53e0\uff08ensembling and stacking\uff09 \"\u4e00\u7ae0\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u548c Nelder-Mead \u7b97\u6cd5\u7684\u5e94\u7528\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u9ad8\u65af\u8fc7\u7a0b\u5982\u4f55\u7528\u4e8e\u8d85\u53c2\u6570\u4f18\u5316\u3002\u8fd9\u7c7b\u7b97\u6cd5\u9700\u8981\u4e00\u4e2a\u53ef\u4ee5\u4f18\u5316\u7684\u51fd\u6570\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u90fd\u662f\u6700\u5c0f\u5316\u8fd9\u4e2a\u51fd\u6570\uff0c\u5c31\u50cf\u6211\u4eec\u6700\u5c0f\u5316\u635f\u5931\u4e00\u6837\u3002 \u56e0\u6b64\uff0c\u6bd4\u65b9\u8bf4\uff0c\u4f60\u60f3\u627e\u5230\u6700\u4f73\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u663e\u7136\uff0c\u51c6\u786e\u5ea6\u8d8a\u9ad8\u8d8a\u597d\u3002\u73b0\u5728\uff0c\u6211\u4eec\u4e0d\u80fd\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\uff0c\u4f46\u6211\u4eec\u53ef\u4ee5\u5c06\u7cbe\u786e\u5ea6\u4e58\u4ee5-1\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u662f\u5728\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\u7684\u8d1f\u503c\uff0c\u4f46\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u6700\u5927\u5316\u7cbe\u786e\u5ea6\u3002 \u5728\u9ad8\u65af\u8fc7\u7a0b\u4e2d\u4f7f\u7528\u8d1d\u53f6\u65af\u4f18\u5316\uff0c\u53ef\u4ee5\u4f7f\u7528 scikit-optimize (skopt) \u5e93\u4e2d\u7684 gp_minimize \u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4f7f\u7528\u8be5\u51fd\u6570\u8c03\u6574\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u53c2\u6570\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from skopt import gp_minimize from skopt import space def optimize ( params , param_names , x , y ): # \u5c06\u53c2\u6570\u540d\u79f0\u548c\u5bf9\u5e94\u7684\u503c\u6253\u5305\u6210\u5b57\u5178 params = dict ( zip ( param_names , params )) # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u6a21\u578b\uff0c\u4f7f\u7528\u4f20\u5165\u7684\u53c2\u6570\u914d\u7f6e model = ensemble . RandomForestClassifier ( ** params ) # \u521b\u5efa StratifiedKFold \u4ea4\u53c9\u9a8c\u8bc1\u5bf9\u8c61\uff0c\u5c06\u6570\u636e\u5206\u4e3a 5 \u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6298\u53e0\u7684\u51c6\u786e\u5ea6\u7684\u5217\u8868 accuracies = [] # \u5faa\u73af\u904d\u5386\u6bcf\u4e2a\u6298\u53e0\u7684\u8bad\u7ec3\u548c\u6d4b\u8bd5\u6570\u636e for idx in kf . split ( X = x , y = y ): train_idx , test_idx = idx [ 0 ], idx [ 1 ] xtrain = x [ train_idx ] ytrain = y [ train_idx ] xtest = x [ test_idx ] ytest = y [ test_idx ] # \u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u62df\u5408\u6a21\u578b model . fit ( xtrain , ytrain ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( xtest ) # \u8ba1\u7b97\u6298\u53e0\u7684\u51c6\u786e\u5ea6 fold_accuracy = metrics . accuracy_score ( ytest , preds ) accuracies . append ( fold_accuracy ) # \u8fd4\u56de\u5e73\u5747\u51c6\u786e\u5ea6\u7684\u8d1f\u6570\uff08\u56e0\u4e3a skopt \u4f7f\u7528\u8d1f\u6570\u6765\u6700\u5c0f\u5316\u76ee\u6807\u51fd\u6570\uff09 return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u53d6\u7279\u5f81\u77e9\u9635 X\uff08\u53bb\u6389\"price_range\"\u5217\uff09 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u5b9a\u4e49\u8d85\u53c2\u6570\u641c\u7d22\u7a7a\u95f4 param_space param_space = [ space . Integer ( 3 , 15 , name = \"max_depth\" ), space . Integer ( 100 , 1500 , name = \"n_estimators\" ), space . Categorical ([ \"gini\" , \"entropy\" ], name = \"criterion\" ), space . Real ( 0.01 , 1 , prior = \"uniform\" , name = \"max_features\" ) ] # \u5b9a\u4e49\u8d85\u53c2\u6570\u7684\u540d\u79f0\u5217\u8868 param_names param_names = [ \"max_depth\" , \"n_estimators\" , \"criterion\" , \"max_features\" ] # \u521b\u5efa\u51fd\u6570 optimization_function\uff0c\u7528\u4e8e\u4f20\u9012\u7ed9 gp_minimize optimization_function = partial ( optimize , param_names = param_names , x = X , y = y ) # \u4f7f\u7528 Bayesian Optimization\uff08\u57fa\u4e8e\u8d1d\u53f6\u65af\u4f18\u5316\uff09\u6765\u641c\u7d22\u6700\u4f73\u8d85\u53c2\u6570 result = gp_minimize ( optimization_function , dimensions = param_space , n_calls = 15 , n_random_starts = 10 , verbose = 10 ) # \u83b7\u53d6\u6700\u4f73\u8d85\u53c2\u6570\u7684\u5b57\u5178 best_params = dict ( zip ( param_names , result . x ) ) # \u6253\u5370\u51fa\u627e\u5230\u7684\u6700\u4f73\u8d85\u53c2\u6570 print ( best_params ) \u8fd9\u540c\u6837\u4f1a\u4ea7\u751f\u5927\u91cf\u8f93\u51fa\uff0c\u6700\u540e\u4e00\u90e8\u5206\u5982\u4e0b\u6240\u793a\u3002 Iteration No : 14 started . Searching for the next optimal point . Iteration No : 14 ended . Search finished for the next optimal point . Time taken : 4.7793 Function value obtained : - 0.9075 Current minimum : - 0.9075 Iteration No : 15 started . Searching for the next optimal point . Iteration No : 15 ended . Search finished for the next optimal point . Time taken : 49.4186 Function value obtained : - 0.9075 Current minimum : - 0.9075 { 'max_depth' : 12 , 'n_estimators' : 100 , 'criterion' : 'entropy' , 'max_features' : 1.0 } \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6210\u529f\u7a81\u7834\u4e86 0.90 \u7684\u51c6\u786e\u7387\u3002\u8fd9\u771f\u662f\u592a\u795e\u5947\u4e86\uff01 \u6211\u4eec\u8fd8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6bb5\u67e5\u770b\uff08\u7ed8\u5236\uff09\u6211\u4eec\u662f\u5982\u4f55\u5b9e\u73b0\u6536\u655b\u7684\u3002 from skopt.plots import plot_convergence plot_convergence ( result ) \u6536\u655b\u56fe\u5982\u56fe 2 \u6240\u793a\u3002 \u56fe 2\uff1a\u968f\u673a\u68ee\u6797\u53c2\u6570\u4f18\u5316\u7684\u6536\u655b\u56fe Scikit- optimize \u5c31\u662f\u8fd9\u6837\u4e00\u4e2a\u5e93\u3002 hyperopt \u4f7f\u7528\u6811\u72b6\u7ed3\u6784\u8d1d\u53f6\u65af\u4f30\u8ba1\u5668\uff08TPE\uff09\u6765\u627e\u5230\u6700\u4f18\u53c2\u6570\u3002\u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u7247\u6bb5\uff0c\u6211\u5728\u4f7f\u7528 hyperopt \u65f6\u5bf9\u4e4b\u524d\u7684\u4ee3\u7801\u505a\u4e86\u6700\u5c0f\u7684\u6539\u52a8\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from hyperopt import hp , fmin , tpe , Trials from hyperopt.pyll.base import scope def optimize ( params , x , y ): model = ensemble . RandomForestClassifier ( ** params ) kf = model_selection . StratifiedKFold ( n_splits = 5 ) ... return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/mobile_train.csv\" ) X = df . drop ( \"price_range\" , axis = 1 ) . values y = df . price_range . values # \u5b9a\u4e49\u641c\u7d22\u7a7a\u95f4\uff08\u6574\u578b\u3001\u6d6e\u70b9\u6570\u578b\u3001\u9009\u62e9\u578b\uff09 param_space = { \"max_depth\" : scope . int ( hp . quniform ( \"max_depth\" , 1 , 15 , 1 )), \"n_estimators\" : scope . int ( hp . quniform ( \"n_estimators\" , 100 , 1500 , 1 ) ), \"criterion\" : hp . choice ( \"criterion\" , [ \"gini\" , \"entropy\" ]), \"max_features\" : hp . uniform ( \"max_features\" , 0 , 1 ) } # \u5305\u88c5\u51fd\u6570 optimization_function = partial ( optimize , x = X , y = y ) # \u5f00\u59cb\u8bad\u7ec3 trials = Trials () # \u6700\u5c0f\u5316\u76ee\u6807\u503c hopt = fmin ( fn = optimization_function , space = param_space , algo = tpe . suggest , max_evals = 15 , trials = trials ) #\u6253\u5370\u6700\u4f73\u53c2\u6570 print ( hopt ) \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u4e0e\u4e4b\u524d\u7684\u4ee3\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u4f60\u5fc5\u987b\u4ee5\u4e0d\u540c\u7684\u683c\u5f0f\u5b9a\u4e49\u53c2\u6570\u7a7a\u95f4\uff0c\u8fd8\u9700\u8981\u6539\u53d8\u5b9e\u9645\u4f18\u5316\u90e8\u5206\uff0c\u7528 hyperopt \u4ee3\u66ff gp_minimize\u3002\u7ed3\u679c\u76f8\u5f53\u4e0d\u9519\uff01 \u276f python rf_hyperopt . py 100 %| \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 15 / 15 [ 04 : 38 < 00 : 00 , 18.57 s / trial , best loss : - 0.9095000000000001 ] { 'criterion' : 1 , 'max_depth' : 11.0 , 'max_features' : 0.821163568049807 , 'n_estimators' : 806.0 } \u6211\u4eec\u5f97\u5230\u4e86\u6bd4\u4ee5\u524d\u66f4\u597d\u7684\u51c6\u786e\u5ea6\u548c\u4e00\u7ec4\u53ef\u4ee5\u4f7f\u7528\u7684\u53c2\u6570\u3002\u8bf7\u6ce8\u610f\uff0c\u6700\u7ec8\u7ed3\u679c\u4e2d\u7684\u6807\u51c6\u662f 1\u3002\u8fd9\u610f\u5473\u7740\u9009\u62e9\u4e86 1\uff0c\u5373\u71b5\u3002 \u4e0a\u8ff0\u8c03\u6574\u8d85\u53c2\u6570\u7684\u65b9\u6cd5\u662f\u6700\u5e38\u89c1\u7684\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u6a21\u578b\uff1a\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001\u57fa\u4e8e\u6811\u7684\u65b9\u6cd5\u3001\u68af\u5ea6\u63d0\u5347\u6a21\u578b\uff08\u5982 xgboost\u3001lightgbm\uff09\uff0c\u751a\u81f3\u795e\u7ecf\u7f51\u7edc\uff01 \u867d\u7136\u8fd9\u4e9b\u65b9\u6cd5\u5df2\u7ecf\u5b58\u5728\uff0c\u4f46\u5b66\u4e60\u65f6\u5fc5\u987b\u4ece\u624b\u52a8\u8c03\u6574\u8d85\u53c2\u6570\u5f00\u59cb\uff0c\u5373\u624b\u5de5\u8c03\u6574\u3002\u624b\u52a8\u8c03\u6574\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5b66\u4e60\u57fa\u7840\u77e5\u8bc6\uff0c\u4f8b\u5982\uff0c\u5728\u68af\u5ea6\u63d0\u5347\u4e2d\uff0c\u5f53\u4f60\u589e\u52a0\u6df1\u5ea6\u65f6\uff0c\u4f60\u5e94\u8be5\u964d\u4f4e\u5b66\u4e60\u7387\u3002\u5982\u679c\u4f7f\u7528\u81ea\u52a8\u5de5\u5177\uff0c\u5c31\u65e0\u6cd5\u5b66\u4e60\u5230\u8fd9\u4e00\u70b9\u3002\u8bf7\u53c2\u8003\u4e0b\u8868\uff0c\u4e86\u89e3\u5e94\u5982\u4f55\u8c03\u6574\u3002RS* \u8868\u793a\u968f\u673a\u641c\u7d22\u5e94\u8be5\u66f4\u597d\u3002 \u4e00\u65e6\u4f60\u80fd\u66f4\u597d\u5730\u624b\u52a8\u8c03\u6574\u53c2\u6570\uff0c\u4f60\u751a\u81f3\u53ef\u80fd\u4e0d\u9700\u8981\u4efb\u4f55\u81ea\u52a8\u8d85\u53c2\u6570\u8c03\u6574\u3002\u521b\u5efa\u5927\u578b\u6a21\u578b\u6216\u5f15\u5165\u5927\u91cf\u7279\u5f81\u65f6\uff0c\u4e5f\u5bb9\u6613\u9020\u6210\u8bad\u7ec3\u6570\u636e\u7684\u8fc7\u5ea6\u62df\u5408\u3002\u4e3a\u907f\u514d\u8fc7\u5ea6\u62df\u5408\uff0c\u9700\u8981\u5728\u8bad\u7ec3\u6570\u636e\u7279\u5f81\u4e2d\u5f15\u5165\u566a\u58f0\u6216\u5bf9\u4ee3\u4ef7\u51fd\u6570\u8fdb\u884c\u60e9\u7f5a\u3002\u8fd9\u79cd\u60e9\u7f5a\u79f0\u4e3a \u6b63\u5219\u5316 \uff0c\u6709\u52a9\u4e8e\u6cdb\u5316\u6a21\u578b\u3002\u5728\u7ebf\u6027\u6a21\u578b\u4e2d\uff0c\u6700\u5e38\u89c1\u7684\u6b63\u5219\u5316\u7c7b\u578b\u662f L1 \u548c L2\u3002L1 \u4e5f\u79f0\u4e3a Lasso \u56de\u5f52\uff0cL2 \u79f0\u4e3a Ridge \u56de\u5f52\u3002\u8bf4\u5230\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528dropout\u3001\u6dfb\u52a0\u589e\u5f3a\u3001\u566a\u58f0\u7b49\u65b9\u6cd5\u5bf9\u6a21\u578b\u8fdb\u884c\u6b63\u5219\u5316\u3002\u5229\u7528\u8d85\u53c2\u6570\u4f18\u5316\uff0c\u8fd8\u53ef\u4ee5\u627e\u5230\u6b63\u786e\u7684\u60e9\u7f5a\u65b9\u6cd5\u3002 Model Optimize Range of values Linear Regression - fit_intercept - normalize - True/False - True/False Ridge - alpha - fit_intercept - normalize - 0.01, 0.1, 1.0, 10, 100 - True/False - True/False k-neighbors - n_neighbors - p - 2, 4, 8, 16, ... - 2, 3, ... SVM - C - gamma - class_weight - 0.001, 0.01, ...,10, 100, 1000 - 'auto', RS* - 'balanced', None Logistic Regression - Penalyt - C - L1 or L2 - 0.001, 0.01, ..., 10, ..., 100 Lasso - Alpha - Normalize - 0.1, 1.0, 10 - True/False Random Forest - n_estimators - max_depth - min_samples_split - min_samples_leaf - max features - 120, 300, 500, 800, 1200 - 5, 8, 15, 25, 30, None - 1, 2, 5, 10, 15, 100 - log2, sqrt, None XGBoost - eta - gamma - max_depth - min_child_weight - subsample - colsample_bytree - lambda - alpha - 0.01, 0.015, 0.025, 0.05, 0.1 - 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0 - 3, 5, 7, 9, 12, 15, 17, 25 - 1, 3, 5, 7 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.01, 0.1, 1.0, RS - 0, 0.1, 0.5, 1.0, RS","title":"\u8d85\u53c2\u6570\u4f18\u5316"},{"location":"%E8%B6%85%E5%8F%82%E6%95%B0%E4%BC%98%E5%8C%96/#_1","text":"\u6709\u4e86\u4f18\u79c0\u7684\u6a21\u578b\uff0c\u5c31\u6709\u4e86\u4f18\u5316\u8d85\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u5f97\u5206\u6a21\u578b\u7684\u96be\u9898\u3002\u90a3\u4e48\uff0c\u4ec0\u4e48\u662f\u8d85\u53c2\u6570\u4f18\u5316\u5462\uff1f\u5047\u8bbe\u60a8\u7684\u673a\u5668\u5b66\u4e60\u9879\u76ee\u6709\u4e00\u4e2a\u7b80\u5355\u7684\u6d41\u7a0b\u3002\u6709\u4e00\u4e2a\u6570\u636e\u96c6\uff0c\u4f60\u76f4\u63a5\u5e94\u7528\u4e00\u4e2a\u6a21\u578b\uff0c\u7136\u540e\u5f97\u5230\u7ed3\u679c\u3002\u6a21\u578b\u5728\u8fd9\u91cc\u7684\u53c2\u6570\u88ab\u79f0\u4e3a\u8d85\u53c2\u6570\uff0c\u5373\u63a7\u5236\u6a21\u578b\u8bad\u7ec3/\u62df\u5408\u8fc7\u7a0b\u7684\u53c2\u6570\u3002\u5982\u679c\u6211\u4eec\u7528 SGD \u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u6a21\u578b\u7684\u53c2\u6570\u662f\u659c\u7387\u548c\u504f\u5dee\uff0c\u8d85\u53c2\u6570\u662f\u5b66\u4e60\u7387\u3002\u4f60\u4f1a\u53d1\u73b0\u6211\u5728\u672c\u7ae0\u548c\u672c\u4e66\u4e2d\u4ea4\u66ff\u4f7f\u7528\u8fd9\u4e9b\u672f\u8bed\u3002\u5047\u8bbe\u6a21\u578b\u4e2d\u6709\u4e09\u4e2a\u53c2\u6570 a\u3001b\u3001c\uff0c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u90fd\u53ef\u4ee5\u662f 1 \u5230 10 \u4e4b\u95f4\u7684\u6574\u6570\u3002\u8fd9\u4e9b\u53c2\u6570\u7684 \"\u6b63\u786e \"\u7ec4\u5408\u5c06\u4e3a\u60a8\u63d0\u4f9b\u6700\u4f73\u7ed3\u679c\u3002\u56e0\u6b64\uff0c\u8fd9\u5c31\u6709\u70b9\u50cf\u4e00\u4e2a\u88c5\u6709\u4e09\u62e8\u5bc6\u7801\u9501\u7684\u624b\u63d0\u7bb1\u3002\u4e0d\u8fc7\uff0c\u4e09\u62e8\u5bc6\u7801\u9501\u53ea\u6709\u4e00\u4e2a\u6b63\u786e\u7b54\u6848\u3002\u800c\u6a21\u578b\u6709\u5f88\u591a\u6b63\u786e\u7b54\u6848\u3002\u90a3\u4e48\uff0c\u5982\u4f55\u627e\u5230\u6700\u4f73\u53c2\u6570\u5462\uff1f\u4e00\u79cd\u65b9\u6cd5\u662f\u5bf9\u6240\u6709\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\uff0c\u770b\u54ea\u79cd\u7ec4\u5408\u80fd\u63d0\u9ad8\u6307\u6807\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 # \u521d\u59cb\u5316\u6700\u4f73\u51c6\u786e\u5ea6 best_accuracy = 0 # \u521d\u59cb\u5316\u6700\u4f73\u53c2\u6570\u7684\u5b57\u5178 best_parameters = { \"a\" : 0 , \"b\" : 0 , \"c\" : 0 } # \u5faa\u73af\u904d\u5386 a \u7684\u53d6\u503c\u8303\u56f4 1~10 for a in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 b \u7684\u53d6\u503c\u8303\u56f4 1~10 for b in range ( 1 , 11 ): # \u5faa\u73af\u904d\u5386 c \u7684\u53d6\u503c\u8303\u56f4 1~10 for c in range ( 1 , 11 ): # \u521b\u5efa\u6a21\u578b\uff0c\u4f7f\u7528 a\u3001b\u3001c \u53c2\u6570 model = MODEL ( a , b , c ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408\u6a21\u578b model . fit ( training_data ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u9a8c\u8bc1\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( validation_data ) # \u8ba1\u7b97\u9884\u6d4b\u7684\u51c6\u786e\u5ea6 accuracy = metrics . accuracy_score ( targets , preds ) # \u5982\u679c\u5f53\u524d\u51c6\u786e\u5ea6\u4f18\u4e8e\u4e4b\u524d\u7684\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u5219\u66f4\u65b0\u6700\u4f73\u51c6\u786e\u5ea6\u548c\u6700\u4f73\u53c2\u6570 if accuracy > best_accuracy : best_accuracy = accuracy best_parameters [ \"a\" ] = a best_parameters [ \"b\" ] = b best_parameters [ \"c\" ] = c \u5728\u4e0a\u8ff0\u4ee3\u7801\u4e2d\uff0c\u6211\u4eec\u4ece 1 \u5230 10 \u5bf9\u6240\u6709\u53c2\u6570\u8fdb\u884c\u4e86\u62df\u5408\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u603b\u5171\u8981\u5bf9\u6a21\u578b\u8fdb\u884c 1000 \u6b21\uff0810 x 10 x 10\uff09\u62df\u5408\u3002\u8fd9\u53ef\u80fd\u4f1a\u5f88\u6602\u8d35\uff0c\u56e0\u4e3a\u6a21\u578b\u7684\u8bad\u7ec3\u9700\u8981\u5f88\u957f\u65f6\u95f4\u3002\u4e0d\u8fc7\uff0c\u5728\u8fd9\u79cd\u60c5\u51b5\u4e0b\u5e94\u8be5\u6ca1\u95ee\u9898\uff0c\u4f46\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\uff0c\u5e76\u4e0d\u662f\u53ea\u6709\u4e09\u4e2a\u53c2\u6570\uff0c\u6bcf\u4e2a\u53c2\u6570\u4e5f\u4e0d\u662f\u53ea\u6709\u5341\u4e2a\u503c\u3002 \u5927\u591a\u6570\u6a21\u578b\u53c2\u6570\u90fd\u662f\u5b9e\u6570\uff0c\u4e0d\u540c\u53c2\u6570\u7684\u7ec4\u5408\u53ef\u4ee5\u662f\u65e0\u9650\u7684\u3002 \u8ba9\u6211\u4eec\u770b\u770b scikit-learn \u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\u3002 RandomForestClassifier ( n_estimators = 100 , criterion = 'gini' , max_depth = None , min_samples_split = 2 , min_samples_leaf = 1 , min_weight_fraction_leaf = 0.0 , max_features = 'auto' , max_leaf_nodes = None , min_impurity_decrease = 0.0 , min_impurity_split = None , bootstrap = True , oob_score = False , n_jobs = None , random_state = None , verbose = 0 , warm_start = False , class_weight = None , ccp_alpha = 0.0 , max_samples = None , ) \u6709 19 \u4e2a\u53c2\u6570\uff0c\u800c\u6240\u6709\u8fd9\u4e9b\u53c2\u6570\u7684\u6240\u6709\u7ec4\u5408\uff0c\u4ee5\u53ca\u5b83\u4eec\u53ef\u4ee5\u627f\u62c5\u7684\u6240\u6709\u503c\uff0c\u90fd\u5c06\u662f\u65e0\u7a77\u65e0\u5c3d\u7684\u3002\u901a\u5e38\u60c5\u51b5\u4e0b\uff0c\u6211\u4eec\u6ca1\u6709\u8db3\u591f\u7684\u8d44\u6e90\u548c\u65f6\u95f4\u6765\u505a\u8fd9\u4ef6\u4e8b\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u6307\u5b9a\u4e86\u4e00\u4e2a\u53c2\u6570\u7f51\u683c\u3002\u5728\u8fd9\u4e2a\u7f51\u683c\u4e0a\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408\u7684\u641c\u7d22\u79f0\u4e3a\u7f51\u683c\u641c\u7d22\u3002\u6211\u4eec\u53ef\u4ee5\u8bf4\uff0cn_estimators \u53ef\u4ee5\u662f 100\u3001200\u3001250\u3001300\u3001400\u3001500\uff1bmax_depth \u53ef\u4ee5\u662f 1\u30012\u30015\u30017\u300111\u300115\uff1bcriterion \u53ef\u4ee5\u662f gini \u6216 entropy\u3002\u8fd9\u4e9b\u53c2\u6570\u770b\u8d77\u6765\u5e76\u4e0d\u591a\uff0c\u4f46\u5982\u679c\u6570\u636e\u96c6\u8fc7\u5927\uff0c\u8ba1\u7b97\u8d77\u6765\u4f1a\u8017\u8d39\u5927\u91cf\u65f6\u95f4\u3002\u6211\u4eec\u53ef\u4ee5\u50cf\u4e4b\u524d\u4e00\u6837\u521b\u5efa\u4e09\u4e2a for \u5faa\u73af\uff0c\u5e76\u5728\u9a8c\u8bc1\u96c6\u4e0a\u8ba1\u7b97\u5f97\u5206\uff0c\u8fd9\u6837\u5c31\u80fd\u5b9e\u73b0\u7f51\u683c\u641c\u7d22\u3002\u8fd8\u5fc5\u987b\u6ce8\u610f\u7684\u662f\uff0c\u5982\u679c\u8981\u8fdb\u884c k \u6298\u4ea4\u53c9\u9a8c\u8bc1\uff0c\u5219\u9700\u8981\u66f4\u591a\u7684\u5faa\u73af\uff0c\u8fd9\u610f\u5473\u7740\u9700\u8981\u66f4\u591a\u7684\u65f6\u95f4\u6765\u627e\u5230\u5b8c\u7f8e\u7684\u53c2\u6570\u3002\u56e0\u6b64\uff0c\u7f51\u683c\u641c\u7d22\u5e76\u4e0d\u6d41\u884c\u3002\u8ba9\u6211\u4eec\u4ee5\u6839\u636e \u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4 \u6570\u636e\u96c6\u4e3a\u4f8b\uff0c\u770b\u770b\u5b83\u662f\u5982\u4f55\u5b9e\u73b0\u7684\u3002 \u56fe 1\uff1a\u624b\u673a\u914d\u7f6e\u9884\u6d4b\u624b\u673a\u4ef7\u683c\u8303\u56f4\u6570\u636e\u96c6\u5c55\u793a \u8bad\u7ec3\u96c6\u4e2d\u53ea\u6709 2000 \u4e2a\u6837\u672c\u3002\u6211\u4eec\u53ef\u4ee5\u8f7b\u677e\u5730\u4f7f\u7528\u5206\u5c42 kfold \u548c\u51c6\u786e\u7387\u4f5c\u4e3a\u8bc4\u4f30\u6307\u6807\u3002\u6211\u4eec\u5c06\u4f7f\u7528\u5177\u6709\u4e0a\u8ff0\u53c2\u6570\u8303\u56f4\u7684\u968f\u673a\u68ee\u6797\u6a21\u578b\uff0c\u5e76\u5728\u4e0b\u9762\u7684\u793a\u4f8b\u4e2d\u4e86\u89e3\u5982\u4f55\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u3002 # rf_grid_search.py import numpy as np import pandas as pd from sklearn import ensemble from sklearn import metrics from sklearn import model_selection if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u5220\u9664 price_range \u5217 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u53d6\u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\uff0c\u4f7f\u7528\u6240\u6709\u53ef\u7528\u7684 CPU \u6838\u5fc3\u8fdb\u884c\u8bad\u7ec3 classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid = { \"n_estimators\" : [ 100 , 200 , 250 , 300 , 400 , 500 ], \"max_depth\" : [ 1 , 2 , 5 , 7 , 11 , 15 ], \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22 model = model_selection . GridSearchCV ( estimator = classifier , param_grid = param_grid , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( f \"Best score: { model . best_score_ } \" ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u8fd9\u91cc\u6253\u5370\u4e86\u5f88\u591a\u5185\u5bb9\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u6700\u540e\u51e0\u884c\u3002 [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.895 , total = 1.0 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.890 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.910 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.880 , total = 1.1 s [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 ............... [ CV ] criterion = entropy , max_depth = 15 , n_estimators = 500 , score = 0.870 , total = 1.1 s [ Parallel ( n_jobs = 1 )]: Done 360 out of 360 | elapsed : 3.7 min finished Best score : 0.889 Best parameters set : criterion : 'entropy' max_depth : 15 n_estimators : 500 \u6700\u540e\uff0c\u6211\u4eec\u53ef\u4ee5\u770b\u5230\uff0c5\u6298\u4ea4\u53c9\u68c0\u9a8c\u6700\u4f73\u5f97\u5206\u662f 0.889\uff0c\u6211\u4eec\u7684\u7f51\u683c\u641c\u7d22\u5f97\u5230\u4e86\u6700\u4f73\u53c2\u6570\u3002\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u7684\u4e0b\u4e00\u4e2a\u6700\u4f73\u65b9\u6cd5\u662f \u968f\u673a\u641c\u7d22 \u3002\u5728\u968f\u673a\u641c\u7d22\u4e2d\uff0c\u6211\u4eec\u968f\u673a\u9009\u62e9\u4e00\u4e2a\u53c2\u6570\u7ec4\u5408\uff0c\u7136\u540e\u8ba1\u7b97\u4ea4\u53c9\u9a8c\u8bc1\u5f97\u5206\u3002\u8fd9\u91cc\u6d88\u8017\u7684\u65f6\u95f4\u6bd4\u7f51\u683c\u641c\u7d22\u5c11\uff0c\u56e0\u4e3a\u6211\u4eec\u4e0d\u5bf9\u6240\u6709\u4e0d\u540c\u7684\u53c2\u6570\u7ec4\u5408\u8fdb\u884c\u8bc4\u4f30\u3002\u6211\u4eec\u9009\u62e9\u8981\u5bf9\u6a21\u578b\u8fdb\u884c\u591a\u5c11\u6b21\u8bc4\u4f30\uff0c\u8fd9\u5c31\u51b3\u5b9a\u4e86\u641c\u7d22\u6240\u9700\u7684\u65f6\u95f4\u3002\u4ee3\u7801\u4e0e\u4e0a\u9762\u7684\u5dee\u522b\u4e0d\u5927\u3002\u9664 GridSearchCV \u5916\uff0c\u6211\u4eec\u4f7f\u7528 RandomizedSearchCV\u3002 if __name__ == \"__main__\" : classifier = ensemble . RandomForestClassifier ( n_jobs =- 1 ) # \u66f4\u6539\u641c\u7d22\u7a7a\u95f4 param_grid = { \"n_estimators\" : np . arange ( 100 , 1500 , 100 ), \"max_depth\" : np . arange ( 1 , 31 ), \"criterion\" : [ \"gini\" , \"entropy\" ] } # \u968f\u673a\u53c2\u6570\u641c\u7d22 model = model_selection . RandomizedSearchCV ( estimator = classifier , param_distributions = param_grid , n_iter = 20 , scoring = \"accuracy\" , verbose = 10 , n_jobs = 1 , cv = 5 ) # \u4f7f\u7528\u7f51\u683c\u641c\u7d22\u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) print ( f \"Best score: { model . best_score_ } \" ) print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( f \" \\t { param_name } : { best_parameters [ param_name ] } \" ) \u6211\u4eec\u66f4\u6539\u4e86\u968f\u673a\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c\uff0c\u7ed3\u679c\u4f3c\u4e4e\u6709\u4e86\u4e9b\u8bb8\u6539\u8fdb\u3002 Best score : 0.8905 Best parameters set : criterion : entropy max_depth : 25 n_estimators : 300 \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8f83\u5c11\uff0c\u968f\u673a\u641c\u7d22\u6bd4\u7f51\u683c\u641c\u7d22\u66f4\u5feb\u3002\u4f7f\u7528\u8fd9\u4e24\u79cd\u65b9\u6cd5\uff0c\u4f60\u53ef\u4ee5\u4e3a\u5404\u79cd\u6a21\u578b\u627e\u5230\u6700\u4f18\u53c2\u6570\uff0c\u53ea\u8981\u5b83\u4eec\u6709\u62df\u5408\u548c\u9884\u6d4b\u529f\u80fd\uff0c\u8fd9\u4e5f\u662f scikit-learn \u7684\u6807\u51c6\u3002\u6709\u65f6\uff0c\u4f60\u53ef\u80fd\u60f3\u4f7f\u7528\u7ba1\u9053\u3002\u4f8b\u5982\uff0c\u5047\u8bbe\u6211\u4eec\u6b63\u5728\u5904\u7406\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\u3002\u5728\u8fd9\u4e2a\u95ee\u9898\u4e2d\uff0c\u8bad\u7ec3\u6570\u636e\u7531\u4e24\u5217\u6587\u672c\u7ec4\u6210\uff0c\u4f60\u9700\u8981\u5efa\u7acb\u4e00\u4e2a\u6a21\u578b\u6765\u9884\u6d4b\u7c7b\u522b\u3002\u8ba9\u6211\u4eec\u5047\u8bbe\u4f60\u9009\u62e9\u7684\u7ba1\u9053\u662f\u9996\u5148\u4ee5\u534a\u76d1\u7763\u7684\u65b9\u5f0f\u5e94\u7528 tf-idf\uff0c\u7136\u540e\u4f7f\u7528 SVD \u548c SVM \u5206\u7c7b\u5668\u3002\u73b0\u5728\u7684\u95ee\u9898\u662f\uff0c\u6211\u4eec\u5fc5\u987b\u9009\u62e9 SVD \u7684\u6210\u5206\uff0c\u8fd8\u9700\u8981\u8c03\u6574 SVM \u7684\u53c2\u6570\u3002\u4e0b\u9762\u7684\u4ee3\u7801\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u505a\u5230\u8fd9\u4e00\u70b9\u3002 import numpy as np import pandas as pd from sklearn import metrics from sklearn import model_selection from sklearn import pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # \u8ba1\u7b97\u52a0\u6743\u4e8c\u6b21 Kappa \u5206\u6570 def quadratic_weighted_kappa ( y_true , y_pred ): return metrics . cohen_kappa_score ( y_true , y_pred , weights = \"quadratic\" ) if __name__ == '__main__' : # \u8bfb\u53d6\u8bad\u7ec3\u96c6 train = pd . read_csv ( '../input/train.csv' ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u63d0\u53d6 id \u5217\u7684\u503c\uff0c\u5e76\u5c06\u5176\u8f6c\u6362\u4e3a\u6574\u6570\u7c7b\u578b\uff0c\u5b58\u50a8\u5728\u53d8\u91cf idx \u4e2d idx = test . id . values . astype ( int ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 train = train . drop ( 'id' , axis = 1 ) # \u4ece\u6d4b\u8bd5\u6570\u636e\u4e2d\u5220\u9664 'id' \u5217 test = test . drop ( 'id' , axis = 1 ) # \u4ece\u8bad\u7ec3\u6570\u636e\u4e2d\u63d0\u53d6\u76ee\u6807\u53d8\u91cf 'relevance' \uff0c\u5b58\u50a8\u5728\u53d8\u91cf y \u4e2d y = train . relevance . values # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 traindata \u4e2d traindata = list ( train . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81 'text1' \u548c 'text2' \u5408\u5e76\u6210\u4e00\u4e2a\u65b0\u7684\u7279\u5f81\u5217\uff0c\u5e76\u5b58\u50a8\u5728\u5217\u8868 testdata \u4e2d testdata = list ( test . apply ( lambda x : ' %s %s ' % ( x [ 'text1' ], x [ 'text2' ]), axis = 1 )) # \u521b\u5efa\u4e00\u4e2a TfidfVectorizer \u5bf9\u8c61 tfv\uff0c\u7528\u4e8e\u5c06\u6587\u672c\u6570\u636e\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv = TfidfVectorizer ( min_df = 3 , max_features = None , strip_accents = 'unicode' , analyzer = 'word' , token_pattern = r '\\w{1,}' , ngram_range = ( 1 , 3 ), use_idf = 1 , smooth_idf = 1 , sublinear_tf = 1 , stop_words = 'english' ) # \u4f7f\u7528\u8bad\u7ec3\u6570\u636e\u62df\u5408 TfidfVectorizer\uff0c\u5c06\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81 tfv . fit ( traindata ) # \u5c06\u8bad\u7ec3\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X X = tfv . transform ( traindata ) # \u5c06\u6d4b\u8bd5\u6570\u636e\u4e2d\u7684\u6587\u672c\u7279\u5f81\u8f6c\u6362\u4e3a TF-IDF \u7279\u5f81\u77e9\u9635 X_test X_test = tfv . transform ( testdata ) # \u521b\u5efa TruncatedSVD \u5bf9\u8c61 svd\uff0c\u7528\u4e8e\u8fdb\u884c\u5947\u5f02\u503c\u5206\u89e3 svd = TruncatedSVD () # \u521b\u5efa StandardScaler \u5bf9\u8c61 scl\uff0c\u7528\u4e8e\u8fdb\u884c\u7279\u5f81\u7f29\u653e scl = StandardScaler () # \u521b\u5efa\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668\u5bf9\u8c61 svm_model svm_model = SVC () # \u521b\u5efa\u673a\u5668\u5b66\u4e60\u7ba1\u9053 clf\uff0c\u5305\u542b\u5947\u5f02\u503c\u5206\u89e3\u3001\u7279\u5f81\u7f29\u653e\u548c\u652f\u6301\u5411\u91cf\u673a\u5206\u7c7b\u5668 clf = pipeline . Pipeline ( [ ( 'svd' , svd ), ( 'scl' , scl ), ( 'svm' , svm_model ) ] ) # \u5b9a\u4e49\u8981\u8fdb\u884c\u7f51\u683c\u641c\u7d22\u7684\u53c2\u6570\u7f51\u683c param_grid param_grid = { 'svd__n_components' : [ 200 , 300 ], 'svm__C' : [ 10 , 12 ] } # \u521b\u5efa\u81ea\u5b9a\u4e49\u7684\u8bc4\u5206\u51fd\u6570 kappa_scorer\uff0c\u7528\u4e8e\u8bc4\u4f30\u6a21\u578b\u6027\u80fd kappa_scorer = metrics . make_scorer ( quadratic_weighted_kappa , greater_is_better = True ) # \u521b\u5efa GridSearchCV \u5bf9\u8c61 model\uff0c\u7528\u4e8e\u5728\u53c2\u6570\u7f51\u683c\u4e0a\u8fdb\u884c\u7f51\u683c\u641c\u7d22\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model = model_selection . GridSearchCV ( estimator = clf , param_grid = param_grid , scoring = kappa_scorer , verbose = 10 , n_jobs =- 1 , refit = True , cv = 5 ) # \u4f7f\u7528 GridSearchCV \u5bf9\u8c61 model \u62df\u5408\u6570\u636e\uff0c\u5bfb\u627e\u6700\u4f73\u53c2\u6570\u7ec4\u5408 model . fit ( X , y ) # \u6253\u5370\u51fa\u6700\u4f73\u6a21\u578b\u7684\u6700\u4f73\u51c6\u786e\u5ea6\u5206\u6570 print ( \"Best score: %0.3f \" % model . best_score_ ) # \u6253\u5370\u6700\u4f73\u53c2\u6570\u96c6\u5408 print ( \"Best parameters set:\" ) best_parameters = model . best_estimator_ . get_params () for param_name in sorted ( param_grid . keys ()): print ( \" \\t %s : %r \" % ( param_name , best_parameters [ param_name ])) # \u83b7\u53d6\u6700\u4f73\u6a21\u578b best_model = model . best_estimator_ best_model . fit ( X , y ) # \u4f7f\u7528\u6700\u4f73\u6a21\u578b\u8fdb\u884c\u9884\u6d4b preds = best_model . predict ( ... ) \u8fd9\u91cc\u663e\u793a\u7684\u7ba1\u9053\u5305\u62ec SVD\uff08\u5947\u5f02\u503c\u5206\u89e3\uff09\u3001\u6807\u51c6\u7f29\u653e\u548c SVM\uff08\u652f\u6301\u5411\u91cf\u673a\uff09\u6a21\u578b\u3002\u8bf7\u6ce8\u610f\uff0c\u7531\u4e8e\u6ca1\u6709\u8bad\u7ec3\u6570\u636e\uff0c\u60a8\u65e0\u6cd5\u6309\u539f\u6837\u8fd0\u884c\u4e0a\u8ff0\u4ee3\u7801\u3002\u5f53\u6211\u4eec\u8fdb\u5165\u9ad8\u7ea7\u8d85\u53c2\u6570\u4f18\u5316\u6280\u672f\u65f6\uff0c\u6211\u4eec\u53ef\u4ee5\u4f7f\u7528\u4e0d\u540c\u7c7b\u578b\u7684 \u6700\u5c0f\u5316\u7b97\u6cd5 \u6765\u7814\u7a76\u51fd\u6570\u7684\u6700\u5c0f\u5316\u3002\u8fd9\u53ef\u4ee5\u901a\u8fc7\u4f7f\u7528\u591a\u79cd\u6700\u5c0f\u5316\u51fd\u6570\u6765\u5b9e\u73b0\uff0c\u5982\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u3001\u5185\u5c14\u5fb7-\u6885\u5fb7\u4f18\u5316\u7b97\u6cd5\u3001\u4f7f\u7528\u8d1d\u53f6\u65af\u6280\u672f\u548c\u9ad8\u65af\u8fc7\u7a0b\u5bfb\u627e\u6700\u4f18\u53c2\u6570\u6216\u4f7f\u7528\u9057\u4f20\u7b97\u6cd5\u3002\u6211\u5c06\u5728 \"\u96c6\u5408\u4e0e\u5806\u53e0\uff08ensembling and stacking\uff09 \"\u4e00\u7ae0\u4e2d\u8be6\u7ec6\u4ecb\u7ecd\u4e0b\u5761\u5355\u7eaf\u5f62\u7b97\u6cd5\u548c Nelder-Mead \u7b97\u6cd5\u7684\u5e94\u7528\u3002\u9996\u5148\uff0c\u8ba9\u6211\u4eec\u770b\u770b\u9ad8\u65af\u8fc7\u7a0b\u5982\u4f55\u7528\u4e8e\u8d85\u53c2\u6570\u4f18\u5316\u3002\u8fd9\u7c7b\u7b97\u6cd5\u9700\u8981\u4e00\u4e2a\u53ef\u4ee5\u4f18\u5316\u7684\u51fd\u6570\u3002\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u90fd\u662f\u6700\u5c0f\u5316\u8fd9\u4e2a\u51fd\u6570\uff0c\u5c31\u50cf\u6211\u4eec\u6700\u5c0f\u5316\u635f\u5931\u4e00\u6837\u3002 \u56e0\u6b64\uff0c\u6bd4\u65b9\u8bf4\uff0c\u4f60\u60f3\u627e\u5230\u6700\u4f73\u53c2\u6570\u4ee5\u83b7\u5f97\u6700\u4f73\u51c6\u786e\u5ea6\uff0c\u663e\u7136\uff0c\u51c6\u786e\u5ea6\u8d8a\u9ad8\u8d8a\u597d\u3002\u73b0\u5728\uff0c\u6211\u4eec\u4e0d\u80fd\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\uff0c\u4f46\u6211\u4eec\u53ef\u4ee5\u5c06\u7cbe\u786e\u5ea6\u4e58\u4ee5-1\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u662f\u5728\u6700\u5c0f\u5316\u7cbe\u786e\u5ea6\u7684\u8d1f\u503c\uff0c\u4f46\u4e8b\u5b9e\u4e0a\uff0c\u6211\u4eec\u662f\u5728\u6700\u5927\u5316\u7cbe\u786e\u5ea6\u3002 \u5728\u9ad8\u65af\u8fc7\u7a0b\u4e2d\u4f7f\u7528\u8d1d\u53f6\u65af\u4f18\u5316\uff0c\u53ef\u4ee5\u4f7f\u7528 scikit-optimize (skopt) \u5e93\u4e2d\u7684 gp_minimize \u51fd\u6570\u3002\u8ba9\u6211\u4eec\u770b\u770b\u5982\u4f55\u4f7f\u7528\u8be5\u51fd\u6570\u8c03\u6574\u968f\u673a\u68ee\u6797\u6a21\u578b\u7684\u53c2\u6570\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from skopt import gp_minimize from skopt import space def optimize ( params , param_names , x , y ): # \u5c06\u53c2\u6570\u540d\u79f0\u548c\u5bf9\u5e94\u7684\u503c\u6253\u5305\u6210\u5b57\u5178 params = dict ( zip ( param_names , params )) # \u521b\u5efa\u968f\u673a\u68ee\u6797\u5206\u7c7b\u5668\u6a21\u578b\uff0c\u4f7f\u7528\u4f20\u5165\u7684\u53c2\u6570\u914d\u7f6e model = ensemble . RandomForestClassifier ( ** params ) # \u521b\u5efa StratifiedKFold \u4ea4\u53c9\u9a8c\u8bc1\u5bf9\u8c61\uff0c\u5c06\u6570\u636e\u5206\u4e3a 5 \u6298 kf = model_selection . StratifiedKFold ( n_splits = 5 ) # \u521d\u59cb\u5316\u7528\u4e8e\u5b58\u50a8\u6bcf\u4e2a\u6298\u53e0\u7684\u51c6\u786e\u5ea6\u7684\u5217\u8868 accuracies = [] # \u5faa\u73af\u904d\u5386\u6bcf\u4e2a\u6298\u53e0\u7684\u8bad\u7ec3\u548c\u6d4b\u8bd5\u6570\u636e for idx in kf . split ( X = x , y = y ): train_idx , test_idx = idx [ 0 ], idx [ 1 ] xtrain = x [ train_idx ] ytrain = y [ train_idx ] xtest = x [ test_idx ] ytest = y [ test_idx ] # \u5728\u8bad\u7ec3\u6570\u636e\u4e0a\u62df\u5408\u6a21\u578b model . fit ( xtrain , ytrain ) # \u4f7f\u7528\u6a21\u578b\u5bf9\u6d4b\u8bd5\u6570\u636e\u8fdb\u884c\u9884\u6d4b preds = model . predict ( xtest ) # \u8ba1\u7b97\u6298\u53e0\u7684\u51c6\u786e\u5ea6 fold_accuracy = metrics . accuracy_score ( ytest , preds ) accuracies . append ( fold_accuracy ) # \u8fd4\u56de\u5e73\u5747\u51c6\u786e\u5ea6\u7684\u8d1f\u6570\uff08\u56e0\u4e3a skopt \u4f7f\u7528\u8d1f\u6570\u6765\u6700\u5c0f\u5316\u76ee\u6807\u51fd\u6570\uff09 return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : # \u8bfb\u53d6\u6570\u636e df = pd . read_csv ( \"../input/mobile_train.csv\" ) # \u53d6\u7279\u5f81\u77e9\u9635 X\uff08\u53bb\u6389\"price_range\"\u5217\uff09 X = df . drop ( \"price_range\" , axis = 1 ) . values # \u76ee\u6807\u53d8\u91cf y\uff08\"price_range\"\u5217\uff09 y = df . price_range . values # \u5b9a\u4e49\u8d85\u53c2\u6570\u641c\u7d22\u7a7a\u95f4 param_space param_space = [ space . Integer ( 3 , 15 , name = \"max_depth\" ), space . Integer ( 100 , 1500 , name = \"n_estimators\" ), space . Categorical ([ \"gini\" , \"entropy\" ], name = \"criterion\" ), space . Real ( 0.01 , 1 , prior = \"uniform\" , name = \"max_features\" ) ] # \u5b9a\u4e49\u8d85\u53c2\u6570\u7684\u540d\u79f0\u5217\u8868 param_names param_names = [ \"max_depth\" , \"n_estimators\" , \"criterion\" , \"max_features\" ] # \u521b\u5efa\u51fd\u6570 optimization_function\uff0c\u7528\u4e8e\u4f20\u9012\u7ed9 gp_minimize optimization_function = partial ( optimize , param_names = param_names , x = X , y = y ) # \u4f7f\u7528 Bayesian Optimization\uff08\u57fa\u4e8e\u8d1d\u53f6\u65af\u4f18\u5316\uff09\u6765\u641c\u7d22\u6700\u4f73\u8d85\u53c2\u6570 result = gp_minimize ( optimization_function , dimensions = param_space , n_calls = 15 , n_random_starts = 10 , verbose = 10 ) # \u83b7\u53d6\u6700\u4f73\u8d85\u53c2\u6570\u7684\u5b57\u5178 best_params = dict ( zip ( param_names , result . x ) ) # \u6253\u5370\u51fa\u627e\u5230\u7684\u6700\u4f73\u8d85\u53c2\u6570 print ( best_params ) \u8fd9\u540c\u6837\u4f1a\u4ea7\u751f\u5927\u91cf\u8f93\u51fa\uff0c\u6700\u540e\u4e00\u90e8\u5206\u5982\u4e0b\u6240\u793a\u3002 Iteration No : 14 started . Searching for the next optimal point . Iteration No : 14 ended . Search finished for the next optimal point . Time taken : 4.7793 Function value obtained : - 0.9075 Current minimum : - 0.9075 Iteration No : 15 started . Searching for the next optimal point . Iteration No : 15 ended . Search finished for the next optimal point . Time taken : 49.4186 Function value obtained : - 0.9075 Current minimum : - 0.9075 { 'max_depth' : 12 , 'n_estimators' : 100 , 'criterion' : 'entropy' , 'max_features' : 1.0 } \u770b\u6765\u6211\u4eec\u5df2\u7ecf\u6210\u529f\u7a81\u7834\u4e86 0.90 \u7684\u51c6\u786e\u7387\u3002\u8fd9\u771f\u662f\u592a\u795e\u5947\u4e86\uff01 \u6211\u4eec\u8fd8\u53ef\u4ee5\u901a\u8fc7\u4ee5\u4e0b\u4ee3\u7801\u6bb5\u67e5\u770b\uff08\u7ed8\u5236\uff09\u6211\u4eec\u662f\u5982\u4f55\u5b9e\u73b0\u6536\u655b\u7684\u3002 from skopt.plots import plot_convergence plot_convergence ( result ) \u6536\u655b\u56fe\u5982\u56fe 2 \u6240\u793a\u3002 \u56fe 2\uff1a\u968f\u673a\u68ee\u6797\u53c2\u6570\u4f18\u5316\u7684\u6536\u655b\u56fe Scikit- optimize \u5c31\u662f\u8fd9\u6837\u4e00\u4e2a\u5e93\u3002 hyperopt \u4f7f\u7528\u6811\u72b6\u7ed3\u6784\u8d1d\u53f6\u65af\u4f30\u8ba1\u5668\uff08TPE\uff09\u6765\u627e\u5230\u6700\u4f18\u53c2\u6570\u3002\u8bf7\u770b\u4e0b\u9762\u7684\u4ee3\u7801\u7247\u6bb5\uff0c\u6211\u5728\u4f7f\u7528 hyperopt \u65f6\u5bf9\u4e4b\u524d\u7684\u4ee3\u7801\u505a\u4e86\u6700\u5c0f\u7684\u6539\u52a8\u3002 import numpy as np import pandas as pd from functools import partial from sklearn import ensemble from sklearn import metrics from sklearn import model_selection from hyperopt import hp , fmin , tpe , Trials from hyperopt.pyll.base import scope def optimize ( params , x , y ): model = ensemble . RandomForestClassifier ( ** params ) kf = model_selection . StratifiedKFold ( n_splits = 5 ) ... return - 1 * np . mean ( accuracies ) if __name__ == \"__main__\" : df = pd . read_csv ( \"../input/mobile_train.csv\" ) X = df . drop ( \"price_range\" , axis = 1 ) . values y = df . price_range . values # \u5b9a\u4e49\u641c\u7d22\u7a7a\u95f4\uff08\u6574\u578b\u3001\u6d6e\u70b9\u6570\u578b\u3001\u9009\u62e9\u578b\uff09 param_space = { \"max_depth\" : scope . int ( hp . quniform ( \"max_depth\" , 1 , 15 , 1 )), \"n_estimators\" : scope . int ( hp . quniform ( \"n_estimators\" , 100 , 1500 , 1 ) ), \"criterion\" : hp . choice ( \"criterion\" , [ \"gini\" , \"entropy\" ]), \"max_features\" : hp . uniform ( \"max_features\" , 0 , 1 ) } # \u5305\u88c5\u51fd\u6570 optimization_function = partial ( optimize , x = X , y = y ) # \u5f00\u59cb\u8bad\u7ec3 trials = Trials () # \u6700\u5c0f\u5316\u76ee\u6807\u503c hopt = fmin ( fn = optimization_function , space = param_space , algo = tpe . suggest , max_evals = 15 , trials = trials ) #\u6253\u5370\u6700\u4f73\u53c2\u6570 print ( hopt ) \u6b63\u5982\u4f60\u6240\u770b\u5230\u7684\uff0c\u8fd9\u4e0e\u4e4b\u524d\u7684\u4ee3\u7801\u5e76\u65e0\u592a\u5927\u533a\u522b\u3002\u4f60\u5fc5\u987b\u4ee5\u4e0d\u540c\u7684\u683c\u5f0f\u5b9a\u4e49\u53c2\u6570\u7a7a\u95f4\uff0c\u8fd8\u9700\u8981\u6539\u53d8\u5b9e\u9645\u4f18\u5316\u90e8\u5206\uff0c\u7528 hyperopt \u4ee3\u66ff gp_minimize\u3002\u7ed3\u679c\u76f8\u5f53\u4e0d\u9519\uff01 \u276f python rf_hyperopt . py 100 %| \u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588 | 15 / 15 [ 04 : 38 < 00 : 00 , 18.57 s / trial , best loss : - 0.9095000000000001 ] { 'criterion' : 1 , 'max_depth' : 11.0 , 'max_features' : 0.821163568049807 , 'n_estimators' : 806.0 } \u6211\u4eec\u5f97\u5230\u4e86\u6bd4\u4ee5\u524d\u66f4\u597d\u7684\u51c6\u786e\u5ea6\u548c\u4e00\u7ec4\u53ef\u4ee5\u4f7f\u7528\u7684\u53c2\u6570\u3002\u8bf7\u6ce8\u610f\uff0c\u6700\u7ec8\u7ed3\u679c\u4e2d\u7684\u6807\u51c6\u662f 1\u3002\u8fd9\u610f\u5473\u7740\u9009\u62e9\u4e86 1\uff0c\u5373\u71b5\u3002 \u4e0a\u8ff0\u8c03\u6574\u8d85\u53c2\u6570\u7684\u65b9\u6cd5\u662f\u6700\u5e38\u89c1\u7684\uff0c\u51e0\u4e4e\u9002\u7528\u4e8e\u6240\u6709\u6a21\u578b\uff1a\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001\u57fa\u4e8e\u6811\u7684\u65b9\u6cd5\u3001\u68af\u5ea6\u63d0\u5347\u6a21\u578b\uff08\u5982 xgboost\u3001lightgbm\uff09\uff0c\u751a\u81f3\u795e\u7ecf\u7f51\u7edc\uff01 \u867d\u7136\u8fd9\u4e9b\u65b9\u6cd5\u5df2\u7ecf\u5b58\u5728\uff0c\u4f46\u5b66\u4e60\u65f6\u5fc5\u987b\u4ece\u624b\u52a8\u8c03\u6574\u8d85\u53c2\u6570\u5f00\u59cb\uff0c\u5373\u624b\u5de5\u8c03\u6574\u3002\u624b\u52a8\u8c03\u6574\u53ef\u4ee5\u5e2e\u52a9\u4f60\u5b66\u4e60\u57fa\u7840\u77e5\u8bc6\uff0c\u4f8b\u5982\uff0c\u5728\u68af\u5ea6\u63d0\u5347\u4e2d\uff0c\u5f53\u4f60\u589e\u52a0\u6df1\u5ea6\u65f6\uff0c\u4f60\u5e94\u8be5\u964d\u4f4e\u5b66\u4e60\u7387\u3002\u5982\u679c\u4f7f\u7528\u81ea\u52a8\u5de5\u5177\uff0c\u5c31\u65e0\u6cd5\u5b66\u4e60\u5230\u8fd9\u4e00\u70b9\u3002\u8bf7\u53c2\u8003\u4e0b\u8868\uff0c\u4e86\u89e3\u5e94\u5982\u4f55\u8c03\u6574\u3002RS* \u8868\u793a\u968f\u673a\u641c\u7d22\u5e94\u8be5\u66f4\u597d\u3002 \u4e00\u65e6\u4f60\u80fd\u66f4\u597d\u5730\u624b\u52a8\u8c03\u6574\u53c2\u6570\uff0c\u4f60\u751a\u81f3\u53ef\u80fd\u4e0d\u9700\u8981\u4efb\u4f55\u81ea\u52a8\u8d85\u53c2\u6570\u8c03\u6574\u3002\u521b\u5efa\u5927\u578b\u6a21\u578b\u6216\u5f15\u5165\u5927\u91cf\u7279\u5f81\u65f6\uff0c\u4e5f\u5bb9\u6613\u9020\u6210\u8bad\u7ec3\u6570\u636e\u7684\u8fc7\u5ea6\u62df\u5408\u3002\u4e3a\u907f\u514d\u8fc7\u5ea6\u62df\u5408\uff0c\u9700\u8981\u5728\u8bad\u7ec3\u6570\u636e\u7279\u5f81\u4e2d\u5f15\u5165\u566a\u58f0\u6216\u5bf9\u4ee3\u4ef7\u51fd\u6570\u8fdb\u884c\u60e9\u7f5a\u3002\u8fd9\u79cd\u60e9\u7f5a\u79f0\u4e3a \u6b63\u5219\u5316 \uff0c\u6709\u52a9\u4e8e\u6cdb\u5316\u6a21\u578b\u3002\u5728\u7ebf\u6027\u6a21\u578b\u4e2d\uff0c\u6700\u5e38\u89c1\u7684\u6b63\u5219\u5316\u7c7b\u578b\u662f L1 \u548c L2\u3002L1 \u4e5f\u79f0\u4e3a Lasso \u56de\u5f52\uff0cL2 \u79f0\u4e3a Ridge \u56de\u5f52\u3002\u8bf4\u5230\u795e\u7ecf\u7f51\u7edc\uff0c\u6211\u4eec\u4f1a\u4f7f\u7528dropout\u3001\u6dfb\u52a0\u589e\u5f3a\u3001\u566a\u58f0\u7b49\u65b9\u6cd5\u5bf9\u6a21\u578b\u8fdb\u884c\u6b63\u5219\u5316\u3002\u5229\u7528\u8d85\u53c2\u6570\u4f18\u5316\uff0c\u8fd8\u53ef\u4ee5\u627e\u5230\u6b63\u786e\u7684\u60e9\u7f5a\u65b9\u6cd5\u3002 Model Optimize Range of values Linear Regression - fit_intercept - normalize - True/False - True/False Ridge - alpha - fit_intercept - normalize - 0.01, 0.1, 1.0, 10, 100 - True/False - True/False k-neighbors - n_neighbors - p - 2, 4, 8, 16, ... - 2, 3, ... SVM - C - gamma - class_weight - 0.001, 0.01, ...,10, 100, 1000 - 'auto', RS* - 'balanced', None Logistic Regression - Penalyt - C - L1 or L2 - 0.001, 0.01, ..., 10, ..., 100 Lasso - Alpha - Normalize - 0.1, 1.0, 10 - True/False Random Forest - n_estimators - max_depth - min_samples_split - min_samples_leaf - max features - 120, 300, 500, 800, 1200 - 5, 8, 15, 25, 30, None - 1, 2, 5, 10, 15, 100 - log2, sqrt, None XGBoost - eta - gamma - max_depth - min_child_weight - subsample - colsample_bytree - lambda - alpha - 0.01, 0.015, 0.025, 0.05, 0.1 - 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0 - 3, 5, 7, 9, 12, 15, 17, 25 - 1, 3, 5, 7 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.6, 0.7, 0.8, 0.9, 1.0 - 0.01, 0.1, 1.0, RS - 0, 0.1, 0.5, 1.0, RS","title":"\u8d85\u53c2\u6570\u4f18\u5316"}]} \ No newline at end of file diff --git a/sitemap.xml.gz b/sitemap.xml.gz index 8dc2b54..d7b2c59 100644 Binary files a/sitemap.xml.gz and b/sitemap.xml.gz differ diff --git "a/\345\233\276\345\203\217\345\210\206\347\261\273\345\222\214\345\210\206\345\211\262\346\226\271\346\263\225/index.html" "b/\345\233\276\345\203\217\345\210\206\347\261\273\345\222\214\345\210\206\345\211\262\346\226\271\346\263\225/index.html" new file mode 100644 index 0000000..26594a2 --- /dev/null +++ "b/\345\233\276\345\203\217\345\210\206\347\261\273\345\222\214\345\210\206\345\211\262\346\226\271\346\263\225/index.html" @@ -0,0 +1,1578 @@ + + + + + + + + + + + + + + + + + 图像分类和分割方法 - AAAMLP 中译版 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + 跳转至 + + +
+
+ +
+ + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + + +
+
+
+ + +
+
+ + + + + + + +

图像分类和分割方法

+

说到图像,过去几年取得了很多成就。计算机视觉的进步相当快,感觉计算机视觉的许多问题现在都更容易解决了。随着预训练模型的出现和计算成本的降低,现在在家里就能轻松训练出接近最先进水平的模型,解决大多数与图像相关的问题。但是,图像问题有许多不同的类型。从两个或多个类别的标准图像分类,到像自动驾驶汽车这样具有挑战性的问题。我们不会在本书中讨论自动驾驶汽车,但我们显然会处理一些最常见的图像问题。

+

我们可以对图像采用哪些不同的方法?图像只不过是一个数字矩阵。计算机无法像人类一样看到图像。它只能看到数字,这就是图像。灰度图像是一个二维矩阵,数值范围从 0 到 255。0 代表黑色,255 代表白色,介于两者之间的是各种灰色。以前,在没有深度学习的时候(或者说深度学习还不流行的时候),人们习惯于查看像素。每个像素都是一个特征。你可以在 Python 中轻松做到这一点。只需使用 OpenCV 或 Python-PIL 读取灰度图像,转换为 numpy 数组,然后将矩阵平铺(扁平化)即可。如果处理的是 RGB 图像,则需要三个矩阵,而不是一个。但思路是一样的。

+
import numpy as np
+import matplotlib.pyplot as plt
+# 生成一个 256x256 的随机灰度图像,像素值在0到255之间随机分布
+random_image = np.random.randint(0, 256, (256, 256))
+
+# 创建一个新的图像窗口,设置窗口大小为7x7英寸
+plt.figure(figsize=(7, 7))
+
+# 显示生成的随机图像
+# 使用灰度颜色映射 (colormap),范围从0到255
+plt.imshow(random_image, cmap='gray', vmin=0, vmax=255)
+
+# 显示图像窗口
+plt.show()
+
+

上面的代码使用 numpy 生成一个随机矩阵。该矩阵由 0 到 255(包含)的值组成,大小为 256x256(也称为像素)。

+

+

图 1:二维图像阵列(单通道)及其展平版本

+ +

正如你所看到的,拼写后的版本只是一个大小为 M 的向量,其中 M = N * N,在本例中,这个向量的大小为 256 * 256 = 65536。

+

现在,如果我们继续对数据集中的所有图像进行处理,每个样本就会有 65536 个特征。我们可以在这些数据上快速建立决策树模型、随机森林模型或基于 SVM 的模型。这些模型将基于像素值,尝试将正样本与负样本区分开来(二元分类问题)。

+

你们一定都听说过猫与狗的问题,这是一个经典的问题。如果你们还记得,在评估指标一章的开头,我向你们介绍了一个气胸图像数据集。那么,让我们尝试建立一个模型来检测肺部的 X 光图像是否存在气胸。也就是说,这是一个(并不)简单的二元分类。

+

+

图 2:非气胸与气胸 X 光图像对比

+ +

在图 2 中,您可以看到非气胸和气胸图像的对比。您一定已经注意到了,对于一个非专业人士(比如我)来说,要在这些图像中辨别出哪个是气胸是相当困难的。

+

最初的数据集是关于检测气胸的具体位置,但我们将问题修改为查找给定的 X 光图像是否存在气胸。别担心,我们将在本章介绍这个部分。数据集由 10675 张独特的图像组成,其中 2379 张有气胸(注意,这些数字是经过数据清理后得出的,因此与原始数据集不符)。正如数据科学家所说:这是一个典型的偏斜二元分类案例。因此,我们选择 AUC 作为评估指标,并采用分层 k 折交叉验证方案。

+

您可以将特征扁平化,然后尝试一些经典方法(如 SVM、RF)来进行分类,这完全没问题,但却无法让您达到最先进的水平。此外,图像大小为 1024x1024。在这个数据集上训练一个模型需要很长时间。不管怎样,让我们尝试在这些数据上建立一个简单的随机森林模型。由于图像是灰度的,我们不需要进行任何转换。我们将把图像大小调整为 256x256,使其更小,并使用之前讨论过的 AUC 作为衡量指标。

+

让我们看看它的表现如何。

+
import os
+import numpy as np
+import pandas as pd
+from PIL import Image
+from sklearn import ensemble
+from sklearn import metrics
+from sklearn import model_selection
+from tqdm import tqdm
+
+# 定义一个函数来创建数据集
+def create_dataset(training_df, image_dir):
+    # 初始化空列表来存储图像数据和目标值
+    images = []
+    targets = []
+
+    # 迭代处理训练数据集中的每一行
+    for index, row in tqdm(
+        training_df.iterrows(), 
+        total=len(training_df), 
+        desc="processing images"
+    ):
+        # 获取图像文件名
+        image_id = row["ImageId"] 
+
+        # 构建完整的图像文件路径
+        image_path = os.path.join(image_dir, image_id)
+
+        # 打开图像文件并进行大小调整(resize)为 256x256 像素,使用双线性插值(BILINEAR)
+        image = Image.open(image_path + ".png")
+        image = image.resize((256, 256), resample=Image.BILINEAR) 
+
+        # 将图像转换为NumPy数组
+        image = np.array(image)
+
+        # 将图像扁平化为一维数组,并将其添加到图像列表
+        image = image.ravel()
+        images.append(image)
+
+        # 将目标值(target)添加到目标列表
+        targets.append(int(row["target"]))
+
+    # 将图像列表转换为NumPy数组
+    images = np.array(images)
+
+    # 打印图像数组的形状
+    print(images.shape) 
+
+    # 返回图像数据和目标值
+    return images, targets
+
+if __name__ == "__main__":
+    # 定义CSV文件路径和图像文件目录路径
+    csv_path = "/home/abhishek/workspace/siim_png/train.csv" 
+    image_path = "/home/abhishek/workspace/siim_png/train_png/"
+
+    # 从CSV文件加载数据
+    df = pd.read_csv(csv_path)
+
+    # 添加一个名为'kfold'的列,并初始化为-1
+    df["kfold"] = -1
+
+    # 随机打乱数据
+    df = df.sample(frac=1).reset_index(drop=True)
+
+    # 获取目标值(target)
+    y = df.target.values
+
+    # 使用分层KFold交叉验证将数据集分成5折
+    kf = model_selection.StratifiedKFold(n_splits=5)
+
+    # 遍历每个折(fold)
+    for f, (t_, v_) in enumerate(kf.split(X=df, y=y)): 
+        df.loc[v_, 'kfold'] = f
+
+    # 遍历每个折
+    for fold_ in range(5):
+        # 获取训练数据和测试数据
+        train_df = df[df.kfold != fold_].reset_index(drop=True) 
+        test_df = df[df.kfold == fold_].reset_index(drop=True)
+
+        # 创建训练数据集的图像数据和目标值
+        xtrain, ytrain = create_dataset(train_df, image_path)
+
+        # 创建测试数据集的图像数据和目标值
+        xtest, ytest = create_dataset(test_df, image_path)
+
+        # 初始化一个随机森林分类器
+        clf = ensemble.RandomForestClassifier(n_jobs=-1)
+
+        # 使用训练数据拟合分类器
+        clf.fit(xtrain, ytrain)
+
+        # 使用分类器对测试数据进行预测,并获取概率值
+        preds = clf.predict_proba(xtest)[:, 1]
+
+        # 打印折数(fold)和AUC(ROC曲线下的面积)
+        print(f"FOLD: {fold_}")
+        print(f"AUC = {metrics.roc_auc_score(ytest, preds)}")
+        print("")
+
+

平均 AUC 值约为 0.72。这还不错,但我们希望能做得更好。你可以将这种方法用于图像,这也是它在以前最常用的方法。SVM 在图像数据集方面相当有名。深度学习已被证明是解决此类问题的最先进方法,因此我们下一步可以试试它。

+

关于深度学习的历史以及谁发明了什么,我就不多说了。让我们看看最著名的深度学习模型之一 AlexNet。

+

+

图 3:AlexNet 架构9 请注意,本图中的输入大小不是 224x224 而是 227x227

+ +

如今,你可能会说这只是一个基本的深度卷积神经网络,但它却是许多新型深度网络(深度神经网络)的基础。我们看到,图 3 中的网络是一个具有五个卷积层、两个密集层和一个输出层的卷积神经网络。我们看到还有最大池化。这是什么意思?让我们来看看在进行深度学习时会遇到的一些术语。

+

+

图 4:图像大小为 8x8,滤波器大小为 3x3,步长为 2。

+ +

图 4 引入了两个新术语:滤波器和步长。滤波器是由给定函数初始化的二维矩阵,由指定函数初始化。Kaiming正态初始化,是卷积神经网络的最佳选择。这是因为大多数现代网络都使用 ReLU(整流线性单元)激活函数,需要适当的初始化来避免梯度消失问题(梯度趋近于零,网络权重不变)。该滤波器与图像进行卷积。卷积不过是滤波器与给定图像中当前重叠像素之间的元素相乘的总和。您可以在任何高中数学教科书中阅读更多关于卷积的内容。我们从图像的左上角开始对滤镜进行卷积,然后水平移动滤镜。如果移动 1 个像素,则步长为 1;如果移动 2 个像素,则步长为 2。

+

即使在自然语言处理中,例如在问题和回答系统中需要从大量文本语料中筛选答案时,步长也是一个非常有用的概念。当我们在水平方向上走到尽头时,就会以同样的步长垂直向下移动过滤器,从左侧开始。图 4 还显示了过滤器移出图像的情况。在这种情况下,无法计算卷积。因此,我们跳过它。如果不想跳过,则需要对图像进行填充(pad)。还必须注意的是,卷积会减小图像的大小。填充也是保持图像大小不变的一种方法。在图 4 中,一个 3x3 滤波器正在水平和垂直移动,每次移动都会分别跳过两列和两行(即像素)。由于它跳过了两个像素,所以步长 = 2。因此图像大小为 [(8-3) / 2] + 1 = 3.5。我们取 3.5 的下限,所以是 3x3。您可以在草稿纸上进行尝试。

+

+

图 5:通过填充,我们可以提供与输入图像大小相同的图像

+ +

我们可以从图 5 中看到填充的效果。现在,我们有一个 3x3 的滤波器,它以 1 的步长移动。原始图像的大小为 6x6,我们添加了 1 的填充。在这种情况下,生成的图像将与输入图像大小相同,即 6x6。在处理深度神经网络时可能会遇到的另一个相关术语是膨胀(dilation),如图 6 所示。

+

+

图 6:膨胀(dilation)的例子

+ +

在膨胀过程中,我们将滤波器扩大 N-1,其中 N 是膨胀率的值,或简称为膨胀。在这种带膨胀的内核中,每次卷积都会跳过一些像素。这在分割任务中尤为有效。请注意,我们只讨论了二维卷积。 还有一维卷积和更高维度的卷积。它们都基于相同的基本概念。

+

接下来是最大池化(Max pooling)。最大值池只是一个返回最大值的滤波器。因此,我们提取的不是卷积,而是像素的最大值。同样,平均池化(average pooling)均值池化(mean pooling)会返回像素的平均值。它们的使用方法与卷积核相同。池化比卷积更快,是一种对图像进行缩减采样的方法。最大池化可检测边缘,平均池化可平滑图像。

+

卷积神经网络和深度学习的概念太多了。我所讨论的是一些基础知识,可以帮助你入门。现在,我们已经为在 PyTorch 中构建第一个卷积神经网络做好了充分准备。PyTorch 提供了一种直观而简单的方法来实现深度神经网络,而且你不需要关心反向传播。我们用一个 python 类和一个前馈函数来定义网络,告诉 PyTorch 各层之间如何连接。在 PyTorch 中,图像符号是 BS、C、H、W,其中,BS 是批大小,C 是通道,H 是高度,W 是宽度。让我们看看 PyTorch 是如何实现 AlexNet 的。

+
import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class AlexNet(nn.Module): 
+    def __init__(self):
+        super(AlexNet, self).__init__()
+        self.conv1 = nn.Conv2d(
+            in_channels=3, 
+            out_channels=96, 
+            kernel_size=11, 
+            stride=4, 
+            padding=0)
+        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 
+        self.conv2 = nn.Conv2d(
+            in_channels=96, 
+            out_channels=256, 
+            kernel_size=5, 
+            stride=1,
+            padding=2)
+        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 
+        self.conv3 = nn.Conv2d(
+            in_channels=256, 
+            out_channels=384, 
+            kernel_size=3, 
+            stride=1,
+            padding=1)
+        self.conv4 = nn.Conv2d(in_channels=384,out_channels=384, 
+            kernel_size=3, stride=1, padding=1)
+        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256,
+            kernel_size=3, stride=1, padding=1)
+        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2) 
+        self.fc1 = nn.Linear(in_features=9216, out_features=4096)
+        self.dropout1 = nn.Dropout(0.5) 
+        self.fc2 = nn.Linear(in_features=4096, 
+                             out_features=4096)
+        self.dropout2 = nn.Dropout(0.5) 
+        self.fc3 = nn.Linear(
+            in_features=4096, 
+            out_features=1000)
+    def forward(self, image):
+        bs, c, h, w = image.size()
+        x = F.relu(self.conv1(image)) # size: (bs, 96, 55, 55)
+        x = self.pool1(x) # size: (bs, 96, 27, 27)
+        x = F.relu(self.conv2(x)) # size: (bs, 256, 27, 27)
+        x = self.pool2(x) # size: (bs, 256, 13, 13)
+        x = F.relu(self.conv3(x)) # size: (bs, 384, 13, 13)
+        x = F.relu(self.conv4(x)) # size: (bs, 384, 13, 13)
+        x = F.relu(self.conv5(x)) # size: (bs, 256, 13, 13)
+        x = self.pool3(x) # size: (bs, 256, 6, 6)
+        x = x.view(bs, -1) # size: (bs, 9216)
+        x = F.relu(self.fc1(x)) # size: (bs, 4096)
+        x = self.dropout1(x) # size: (bs, 4096) 
+        # dropout does not change size
+        # dropout is used for regularization
+        # 0.3 dropout means that only 70% of the nodes 
+        # of the current layer are used for the next layer 
+        x = F.relu(self.fc2(x)) # size: (bs, 4096)
+        x = self.dropout2(x) # size: (bs, 4096)
+        x = F.relu(self.fc3(x)) # size: (bs, 1000)
+        # 1000 is number of classes in ImageNet Dataset 
+        # softmax is an activation function that converts 
+        # linear output to probabilities that add up to 1
+        # for each sample in the batch
+        x = torch.softmax(x, axis=1) # size: (bs, 1000) 
+        return x
+
+

如果您有一幅 3x227x227 的图像,并应用了一个大小为 11x11 的卷积滤波器,这意味着您应用了一个大小为 11x11x3 的滤波器,并与一个大小为 227x227x3 的图像进行了卷积。输出通道的数量就是分别应用于图像的相同大小的不同卷积滤波器的数量。 因此,在第一个卷积层中,输入通道是 3,也就是原始输入,即 R、G、B 三通道。PyTorch 的 torchvision 提供了许多与 AlexNet 类似的不同模型,必须指出的是,AlexNet 的实现与 torchvision 的实现并不相同。Torchvision 的 AlexNet 实现是从另一篇论文中修改而来的 AlexNet: Krizhevsky, A. One weird trick for parallelizing convolutional neural networks. CoRR, abs/1404.5997, 2014.

+

你可以为自己的任务设计卷积神经网络,很多时候,从零做起是个不错的主意。让我们构建一个网络,用于区分图像有无气胸。首先,让我们准备一些文件。第一步是创建一个交叉检验数据集,即 train.csv,但增加一列 kfold。我们将创建五个文件夹。在本书中,我已经演示了如何针对不同的数据集创建折叠,因此我将跳过这一部分,留作练习。对于基于 PyTorch 的神经网络,我们需要创建一个数据集类。数据集类的目的是返回一个数据项或数据样本。这个数据样本应该包含训练或评估模型所需的所有内容。

+
import torch
+import numpy as np
+from PIL import Image 
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+# 定义一个数据集类,用于处理图像分类任务
+class ClassificationDataset:
+    def __init__(self, image_paths, targets, resize=None, augmentations=None):
+        # 图像文件路径列表
+        self.image_paths = image_paths
+        # 目标标签列表
+        self.targets = targets
+        # 图像尺寸调整参数,可以为None
+        self.resize = resize
+        # 数据增强函数,可以为None
+        self.augmentations = augmentations
+
+    def __len__(self):
+        # 返回数据集的大小,即图像数量
+        return len(self.image_paths)
+
+    def __getitem__(self, item):
+        # 获取数据集中的一个样本
+        image = Image.open(self.image_paths[item])
+        image = image.convert("RGB")  # 将图像转换为RGB格式
+
+        # 获取该样本的目标标签
+        targets = self.targets[item]
+
+        if self.resize is not None:
+            # 如果指定了尺寸调整参数,将图像进行尺寸调整
+            image = image.resize((self.resize[1], self.resize[0]),
+                                 resample=Image.BILINEAR)
+            image = np.array(image)
+
+            if self.augmentations is not None:
+                # 如果指定了数据增强函数,应用数据增强
+                augmented = self.augmentations(image=image)
+                image = augmented["image"]
+
+            # 将图像通道顺序调整为(C, H, W)的形式,并转换为float32类型
+            image = np.transpose(image, (2, 0, 1)).astype(np.float32)
+
+        # 返回样本,包括图像和对应的目标标签
+        return {
+            "image": torch.tensor(image, dtype=torch.float),
+            "targets": torch.tensor(targets, dtype=torch.long),
+        }
+
+

现在我们需要 engine.py。engine.py 包含训练和评估功能。让我们看看 engine.py 是如何编写的。

+
import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+# 用于训练模型的函数
+def train(data_loader, model, optimizer, device):
+    # 将模型设置为训练模式
+    model.train()
+    for data in data_loader:
+        # 从数据加载器中提取输入图像和目标标签
+        inputs = data["image"]
+        targets = data["targets"]
+
+        # 将输入和目标移动到指定的设备(例如,GPU)
+        inputs = inputs.to(device, dtype=torch.float)
+        targets = targets.to(device, dtype=torch.float)
+
+        # 将优化器中的梯度归零
+        optimizer.zero_grad()
+
+        # 前向传播:计算模型预测
+        outputs = model(inputs)
+
+        # 使用带逻辑斯蒂函数的二元交叉熵损失计算损失
+        loss = nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))
+
+        # 反向传播:计算梯度并更新模型权重
+        loss.backward()
+        optimizer.step()
+
+# 用于评估模型的函数
+def evaluate(data_loader, model, device):
+    # 将模型设置为评估模式(不进行梯度计算)
+    model.eval()
+
+    # 初始化列表以存储真实目标和模型预测
+    final_targets = []
+    final_outputs = []
+
+    with torch.no_grad():
+        for data in data_loader:
+            # 从数据加载器中提取输入图像和目标标签
+            inputs = data["image"]
+            targets = data["targets"]
+
+            # 将输入移动到指定的设备(例如,GPU)
+            inputs = inputs.to(device, dtype=torch.float)
+
+            # 获取模型预测
+            output = model(inputs)
+
+            # 将目标和输出转换为CPU和Python列表
+            targets = targets.detach().cpu().numpy().tolist()
+            output = output.detach().cpu().numpy().tolist()
+
+            # 将列表扩展以包含批次数据
+            final_targets.extend(targets)
+            final_outputs.extend(output)
+
+    # 返回最终的模型预测和真实目标
+    return final_outputs, final_targets
+
+

有了 engine.py,就可以创建一个新文件:model.py。model.py 将包含我们的模型。把模型与训练分开是个好主意,因为这样我们就可以轻松地试验不同的模型和不同的架构。名为 pretrainedmodels 的 PyTorch 库中有很多不同的模型架构,如 AlexNet、ResNet、DenseNet 等。这些不同的模型架构是在名为 ImageNet 的大型图像数据集上训练出来的。在 ImageNet 上训练后,我们可以使用它们的权重,也可以不使用这些权重。如果我们不使用 ImageNet 权重进行训练,这意味着我们的网络将从头开始学习一切。这就是 model.py 的样子。

+
import torch.nn as nn
+import pretrainedmodels
+
+# 定义一个函数以获取模型
+def get_model(pretrained):
+    if pretrained:
+        # 使用预训练的 AlexNet 模型,加载在 ImageNet 数据集上训练的权重
+        model = pretrainedmodels.__dict__["alexnet"](pretrained='imagenet')
+    else:
+        # 使用未经预训练的 AlexNet 模型
+        model = pretrainedmodels.__dict__["alexnet"](pretrained=None)
+
+    # 修改模型的最后一层全连接层,以适应特定任务
+    model.last_linear = nn.Sequential(
+        nn.BatchNorm1d(4096),  # 批归一化层
+        nn.Dropout(p=0.25),  # 随机失活层,防止过拟合
+        nn.Linear(in_features=4096, out_features=2048),  # 连接层
+        nn.ReLU(),  # ReLU 激活函数
+        nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1),  # 批归一化层
+        nn.Dropout(p=0.5),  # 随机失活层
+        nn.Linear(in_features=2048, out_features=1)  # 最终的二元分类层
+    )
+
+    return model
+
+

如果你打印了网络,会得到如下输出:

+
AlexNet(
+    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6)) 
+    (_features): Sequential(
+        (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2,2))
+        (1): ReLU(inplace=True)
+        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
+        (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2,2))
+        (4): ReLU(inplace=True)
+        (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
+        (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
+        (7): ReLU(inplace=True)
+        (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
+        (9): ReLU(inplace=True)
+        (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1,1))
+        (11): ReLU(inplace=True)
+        (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, eil_mode=False))
+    (dropout0): Dropout(p=0.5, inplace=False)
+    (linear0): Linear(in_features=9216, out_features=4096, bias=True) 
+    (relu0): ReLU(inplace=True)
+    (dropout1): Dropout(p=0.5, inplace=False)
+    (linear1): Linear(in_features=4096, out_features=4096, bias=True) 
+    (relu1): ReLU(inplace=True)
+    (last_linear): Sequential(
+        (0): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, rack_running_stats=True)
+        (1): Dropout(p=0.25, inplace=False)
+        (2): Linear(in_features=4096, out_features=2048, bias=True)
+        (3): ReLU()
+        (4): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, 
+track_running_stats=True)
+        (5): Dropout(p=0.5, inplace=False)
+        (6): Linear(in_features=2048, out_features=1, bias=True) 
+    )
+)
+
+

现在,万事俱备,可以开始训练了。我们将使用 train.py 训练模型。

+
import os
+import pandas as pd
+import numpy as np
+import albumentations
+import torch
+from sklearn import metrics
+from sklearn.model_selection import train_test_split
+import dataset
+import engine
+from model import get_model
+
+if __name__ == "__main__":
+    # 定义数据路径、设备、迭代次数
+    data_path = "/home/abhishek/workspace/siim_png/"
+    device = "cuda"  # 使用GPU加速
+    epochs = 10
+
+    # 从CSV文件读取数据
+    df = pd.read_csv(os.path.join(data_path, "train.csv"))
+    images = df.ImageId.values.tolist()
+    images = [os.path.join(data_path, "train_png", i + ".png") for i in images]
+    targets = df.target.values
+
+    # 获取预训练的模型
+    model = get_model(pretrained=True)
+    model.to(device)
+
+    # 定义均值和标准差,用于数据标准化
+    mean = (0.485, 0.456, 0.406)
+    std = (0.229, 0.224, 0.225)
+
+    # 数据增强,将图像标准化
+    aug = albumentations.Compose(
+        [
+            albumentations.Normalize(
+                mean, std, max_pixel_value=255.0, always_apply=True
+            )
+        ]
+    )
+
+    # 划分训练集和验证集
+    train_images, valid_images, train_targets, valid_targets = train_test_split(
+        images, targets, stratify=targets, random_state=42
+    )
+
+    # 创建训练数据集和验证数据集
+    train_dataset = dataset.ClassificationDataset(
+        image_paths=train_images,
+        targets=train_targets,
+        resize=(227, 227),
+        augmentations=aug,
+    )
+
+    # 创建训练数据加载器
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset, batch_size=16, shuffle=True, num_workers=4
+    )
+
+    # 创建验证数据集
+    valid_dataset = dataset.ClassificationDataset(
+        image_paths=valid_images,
+        targets=valid_targets,
+        resize=(227, 227),
+        augmentations=aug,
+    )
+
+    # 创建验证数据加载器
+    valid_loader = torch.utils.data.DataLoader(
+        valid_dataset, batch_size=16, shuffle=False, num_workers=4
+    )
+
+    # 定义优化器
+    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
+
+    # 训练循环
+    for epoch in range(epochs):
+        # 训练模型
+        engine.train(train_loader, model, optimizer, device=device)
+
+        # 评估模型性能
+        predictions, valid_targets = engine.evaluate(
+            valid_loader, model, device=device
+        )
+
+        # 计算ROC AUC分数并打印
+        roc_auc = metrics.roc_auc_score(valid_targets, predictions)
+        print(f"Epoch={epoch}, Valid ROC AUC={roc_auc}")
+
+

让我们在没有预训练权重的情况下进行训练:

+
Epoch=0, Valid ROC AUC=0.5737161981475328
+Epoch=1, Valid ROC AUC=0.5362868001588292
+Epoch=2, Valid ROC AUC=0.6163448214387008
+Epoch=3, Valid ROC AUC=0.6119219143780944
+Epoch=4, Valid ROC AUC=0.6229718888519726
+Epoch=5, Valid ROC AUC=0.5983014999635341
+Epoch=6, Valid ROC AUC=0.5523236874306134
+Epoch=7, Valid ROC AUC=0.4717721611306046
+Epoch=8, Valid ROC AUC=0.6473408263980617
+Epoch=9, Valid ROC AUC=0.6639862888260415
+
+

AUC 约为 0.66,甚至低于我们的随机森林模型。使用预训练权重会发生什么情况?

+
Epoch=0, Valid ROC AUC=0.5730387429803165
+Epoch=1, Valid ROC AUC=0.5319813942934937
+Epoch=2, Valid ROC AUC=0.627111577514323
+Epoch=3, Valid ROC AUC=0.6819736959393209
+Epoch=4, Valid ROC AUC=0.5747117168950512
+Epoch=5, Valid ROC AUC=0.5994619255609669
+Epoch=6, Valid ROC AUC=0.5080889443530546
+Epoch=7, Valid ROC AUC=0.6323792776512727
+Epoch=8, Valid ROC AUC=0.6685753182661686
+Epoch=9, Valid ROC AUC=0.6861802387300147
+
+

现在的 AUC 好了很多。不过,它仍然较低。预训练模型的好处是可以轻松尝试多种不同的模型。让我们试试使用预训练权重的 resnet18

+
import torch.nn as nn
+import pretrainedmodels
+
+# 定义一个函数以获取模型
+def get_model(pretrained):
+    if pretrained:
+        # 使用预训练的 ResNet-18 模型,加载在 ImageNet 数据集上训练的权重
+        model = pretrainedmodels.__dict__["resnet18"](pretrained='imagenet')
+    else:
+        # 使用未经预训练的 ResNet-18 模型
+        model = pretrainedmodels.__dict__["resnet18"](pretrained=None)
+
+    # 修改模型的最后一层全连接层,以适应特定任务
+    model.last_linear = nn.Sequential(
+        nn.BatchNorm1d(512),  # 批归一化层
+        nn.Dropout(p=0.25),  # 随机失活层,防止过拟合
+        nn.Linear(in_features=512, out_features=2048),  # 连接层
+        nn.ReLU(),  # ReLU 激活函数
+        nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1),  # 批归一化层
+        nn.Dropout(p=0.5),  # 随机失活层
+        nn.Linear(in_features=2048, out_features=1)  # 最终的二元分类层
+    )
+
+    return model
+
+

在尝试该模型时,我还将图像大小改为 512x512,并添加了一个学习率调度器,每 3 个epochs后将学习率乘以 0.5。

+
Epoch=0, Valid ROC AUC=0.5988225569880796
+Epoch=1, Valid ROC AUC=0.730349343208836
+Epoch=2, Valid ROC AUC=0.5870943169939142
+Epoch=3, Valid ROC AUC=0.5775864444138311
+Epoch=4, Valid ROC AUC=0.7330502499939224
+Epoch=5, Valid ROC AUC=0.7500336296524395
+Epoch=6, Valid ROC AUC=0.7563722113724951
+Epoch=7, Valid ROC AUC=0.7987463837994215
+Epoch=8, Valid ROC AUC=0.798505708937384
+Epoch=9, Valid ROC AUC=0.8025477500546988
+
+

这个模型似乎表现最好。不过,您可以调整 AlexNet 中的不同参数和图像大小,以获得更好的分数。 使用增强技术将进一步提高得分。优化深度神经网络很难,但并非不可能。选择 Adam 优化器、使用低学习率、在验证损失达到高点时降低学习率、尝试一些增强技术、尝试对图像进行预处理(如在需要时进行裁剪,这也可视为预处理)、改变批次大小等。你可以做很多事情来优化深度神经网络。

+

与 AlexNet 相比,ResNet 的结构要复杂得多。ResNet 是残差神经网络(Residual Neural Network)的缩写,由 K. He、X. Zhang、S. Ren 和 J. Sun 在 2015 年发表的论文中提出。ResNet 由残差块(residual blocks)组成,通过跳过某些层,使知识能够不断在各层中进行传递。这些层之间的 连接被称为跳跃连接(skip-connections),因为我们跳过了一层或多层。跳跃连接通过将梯度传播到更多层来帮助解决梯度消失问题。这样,我们就可以训练非常大的卷积神经网络,而不会损失性能。通常情况下,如果我们使用的是大型神经网络,那么当训练到某一节点上时训练损失反而会增加,但这可以通过使用跳跃连接来避免。通过图 7 可以更好地理解这一点。

+

+

图 7:简单连接与残差连接的比较。参见跳跃连接。请注意,本图省略了最后一层。

+ +

残差块非常容易理解。你从某一层获取输出,跳过一些层,然后将输出添加到网络中更远的一层。虚线表示输入形状需要调整,因为使用了最大池化,而最大池化的使用会改变输出的大小。

+

ResNet 有多种不同的版本: 有 18 层、34 层、50 层、101 层和 152 层,所有这些层都在 ImageNet 数据集上进行了权重预训练。如今,预训练模型(几乎)适用于所有情况,但请确保您从较小的模型开始,例如,从 resnet-18 开始,而不是 resnet-50。其他一些 ImageNet 预训练模型包括:

+
    +
  • Inception
  • +
  • +

    DenseNet(different variations)

    +
  • +
  • +

    NASNet

    +
  • +
  • PNASNet
  • +
  • VGG
  • +
  • Xception
  • +
  • ResNeXt
  • +
  • EfficientNet, etc.
  • +
+

大部分预训练的最先进模型可以在 GitHub 上的 pytorch- pretrainedmodels 资源库中找到:https://github.com/Cadene/pretrained-models.pytorch。详细讨论这些模型不在本章(和本书)范围之内。既然我们只关注应用,那就让我们看看这样的预训练模型如何用于分割任务。

+

+

图 8:U-Net架构

+ +

分割(Segmentation)是计算机视觉中相当流行的一项任务。在分割任务中,我们试图从背景中移除/提取前景。 前景和背景可以有不同的定义。我们也可以说,这是一项像素分类任务,你的工作是给给定图像中的每个像素分配一个类别。事实上,我们正在处理的气胸数据集就是一项分割任务。在这项任务中,我们需要对给定的胸部放射图像进行气胸分割。用于分割任务的最常用模型是 U-Net。其结构如图 8 所示。

+

U-Net 包括两个部分:编码器和解码器。编码器与您目前所见过的任何 U-Net 都是一样的。解码器则有些不同。解码器由上卷积层组成。在上卷积(up-convolutions)(转置卷积transposed convolutions)中,我们使用滤波器,当应用到一个小图像时,会产生一个大图像。在 PyTorch 中,您可以使用 ConvTranspose2d 来完成这一操作。必须注意的是,上卷积与上采样并不相同。上采样是一个简单的过程,我们在图像上应用一个函数来调整它的大小。在上卷积中,我们要学习滤波器。我们将编码器的某些部分作为某些解码器的输入。这对 上卷积层非常重要。

+

让我们看看 U-Net 是如何实现的。

+
import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+# 定义一个双卷积层
+def double_conv(in_channels, out_channels): 
+    conv = nn.Sequential(
+        nn.Conv2d(in_channels, out_channels, kernel_size=3), 
+        nn.ReLU(inplace=True),
+        nn.Conv2d(out_channels, out_channels, kernel_size=3), 
+        nn.ReLU(inplace=True)
+    )
+    return conv
+
+# 定义函数用于裁剪输入张量
+def crop_tensor(tensor, target_tensor):
+    target_size = target_tensor.size()[2] 
+    tensor_size = tensor.size()[2] 
+    delta = tensor_size - target_size 
+    delta = delta // 2
+    return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]
+
+# 定义 U-Net 模型
+class UNet(nn.Module):
+    def __init__(self):
+        super(UNet, self).__init()
+
+        # 定义池化层,编码器和解码器的双卷积层
+        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.down_conv_1 = double_conv(1, 64)
+        self.down_conv_2 = double_conv(64, 128)
+        self.down_conv_3 = double_conv(128, 256)
+        self.down_conv_4 = double_conv(256, 512)
+        self.down_conv_5 = double_conv(512, 1024)
+
+        # 定义上采样层和解码器的双卷积层
+        self.up_trans_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
+        self.up_conv_1 = double_conv(1024, 512) 
+        self.up_trans_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
+        self.up_conv_2 = double_conv(512, 256) 
+        self.up_trans_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
+        self.up_conv_3 = double_conv(256, 128) 
+        self.up_trans_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
+        self.up_conv_4 = double_conv(128, 64) 
+
+        # 定义输出层
+        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
+
+    def forward(self, image): 
+        # 编码器部分
+        x1 = self.down_conv_1(image)
+        x2 = self.max_pool_2x2(x1)
+        x3 = self.down_conv_2(x2)
+        x4 = self.max_pool_2x2(x3)
+        x5 = self.down_conv_3(x4)
+        x6 = self.max_pool_2x2(x5)
+        x7 = self.down_conv_4(x6)
+        x8 = self.max_pool_2x2(x7)
+        x9 = self.down_conv_5(x8)
+
+        # 解码器部分
+        x = self.up_trans_1(x9) 
+        y = crop_tensor(x7, x)
+        x = self.up_conv_1(torch.cat([x, y], axis=1)) 
+        x = self.up_trans_2(x)
+        y = crop_tensor(x5, x)
+        x = self.up_conv_2(torch.cat([x, y], axis=1)) 
+        x = self.up_trans_3(x)
+        y = crop_tensor(x3, x)
+        x = self.up_conv_3(torch.cat([x, y], axis=1)) 
+        x = self.up_trans_4(x)
+        y = crop_tensor(x1, x)
+        x = self.up_conv_4(torch.cat([x, y], axis=1))
+
+        # 输出层
+        out = self.out(x) 
+        return out
+
+if __name__ == "__main__":
+    image = torch.rand((1, 1, 572, 572)) 
+    model = UNet()
+    print(model(image))
+
+

请注意,我上面展示的 U-Net 实现是 U-Net 论文的原始实现。互联网上有很多不同的实现方法。 有些人喜欢使用双线性采样代替转置卷积进行上采样,但这并不是论文的真正实现。不过,它的性能可能会更好。在上图所示的原始实现中,有一个单通道图像,输出中有两个通道:一个是前景,一个是背景。正如你所看到的,这可以很容易地为任意数量的类和任意数量的输入通道进行定制。在此实现中,输入图像的大小与输出图像的大小不同,因为我们使用的是无填充卷积(convolutions without padding)。

+

我们可以看到,U-Net 的编码器部分只是一个简单的卷积网络。 因此,我们可以用任何网络(如 ResNet)来替换它。 这种替换也可以通过预训练权重来完成。因此,我们可以使用基于 ResNet 的编码器,该编码器已在 ImageNet 和通用解码器上进行了预训练。我们可以使用多种不同的网络架构来代替 ResNet。Pavel Yakubovskiy 所著的《Segmentation Models Pytorch》就是许多此类变体的实现,其中编码器可以被预训练模型所取代。让我们应用基于 ResNet 的 U-Net 来解决气胸检测问题。

+

大多数类似的问题都有两个输入:原始图像和掩码(mask)。 如果有多个对象,就会有多个掩码。 在我们的气胸数据集中,我们得到的是 RLE。RLE 代表运行长度编码,是一种表示二进制掩码以节省空间的方法。深入研究 RLE 超出了本章的范围。因此,假设我们有一张输入图像和相应的掩码。让我们先设计一个数据集类,用于输出图像和掩码图像。请注意,我们创建的脚本几乎可以应用于任何分割问题。训练数据集是一个 CSV 文件,只包含图像 ID(也是文件名)。

+
import os
+import glob
+import torch
+import numpy as np
+import pandas as pd
+from PIL import Image, ImageFile
+from tqdm import tqdm
+from collections import defaultdict
+from torchvision import transforms
+from albumentations import (Compose,
+                            OneOf,
+                            RandomBrightnessContrast, 
+                            RandomGamma,
+                            ShiftScaleRotate, )
+
+# 设置PIL图像加载截断的处理
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+# 创建SIIM数据集类
+class SIIMDataset(torch.utils.data.Dataset):
+    def __init__(self, image_ids, transform=True, preprocessing_fn=None):
+        self.data = defaultdict(dict)
+        self.transform = transform
+        self.preprocessing_fn = preprocessing_fn
+
+        # 定义数据增强
+        self.aug = Compose(
+            [ShiftScaleRotate(
+                shift_limit=0.0625,
+                scale_limit=0.1,
+                rotate_limit=10, p=0.8
+            ),
+             OneOf(
+                 [
+                     RandomGamma(
+                         gamma_limit=(90, 110) 
+                     ),
+                     RandomBrightnessContrast( 
+                         brightness_limit=0.1, 
+                         contrast_limit=0.1
+                     ), 
+                 ], 
+                 p=0.5,
+             ), 
+            ]
+        )
+
+        # 构建数据字典,其中包含图像和掩码的路径信息
+        for imgid in image_ids:
+            files = glob.glob(os.path.join(TRAIN_PATH, imgid, "*.png")) 
+            self.data[counter] = {
+                "img_path": os.path.join( 
+                    TRAIN_PATH, imgid + ".png"
+                ),
+                "mask_path": os.path.join(
+                    TRAIN_PATH, imgid + "_mask.png" 
+                ),
+            }
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, item):
+        img_path = self.data[item]["img_path"] 
+        mask_path = self.data[item]["mask_path"]
+
+        # 打开图像并将其转换为RGB模式
+        img = Image.open(img_path) 
+        img = img.convert("RGB")
+        img = np.array(img)
+
+        # 打开掩码图像,并将其转换为浮点数
+        mask = Image.open(mask_path)
+        mask = (mask >= 1).astype("float32")
+
+        # 如果需要进行数据增强
+        if self.transform is True:
+            augmented = self.aug(image=img, mask=mask) 
+            img = augmented["image"]
+            mask = augmented["mask"]
+
+        # 应用预处理函数(如果有)
+        img = self.preprocessing_fn(img)
+
+        # 返回图像和掩码
+        return {
+            "image": transforms.ToTensor()(img),
+            "mask": transforms.ToTensor()(mask).float(), 
+        }
+
+

有了数据集类之后,我们就可以创建一个训练函数。

+
import os
+import sys
+import torch
+import numpy as np
+import pandas as pd
+import segmentation_models_pytorch as smp
+import torch.nn as nn
+import torch.optim as optim
+from apex import amp
+from collections import OrderedDict
+from sklearn import model_selection
+from tqdm import tqdm
+from torch.optim import lr_scheduler
+from dataset import SIIMDataset
+
+# 定义训练数据集CSV文件路径
+TRAINING_CSV = "../input/train_pneumothorax.csv" 
+# 定义训练和测试的批量大小
+TRAINING_BATCH_SIZE = 16 
+TEST_BATCH_SIZE = 4
+# 定义训练的时期数
+EPOCHS = 10
+# 指定使用的编码器和权重
+ENCODER = "resnet18"
+ENCODER_WEIGHTS = "imagenet"
+# 指定设备(GPU)
+DEVICE = "cuda"
+
+# 定义训练函数
+def train(dataset, data_loader, model, criterion, optimizer): 
+    model.train()
+    num_batches = int(len(dataset) / data_loader.batch_size)
+    tk0 = tqdm(data_loader, total=num_batches)
+    for d in tk0:
+        inputs = d["image"] 
+        targets = d["mask"]
+        inputs = inputs.to(DEVICE, dtype=torch.float)
+        targets = targets.to(DEVICE, dtype=torch.float)
+        optimizer.zero_grad()
+        outputs = model(inputs)
+        loss = criterion(outputs, targets)
+        with amp.scale_loss(loss, optimizer) as scaled_loss: 
+            scaled_loss.backward()
+        optimizer.step()
+    tk0.close()
+
+# 定义评估函数
+def evaluate(dataset, data_loader, model): 
+    model.eval()
+    final_loss = 0
+    num_batches = int(len(dataset) / data_loader.batch_size) 
+    tk0 = tqdm(data_loader, total=num_batches)
+    with torch.no_grad():
+        for d in tk0:
+            inputs = d["image"] 
+            targets = d["mask"]
+            inputs = inputs to (DEVICE, dtype=torch.float) 
+            targets = targets.to(DEVICE, dtype=torch.float) 
+            output = model(inputs)
+            loss = criterion(output, targets)
+            final_loss += loss
+        tk0.close()
+        return final_loss / num_batches
+
+if __name__ == "__main__":
+    df = pd.read_csv(TRAINING_CSV)
+    df_train, df_valid = model_selection.train_test_split( 
+        df, random_state=42, test_size=0.1
+    )
+    training_images = df_train.image_id.values 
+    validation_images = df_valid.image_id.values
+
+    # 创建 U-Net 模型
+    model = smp.Unet(
+        encoder_name=ENCODER,
+        encoder_weights=ENCODER_WEIGHTS, 
+        classes=1,
+        activation=None, 
+    )
+
+    # 获取数据预处理函数
+    prep_fn = smp.encoders.get_preprocessing_fn( 
+        ENCODER,
+        ENCODER_WEIGHTS 
+    )
+
+    # 将模型放在设备上
+    model.to(DEVICE)
+
+    # 创建训练数据集
+    train_dataset = SIIMDataset( 
+        training_images, 
+        transform=True, 
+        preprocessing_fn=prep_fn,
+    )
+
+    # 创建训练数据加载器
+    train_loader = torch.utils.data.DataLoader( 
+        train_dataset,
+        batch_size=TRAINING_BATCH_SIZE, 
+        shuffle=True,
+        num_workers=12 
+    )
+
+    # 创建验证数据集
+    valid_dataset = SIIMDataset( 
+        validation_images, 
+        transform=False, 
+        preprocessing_fn=prep_fn,
+    )
+
+    # 创建验证数据加载器
+    valid_loader = torch.utils.data.DataLoader( 
+        valid_dataset,
+        batch_size=TEST_BATCH_SIZE, 
+        shuffle=True,
+        num_workers=4 
+    )
+
+    # 定义优化器
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 
+
+    # 定义学习率调度器
+    scheduler = lr_scheduler.ReduceLROnPlateau( 
+        optimizer, mode="min", patience=3, verbose=True
+    )
+
+    # 初始化 Apex 混合精度训练
+    model, optimizer = amp.initialize(
+        model, optimizer, opt_level="O1", verbosity=0 
+    )
+
+    # 如果有多个GPU,则使用 DataParallel 进行并行训练
+    if torch.cuda.device_count() > 1:
+        print(f"Let's use {torch.cuda.device_count()} GPUs!") 
+        model = nn.DataParallel(model)
+
+    # 输出训练相关的信息
+    print(f"Training batch size: {TRAINING_BATCH_SIZE}")
+    print(f"Test batch size: {TEST_BATCH_SIZE}")
+    print(f"Epochs: {EPOCHS}")
+    print(f"Image size: {IMAGE_SIZE}")
+    print(f"Number of training images: {len(train_dataset)}")
+    print(f"Number of validation images: {len(valid_dataset)}")
+    print(f"Encoder: {ENCODER}")
+
+    # 循环训练多个时期
+    for epoch in range(EPOCHS):
+        print(f"Training Epoch: {epoch}") 
+        train(
+            train_dataset, 
+            train_loader, 
+            model, 
+            criterion, 
+            optimizer
+        )
+        print(f"Validation Epoch: {epoch}") 
+        val_log = evaluate( 
+            valid_dataset, 
+            valid_loader, 
+            model
+        )
+        scheduler.step(val_log["loss"]) 
+        print("\n")
+
+

在分割问题中,你可以使用各种损失函数,例如二元交叉熵、focal损失、dice损失等。我把这个问题留给 读者根据评估指标来决定合适的损失。当训练这样一个模型时,您将建立预测气胸位置的模型,如图 9 所示。在上述代码中,我们使用英伟达 apex 进行了混合精度训练。请注意,从 PyTorch 1.6.0+ 版本开始,PyTorch 本身就提供了这一功能。

+

+

图 9:从训练有素的模型中检测到气胸的示例(可能不是正确预测)。

+ +

我在一个名为 "Well That's Fantastic Machine Learning (WTFML) "的 python 软件包中收录了一些常用函数。让我们看看它如何帮助我们为 FGVC 202013 植物病理学挑战赛中的植物图像建立多类分类模型。

+
import os
+import pandas as pd
+import numpy as np
+import albumentations
+import argparse
+import torch
+import torchvision
+import torch.nn as nn
+import torch.nn.functional as F
+from sklearn import metrics
+from sklearn.model_selection import train_test_split
+from wtfml.engine import Engine
+from wtfml.data_loaders.image import ClassificationDataLoader
+
+# 自定义损失函数,实现密集交叉熵
+class DenseCrossEntropy(nn.Module): 
+    def __init__(self):
+        super(DenseCrossEntropy, self).__init__() 
+
+    def forward(self, logits, labels):
+        logits = logits.float() 
+        labels = labels.float()
+        logprobs = F.log_softmax(logits, dim=-1) 
+        loss = -labels * logprobs
+        loss = loss.sum(-1) 
+        return loss.mean()
+
+# 自定义神经网络模型
+class Model(nn.Module): 
+    def __init__(self):
+        super().__init()
+        self.base_model = torchvision.models.resnet18(pretrained=True) 
+        in_features = self.base_model.fc.in_features
+        self.out = nn.Linear(in_features, 4) 
+
+    def forward(self, image, targets=None):
+        batch_size, C, H, W = image.shape
+        x = self.base_model.conv1(image)
+        x = self.base_model.bn1(x)
+        x = self.base_model.relu(x)
+        x = self.base_model.maxpool(x)
+        x = self.base_model.layer1(x)
+        x = self.base_model.layer2(x)
+        x = self.base_model.layer3(x)
+        x = self.base_model.layer4(x)
+        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
+        x = self.out(x)
+        loss = None
+        if targets is not None:
+            loss = DenseCrossEntropy()(x, targets.type_as(x)) 
+        return x, loss
+
+if __name__ == "__main__":
+    # 命令行参数解析器
+    parser = argparse.ArgumentParser() 
+    parser.add_argument("--data_path", type=str, )
+    parser.add_argument("--device", type=str,)
+    parser.add_argument("--epochs", type=int,)
+    args = parser.parse_args()
+
+    # 从CSV文件加载数据
+    df = pd.read_csv(os.path.join(args.data_path, "train.csv")) 
+    images = df.image_id.values.tolist()
+    images = [os.path.join(args.data_path, "images", i + ".jpg") for i in images]
+    targets = df[["healthy", "multiple_diseases", "rust", "scab"]].values 
+
+    # 创建神经网络模型
+    model = Model()
+    model.to(args.device)
+
+    # 定义均值和标准差以及数据增强
+    mean = (0.485, 0.456, 0.406)
+    std = (0.229, 0.224, 0.225)
+    aug = albumentations.Compose( 
+        [
+            albumentations.Normalize(
+                mean, 
+                std,
+                max_pixel_value=255.0, 
+                always_apply=True
+            ) 
+        ]
+    ) 
+
+    # 分割训练集和验证集
+    (
+        train_images, valid_images, 
+        train_targets, valid_targets
+    ) = train_test_split(images, targets) 
+
+    # 创建训练数据加载器
+    train_loader = ClassificationDataLoader(
+        image_paths=train_images, 
+        targets=train_targets, 
+        resize=(128, 128), 
+        augmentations=aug,
+    ).fetch(
+        batch_size=16, 
+        num_workers=4, 
+        drop_last=False, 
+        shuffle=True, 
+        tpu=False
+    )
+
+    # 创建验证数据加载器
+    valid_loader = ClassificationDataLoader( 
+        image_paths=valid_images,
+        targets=valid_targets, 
+        resize=(128, 128), 
+        augmentations=aug,
+    ).fetch(
+        batch_size=16, 
+        num_workers=4, 
+        drop_last=False, 
+        shuffle=False, 
+        tpu=False
+    )
+
+    # 创建优化器
+    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) 
+
+    # 创建学习率调度器
+    scheduler = torch.optim.lr_scheduler.StepLR(
+        optimizer, step_size=15, gamma=0.6 
+    )
+
+    # 循环训练多个时期
+    for epoch in range(args.epochs): 
+        # 训练模型
+        train_loss = Engine.train(
+            train_loader, model, optimizer, device=args.device 
+        )
+        # 评估模型
+        valid_loss = Engine.evaluate(
+            valid_loader, model, device=args.device 
+        )
+        # 打印损失信息
+        print(f"{epoch}, Train Loss={train_loss} Valid Loss={valid_loss}")
+
+

有了数据后,就可以运行脚本了:

+
python plant.py --data_path ../../plant_pathology --device cuda -- 
+epochs 2
+100%|█████████████| 86/86 [00:12<00:00, 6.73it/s, loss=0.723] 
+100%|█████████████ 29/29 [00:04<00:00, 6.62it/s, loss=0.433] 
+0, Train Loss=0.7228777609592261 Valid Loss=0.4327834551704341 
+100%|█████████████| 86/86 [00:12<00:00, 6.74it/s, loss=0.271] 
+100%|█████████████ 29/29 [00:04<00:00, 6.63it/s, loss=0.568] 
+1, Train Loss=0.2708700496790021 Valid Loss=0.56841839541649
+
+

正如你所看到的,这让我们构建模型变得简单,代码也易于阅读和理解。没有任何封装的 PyTorch 效果最好。图像中不仅仅有分类,还有很多其他的内容,如果我开始写所有的内容,就得再写一本书了, 接近(几乎)任何图像问题(作者在开玩笑)。

+ + + + + + + +
+
+
+ + + + Back to top + + +
+ + + + +
+
+
+
+ + + + + + + + + + + + + + \ No newline at end of file diff --git "a/\346\226\207\346\234\254\345\210\206\347\261\273\346\210\226\345\233\236\345\275\222\346\226\271\346\263\225/index.html" "b/\346\226\207\346\234\254\345\210\206\347\261\273\346\210\226\345\233\236\345\275\222\346\226\271\346\263\225/index.html" new file mode 100644 index 0000000..c78d0b7 --- /dev/null +++ "b/\346\226\207\346\234\254\345\210\206\347\261\273\346\210\226\345\233\236\345\275\222\346\226\271\346\263\225/index.html" @@ -0,0 +1,1537 @@ + + + + + + + + + + + + + + + + + 文本分类或回归方法 - AAAMLP 中译版 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + 跳转至 + + +
+
+ +
+ + + +
+ + +
+ +
+ + + + + + +
+
+ + + +
+
+
+ + + + +
+
+
+ + + +
+
+
+ + + + +
+
+
+ + +
+
+ + + + + + + +

文本分类或回归方法

+

文本问题是我的最爱。一般来说,这些问题也被称为自然语言处理(NLP)问题。NLP 问题与图像问题也有很大不同。你需要创建以前从未为表格问题创建过的数据管道。你需要了解商业案例,才能建立一个好的模型。顺便说一句,机器学习中的任何事情都是如此。建立模型会让你达到一定的水平,但要想改善和促进你所建立模型的业务,你必须了解它对业务的影响。

+

NLP 问题有很多种,其中最常见的是字符串分类。很多时候,我们会看到人们在处理表格数据或图像时表现出色,但在处理文本时,他们甚至不知道从何入手。文本数据与其他类型的数据集没有什么不同。对于计算机来说,一切都是数字。

+

假设我们从情感分类这一基本任务开始。我们将尝试对电影评论进行情感分类。因此,您有一个文本,并有与之相关的情感。你将如何处理这类问题?是应用深度神经网络? 不,绝对错了。你要从最基本的开始。让我们先看看这些数据是什么样子的。

+

我们从IMDB 电影评论数据集开始,该数据集包含 25000 篇正面情感评论和 25000 篇负面情感评论。

+

我将在此讨论的概念几乎适用于任何文本分类数据集。

+

这个数据集非常容易理解。一篇评论对应一个目标变量。请注意,我写的是评论而不是句子。评论就是一堆句子。所以,到目前为止,你一定只看到了对单句的分类,但在这个问题中,我们将对多个句子进行分类。简单地说,这意味着不仅一个句子会对情感产生影响,而且情感得分是多个句子得分的组合。数据简介如图 1 所示。

+

+

如何着手解决这样的问题?一个简单的方法就是手工制作两份单词表。一个列表包含你能想象到的所有正面词汇,例如好、棒、好等;另一个列表包含所有负面词汇,例如坏、恶等。我们先不要举例说明坏词,否则这本书就只能供 18 岁以上的人阅读了。一旦你有了这些列表,你甚至不需要一个模型来进行预测。这些列表也被称为情感词典。你可以用一个简单的计数器来计算句子中正面和负面词语的数量。如果正面词语的数量较多,则表示该句子具有正面情感;如果负面词语的数量较多,则表示该句子具有负面情感。如果句子中没有这些词,则可以说该句子具有中性情感。这是最古老的方法之一,现在仍有人在使用。它也不需要太多代码。

+
def find_sentiment(sentence, pos, neg):
+    sentence = sentence.split()
+    sentence = set(sentence)
+    num_common_pos = len(sentence.intersection(pos))
+    num_common_neg = len(sentence.intersection(neg))
+    if num_common_pos > num_common_neg:
+        return "positive"
+    if num_common_pos < num_common_neg: 
+        return "negative"
+    return "neutral"
+
+

不过,这种方法考虑的因素并不多。正如你所看到的,我们的 split() 也并不完美。如果使用 split(),就会出现这样的句子:

+

"hi, how are you?"

+

经过分割后变为:

+

["hi,", "how","are","you?"]

+

这种方法并不理想,因为单词中包含了逗号和问号,它们并没有被分割。因此,如果没有在分割前对这些特殊字符进行预处理,不建议使用这种方法。将字符串拆分为单词列表称为标记化。最流行的标记化方法之一来自 NLTK(自然语言工具包)

+
In [X]: from nltk.tokenize import word_tokenize
+In [X]: sentence = "hi, how are you?"
+In [X]: sentence.split()
+Out[X]: ['hi,', 'how', 'are', 'you?'] 
+In [X]: word_tokenize(sentence)
+Out[X]: ['hi', ',', 'how', 'are', 'you', '?']
+
+

正如您所看到的,使用 NLTK 的单词标记化功能,同一个句子的拆分效果要好得多。使用单词列表进行对比的效果也会更好!这就是我们将应用于第一个情感检测模型的方法。

+

在处理 NLP 分类问题时,您应该经常尝试的基本模型之一是词袋模型(bag of words)。在词袋模型中,我们创建一个巨大的稀疏矩阵,存储语料库(语料库=所有文档=所有句子)中所有单词的计数。为此,我们将使用 scikit-learn 中的 CountVectorizer。让我们看看它是如何工作的。

+
from sklearn.feature_extraction.text import CountVectorizer
+
+corpus = [
+    "hello, how are you?",
+    "im getting bored at home. And you? What do you think?", 
+    "did you know about counts",
+    "let's see if this works!", 
+    "YES!!!!"
+]
+ctv = CountVectorizer()
+ctv.fit(corpus)
+corpus_transformed = ctv.transform(corpus)
+
+

如果我们打印 corpus_transformed,就会得到类似下面的结果:

+
(0, 2)      1
+(0, 9)      1
+(0, 11)     1
+(0, 22)     1
+(1, 1)      1
+(1, 3)      1
+(1, 4)      1
+(1, 7)      1
+(1, 8)      1
+(1, 10)     1
+(1, 13)     1
+(1, 17)     1
+(1, 19)     1
+(1, 22)     2
+(2, 0)      1
+(2, 5)      1
+(2, 6)      1
+(2, 14)     1
+(2, 22)     1
+(3, 12)     1
+(3, 15)     1
+(3, 16)     1
+(3, 18)     1
+(3, 20)     1
+(4, 21)     1
+
+

在前面的章节中,我们已经见识过这种表示法。即稀疏表示法。因此,语料库现在是一个稀疏矩阵,其中第一个样本有 4 个元素,第二个样本有 10 个元素,以此类推,第三个样本有 5 个元素,以此类推。我们还可以看到,这些元素都有相关的计数。有些元素会出现两次,有些则只有一次。例如,在样本 2(第 1 行)中,我们看到第 22 列的数值是 2。这是为什么呢?第 22 列是什么?

+

CountVectorizer 的工作方式是首先对句子进行标记化处理,然后为每个标记赋值。因此,每个标记都由一个唯一索引表示。这些唯一索引就是我们看到的列。CountVectorizer 会存储这些信息。

+
print(ctv.vocabulary_)
+{'hello': 9, 'how': 11, 'are': 2, 'you': 22, 'im': 13, 'getting': 8, 
+'bored': 4, 'at': 3, 'home': 10, 'and': 1, 'what': 19, 'do': 7, 'think': 
+17, 'did': 6, 'know': 14, 'about': 0, 'counts': 5, 'let': 15, 'see': 16, 
+'if': 12, 'this': 18, 'works': 20, 'yes': 21}
+
+

我们看到,索引 22 属于 "you",而在第二句中,我们使用了两次 "you"。我希望大家现在已经清楚什么是词袋了。但是我们还缺少一些特殊字符。有时,这些特殊字符也很有用。例如,"? "在大多数句子中表示疑问句。让我们把 scikit-learn 的 word_tokenize 整合到 CountVectorizer 中,看看会发生什么。

+
from sklearn.feature_extraction.text import CountVectorizer
+from nltk.tokenize import word_tokenize 
+
+corpus = [
+    "hello, how are you?",
+    "im getting bored at home. And you? What do you think?", 
+    "did you know about counts",
+    "let's see if this works!", 
+    "YES!!!!"
+]
+ctv = CountVectorizer(tokenizer=word_tokenize, token_pattern=None) 
+ctv.fit(corpus)
+corpus_transformed = ctv.transform(corpus) 
+print(ctv.vocabulary_)
+
+

这样,我们的词袋就变成了:

+
{'hello': 14, ',': 2, 'how': 16, 'are': 7, 'you': 27, '?': 4, 'im': 18,
+'getting': 13, 'bored': 9, 'at': 8, 'home': 15, '.': 3, 'and': 6, 'what':
+24, 'do': 12, 'think': 22, 'did': 11, 'know': 19, 'about': 5, 'counts': 
+10, 'let': 20, "'s": 1, 'see': 21, 'if': 17, 'this': 23, 'works': 25, 
+'!': 0, 'yes': 26}
+
+

我们现在可以利用 IMDB 数据集中的所有句子创建一个稀疏矩阵,并建立一个模型。该数据集中正负样本的比例为 1:1,因此我们可以使用准确率作为衡量标准。我们将使用 StratifiedKFold 并创建一个脚本来训练5个折叠。你会问使用哪个模型?对于高维稀疏数据,哪个模型最快?逻辑回归。我们将首先使用逻辑回归来处理这个数据集,并创建第一个基准模型。

+

让我们看看如何做到这一点。

+
import pandas as pd
+from nltk.tokenize import word_tokenize
+from sklearn import linear_model
+from sklearn import metrics
+from sklearn import model_selection
+from sklearn.feature_extraction.text import CountVectorizer
+if __name__ == "__main__": 
+    df = pd.read_csv("../input/imdb.csv") 
+    df.sentiment = df.sentiment.apply( 
+        lambda x: 1 if x == "positive" else 0
+    )
+    df["kfold"] = -1
+    df = df.sample(frac=1).reset_index(drop=True)
+    y = df.sentiment.values
+    kf = model_selection.StratifiedKFold(n_splits=5)
+    for f, (t_, v_) in enumerate(kf.split(X=df, y=y)): 
+        df.loc[v_, 'kfold'] = f
+
+    for fold_ in range(5):
+        train_df = df[df.kfold != fold_].reset_index(drop=True)
+        test_df = df[df.kfold == fold_].reset_index(drop=True) 
+        count_vec = CountVectorizer( 
+            tokenizer=word_tokenize, 
+            token_pattern=None
+        )
+        count_vec.fit(train_df.review)
+        xtrain = count_vec.transform(train_df.review) 
+        xtest = count_vec.transform(test_df.review)
+        model = linear_model.LogisticRegression()
+        model.fit(xtrain, train_df.sentiment)
+        preds = model.predict(xtest)
+        accuracy = metrics.accuracy_score(test_df.sentiment, preds)
+        print(f"Fold: {fold_}")
+        print(f"Accuracy = {accuracy}")
+        print("")
+
+

这段代码的运行需要一定的时间,但可以得到以下输出结果:

+
Fold: 0
+Accuracy = 0.8903
+
+Fold: 1
+Accuracy = 0.897
+
+Fold: 2
+Accuracy = 0.891
+
+Fold: 3 
+Accuracy = 0.8914
+
+Fold: 4 
+Accuracy = 0.8931
+
+

哇,准确率已经达到 89%,而我们所做的只是使用词袋和逻辑回归!这真是太棒了!不过,这个模型的训练花费了很多时间,让我们看看能否通过使用朴素贝叶斯分类器来缩短训练时间。朴素贝叶斯分类器在 NLP 任务中相当流行,因为稀疏矩阵非常庞大,而朴素贝叶斯是一个简单的模型。要使用这个模型,需要更改一个导入和模型的行。让我们看看这个模型的性能如何。我们将使用 scikit-learn 中的 MultinomialNB。

+
import pandas as pd
+from nltk.tokenize import word_tokenize
+from sklearn import naive_bayes
+from sklearn import metrics
+from sklearn import model_selection
+from sklearn.feature_extraction.text import CountVectorizer
+
+
+model = naive_bayes.MultinomialNB()
+model.fit(xtrain, train_df.sentiment)
+
+

得到如下结果:

+
Fold: 0
+Accuracy = 0.8444
+
+Fold: 1 
+Accuracy = 0.8499
+
+Fold: 2 
+Accuracy = 0.8422
+
+Fold: 3 
+Accuracy = 0.8443
+
+Fold: 4 
+Accuracy = 0.8455
+
+

我们看到这个分数很低。但朴素贝叶斯模型的速度非常快。

+

NLP 中的另一种方法是 TF-IDF,如今大多数人都倾向于忽略或不屑于了解这种方法。TF 是术语频率,IDF 是反向文档频率。从这些术语来看,这似乎有些困难,但通过 TF 和 IDF 的计算公式,事情就会变得很明显。 +$$ +TF(t) = \frac{Number\ of\ times\ a\ term\ t\ appears\ in\ a\ document}{Total\ number\ of\ terms\ in \ the\ document} +$$

+
\[ +IDF(t) = LOG\left(\frac{Total\ number\ of\ documents}{Number\ of\ documents with\ term\ t\ in\ it}\right) +\]
+

术语 t 的 TF-IDF 定义为: +$$ +TF-IDF(t) = TF(t) \times IDF(t) +$$ +与 scikit-learn 中的 CountVectorizer 类似,我们也有 TfidfVectorizer。让我们试着像使用 CountVectorizer 一样使用它。

+
from sklearn.feature_extraction.text import TfidfVectorizer
+from nltk.tokenize import word_tokenize
+
+corpus = [
+    "hello, how are you?",
+    "im getting bored at home. And you? What do you think?", 
+    "did you know about counts",
+    "let's see if this works!", 
+    "YES!!!!"
+]
+tfv = TfidfVectorizer(tokenizer=word_tokenize, token_pattern=None) 
+tfv.fit(corpus)
+corpus_transformed = tfv.transform(corpus) 
+print(corpus_transformed)
+
+

输出结果如下:

+
(0, 27)     0.2965698850220162
+(0, 16)     0.4428321995085722
+(0, 14)     0.4428321995085722
+(0, 7)      0.4428321995085722
+(0, 4)      0.35727423026525224
+(0, 2)      0.4428321995085722
+(1, 27)     0.35299699146792735
+(1, 24)     0.2635440111190765
+(1, 22)     0.2635440111190765
+(1, 18)     0.2635440111190765
+(1, 15)     0.2635440111190765
+(1, 13)     0.2635440111190765
+(1, 12)     0.2635440111190765
+(1, 9)      0.2635440111190765
+(1, 8)      0.2635440111190765
+(1, 6)      0.2635440111190765
+(1, 4)      0.42525129752567803
+(1, 3)      0.2635440111190765
+(2, 27)     0.31752680284846835
+(2, 19)     0.4741246485558491
+(2, 11)     0.4741246485558491
+(2, 10)     0.4741246485558491
+(2, 5)      0.4741246485558491
+(3, 25)     0.38775666010579296
+(3, 23)     0.38775666010579296
+(3, 21)     0.38775666010579296
+(3, 20)     0.38775666010579296 
+(3, 17)     0.38775666010579296 
+(3, 1)      0.38775666010579296 
+(3, 0)      0.3128396318588854 
+(4, 26)     0.2959842226518677 
+(4, 0)      0.9551928286692534
+
+

可以看到,这次我们得到的不是整数值,而是浮点数。 用 TfidfVectorizer 代替 CountVectorizer 也是小菜一碟。Scikit-learn 还提供了 TfidfTransformer。如果你使用的是计数值,可以使用 TfidfTransformer 并获得与 TfidfVectorizer 相同的效果。

+
import pandas as pd
+from nltk.tokenize import word_tokenize
+from sklearn import linear_model
+from sklearn import metrics
+from sklearn import model_selection
+from sklearn.feature_extraction.text import TfidfVectorizer
+
+for fold_ in range(5):
+    train_df = df[df.kfold != fold_].reset_index(drop=True) 
+    test_df = df[df.kfold == fold_].reset_index(drop=True) 
+    tfidf_vec = TfidfVectorizer( 
+        tokenizer=word_tokenize, 
+        token_pattern=None
+    )
+    tfidf_vec.fit(train_df.review)
+    xtrain = tfidf_vec.transform(train_df.review) 
+    xtest = tfidf_vec.transform(test_df.review)
+    model = linear_model.LogisticRegression()
+    model.fit(xtrain, train_df.sentiment)
+    preds = model.predict(xtest)
+    accuracy = metrics.accuracy_score(test_df.sentiment, preds)
+    print(f"Fold: {fold_}")
+    print(f"Accuracy = {accuracy}")
+    print("")
+
+

我们可以看看 TF-IDF 在逻辑回归模型上的表现如何。

+
Fold: 0 
+Accuracy = 0.8976
+
+Fold: 1 
+Accuracy = 0.8998
+
+Fold: 2 
+Accuracy = 0.8948
+
+Fold: 3 
+Accuracy = 0.8912
+
+Fold: 4 
+Accuracy = 0.8995
+
+

我们看到,这些分数都比 CountVectorizer 高一些,因此它成为了我们想要击败的新基准。

+

NLP 中另一个有趣的概念是 N-gram。N-grams 是按顺序排列的单词组合。N-grams 很容易创建。您只需注意顺序即可。为了让事情变得更简单,我们可以使用 NLTK 的 N-gram 实现。

+
from nltk import ngrams
+from nltk.tokenize import word_tokenize 
+
+N = 3
+sentence = "hi, how are you?" 
+tokenized_sentence = word_tokenize(sentence)
+n_grams = list(ngrams(tokenized_sentence, N)) 
+print(n_grams)
+
+

由此得到:

+
[('hi', ',', 'how'),
+(',', 'how', 'are'), 
+('how', 'are', 'you'), 
+('are', 'you', '?')]
+
+

同样,我们还可以创建 2-gram 或 4-gram 等。现在,这些 n-gram 将成为我们词汇表的一部分,当我们计算计数或 tf-idf 时,我们会将一个 n-gram 视为一个全新的标记。因此,在某种程度上,我们是在结合上下文。scikit-learn 的 CountVectorizer 和 TfidfVectorizer 实现都通过 ngram_range 参数提供 n-gram,该参数有最小和最大限制。默认情况下,该参数为(1, 1)。当我们将其改为 (1, 3) 时,我们将看到单字元、双字元和三字元。代码改动很小。

+

由于到目前为止我们使用 tf-idf 得到了最好的结果,让我们来看看包含 n-grams 直至 trigrams 是否能改进模型。唯一需要修改的是 TfidfVectorizer 的初始化。

+
tfidf_vec = TfidfVectorizer(
+    tokenizer=word_tokenize, 
+    token_pattern=None, 
+    ngram_range=(1, 3)
+)
+
+

让我们看看是否会有改进。

+
Fold: 0 
+Accuracy = 0.8931
+
+Fold: 1 
+Accuracy = 0.8941
+
+Fold: 2 
+Accuracy = 0.897
+
+Fold: 3 
+Accuracy = 0.8922
+
+Fold: 4 
+Accuracy = 0.8847
+
+

看起来还行,但我们看不到任何改进。 也许我们可以通过多使用 bigrams 来获得改进。 我不会在这里展示这一部分。也许你可以自己试着做。

+

NLP 的基础知识还有很多。你必须知道的一个术语是词干提取(strmming)。另一个是词形还原(lemmatization)。词干提取和词形还原可以将一个词减少到最小形式。在词干提取的情况下,处理后的单词称为词干单词,而在词形还原情况下,处理后的单词称为词形。必须指出的是,词形还原比词干提取更激进,而词干提取更流行和广泛。词干和词形都来自语言学。如果你打算为某种语言制作词干或词型,需要对该语言有深入的了解。如果要过多地介绍这些知识,就意味着要在本书中增加一章。使用 NLTK 软件包可以轻松完成词干提取和词形还原。让我们来看看这两种方法的一些示例。有许多不同类型的词干提取和词形还原器。我将用最常见的 Snowball Stemmer 和 WordNet Lemmatizer 来举例说明。

+
from nltk.stem import WordNetLemmatizer
+from nltk.stem.snowball import SnowballStemmer 
+
+lemmatizer = WordNetLemmatizer()
+stemmer = SnowballStemmer("english") 
+words = ["fishing", "fishes", "fished"] 
+for word in words:
+    print(f"word={word}")
+    print(f"stemmed_word={stemmer.stem(word)}")
+    print(f"lemma={lemmatizer.lemmatize(word)}")
+    print("")
+
+

这将打印:

+
word=fishing
+stemmed_word=fish 
+lemma=fishing
+word=fishes 
+stemmed_word=fish 
+lemma=fish
+word=fished 
+stemmed_word=fish 
+lemma=fished
+
+

正如您所看到的,词干提取和词形还原是截然不同的。当我们进行词干提取时,我们得到的是一个词的最小形式,它可能是也可能不是该词所属语言词典中的一个词。但是,在词形还原情况下,这将是一个词。现在,您可以自己尝试添加词干和词素化,看看是否能改善结果。

+

您还应该了解的一个主题是主题提取。主题提取可以使用非负矩阵因式分解(NMF)或潜在语义分析(LSA)来完成,后者也被称为奇异值分解或 SVD。这些分解技术可将数据简化为给定数量的成分。 您可以在从 CountVectorizer 或 TfidfVectorizer 中获得的稀疏矩阵上应用其中任何一种技术。

+

让我们把它应用到之前使用过的 TfidfVetorizer 上。

+
import pandas as pd
+from nltk.tokenize import word_tokenize
+from sklearn import decomposition
+from sklearn.feature_extraction.text import TfidfVectorizer
+corpus = pd.read_csv("../input/imdb.csv", nrows=10000) 
+corpus = corpus.review.values
+tfv = TfidfVectorizer(tokenizer=word_tokenize, token_pattern=None) 
+tfv.fit(corpus)
+corpus_transformed = tfv.transform(corpus)
+svd = decomposition.TruncatedSVD(n_components=10) 
+corpus_svd = svd.fit(corpus_transformed)
+sample_index = 0
+feature_scores = dict( 
+    zip(
+        tfv.get_feature_names(),
+        corpus_svd.components_[sample_index] 
+    )
+)
+N = 5
+print(sorted(feature_scores, key=feature_scores.get, reverse=True)[:N])
+
+

您可以使用循环来运行多个样本。

+
N = 5
+for sample_index in range(5): 
+    feature_scores = dict(
+        zip(
+            tfv.get_feature_names(),
+            corpus_svd.components_[sample_index] 
+        )
+    )
+    print(
+        sorted(
+            feature_scores,
+            key=feature_scores.get, 
+            reverse=True
+        )[:N] 
+    )
+
+

输出结果如下:

+
['the', ',', '.', 'a', 'and']
+['br', '<', '>', '/', '-']
+['i', 'movie', '!', 'it', 'was'] 
+[',', '!', "''", '``', 'you'] 
+['!', 'the', '...', "''", '``']
+
+

你可以看到,这根本说不通。怎么办呢?让我们试着清理一下,看看是否有意义。要清理任何文本数据,尤其是 pandas 数据帧中的文本数据,可以创建一个函数。

+
import re
+import string 
+
+def clean_text(s):
+    s = s.split()
+    s = " ".join(s)
+    s = re.sub(f'[{re.escape(string.punctuation)}]', '', s)
+    return s
+
+

该函数会将 "hi, how are you????" 这样的字符串转换为 "hi how are you"。让我们把这个函数应用到旧的 SVD 代码中,看看它是否能给提取的主题带来提升。使用 pandas,你可以使用 apply 函数将清理代码 "应用 "到任意给定的列中。

+
import pandas as pd
+corpus = pd.read_csv("../input/imdb.csv", nrows=10000) 
+corpus.loc[:, "review"] = corpus.review.apply(clean_text)
+
+

请注意,我们只在主 SVD 脚本中添加了一行代码,这就是使用函数和 pandas 应用的好处。这次生成的主题如下。

+
['the', 'a', 'and', 'of', 'to']
+['i', 'movie', 'it', 'was', 'this'] 
+['the', 'was', 'i', 'were', 'of'] 
+['her', 'was', 'she', 'i', 'he'] 
+['br', 'to', 'they', 'he', 'show']
+
+

呼!至少这比我们之前好多了。但你知道吗?你可以通过在清理功能中删除停止词(stopwords)来使它变得更好。什么是stopwords?它们是存在于每种语言中的高频词。例如,在英语中,这些词包括 "a"、"an"、"the"、"for "等。删除停止词并非总是明智的选择,这在很大程度上取决于业务问题。像 "I need a new dog"这样的句子,去掉停止词后会变成 "need new dog",此时我们不知道谁需要new dog。

+

如果我们总是删除停止词,就会丢失很多上下文信息。你可以在 NLTK 中找到许多语言的停止词,如果没有,你也可以在自己喜欢的搜索引擎上快速搜索一下。

+

现在,让我们转到大多数人都喜欢使用的方法:深度学习。但首先,我们必须知道什么是词嵌入(embedings for words)。你已经看到,到目前为止,我们已经将标记转换成了数字。因此,如果某个语料库中有 N 个唯一的词块,它们可以用 0 到 N-1 之间的整数来表示。现在,我们将用向量来表示这些整数词块。这种将单词表示成向量的方法被称为单词嵌入或单词向量。谷歌的 Word2Vec 是将单词转换为向量的最古老方法之一。此外,还有 Facebook 的 FastText 和斯坦福大学的 GloVe(用于单词表示的全局向量)。这些方法彼此大相径庭。

+

其基本思想是建立一个浅层网络,通过重构输入句子来学习单词的嵌入。因此,您可以通过使用周围的所有单词来训练网络预测一个缺失的单词,在此过程中,网络将学习并更新所有相关单词的嵌入。这种方法也被称为连续词袋或 CBoW 模型。您也可以尝试使用一个单词来预测上下文中的单词。这就是所谓的跳格模型。Word2Vec 可以使用这两种方法学习嵌入。

+

FastText 可以学习字符 n-gram 的嵌入。和单词 n-gram 一样,如果我们使用的是字符,则称为字符 n-gram,最后,GloVe 通过共现矩阵来学习这些嵌入。因此,我们可以说,所有这些不同类型的嵌入最终都会返回一个字典,其中键是语料库(例如英语维基百科)中的单词,值是大小为 N(通常为 300)的向量。

+

+

图 1:可视化二维单词嵌入。

+ +

图 1 显示了二维单词嵌入的可视化效果。假设我们以某种方式完成了词语的二维表示。图 1 显示,如果从Berlin(德国首都)的向量中减去德国(Germany)的向量,再加上法国(france)的向量,就会得到一个接近Paris(法国首都)的向量。由此可见,嵌入式也能进行类比。 这并不总是正确的,但这样的例子有助于理解单词嵌入的作用。像 "嗨,你好吗 "这样的句子可以用下面的一堆向量来表示。

+

hi ─> [vector (v1) of size 300]

+

, ─> [vector (v2) of size 300]

+

how ─> [vector (v3) of size 300]

+

are ─> [vector (v4) of size 300]

+

you ─> [vector (v5) of size 300]

+

? ─> [vector (v6) of size 300]

+

使用这些信息有多种方法。最简单的方法之一就是使用嵌入向量。如上例所示,每个单词都有一个 1x300 的嵌入向量。利用这些信息,我们可以计算出整个句子的嵌入。计算方法有多种。其中一种方法如下所示。在这个函数中,我们将给定句子中的所有单词向量提取出来,然后从所有标记词的单词向量中创建一个归一化的单词向量。这样就得到了一个句子向量。

+
import numpy as np
+def sentence_to_vec(s, embedding_dict, stop_words, tokenizer): 
+    words = str(s).lower()
+    words = tokenizer(words)
+    words = [w for w in words if not w in stop_words] 
+    words = [w for w in words if w.isalpha()] 
+    M = []
+    for w in words:
+        if w in embedding_dict:
+            M.append(embedding_dict[w])
+    if len(M) == 0:
+        return np.zeros(300)
+    M = np.array(M)
+    v = M.sum(axis=0)
+    return v / np.sqrt((v ** 2).sum())
+
+

我们可以用这种方法将所有示例转换成一个向量。我们能否使用 fastText 向量来改进之前的结果?每篇评论都有 300 个特征。

+
import io
+import numpy as np
+import pandas as pd
+from nltk.tokenize import word_tokenize
+from sklearn import linear_model
+from sklearn import metrics
+from sklearn import model_selection
+from sklearn.feature_extraction.text import TfidfVectorizer
+def load_vectors(fname):
+    fin = io.open(
+        fname, 
+        'r',
+        encoding='utf-8', 
+        newline='\n', 
+        errors='ignore'
+    )
+    n, d = map(int, fin.readline().split()) 
+    data = {}
+    for line in fin:
+        tokens = line.rstrip().split(' ')
+        data[tokens[0]] = list(map(float, tokens[1:])) 
+    return data
+
+def sentence_to_vec(s, embedding_dict, stop_words, tokenizer):
+
+if __name__ == "__main__": 
+    df = pd.read_csv("../input/imdb.csv")
+    df.sentiment = df.sentiment.apply( 
+        lambda x: 1 if x == "positive" else 0
+    )
+    df = df.sample(frac=1).reset_index(drop=True)
+    print("Loading embeddings")
+    embeddings = load_vectors("../input/crawl-300d-2M.vec") 
+    print("Creating sentence vectors") 
+    vectors = []
+    for review in df.review.values: 
+        vectors.append(
+            sentence_to_vec( 
+                s = review,
+                embedding_dict = embeddings, 
+                stop_words = [], 
+                tokenizer = word_tokenize
+            ) 
+        )
+        vectors = np.array(vectors)
+        y = df.sentiment.values 
+        kf = model_selection.StratifiedKFold(n_splits=5)
+        for fold_, (t_, v_) in enumerate(kf.split(X=vectors, y=y)): 
+            print(f"Training fold: {fold_}")
+        xtrain = vectors[t_, :] 
+        ytrain = y[t_]
+        xtest = vectors[v_, :] 
+        ytest = y[v_]
+        model = linear_model.LogisticRegression()
+        model.fit(xtrain, ytrain)
+        preds = model.predict(xtest)
+        accuracy = metrics.accuracy_score(ytest, preds) 
+        print(f"Accuracy = {accuracy}")
+        print("")
+
+

这将得到如下结果:

+
Loading embeddings 
+Creating sentence vectors 
+
+Training fold: 0 
+Accuracy = 0.8619
+
+Training fold: 1 
+Accuracy = 0.8661 
+
+Training fold: 2 
+Accuracy = 0.8544 
+
+Training fold: 3 
+Accuracy = 0.8624 
+
+Training fold: 4 
+Accuracy = 0.8595
+
+

Wow!真是出乎意料。我们所做的一切都是为了使用 FastText 嵌入。试着把嵌入式换成 GloVe,看看会发生什么。我把它作为一个练习留给大家。 +当我们谈论文本数据时,我们必须牢记一件事。文本数据与时间序列数据非常相似。如图 2 所示,我们评论中的任何样本都是在不同时间戳上按递增顺序排列的标记序列,每个标记都可以表示为一个向量/嵌入。

+

+

图 2:将标记表示为嵌入,并将其视为时间序列

+ +

这意味着我们可以使用广泛用于时间序列数据的模型,例如长短期记忆(LSTM)或门控递归单元(GRU),甚至卷积神经网络(CNN)。让我们看看如何在该数据集上训练一个简单的双向 LSTM 模型。

+

首先,我们将创建一个项目。你可以随意给它命名。然后,我们的第一步将是分割数据进行交叉验证。

+
import pandas as pd
+from sklearn import model_selection 
+if __name__ == "__main__":
+    df = pd.read_csv("../input/imdb.csv") 
+    df.sentiment = df.sentiment.apply( 
+        lambda x: 1 if x == "positive" else 0
+    )
+    df["kfold"] = -1
+    df = df.sample(frac=1).reset_index(drop=True)
+    y = df.sentiment.values
+    kf = model_selection.StratifiedKFold(n_splits=5)
+    for f, (t_, v_) in enumerate(kf.split(X=df, y=y)): 
+        df.loc[v_, 'kfold'] = f
+        df.to_csv("../input/imdb_folds.csv", index=False)
+
+

将数据集划分为多个折叠后,我们就可以在 dataset.py 中创建一个简单的数据集类。数据集类会返回一个训练或验证数据样本。

+
import torch
+class IMDBDataset:
+def __init__(self, reviews, targets): 
+    self.reviews = reviews 
+    self.target = targets
+def __len__(self):
+    return len(self.reviews)
+def __getitem__(self, item):
+    review = self.reviews[item, :] 
+    target = self.target[item]
+    return {
+        "review": torch.tensor(review, dtype=torch.long),
+        "target": torch.tensor(target, dtype=torch.float) 
+    }
+
+

完成数据集分类后,我们就可以创建 lstm.py,其中包含我们的 LSTM 模型

+
import torch
+import torch.nn as nn 
+class LSTM(nn.Module):
+    def __init__(self, embedding_matrix): 
+        super(LSTM, self).__init__()
+        num_words = embedding_matrix.shape[0]
+        embed_dim = embedding_matrix.shape[1]
+        self.embedding = nn.Embedding( 
+            num_embeddings=num_words, 
+            embedding_dim=embed_dim)
+        self.embedding.weight = nn.Parameter( 
+            torch.tensor(
+                embedding_matrix, 
+                dtype=torch.float32
+            ) 
+        )
+        self.embedding.weight.requires_grad = False
+        self.lstm = nn.LSTM( 
+            embed_dim,
+            128,
+            bidirectional=True, 
+            batch_first=True,
+        )
+        self.out = nn.Linear(512, 1)
+
+        def forward(self, x):
+            x = self.embedding(x)
+            x, _ = self.lstm(x)
+            avg_pool = torch.mean(x, 1)
+            max_pool, _ = torch.max(x, 1)
+            out = torch.cat((avg_pool, max_pool), 1)
+            out = self.out(out)
+            return out
+
+

现在,我们创建 engine.py,其中包含训练和评估函数。

+
import torch
+import torch.nn as nn
+def train(data_loader, model, optimizer, device):
+    model.train()
+    for data in data_loader:
+        reviews = data["review"]
+        targets = data["target"]
+        reviews = reviews.to(device, dtype=torch.long) 
+        targets = targets.to(device, dtype=torch.float)
+        optimizer.zero_grad()
+        predictions = model(reviews)
+        loss = nn.BCEWithLogitsLoss()( 
+            predictions,
+            targets.view(-1, 1) 
+        )
+        loss.backward()
+        optimizer.step()
+
+def evaluate(data_loader, model, device):
+    final_predictions = [] 
+    final_targets = []
+    model.eval()
+    with torch.no_grad():
+        for data in data_loader: 
+            reviews = data["review"] 
+            targets = data["target"]
+            reviews = reviews.to(device, dtype=torch.long) 
+            targets = targets.to(device, dtype=torch.float)
+            predictions = model(reviews)
+            predictions = predictions.cpu().numpy().tolist() 
+            targets = data["target"].cpu().numpy().tolist() 
+            final_predictions.extend(predictions)
+            final_targets.extend(targets)
+            return final_predictions, final_targets
+
+

这些函数将在 train.py 中为我们提供帮助,该函数用于训练多个折叠。

+
import io
+import torch
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from sklearn import metrics
+import config
+import dataset
+import engine
+import lstm
+
+def load_vectors(fname):
+    fin = io.open( 
+        fname, 
+        'r',
+        encoding='utf-8', 
+        newline='\n', 
+        errors='ignore'
+    )
+    n, d = map(int, fin.readline().split()) 
+    data = {}
+    for line in fin:
+        tokens = line.rstrip().split(' ')
+        data[tokens[0]] = list(map(float, tokens[1:])) 
+    return data
+
+def create_embedding_matrix(word_index, embedding_dict): 
+    embedding_matrix = np.zeros((len(word_index) + 1, 300)) 
+    for word, i in word_index.items():
+        if word in embedding_dict:
+            embedding_matrix[i] = embedding_dict[word]
+    return embedding_matrix
+
+def run(df, fold): 
+    train_df = df[df.kfold != fold].reset_index(drop=True) 
+    valid_df = df[df.kfold == fold].reset_index(drop=True) 
+    print("Fitting tokenizer")
+    tokenizer = tf.keras.preprocessing.text.Tokenizer()
+    tokenizer.fit_on_texts(df.review.values.tolist())
+    xtrain = tokenizer.texts_to_sequences(train_df.review.values)
+    xtest = tokenizer.texts_to_sequences(valid_df.review.values)
+    xtrain = tf.keras.preprocessing.sequence.pad_sequences( 
+        xtrain, maxlen=config.MAX_LEN
+    )
+    xtest = tf.keras.preprocessing.sequence.pad_sequences(
+        xtest, maxlen=config.MAX_LEN
+    )
+    train_dataset = dataset.IMDBDataset( 
+        reviews=xtrain,
+        targets=train_df.sentiment.values 
+    )
+    train_data_loader = torch.utils.data.DataLoader( 
+        train_dataset,
+        batch_size=config.TRAIN_BATCH_SIZE, 
+        num_workers=2
+    )
+    valid_dataset = dataset.IMDBDataset( 
+        reviews=xtest,
+        targets=valid_df.sentiment.values 
+    )
+    valid_data_loader = torch.utils.data.DataLoader( 
+        valid_dataset,
+        batch_size=config.VALID_BATCH_SIZE, 
+        num_workers=1
+    )
+    print("Loading embeddings")
+    embedding_dict = load_vectors("../input/crawl-300d-2M.vec") 
+    embedding_matrix = create_embedding_matrix(
+        tokenizer.word_index, embedding_dict 
+    )
+    device = torch.device("cuda")
+    model = lstm.LSTM(embedding_matrix)
+    model.to(device)
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 
+    print("Training Model")
+    best_accuracy = 0
+    early_stopping_counter = 0
+    for epoch in range(config.EPOCHS):
+        engine.train(train_data_loader, model, optimizer, device)
+        outputs, targets = engine.evaluate(
+            valid_data_loader, model, device
+        )
+        outputs = np.array(outputs) >= 0.5
+        accuracy = metrics.accuracy_score(targets, outputs) 
+        print(f"FOLD:{fold}, Epoch: {epoch}, Accuracy Score = {accuracy}")
+        if accuracy > best_accuracy: 
+            best_accuracy = accuracy
+        else:
+            early_stopping_counter += 1 
+            if early_stopping_counter > 2:
+                break
+if __name__ == "__main__":
+    df = pd.read_csv("../input/imdb_folds.csv") 
+    run(df, fold=0)
+    run(df, fold=1)
+    run(df, fold=2)
+    run(df, fold=3)
+    run(df, fold=4)
+
+

最后是 config.py。

+
MAX_LEN = 128
+TRAIN_BATCH_SIZE = 16 
+VALID_BATCH_SIZE = 8 
+EPOCHS = 10
+
+

让我们看看输出:

+
FOLD:0, Epoch: 3, Accuracy Score = 0.9015
+FOLD:1, Epoch: 4, Accuracy Score = 0.9007
+FOLD:2, Epoch: 3, Accuracy Score = 0.8924
+FOLD:3, Epoch: 2, Accuracy Score = 0.9
+FOLD:4, Epoch: 1, Accuracy Score = 0.878
+
+

这是迄今为止我们获得的最好成绩。 请注意,我只显示了每个折叠中精度最高的Epoch。

+

你一定已经注意到,我们使用了预先训练的嵌入和简单的双向 LSTM。 如果你想改变模型,你可以只改变 lstm.py 中的模型并保持一切不变。 这种代码只需要很少的实验改动,并且很容易理解。 例如,您可以自己学习嵌入而不是使用预训练的嵌入,您可以使用其他一些预训练的嵌入,您可以组合多个预训练的嵌入,您可以使用GRU,您可以在嵌入后使用空间dropout,您可以添加GRU LSTM 层之后,您可以添加两个 LSTM 层,您可以进行 LSTM-GRU-LSTM 配置,您可以用卷积层替换 LSTM 等,而无需对代码进行太多更改。 我提到的大部分内容只需要更改模型类。

+

当您使用预训练的嵌入时,尝试查看有多少单词无法找到嵌入以及原因。 预训练嵌入的单词越多,结果就越好。 我向您展示以下未注释的 (!) 函数,您可以使用它为任何类型的预训练嵌入创建嵌入矩阵,其格式与 glove 或 fastText 相同(可能需要进行一些更改)。

+
def load_embeddings(word_index, embedding_file, vector_length=300):
+    max_features = len(word_index) + 1 
+    words_to_find = list(word_index.keys()) 
+    more_words_to_find = []
+    for wtf in words_to_find:
+        more_words_to_find.append(wtf)
+        more_words_to_find.append(str(wtf).capitalize()) 
+    more_words_to_find = set(more_words_to_find)
+def get_coefs(word, *arr):
+    return word, np.asarray(arr, dtype='float32') 
+
+embeddings_index = dict(
+    get_coefs(*o.strip().split(" ")) 
+    for o in open(embedding_file) 
+    if o.split(" ")[0]
+    in more_words_to_find 
+    and len(o) > 100
+)
+
+embedding_matrix = np.zeros((max_features, vector_length)) 
+for word, i in word_index.items():
+    if i >= max_features: 
+        continue
+    embedding_vector = embeddings_index.get(word) 
+    if embedding_vector is None:
+        embedding_vector = embeddings_index.get( 
+            str(word).capitalize()
+        )
+    if embedding_vector is None:
+        embedding_vector = embeddings_index.get( 
+            str(word).upper()
+        )
+    if (embedding_vector is not None
+        and len(embedding_vector) == vector_length): 
+        embedding_matrix[i] = embedding_vector
+    return embedding_matrix
+
+

阅读并运行上面的函数,看看发生了什么。 该函数还可以修改为使用词干词或词形还原词。 最后,您希望训练语料库中的未知单词数量最少。 另一个技巧是学习嵌入层,即使其可训练,然后训练网络。

+

到目前为止,我们已经为分类问题构建了很多模型。 然而,现在是布偶时代,越来越多的人转向基于变形金刚的模型。 基于 Transformer 的网络能够处理本质上长期的依赖关系。 LSTM 仅当它看到前一个单词时才查看下一个单词。 变压器的情况并非如此。 它可以同时查看整个句子中的所有单词。 因此,另一个优点是它可以轻松并行化并更有效地使用 GPU。

+

Transformers 是一个非常广泛的话题,有太多的模型:BERT、RoBERTa、XLNet、XLM-RoBERTa、T5 等。我将向您展示一种可用于所有这些模型(T5 除外)进行分类的通用方法 我们一直在讨论的问题。 请注意,这些变压器需要训练它们所需的计算能力。 因此,如果您没有高端系统,与基于 LSTM 或 TF-IDF 的模型相比,训练模型可能需要更长的时间。

+

我们要做的第一件事是创建一个配置文件。

+
import transformers
+MAX_LEN = 512
+TRAIN_BATCH_SIZE = 8
+VALID_BATCH_SIZE = 4
+EPOCHS = 10
+
+BERT_PATH = "../input/bert_base_uncased/" 
+MODEL_PATH = "model.bin" 
+TRAINING_FILE = "../input/imdb.csv" 
+TOKENIZER = transformers.BertTokenizer.from_pretrained( 
+    BERT_PATH,
+    do_lower_case=True 
+)
+
+

这里的配置文件是我们定义分词器和其他我们想要经常更改的参数的唯一地方 —— 这样我们就可以做很多实验而不需要进行大量更改。

+

下一步是构建数据集类。

+
import config 
+import torch
+class BERTDataset:
+    def __init__(self, review, target): 
+        self.review = review 
+        self.target = target
+        self.tokenizer = config.TOKENIZER 
+        self.max_len = config.MAX_LEN
+    def __len__(self):
+        return len(self.review)
+    def __getitem__(self, item):
+        review = str(self.review[item]) 
+        review = " ".join(review.split())
+        inputs = self.tokenizer.encode_plus( 
+            review,
+            None,
+            add_special_tokens=True, 
+            max_length=self.max_len, 
+            pad_to_max_length=True,
+        )
+        ids = inputs["input_ids"]
+        mask = inputs["attention_mask"]
+        token_type_ids = inputs["token_type_ids"]
+        return {
+            "ids": torch.tensor( 
+                ids, dtype=torch.long
+            ),
+            "mask": torch.tensor( 
+                mask, dtype=torch.long
+            ),
+            "token_type_ids": torch.tensor( 
+                token_type_ids, dtype=torch.long
+            ),
+            "targets": torch.tensor(
+                self.target[item], dtype=torch.float 
+            )
+        }
+
+

现在我们来到了该项目的核心,即模型。

+
import config
+import transformers
+import torch.nn as nn
+class BERTBaseUncased(nn.Module): 
+    def __init__(self):
+        super(BERTBaseUncased, self).__init__()
+        self.bert = transformers.BertModel.from_pretrained(
+            config.BERT_PATH 
+        )
+        self.bert_drop = nn.Dropout(0.3) 
+        self.out = nn.Linear(768, 1)
+    def forward(self, ids, mask, token_type_ids):
+        hidden state
+        _, o2 = self.bert( 
+            ids,
+            attention_mask=mask,
+            token_type_ids=token_type_ids 
+        )
+        bo = self.bert_drop(o2)
+        output = self.out(bo)
+        return output
+
+

该模型返回单个输出。 我们可以使用带有 logits 的二元交叉熵损失,它首先应用 sigmoid,然后计算损失。 这是在engine.py 中完成的。

+
import torch
+import torch.nn as nn
+def loss_fn(outputs, targets): 
+    return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))
+def train_fn(data_loader, model, optimizer, device, scheduler):
+    model.train()
+    for d in data_loader:
+        ids = d["ids"]
+        token_type_ids = d["token_type_ids"] 
+        mask = d["mask"]
+        targets = d["targets"]
+        ids = ids.to(device, dtype=torch.long)
+        token_type_ids = token_type_ids.to(device, dtype=torch.long) 
+        mask = mask.to(device, dtype=torch.long)
+        targets = targets.to(device, dtype=torch.float)
+        optimizer.zero_grad()
+        outputs = model( 
+            ids=ids, 
+            mask=mask,
+            token_type_ids=token_type_ids 
+        )
+        loss = loss_fn(outputs, targets)
+        loss.backward()
+        optimizer.step()
+        scheduler.step()
+
+def eval_fn(data_loader, model, device): 
+    model.eval()
+    fin_targets = [] 
+    fin_outputs = []
+    with torch.no_grad():
+        for d in data_loader: 
+            ids = d["ids"]
+            token_type_ids = d["token_type_ids"] 
+            mask = d["mask"]
+            targets = d["targets"]
+            ids = ids.to(device, dtype=torch.long)
+            token_type_ids = token_type_ids.to(device, dtype=torch.long) 
+            mask = mask.to(device, dtype=torch.long)
+            targets = targets.to(device, dtype=torch.float) 
+            outputs = model(
+                ids=ids, 
+                mask=mask,
+                token_type_ids=token_type_ids 
+            )
+            targets = targets.cpu().detach()
+            fin_targets.extend(targets.numpy().tolist())
+            outputs = torch.sigmoid(outputs).cpu().detach() 
+            fin_outputs.extend(outputs.numpy().tolist())
+    return fin_outputs, fin_targets
+
+

最后,我们准备好训练了。 我们来看看训练脚本吧!

+
import config
+import dataset
+import engine
+import torch
+import pandas as pd
+import torch.nn as nn
+import numpy as np
+from model import BERTBaseUncased
+from sklearn import model_selection
+from sklearn import metrics
+from transformers import AdamW
+from transformers import get_linear_schedule_with_warmup
+def train():
+    dfx = pd.read_csv(config.TRAINING_FILE).fillna("none") 
+    dfx.sentiment = dfx.sentiment.apply( 
+        lambda x: 1 if x == "positive" else 0
+    )
+    df_train, df_valid = model_selection.train_test_split( 
+        dfx,
+        test_size=0.1,
+        random_state=42,
+        stratify=dfx.sentiment.values
+    )
+    df_train = df_train.reset_index(drop=True) 
+    df_valid = df_valid.reset_index(drop=True)
+    train_dataset = dataset.BERTDataset( 
+        review=df_train.review.values, 
+        target=df_train.sentiment.values
+    )
+    train_data_loader = torch.utils.data.DataLoader( 
+        train_dataset,
+        batch_size=config.TRAIN_BATCH_SIZE, 
+        num_workers=4
+    )
+    valid_dataset = dataset.BERTDataset( 
+        review=df_valid.review.values, 
+        target=df_valid.sentiment.values
+    )
+    valid_data_loader = torch.utils.data.DataLoader( 
+        valid_dataset,
+        batch_size=config.VALID_BATCH_SIZE, 
+        num_workers=1
+    )
+    device = torch.device("cuda")
+    model = BERTBaseUncased() 
+    model.to(device)
+    param_optimizer = list(model.named_parameters())
+    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] 
+    optimizer_parameters = [
+        {
+            "params": [
+                p for n, p in param_optimizer if 
+                not any(nd in n for nd in no_decay)
+            ],
+            "weight_decay": 0.001,
+        }
+        {
+            "params": [
+                p for n, p in param_optimizer if 
+                any(nd in n for nd in no_decay)
+            ],
+            "weight_decay": 0.0,
+        }]
+    num_train_steps = int(
+        len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS 
+    )
+    optimizer = AdamW(optimizer_parameters, lr=3e-5) 
+    scheduler = get_linear_schedule_with_warmup( 
+        optimizer,
+        num_warmup_steps=0,
+        num_training_steps=num_train_steps 
+    )
+    model = nn.DataParallel(model)
+    best_accuracy = 0
+    for epoch in range(config.EPOCHS): 
+        engine.train_fn(
+            train_data_loader, model, optimizer, device, scheduler 
+        )
+        outputs, targets = engine.eval_fn(
+            valid_data_loader, model, device 
+        )
+
+        outputs = np.array(outputs) >= 0.5
+        accuracy = metrics.accuracy_score(targets, outputs) 
+        print(f"Accuracy Score = {accuracy}")
+        if accuracy > best_accuracy:
+            torch.save(model.state_dict(), config.MODEL_PATH) 
+            best_accuracy = accuracy
+
+if __name__ == "__main__": 
+    train()
+
+

乍一看可能看起来很多,但一旦您了解了各个组件,就不再那么简单了。 您只需更改几行代码即可轻松将其更改为您想要使用的任何其他变压器模型。

+

该模型的准确率为 93%! 哇! 这比任何其他模型都要好得多。 但是这值得吗?

+

我们使用 LSTM 能够实现 90% 的目标,而且它们更简单、更容易训练并且推理速度更快。 通过使用不同的数据处理或调整层、节点、dropout、学习率、更改优化器等参数,我们可以将该模型改进一个百分点。然后我们将从 BERT 中获得约 2% 的收益。 另一方面,BERT 的训练时间要长得多,参数很多,而且推理速度也很慢。 最后,您应该审视自己的业务并做出明智的选择。 不要仅仅因为 BERT“酷”而选择它。

+

必须注意的是,我们在这里讨论的唯一任务是分类,但将其更改为回归、多标签或多类只需要更改几行代码。 例如,多类分类设置中的同一问题将有多个输出和交叉熵损失。 其他一切都应该保持不变。 自然语言处理非常庞大,我们只讨论了其中的一小部分。 显然,这是一个很大的比例,因为大多数工业模型都是分类或回归模型。 如果我开始详细写所有内容,我最终可能会写几百页,这就是为什么我决定将所有内容包含在一本单独的书中:接近(几乎)任何 NLP 问题!

+ + + + + + + +
+
+
+ + + + Back to top + + +
+ + + + +
+
+
+
+ + + + + + + + + + + + + + \ No newline at end of file