Skip to content

Commit

Permalink
[SPARK-37137][PYTHON] Inline type hints for python/pyspark/conf.py
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Inline type hints for python/pyspark/conf.py

### Why are the changes needed?

Currently, Inline type hints for python/pyspark/conf.pyi doesn't support type checking within function bodies. So we inline type hints to support that.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Exising test.

Closes #34411 from ByronHsu/SPARK-37137.

Authored-by: Byron <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
ByronHsu authored and HyukjinKwon committed Nov 5, 2021
1 parent 8f20398 commit 2b4099f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 69 deletions.
65 changes: 42 additions & 23 deletions python/pyspark/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
__all__ = ['SparkConf']

import sys
from typing import Dict, List, Optional, Tuple, cast, overload

from py4j.java_gateway import JVMView, JavaObject # type: ignore[import]

class SparkConf(object):

class SparkConf(object):
"""
Configuration for a Spark application. Used to set various Spark
parameters as key-value pairs.
Expand Down Expand Up @@ -105,15 +107,19 @@ class SparkConf(object):
spark.home=/path
"""

def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
_jconf: Optional[JavaObject]
_conf: Optional[Dict[str, str]]

def __init__(self, loadDefaults: bool = True, _jvm: Optional[JVMView] = None,
_jconf: Optional[JavaObject] = None):
"""
Create a new Spark configuration.
"""
if _jconf:
self._jconf = _jconf
else:
from pyspark.context import SparkContext
_jvm = _jvm or SparkContext._jvm
_jvm = _jvm or SparkContext._jvm # type: ignore[attr-defined]

if _jvm is not None:
# JVM is created, so create self._jconf directly through JVM
Expand All @@ -124,48 +130,58 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
self._jconf = None
self._conf = {}

def set(self, key, value):
def set(self, key: str, value: str) -> "SparkConf":
"""Set a configuration property."""
# Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet.
if self._jconf is not None:
self._jconf.set(key, str(value))
else:
assert self._conf is not None
self._conf[key] = str(value)
return self

def setIfMissing(self, key, value):
def setIfMissing(self, key: str, value: str) -> "SparkConf":
"""Set a configuration property, if not already set."""
if self.get(key) is None:
self.set(key, value)
return self

def setMaster(self, value):
def setMaster(self, value: str) -> "SparkConf":
"""Set master URL to connect to."""
self.set("spark.master", value)
return self

def setAppName(self, value):
def setAppName(self, value: str) -> "SparkConf":
"""Set application name."""
self.set("spark.app.name", value)
return self

def setSparkHome(self, value):
def setSparkHome(self, value: str) -> "SparkConf":
"""Set path where Spark is installed on worker nodes."""
self.set("spark.home", value)
return self

def setExecutorEnv(self, key=None, value=None, pairs=None):
@overload
def setExecutorEnv(self, key: str, value: str) -> "SparkConf":
...

@overload
def setExecutorEnv(self, *, pairs: List[Tuple[str, str]]) -> "SparkConf":
...

def setExecutorEnv(self, key: Optional[str] = None, value: Optional[str] = None,
pairs: Optional[List[Tuple[str, str]]] = None) -> "SparkConf":
"""Set an environment variable to be passed to executors."""
if (key is not None and pairs is not None) or (key is None and pairs is None):
raise RuntimeError("Either pass one key-value pair or a list of pairs")
elif key is not None:
self.set("spark.executorEnv." + key, value)
self.set("spark.executorEnv.{}".format(key), cast(str, value))
elif pairs is not None:
for (k, v) in pairs:
self.set("spark.executorEnv." + k, v)
self.set("spark.executorEnv.{}".format(k), v)
return self

def setAll(self, pairs):
def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf":
"""
Set multiple parameters, passed as a list of key-value pairs.
Expand All @@ -178,49 +194,52 @@ def setAll(self, pairs):
self.set(k, v)
return self

def get(self, key, defaultValue=None):
def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]:
"""Get the configured value for some key, or return a default otherwise."""
if defaultValue is None: # Py4J doesn't call the right get() if we pass None
if defaultValue is None: # Py4J doesn't call the right get() if we pass None
if self._jconf is not None:
if not self._jconf.contains(key):
return None
return self._jconf.get(key)
else:
if key not in self._conf:
return None
return self._conf[key]
assert self._conf is not None
return self._conf.get(key, None)
else:
if self._jconf is not None:
return self._jconf.get(key, defaultValue)
else:
assert self._conf is not None
return self._conf.get(key, defaultValue)

def getAll(self):
def getAll(self) -> List[Tuple[str, str]]:
"""Get all values as a list of key-value pairs."""
if self._jconf is not None:
return [(elem._1(), elem._2()) for elem in self._jconf.getAll()]
return [(elem._1(), elem._2()) for elem in cast(JavaObject, self._jconf).getAll()]
else:
return self._conf.items()
assert self._conf is not None
return list(self._conf.items())

def contains(self, key):
def contains(self, key: str) -> bool:
"""Does this configuration contain a given key?"""
if self._jconf is not None:
return self._jconf.contains(key)
else:
assert self._conf is not None
return key in self._conf

def toDebugString(self):
def toDebugString(self) -> str:
"""
Returns a printable version of the configuration, as a list of
key=value pairs, one per line.
"""
if self._jconf is not None:
return self._jconf.toDebugString()
else:
assert self._conf is not None
return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items())


def _test():
def _test() -> None:
import doctest
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
Expand Down
44 changes: 0 additions & 44 deletions python/pyspark/conf.pyi

This file was deleted.

4 changes: 2 additions & 2 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def _create_shell_session() -> "SparkSession":
try:
# Try to access HiveConf, it will raise exception if Hive is not added
conf = SparkConf()
if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive':
if cast(str, conf.get('spark.sql.catalogImplementation', 'hive')).lower() == 'hive':
(SparkContext._jvm # type: ignore[attr-defined]
.org.apache.hadoop.hive.conf.HiveConf())
return SparkSession.builder\
Expand All @@ -619,7 +619,7 @@ def _create_shell_session() -> "SparkSession":
else:
return SparkSession.builder.getOrCreate()
except (py4j.protocol.Py4JError, TypeError):
if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive':
if cast(str, conf.get('spark.sql.catalogImplementation', '')).lower() == 'hive':
warnings.warn("Fall back to non-hive support because failing to access HiveConf, "
"please make sure you build spark with hive")

Expand Down

0 comments on commit 2b4099f

Please sign in to comment.