Source code for xframes.spark_context

"""
Provides functions to create and maintain the spark context.
"""

import os
import atexit
from zipfile import PyZipFile
from tempfile import NamedTemporaryFile
import logging

from xframes.environment import Environment
from xframes.xrdd import XRdd

def get_xframes_home():
    import xframes
    return os.path.dirname(xframes.__file__)


# CommonSparkContext wraps SparkContext, which must only be instantiated once in a program.
# This is used as a metaclass for CommonSparkContext, so that only one
#  instance is created.
class Singleton(type):
    def __init__(cls, name, bases, dictionary):
        super(Singleton, cls).__init__(name, bases, dictionary)
        cls.instance = None

    def __call__(cls, *args):
        if cls.instance is None:
            cls.instance = super(Singleton, cls).__call__(*args)
        return cls.instance


# noinspection PyClassHasNoInit
[docs]class SparkInitContext: """ Spark Context initialization. This may be used to initialize the spark context with the supplied values. If this mechanism is not used, then the spark context will be initialized using the config file the first time a context is needed. """ context = {} @staticmethod
[docs] def set(context): """ Sets the spark context parameters, and then creates a context. If the spark context has already been created, then this will have no effect. Parameters ---------- context : dict Dictionary of property/value pairs. These are passed to spark as config parameters. If a config file is present, these parameters will override the parameters there. Notes ----- The following values are the most commonly used. They will be given default values if none are supplied in a configuration file. Other values can be found in the spark configuration documentation. spark.master : str, optional The url of the spark cluster to use. To use the local spark, give 'local'. To use a spark cluster with its master on a specific IP address, give the IP address or the hostname as in the following examples: spark.master=spark://my_spark_host:7077 spark.master=mesos://my_mesos_host:5050 app.name : str, optional The app name is used on the job monitoring server, and for logging. spark.cores.max : str, optional The maximum number of cores to use for execution. spark.executor.memory : str, optional The amount of main memory to allocate to executors. For example, '2g'. """ SparkInitContext.context = context CommonSparkContext()
def create_spark_config(env): # Create the context for configuring spark from xframes.utils import merge_dicts default_context = {'spark.master': 'local[*]', 'spark.app.name': 'XFrames'} # get values from [spark] section and from SparkInitContext config_context = env.get_config_items('spark') context = merge_dicts(default_context, config_context) context = merge_dicts(context, SparkInitContext.context) app_name = os.environ.get('SPARK_APP_NAME', None) if app_name is not None: context['SPARK_APP_NAME'] = app_name return context class CommonSparkContext(object): __metaclass__ = Singleton def __init__(self): """ Create a spark context. The spark configuration is taken from xframes/config.ini and from the values set in SparkInitContext.set() if this has been called. """ # This is placed here because otherwise it causes an error when used in a spark slave. from pyspark import SparkConf, SparkContext, SQLContext, HiveContext # This reads from default.ini and then $XFRAMES_CONFIG_DIR/config.ini # if they exist. self._env = Environment.create() context = create_spark_config(self._env) verbose = self._env.get_config('xframes', 'verbose', 'false').lower() == 'true' hdfs_user_name = self._env.get_config('webhdfs', 'user', 'hdfs') os.environ['HADOOP_USER_NAME'] = hdfs_user_name config_pairs = [(k, v) for k, v in context.iteritems()] self._config = (SparkConf().setAll(config_pairs)) self._sc = SparkContext(conf=self._config) # Create these when needed self._sqlc = None self._hivec = None self._streamingc = None if verbose: actual_config = self._config.getAll() print('Spark Config:') for cfg in actual_config: print(' {}: {}'.format(cfg[0], cfg[1])) self.zip_path = [] version = [int(n) for n in self._sc.version.split('.')] self.status_tracker = self._sc.statusTracker() if cmp(version, [1, 4, 1]) >= 0: self.application_id = self._sc.applicationId else: self.application_id = None if verbose: print('Spark Version: {}'.format(self._sc.version)) if self.application_id: print('Application Id: {}'.format(self.application_id)) print('Application Name: {}'.format(self._sc.appName)) if not context['spark.master'].startswith('local'): zip_path = self._build_zip(get_xframes_home()) if zip_path: self._sc.addPyFile(zip_path) self.zip_path.append(zip_path) trace_flag = self._env.get_config('xframes', 'rdd-trace', 'false').lower() == 'true' XRdd.set_trace(trace_flag) atexit.register(self.close_context) def spark_add_files(self, dirs): """ Adds python files in the given directory or directories. Parameters ---------- dirs: str or list(str) If a str, the pathname to a directory containing a python module. If a list, then it is a list of such directories. The python files in each directory are compiled, packed into a zip, distributed to each spark slave, and placed in PYTHONPATH. This is only done if spark is deployed on a cluster. """ props = self.config() if props.get('spark.master', 'local').startswith('local'): return if isinstance(dirs, basestring): dirs = [dirs] for path in dirs: zip_path = self._build_zip(path) if zip_path: self._sc.addPyFile(zip_path) self.zip_path.append(zip_path) def close_context(self): if self._sc: self._sc.stop() self._sc = None for zip_path in self.zip_path: os.remove(zip_path) def _get_config(self): props = self._config.getAll() return {prop[0]: prop[1] for prop in props} def _get_env(self): return self._env def _get_sqlc(self): from pyspark import SQLContext if self._sqlc is None: self._sqlc = SQLContext(self._sc) return self._sqlc def _get_hivec(self): from pyspark import HiveContext if self._hivec is None: self._hivec = HiveContext(self._sc) return self._hivec def _get_streamingc(self, interval=1): from pyspark.streaming import StreamingContext if self._streamingc is None: self._streamingc = StreamingContext(self._sc, interval) return self._streamingc def _get_version(self): return [int(n) for n in self._sc.version.split('.')] def _get_jobs(self): return {job_id: self.status_tracker.getJobInfo(job_id) for job_id in self.status_tracker.getActiveJobIds()} def _get_cluster_mode(self): return not self._config.get('spark.master').startswith('local') # noinspection PyBroadException @staticmethod def _build_zip(module_dir): # This can fail at writepy if there is something wrong with the files # in xframes. Go ahead anyway, but things will probably fail if this job is # distributed. try: tf = NamedTemporaryFile(suffix='.zip', delete=False) z = PyZipFile(tf, 'w') z.writepy(module_dir) z.close() return tf.name except: logging.warn('Zip file distribution failed -- workers will not get xframes code.') logging.warn('Check for unexpected files in xframes directory.') return None @staticmethod def config(self): """ Gets the configuration parameters used to initialize the spark context. Returns ------- dict : A dict of the properties used to initialize the spark context. """ return CommonSparkContext()._get_config() @staticmethod def env(self): """ Gets the config environment. Returns ------- :class:`.Environment` : The environment. This contains all the values from the configuration file(s). """ return CommonSparkContext()._get_env() @staticmethod def spark_context(): """ Returns the spark context. Returns ------- :class:`~pyspark.SparkContext` The SparkContext object from spark. """ return CommonSparkContext()._sc @staticmethod def spark_config(): """ Returns the spark cofig parameters. Returns ------- list A list of the key-value pairs stored as tuples, used to initialize the spark context. """ return CommonSparkContext().config() @staticmethod def spark_sql_context(): """ Returns the spark sql context. Returns ------- :class:`~pyspark.sql.SQLContext' The SQLContext object from spark. """ return CommonSparkContext()._get_sqlc() @staticmethod def hive_context(): """ Returns the hive context. Returns ------- :class:`~pyspark.streaming.HiveContext` The Hive object from spark. """ return CommonSparkContext()._get_hivec() @staticmethod def streaming_context(interval=1): """ Returns the streaming context. Parameters ---------- interval : int, optional The batch duration in seconds for the stream. Default is one second. Returns ------- :class:`~pyspark.streaming.StreamingContext` The streaming context. """ return CommonSparkContext()._get_streamingc(interval) @staticmethod def spark_version(): """ Gets the spark version. Returns ------- list[int] The spark version, as a list of integers. """ return CommonSparkContext()._get_version() @staticmethod def jobs(): """ Get the spark job ID and info for the active jobs. This method would normally be called by another thread from the executing job. Returns ------- map(job_id: job_info} A map of the active job IDs and their corresponding job info """ return CommonSparkContext()._get_jobs() @staticmethod def cluster_mode(): """ Gets the cluster mode Returns ------- boolean True if spark is running in cluster mode. Cluster mode means that spark is running on a platform separate the program. In practice, cluster mode means that file arguments must be located on a network filesystem such as HDFS or NFS. """ return CommonSparkContext()._get_cluster_mode()