Skip to content

Commit

Permalink
[CELEBORN-1150] support io encryption for spark
Browse files Browse the repository at this point in the history
  • Loading branch information
FMX committed Dec 6, 2023
1 parent 406cef8 commit 075f544
Show file tree
Hide file tree
Showing 19 changed files with 522 additions and 24 deletions.
1 change: 1 addition & 0 deletions client-spark/spark-2-shaded/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>org.apache.commons:commons-crypto</include>
</includes>
</artifactSet>
<filters>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;

import scala.Int;
import scala.Option;

import org.apache.spark.*;
import org.apache.spark.internal.config.package$;
import org.apache.spark.launcher.SparkLauncher;
import org.apache.spark.rdd.DeterministicLevel;
import org.apache.spark.shuffle.*;
Expand All @@ -35,6 +39,7 @@

import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.security.CryptoUtils;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.common.util.ThreadUtils;
Expand Down Expand Up @@ -99,7 +104,29 @@ private SortShuffleManager sortShuffleManager() {
return _sortShuffleManager;
}

private void initializeLifecycleManager(String appId) {
private Properties getIoCryptoConf() {
if (!celebornConf.ioEncryptionEnabled()) return new Properties();
Properties cryptoConf = CryptoStreamUtils.toCryptoConf(conf);
cryptoConf.put(
CryptoUtils.COMMONS_CRYPTO_CONFIG_TRANSFORMATION,
conf.get(package$.MODULE$.IO_CRYPTO_CIPHER_TRANSFORMATION()));
return cryptoConf;
}

private Optional<byte[]> getIoCryptoKey() {
if (!celebornConf.ioEncryptionEnabled()) return Optional.empty();
Option<byte[]> key = SparkEnv.get().securityManager().getIOEncryptionKey();
return key.isEmpty() ? Optional.empty() : Optional.ofNullable(key.get());
}

private byte[] getIoCryptoInitializationVector() {
if (!celebornConf.ioEncryptionEnabled()) return null;
return conf.getBoolean(package$.MODULE$.IO_ENCRYPTION_ENABLED().key(), false)
? CryptoUtils.createIoCryptoInitializationVector()
: null;
}

private void initializeLifecycleManager(String appId, byte[] ioCryptoInitializationVector) {
// Only create LifecycleManager singleton in Driver. When register shuffle multiple times, we
// need to ensure that LifecycleManager will only be created once. Parallelism needs to be
// considered in this place, because if there is one RDD that depends on multiple RDDs
Expand All @@ -126,7 +153,8 @@ public <K, V, C> ShuffleHandle registerShuffle(
// is the same SparkContext among different shuffleIds.
// This method may be called many times.
appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context());
initializeLifecycleManager(appUniqueId);
byte[] iv = getIoCryptoInitializationVector();
initializeLifecycleManager(appUniqueId, iv);

lifecycleManager.registerAppShuffleDeterminate(
shuffleId,
Expand All @@ -146,7 +174,8 @@ public <K, V, C> ShuffleHandle registerShuffle(
shuffleId,
celebornConf.clientFetchThrowsFetchFailure(),
numMaps,
dependency);
dependency,
iv);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,23 @@ public static CelebornConf fromSparkConf(SparkConf conf) {
}

public static String appUniqueId(SparkContext context) {
return appUniqueId(context, fromSparkConf(context.conf()));
}

public static String appUniqueId(SparkContext context, CelebornConf celebornConf) {
String appUniqueId = "";
if (context.applicationAttemptId().isDefined()) {
return context.applicationId() + "_" + context.applicationAttemptId().get();
appUniqueId = context.applicationId() + "_" + context.applicationAttemptId().get();
} else {
return context.applicationId();
appUniqueId = context.applicationId();
}
}

public static String appUniqueId(SparkContext context, CelebornConf celebornConf) {
String appUniqueId = appUniqueId(context);
return celebornConf.appIdWithIdentifierPrefix(appUniqueId);
}

public static String getAppShuffleIdentifier(int appShuffleId, TaskContext context) {
return appShuffleId + "-" + context.stageId() + "-" + context.stageAttemptNumber();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,23 @@ class CelebornShuffleHandle[K, V, C](
shuffleId: Int,
val throwsFetchFailure: Boolean,
numMappers: Int,
dependency: ShuffleDependency[K, V, C])
extends BaseShuffleHandle(shuffleId, numMappers, dependency)
dependency: ShuffleDependency[K, V, C],
val ioCryptoInitializationVector: Array[Byte])
extends BaseShuffleHandle(shuffleId, numMappers, dependency) {
def this(
appUniqueId: String,
lifecycleManagerHost: String,
lifecycleManagerPort: Int,
userIdentifier: UserIdentifier,
shuffleId: Int,
numMappers: Int,
dependency: ShuffleDependency[K, V, C]) = this(
appUniqueId,
lifecycleManagerHost,
lifecycleManagerPort,
userIdentifier,
shuffleId,
numMappers,
dependency,
null)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.shuffle.celeborn

import java.io.IOException
import java.util.{Optional, Properties}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle.celeborn

import java.util.{Optional, Properties}

import org.apache.spark.{ShuffleDependency, TaskContext}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
Expand All @@ -34,7 +36,9 @@ class CelebornColumnarShuffleReader[K, C](
context: TaskContext,
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
shuffleIdTracker: ExecutorShuffleIdTracker)
shuffleIdTracker: ExecutorShuffleIdTracker,
ioCryptoKey: Optional[Array[Byte]],
ioCryptoConf: Properties)
extends CelebornShuffleReader[K, C](
handle,
startPartition,
Expand All @@ -44,7 +48,9 @@ class CelebornColumnarShuffleReader[K, C](
context,
conf,
metrics,
shuffleIdTracker) {
shuffleIdTracker,
ioCryptoKey,
ioCryptoConf) {

override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.shuffle.celeborn

import java.util.Comparator.nullsFirst

import com.esotericsoftware.kryo.serializers.FieldSerializer.Optional
import org.apache.spark.{ShuffleDependency, SparkConf}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
Expand Down Expand Up @@ -55,7 +58,9 @@ class CelebornColumnarShuffleReaderSuite {
null,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
new ExecutorShuffleIdTracker(),
java.util.Optional.empty(),
null)
assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]])
} finally {
if (shuffleClient != null) {
Expand All @@ -78,6 +83,7 @@ class CelebornColumnarShuffleReaderSuite {
0,
false,
10,
null,
null),
0,
10,
Expand All @@ -86,7 +92,9 @@ class CelebornColumnarShuffleReaderSuite {
null,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
new ExecutorShuffleIdTracker(),
java.util.Optional.empty(),
null)
val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]])
Mockito.when(shuffleDependency.shuffleId).thenReturn(0)
Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer(
Expand Down
1 change: 1 addition & 0 deletions client-spark/spark-3-shaded/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>org.apache.commons:commons-crypto</include>
</includes>
</artifactSet>
<filters>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.spark.*;
import org.apache.spark.internal.config.package$;
import org.apache.spark.launcher.SparkLauncher;
import org.apache.spark.rdd.DeterministicLevel;
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.shuffle.*;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.sql.internal.SQLConf;
Expand All @@ -34,6 +38,7 @@

import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.security.CryptoUtils;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.common.util.ThreadUtils;
Expand Down Expand Up @@ -130,7 +135,32 @@ private SortShuffleManager sortShuffleManager() {
return _sortShuffleManager;
}

private void initializeLifecycleManager() {
private Properties getIoCryptoConf() {
if (!celebornConf.ioEncryptionEnabled()) return new Properties();
Properties cryptoConf = CryptoStreamUtils.toCryptoConf(conf);
cryptoConf.put(
CryptoUtils.COMMONS_CRYPTO_CONFIG_TRANSFORMATION,
conf.get(package$.MODULE$.IO_CRYPTO_CIPHER_TRANSFORMATION()));
return cryptoConf;
}

private Optional<byte[]> getIoCryptoKey() {
if (!celebornConf.ioEncryptionEnabled()) return Optional.empty();
return SparkEnv.get()
.securityManager()
.getIOEncryptionKey()
.map(key -> Optional.ofNullable(key))
.getOrElse(() -> Optional.empty());
}

private byte[] getIoCryptoInitializationVector() {
if (!celebornConf.ioEncryptionEnabled()) return null;
return conf.getBoolean(package$.MODULE$.IO_ENCRYPTION_ENABLED().key(), false)
? CryptoUtils.createIoCryptoInitializationVector()
: null;
}

private void initializeLifecycleManager(byte[] ioCryptoInitializationVector) {
// Only create LifecycleManager singleton in Driver. When register shuffle multiple times, we
// need to ensure that LifecycleManager will only be created once. Parallelism needs to be
// considered in this place, because if there is one RDD that depends on multiple RDDs
Expand All @@ -139,6 +169,16 @@ private void initializeLifecycleManager() {
synchronized (this) {
if (lifecycleManager == null) {
lifecycleManager = new LifecycleManager(appUniqueId, celebornConf);
shuffleClient =
ShuffleClient.get(
appUniqueId,
lifecycleManager.getHost(),
lifecycleManager.getPort(),
celebornConf,
lifecycleManager.getUserIdentifier(),
getIoCryptoKey(),
getIoCryptoConf(),
ioCryptoInitializationVector);
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
Expand All @@ -158,7 +198,8 @@ public <K, V, C> ShuffleHandle registerShuffle(
// is the same SparkContext among different shuffleIds.
// This method may be called many times.
appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context());
initializeLifecycleManager();
byte[] iv = getIoCryptoInitializationVector();
initializeLifecycleManager(iv);

lifecycleManager.registerAppShuffleDeterminate(
shuffleId,
Expand Down Expand Up @@ -187,7 +228,8 @@ public <K, V, C> ShuffleHandle registerShuffle(
shuffleId,
celebornConf.clientFetchThrowsFetchFailure(),
dependency.rdd().getNumPartitions(),
dependency);
dependency,
iv);
}
}

Expand Down Expand Up @@ -242,7 +284,10 @@ public <K, V> ShuffleWriter<K, V> getWriter(
h.lifecycleManagerHost(),
h.lifecycleManagerPort(),
celebornConf,
h.userIdentifier());
h.userIdentifier(),
getIoCryptoKey(),
getIoCryptoConf(),
h.ioCryptoInitializationVector());
int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, true);
shuffleIdTracker.track(h.shuffleId(), shuffleId);

Expand Down Expand Up @@ -371,7 +416,9 @@ public <K, C> ShuffleReader<K, C> getCelebornShuffleReader(
context,
celebornConf,
metrics,
shuffleIdTracker);
shuffleIdTracker,
getIoCryptoKey(),
getIoCryptoConf());
} else {
return new CelebornShuffleReader<>(
h,
Expand All @@ -382,7 +429,9 @@ public <K, C> ShuffleReader<K, C> getCelebornShuffleReader(
context,
celebornConf,
metrics,
shuffleIdTracker);
shuffleIdTracker,
getIoCryptoKey(),
getIoCryptoConf());
}
}

Expand Down
Loading

0 comments on commit 075f544

Please sign in to comment.