-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch.xml
2426 lines (2052 loc) · 495 KB
/
search.xml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<?xml version="1.0" encoding="utf-8"?>
<search>
<entry>
<title><![CDATA[Flink 均衡调度]]></title>
<url>https://sharkdtu.github.io/posts/flink-schedule-balance.html</url>
<content type="html"><![CDATA[<p>Flink 当前的计算任务调度是完全随机的,直接后果是各个 taskmanager 上运行的计算任务分布不均,进而导致 taskmanagers 之间的负载不均衡,用户在配置 taskmanager 资源时不得不预留较大的资源 buffer,带来不必要的浪费。为此,我们扩展了一种均衡调度策略,尽量保证每个 flink 算子的子任务均匀分布在所有的 taskmanagers 上,使得 taskmanagers 之间的负载相对均衡。<a id="more"></a></p>
<h2 id="背景"><a href="#背景" class="headerlink" title="背景"></a>背景</h2><p>flink 在下发计算任务时,只要有空闲的 slot 资源就直接分配,并不考虑计算任务在 taskmanagers 上的分布情况,然而,不同算子的计算逻辑不同,如果一个算子是计算密集型的,其多个并行任务被扎堆调度下发到同一个 taskmanager 上,那么这个 taskmanager 的 cpu 负载压力会很大。更形象地,如下图所示的 JobGraph,有三个算子,最大并行度为 6,按照 flink 默认的 slot 共享调度机制,需要 6 个 slot。</p>
<p><img src="/images/flink-schedule-jobgraph-demo.png" width="600" height="400" align="center"></p>
<p>假如用户配置 2 个 taskmanager,每个 taskmanager 3 个 slot,按照目前 flink 的调度下发机制,很可能会出现如下图所示的计算任务分配情况,可以看到,source 和 sink 这两个算子的子任务被扎堆下发到同一个 taskmanager 上了,势必会造成该 taskmanager 上的负载(包括cpu、mem、network io 等)比其他 taskmanager 更高。</p>
<p><img src="/images/flink-schedule-execution-demo.png" width="600" height="400" align="center"></p>
<h2 id="方案"><a href="#方案" class="headerlink" title="方案"></a>方案</h2><p>在阐述具体方案前,先通过一个例子简单介绍下当前 flink 计算任务分配下发的过程,如下图所示,上面的 JobGraph 在调度下发时,会创建一系列的 ExecutionSlotSharingGroup,每个 ExecutionSlotSharingGroup 包含不同算子的子任务,一个 ExecutionSlotSharingGroup 需要一个 slot,所以申请 slot 时,只需按照按 ExecutionSlotSharingGroup 数量来申请即可。</p>
<p><img src="/images/flink-schedule-executionslotsharing.png" width="600" height="400" align="center"></p>
<p>如下图所示,JobMaster 向 ResourceManager 声明请求 slot 个数,ResourceManager 判断是否有足够的 slot 资源,如果有,则将 job 信息发给 TaskExecutor 请求 slot,TaskExecutor 再向 JobMaster 提供 slot,JobMaster 即可下发计算任务;如果没有,则会尝试向集群申请资源,TaskExecutor 起来后会向 ResourceManager 上报 slot 资源信息。</p>
<p><img src="/images/flink-schedule-task-deploy.png" width="600" height="400" align="center"></p>
<p>计算任务分布不均衡本质原因是,JobMaster 申请到的 slot 不是一次性拿到的,每次 TaskExecutor 向 JobMaster 提供 slot 时,JobMaster 就将这部分 slot 分给 ExecutionSlotSharingGroup ,在分配的时候,并不考虑分布情况。</p>
<p>为了能有一个全局的分配视角,需要等所有 slot 到齐后,一把分配。问题就变成了:有 K 个大小不一的 ExecutionSlotSharingGroup,要放到 m*n = K 个 slot 里(m 为 tm 个数,n 为每个 tm 的 slot 数),尽量让每个 tm 上的 ExecutionSlotSharingGroup 分布均衡。为此,我们对每个 ExecutionSlotSharingGroup 分类编号,如果其包含的子任务所属的算子相同,会被分配同一个编号,如下图所示,总共有三类,相同计算负载的 ExecutionSlotSharingGroup 编号相同。</p>
<p><img src="/images/flink-schedule-executionslotsharing-optimize.png" width="600" height="400" align="center"></p>
<p>有了上述基础后,我们只需要实现一个算法,按 ExecutionSlotSharingGroup 类别 id,均匀分配到 taskmanager 中即可,如下图所示,可以看到,最终运行时,两个 taskmanager 上的负载是相对均衡的。</p>
<p><img src="/images/flink-schedule-task-balance.png" width="600" height="400" align="center"></p>
<h2 id="小结"><a href="#小结" class="headerlink" title="小结"></a>小结</h2><p>本文介绍了 flink 均衡调度,目的是尽可能使计算任务在各 taskmanagers 上分布均衡,保证作业稳定性以及节省资源。该特性需要等所有 slot 全部到位一把分配,仅适用于流处理模式,对批处理意义不大。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/flink-schedule-balance.html">https://sharkdtu.github.io/posts/flink-schedule-balance.html</a></em></span></p>
]]></content>
<categories>
<category> flink </category>
</categories>
<tags>
<tag> flink </tag>
<tag> scheduler </tag>
</tags>
</entry>
<entry>
<title><![CDATA[Flink RPC 详解]]></title>
<url>https://sharkdtu.github.io/posts/flink-rpc.html</url>
<content type="html"><![CDATA[<p>要理解 Flink 内部各组件交互的源码实现,首先必须要理解其 RPC 的工作机制。与 Hadoop、Spark 等系统类似,作为一个独立的分布式系统框架,Flink 也抽象了自己的一套 RPC 框架,本文尝试尽可能详尽地阐述其设计及实现原理。<a id="more"></a></p>
<h2 id="接口设计"><a href="#接口设计" class="headerlink" title="接口设计"></a>接口设计</h2><p>首先不用纠结其内部实现细节,先感性地认识下如何使用 Flink RPC 框架实现一个基本的 RPC 调用。</p>
<ol>
<li><p>定义接口协议</p>
<figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">public</span> <span class="class"><span class="keyword">interface</span> <span class="title">HelloGateway</span> <span class="keyword">extends</span> <span class="title">RpcGateway</span> </span>{</span><br><span class="line"> <span class="function">CompletableFuture<String> <span class="title">sayHello</span><span class="params">()</span></span>;</span><br><span class="line">}</span><br></pre></td></tr></table></figure>
</li>
<li><p>服务端组件实现接口</p>
<figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// RpcEndpoint 可以理解为服务端组件</span></span><br><span class="line"><span class="keyword">public</span> <span class="keyword">static</span> <span class="class"><span class="keyword">class</span> <span class="title">HelloEndpoint</span> <span class="keyword">extends</span> <span class="title">RpcEndpoint</span> <span class="keyword">implements</span> <span class="title">HelloGateway</span> </span>{</span><br><span class="line"> <span class="function"><span class="keyword">protected</span> <span class="title">HelloEndpoint</span><span class="params">(RpcService rpcService)</span> </span>{</span><br><span class="line"> <span class="keyword">super</span>(rpcService);</span><br><span class="line"> ...</span><br><span class="line"> }</span><br><span class="line"> <span class="meta">@Override</span></span><br><span class="line"> <span class="function"><span class="keyword">public</span> CompletableFuture<String> <span class="title">sayHello</span><span class="params">()</span> </span>{</span><br><span class="line"> <span class="keyword">return</span> CompletableFuture.completedFuture(<span class="string">"Hello World"</span>);</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure>
</li>
<li><p>实例化服务端组件</p>
<figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// RpcService 可以理解为 RPC 框架引擎(客户端和服务端都有),可以用来启动、停止、连接一个服务端组件</span></span><br><span class="line">RpcService rpcService = getRpcService ...</span><br><span class="line">HelloEndpoint helloEndpoint = <span class="keyword">new</span> HelloEndpoint(rpcService); <span class="comment">// 内部会启动这个组件服务</span></span><br></pre></td></tr></table></figure>
</li>
<li><p>客户端发起远程调用</p>
<figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">RpcService rpcService = getRpcService ...</span><br><span class="line"><span class="comment">// rpcAddress 唯一标识要连接的服务端组件,例如 "rpc://host:port/path/to/helloendpoint"</span></span><br><span class="line">HelloGateway helloGateway = rpcService.connect(rpcAddress, HelloGateway<span class="class">.<span class="keyword">class</span>)</span>;</span><br><span class="line"><span class="comment">// 如果客户端跟服务端组件在同一个进程里,可以省去connect</span></span><br><span class="line"><span class="comment">// HelloGateway helloGateway = helloEndpoint.getSelfGateway(HelloGateway.class);</span></span><br><span class="line">helloGateway.sayHello(); <span class="comment">// helloGateway 作为客户端代理调用远程方法</span></span><br></pre></td></tr></table></figure>
</li>
</ol>
<p>从以上四步可以看到,Flink RPC 的封装比较高层,客户端的远程调用看起来完全就是调用本地方法,毫无收发消息的痕迹,接口类的命名也比较形象,如下图所示,当要发起远程调用时,临时拿到对应的接口网关,直接调用对应的接口。</p>
<p><img src="/images/flink-rpc-abstract.png" width="600" height="400" align="center"></p>
<p>有了一个基本的高层次认识后,再仔细分析上述代码,提出几个问题:</p>
<ol>
<li>服务端组件(RpcEndpoint)实例化过程中做了什么?</li>
<li>我们只是定了接口协议,接口网关(RpcGateway)是如何实例化出来的?</li>
<li>通过接口网关(RpcGateway)调用方法时,其内部是怎么收发消息的?</li>
</ol>
<p>在具体回答以上三个问题前,先简单介绍下 Java 的动态代理技术。</p>
<h3 id="Java-动态代理简介"><a href="#Java-动态代理简介" class="headerlink" title="Java 动态代理简介"></a>Java 动态代理简介</h3><p>有一种设计模式叫代理模式,通过代理对象访问目标对象,可以在不修改原目标对象的前提下,提供额外的功能操作,以达到扩展目标对象的功能。其UML大致如下图所示。</p>
<p><img src="/images/flink-rpc-proxy-pattern.png" width="600" height="400" align="center"></p>
<p>代理模式在 Java 中有静态代理和动态代理之分,我们先看下静态代理:<br><figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">public</span> <span class="class"><span class="keyword">interface</span> <span class="title">HelloInterface</span> </span>{</span><br><span class="line"> <span class="function">String <span class="title">sayHello</span><span class="params">()</span></span>;</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="keyword">public</span> <span class="class"><span class="keyword">class</span> <span class="title">ChinaHello</span> <span class="keyword">implements</span> <span class="title">HelloInterface</span> </span>{</span><br><span class="line"> <span class="meta">@Override</span></span><br><span class="line"> <span class="function"><span class="keyword">public</span> String <span class="title">sayHello</span><span class="params">()</span> </span>{</span><br><span class="line"> <span class="keyword">return</span> <span class="string">"你好"</span>;</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="keyword">public</span> <span class="class"><span class="keyword">class</span> <span class="title">HelloProxy</span> <span class="keyword">implements</span> <span class="title">HelloInterface</span> </span>{</span><br><span class="line"> <span class="keyword">private</span> HelloInterface target;</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">public</span> <span class="title">HelloProxy</span><span class="params">(HelloInterface target)</span> </span>{</span><br><span class="line"> <span class="keyword">this</span>.target = target;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="meta">@Override</span></span><br><span class="line"> <span class="function"><span class="keyword">public</span> String <span class="title">sayHello</span><span class="params">()</span> </span>{</span><br><span class="line"> <span class="comment">// do something before</span></span><br><span class="line"> <span class="keyword">return</span> target.sayHello();</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>以上静态代理模式的代码相信大家或多或少都有见过,通过<code>HelloProxy</code>去代理实际目标对象,扩展相关功能。但是静态代理需要在编译时实现,冗余代码较多。另外,Java 也提供了动态代理模式的实现,不需要事先实现接口,运行时通过反射动态实例化特定接口的实例,上述静态代理模式代码可以用如下动态代理模式来实现。<br><figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">public</span> <span class="class"><span class="keyword">interface</span> <span class="title">HelloInterface</span> </span>{</span><br><span class="line"> <span class="function">String <span class="title">sayHello</span><span class="params">()</span></span>;</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line">HelloInterface helloProxy = (HelloInterface) Proxy.newProxyInstance(</span><br><span class="line"> getClass().getClassLoader(),</span><br><span class="line"> <span class="keyword">new</span> Class<?>[] {HelloInterface<span class="class">.<span class="keyword">class</span>},</span></span><br><span class="line"><span class="class"> <span class="title">new</span> <span class="title">InvocationHandler</span>() </span>{</span><br><span class="line"> <span class="meta">@Override</span></span><br><span class="line"> <span class="function"><span class="keyword">public</span> Object <span class="title">invoke</span><span class="params">(Object proxy, Method method, Object[] args)</span> </span>{</span><br><span class="line"> <span class="keyword">if</span> (method.getName().equals(<span class="string">"sayHello"</span>)) {</span><br><span class="line"> <span class="keyword">return</span> <span class="string">"你好"</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span> <span class="keyword">null</span>;</span><br><span class="line"> }</span><br><span class="line"> });</span><br><span class="line">helloProxy.sayHello();</span><br></pre></td></tr></table></figure></p>
<p>可以看到以上代码并没有显示地实现接口<code>HelloInterface</code>,但是通过 Java 提供的<code>Proxy.newProxyInstance</code>方法可以动态创建该接口的实例,当调用该实例的方法时,会被转发到<code>InvocationHandler#invoke</code>中。认识了动态代理后,下面回过头来逐一回答前面提到的三个问题。</p>
<h3 id="接口实现规范"><a href="#接口实现规范" class="headerlink" title="接口实现规范"></a>接口实现规范</h3><p>为了阅读方便,下面先把前面的三个问题再拎出来:</p>
<ol>
<li><p>服务端组件(RpcEndpoint)实例化过程中做了什么?</p>
<figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">HelloEndpoint helloEndpoint = <span class="keyword">new</span> HelloEndpoint(rpcService);</span><br></pre></td></tr></table></figure>
</li>
<li><p>我们只是定了接口协议,接口网关(RpcGateway)是如何实例化出来的?</p>
<figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">HelloGateway helloGateway = rpcService.connect(rpcAddress, HelloGateway<span class="class">.<span class="keyword">class</span>)</span>;</span><br><span class="line"><span class="comment">// or</span></span><br><span class="line">HelloGateway helloGateway = helloEndpoint.getSelfGateway(HelloGateway<span class="class">.<span class="keyword">class</span>)</span>;</span><br></pre></td></tr></table></figure>
</li>
<li><p>通过接口网关(RpcGateway)调用方法时,其内部是怎么收发消息的?</p>
<figure class="highlight java"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">helloGateway.sayHello();</span><br></pre></td></tr></table></figure>
</li>
</ol>
<h4 id="服务端组件(RpcEndpoint)初始化"><a href="#服务端组件(RpcEndpoint)初始化" class="headerlink" title="服务端组件(RpcEndpoint)初始化"></a>服务端组件(RpcEndpoint)初始化</h4><p>为了更简单地处理多线程并发问题,对同一个<code>RpcEndpoint</code>的所有调用被设计成在同一个主线里串行执行,所以每个<code>RpcEndpoint</code>在实现的时候都不用担心数据共享一致性问题(不用考虑加锁等)。从前面的例子可以知道服务端组件实现了接口协议,如果客户端跟服务端在同一个进程中,客户端直接通过<code>RpcEndpoint#getSelfGateway</code>拿到<code>RpcEndpoint</code>实例调用对应的方法,那么就无法保证对同一个<code>RpcEndpoint</code>的所有调用在同一个主线程中串行执行。</p>
<p>为此,服务端在实例化具体<code>RpcEndpoint</code>时,其内部会启动一个<code>RpcServer</code>(不对外暴露),<code>RpcServer</code>只是一个接口,要实例化一个特定的<code>RpcServer</code>实例,就需要通过前面介绍的动态代理技术,在运行时动态生成,UML关系如下图所示。</p>
<p><img src="/images/flink-rpc-endpoint.png" width="600" height="400" align="center"></p>
<p>通过动态代理生成的<code>RpcServer</code>实例会绑定其对应的<code>RpcEndpoint</code>所实现的接口协议,即上述例子中<code>HelloEndpoint</code>中的<code>RpcServer</code>会有<code>sayHello</code>方法,所以当客户端跟服务端在同一个进程中,客户端通过<code>RpcEndpoint#getSelfGateway</code>拿到其中的<code>RpcServer</code>实例作为接口网关,进而调用其绑定的接口协议方法,根据Java动态代理原理,对<code>RpcServer</code>中的方法调用会被转发给<code>InvocationHandler</code>,在<code>InvocationHandler</code>中控制所有调用在同一个主线里串行执行。</p>
<h4 id="客户端获取接口网关(RpcGateway)"><a href="#客户端获取接口网关(RpcGateway)" class="headerlink" title="客户端获取接口网关(RpcGateway)"></a>客户端获取接口网关(RpcGateway)</h4><p>客户端发起RPC调用前,需要先拿到对应的接口网关<code>RpcGateway</code>,前面介绍到,当客户端与服务端在同一个进程中,通过<code>RpcEndpoint#getSelfGateway</code>获取,实际是拿到的是<code>RpcEndpoint</code>中的<code>RpcServer</code>实例,因为它是通过动态代理绑定了特定的RpcGateway创建的,所以也可以作为<code>RpcGateway</code>。当客户端与服务端不在一个进程中,通过<code>RpcService#connect</code>获取,服务端的每个<code>RpcEndpoint</code>都有一个唯一的 RPC 地址,客户端通过这个地址去连接路由到指定的<code>RpcEndpoint</code>,拿到消息 handler,通过消息 handler 双方握手成功后,客户端再通过动态代理创建特定的<code>RpcGateway</code>实例,其总体流程如下图所示。</p>
<p><img src="/images/flink-rpc-new-gateway.png" width="600" height="400" align="center"></p>
<h4 id="客户端发起-RPC-调用"><a href="#客户端发起-RPC-调用" class="headerlink" title="客户端发起 RPC 调用"></a>客户端发起 RPC 调用</h4><p>无论客户端是否与服务端在同一个进程中,客户端与<code>RpcGateway</code>的UML关系抽象如下图所示。</p>
<p><img src="/images/flink-rpc-gateway-call.png" width="600" height="400" align="center"></p>
<p>当客户端通过<code>RpcGateway</code>调用方法时,根据动态代理原理,该调用会被转发到<code>InvocationHandler</code>中,<code>InvocationHandler</code>将方法名、参数类型、参数对象列表打包成<code>RpcInvocation</code>消息,通过其握有的消息 handler,发送消息,并接受服务端响应,完成一次RPC调用。</p>
<h2 id="基于-Akka-的实现"><a href="#基于-Akka-的实现" class="headerlink" title="基于 Akka 的实现"></a>基于 Akka 的实现</h2><p>以上只是在一个抽象的层面介绍了 Flink RPC 的设计,具体实现还需要借助一套消息系统来完成,目前 Flink RPC 的默认是基于 Akka 框架实现(也是唯一的实现),Akka 的核心是 Actor 模型,如下图所示,Actor 与 Actor之前只能用消息进行通信,每个Actor都有对应一个信箱,消息是有顺序地被投递到信箱,Actor 串行处理信箱中的消息。建议自行先了解 Akka 及 Actor 的相关知识,这里不展开详细介绍。</p>
<p><img src="/images/flink-rpc-actor.png" width="600" height="400" align="center"></p>
<p>基于 Actor 模型,每个 RpcEndpoint 关联一个 Actor,正好契合了对每个 RpcEndpoint 的调用要求在同一个线程中完成的设计,同时,Akka 的每个 Actor 都有一个唯一的地址,正好作为 RpcEndpoint 的 RPC Address。</p>
<p>具体实现上,<code>AkkaRpcService</code>实现启动、停止、连接一个服务端组件,<code>AkkaRpcService</code>内部持有一个<code>ActorSystem</code>实例,当启动一个服务端组件时,会创建一个 <code>AkkaRpcActor</code>(其中定义了消息处理逻辑,当收到<code>RpcInvocation</code>消息时,会按照方法名调用<code>RpcEndpoint</code>中的具体实现),作为前面提到的消息 handler,然后通过动态代理实例化一个 <code>RpcServer</code>,绑定一个<code>AkkaInvocationHandler</code>,其持有前面创建 <code>ActorRef</code>。</p>
<p>当客户端与服务端在同一个进程中,那么直接获取这个<code>RpcServer</code>实例作为接口网关<code>RpcGateway</code>,这样接口网关<code>RpcGateway</code>上的方法调用会被转到<code>AkkaInvocationHandler</code>中,进而将方法名、参数类型、参数对象列表打包成<code>RpcInvocation</code>消息通过 <code>ActorRef</code> 发送,其实现如下图UML所示。</p>
<p><img src="/images/flink-rpc-akka.png" width="600" height="400" align="center"></p>
<p>当客户端与服务端不在同一个进程中,其通过<code>AkkaRpcService#connect</code>方法,连接服务端对应的<code>AkkaRpcActor</code>以得到其对应的<code>ActorRef</code>,类似地,通过动态代理实例化特定的 <code>RpcGateway</code>,绑定一个<code>AkkaInvocationHandler</code>,其持有前面连接获取到的 <code>ActorRef</code>,之后这个接口网关<code>RpcGateway</code>的方法调用会被转到<code>AkkaInvocationHandler</code>中,进而将方法名、参数类型、参数对象列表打包成<code>RpcInvocation</code>消息通过 <code>ActorRef</code> 发送,其实现如下图UML所示。。</p>
<p><img src="/images/flink-rpc-akka2.png" width="600" height="400" align="center"></p>
<h2 id="小结"><a href="#小结" class="headerlink" title="小结"></a>小结</h2><p>某种程度上讲,Flink RPC 设计来源于 Actor 模型,只是在这之上做了更高层的抽象,应用层不感知底层的消息收发,做到如同本地方法调用一般。一开始看其源码实现可能会觉得过度设计,来回绕圈,但是当看懂其设计本意后,就会觉得别有一番风味,里面包含了大量的优秀设计模式,对于我们实际写代码有很大的参考价值。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/flink-rpc.html">https://sharkdtu.github.io/posts/flink-rpc.html</a></em></span></p>
]]></content>
<categories>
<category> flink </category>
</categories>
<tags>
<tag> flink </tag>
<tag> rpc </tag>
</tags>
</entry>
<entry>
<title><![CDATA[TensorFlow 迁移学习实践小记]]></title>
<url>https://sharkdtu.github.io/posts/tf-migrate-learning.html</url>
<content type="html"><![CDATA[<p>在我们的很多推荐业务场景中,通常一个模型可能是一直不断增量训练的,如果哪天业务需要调整模型结构,去训练一个新模型,但是又不想完全从0开始,希望复用原来模型里面的部分参数,这样冷启动的代价就小很多了。<a id="more"></a></p>
<p>实际上 TensorFlow 提供了足够的灵活性,我们可以控制从其他模型 restore 部分参数到新的模型里。因为目前生产环境普遍还是在用 tf-1.x,下面分别介绍Low-Level API 和 Estimator API 两种实践。</p>
<h2 id="Low-Level-API-实践"><a href="#Low-Level-API-实践" class="headerlink" title="Low-Level API 实践"></a>Low-Level API 实践</h2><p>在决定从已有的模型预热参数前,可以先将模型ckpt拉到本地,开一个 ipython 或 jupyter,列出模型中的所有参数。<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line">In [<span class="number">6</span>]: tf.train.list_variables(checkpoint_dir)</span><br><span class="line">Out[<span class="number">6</span>]:</span><br><span class="line">[(<span class="string">'dense/bias'</span>, [<span class="number">1</span>]),</span><br><span class="line"> (<span class="string">'dense/bias/Adagrad'</span>, [<span class="number">1</span>]),</span><br><span class="line"> (<span class="string">'dense/kernel'</span>, [<span class="number">17</span>, <span class="number">1</span>]),</span><br><span class="line"> (<span class="string">'dense/kernel/Adagrad'</span>, [<span class="number">17</span>, <span class="number">1</span>]),</span><br><span class="line"> (<span class="string">'fm/b'</span>, [<span class="number">1</span>]),</span><br><span class="line"> (<span class="string">'fm/b/Adagrad'</span>, [<span class="number">1</span>]),</span><br><span class="line"> (<span class="string">'fm/v'</span>, [<span class="number">4809162</span>, <span class="number">16</span>]),</span><br><span class="line"> (<span class="string">'fm/v/Adagrad'</span>, [<span class="number">4809162</span>, <span class="number">16</span>]),</span><br><span class="line"> (<span class="string">'fm/w'</span>, [<span class="number">4809162</span>, <span class="number">1</span>]),</span><br><span class="line"> (<span class="string">'fm/w/Adagrad'</span>, [<span class="number">4809162</span>, <span class="number">1</span>]),</span><br><span class="line"> (<span class="string">'global_step'</span>, [])]</span><br></pre></td></tr></table></figure></p>
<p>假如,我们想要从ckpt中预热 <code>fm/v</code> 和 <code>fm/w</code> 两个参数,很简单,通过自定义一个 <code>tf.train.Saver</code> 来控制加载哪些参数:</p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line">...</span><br><span class="line">w2 = tf.get_variable(</span><br><span class="line"> <span class="string">"w2"</span>, shape=[<span class="number">4809162</span>, <span class="number">1</span>],</span><br><span class="line"> dtype=tf.float32,</span><br><span class="line"> initializer=tf.initializers.zeros())</span><br><span class="line">v2 = tf.get_variable(</span><br><span class="line"> <span class="string">"v2"</span>,</span><br><span class="line"> shape=[<span class="number">4809162</span>, <span class="number">16</span>],</span><br><span class="line"> dtype=tf.float32,</span><br><span class="line"> initializer=tf.initializers.truncated_normal(mean=<span class="number">0.0</span>, stddev=<span class="number">1</span> / math.sqrt(<span class="number">16</span>)))</span><br><span class="line"></span><br><span class="line">ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir)</span><br><span class="line">recover_vars = {<span class="string">'fm/w'</span>: w2, <span class="string">'fm/v'</span>: v2}</span><br><span class="line">saver = tf.train.Saver(recover_vars)</span><br><span class="line">...</span><br><span class="line"></span><br><span class="line"><span class="keyword">with</span> tf.Session() <span class="keyword">as</span> sess:</span><br><span class="line"> ...</span><br><span class="line"> sess.run(init_op)</span><br><span class="line"> saver.restore(sess, ckpt_state.model_checkpoint_path)</span><br><span class="line"> ...</span><br></pre></td></tr></table></figure>
<p>以上代码中 <code>recover_vars</code> 定义了要从 ckpt 中恢复的参数,是一个字典形式,key 为 ckpt 中的变量名,从上面我们 list 出来的变量里找即可,value 为要覆盖的变量,即从 ckpt 中找到名字为 key 的变量参数,去覆盖 value 指定的变量。</p>
<h2 id="Estimator-API-实践"><a href="#Estimator-API-实践" class="headerlink" title="Estimator API 实践"></a>Estimator API 实践</h2><p>如果你是用高阶 Estimator API,其实完全可以借助 Estimator 自带的 <a href="https://www.tensorflow.org/api_docs/python/tf/estimator/WarmStartSettings" target="_blank" rel="noopener">warm_start</a> 功能来实现。</p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">tf.estimator.WarmStartSettings(</span><br><span class="line"> ckpt_to_initialize_from, vars_to_warm_start=<span class="string">'.*'</span>, var_name_to_vocab_info=<span class="literal">None</span>,</span><br><span class="line"> var_name_to_prev_var_name=<span class="literal">None</span></span><br><span class="line">)</span><br></pre></td></tr></table></figure>
<ul>
<li>ckpt_to_initialize_from:预热模型的ckpt路径</li>
<li>vars_to_warm_start:要加载哪些变量出来预热,可以通过上述 <code>tf.train.list_variables</code> 方法先列出变量名再决定要哪些变量</li>
<li>var_name_to_vocab_info:动态词表信息</li>
<li>var_name_to_prev_var_name:新模型中的变量名 -> 旧模型中的变量名,意思就是加载出来的变量会预热到新模型的变量</li>
</ul>
<p>如果旧模型中有变量A,新模型有变量A、B,需要将旧模型的变量A恢复到新模型的变量B,如果使用tf warm_start,它既会将旧模型的变量A恢复到新模型的变量B,也会恢复到新模型的变量A。为解决名字冲突问题,我们可以自定义一个 Hook 将上述 low-level api 的使用方式封装一下,实现定制化恢复即可。</p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">VariableRecoverHook</span><span class="params">(tf.train.SessionRunHook)</span>:</span></span><br><span class="line"> <span class="string">"""Recover specified variables from checkpoint."""</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, ckp_dir, recover_vars)</span>:</span></span><br><span class="line"> <span class="string">"""Initializes a `VariableRecoverHook`.</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> ckp_dir: Checkpoint directory where variables recover from</span></span><br><span class="line"><span class="string"> recover_vars: A `dict` of names to variables</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> self._ckpt_state = tf.train.get_checkpoint_state(ckp_dir)</span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> isinstance(recover_vars, dict):</span><br><span class="line"> <span class="keyword">raise</span> ValueError(<span class="string">"recover_vars must be a dict of names to variables"</span>)</span><br><span class="line"> self._vars = recover_vars</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">begin</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="string">"""Create a tf saver for recover variables."""</span></span><br><span class="line"> self._saver = tf.train.Saver(self._vars)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">after_create_session</span><span class="params">(self, session, coord)</span>:</span></span><br><span class="line"> <span class="string">"""Recover variables from checkpoint."""</span></span><br><span class="line"> self._saver.restore(session, self._ckpt_state.model_checkpoint_path)</span><br></pre></td></tr></table></figure>
<p>以上代码实现一个 Hook,其中初始化参数 <code>recover_vars</code> 表示要从 ckpt 中恢复的参数。一般在恢复参数前,也需要list一下旧模型中的参数,找到对应的变量名。有了这个 Hook 后,那么我们就可以在 <code>model_fn</code> 中插入这个 Hook 的实例即可。</p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">model_fn</span><span class="params">(features, labels, mode, params)</span>:</span></span><br><span class="line"> ...</span><br><span class="line"> spec = head.create_estimator_spec(...)</span><br><span class="line"> recover_hook = VariableRecoverHook(ckp_dir=old_ckpt_dir, recover_vars=recover_vars)</span><br><span class="line"> <span class="keyword">return</span> spec._replace(training_hooks=(spec.training_hooks + (recover_hook,)))</span><br></pre></td></tr></table></figure>
<h2 id="小结"><a href="#小结" class="headerlink" title="小结"></a>小结</h2><p>实践中我们可以基于 TensorFlow 灵活保存以及恢复参数,当有迁移学习需求时,可以通过定制化 <code>tf.train.Saver</code> 的方式来控制预热指定的参数。目前 TensorFlow 也进入 2.x 时代了,官方主推Keras API,通过 Keras API 可以更加灵活的控制<a href="https://www.tensorflow.org/guide/checkpoint" target="_blank" rel="noopener">保存以及恢复参数</a>。但是如果你是用 Estimator,则可以直接复用 warm_start 或上述 Hook 实现。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/tf-migrate-learning.html">https://sharkdtu.github.io/posts/tf-migrate-learning.html</a></em></span></p>
]]></content>
<categories>
<category> 机器学习 </category>
</categories>
<tags>
<tag> TensorFlow </tag>
</tags>
</entry>
<entry>
<title><![CDATA[分布式TensorFlow编程模型演进]]></title>
<url>https://sharkdtu.github.io/posts/dist-tf-evolution.html</url>
<content type="html"><![CDATA[<p>TensorFlow从15年10月开源至今,可谓是发展迅猛,从v0.5到如今的v2.0.0-alpha,经历了无数个功能特性的升级,性能、可用性、易用性等都在稳步提升。相对来说,对于我们工业界,大家可能更关注分布式TensorFlow的发展,本文尝试梳理下分布式TensorFlow从问世到现在经历过的变迁。<a id="more"></a></p>
<h2 id="分布式TensorFlow运行时基本组件"><a href="#分布式TensorFlow运行时基本组件" class="headerlink" title="分布式TensorFlow运行时基本组件"></a>分布式TensorFlow运行时基本组件</h2><p>用户基于TensorFlow-API编写好代码提交运行,整体架构如下图所示。</p>
<p><img src="/images/tf-runtime.png" width="600" height="400" alt="tf-runtime" align="center"></p>
<ul>
<li><p>Client<br>可以把它看成是TensorFlow前端,它支持多语言的编程环境(Python/C++/Go/Java等),方便用户构造各种复杂的计算图。Client通过<code>Session</code>连接TensorFlow后端,并启动计算图的执行。</p>
</li>
<li><p>Master<br>Master根据要计算的操作(Op),从计算图中反向遍历,找到其所依赖的最小子图,然后将该子图再次分裂为多个子图片段,以便在不同的进程和设备上运行这些子图片段,最后将这些子图片段派发给Worker执行。</p>
</li>
<li><p>Worker<br>Worker按照计算子图中节点之间的依赖关系,根据当前的可用的硬件环境(GPU/CPU/TPU),调用Op的Kernel实现完成运算。</p>
</li>
</ul>
<p>在分布式TensorFlow中,参与分布式系统的所有节点或者设备统称为一个Cluster,一个Cluster中包含很多Server,每个Server去执行一项Task,Server和Task是一一对应的。所以,Cluster可以看成是Server的集合,也可以看成是Task的集合,TensorFlow为各个Task又增加了一个抽象层,将一系列相似的Task集合称为一个Job。形式化地,一个TensorFlow Cluster可以通过以下json来描述:<br><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">{</span><br><span class="line"> <span class="attr">"${job_name1}"</span>: [</span><br><span class="line"> <span class="string">"${host1}:${port1}"</span>,</span><br><span class="line"> <span class="string">"${host2}:${port2}"</span>,</span><br><span class="line"> <span class="string">"${host3}:${port3}"</span></span><br><span class="line"> ],</span><br><span class="line"> <span class="attr">"${job_name2}"</span>: [</span><br><span class="line"> <span class="string">"${host4}:${port4}"</span>,</span><br><span class="line"> <span class="string">"${host5}:${port5}"</span></span><br><span class="line"> ]</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>job用job_name(字符串)标识,而task用index(整数索引)标识,那么cluster中的每个task可以用job的name加上task的index来唯一标识,例如‘/job:worker/task:1’。一组Task集合(即Job)有若干个Server(host和port标识),每个Server上会绑定两个Service,就是前面提到的Master Service和Worker Service,Client通过Session连接集群中的任意一个Server的Master Service提交计算图,Master Service负责划分子图并派发Task给Worker Service,Worker Service则负责运算派发过来的Task完成子图的运算。下面详细阐述分布式TensorFlow不同架构的编程模型演进。</p>
<h2 id="基于PS的分布式TensorFlow编程模型"><a href="#基于PS的分布式TensorFlow编程模型" class="headerlink" title="基于PS的分布式TensorFlow编程模型"></a>基于PS的分布式TensorFlow编程模型</h2><p>分布式TensorFlow设计之初是沿用DistBelief(Google第一代深度学习系统)中采用的经典ps-worker架构,如下图所示。</p>
<p><img src="/images/tf-ps-worker.png" width="600" height="400" alt="tf-ps-worker" align="center"></p>
<p>对于PS架构,Parameter Server的Task集合为ps(即job类型为ps),而执行梯度计算的Task集合为worker(即job类型为worker),所以一个TensorFlow Cluster可以通过如下json描述:<br><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">{</span><br><span class="line"> <span class="attr">"worker"</span>: [</span><br><span class="line"> <span class="string">"${host1}:${port1}"</span>,</span><br><span class="line"> <span class="string">"${host2}:${port2}"</span>,</span><br><span class="line"> <span class="string">"${host3}:${port3}"</span></span><br><span class="line"> ],</span><br><span class="line"> <span class="attr">"ps"</span>: [</span><br><span class="line"> <span class="string">"${host4}:${port4}"</span>,</span><br><span class="line"> <span class="string">"${host5}:${port5}"</span></span><br><span class="line"> ]</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<h3 id="Low-level-分布式编程模型"><a href="#Low-level-分布式编程模型" class="headerlink" title="Low-level 分布式编程模型"></a>Low-level 分布式编程模型</h3><p>最原始的分布式TensorFlow编程是基于Low-level API来实现,下面我们通过举例来理解最原始的分布式TensorFlow编程步骤。我们在一台机器上启动三个Server(2个worker,1个ps)来模拟分布式多机环境,开启三个Python解释器(分别对应2个worker和1个ps),执行如下python语句,定义一个Cluster:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</span><br><span class="line"></span><br><span class="line">cluster = tf.train.ClusterSpec({</span><br><span class="line"> <span class="string">"worker"</span>: [</span><br><span class="line"> <span class="string">"localhost:2222"</span>,</span><br><span class="line"> <span class="string">"localhost:2223"</span></span><br><span class="line"> ],</span><br><span class="line"> <span class="string">"ps"</span>: [</span><br><span class="line"> <span class="string">"localhost:2224"</span></span><br><span class="line"> ]})</span><br></pre></td></tr></table></figure></p>
<p>在第一个worker解释器内执行如下语句启动Server:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">server = tf.train.Server(cluster, job_name=<span class="string">"worker"</span>, task_index=<span class="number">0</span>)</span><br></pre></td></tr></table></figure></p>
<p>在第二个worker解释器内执行如下语句启动Server:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">server = tf.train.Server(cluster, job_name=<span class="string">"worker"</span>, task_index=<span class="number">1</span>)</span><br></pre></td></tr></table></figure></p>
<p>在ps解释器内执行如下语句启动Server:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">server = tf.train.Server(cluster, job_name=<span class="string">"ps"</span>, task_index=<span class="number">0</span>)</span><br></pre></td></tr></table></figure></p>
<p>至此,我们已经启动了一个TensorFlow Cluster,它由两个worker节点和一个ps节点组成,每个节点上都有Master Service和Worker Service,其中worker节点上的Worker Service将负责梯度运算,ps节点上的Worker Service将负责参数更新,三个Master Service将仅有一个会在需要时被用到,负责子图划分与Task派发。</p>
<p>有了Cluster,我们就可以编写Client,构建计算图,并提交到这个Cluster上执行。使用分布式TensorFlow时,最常采用的分布式训练策略是数据并行,数据并行就是在很多设备上放置相同的模型,在TensorFlow中称之为Replicated training,主要表现为两种模式:图内复制(in-graph replication)和图间复制(between-graph replication)。不同的运行模式,Client的表现形式不一样。</p>
<h4 id="图内复制"><a href="#图内复制" class="headerlink" title="图内复制"></a>图内复制</h4><p>对于图内复制,只构建一个Client,这个Client构建一个Graph,Graph中包含一套模型参数,放置在ps上,同时Graph中包含模型计算部分的多个副本,每个副本都放置在一个worker上,这样多个worker可以同时训练复制的模型。</p>
<p>再开一个Python解释器,作为Client,执行如下语句构建计算图,并:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</span><br><span class="line"></span><br><span class="line"><span class="keyword">with</span> tf.device(<span class="string">"/job:ps/task:0"</span>):</span><br><span class="line"> w = tf.get_variable([[<span class="number">1.</span>, <span class="number">2.</span>, <span class="number">3.</span>], [<span class="number">1.</span>, <span class="number">3.</span>, <span class="number">5.</span>]])</span><br><span class="line"></span><br><span class="line">input_data = ...</span><br><span class="line">inputs = tf.split(input_data, num_workers)</span><br><span class="line">outputs = []</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(num_workers):</span><br><span class="line"> <span class="keyword">with</span> tf.device(<span class="string">"/job:ps/task:%s"</span> % str(i)):</span><br><span class="line"> outputs.append(tf.matmul(inputs[i], w))</span><br><span class="line"></span><br><span class="line">output = tf.concat(outputs, axis=<span class="number">0</span>)</span><br><span class="line"><span class="keyword">with</span> tf.Session() <span class="keyword">as</span> sess:</span><br><span class="line"> sess.run(tf.global_variables_initializer())</span><br><span class="line"> <span class="keyword">print</span> sess.run(output)</span><br></pre></td></tr></table></figure></p>
<p>从以上代码可以看到,当采用图内复制时,需要在Client上创建一个包含所有worker副本的流程图,随着worker数量的增长,计算图将会变得非常大,不利于计算图的维护。此外,数据分发在Client单点,要把训练数据分发到不同的机器上,会严重影响并发训练速度。所以在大规模分布式多机训练情况下,一般不会采用图内复制的模式,该模式常用于单机多卡情况下,简单直接。</p>
<h4 id="图间复制"><a href="#图间复制" class="headerlink" title="图间复制"></a>图间复制</h4><p>为可以解决图内复制在扩展上的局限性,我们可以采用图间复制模式。对于图间复制,每个worker节点上都创建一个Client,各个Client构建相同的Graph,但是参数还是放置在ps上,每个worker节点单独运算,一个worker节点挂掉了,系统还可以继续跑。</p>
<p>所以我们在第一个worker和第二个worker的Python解释器里继续执行如下语句实现Client完成整个分布式TensorFlow的运行:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">with</span> tf.device(<span class="string">"/job:ps/task:0"</span>):</span><br><span class="line"> w = tf.get_variable(name=<span class="string">'w'</span>, shape=[<span class="number">784</span>, <span class="number">10</span>])</span><br><span class="line"> b = tf.get_variable(name=<span class="string">'b'</span>, shape=[<span class="number">10</span>])</span><br><span class="line"></span><br><span class="line">x = tf.placeholder(tf.float32, shape=[<span class="literal">None</span>, <span class="number">784</span>])</span><br><span class="line">y = tf.placeholder(tf.int32, shape=[<span class="literal">None</span>])</span><br><span class="line">logits = tf.matmul(x, w) + b</span><br><span class="line">loss = ...</span><br><span class="line">train_op = ...</span><br><span class="line"></span><br><span class="line"><span class="keyword">with</span> tf.Session() <span class="keyword">as</span> sess:</span><br><span class="line"> <span class="keyword">for</span> _ <span class="keyword">in</span> range(<span class="number">10000</span>):</span><br><span class="line"> sess.run(train_op, feed_dict=...)</span><br></pre></td></tr></table></figure></p>
<p>在上述描述的过程中,我们是全程手动做分布式驱动的,先建立Cluster,然后构建计算图提交执行,Server上的Master Service和Worker Service根本没有用到。实际应用时当然不会这么愚蠢,一般是将以上代码片段放到一个文件中,通过参数控制执行不同的代码片段,例如:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</span><br><span class="line"></span><br><span class="line">ps_hosts = FLAGS.ps_hosts.split(<span class="string">","</span>)</span><br><span class="line">worker_hosts = FLAGS.worker_hosts.split(<span class="string">","</span>)</span><br><span class="line">cluster = tf.train.ClusterSpec({<span class="string">"ps"</span>: ps_hosts, <span class="string">"worker"</span>: worker_hosts})</span><br><span class="line">server = tf.train.Server(</span><br><span class="line"> cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)</span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> FLAGS.job_name == <span class="string">'ps'</span>:</span><br><span class="line"> server.join()</span><br><span class="line"><span class="keyword">elif</span> FLAGS.job_name == <span class="string">"worker"</span>:</span><br><span class="line"> <span class="keyword">with</span> tf.device(tf.train.replica_device_setter(</span><br><span class="line"> worker_device=<span class="string">"/job:worker/task:%d"</span> % FLAGS.task_index,</span><br><span class="line"> cluster=cluster)):</span><br><span class="line"> <span class="comment"># Build model...</span></span><br><span class="line"> loss = ...</span><br><span class="line"> train_op = ...</span><br><span class="line"></span><br><span class="line"> <span class="keyword">with</span> tf.train.MonitoredTrainingSession(</span><br><span class="line"> master=<span class="string">"/job:worker/task:0"</span>,</span><br><span class="line"> is_chief=(FLAGS.task_index == <span class="number">0</span>),</span><br><span class="line"> checkpoint_dir=<span class="string">"/tmp/train_logs"</span>) <span class="keyword">as</span> mon_sess:</span><br><span class="line"> <span class="keyword">while</span> <span class="keyword">not</span> mon_sess.should_stop():</span><br><span class="line"> mon_sess.run(train_op)</span><br></pre></td></tr></table></figure></p>
<p>每个节点上都执行如上代码,只是不同节点输入的参数不一样,对于ps节点,启动Server后就堵塞等待参数服务,对于worker节点,启动Server后(后台服务),开始扮演Client,构建计算图,最后通过<code>Session</code>提交计算。注意在调用<code>Session.run</code>之前,仅仅是Client的构图,并未开始计算,各节点上的Server还未发挥作用,只有在调用<code>Session.run</code>后,worker和ps节点才会被派发Task。在调用<code>Session.run</code>时,需要给<code>Session</code>传递<code>target</code>参数,指定使用哪个worker节点上的Master Service,Client将构建的计算图发给<code>target</code>指定的Master Service,一个TensorFlow集群中只有一个Master Service在工作,它负责子图划分、Task的分发以及模型保存与恢复等,在子图划分时,它会自动将模型参数分发到ps节点,将梯度计算分发到worker节点。另外,在Client构图时通过<code>tf.train.replica_device_setter</code>告诉worker节点默认在本机分配Op,这样每个Worker Service收到计算任务后构建出一个单独的计算子图副本,这样每个worker节点就可以单独运行,挂了不影响其他worker节点继续运行。</p>
<p>虽然图间复制具有较好的扩展性,但是从以上代码可以看到,写一个分布式TensorFlow应用,需要用户自行控制不同组件的运行,这就需要用户对TensorFlow的分布式架构有较深的理解。另外,分布式TensorFlow应用与单机版TensorFlow应用的代码是两套,一般使用过程中,用户都是先在单机上调试好基本逻辑,然后再部署到集群,在部署分布式TensorFlow应用前,就需要将前面的单机版代码改写成分布式多机版,用户体验非常差。所以说,使用Low-level 分布式编程模型,不能做到一套代码既可以在单机上运行也可以在分布式多机上运行,其用户门槛较高,一度被相关工程及研究人员诟病。为此,TensorFlow推出了High-level分布式编程模型,极大地改善用户易用性。</p>
<h3 id="High-level-分布式编程模型"><a href="#High-level-分布式编程模型" class="headerlink" title="High-level 分布式编程模型"></a>High-level 分布式编程模型</h3><p>TensorFlow提供<code>Estimator</code>和<code>Dataset</code>高阶API,简化模型构建以及数据输入,用户通过<code>Estimator</code>和<code>Dataset</code>高阶API编写TensorFlow应用,不用了解TensorFlow内部实现细节,只需关注模型本身即可。</p>
<p><code>Estimator</code>代表一个完整的模型,它提供方法用于模型的训练、评估、预测及导出。下图概括了<code>Estimator</code>的所有功能。</p>
<p><img src="/images/tf-estimator-interface.png" width="600" height="400" alt="tf-estimator-interface" align="center"></p>
<p><code>Estimator</code>具备如下优势:</p>
<ul>
<li>基于Estimator编写的代码,可运行在单机和分布式环境中,不用区别对待</li>
<li>简化了模型开发者之间共享部署,它提供了标准的模型导出功能,可以将训练好的模型直接用于TensorFlow-Serving等在线服务</li>
<li>提供全套的分布式训练生命周期管理,自动初始化变量、处理异常、创建检查点文件并从故障中恢复、以及保存TensorBoard 的摘要等</li>
<li>提供了一系列开箱即用的常见<code>Estimator</code>,例如<code>DNNClassifier</code>,<code>LinearClassifier</code>等</li>
</ul>
<p>使用<code>Estimator</code>编写应用时,需将数据输入从模型中分离出来。数据输入可以通过 <code>Dataset</code> API 构建数据 pipeline,类似Spark RDD或DataFrame,可以轻松处理大规模数据、不同的数据格式以及复杂的转换等。具体关于<code>Estimator</code>的使用可以参考<a href="https://www.tensorflow.org/guide/estimators" target="_blank" rel="noopener">TensorFlow官方文档</a>,讲的特别详细。</p>
<p>使用<code>Estimator</code>编写完应用后,可以直接单机上运行,如果需要将其部署到分布式环境运行,则需要在每个节点执行代码前设置集群的<code>TF_CONFIG</code>环境变量(实际应用时通常借助资源调度平台自动完成,如K8S,不需要修改TensorFlow应用程序代码):<br><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">TF_CONFIG=<span class="string">'{</span></span><br><span class="line"><span class="string"> "cluster": {</span></span><br><span class="line"><span class="string"> "chief": ["host0:2222"],</span></span><br><span class="line"><span class="string"> "worker": ["host1:2222", "host2:2222", "host3:2222"],</span></span><br><span class="line"><span class="string"> "ps": ["host4:2222", "host5:2222"]</span></span><br><span class="line"><span class="string"> },</span></span><br><span class="line"><span class="string"> "task": {"type": "chief", "index": 0}</span></span><br><span class="line"><span class="string">}'</span></span><br></pre></td></tr></table></figure></p>
<p><code>TF_CONFIG</code>环境变量是一个json字符串,指定集群规格cluster以及节点自身的角色task,cluster包括chief、worker、ps节点,chief节点其实是一个特殊的worker节点,而且只能有一个节点,表示分布式TensorFlow Master Service所在的节点。</p>
<p>通过以上描述可以看到,使用高阶API编写分布式TensorFlow应用已经很方便了,然而因为PS架构的缘故,我们实际部署时,需要规划使用多少个ps,多少个worker,那么调试过程中,需要反复调整ps和worker的数量。当模型规模较大时,在分布式训练过程中,ps可能成为网络瓶颈,因为所有worker都需要从ps处更新/获取参数,如果ps节点网络被打满,那么worker节点可能就会堵塞等待,以至于其计算能力就发挥不出来。所以后面TensorFlow引入All-Reduce架构解决这类问题。</p>
<h2 id="基于All-Reduce的分布式TensorFlow架构"><a href="#基于All-Reduce的分布式TensorFlow架构" class="headerlink" title="基于All-Reduce的分布式TensorFlow架构"></a>基于All-Reduce的分布式TensorFlow架构</h2><p>在单机多卡情况下,如下图左表所示(对应TensorFlow图内复制模式),GPU1~4卡负责网络参数的训练,每个卡上都布置了相同的深度学习网络,每个卡都分配到不同的数据的minibatch。每张卡训练结束后将网络参数同步到GPU0,也就是Reducer这张卡上,然后再求参数变换的平均下发到每张计算卡。</p>
<p><img src="/images/dl-ring-allreduce.png" width="600" height="400" alt="dl-ring-allreduce" align="center"></p>
<p>很显然,如果GPU较多,GPU0这张卡将成为整个训练的瓶颈,为了解决这样的问题,就引入了一种通信算法Ring Allreduce,通过将GPU卡的通信模式拼接成一个环形,解决带宽瓶颈问题,如上图右边所示。Ring Allreduce最早由百度提出,通过Ring Allreduce算法可以将整个训练过程中的带宽占用分摊到每块GPU卡上,详情可参考uber的一篇<a href="https://arxiv.org/pdf/1802.05799.pdf" target="_blank" rel="noopener">论文</a>。</p>
<p>TensorFlow从v1.8版本开始支持All-Reduce架构,它采用NVIDIA NCCL作为All-Reduce实现,为支持多种分布式架构,TensorFlow引入Distributed Strategy API,用户通过该API控制使用何种分布式架构,例如如果用户需要在单机多卡环境中使用All-Reduce架构,只需定义对应架构下的<code>Strategy</code>,指定<code>Estimator</code>的<code>config</code>参数即可:<br><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">mirrored_strategy = tf.distribute.MirroredStrategy()</span><br><span class="line">config = tf.estimator.RunConfig(</span><br><span class="line"> train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)</span><br><span class="line">regressor = tf.estimator.LinearRegressor(</span><br><span class="line"> feature_columns=[tf.feature_column.numeric_column(<span class="string">'feats'</span>)],</span><br><span class="line"> optimizer=<span class="string">'SGD'</span>,</span><br><span class="line"> config=config)</span><br></pre></td></tr></table></figure></p>
<p>对于分布式多机环境,最早是Uber专门提出了一种基于Ring-Allreduce的分布式TensorFlow架构<a href="https://github.com/horovod/horovod" target="_blank" rel="noopener">Horovod</a>,并已开源。目前TensorFlow已经官方支持,通过<code>MultiWorkerMirroredStrategy</code>来指定,目前该API尚处于实验阶段。如果在代码中通过<code>MultiWorkerMirroredStrategy</code>指定使用All-Reduce架构,则分布式提交时,<code>TF_CONFIG</code>环境变量中的cluster就不需要ps类型的节点了,例如:<br><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">TF_CONFIG=<span class="string">'{</span></span><br><span class="line"><span class="string"> "cluster": {</span></span><br><span class="line"><span class="string"> "chief": ["host0:2222"],</span></span><br><span class="line"><span class="string"> "worker": ["host1:2222", "host2:2222", "host3:2222"]</span></span><br><span class="line"><span class="string"> },</span></span><br><span class="line"><span class="string"> "task": {"type": "chief", "index": 0}</span></span><br><span class="line"><span class="string">}'</span></span><br></pre></td></tr></table></figure></p>
<p>通过不同的<code>Strategy</code>,可以轻松控制使用不同的分布式TensorFlow架构,可见TensorFlow的API设计更加灵活友好,拥有极强的可扩展性,相信将来会出现更多的<code>Strategy</code>来应对复杂的分布式场景。</p>
<h2 id="小结"><a href="#小结" class="headerlink" title="小结"></a>小结</h2><p>本文梳理了分布式TensorFlow编程模型的发展,主要从用户使用分布式TensorFlow角度出发,阐述了不同的分布式TensorFlow架构。可以看到,随着TensorFlow的迭代演进,其易用性越来越友好。目前TensorFlow已经发布了2.0.0-alpha版本了,标志着TensorFlow正式进入2.0时代了,在2.0版本中,其主打卖点是Eager Execution与Keras高阶API,整体易用性将进一步提升,通过Eager Execution功能,我们可以像使用原生Python一样操作Tensor,而不需要像以前一样需要通过<code>Session.run</code>的方式求解Tensor,另外,通过TensorFlow Keras高阶API,可以更加灵活方便构建模型,同时可以将模型导出为Keras标准格式HDF5,以灵活兼容在线服务等。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/dist-tf-evolution.html">https://sharkdtu.github.io/posts/dist-tf-evolution.html</a></em></span></p>
]]></content>
<categories>
<category> 深度学习 </category>
</categories>
<tags>
<tag> 分布式计算 </tag>
<tag> TensorFlow </tag>
</tags>
</entry>
<entry>
<title><![CDATA[借助Spark调度MPI作业]]></title>
<url>https://sharkdtu.github.io/posts/spark-mpi.html</url>
<content type="html"><![CDATA[<p>在Spark-2.4.0中社区提出一种新的Spark调度执行方式,名为<a href="https://jira.apache.org/jira/browse/SPARK-24374" target="_blank" rel="noopener">Barrier Execution Mode</a>,旨在通过Spark去调度分布式ML/DL(机器学习/深度学习)训练作业,这种训练作业一般是通过其他框架实现,例如MPI、TensorFlow等。由于Spark本身的计算框架是遵循MapReduce架构的,所以在调度执行时,每个Task都是独立执行的。然而,MPI这类分布式计算作业运行时需要所有Task一起执行,Task与Task之间需要相互通信。为了将诸如MPI分布式机器学习作业很好地与Spark结合,Spark社区提出一种新的Barrier Scheduler,可保证所有的Task全部调起(当然要保证资源足够)。<a id="more"></a></p>
<h2 id="API风格介绍"><a href="#API风格介绍" class="headerlink" title="API风格介绍"></a>API风格介绍</h2><p>众所周知,Spark其实本质上仍然是MapReduce架构,它能颠覆Hadoop的很重要一点是它提供了极其易用的API,用户不再需要绞尽脑汁将问题抽象成Map和Reduce,而是直接通过Spark提供的一个个算子来组织业务逻辑。新增Barrier Execution Mode当然不能破坏Spark本身的易用性,沿用原有的RDD API风格,例如调起一个MPI作业的大致轮廓为:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">rdd.barrier().mapPartitions { (iter, context) =></span><br><span class="line"> <span class="comment">// Maybe write iter to disk.</span></span><br><span class="line"> ???</span><br><span class="line"> <span class="comment">// Wait until all tasks finished writing.</span></span><br><span class="line"> context.barrier()</span><br><span class="line"> <span class="comment">// The 0-th task launches an MPI job.</span></span><br><span class="line"> <span class="keyword">if</span> (context.partitionId() == <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">val</span> hosts = context.getTaskInfos().map(_.host)</span><br><span class="line"> <span class="comment">// Set up MPI machine file using host infos.</span></span><br><span class="line"> ???</span><br><span class="line"> <span class="comment">// Launch the MPI job by calling mpirun.</span></span><br><span class="line"> ???</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// Wait until the MPI job finished.</span></span><br><span class="line"> context.barrier()</span><br><span class="line"> <span class="comment">// Collect output and return.</span></span><br><span class="line"> ???</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>上述代码首先将数据写到本地磁盘,以供mpi程序读取,在正式启动mpi程序前,先通过<code>barrier</code>操作同步等待所有Task完成前序工作,然后通过第一个Task(一般为rank=0的mpi进程)去拉起一个mpi job,拉起mpi job的参数通过<code>context</code>相关方法获取,最后等待所有mpi任务执行完毕,mpi任务将结果写到本地磁盘,由spark最后完成结果收集。</p>
<h2 id="小结"><a href="#小结" class="headerlink" title="小结"></a>小结</h2><p>如今大数据平台或机器学习平台一般都会提供Spark作业的调度,而对mpi作业的支持可能没有统一的调度方式,如果你的作业恰好需要mpi来实现,却没有一个平台来支持,这个时候我们可能会给平台方提需求,要求其支持你的mpi作业调度,但是这就拉长了你的开发周期,很多东西变得不可控,通过上述spark拉起mpi作业的方式,你可以将你的mpi作业变成一个spark作业,快速完成部署。另一方面,mpi作业的输入数据一般是要提前预处理好,这部分工作spark完全可以胜任,这样就可以通过pipeline的方式将整个业务逻辑串起来。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/spark-mpi.html">https://sharkdtu.github.io/posts/spark-mpi.html</a></em></span></p>
]]></content>
<categories>
<category> spark </category>
</categories>
<tags>
<tag> spark </tag>
<tag> 分布式计算 </tag>
<tag> 大数据 </tag>
<tag> mpi </tag>
</tags>
</entry>
<entry>
<title><![CDATA[Spark-Streaming状态管理应用优化之路]]></title>
<url>https://sharkdtu.github.io/posts/spark-streaming-state.html</url>
<content type="html"><![CDATA[<p>通常来说,使用Spark-Streaming做无状态的流式计算是很方便的,每个batch时间间隔内仅需要计算当前时间间隔的数据即可,不需要关注之前的状态。但是很多时候,我们需要对一些数据做跨周期的统计,例如我们需要统计一个小时内每个用户的行为,我们定义的计算间隔(batch-duration)肯定会比一个小时小,一般是数十秒到几分钟左右,每个batch的计算都要更新最近一小时的用户行为,所以需要在整个计算过程中维护一个状态来保存近一个小时的用户行为。在Spark-1.6以前,可以通过<code>updateStateByKey</code>操作实现有状态的流式计算,从spark-1.6开始,新增了<code>mapWithState</code>操作,引入了一种新的流式状态管理机制。<a id="more"></a></p>
<h2 id="背景"><a href="#背景" class="headerlink" title="背景"></a>背景</h2><p>为了更形象的介绍Spark-Streaming中的状态管理,我们从一个简单的问题展开:我们需要实时统计近一小时内每个用户的行为(点击、购买等),为了简单,就把这个行为看成点击列表吧,来一条记录,则加到指定用户的点击列表中,并保证点击列表无重复。计算时间间隔为1分钟,即每1分钟更新近一小时用户行为,并将有状态变化的用户行为输出。</p>
<h2 id="updateStateByKey"><a href="#updateStateByKey" class="headerlink" title="updateStateByKey"></a>updateStateByKey</h2><p>在Spark-1.6以前,可以通过<code>updateStateByKey</code>实现状态管理,其内部维护一个状态流来保存状态,上述问题可以通过如下实现完成:</p>
<figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// 更新一个用户的状态</span></span><br><span class="line"><span class="comment">// values: 单个用户实时过来的数据,这里为Seq类型,表示一分钟内可能有多条数据.</span></span><br><span class="line"><span class="comment">// state:单个用户上一个时刻的状态,如果没有这个用户的状态,则默认为空.</span></span><br><span class="line"><span class="keyword">val</span> updateState = (values: <span class="type">Seq</span>[<span class="type">Int</span>], state: <span class="type">Option</span>[<span class="type">Set</span>[<span class="type">Int</span>]]) => {</span><br><span class="line"> <span class="keyword">val</span> currentValues = values.toSet</span><br><span class="line"> <span class="keyword">val</span> previousValues = state.getOrElse(<span class="type">Set</span>.empty[<span class="type">Int</span>])</span><br><span class="line"> <span class="type">Some</span>(currentValues ++ previousValues)</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="comment">// 更新一个分区内用户的状态</span></span><br><span class="line"><span class="keyword">val</span> updateFunc =</span><br><span class="line"> (iterator: <span class="type">Iterator</span>[(<span class="type">String</span>, <span class="type">Seq</span>[<span class="type">Int</span>], <span class="type">Option</span>[<span class="type">Set</span>[<span class="type">Int</span>]])]) => {</span><br><span class="line"> iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="comment">// 原始数据流,经过过滤清洗,得到的记录形式为(userId, clickId)</span></span><br><span class="line"><span class="keyword">val</span> liveDStream = ... <span class="comment">// (userId, clickId)</span></span><br><span class="line"><span class="comment">// 使用updateStateByKey更新状态</span></span><br><span class="line"><span class="keyword">val</span> stateDstream = liveDStream.updateStateByKey(updateFunc)</span><br><span class="line"></span><br><span class="line">stateDstream.foreach(...)</span><br></pre></td></tr></table></figure>
<p>上述代码显示,我们只需要定义一个状态更新函数,传给<code>updateStateByKey</code>即可,Spark-Streaming会根据我们定义的更新函数,在每个计算时间间隔内更新内部维护的状态,最后把更新后的状态返回给我们。那么其内部是怎么做到的呢,简单来说就是cache+checkpoint+cogroup,状态更新流程如下图所示。</p>
<p><img src="/images/spark-streaming-updateStateByKey.png" width="600" height="300" alt="updateStateByKey" align="center"></p>
<p>上图左边蓝色箭头为实时过来的数据流<code>liveDStream</code>,通过<code>liveDStream.updateStateByKey</code>的调用,会得到一个<code>StateDStream</code>,为方框中上面浅绿色的箭头,实际更新状态时,Spark-Streaming会将当前时间间隔内的数据rdd-x,与上一个时间间隔的状态state-(x-1)做<code>cogroup</code>操作,<code>cogroup</code>中做的更新操作就是我们前面定义的<code>updateState</code>函数。程序开始时,state-0状态为空,即由rdd-1去初始化state-1。另外,出于容错考虑,状态数据流<code>StateDStream</code>一般会做cache和定期checkpoint,程序因为机器宕机等原因挂掉可以从checkpoint处恢复状态。</p>
<p>但是,我们之前的问题描述是“输出有状态变化的用户行为”,通过<code>updateStateByKey</code>得到的是整个状态数据,这并不是我们想要的。同时在每次状态更新时,都需要将实时过来的数据跟全量的状态做<code>cogroup</code>计算,也就是说,每次计算都要将全量状态扫一遍进行比对,当计算随着时间的进行,状态数据逐步覆盖到全量用户,数据量慢慢增大,在做<code>cogroup</code>时遍历就变的越来越慢,使得在一个batch的时间内完成不了计算,导致后续数据堆积,最终挂掉。所以说,<code>updateStateByKey</code>并不能解决我们之前描述的那个问题。</p>
<h2 id="mapWithState"><a href="#mapWithState" class="headerlink" title="mapWithState"></a>mapWithState</h2><p>从Spark-1.6开始,Spark-Streaming引入一种新的状态管理机制<code>mapWithState</code>,支持输出全量的状态和更新的状态,还支持对状态超时管理,用户可以自行选择需要的输出,通过<code>mapWithState</code>操作可以很方便地实现前面提出的问题。</p>
<figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// 状态更新函数,output是输出,state是状态</span></span><br><span class="line"><span class="keyword">val</span> mappingFunc = (</span><br><span class="line"> userId: <span class="type">Long</span>,</span><br><span class="line"> value: <span class="type">Option</span>[<span class="type">Int</span>],</span><br><span class="line"> state: <span class="type">State</span>[<span class="type">Set</span>[<span class="type">Int</span>]]) => {</span><br><span class="line"> <span class="keyword">val</span> previousValues = state.getOption.getOrElse(<span class="type">Set</span>.empty[<span class="type">Int</span>])</span><br><span class="line"> <span class="keyword">val</span> newValues = <span class="keyword">if</span> (value.isDefined){</span><br><span class="line"> previousValues.add(value.get)</span><br><span class="line"> } <span class="keyword">else</span> previousValues</span><br><span class="line"> <span class="keyword">val</span> output = (userId, newValues)</span><br><span class="line"> state.update(newValues)</span><br><span class="line"> output</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="comment">// 原始数据流,经过过滤清洗,得到的记录形式为(userId, clickId)</span></span><br><span class="line"><span class="keyword">val</span> liveDStream = ... <span class="comment">// (userId, clickId)</span></span><br><span class="line"><span class="comment">// 使用mapWithState更新状态,并设置状态超时时间为1小时</span></span><br><span class="line"><span class="keyword">val</span> stateDstream = liveDStream.mapWithState(</span><br><span class="line"> <span class="type">StateSpec</span>.function(mappingFunc).timeout(<span class="type">Minutes</span>(<span class="number">60</span>)))</span><br><span class="line"><span class="comment">// stateDstream默认只返回新数据经过mappingFunc后的结果</span></span><br><span class="line"><span class="comment">// 通过stateDstream.snapshot()返回当前的全量状态</span></span><br><span class="line">stateDstream.foreach(...)</span><br></pre></td></tr></table></figure>
<p>上述代码显示,我们需要定义一个状态更新函数<code>mappingFunc</code>,该函数会更新指定用户的状态,同时会返回更新后的状态,将该函数传给<code>mapWithState</code>,并设置状态超时时间,Spark-Streaming通过根据我们定义的更新函数,在每个计算时间间隔内更新内部维护的状态,同时返回经过<code>mappingFunc</code>后的结果数据流,其内部执行流程如下图所示。</p>
<p><img src="/images/spark-streaming-mapWithState-1.png" width="600" height="300" alt="mapWithState-1" align="center"></p>
<p>上图左边蓝色箭头为实时过来的数据流<code>liveDStream</code>,通过<code>liveDStream.mapWithState</code>的调用,会得到一个<code>MapWithStateDStream</code>,为方框中上面浅绿色的箭头,计算过程中,Spark-Streaming会遍历当前时间间隔内的数据rdd-x,在上一个时间间隔的状态state-(x-1)中查找指定的记录,并更新状态,更新操作就是我们前面定义的<code>mappingFunc</code>函数。这里的状态更新不再需要全量扫描状态数据了,状态数据是存在hashmap中,可以根据过来的数据快速定位到,详细的状态更新流程如下图所示。</p>
<p><img src="/images/spark-streaming-mapWithState-2.png" width="600" height="300" alt="mapWithState-2" align="center"></p>
<p>首先通过<code>partitionBy</code>将新来的数据分区到对应的状态分区上,每个状态分区中的仅有一条记录,类型为<code>MapWithStateRDDRecord</code>,它打包了两份数据,如下代码所示。</p>
<figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">case</span> <span class="class"><span class="keyword">class</span> <span class="title">MapWithStateRDDRecord</span>[<span class="type">K</span>, <span class="type">S</span>, <span class="type">E</span>](<span class="params"></span></span></span><br><span class="line"><span class="class"><span class="params"> var stateMap: <span class="type">StateMap</span>[<span class="type">K</span>, <span class="type">S</span>], var mappedData: <span class="type">Seq</span>[<span class="type">E</span>]</span>)</span></span><br></pre></td></tr></table></figure>
<p>其中<code>stateMap</code>保存当前分区内所有的状态,底层为hashmap类型,<code>mappedData</code>保存经过<code>mappingFunc</code>处理后的结果。这样,<code>liveDStream</code>经过<code>mapWithState</code>后就可以得到两份数据,默认输出的是<code>mappedData</code>这份,如果需要输出全量状态,则可以在<code>mapWithState</code>后调用<code>snapshot</code>函数获取。</p>
<table>
<thead>
<tr>
<th style="text-align:left">输入</th>
<th style="text-align:left">mapWithState后的结果</th>
<th style="text-align:left">调用stateSnapshots后的结果</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:left">(200100101, 1)</td>
<td style="text-align:left">(200100101, Set(1))</td>
<td style="text-align:left">(200100101, Set(1, 2))</td>
</tr>
<tr>
<td style="text-align:left">(200100101, 2)</td>
<td style="text-align:left">(200100101, Set(1, 2))</td>
<td style="text-align:left">(200100102, Set(1))</td>
</tr>
<tr>
<td style="text-align:left">(200100102, 1)</td>
<td style="text-align:left">(200100102, Set(1))</td>
<td style="text-align:left"></td>
</tr>
<tr>
<td style="text-align:left">(200100101, 2)</td>
<td style="text-align:left">(200100101, Set(1, 2))</td>
</tr>
</tbody>
</table>
<p>上述实现看似很美好,基本可以满足大部分的流式计算状态管理需求。但是,经过实际测试发现,状态缓存太耗内存了,出于容错考虑,状态数据会做cache和定期checkpoint,默认情况下是10个batch的时间做一次checkpoint,cache的记忆时间是20个batch时间,也就是说最多会缓存20份历史状态,我们的用户数是10亿,不可能hold住这么大的量。最最奇葩的是,checkpoint时间间隔和cache记忆时间都是代码里写死的,而且缓存方式采用<code>MEMORY_ONLY</code>也是写死的(估计是出于hashmap查找性能的考虑)。</p>
<p>既然写死了,那我们就修改源代码,将cache的记忆时间改为一个batch的时间,即每次仅缓存最新的那份,但是实际运行时,状态缓存数据量还是很大,膨胀了10倍以上,原因是Spark-Streaming在存储状态时,除了存储我们必要的数据外,还会带一些额外数据,例如时间戳、是否被删除标记、是否更新标记等,再加上JVM本身内存布局的膨胀,最终导致10倍以上的膨胀,而且在状态没有完全更新完毕时,旧的状态不会删除,所以中间会有两份的临时状态,如下图所示。</p>
<p><img src="/images/spark-streaming-cache.png" alt="spark-streaming-cache"></p>
<p>所以说,在状态数据量较大的情况下,<code>mapWithState</code>还是处理不了,看其源码的注释也是<code>@Experimental</code>状态,这大概也解释了为什么有些可调参数写死在代码里:),对于状态数据量较小的情况,还是可以一试。</p>
<p>综上分析,我们之前提出的那个问题当前Spark-Streaming是没法儿解决了,那就这样放弃了么?既然Spark-Streaming的状态管理做的那么差,那我们不用它的状态管理就是了,看看是否可以通过其他方式来存状态。我们最后想到了Redis,它是全内存的KV存储,具有较高的访问性能,同时它还支持超时管理,可以通过借助Redis来缓存状态,实现<code>mapWithState</code>类似的工作。</p>
<h2 id="使用Redis管理状态"><a href="#使用Redis管理状态" class="headerlink" title="使用Redis管理状态"></a>使用Redis管理状态</h2><p>通过前面的分析,我们不使用Spark自身的缓存机制来存储状态,而是使用Redis来存储状态。来一批新数据,先去redis上读取它们的上一个状态,然后更新写回Redis,逻辑非常简单,如下图所示。</p>
<p><img src="/images/sparkstreaming+redis.png" width="600" height="300" alt="sparkstreaming+redis" align="center"></p>
<p>在实际实现过程中,为了避免对同一个key有多次get/set请求,所以在更新状态前,使用<code>groupByKey</code>对相同key的记录做个归并,对于前面描述的问题,我们可以先这样做:</p>
<figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> liveDStream = ... <span class="comment">// (userId, clickId)</span></span><br><span class="line"></span><br><span class="line">liveDStream.groupByKey().mapPartitions(...)</span><br></pre></td></tr></table></figure>
<p>为了减少访问Redis的次数,我们使用pipeline的方式批量访问,即在一个分区内,一个一个批次的get/set,以提高Redis的访问性能,那么我们的更新逻辑就可以做到<code>mapPartitions</code>里面,如下代码所示。</p>
<figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> updateAndflush = (</span><br><span class="line"> records: <span class="type">Seq</span>[(<span class="type">Long</span>, <span class="type">Set</span>(<span class="type">Int</span>))],</span><br><span class="line"> states: <span class="type">Seq</span>[<span class="type">Response</span>[<span class="type">String</span>]],</span><br><span class="line"> pipeline: <span class="type">Pipeline</span>) => {</span><br><span class="line"> pipeline.sync() <span class="comment">// wait for getting</span></span><br><span class="line"> <span class="keyword">var</span> i = <span class="number">0</span></span><br><span class="line"> <span class="keyword">while</span> (i < records.size) {</span><br><span class="line"> <span class="keyword">val</span> (userId, values) = records(i)</span><br><span class="line"> <span class="comment">// 从字符串中解析出上一个状态中的点击列表</span></span><br><span class="line"> <span class="keyword">val</span> oldValues: <span class="type">Set</span>[<span class="type">Int</span>] = parseFrom(states(i).get())</span><br><span class="line"> <span class="keyword">val</span> newValues = values ++ oldValues</span><br><span class="line"> <span class="comment">// toString函数将Set[Int]编码为字符串</span></span><br><span class="line"> pipeline.setex(userId.toString, <span class="number">3600</span>, toString(newValues))</span><br><span class="line"> i += <span class="number">1</span></span><br><span class="line"> }</span><br><span class="line"> pipeline.sync() <span class="comment">// wait for setting</span></span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="keyword">val</span> mappingFunc = (iter: <span class="type">Iterator</span>[(<span class="type">Long</span>, <span class="type">Iterable</span>[<span class="type">Int</span>])]) => {</span><br><span class="line"> <span class="keyword">val</span> jedis = <span class="type">ConnectionPool</span>.getConnection()</span><br><span class="line"> <span class="keyword">val</span> pipeline = jedis.pipelined()</span><br><span class="line"></span><br><span class="line"> <span class="keyword">val</span> records = <span class="type">ArrayBuffer</span>.empty[(<span class="type">Long</span>, <span class="type">Set</span>(<span class="type">Int</span>))]</span><br><span class="line"> <span class="keyword">val</span> states = <span class="type">ArrayBuffer</span>.empty[<span class="type">Response</span>[<span class="type">String</span>]]</span><br><span class="line"> <span class="keyword">while</span> (iter.hasNext) {</span><br><span class="line"> <span class="keyword">val</span> (userId, values) = iter.next()</span><br><span class="line"> records += ((userId, values.toSet))</span><br><span class="line"> states += pipeline.get(userId.toString)</span><br><span class="line"> <span class="keyword">if</span> (records.size == batchSize) {</span><br><span class="line"> updateAndflush(records, states, pipeline)</span><br><span class="line"> records.clear()</span><br><span class="line"> states.clear()</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> updateAndflush(records, states, pipeline)</span><br><span class="line"> <span class="type">Iterator</span>[<span class="type">Int</span>]()</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line">liveDStream.groupByKey()</span><br><span class="line"> .mapPartitions(mappingFunc)</span><br><span class="line"> .foreachRDD { rdd =></span><br><span class="line"> rdd.foreach(_ => <span class="type">Unit</span>) <span class="comment">// force action</span></span><br><span class="line">}</span><br></pre></td></tr></table></figure>
<p>上述代码没有加容错等操作,仅描述实现逻辑,可以看到,函数<code>mappingFunc</code>会对每个分区的数据处理,实际计算时,会累计到batchSize才去访问Redis并更新,以降低访问Redis的频率。这样就不再需要cache和checkpoint了,程序挂了,快速拉起来即可,不需要从checkpoint处恢复状态,同时可以节省相当大的计算资源。</p>
<h2 id="测试及优化选项"><a href="#测试及优化选项" class="headerlink" title="测试及优化选项"></a>测试及优化选项</h2><p>经过上述改造后,实际测试中,我们的batch时间为一分钟,每个batch约200W条记录,使用资源列表如下:</p>
<ul>
<li>driver-memory: 4g</li>
<li>num-executors: 10</li>
<li>executor-memory: 4g</li>
<li>executor-cores: 3</li>
</ul>
<p>每个executor上启一个receiver,则总共启用10个receiver收数据,一个receiver占用一个core,则总共剩下10*2=20个core可供计算用,通过调整如下参数,可控制每个batch的分区数为 10*(60*1000)/10000=60(10个receiver,每个receiver上(60*1000)/10000个分区)。<br><figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">spark.streaming.blockInterval=10000</span><br></pre></td></tr></table></figure></p>
<p>为了避免在某个瞬间数据量暴增导致程序处理不过来,我们可以对receiver进行反压限速,只需调整如下两个参数即可,其中第一个参数是开启反压机制,即使数据源的数据出现瞬间暴增,每个receiver在收数据时都不会超过第二个参数的配置值,第二个参数控制单个receiver每秒接收数据的最大条数,通过下面的配置,一分钟内最多收 10*60*5000=300W(10个receiver,每个receiver一分钟最多收60*5000)条。</p>
<figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">spark.streaming.backpressure.enabled=true</span><br><span class="line">spark.streaming.receiver.maxRate=5000</span><br></pre></td></tr></table></figure>
<p>如果程序因为机器故障挂掉,我们应该迅速把拉重新拉起来,为了保险起见,我们应该加上如下参数让Driver失败重试4次,并在相应的任务调度平台上配置失败重试。</p>
<figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">spark.yarn.maxAppAttempts=4</span><br></pre></td></tr></table></figure>
<p>此外,为了防止少数任务太慢影响整个计算的速度,可以开启推测,并增加任务的失败容忍次数,这样在少数几个任务非常慢的情况下,会在其他机器上尝试拉起新任务做同样的事,哪个先做完,就干掉另外那个。但是开启推测有个条件,每个任务必须是幂等的,否则就会存在单条数据被计算多次。</p>
<figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">spark.speculation=<span class="literal">true</span></span><br><span class="line">spark.task.maxFailures=<span class="number">8</span></span><br></pre></td></tr></table></figure>
<p>经过上述配置优化后,基本可以保证程序7*24小时稳定运行,实际测试显示每个batch的计算时间可以稳定在30秒以内,没有上升趋势。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/spark-streaming-state.html">https://sharkdtu.github.io/posts/spark-streaming-state.html</a></em></span></p>
]]></content>
<categories>
<category> spark </category>
</categories>
<tags>
<tag> spark </tag>
<tag> 分布式计算 </tag>
<tag> 大数据 </tag>
<tag> streaming </tag>
<tag> updateStateByKey </tag>
<tag> mapWithState </tag>
</tags>
</entry>
<entry>
<title><![CDATA[机器学习套路--朴素贝叶斯]]></title>
<url>https://sharkdtu.github.io/posts/ml-nb.html</url>
<content type="html"><![CDATA[<p>朴素贝叶斯(NaiveBayes)是基于贝叶斯定理与特征条件独立假设的一种分类方法,常用于文档分类、垃圾邮件分类等应用场景。其基本思想是,对于给定的训练集,基于特征条件独立的假设,学习输入输出的联合概率分布,然后根据贝叶斯定理,对给定的预测数据,预测其类别为后验概率最大的类别。<a id="more"></a></p>
<h2 id="基本套路"><a href="#基本套路" class="headerlink" title="基本套路"></a>基本套路</h2><p>给定训练集 $T$,每个实例表示为 $(x, y)$,其中 $x$ 为 $n$ 维特征向量,定义 $X$ 为输入(特征)空间上的随机向量,$Y$ 为输出(类别)空间上的随机变量,根据训练集计算如下概率分布:</p>
<ul>
<li>先验概率分布,即每个类别在训练集中概率分布</li>
</ul>
<p>$$<br>P\left( Y=c_k \right) ,k=1, 2,…, K \left(\text{其中K为类别个数}\right)<br>$$</p>
<ul>
<li>条件概率分布,即在每个类别下,各特征的条件概率分布</li>
</ul>
<p>$$<br>P\left( X=x \mid Y=c_k \right) = P\left( X_1=x_1, X_2=x_2,…, X_n=x_n \mid Y=c_k \right)<br>$$</p>
<p>假设每个特征之间是独立的,那么上述条件概率分布可以展开为如下形式:</p>
<p>$$<br>\begin{split}<br>P\left( X=x \mid Y=c_k \right) &= P\left( X_1=x_1, X_2=x_2,…, X_n=x_n \mid Y=c_k \right) \\<br>&= \prod_{j=1}^{n} P\left( X_j=x_j \mid Y=c_k \right)<br>\end{split}<br>$$</p>
<p>如果有了每个类别的概率 $P\left( Y=c_k \right)$,以及 每个类别下每个特征的条件概率 $P\left( X_j=x_j \mid Y=c_k \right)$,那么对于一个未知类别的实例 $x$,就可以用贝叶斯公式求解其属于每个类别的后验概率:</p>
<p>$$<br>\begin{split}<br>P\left( Y=c_k \mid X=x \right) &= \frac {P\left( X=x \mid Y=c_k \right) P\left( Y=c_k \right)} {\sum_{k}P\left( X=x \mid Y=c_k \right) P\left( Y=c_k \right)} \\<br>&= \frac {P\left( Y=c_k \right) \prod_{j} P\left( X_j=x_j \mid Y=c_k \right)} {\sum_{k} P\left( Y=c_k \right)\prod_{j} P\left( X_j=x_j \mid Y=c_k \right)}<br>\end{split}<br>$$</p>
<p>对于每个实例,分母都一样,则将该实例的类别判别为:</p>
<p>$$<br>y = {arg \, max}_{c_k} \; P\left( Y=c_k \right) \prod_{j} P\left( X_j=x_j \mid Y=c_k \right)<br>$$</p>
<h2 id="应用套路"><a href="#应用套路" class="headerlink" title="应用套路"></a>应用套路</h2><p>那么如何求解 $P\left( Y=c_k \right)$ 和 $P\left( X_j=x_j \mid Y=c_k \right)$ 这些概率值呢?答案是极大似然估计。先验概率的极大似然估计为:</p>
<p>$$<br>P\left( Y=c_k \right) = \frac {N_{y=c_k} + \lambda} {\sum_i^K N_{y=c_i} + K\lambda}<br>$$</p>
<blockquote>
<p>其中 $N_{y=c_k}$ 为类别 $c_k$ 的实例个数,$K$ 为类别个数,$\lambda$ 为平滑系数,避免估计的概率为0的情况。</p>
</blockquote>
<p>对于条件概率 $P\left( X_j=x_j \mid Y=c_k \right)$ 的极大似然估计通常有两种模型:多项式模型和伯努利模型。</p>
<p><strong>多项式模型</strong></p>
<p>$$<br>P\left( X_j=x_j \mid Y=c_k \right) = \frac {N_{x_j \mid y=c_k} + \lambda} {\sum_i^{n}N_{x_j \mid y=c_k} + n\lambda}<br>$$</p>
<blockquote>
<p>其中 $N_{x_j \mid y=c_k}$ 为类别 $c_k$ 下特征 $x_j$ 出现的总次数, $n$ 为特征维度。</p>
</blockquote>
<p><strong>伯努利模型</strong></p>
<p>对于每个特征 $x_j$,只能有{0, 1}两种可能的取值:</p>
<p>$$<br>\begin{split}<br>P\left( X_j=1 \mid Y=c_k \right) &= \frac {N_{y=c_k, x_j=1} + \lambda} {N_{y=c_k} + 2\lambda} \\<br>P\left( X_j=0 \mid Y=c_k \right) &= 1- P\left( X_j=1 \mid Y=c_k \right)<br>\end{split}<br>$$</p>
<blockquote>
<p>其中 $N_{y=c_k, x_j=1}$ 为类别 $c_k$ 下特征 $x_j=1$ 出现的总次数。</p>
</blockquote>
<p>通过给定的训练集,根据上述极大似然估计方法,可以求得朴素贝叶斯模型的参数(即上述的先验概率和条件概率),基于这些参数即可根据下面的模型对未知类别的数据进行预测。</p>
<p>$$<br>y = {arg \, max}_{c_k} \; P\left( Y=c_k \right) \prod_{j} P\left( X_j=x_j \mid Y=c_k \right)<br>$$</p>
<h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>朴素贝叶斯模型是基于特征之间独立的假设,这是个非常强的假设,这也是其名字的由来,它属于生成学习方法,训练时不需要迭代拟合,模型简单易于理解,常用于文本分类等,并能取得较好的效果。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/ml-nb.html">https://sharkdtu.github.io/posts/ml-nb.html</a></em></span></p>
]]></content>
<categories>
<category> 机器学习 </category>
</categories>
<tags>
<tag> NaiveBayes </tag>
<tag> Classification </tag>
</tags>
</entry>
<entry>
<title><![CDATA[PySpark 的背后原理]]></title>
<url>https://sharkdtu.github.io/posts/pyspark-internal.html</url>
<content type="html"><![CDATA[<p><a href="http://spark.apache.org/" target="_blank" rel="noopener">Spark</a>主要是由Scala语言开发,为了方便和其他系统集成而不引入scala相关依赖,部分实现使用Java语言开发,例如External Shuffle Service等。总体来说,Spark是由JVM语言实现,会运行在JVM中。然而,Spark除了提供Scala/Java开发接口外,还提供了Python、R等语言的开发接口,为了保证Spark核心实现的独立性,Spark仅在外围做包装,实现对不同语言的开发支持,本文主要介绍Python Spark的实现原理,剖析pyspark应用程序是如何运行起来的。<a id="more"></a></p>
<h2 id="Spark运行时架构"><a href="#Spark运行时架构" class="headerlink" title="Spark运行时架构"></a>Spark运行时架构</h2><p>首先我们先回顾下Spark的基本运行时架构,如下图所示,其中橙色部分表示为JVM,Spark应用程序运行时主要分为Driver和Executor,Driver负载总体调度及UI展示,Executor负责Task运行,Spark可以部署在多种资源管理系统中,例如Yarn、Mesos等,同时Spark自身也实现了一种简单的Standalone(独立部署)资源管理系统,可以不用借助其他资源管理系统即可运行。更多细节请参考<a href="/posts/spark-scheduler.html">Spark Scheduler内部原理剖析</a>。</p>
<p><img src="/images/spark-structure.png" width="600" height="400" alt="spark-structure" align="center"></p>
<p>用户的Spark应用程序运行在Driver上(某种程度上说,用户的程序就是Spark Driver程序),经过Spark调度封装成一个个Task,再将这些Task信息发给Executor执行,Task信息包括代码逻辑以及数据信息,Executor不直接运行用户的代码。</p>
<h2 id="PySpark运行时架构"><a href="#PySpark运行时架构" class="headerlink" title="PySpark运行时架构"></a>PySpark运行时架构</h2><p>为了不破坏Spark已有的运行时架构,Spark在外围包装一层Python API,借助<a href="https://www.py4j.org/" target="_blank" rel="noopener">Py4j</a>实现Python和Java的交互,进而实现通过Python编写Spark应用程序,其运行时架构如下图所示。</p>
<p><img src="/images/pyspark-structure.png" width="600" height="400" alt="pyspark-structure" align="center"></p>
<p>其中白色部分是新增的Python进程,在Driver端,通过Py4j实现在Python中调用Java的方法,即将用户写的PySpark程序”映射”到JVM中,例如,用户在PySpark中实例化一个Python的SparkContext对象,最终会在JVM中实例化Scala的SparkContext对象;在Executor端,则不需要借助Py4j,因为Executor端运行的Task逻辑是由Driver发过来的,那是序列化后的字节码,虽然里面可能包含有用户定义的Python函数或Lambda表达式,Py4j并不能实现在Java里调用Python的方法,为了能在Executor端运行用户定义的Python函数或Lambda表达式,则需要为每个Task单独启一个Python进程,通过socket通信方式将Python函数或Lambda表达式发给Python进程执行。语言层面的交互总体流程如下图所示,实线表示方法调用,虚线表示结果返回。</p>
<p><img src="/images/pyspark-call.png" width="600" height="400" alt="pyspark-call" align="center"></p>
<p>下面分别详细剖析PySpark的Driver是如何运行起来的以及Executor是如何运行Task的。</p>
<h3 id="Driver端运行原理"><a href="#Driver端运行原理" class="headerlink" title="Driver端运行原理"></a>Driver端运行原理</h3><p>当我们通过spark-submmit提交pyspark程序,首先会上传python脚本及依赖,并申请Driver资源,当申请到Driver资源后,会通过PythonRunner(其中有main方法)拉起JVM,如下图所示。</p>
<p><img src="/images/pyspark-driver-runtime.png" width="600" height="400" alt="pyspark-driver-runtime" align="center"></p>
<p>PythonRunner入口main函数里主要做两件事:</p>
<ul>
<li>开启Py4j GatewayServer</li>
<li>通过Java Process方式运行用户上传的Python脚本</li>
</ul>
<p>用户Python脚本起来后,首先会实例化Python版的SparkContext对象,在实例化过程中会做两件事:</p>
<ul>
<li>实例化Py4j GatewayClient,连接JVM中的Py4j GatewayServer,后续在Python中调用Java的方法都是借助这个Py4j Gateway</li>
<li>通过Py4j Gateway在JVM中实例化SparkContext对象</li>
</ul>
<p>经过上面两步后,SparkContext对象初始化完毕,Driver已经起来了,开始申请Executor资源,同时开始调度任务。用户Python脚本中定义的一系列处理逻辑最终遇到action方法后会触发Job的提交,提交Job时是直接通过Py4j调用Java的PythonRDD.runJob方法完成,映射到JVM中,会转给sparkContext.runJob方法,Job运行完成后,JVM中会开启一个本地Socket等待Python进程拉取,对应地,Python进程在调用PythonRDD.runJob后就会通过Socket去拉取结果。</p>
<p>把前面运行时架构图中Driver部分单独拉出来,如下图所示,通过PythonRunner入口main函数拉起JVM和Python进程,JVM进程对应下图橙色部分,Python进程对应下图白色部分。Python进程通过Py4j调用Java方法提交Job,Job运行结果通过本地Socket被拉取到Python进程。还有一点是,对于大数据量,例如广播变量等,Python进程和JVM进程是通过本地文件系统来交互,以减少进程间的数据传输。</p>
<p><img src="/images/pyspark-driver.png" width="400" height="230" alt="pyspark-driver" align="center"></p>
<h3 id="Executor端运行原理"><a href="#Executor端运行原理" class="headerlink" title="Executor端运行原理"></a>Executor端运行原理</h3><p>为了方便阐述,以Spark On Yarn为例,当Driver申请到Executor资源时,会通过CoarseGrainedExecutorBackend(其中有main方法)拉起JVM,启动一些必要的服务后等待Driver的Task下发,在还没有Task下发过来时,Executor端是没有Python进程的。当收到Driver下发过来的Task后,Executor的内部运行过程如下图所示。</p>
<p><img src="/images/pyspark-executor-runtime.png" width="600" height="400" alt="pyspark-executor-runtime" align="center"></p>
<p>Executor端收到Task后,会通过launchTask运行Task,最后会调用到PythonRDD的compute方法,来处理一个分区的数据,PythonRDD的compute方法的计算流程大致分三步走:</p>
<ul>
<li>如果不存在pyspark.deamon后台Python进程,那么通过Java Process的方式启动pyspark.deamon后台进程,注意每个Executor上只会有一个pyspark.deamon后台进程,否则,直接通过Socket连接pyspark.deamon,请求开启一个pyspark.worker进程运行用户定义的Python函数或Lambda表达式。pyspark.deamon是一个典型的多进程服务器,来一个Socket请求,fork一个pyspark.worker进程处理,一个Executor上同时运行多少个Task,就会有多少个对应的pyspark.worker进程。</li>
<li>紧接着会单独开一个线程,给pyspark.worker进程喂数据,pyspark.worker则会调用用户定义的Python函数或Lambda表达式处理计算。</li>
<li>在一边喂数据的过程中,另一边则通过Socket去拉取pyspark.worker的计算结果。</li>
</ul>
<p>把前面运行时架构图中Executor部分单独拉出来,如下图所示,橙色部分为JVM进程,白色部分为Python进程,每个Executor上有一个公共的pyspark.deamon进程,负责接收Task请求,并fork pyspark.worker进程单独处理每个Task,实际数据处理过程中,pyspark.worker进程和JVM Task会较频繁地进行本地Socket数据通信。</p>
<p><img src="/images/pyspark-executor.png" width="300" height="130" alt="pyspark-executor.png" align="center"></p>
<h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>总体上来说,PySpark是借助Py4j实现Python调用Java,来驱动Spark应用程序,本质上主要还是JVM runtime,Java到Python的结果返回是通过本地Socket完成。虽然这种架构保证了Spark核心代码的独立性,但是在大数据场景下,JVM和Python进程间频繁的数据通信导致其性能损耗较多,恶劣时还可能会直接卡死,所以建议对于大规模机器学习或者Streaming应用场景还是慎用PySpark,尽量使用原生的Scala/Java编写应用程序,对于中小规模数据量下的简单离线任务,可以使用PySpark快速部署提交。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/pyspark-internal.html">https://sharkdtu.github.io/posts/pyspark-internal.html</a></em></span></p>
]]></content>
<categories>
<category> spark </category>
</categories>
<tags>
<tag> spark </tag>
<tag> pyspark </tag>
<tag> 分布式计算 </tag>
</tags>
</entry>
<entry>
<title><![CDATA[机器学习套路--协同过滤推荐ALS]]></title>
<url>https://sharkdtu.github.io/posts/ml-als.html</url>
<content type="html"><![CDATA[<p>如今,协同过滤推荐(CollaboratIve Filtering)技术已广泛应用于各类推荐系统中,其通常分为两类,一种是基于用户的协同过滤算法(User-Based CF),它是根据用户对物品的历史评价数据(如,喜欢、点击、购买等),计算不同用户之间的相似度,在有相同喜好的用户间进行物品推荐,例如将跟我有相同电影爱好的人看过的电影推荐给我;另一种是基于物品的协同过滤算法(Item-Based CF),它是根据用户对物品的历史评价数据,计算物品之间的相似度,用户如果喜欢A物品,那么可以给用户推荐跟A物品相似的其他物品,例如如果我们在购物网站上买过尿片,第二天你再到购物网站上浏览时,可能会被推荐奶瓶。<a id="more"></a>更多关于User-Based CF和Item-Based CF的阐述请参考<a href="http://www.cnblogs.com/luchen927/archive/2012/02/01/2325360.html" target="_blank" rel="noopener">文章</a>。然而,在用户数量以及用户评分不足的情况下,上述两种方法就不是那么地好使了,近年来,基于模型的推荐算法ALS(交替最小二乘)在Netflix成功应用并取得显著效果提升,ALS使用机器学习算法建立用户和物品间的相互作用模型,进而去预测新项。</p>
<h2 id="基本原理"><a href="#基本原理" class="headerlink" title="基本原理"></a>基本原理</h2><p>用户对物品的打分行为可以表示成一个打分矩阵 $R$,例如下表所示:</p>
<p><img src="/images/als-ratings.png" alt="als-ratings | center"></p>
<p>矩阵中的打分值 $r_{ij}$ 表示用户 $u_i$ 对物品 $v_j$ 的打分,其中”?”表示用户没有打分,这也就是要通过机器学习的方法去预测这个打分值,从而达到推荐的目的。</p>
<h3 id="模型抽象"><a href="#模型抽象" class="headerlink" title="模型抽象"></a>模型抽象</h3><p>按照User-Based CF的思想,$R$ 的行向量对应每个用户 $u$ ,按照Item-Based CF的思想,$R$ 的列向量对应每个物品 $v$ 。ALS 的核心思想是,将用户和物品都投影到 $k$ 维空间,也就是说,假设有 $k$ 个隐含特征,至于 $k$ 个隐含特征具体指什么不用关心,将每个用户和物品都用 $k$ 维向量来表示,把它们之间的内积近似为打分值,这样就可以得到如下近似关系:</p>
<p>$$ R \approx U V^T $$</p>
<blockquote>
<p>$R$ 为打分矩阵($m \times n$),$m$ 个用户,$n$ 个物品,$U$ 为用户对隐含特征的偏好矩阵($m \times k$),$V$ 为物品对隐含特征的偏好矩阵($n \times k$)。</p>
</blockquote>
<p>上述模型的参数就是矩阵 $U$ 和 $V$,即求解出 $U$ 和 $V$ 我们就可以重现打分矩阵,填补原始打分矩阵中的缺失值”?”。</p>
<h3 id="显示反馈代价函数"><a href="#显示反馈代价函数" class="headerlink" title="显示反馈代价函数"></a>显示反馈代价函数</h3><p>要求解上述模型中的 $U$ 和 $V$,那么就需要一个代价函数来衡量参数的拟合程度,如果有比较明确的显式反馈打分数据,那么可以比较重构出来的打分矩阵与实际打分矩阵,即得到重构误差,由于实际打分矩阵有很多缺失值,所以仅计算已知打分的重构误差,下面函数为显示反馈代价函数。</p>
<p>$$<br>J\left( U, V \right) = \sum_i \sum_j \left[ \left( r_{ij} - u_i v_j^T \right)^2 + \lambda \left( \|u_i\|^2 + \|v_j\|^2 \right) \right]<br>$$</p>
<blockquote>
<p>$r_{ij}$ 为矩阵 $R$ 的第 $i$ 行第 $j$ 列,表示用户 $u_i$ 对物品 $v_j$ 的打分,$u_i$ 为矩阵 $U$ 的第 $i$ 行 $(1 \times k)$,$v_j^T$ 为矩阵 $V^T$ 的第 $j$ 列 $(k \times 1)$,$\lambda$ 为正则项系数。</p>
</blockquote>
<h3 id="隐式反馈代价函数"><a href="#隐式反馈代价函数" class="headerlink" title="隐式反馈代价函数"></a>隐式反馈代价函数</h3><p>很多情况下,用户并没有明确反馈对物品的偏好,需要通过用户的相关行为来推测其对物品的偏好,例如,在视频推荐问题中,可能由于用户就懒得对其所看的视频进行反馈,通常是收集一些用户的行为数据,得到其对视频的偏好,例如观看时长等。通过这种方式得到的偏好值称之为隐式反馈值,即矩阵 $R$ 为隐式反馈矩阵,引入变量 $p_{ij}$ 表示用户 $u_i$ 对物品 $v_j$ 的置信度,如果隐式反馈值大于0,置信度为1,否则置信度为0。</p>
<p>$$ p_{ij} = \left\{\begin{matrix}1 \qquad r_{ij} > 0 & \\ 0 \qquad r_{ij} = 0 & \end{matrix}\right. $$</p>
<p>但是隐式反馈值为0并不能说明用户就完全不喜欢,用户对一个物品没有得到一个正的偏好可能源于多方面的原因,例如,用户可能不知道该物品的存在,另外,用户购买一个物品也并不一定是用户喜欢它,所以需要一个信任等级来显示用户偏爱某个物品,一般情况下,$r_{ij}$ 越大,越能暗示用户喜欢某个物品,因此,引入变量 $c_{ij}$,来衡量 $p_{ij}$ 的信任度。</p>
<p>$$ c_{ij} = 1 + \alpha r_{ij} $$</p>
<blockquote>
<p>$\alpha$ 为置信度系数</p>
</blockquote>
<p>那么,代价函数则变成如下形式:</p>
<p>$$<br>J\left( U, V \right) = \sum_i \sum_j \left[ c_{ij} \left( p_{ij} - u_i v_j^T \right)^2 + \lambda \left( \|u_i\|^2 + \|v_j\|^2 \right)\right]<br>$$</p>
<h3 id="算法"><a href="#算法" class="headerlink" title="算法"></a>算法</h3><p>无论是显示反馈代价函数还是隐式反馈代价函数,它们都不是凸的,变量互相耦合在一起,常规的梯度下降法可不好使了。但是如果先固定 $U$ 求解 $V$,再固定 $V$ 求解 $U$ ,如此迭代下去,问题就可以得到解决了。</p>
<p>$$ U^{(0)} \rightarrow V^{(1)} \rightarrow U^{(1)} \rightarrow V^{(2)} \rightarrow \cdots $$</p>
<p>那么固定一个变量求解另一个变量如何实现呢,梯度下降?虽然可以用梯度下降,但是需要迭代,计算起来相对较慢,试想想,固定 $U$ 求解 $V$,或者固定 $V$ 求解 $U$,其实是一个最小二乘问题,由于一般隐含特征个数 $k$ 取值不会特别大,可以将最小二乘转化为正规方程一次性求解,而不用像梯度下降一样需要迭代。如此交替地解最小二乘问题,所以得名交替最小二乘法ALS,下面是基于显示反馈和隐式反馈的最小二乘正规方程。</p>
<h4 id="显示反馈"><a href="#显示反馈" class="headerlink" title="显示反馈"></a>显示反馈</h4><p><strong>固定 $V$ 求解 $U$</strong></p>
<p>$$ U ^T = \left( V^T V + \lambda I \right)^{-1} V^T R^T $$</p>
<p>更直观一点,每个用户向量的求解公式如下:</p>
<p>$$<br>u_i ^T = \left( V^T V + \lambda I \right)^{-1} V^T r_i^T<br>$$</p>
<blockquote>
<p>$u_i^T$ 为矩阵 $U$ 的第 $i$ 行的转置($k \times 1$),$r_i^T$ 为矩阵 $R$ 的第 $i$ 行的转置($n \times 1$)。</p>
</blockquote>
<p><strong>固定 $U$ 求解 $V$</strong></p>
<p>$$ V ^T = \left( U^T U + \lambda I \right)^{-1} U^T R $$</p>
<p>更直观一点,每个物品向量的求解公式如下:</p>
<p>$$<br>v_j ^T = \left( U^T U + \lambda I \right)^{-1} U^T r_j^T<br>$$</p>
<blockquote>
<p>$v_j^T$ 为矩阵 $V^T$ 的第 $j$ 列($k \times 1$),$r_j^T$ 为矩阵 $R$ 的第 $j$ 列($m \times 1$)。</p>
</blockquote>
<h4 id="隐式反馈"><a href="#隐式反馈" class="headerlink" title="隐式反馈"></a>隐式反馈</h4><p><strong>固定 $V$ 求解 $U$</strong></p>
<p>$$<br>U ^T = \left( V^T C_v V + \lambda I \right)^{-1} V^T C_v R^T<br>$$</p>
<p>更直观一点,每个用户向量的求解公式如下:</p>
<p>$$<br>u_i ^T = \left( V^T C_v V + \lambda I \right)^{-1} V^T C_v r_i^T<br>$$</p>
<blockquote>
<p>$u_i^T$ 为矩阵 $U$ 的第 $i$ 行的转置($k \times 1$),$r_i^T$ 为矩阵 $R$ 的第 $i$ 行的转置($n \times 1$), $C_v$ 为对角矩阵($n \times n$)。</p>
</blockquote>
<p><strong>固定 $U$ 求解 $V$</strong></p>
<p>$$<br>V ^T = \left( U^T C_u U + \lambda I \right)^{-1} U^T C_u R<br>$$</p>
<p>更直观一点,每个物品向量的求解公式如下:</p>
<p>$$<br>v_j ^T = \left( U^T C_u U + \lambda I \right)^{-1} U^T C_u r_j^T<br>$$</p>
<blockquote>
<p>$v_j^T$ 为矩阵 $V^T$ 的第 $j$ 列($k \times 1$),$r_j^T$ 为矩阵 $R$ 的第 $j$ 列($m \times 1$),, $C_u$ 为对角矩阵($m \times m$)。</p>
</blockquote>
<h2 id="Spark-分布式实现"><a href="#Spark-分布式实现" class="headerlink" title="Spark 分布式实现"></a>Spark 分布式实现</h2><p>上述ALS算法虽然明朗了,但是要将其实现起来并不是信手拈来那么简单,尤其是数据量较大,需要使用分布式计算来实现,就更加不是那么地容易了。下面详细阐述Spark ML是如何完成ALS分布式实现的。为了更加直观的了解其分布式实现,下面用前面的打分矩阵作为例子,如下图所示。</p>
<p><img src="/images/als-ratings.png" alt="als-ratings | center"></p>
<p>由前面的原理介绍可知,按照显示反馈模型,固定 $U$ 求解 $V$,每个物品对隐含特征的偏好向量 $ v_j ^T$ 由以下公式得到:</p>
<p>$$ v_j ^T = \left( U^T U + \lambda I \right)^{-1} U^T r_j^T $$</p>
<p>计算时,只需要计算得到 $ U^T U + \lambda I $ 和 $U^T r_j^T$,再利用BLAS库即可解方程,初次迭代计算时,随机初始化矩阵 $U$,假设得到如下初始形式:</p>
<p>$$<br>U = \begin{bmatrix} -u_1- \\ -u_2- \\ -u_3- \end{bmatrix}<br>$$</p>
<p>假如求解 $v_1^T$,由于只有 $u_1$ 和 $u_2$ 对 $v_1$ 有打分,那么只需基于 $u_1$ 和 $u_2$ 来计算,根据相关线性代数知识就可以得到:</p>
<p> $$<br> \begin{split}<br>&U^T U = u_1^T u_1 + u_2^T u_2 \\<br>&U^T r_1^T = {\begin{bmatrix} u_1^T & u_2^T \end{bmatrix}} {\begin{bmatrix} 4 \\ 5 \end{bmatrix}} = 4u_1^T + 5u_2^T<br> \end{split}<br>$$</p>
<p>有了这个基本求解思路后,考虑 $u$ 的维度为 $k$,可以在单机上完成上述求解,那么就可以在不同task里完成不同物品 $v^T$ 的计算,实现分布式求解,由打分矩阵可以得到如下图所示的关系图。</p>
<p><img src="/images/mllib-als-reduce-1.png" width="600" height="400" alt="mllib-als-reduce-1" align="center"></p>
<p>基于上述思路,就是要把有打分关联的 u 和 v 想办法放到同一个分区里,这样就可以在一个task里完成对 v 的求解,例如要求解 $v_1$,就必须把 $u_1$ 和 $u_2$ 以及其对应地打分放到同一个分区,才能利用上述公式求解。首先对uid和vid以Hash分区的方式分区,假设分区数均为2,那么分区后的大致情况如下图所示,$v_1$ 和 $v_3$ 在同一个分区中被求解,$v_2$ 和 $v_4$ 在同一个分区中被求解。</p>
<p><img src="/images/als-id-partition.png" width="600" height="400" alt="als-id-partition" align="center"></p>
<p>上面的图仅为感性认识图,实际上手头仅有的数据就是打分矩阵,可以通过一个RDD表示打分矩阵<code>ratings</code>,RDD中的每条记录为<code>(uid, vid, rating)</code>形式,由于是基于 $U$ 求解 $V$,把uid称之为<code>srcId</code>,vid称之为<code>dstId</code>,按照<code>srcId</code>和<code>dstId</code>的分区方式,将<code>ratings</code>重新分区,得到的RDD为<code>blockRatings</code>,其中的每条记录为<code>((srcBlockId, dstBlockId), RatingBlock)</code>形式,key为<code>srcId</code>和<code>dstId</code>对应的分区id组成的二元组,value(<code>RatingBlock</code>)包含一个三元组<code>(srcIds, dstIds, ratings)</code>。对于前面的打分关系,原始打分矩阵重新分区如下图所示。</p>
<p><img src="/images/als-ratings-partition.png" width="600" height="400" alt="als-ratings-partition" align="center"></p>
<p>对于 u 来说,是要将自身信息发给不同的 v,对于 v 来说,是要接收来自不同 u 的信息,例如,要将 $u_1$ 发给 $v_1$、$v_2$、$v_3$ ,$v_1$ 要接收 $u_1$ 和 $u_2$。那么基于上述重新分区后的打分RDD,分别得到关于 u 的出口信息<code>userOutBlocks</code>,以及 v 的入口信息<code>itemInBlocks</code>,就可以通过<code>join</code>将两者联系起来计算了。由于后面基于 $V$ 求 $U$,也需要求解关于 u 的入口信息<code>userInBlocks</code>,以及 v 的出口信息<code>itemOutBlocks</code>,所以一次性计算好并缓存起来。以计算 u 的入口信息和出口信息为例,在前面得到的重新分区后的<code>blockRatings</code>基础上求解,如下图所示。</p>
<p><img src="/images/als-user-inblock.png" width="600" height="400" alt="als-user-inblock" align="center"></p>
<p>首先通过一个<code>map</code>操作,将记录形式<code>((srcBlockId, dstBlockId), RatingBlock)</code>转换为<code>(srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))</code>,其中<code>dstLocalIndices</code>为<code>dstIds</code>去重排序后,每个<code>dstId</code>的索引,最后根据<code>srcBlockId</code>做<code>groupByKey</code>,合并相同<code>srcBlockId</code>对应的value,合并过程中,对<code>dstLocalIndices</code>中的每个元素加上其对应的<code>dstBlockId</code>,这里做了一个优化,就是将<code>localIndex</code>和<code>blockId</code>用一个<code>Int</code>编码表示,同时采用类似<a href="http://blog.csdn.net/Em_dark/article/details/54313539" target="_blank" rel="noopener">CSC压缩编码</a>的方式,进一步压缩<code>srcIds</code>和<code>dstIds</code>的对应关系。这样就按照 uid 进行分区后,得到 u 的入口信息,即将跟 u 关联的 v 绑定在一起了。基于该入口信息,可以进一步得到 u 的出口信息,如下图所示。</p>
<p><img src="/images/als-user-outblock.png" width="600" height="400" alt="als-user-outblock" align="center"></p>
<p>在<code>userInBlocks</code>基础上根据<code>srcId</code>和<code>dstId</code>的对应关系,通过<code>map</code>操作将<code>(srcBlockId, (srcIds, dstPtrs, dstEncodedIndices, ratings))</code>形式的记录转换为<code>(srcBlockId, OutBlock)</code>得到<code>userOutBlocks</code>,其中<code>OutBlock</code>是一个二维数组,有<code>numDstBlock</code>行,每一行为<code>srcId</code>所在<code>srcBlockId</code>中的索引,意为当前<code>srcBlockId</code>应该往每个<code>dstBlockId</code>发送哪些用户信息。</p>
<p>同理,在<code>userInBlocks</code>基础上初始化用户信息,得到<code>userFactors</code>,如下图所示,其中 $u_1$、$u_2$、$u_3$为随机初始化的向量($1 \times k$)。</p>
<p><img src="/images/als-user-factors.png" width="600" height="400" alt="als-user-factors" align="center"></p>
<p>接着对<code>userOutBlocks</code>和<code>userFactors</code>做<code>join</code> 就可以模拟发送信息了,<code>userOutBlocks</code>中保存了应该往哪里发送的信息,<code>userFactors</code>中保存了用户信息,即一个掌握了方向,一个掌握了信息,如下图所示:</p>
<p><img src="/images/als-user-send.png" width="600" height="400" alt="als-user-send" align="center"></p>
<p>完成了从 u 到 v 的信息发送,后面就是基于 v 的入口信息来收集来自不同 u 的信息了,计算 v 的入口信息跟计算 u 的入口信息一样,只是先要把打分数据<code>blockRatings</code>的src和dst交换一下,如下图所示。</p>
<p><img src="/images/als-item-inblock.png" width="600" height="400" alt="als-item-inblock" align="center"></p>
<p>将<code>itemInBlocks</code>与前面的<code>userOut</code>做<code>join</code>,即可将具有相同<code>dstBlockId</code>的记录拉到一起,<code>userOut</code>中包含来自 u 的信息,<code>itemInBlocks</code>包含了与src的对应关系以及打分数据,针对每个 v 找到所有给它发送信息的 u,进而套最小二乘正规方程计算得到<code>itemFactors</code>。</p>
<p><img src="/images/als-item-factors.png" width="600" height="400" alt="als-item-factors" align="center"></p>
<p>得到<code>itemFactors</code>后可以以同样的方法基于 $V$ 求解 $U$,如此交替求解,直到最大迭代次数为止。</p>
<h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>ALS从基本原理上来看应该是很好理解的,但是要通过分布式计算来实现它,相对而言还是较为复杂的,本文重点阐述了Spark ML库中ALS的实现,要看懂以上计算流程,请务必结合源代码理解,凭空理解上述流程可能比较困难,在实际源码实现中,使用了很多优化技巧,例如使用在分区中的索引代替实际uid或vid,实现<code>Int</code>代替<code>Long</code>,使用数组等连续内存数据结构避免由于过多对象造成JVM GC后的内存碎片等。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/ml-als.html">https://sharkdtu.github.io/posts/ml-als.html</a></em></span></p>
]]></content>
<categories>
<category> 机器学习 </category>
</categories>
<tags>
<tag> 逻辑回归 </tag>
<tag> ALS </tag>
<tag> 协同过滤 </tag>
</tags>
</entry>
<entry>
<title><![CDATA[机器学习套路--逻辑回归]]></title>
<url>https://sharkdtu.github.io/posts/ml-lr.html</url>
<content type="html"><![CDATA[<p>逻辑回归常用于解决二分类问题,它将具有 $n$ 维特征的样本 $X$,经过线性加权后,通过 $sigmoid$ 函数转换得到一个概率值 $y$,预测时根据一个门限 $threshold$ (例如0.5)来划分类别,$y < threshold$ 为负类,$y \geq threshold$ 为正类。<a id="more"></a></p>
<h2 id="感性认识"><a href="#感性认识" class="headerlink" title="感性认识"></a>感性认识</h2><p>$sigmoid$ 函数 $\sigma (z) = \frac{1}{1+e^{-z}}$ 有如下图所示的漂亮S型曲线。</p>
<p><img src="/images/sigmoid.png" alt="sigmoid | center"></p>
<p>逻辑回归其实是在线性回归的基础上 $z = \sum_{i=1}^{n} {w_ix_i}$ ,借助 $sigmoid$ 函数将预测值压缩到0-1之间,实际上它是一种线性模型。其决策边界并不是上图中的S型曲线,而是一条直线或平面,如下图所示。</p>
<p><img src="/images/lr-boundary.png" width="328" height="200" alt="lr-boundary" align="center"></p>
<h2 id="基本套路"><a href="#基本套路" class="headerlink" title="基本套路"></a>基本套路</h2><p>机器学习问题,无外乎三点:模型,代价函数,优化算法。首先找到一个模型用于预测未知世界,然后针对该模型确定代价函数,以度量预测错误的程度,最后使用优化算法在已有的样本数据上不断地优化模型参数,来最小化代价函数。通常来说,用的最多的优化算法主要是梯度下降或拟牛顿法,计算过程都需要计算参数梯度值,下面仅从模型、代价函数以及参数梯度来描述一种机器学习算法。</p>
<p><strong>基本模型</strong>:<br>$$ h_ \theta(X) = \frac {1} {1 + e^{-\theta^T X}} $$</p>
<blockquote>
<p>$\theta$ 为模型参数,$X$ 为表示样本特征,它们均为 $n$ 维向量。</p>
</blockquote>
<p><strong>代价函数</strong>:<br>$$<br>J(\theta) = - \frac {1} {m} \sum_{i=1}^m \left( y^{(i)} logh_\theta(X^{(i)}) + (1-y^{(i)})(1-logh_\theta(X^{(i)}) \right)<br>$$</p>
<blockquote>
<p>上述公式也称之为交叉熵,$m$ 为样本个数,$(X^{(i)}, y^{(i)})$ 为第 $i$ 个样本。</p>
</blockquote>
<p><strong>参数梯度</strong>:<br>$$<br>\bigtriangledown_{\theta_j} J(\theta) = \frac {1} {m} \sum_{i=1}^m \left[ \left( y^{(i)} - h_\theta(X^{(i)}) \right) X^{(i)}_j \right]<br>$$</p>
<blockquote>
<p>$\theta_j$ 表示第 $j$ 个参数,$X^{(i)}_j$ 表示样本 $X^{(i)}$ 的第 $j$ 个特征值。</p>
</blockquote>
<h2 id="应用套路"><a href="#应用套路" class="headerlink" title="应用套路"></a>应用套路</h2><p>在实际应用时,基于上述基本套路可能会有些小变化,下面还是从模型、代价函数以及参数梯度来描述。</p>
<p>通常来说在模型中会加个偏置项,模型变成如下形式:<br>$$ h_ {\theta,b}(X) = \frac {1} {1 + e^{-(\theta^T X + b)}} $$</p>
<p>为了防止过拟合,一般会在代价函数上增加正则项,常见的正则方法参考前面的文章<a href="http://sharkdtu.com/posts/ml-linear-regression.html#正则化" target="_blank" rel="noopener">“线性回归”</a>。</p>
<p>加上正则项后,代价函数变成如下形式:<br>$$<br>\begin{split}<br>J(\theta, b) =& - \frac {1} {m} \sum_{i=1}^m \left( y^{(i)} log h_{\theta,b}(X^{(i)}) + (1-y^{(i)})(1-log h_{\theta,b}(X^{(i)}) \right) \\<br>&+ \frac {\lambda} {m} \left(\alpha \left \| \theta \right \| + \frac {1-\alpha} {2} {\left \| \theta \right \|}^2 \right)<br>\end{split}<br>$$</p>
<blockquote>
<p> $\lambda$ 为正则项系数,$\alpha$ 为ElasticNet参数,他们都是可调整的超参数, 当 $\alpha = 0$,则为L2正则, 当 $\alpha = 1$,则为L1正则。L1正则项增加 $1/m$ 以及L2正则项增加 $1/2m$ 系数,仅仅是为了使求导后的形式规整一些。</p>
</blockquote>
<p>由于 $sigmoid$ 函数在两端靠近极值点附近特别平缓,如果使用梯度下降优化算法,收敛非常慢,通常实际应用时,会使用拟牛顿法,它是沿着梯度下降最快的方向搜索,收敛相对较快,常见的拟牛顿法为<a href="http://blog.csdn.net/itplus/article/details/21896453" target="_blank" rel="noopener">L-BFGS</a>和<a href="http://research.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf" target="_blank" rel="noopener">OWL-QN</a>。L-BFGS只能处理可导的代价函数,由于L1正则项不可导,如果 $\alpha$ 不为0,那么不能使用L-BFGS,OWL-QN是基于L-BFGS算法的可用于求解L1正则的算法,所以当 $\alpha$ 不为0,可以使用OWL-QN。基于上述代价函数,下面仅列出包含L2正则项时的参数梯度:<br>$$<br>\begin{split}<br>\bigtriangledown_{\theta_j} J(\theta, b) &= \frac {1} {m} \sum_{i=1}^m \left( y^{(i)} - h_{\theta,b} (X^{(i)}) \right) X^{(i)}_j + \frac {\beta} {m} {\theta_j}^\ast \\<br>\bigtriangledown_b J(\theta, b) &= \frac {1} {m} \sum_{i=1}^m \left( y^{(i)} - h_{\theta,b} (X^{(i)}) \right)<br>\end{split}<br>$$</p>
<blockquote>
<p>${\theta_j}^\ast$ 为上一次迭代得到的参数值。</p>
</blockquote>
<h2 id="Softmax"><a href="#Softmax" class="headerlink" title="Softmax"></a>Softmax</h2><p>上述逻辑回归为二元逻辑回归,只能解决二分类问题,更一般地,可以推广到多元逻辑回归,用于解决多分类问题,一般将其称之为softmax,其模型、代价函数以及参数梯度描述如下。</p>
<p><strong>基本模型</strong><br>$$<br>H_\Theta(X) = \frac {1} {\sum_{j=1}^k exp(\Theta_j^T X)}<br>\begin{bmatrix}<br>exp(\Theta_1^T X)\\<br>exp(\Theta_2^T X)\\<br>…\\<br>exp(\Theta_k^T X)<br>\end{bmatrix}<br>$$</p>
<blockquote>
<p>$H_ \Theta(X)$ 是一个 $k$ 维向量,$k$ 为类别的个数,对于一个实例 $X$ ,经过上述模型输出 $k$ 个概率值,表示预测不同类别的概率,不难看出,输出的 $k$ 个概率值之和为1。模型中的参数则可以抽象为如下矩阵形式:<br> $$ \Theta = \begin{bmatrix}-\Theta_1^T-\\ -\Theta_2^T-\\ \cdots \\ -\Theta_k^T-\end{bmatrix} $$ $\Theta_j$ 表示第 $j$ 个参数向量,如果参数中带有偏置项,那么总共有 $k \times (n+1)$ 个参数。</p>
</blockquote>
<p><strong>代价函数</strong><br>$$ J(\Theta) = - \frac {1} {m} \left[\sum_{i=1}^m \sum_{j=1}^k 1 \left\{ y^{(i)} = j \right\} log \frac {exp(\Theta_j^T X)} {\sum_{l=1}^k exp(\Theta_l^T X)} \right] $$</p>
<blockquote>
<p>$1 \left\{ y^{(i)} = j \right\}$ 为示性函数,表示 $y^{(i)} = j$ 为真时,其结果为1,否则为0.</p>
</blockquote>
<p><strong>参数梯度</strong><br>$$<br>\begin{split}<br>& P\left( y^{(i)} = j \mid X^{(i)}, \Theta \right) = \frac {exp(\Theta_j^T X)} {\sum_{l=1}^k exp(\Theta_l^T X)} \\<br>& \bigtriangledown_{\Theta_j} J(\Theta) = \frac {1} {m} \sum_{i=1}^m \left[ \left( 1 \left\{ y^{(i)} = j \right\} - P\left( y^{(i)} = j \mid X^{(i)}, \Theta \right) \right ) X^{(i)} \right]<br>\end{split}<br>$$</p>
<blockquote>
<p>$P\left( y^{(i)} = j \mid X^{(i)}, \Theta \right)$ 表示将 $X^{(i)}$ 预测为第 $j$ 类的概率,注意 $\bigtriangledown_ {\Theta_j} J(\Theta)$ 是一个向量。</p>
</blockquote>
<h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>虽然逻辑回归是线性模型,看起来很简单,但是被应用到大量实际业务中,尤其在计算广告领域它一直是一颗闪耀的明珠,总结其优缺点如下:</p>
<ul>
<li>优点:计算代价低,速度快,易于理解和实现。</li>
<li>缺点:容易欠拟合,分类的精度可能不高。</li>
</ul>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/ml-lr.html">https://sharkdtu.github.io/posts/ml-lr.html</a></em></span></p>
]]></content>
<categories>
<category> 机器学习 </category>
</categories>
<tags>
<tag> 逻辑回归 </tag>
<tag> LR </tag>
<tag> Softmax </tag>
</tags>
</entry>
<entry>
<title><![CDATA[机器学习套路--线性回归]]></title>
<url>https://sharkdtu.github.io/posts/ml-linear-regression.html</url>
<content type="html"><![CDATA[<p>线性回归可以说是机器学习中最简单,最基础的机器学习算法,它是一种监督学习方法,可以被用来解决回归问题。它用一条直线(或者高维空间中的平面)来拟合训练数据,进而对未知数据进行预测。<a id="more"></a></p>
<p><img src="/images/linear_regression.png" alt="Alt text | center"></p>
<h2 id="基本套路"><a href="#基本套路" class="headerlink" title="基本套路"></a>基本套路</h2><p>机器学习方法,无外乎三点:模型,代价函数,优化算法。首先找到一个模型用于预测未知世界,然后针对该模型确定代价函数,以度量预测错误的程度,最后使用优化算法在已有的样本数据上不断地优化模型参数,来最小化代价函数。通常来说,用的最多的优化算法主要是梯度下降或拟牛顿法(<a href="http://blog.csdn.net/itplus/article/details/21896453" target="_blank" rel="noopener">L-BFGS</a>或<a href="http://research.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf" target="_blank" rel="noopener">OWL-QN</a>),计算过程都需要计算参数梯度值,下面仅从模型、代价函数以及参数梯度来描述一种机器学习算法。</p>
<p><strong>基本模型</strong>:<br>$$ \begin{split}<br>h_ \theta(X) &= \theta^T X \\<br>&= \theta_0 + \theta_1 x_1 + \theta_2 x_2 + \cdots + \theta_n x_n<br>\end{split} $$</p>
<blockquote>
<p>$X$ 为表示样本特征,为 $n$ 维向量,$\theta$ 为模型参数,为 $n+1$ 维向量,包括一个偏置 $\theta_0$</p>
</blockquote>
<p><strong>代价函数</strong>:<br>$$ J(\theta) = \frac {1} {2m} \sum_{i=1}^m \left ( y^{(i)}-h_\theta(X) \right ) ^2 $$</p>
<blockquote>
<p>上述公式也称之为平方误差,$m$ 为样本个数,$(X^{(i)}, y^{(i)})$ 为第 $i$ 个样本。</p>
</blockquote>
<p><strong>参数梯度</strong>:<br>$$ \bigtriangledown_{\theta_j} J(\theta) = \frac {1} {m} \sum_{i=1}^m \left[\left ( y^{(i)} - h_ \theta(X^{(i)}) \right ) X^{(i)}_j \right] $$</p>
<blockquote>
<p>$\theta_j$ 表示第 $j$ 个参数,$X^{(i)}_j$ 表示样本 $X^{(i)}$ 的第 $j$ 个特征值。</p>
</blockquote>
<p>上述描述是按照常规的机器学习方法来描述线性回归,模型参数一般是通过梯度下降或拟牛顿法优化迭代得到,其实线性回归问题是可解的,只是在样本维度较大时很难求解才使用优化迭代的方法来逼近,如果样本维度并不是很大的情况下,是可以解方程一次性得到样本参数。</p>
<p><strong>最小二乘</strong>:<br>$$ \theta = {\left( X^T X \right)} ^{-1} X^T y$$</p>
<blockquote>
<p>注意这里 $X$ 为 $m \times n$ 矩阵,$n$ 为特征维度,$m$ 为样本个数; $y$ 为 $m \times 1$ 向量,表示每个样本的标签。</p>
</blockquote>
<p><strong>加权最小二乘</strong>:<br>$$ \theta = {\left( X^T W X \right)} ^{-1} X^T W y$$</p>
<blockquote>
<p>$W$ 为 $m \times m$ 对角矩阵,对角线上的每个值表示对应样本实例的权重。</p>
</blockquote>
<h2 id="应用套路"><a href="#应用套路" class="headerlink" title="应用套路"></a>应用套路</h2><p>在实际应用时,基于上述基本套路可能会有些小变化,下面首先还是从模型、代价函数以及参数梯度来描述。把基本套路中模型公式中的 $\theta_0$ 改成 $b$,表示截距项,模型变成如下形式:<br>$$<br>\begin{split}<br>h_{\theta,b}(X) &= \theta^T X + b \\<br>&= \theta_1 x_1 + \theta_2 x_2 + \cdots + \theta_n x_n + b<br>\end{split}<br>$$</p>
<h3 id="正则化"><a href="#正则化" class="headerlink" title="正则化"></a>正则化</h3><p>为了防止过拟合,一般会在代价函数上增加正则项,常见的正则方法有:</p>
<ul>
<li>L1: $\lambda \left \| \theta \right \|$ , 也称之为套索回归(Lasso),可将参数稀疏化,但是不可导</li>
<li>L2: $\frac {\lambda} {2} {\left \| \theta \right \|}^2$,也称之为岭回归(Ridge),可将参数均匀化,可导</li>
<li>L1&L2: $\lambda \left(\alpha \left \| \theta \right \| + \frac {1-\alpha} {2} {\left \| \theta \right \|}^2 \right)$, 也称之为弹性网络(ElasticNet),具备L1&L2的双重特性</li>
</ul>
<p>加上正则项后,代价函数变成如下形式:<br>$$<br>\begin{split}<br>J(\theta, b) =& \frac {1} {2m} \sum_{i=1}^m \left ( y^{(i)}-h_{\theta,b}(X) \right ) ^2 + \frac {\lambda} {m} \left(\alpha \left \| \theta \right \| + \frac {1-\alpha} {2} {\left \| \theta \right \|}^2 \right)<br>\end{split}<br>$$</p>
<blockquote>
<p> $\lambda$ 为正则项系数,$\alpha$ 为ElasticNet参数,他们都是可调整的超参数, 当 $\alpha = 0$,则为L2正则, 当 $\alpha = 1$,则为L1正则。L1正则项增加 $1/m$ 以及L2正则项增加 $1/2m$ 系数,仅仅是为了使求导后的形式规整一些。</p>
</blockquote>
<p>由于L1正则项不可导,如果 $\alpha$ 不为0,那么不能简单的套用梯度下降或L-BFGS,需要采用借助<a href="http://blog.csdn.net/jbb0523/article/details/52103257" target="_blank" rel="noopener">软阈值(Soft Thresholding)</a>函数解决,如果是使用拟牛顿法,可以采用<a href="http://research.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf" target="_blank" rel="noopener">OWL-QN</a>,它是基于L-BFGS算法的可用于求解L1正则的算法。基于上述代价函数,下面仅列出包含L2正则项时的参数梯度:<br>$$<br>\begin{split}<br>\bigtriangledown_{\theta_j} J(\theta, b) &= \frac {1} {m} \sum_{i=1}^m \left ( y^{(i)} - h_{\theta,b} (X^{(i)}) \right ) X^{(i)}_j + \frac {\lambda (1-\alpha)} {m} {\theta_j}^\ast \\<br>\bigtriangledown_b J(\theta, b) &= \frac {1} {m} \sum_{i=1}^m \left( y^{(i)} - h_{\theta,b} (X^{(i)}) \right)<br>\end{split}<br>$$</p>
<blockquote>
<p>${\theta_j}^\ast$ 为上一次迭代得到的参数值。</p>
</blockquote>
<p>实际上,使用L2正则,是将前面所述的最小二乘方程改成如下形式:<br>$$ \theta = {\left( X^T X + kI \right)}^{-1} X^T y$$</p>
<blockquote>
<p>这样可以降低矩阵 $X^T X $ 奇异的可能,否则就不能求逆了。</p>
</blockquote>
<h3 id="标准化"><a href="#标准化" class="headerlink" title="标准化"></a>标准化</h3><p>一般来说,一个特征的值可能在区间 $(0, 1)$ 之间,另一特征的值可能在区间$(-\infty, \infty)$ ,这就是所谓的样本特征之间量纲不同,这样会导致优化迭代过程中的不稳定。当参数有不同初始值时,其收敛速度差异性较大,得到的结果可能也有较大的差异性,如下图所示,可以看到X和Y这两个变量的变化幅度不一致,如果直接使用梯度下降来优化迭代,那么量纲较大的特征信息量会被放大,量纲较小的特征信息量会被缩小。</p>
<p><img src="/images/ml-no-normalize.png" width="400" height="230" alt="ml-no-normalize" align="center"></p>
<p>所以一般要对数据作无量纲化处理,通常会采用标准化方法 $(x-u)/\sigma$ ,得到如下数据分布,这样无论从哪个点开始,其迭代方向的抖动都不会太大,每个特征的信息也不至于被放大和缩小。</p>
<p><img src="/images/ml-normalize.png" width="400" height="230" alt="ml-normalize.png" align="center"></p>
<h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>虽然线性回归现在可能很少用于解决实际问题,但是因为其简单易懂,学习它有助于对机器学习有个入门级的初步掌握,了解机器学习的套路等。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/ml-linear-regression.html">https://sharkdtu.github.io/posts/ml-linear-regression.html</a></em></span></p>
]]></content>
<categories>
<category> 机器学习 </category>
</categories>
<tags>
<tag> 线性回归 </tag>
<tag> LinearRegression </tag>
</tags>
</entry>
<entry>
<title><![CDATA[Spark SQL 之 Join 实现]]></title>
<url>https://sharkdtu.github.io/posts/spark-sql-join.html</url>
<content type="html"><![CDATA[<p>Join作为SQL中一个重要语法特性,几乎所有稍微复杂一点的数据分析场景都离不开Join,如今Spark SQL(<code>Dataset/DataFrame</code>)已经成为Spark应用程序开发的主流,作为开发者,我们有必要了解Join在Spark中是如何组织运行的。<a id="more"></a></p>
<h2 id="SparkSQL总体流程介绍"><a href="#SparkSQL总体流程介绍" class="headerlink" title="SparkSQL总体流程介绍"></a>SparkSQL总体流程介绍</h2><p>在阐述Join实现之前,我们首先简单介绍SparkSQL的总体流程,一般地,我们有两种方式使用SparkSQL,一种是直接写sql语句,这个需要有元数据库支持,例如Hive等,另一种是通过<code>Dataset/DataFrame</code>编写Spark应用程序。如下图所示,sql语句被语法解析(SQL AST)成查询计划,或者我们通过<code>Dataset/DataFrame</code>提供的APIs组织成查询计划,查询计划分为两大类:逻辑计划和物理计划,这个阶段通常叫做逻辑计划,经过语法分析(Analyzer)、一系列查询优化(Optimizer)后得到优化后的逻辑计划,最后被映射成物理计划,转换成RDD执行。</p>
<p><img src="/images/spark-sql-overview.png" width="600" height="400" alt="spark-sql-overview" align="center"></p>
<p>更多关于SparkSQL的解析与执行请参考文章<a href="http://www.cnblogs.com/hseagle/p/3752917.html" target="_blank" rel="noopener">【sql的解析与执行】</a>。对于语法解析、语法分析以及查询优化,本文不做详细阐述,本文重点介绍Join的物理执行过程。</p>
<h2 id="Join基本要素"><a href="#Join基本要素" class="headerlink" title="Join基本要素"></a>Join基本要素</h2><p>如下图所示,Join大致包括三个要素:Join方式、Join条件以及过滤条件。其中过滤条件也可以通过AND语句放在Join条件中。</p>
<p><img src="/images/spark-sql-join-overview.png" width="600" height="400" alt="spark-sql-join-overview" align="center"></p>
<p>Spark支持所有类型的Join,包括:</p>
<ul>
<li>inner join</li>
<li>left outer join</li>
<li>right outer join</li>
<li>full outer join</li>
<li>left semi join</li>
<li>left anti join</li>
</ul>
<p>下面分别阐述这几种Join的实现。</p>
<h2 id="Join基本实现流程"><a href="#Join基本实现流程" class="headerlink" title="Join基本实现流程"></a>Join基本实现流程</h2><p>总体上来说,Join的基本实现流程如下图所示,Spark将参与Join的两张表抽象为流式遍历表(<code>streamIter</code>)和查找表(<code>buildIter</code>),通常<code>streamIter</code>为大表,<code>buildIter</code>为小表,我们不用担心哪个表为<code>streamIter</code>,哪个表为<code>buildIter</code>,这个spark会根据join语句自动帮我们完成。</p>
<p><img src="/images/spark-sql-join-basic.png" width="600" height="400" alt="spark-sql-join-basic" align="center"></p>
<p>在实际计算时,spark会基于<code>streamIter</code>来遍历,每次取出<code>streamIter</code>中的一条记录<code>rowA</code>,根据Join条件计算<code>keyA</code>,然后根据该<code>keyA</code>去<code>buildIter</code>中查找所有满足Join条件(<code>keyB==keyA</code>)的记录<code>rowBs</code>,并将<code>rowBs</code>中每条记录分别与<code>rowA</code>join得到join后的记录,最后根据过滤条件得到最终join的记录。</p>
<p>从上述计算过程中不难发现,对于每条来自<code>streamIter</code>的记录,都要去<code>buildIter</code>中查找匹配的记录,所以<code>buildIter</code>一定要是查找性能较优的数据结构。spark提供了三种join实现:sort merge join、broadcast join以及hash join。</p>
<h3 id="sort-merge-join实现"><a href="#sort-merge-join实现" class="headerlink" title="sort merge join实现"></a>sort merge join实现</h3><p>要让两条记录能join到一起,首先需要将具有相同key的记录在同一个分区,所以通常来说,需要做一次shuffle,map阶段根据join条件确定每条记录的key,基于该key做shuffle write,将可能join到一起的记录分到同一个分区中,这样在shuffle read阶段就可以将两个表中具有相同key的记录拉到同一个分区处理。前面我们也提到,对于<code>buildIter</code>一定要是查找性能较优的数据结构,通常我们能想到hash表,但是对于一张较大的表来说,不可能将所有记录全部放到hash表中,另外也可以对<code>buildIter</code>先排序,查找时按顺序查找,查找代价也是可以接受的,我们知道,spark shuffle阶段天然就支持排序,这个是非常好实现的,下面是sort merge join示意图。</p>
<p><img src="/images/spark-sql-sort-join.png" width="600" height="400" alt="spark-sql-sort-join" align="center"></p>
<p>在shuffle read阶段,分别对<code>streamIter</code>和<code>buildIter</code>进行merge sort,在遍历<code>streamIter</code>时,对于每条记录,都采用顺序查找的方式从<code>buildIter</code>查找对应的记录,由于两个表都是排序的,每次处理完<code>streamIter</code>的一条记录后,对于<code>streamIter</code>的下一条记录,只需从<code>buildIter</code>中上一次查找结束的位置开始查找,所以说每次在<code>buildIter</code>中查找不必重头开始,整体上来说,查找性能还是较优的。</p>
<h3 id="broadcast-join实现"><a href="#broadcast-join实现" class="headerlink" title="broadcast join实现"></a>broadcast join实现</h3><p>为了能具有相同key的记录分到同一个分区,我们通常是做shuffle,那么如果<code>buildIter</code>是一个非常小的表,那么其实就没有必要大动干戈做shuffle了,直接将<code>buildIter</code>广播到每个计算节点,然后将<code>buildIter</code>放到hash表中,如下图所示。</p>
<p><img src="/images/spark-sql-broadcast-join.png" width="600" height="400" alt="spark-sql-broadcast-join" align="center"></p>
<p>从上图可以看到,不用做shuffle,可以直接在一个map中完成,通常这种join也称之为map join。那么问题来了,什么时候会用broadcast join实现呢?这个不用我们担心,spark sql自动帮我们完成,当<code>buildIter</code>的估计大小不超过参数<code>spark.sql.autoBroadcastJoinThreshold</code>设定的值(默认10M),那么就会自动采用broadcast join,否则采用sort merge join。</p>
<h3 id="hash-join实现"><a href="#hash-join实现" class="headerlink" title="hash join实现"></a>hash join实现</h3><p>除了上面两种join实现方式外,spark还提供了hash join实现方式,在shuffle read阶段不对记录排序,反正来自两格表的具有相同key的记录会在同一个分区,只是在分区内不排序,将来自<code>buildIter</code>的记录放到hash表中,以便查找,如下图所示。</p>
<p><img src="/images/spark-sql-hash-join.png" width="600" height="400" alt="spark-sql-hash-join" align="center"></p>
<p>不难发现,要将来自<code>buildIter</code>的记录放到hash表中,那么每个分区来自<code>buildIter</code>的记录不能太大,否则就存不下,默认情况下hash join的实现是关闭状态,如果要使用hash join,必须满足以下四个条件:</p>
<ul>
<li><code>buildIter</code>总体估计大小超过<code>spark.sql.autoBroadcastJoinThreshold</code>设定的值,即不满足broadcast join条件</li>
<li>开启尝试使用hash join的开关,<code>spark.sql.join.preferSortMergeJoin=false</code></li>
<li>每个分区的平均大小不超过<code>spark.sql.autoBroadcastJoinThreshold</code>设定的值,即shuffle read阶段每个分区来自<code>buildIter</code>的记录要能放到内存中</li>
<li><code>streamIter</code>的大小是<code>buildIter</code>三倍以上</li>
</ul>
<p>所以说,使用hash join的条件其实是很苛刻的,在大多数实际场景中,即使能使用hash join,但是使用sort merge join也不会比hash join差很多,所以尽量使用hash</p>
<p>下面我们分别阐述不同Join方式的实现流程。</p>
<h2 id="inner-join"><a href="#inner-join" class="headerlink" title="inner join"></a>inner join</h2><p>inner join是一定要找到左右表中满足join条件的记录,我们在写sql语句或者使用<code>DataFrame</code>时,可以不用关心哪个是左表,哪个是右表,在spark sql查询优化阶段,spark会自动将大表设为左表,即<code>streamIter</code>,将小表设为右表,即<code>buildIter</code>。这样对小表的查找相对更优。其基本实现流程如下图所示,在查找阶段,如果右表不存在满足join条件的记录,则跳过。</p>
<p><img src="/images/spark-sql-inner-join.png" width="600" height="400" alt="spark-sql-inner-join" align="center"></p>
<h2 id="left-outer-join"><a href="#left-outer-join" class="headerlink" title="left outer join"></a>left outer join</h2><p>left outer join是以左表为准,在右表中查找匹配的记录,如果查找失败,则返回一个所有字段都为null的记录。我们在写sql语句或者使用<code>DataFrmae</code>时,一般让大表在左边,小表在右边。其基本实现流程如下图所示。</p>
<p><img src="/images/spark-sql-leftouter-join.png" width="600" height="400" alt="spark-sql-leftouter-join" align="center"></p>
<h2 id="right-outer-join"><a href="#right-outer-join" class="headerlink" title="right outer join"></a>right outer join</h2><p>right outer join是以右表为准,在左表中查找匹配的记录,如果查找失败,则返回一个所有字段都为null的记录。所以说,右表是<code>streamIter</code>,左表是<code>buildIter</code>,我们在写sql语句或者使用<code>DataFrame</code>时,一般让大表在右边,小表在左边。其基本实现流程如下图所示。</p>
<p><img src="/images/spark-sql-rightouter-join.png" width="600" height="400" alt="spark-sql-rightouter-join" align="center"></p>
<h2 id="full-outer-join"><a href="#full-outer-join" class="headerlink" title="full outer join"></a>full outer join</h2><p>full outer join相对来说要复杂一点,总体上来看既要做left outer join,又要做right outer join,但是又不能简单地先left outer join,再right outer join,最后<code>union</code>得到最终结果,因为这样最终结果中就存在两份inner join的结果了。因为既然完成left outer join又要完成right outer join,所以full outer join仅采用sort merge join实现,左边和右表既要作为<code>streamIter</code>,又要作为<code>buildIter</code>,其基本实现流程如下图所示。</p>
<p><img src="/images/spark-sql-fullouter-join.png" width="600" height="400" alt="spark-sql-fullouter-join" align="center"></p>
<p>由于左表和右表已经排好序,首先分别顺序取出左表和右表中的一条记录,比较key,如果key相等,则join<code>rowA</code>和<code>rowB</code>,并将<code>rowA</code>和<code>rowB</code>分别更新到左表和右表的下一条记录;如果<code>keyA<keyB</code>,则说明右表中没有与左表<code>rowA</code>对应的记录,那么join<code>rowA</code>与<code>nullRow</code>,紧接着,<code>rowA</code>更新到左表的下一条记录;如果<code>keyA>keyB</code>,则说明左表中没有与右表<code>rowB</code>对应的记录,那么join<code>nullRow</code>与<code>rowB</code>,紧接着,<code>rowB</code>更新到右表的下一条记录。如此循环遍历直到左表和右表的记录全部处理完。</p>
<h2 id="left-semi-join"><a href="#left-semi-join" class="headerlink" title="left semi join"></a>left semi join</h2><p>left semi join是以左表为准,在右表中查找匹配的记录,如果查找成功,则仅返回左边的记录,否则返回<code>null</code>,其基本实现流程如下图所示。</p>
<p><img src="/images/spark-sql-semi-join.png" width="600" height="400" alt="spark-sql-semi-join" align="center"></p>
<h2 id="left-anti-join"><a href="#left-anti-join" class="headerlink" title="left anti join"></a>left anti join</h2><p>left anti join与left semi join相反,是以左表为准,在右表中查找匹配的记录,如果查找成功,则返回<code>null</code>,否则仅返回左边的记录,其基本实现流程如下图所示。</p>
<p><img src="/images/spark-sql-anti-join.png" width="600" height="400" alt="spark-sql-anti-join" align="center"></p>
<h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>Join是数据库查询中一个非常重要的语法特性,在数据库领域可以说是“得join者得天下”,SparkSQL作为一种分布式数据仓库系统,给我们提供了全面的join支持,并在内部实现上无声无息地做了很多优化,了解join的实现将有助于我们更深刻的了解我们的应用程序的运行轨迹。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/spark-sql-join.html">https://sharkdtu.github.io/posts/spark-sql-join.html</a></em></span></p>
]]></content>
<categories>
<category> spark </category>
</categories>
<tags>
<tag> spark </tag>
<tag> 分布式计算 </tag>
<tag> sql </tag>
<tag> join </tag>
</tags>
</entry>
<entry>
<title><![CDATA[从PageRank Example谈Spark应用程序调优]]></title>
<url>https://sharkdtu.github.io/posts/spark-app-optimize.html</url>
<content type="html"><![CDATA[<p>最近在做<a href="http://prof.ict.ac.cn/BigDataBench" target="_blank" rel="noopener">BigData-Benchmark</a>中PageRank测试,在测试时,发现有很多有趣的调优点,想到这些调优点可能是普遍有效的,现把它整理出来一一分析,以供大家参考。<a href="http://prof.ict.ac.cn/BigDataBench" target="_blank" rel="noopener">BigData-Benchmark</a>中的Spark PageRank采用的是Spark开源代码examples包里的PageRank的代码,原理及代码实现都比较简单,下面我简单地介绍下。<a id="more"></a></p>
<h2 id="PageRank基本原理介绍"><a href="#PageRank基本原理介绍" class="headerlink" title="PageRank基本原理介绍"></a>PageRank基本原理介绍</h2><p>PageRank的作用是评价网页的重要性,除了应用于搜索结果的排序之外,在其他领域也有广泛的应用,例如图算法中的节点重要度等。假设一个由4个页面组成的网络如下图所示,B链接到A、C,C连接到A,D链接到所有页面。</p>
<p><img src="/images/pagerank-graph-example.png" alt="pagerank-graph-example | center"></p>
<p>那么A的PR(PageRank)值分别来自B、C、D的贡献之和,由于B除了链接到A还链接到C,D除了链接到A还链接B、C,所以它们对A的贡献需要平摊,计算公式为:</p>
<p>$$ PR(A) = \frac {PR(B)} {2} + \frac {PR(C)} {1} + \frac {PR(D)} {3} \tag{1-1}$$</p>
<p>简单来说,就是根据链出总数平分一个页面的PR值:</p>
<p>$$ PR(A) = \frac {PR(B)} {L(B)} + \frac {PR(C)} {L(C)} + \frac {PR(D)} {L(D)} \tag{1-2}$$</p>
<p>对于上图中的A页面来说,它没有外链,这样计算迭代下去,PR值会全部收敛到A上去,所以实际上需要对这类没有外链的页面加上系数:</p>
<p>$$ PR(A) = d(\frac {PR(B)} {L(B)} + \frac {PR(C)} {L(C)} + \frac {PR(D)} {L(D)} + …) + \frac {1 - d} {N} \tag{1-3}$$</p>
<h2 id="Spark-PageRank-Example"><a href="#Spark-PageRank-Example" class="headerlink" title="Spark PageRank Example"></a>Spark PageRank Example</h2><p>Spark Examples中给出了一个简易的实现,后续讨论的相关优化都是基于该简易实现,所以并不一定可以用来解决实际PageRank问题,这里仅用于引出关于Spark调优的思考。下面是原始版本的实现代码,我们称之为V1。</p>
<figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">/**</span></span><br><span class="line"><span class="comment"> * Computes the PageRank of URLs from an input file.</span></span><br><span class="line"><span class="comment"> * Input file should be in format of:</span></span><br><span class="line"><span class="comment"> * URL neighbor URL</span></span><br><span class="line"><span class="comment"> * URL neighbor URL</span></span><br><span class="line"><span class="comment"> * URL neighbor URL</span></span><br><span class="line"><span class="comment"> * ...</span></span><br><span class="line"><span class="comment"> * where URL and their neighbor URL are separated by space(s).</span></span><br><span class="line"><span class="comment"> */</span></span><br><span class="line"><span class="keyword">val</span> lines = sc.textFile(inputPath)</span><br><span class="line"></span><br><span class="line"><span class="keyword">val</span> links = lines.map { s =></span><br><span class="line"> <span class="keyword">val</span> parts = s.split(<span class="string">"\\s+"</span>)</span><br><span class="line"> (parts(<span class="number">0</span>), parts(<span class="number">1</span>))</span><br><span class="line">}.distinct().groupByKey().cache()</span><br><span class="line"></span><br><span class="line"><span class="keyword">var</span> ranks = links.mapValues(v => <span class="number">1.0</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> (i <- <span class="number">1</span> to iters) {</span><br><span class="line"> <span class="keyword">val</span> contribs = links.join(ranks).values.flatMap {</span><br><span class="line"> <span class="keyword">case</span> (urls, rank) =></span><br><span class="line"> <span class="keyword">val</span> size = urls.size</span><br><span class="line"> urls.map(url => (url, rank / size))</span><br><span class="line"> }</span><br><span class="line"> ranks = contribs.reduceByKey(_ + _).mapValues(<span class="number">0.15</span> + <span class="number">0.85</span> * _)</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="comment">// Force action, like ranks.saveAsTextFile(outputPath)</span></span><br><span class="line">ranks.foreach(_ => <span class="type">Unit</span>)</span><br></pre></td></tr></table></figure>
<p>上面的代码应该不难理解,它首先通过<code>groupByKey</code>得到每个url链接的urls列表,初始化每个url的初始rank为1.0,然后通过<code>join</code>将每个url的rank均摊到其链接的urls上,最后通过<code>reduceByKey</code>规约来自每个url贡献的rank,经过若干次迭代后得到最终的<code>ranks</code>,为了方便测试,上面代码29行我改成了一个空操作的action,用于触发计算。</p>
<h2 id="优化一-Cache-amp-Checkpoint"><a href="#优化一-Cache-amp-Checkpoint" class="headerlink" title="优化一(Cache&Checkpoint)"></a><span id="opt1">优化一(Cache&Checkpoint)</span></h2><p>从原始版本的代码来看,有些童鞋可能会觉得有必要对<code>ranks</code>做cache,避免每次迭代重计算,我们不妨先运行下原始代码,看看是否真的有必要,下图是指定迭代次数为3时的Job DAG图,其中蓝色的点表示被cache过。</p>
<p><img src="/images/pagerank-iter-3-dag.png" alt="pagerank-iter-3-dag | center"></p>
<p>从上图可以看到,<code>ranks</code>没有被cache,3次迭代计算是在一个job里一气呵成的,所以没必要对<code>ranks</code>做cache,因为从整个代码来看,在迭代循环里没有出现action方法,所以迭代循环中不会触发job,仅仅是组织RDD之间的依赖关系。</p>
<p>但是,一般来说迭代次数都比较大,如果迭代1000甚至10000次,上述RDD依赖关系将变得非常长。一方面会增加driver的维护压力,很可能导致driver OOM;另一方面可能导致失败重算,单个task失败后,会根据RDD的依赖链从头开始计算。所以从容错以及可用性来说,上述代码实现是不可取的。所幸,Spark提供了checkpoint机制,来实现断链及中间结果持久化。</p>
<p>使用checkpoint,我们来改造上述迭代循环,在每迭代若干次后做一次checkpoint,保存中间结果状态,并切断RDD依赖关系链,迭代循环代码改造如下:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br></pre></td><td class="code"><pre><span class="line">...</span><br><span class="line"><span class="keyword">var</span> lastCheckpointRanks: <span class="type">RDD</span>[(<span class="type">String</span>, <span class="type">Double</span>)] = <span class="literal">null</span></span><br><span class="line"><span class="keyword">for</span> (i <- <span class="number">1</span> to iters) {</span><br><span class="line"> <span class="keyword">val</span> contribs = links.join(ranks).values.flatMap {</span><br><span class="line"> <span class="keyword">case</span> (urls, rank) =></span><br><span class="line"> <span class="keyword">val</span> size = urls.size</span><br><span class="line"> urls.map(url => (url, rank / size))</span><br><span class="line"> }</span><br><span class="line"> ranks = contribs.reduceByKey(_ + _).mapValues(<span class="number">0.15</span> + <span class="number">0.85</span> * _)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> (i % <span class="number">10</span> == <span class="number">0</span> && i != iters) {</span><br><span class="line"> ranks.cache().setName(<span class="string">s"iter<span class="subst">$i</span>: ranks"</span>)</span><br><span class="line"> ranks.checkpoint()</span><br><span class="line"> <span class="comment">// Force action, just for trigger calculation</span></span><br><span class="line"> ranks.foreach(_ => <span class="type">Unit</span>)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> (lastCheckpointRanks != <span class="literal">null</span>) {</span><br><span class="line"> lastCheckpointRanks.getCheckpointFile.foreach { ckp =></span><br><span class="line"> <span class="keyword">val</span> p = <span class="keyword">new</span> <span class="type">Path</span>(ckp)</span><br><span class="line"> <span class="keyword">val</span> fs = p.getFileSystem(sc.hadoopConfiguration)</span><br><span class="line"> fs.delete(p, <span class="literal">true</span>)</span><br><span class="line"> }</span><br><span class="line"> lastCheckpointRanks.unpersist(blocking = <span class="literal">false</span>)</span><br><span class="line"> }</span><br><span class="line"> lastCheckpointRanks = ranks</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="comment">// Final force action, like ranks.saveAsTextFile(outputPath)</span></span><br><span class="line">ranks.foreach(_ => <span class="type">Unit</span>)</span><br></pre></td></tr></table></figure></p>
<p>上述代码中每隔10次迭代,做一次checkpoint,并强制触发计算。一定要注意,在做checkpoint前,一定要对要checkpoint的RDD做cache,否则会重计算。这里简单描述下checkpoint的计算流程: 调用<code>rdd.checkpoint()</code>仅仅是标记该RDD需要做checkpoint,并不会触发计算,只有在遇到action方法后,才会触发计算,在job执行完毕后,会启动checkpoint计算,如果RDD依赖链中有RDD被标记为checkpoint,则会对这个RDD再次触发一个job执行checkpoint计算。所以在checkpoint前,对RDD做cache,可以避免checkpoint计算过程中重新根据RDD依赖链计算。在上述代码中变量<code>lastCheckpointRanks</code>记录上一次checkpoint的结果,在一次迭代完毕后,删除上一次checkpoint的结果,并更新变量<code>lastCheckpointRanks</code>。</p>
<p>为了方便测试,我每隔3次迭代做一次checkpoint,总共迭代5次,运行上述代码,整个计算过程中会有一次checkpoint,根据前面checkpoint的计算描述可知,在代码15行处会有两个job,一个是常规计算,一个是checkpoint计算,checkpoint计算是直接从缓存中拿数据写到hdfs,所以计算开销是很小的。加上最终的一个job,整个计算过程中总共有3个job,下面是测试过程中job的截图,注意图中对应的行号跟上面贴的代码没有对应关系哦。</p>
<p><img src="/images/pagerank-checkpoint-jobs.png" alt="jobs | center"></p>
<p>第一个job执行3次迭代计算,并将结果缓存起来,下面是第一个job的DAG:</p>
<p><img src="/images/pagerank-iter-3-dag-cache.png" alt="iter-3-dag-cache | center"></p>
<p>第二个job做checkpoint,由于需要checkpoint的RDD已经缓存了,所以不会重新计算,它会跳过依赖链中前面的RDD,直接从缓存中读取数据写到hdfs,所以前面的依赖链显示是灰色的:</p>
<p><img src="/images/pagerank-checkpoint-dag.png" alt="checkpoint-dag | center"></p>
<p>第三个job执行剩下的2次迭代计算,由于前3次迭代的结果已经做过checkpoint,所以这里的依赖链中不包含前3次迭代计算的依赖链,也就是说checkpoint起到了断链作用,这样driver维护的依赖链就不会越变越长了:</p>
<p><img src="/images/pagerank-after-checkpoint-dag.png" alt="after-checkpoint-dag | center"></p>
<blockquote>
<p>Tips: 对于迭代型任务,每迭代若干次后,做一次checkpoint</p>
</blockquote>
<p>到这里,我们有一个稍微比较稳定的版本了,我们称之为V2。但是,一般实际场景中,<code>links</code>可能会特别大,例如好友关系,就有近10亿的key,每个key对应的value平均应该也有100-200,全部缓存到内存对资源要求比较大,从之前文章<a href="http://km.oa.com/group/2430/articles/show/300304" target="_blank" rel="noopener">Spark Cache性能测试</a>的结论可知,我们可以选择<code>MEMORY_ONLY_SER</code>或<code>DISK_ONLY</code>的缓存方式来减少内存的使用,由于在YARN集群环境中磁盘资源是没有被隔离的,也就是说一台机器上的磁盘资源是多任务共享的,所以使用<code>DISK_ONLY</code>存在磁盘溢出的风险,还是建议使用<code>MEMORY_ONLY_SER</code>,并加上压缩参数<code>spark.rdd.compress=true</code>,这样可以大大降低内存的使用,同时性能不至于损失太多。在上面加了checkpoint的代码基础上,把所有使用cache的地方全部改成如下形式:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// Submit conf: spark.rdd.compress=true</span></span><br><span class="line">links.persist(<span class="type">StorageLevel</span>.<span class="type">MEMORY_ONLY_SER</span>).setName(<span class="string">"links"</span>)</span><br><span class="line">...</span><br><span class="line">ranks.persist(<span class="type">StorageLevel</span>.<span class="type">MEMORY_ONLY_SER</span>).setName(<span class="string">s"iter<span class="subst">$i</span>: ranks"</span>)</span><br></pre></td></tr></table></figure></p>
<p>相同资源和参数下分别使用默认的<code>MEMORY_ONLY</code>和带压缩的<code>MEMORY_ONLY_SER</code>测试3次迭代的性能,下图是使用默认的<code>MEMORY_ONLY</code>方式缓存时,<code>links</code>在内存中的大小,可以看到<code>links</code>缓存后占用了6.6G内存:</p>
<p><img src="/images/links-string-cache.png" alt="links-string-cache | center"></p>
<p>改用带压缩的<code>MEMORY_ONLY_SER</code>的缓存方式后,<code>links</code>缓存后只占用了861.8M内存,仅为之前6.6G的12%:</p>
<p><img src="/images/links-string-cache-compress.png" alt="links-string-cache-compress | center"></p>
<p>通过在日志中打印运行时间,得到使用<code>MEMORY_ONLY</code>时运行时间为333s,使用<code>MEMORY_ONLY_SER</code>时运行时间为391s,性能牺牲了17%左右,所以使用<code>MEMORY_ONLY_SER</code>是以牺牲CPU代价来换取内存的一种较为稳妥的方案。在实际使用过程中需要权衡性能以及内存资源情况。</p>
<blockquote>
<p>Tips: 内存资源较为稀缺时,缓存方式使用带压缩的<code>MEMORY_ONLY_SER</code>代替默认的<code>MEMORY_ONLY</code></p>
</blockquote>
<h2 id="优化二-数据结构"><a href="#优化二-数据结构" class="headerlink" title="优化二(数据结构)"></a>优化二(数据结构)</h2><p>在上述PageRank代码实现中,<code>links</code>中的记录为url -> urls,url类型为<code>String</code>,通常情况下,<code>String</code>占用的内存比<code>Int</code>、 <code>Long</code>等原生类型要多,在PageRank算法中,url完全可以被编码成一个<code>Long</code>型,因为在整个计算过程中根本没有用到url中的内容,这样就可以一定程度上减少<code>links</code>缓存时的内存占用。由于在我的测试数据中,url本身是由数字来表示的,所以在<a href="#opt1">优化一</a>V2代码的基础上再将<code>links</code>的定义改为如下代码,我们将该版本称之为V3:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">...</span><br><span class="line"><span class="keyword">val</span> lines = sc.textFile(inputPath)</span><br><span class="line"><span class="keyword">val</span> links = lines.map { s =></span><br><span class="line"> <span class="keyword">val</span> parts = s.split(<span class="string">"\\s+"</span>)</span><br><span class="line"> (parts(<span class="number">0</span>).trim.toLong, parts(<span class="number">1</span>).trim.toLong)</span><br><span class="line">}.distinct().groupByKey()</span><br><span class="line">links.persist(storageLevel).setName(<span class="string">"links"</span>)</span><br><span class="line">...</span><br></pre></td></tr></table></figure></p>
<p>经过测试发现,url改成<code>Long</code>型后,使用<code>MEMORY_ONLY</code>缓存时,如下图所示,<code>links</code>仅占用2.5G,相比为<code>String</code>类型时的6.6G,缩小了一半多。此外,url改成<code>Long</code>型后,运行3次迭代的时间为278s,相比为<code>String</code>类型时的333s,性能提升了17%左右。</p>
<p><img src="/images/links-long-cache.png" alt="links-long-cache | center"></p>
<p>使用带压缩的<code>MEMORY_ONLY_SER</code>缓存时,如下图所示,<code>links</code>仅占用549.5M,相比为<code>String</code>类型时的861.8M,也缩小了近一半。此外,url改成<code>Long</code>型后,运行3次迭代的时间为306s,相比为<code>String</code>类型时的391s,性能提升了21%左右。</p>
<p><img src="/images/links-long-cache-compress.png" alt="links-long-cache-compress | center"></p>
<blockquote>
<p>Tips: 实际开发中,尽可能使用原生类型,尤其是Numeric的原生类型(<code>Int</code>, <code>Long</code>等)</p>
</blockquote>
<h2 id="优化三-数据倾斜"><a href="#优化三-数据倾斜" class="headerlink" title="优化三(数据倾斜)"></a>优化三(数据倾斜)</h2><p>经过前面两个优化后,基本可以应用到线上跑了,但是,可能还不够,如果我们的数据集中有少数url链接的urls特别多,那么在使用<code>groupByKey</code>初始化<code>links</code>时,少数记录的value(urls)可能会有溢出风险,由于<code>groupByKey</code>底层是用一个<code>Array</code>保存value,如果一个节点链接了数十万个节点,那么要开一个超大的数组,即使不溢出,很可能因为没有足够大的连续内存,导致频繁GC,进而引发OOM等致命性错误,通常我们把这类问题称之为数据倾斜问题。此外,在后续迭代循环中<code>links</code>和<code>ranks</code>的<code>join</code>也可能因为数据倾斜导致部分task非常慢甚至引发OOM,下图是<code>groupByKey</code>和<code>join</code>的示意图,左边是<code>groupByKey</code>后得到每个url链接的urls,底层用数组保存,在<code>join</code>时,shuffle阶段会将来自两个RDD相同key的记录通过网络拉到一个partition中,右边显示对url1的shuffle read,如果url1对应的urls特别多,join过程将会非常慢。<br><img src="/images/pagerank-shuffle-origin.png" alt="shuffle-origin | center"></p>
<h3 id="对key进行分桶"><a href="#对key进行分桶" class="headerlink" title="对key进行分桶"></a>对key进行分桶</h3><p>首先我们应该考虑避免使用<code>groupByKey</code>,这是导致后续数据倾斜的源头。既然可能存在单个key对应的value(urls)特别多,那么可以将key做一个随机化处理,例如将具有相同key的记录随机分配到10个桶中,这样就相当于把数据倾斜的记录给打散了,其大概原理如下图所示。</p>
<p><img src="/images/pagerank-random-int-skew.png" alt="random-int-skew | center"></p>
<p>基于上面的理论基础,我们先得到不用<code>groupByKey</code>的<code>links</code>:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> lines = sc.textFile(inputPath)</span><br><span class="line"><span class="keyword">val</span> links = lines.map { s =></span><br><span class="line"> <span class="keyword">val</span> parts = s.split(<span class="string">"\\s+"</span>)</span><br><span class="line"> (parts(<span class="number">0</span>).trim.toLong, parts(<span class="number">1</span>).trim.toLong)</span><br><span class="line">}.distinct()</span><br><span class="line">links.persist(storageLevel).setName(<span class="string">"links"</span>)</span><br></pre></td></tr></table></figure></p>
<p>再分析前面代码里的迭代循环,发现我们之前使用<code>groupByKey</code>很大一部分原因是想要得到每个key对应的urls size,我们可以单独通过<code>reduceByKey</code>来得到,<code>reduceByKey</code>会做本地combine,这个操作shuffle开销很小的:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// Count of each url's outs</span></span><br><span class="line"><span class="keyword">val</span> outCnts = links.mapValues(_ => <span class="number">1</span>).reduceByKey(_ + _)</span><br><span class="line">outCnts.persist(storageLevel).setName(<span class="string">"out-counts"</span>)</span><br></pre></td></tr></table></figure></p>
<p>现在我们就可以使用<code>cogroup</code>将<code>links</code>、<code>outCnts</code>以及<code>ranks</code>三者join起来了,很快我们会想到使用如下代码:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> contribs = links.cogroup(outCnts, ranks).values.flatMap { pair =></span><br><span class="line"> <span class="keyword">for</span> (u <- pair._1.iterator; v <- pair._2.iterator; w <- pair._3.iterator)</span><br><span class="line"> <span class="keyword">yield</span> (u, w/v)</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>但是!但是!但是!这样做还是会跟之前一样出现数据倾斜,因为<code>cogroup</code>执行过程中,在shuffle阶段还是会把<code>links</code>中相同key的记录分到同一个partition,也就说上面代码<code>pair._1.iterator</code>也可能非常大,这个<code>iterator</code>底层也是<code>Array</code>,面临的问题基本没解决。</p>
<p>所以我们就要考虑使用前面介绍的分桶方法了,对<code>links</code>中的每条记录都随机打散到10个桶中,那么相同key的记录就会被随机分到不同桶中了:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">keyWithRandomInt</span></span>[<span class="type">K</span>, <span class="type">V</span>](rdd: <span class="type">RDD</span>[(<span class="type">K</span>, <span class="type">V</span>)]): <span class="type">RDD</span>[((<span class="type">K</span>, <span class="type">Int</span>), <span class="type">V</span>)] = {</span><br><span class="line"> rdd.map(x => ((x._1, <span class="type">Random</span>.nextInt(<span class="number">10</span>)), x._2))</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>然而,cogroup是按照key进行join的,就是说它把来自多个RDD具有相同key的记录汇聚到一起计算,既然<code>links</code>的key已经被我们改变了,那么<code>outCnts</code>和<code>ranks</code>也要变成跟<code>links</code>相同的形式,才能join到一起去计算:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">expandKeyWithRandomInt</span></span>[<span class="type">K</span>, <span class="type">V</span>](rdd: <span class="type">RDD</span>[(<span class="type">K</span>, <span class="type">V</span>)])</span><br><span class="line"> : <span class="type">RDD</span>[((<span class="type">K</span>, <span class="type">Int</span>), <span class="type">V</span>)] = {</span><br><span class="line"> rdd.flatMap { x =></span><br><span class="line"> <span class="keyword">for</span> (i <- <span class="number">0</span> until <span class="number">10</span>)</span><br><span class="line"> <span class="keyword">yield</span> ((x._1, i), x._2)</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>有了这个基础后,我们就可以将前面的<code>cogroup</code>逻辑修改一下,让他们能够顺利join到一块儿去:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> contribs = keyWithRandomInt(links).cogroup(</span><br><span class="line"> expandKeyWithRandomInt(outCnts),</span><br><span class="line"> expandKeyWithRandomInt(ranks)</span><br><span class="line">).values.flatMap { pair =></span><br><span class="line"> <span class="keyword">for</span> (u <- pair._1.iterator; v <- pair._2.iterator; w <- pair._3.iterator)</span><br><span class="line"> <span class="keyword">yield</span> (u, w/v)</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>我们将该版本称之为V4,将上述逻辑整理成如下图,可以看到,其实我们对<code>outCnts</code>和<code>ranks</code>做了膨胀处理,才能保证<code>cogroup</code>shuffle阶段对于<code>links</code>中的每条记录,都能找到与之对应的<code>outCnts</code>和<code>ranks</code>记录。</p>
<p><img src="/images/pagerank-shuffle-skewed-process.png" alt="shuffle-skewed-process | center"></p>
<p>其实这种做法会极大地损失性能,虽然这样做可能把之前OOM的问题搞定,能够不出错的跑完,但是由于数据膨胀,实际跑起来是非常慢的,不建议采用这种方法处理数据倾斜问题。这里仅仅引出一些问题让我们更多地去思考。</p>
<h3 id="拆分发生倾斜的key"><a href="#拆分发生倾斜的key" class="headerlink" title="拆分发生倾斜的key"></a>拆分发生倾斜的key</h3><p>有了前面的分析基础,我们知道对key分桶的方法,是不加区分地对所有key都一股脑地处理了,把不倾斜的key也当做倾斜来处理了,其实大部分实际情况下,只有少数key有倾斜,如果大部分key都倾斜那就不是数据倾斜了,那叫数据量特别大。所以我们可以考虑对倾斜的key和不倾斜的key分别用不同的处理逻辑,对不倾斜的key,还是用原来<code>groupByKey</code>和<code>join</code>方式来处理,对倾斜的key可以考虑使用<code>broadcast</code>来实现map join,因为倾斜的key一般来说是可数的,其对应的<code>outCnts</code>和<code>ranks</code>信息在我们PageRank场景里也不会很大,所以可以使用广播。</p>
<p>首先我们把链接的urls个数超过1000000的key定义为倾斜key,使用下面代码将<code>links</code>切分为两部分:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> lines = sc.textFile(path)</span><br><span class="line"><span class="keyword">val</span> links = lines.map { s =></span><br><span class="line"> <span class="keyword">val</span> parts = s.split(<span class="string">"\\s+"</span>)</span><br><span class="line"> (parts(<span class="number">0</span>).trim.toLong, parts(<span class="number">1</span>).trim.toLong)</span><br><span class="line">}.distinct()</span><br><span class="line">links.persist(storageLevel).setName(<span class="string">"links"</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment">// Count of each url's outs</span></span><br><span class="line"><span class="keyword">val</span> outCnts = links.mapValues(_ => <span class="number">1</span>L).reduceByKey(_ + _)</span><br><span class="line"> .persist(storageLevel).setName(<span class="string">"out-counts"</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment">// Init ranks</span></span><br><span class="line"><span class="keyword">var</span> ranks = outCnts.mapValues(_ => <span class="number">1.0</span>)</span><br><span class="line"> .persist(storageLevel).setName(<span class="string">"init-ranks"</span>)</span><br><span class="line"><span class="comment">// Force action, just for trigger calculation</span></span><br><span class="line">ranks.foreach(_ => <span class="type">Unit</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">val</span> skewedOutCnts = outCnts.filter(_._2 >= <span class="number">1000000</span>).collectAsMap()</span><br><span class="line"><span class="keyword">val</span> bcSkewedOutCnts = sc.broadcast(skewedOutCnts)</span><br><span class="line"></span><br><span class="line"><span class="keyword">val</span> skewed = links.filter { link =></span><br><span class="line"> <span class="keyword">val</span> cnts = bcSkewedOutCnts.value</span><br><span class="line"> cnts.contains(link._1)</span><br><span class="line">}.persist(storageLevel).setName(<span class="string">"skewed-links"</span>)</span><br><span class="line"><span class="comment">// Force action, just for trigger calculation</span></span><br><span class="line">skewed.foreach(_ => <span class="type">Unit</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">val</span> noSkewed = links.filter { link =></span><br><span class="line"> <span class="keyword">val</span> cnts = bcSkewedOutCnts.value</span><br><span class="line"> !cnts.contains(link._1)</span><br><span class="line">}.groupByKey().persist(storageLevel).setName(<span class="string">"no-skewed-links"</span>)</span><br><span class="line"><span class="comment">// Force action, just for trigger calculation</span></span><br><span class="line">noSkewed.foreach(_ => <span class="type">Unit</span>)</span><br><span class="line"></span><br><span class="line">links.unpersist(blocking = <span class="literal">false</span>)</span><br></pre></td></tr></table></figure></p>
<p>首先统计出链接数超过1000000的key,广播到每个计算节点,然后过滤<code>links</code>,如果key在广播变量中则为倾斜的数据,否则为非倾斜的数据,过滤完毕后原始<code>links</code>被销毁。下面就可以在迭代循环中分别处理倾斜的数据<code>skewed</code>和非倾斜的数据<code>noSkewed</code>了。</p>
<p>对<code>noSkewed</code>使用原来的方法:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> noSkewedPart = noSkewed.join(ranks).values.flatMap {</span><br><span class="line"> <span class="keyword">case</span> (urls, rank) =></span><br><span class="line"> <span class="keyword">val</span> size = urls.size</span><br><span class="line"> urls.map(url => (url, rank / size))</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>对<code>skewed</code>使用<code>broadcast</code>方式实现map join,类似地,要把倾斜的key对应的rank收集起来广播,之前的<code>cogroup</code>中的<code>outCnts</code>和<code>ranks</code>在这里就都被广播了,所以可以直接在<code>map</code>操作里完成对<code>skewed</code>中的数据处理:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> skewedRanks = ranks.filter { rank =></span><br><span class="line"> <span class="keyword">val</span> cnts = bcSkewedOutCnts.value</span><br><span class="line"> cnts.contains(rank._1)</span><br><span class="line">}.collectAsMap()</span><br><span class="line"><span class="keyword">val</span> bcSkewedRanks = sc.broadcast(skewedRanks)</span><br><span class="line"><span class="keyword">val</span> skewedPart = skewed.map { link =></span><br><span class="line"> <span class="keyword">val</span> cnts = bcSkewedOutCnts.value</span><br><span class="line"> <span class="keyword">val</span> ranks = bcSkewedRanks.value</span><br><span class="line"> (link._2, ranks(link._1)/cnts(link._1))</span><br><span class="line">}</span><br></pre></td></tr></table></figure></p>
<p>最后将两部分的处理结果<code>union</code>一下:<br><figure class="highlight scala"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">val</span> contribs = noSkewedPart.union(skewedPart)</span><br></pre></td></tr></table></figure></p>
<p>后面的逻辑就跟前面一样了,我们将该版本称之为V5。分别测V3和V5版本代码,迭代3次,在没有数据倾斜的情况下,相同数据、资源和参数下V3运行时间306s,V5运行时间311s,但是在有数据倾斜的情况下,相同数据、资源和参数下V3运行时间722s并伴有严重的GC,V5运行时间472s。可以发现V5版本在不牺牲性能的情况可以解决数据倾斜问题,同时还能以V3相同的性能处理不倾斜的数据集,所以说V5版本更具通用性。</p>
<blockquote>
<p>Tips: 对有倾斜的数据集,将倾斜的记录和非倾斜的记录切分,对倾斜的记录使用map join来解决由于数据倾斜导致少数task非常慢的问题</p>
</blockquote>
<h2 id="优化四-资源利用最大化"><a href="#优化四-资源利用最大化" class="headerlink" title="优化四(资源利用最大化)"></a>优化四(资源利用最大化)</h2><p>通过前面几个优化操作后,V5版本基本可以用于线上例行化跑作业了,但是部署到线上集群,面临如何给资源的困扰。为了测试方便,测试数据集中没有数据倾斜,下面就拿V5来测试并监控资源利用情况。</p>
<p>原始测试数据(使用带压缩的<code>MEMORY_ONLY_SER</code>缓存方式)情况如下表:</p>
<table>
<thead>
<tr>
<th style="text-align:center">磁盘中大小</th>
<th style="text-align:center"><code>links</code>缓存大小</th>
<th style="text-align:center">分区数</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:center">1.5g</td>
<td style="text-align:center">549.5M</td>
<td style="text-align:center">20</td>
</tr>
</tbody>
</table>
<p>运行3次迭代,一开始大概估计使用如下资源,使用5个executor,每个executor配2个core,一次并行运行10个partition,20个partition 2轮task就可以跑完:</p>
<table>
<thead>
<tr>
<th style="text-align:center">driver_mem</th>
<th style="text-align:center">num_executor</th>
<th style="text-align:center">executor_mem</th>
<th style="text-align:center">executor_cores</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:center">4g</td>
<td style="text-align:center">5</td>
<td style="text-align:center">2g</td>
<td style="text-align:center">2</td>
</tr>
</tbody>
</table>
<p>在提交参数中加上如下额外JVM参数,表示分别对driver和executor在运行期间开启<a href="https://docs.oracle.com/javacomponents/index.html" target="_blank" rel="noopener">Java Flight Recorder</a>:<br><figure class="highlight vim"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">spark.driver.extraJavaOptions -XX:+UnlockCommercialFeatures -XX:+FlightRecorder -XX:StartFlightRecording=filename=<span class="symbol"><LOG_DIR></span>/driver.jfr,dumponexit=true</span><br><span class="line">spark.executor.extraJavaOptions -XX:+UnlockCommercialFeatures -XX:+FlightRecorder -XX:StartFlightRecording=filename=<span class="symbol"><LOG_DIR></span>/excutor.jfr,dumponexit=true</span><br></pre></td></tr></table></figure></p>
<p>运行完毕后,统计运行时间为439s,将<code>driver.jfr</code>和<code>excutor.jfr</code>拿到开发机上来,打开jmc分析工具(位于java安装目录<code>bin/</code>下面),首先我们看driver的监控信息,主页如下图所示,可以看到driver的cpu占用是很小的:</p>
<p><img src="/images/driver-control.png" alt="driver-control | center"></p>
<p>切到内存tab,把物理内存的两个勾选去掉,可以看到driver的内存使用曲线,我们给了4g,但是实际上最大也就用了差不多1g,看下图中的GC统计信息,没有什么瓶颈。</p>
<p><img src="/images/driver-heap.png" alt="driver-heap | center"></p>
<p>所以给driver分配4g是浪费的,我们把它调到2g,虽然实际上只用了大概1g,这里多给driver留点余地,其他配置不变,重新提交程序,统计运行时间为443s,跟4g时运行时间439s差不多。</p>
<p>再来看executor的监控信息,主页如下图所示,可以看到executor的cpu利用明显比driver多,因为要做序列化、压缩以及排序等。</p>
<p><img src="/images/executor-control.png" alt="executor-control | center"></p>
<p>再切到内存tab,可以看到executor的内存使用波动较大,最大内存使用差不多1.75g,我们给了2g,还是相当合适的。但是看下面的GC统计信息,发现最长暂停4s多,而且垃圾回收次数也较多。</p>
<p><img src="/images/executor-heap.png" alt="executor-heap | center"></p>
<p>为此,我们切到”GC时间”tab,可以看到,GC还是比较频繁的,还有一次持续4s多的GC,看右边GC类型,对最长暂停时间从大到小排序,居然有几个SerialOld类型的GC,其他一部分是ParNew类型GC,一部分是CMS类型的GC,没有出现FULL GC,下面先分析内存使用,回过头来再分析这里出现的诡异SerialOld。</p>
<p><img src="/images/executor-gc.png" alt="executor-gc | center"></p>
<p>我们再看下堆内存大对象占用情况,大对象主要是在<code>ExternalAppendOnlyMap</code>和<code>ExternalSorter</code>中,<code>ExternalAppendOnlyMap</code>用于存放shuffle read的数据,<code>ExternalSorter</code>用于存放shuffle write前的数据,用于对记录排序,这两个数据结构底层使用<code>Array</code>存储数据,所以这里表现为大对象。</p>
<p><img src="/images/executor-heap-info.png" alt="executor-heap-info | center"></p>
<p>切换到TLAB,再细化到小对象,可以看到大部分是<code>Long</code>型(url),展开堆栈跟踪,大部分是用在shuffle阶段,因为在<code>join</code>时,一方面会读取<code>groupByKey</code>后的<code>links</code>,用于做shuffle write,一方面在shuffle read阶段,将相同key的<code>links</code>和<code>ranks</code>拉到一起做<code>join</code>计算。</p>
<p><img src="/images/executor-heap-info2.png" alt="executor-heap-info2 | center"></p>
<p>所以总体来说,内存情况是符合业务逻辑的,没有出现莫名其妙的内存占用。让人有点摸不清头脑的是,GC信息中有SerialOld这玩意儿,我明明用了CMS垃圾回收方式,经过一番Google查阅资料,”Concurrent Mode Failure”可能导致Serial Old的出现,查阅”Concurrent Mode Failure”发生的原因: 当CMS GC正进行时,此时有新的对象要进入老年代,但是老年代空间不足。仔细分析,个人觉得可能是因为CMS GC后存在较多的内存碎片,而我们的程序在shuffle阶段底层使用<code>Array</code>,需要连续内存,导致CMS GC过程中出现了”Concurrent Mode Failure”,才退化到Serial Old,Serial Old是采用标记整理回收算法,回收过程中会整理内存碎片。这样看来,应该是CMS GC过程中,老年代空间不足导致的,从两个方面考虑优化下,一是增加老年代内存占比,二是减小参数<code>-XX:CMSInitiatingOccupancyFraction</code>,降低触发CMS GC的阈值,让CMS GC及早回收老年代。</p>
<p>首先我们增加老年代内存占比,也就是降低新生代内存占比,默认<code>-XX:NewRatio=2</code>,我们把它改成<code>-XX:NewRatio=3</code>,将老年代内存占比由2/3提升到3/4,重新提交程序,得到<code>executor.jfr</code>,打开GC监控信息,发现有很大的改善,不在出现Serial Old类型的GC了,最长暂停时间从原来的4s降低到600ms左右,整体运行时间从448s降低到436s。</p>
<p><img src="/images/executor-gc2.png" alt="executor-gc2 | center"></p>
<p>把上述<code>-XX:NewRatio=3</code>去掉,设置参数<code>-XX:CMSInitiatingOccupancyFraction=60</code>,重新提交程序,得到executor GC的监控信息,发现GC最大暂停时间也降下来了,但是由于老年代GC的频率加大了,整体运行时间为498s,比原来的436s还要长。</p>
<p><img src="/images/executor-gc3.png" alt="executor-gc3 | center"></p>
<p>综合考虑以上信息,增加executor的jvm启动参数<code>-XX:NewRatio=3</code>,能把GC状态调整到一个较优的状态。</p>
<h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>Spark给我们提供了一种简单灵活的大数据编程框架,但是对于很多实际问题的处理,还应该多思考下如何让我们写出来的应用程序更高效更节约,以上几个调优点是可以推广到其他应用的,在我们编写spark应用程序时,通过这种思考也可以加速我们对spark的理解。</p>
<p><span style="color:red"><em>转载请注明出处,本文永久链接:<a href="https://sharkdtu.github.io/posts/spark-app-optimize.html">https://sharkdtu.github.io/posts/spark-app-optimize.html</a></em></span></p>
]]></content>
<categories>
<category> spark </category>
</categories>
<tags>
<tag> spark </tag>
<tag> 分布式计算 </tag>
<tag> benchmark </tag>
<tag> 优化 </tag>
</tags>
</entry>
<entry>
<title><![CDATA[Spark Cache性能测试]]></title>
<url>https://sharkdtu.github.io/posts/spark-cache-benchmark.html</url>
<content type="html"><![CDATA[<p>采用Spark自带的Kmeans算法作为测试基准(Spark版本为2.1),该算法Shuffle数据量较小,对于这类迭代型任务,又需要多次加载训练数据,此测试的目的在于评判各种Cache IO的性能,并总结其Spark内部原理作分析,作为Spark用户的参考。<a id="more"></a></p>
<h2 id="测试准备"><a href="#测试准备" class="headerlink" title="测试准备"></a>测试准备</h2><p>训练数据是通过<a href="http://prof.ict.ac.cn/BigDataBench/dowloads/" target="_blank" rel="noopener">Facebook SNS公开数据集生成器</a>得到,在HDFS上大小为9.3G,100个文件,添加如下两个参数以保证所有资源全部到位后才启动task,训练时间为加载数据到训练完毕这期间的耗时。</p>
<figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">--conf spark.scheduler.minRegisteredResourcesRatio=1</span><br><span class="line">--conf spark.scheduler.maxRegisteredResourcesWaitingTime=100000000</span><br></pre></td></tr></table></figure>
<p>测试集群为3个节点的TS5机器搭建而成,其中一台作为RM,并运行着Alluxio Master,两个NM上同时运行着Alluxio Worker。除以上配置外,其他配置全部保持Spark默认状态。公共资源配置、分区设置以及算法参数如下表所示,executor_memory视不同的测试用例不同:</p>
<table>
<thead>
<tr>
<th style="text-align:center">driver_memory</th>
<th style="text-align:center">num_executor</th>
<th style="text-align:center">executor_cores</th>
<th style="text-align:center">分区数</th>
<th style="text-align:center">聚类个数</th>
<th style="text-align:center">迭代次数</th>
</tr>
</thead>
<tbody>