Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37137][PYTHON] Inline type hints for python/pyspark/conf.py #34411

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 41 additions & 26 deletions python/pyspark/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
__all__ = ['SparkConf']

import sys
from typing import Dict, List, Optional, Tuple, cast, overload

from py4j.java_gateway import JVMView, JavaObject # type: ignore[import]

class SparkConf(object):

class SparkConf(object):
"""
Configuration for a Spark application. Used to set various Spark
parameters as key-value pairs.
Expand Down Expand Up @@ -105,15 +107,19 @@ class SparkConf(object):
spark.home=/path
"""

def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
_jconf: Optional[JavaObject]
_conf: Optional[Dict[str, str]]

def __init__(self, loadDefaults: bool = True, _jvm: Optional[JVMView] = None,
_jconf: Optional[JavaObject] = None):
"""
Create a new Spark configuration.
"""
if _jconf:
self._jconf = _jconf
else:
from pyspark.context import SparkContext
_jvm = _jvm or SparkContext._jvm
_jvm = _jvm or SparkContext._jvm # type: ignore[attr-defined]

if _jvm is not None:
# JVM is created, so create self._jconf directly through JVM
Expand All @@ -124,48 +130,57 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
self._jconf = None
self._conf = {}

def set(self, key, value):
def set(self, key: str, value: str) -> "SparkConf":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could be less restrictive here, but I am not sure if that's a good idea, since it would depend on __str__ or __repr__ implementation for the value. This has some weird consequences, like:

>>> key, value = "foo", None
>>> conf = sc.getConf()
>>> conf.set(key, value)
<pyspark.conf.SparkConf object at 0x7f4870aa5ee0>
>>> conf.get(key) == value
False

Copy link
Member Author

@ByronHsu ByronHsu Oct 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do we need to change the type of value from str to Optional[str]? Or could we open another ticket for this issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin @HyukjinKwon Could you help me review this patch? Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tough call .. I would keep it as just str for now though .. for None it should be mapped to null on JVM ideally.

"""Set a configuration property."""
# Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet.
if self._jconf is not None:
self._jconf.set(key, str(value))
else:
self._conf[key] = str(value)
cast(Dict[str, str], self._conf)[key] = str(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably go with assert here

diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index 09c8e63d09..a8538b06e4 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -136,7 +136,8 @@ class SparkConf(object):
         if self._jconf is not None:
             self._jconf.set(key, str(value))
         else:
-            cast(Dict[str, str], self._conf)[key] = str(value)
+            assert self._conf is not None
+            self._conf[key] = str(value)
         return self
 
     def setIfMissing(self, key: str, value: str) -> "SparkConf":

but I guess it is fine for now.

return self

def setIfMissing(self, key, value):
def setIfMissing(self, key: str, value: str) -> "SparkConf":
"""Set a configuration property, if not already set."""
if self.get(key) is None:
self.set(key, value)
return self

def setMaster(self, value):
def setMaster(self, value: str) -> "SparkConf":
"""Set master URL to connect to."""
self.set("spark.master", value)
return self

def setAppName(self, value):
def setAppName(self, value: str) -> "SparkConf":
"""Set application name."""
self.set("spark.app.name", value)
return self

def setSparkHome(self, value):
def setSparkHome(self, value: str) -> "SparkConf":
"""Set path where Spark is installed on worker nodes."""
self.set("spark.home", value)
return self

def setExecutorEnv(self, key=None, value=None, pairs=None):
@overload
def setExecutorEnv(self, key: str, value: str) -> "SparkConf":
...

@overload
def setExecutorEnv(self, *, pairs: List[Tuple[str, str]]) -> "SparkConf":
...

def setExecutorEnv(self, key: Optional[str] = None, value: Optional[str] = None,
pairs: Optional[List[Tuple[str, str]]] = None) -> "SparkConf":
"""Set an environment variable to be passed to executors."""
if (key is not None and pairs is not None) or (key is None and pairs is None):
raise RuntimeError("Either pass one key-value pair or a list of pairs")
elif key is not None:
self.set("spark.executorEnv." + key, value)
self.set("spark.executorEnv.{}".format(key), cast(str, value))
elif pairs is not None:
for (k, v) in pairs:
self.set("spark.executorEnv." + k, v)
self.set("spark.executorEnv.{}".format(k), v)
return self

def setAll(self, pairs):
def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf":
"""
Set multiple parameters, passed as a list of key-value pairs.

Expand All @@ -178,49 +193,49 @@ def setAll(self, pairs):
self.set(k, v)
return self

def get(self, key, defaultValue=None):
def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]:
"""Get the configured value for some key, or return a default otherwise."""
if defaultValue is None: # Py4J doesn't call the right get() if we pass None
if defaultValue is None: # Py4J doesn't call the right get() if we pass None
if self._jconf is not None:
if not self._jconf.contains(key):
return None
return self._jconf.get(key)
else:
if key not in self._conf:
if key not in cast(Dict[str, str], self._conf):
return None
return self._conf[key]
return cast(Dict[str, str], self._conf)[key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated note ‒ shouldn't we use get here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, same as above ‒ single assert might be a better option.

else:
if self._jconf is not None:
return self._jconf.get(key, defaultValue)
else:
return self._conf.get(key, defaultValue)
return cast(Dict[str, str], self._conf).get(key, defaultValue)

def getAll(self):
def getAll(self) -> List[Tuple[str, str]]:
"""Get all values as a list of key-value pairs."""
if self._jconf is not None:
return [(elem._1(), elem._2()) for elem in self._jconf.getAll()]
return [(elem._1(), elem._2()) for elem in cast(JavaObject, self._jconf).getAll()]
else:
return self._conf.items()
return list(cast(Dict[str, str], self._conf).items())

def contains(self, key):
def contains(self, key: str) -> bool:
"""Does this configuration contain a given key?"""
if self._jconf is not None:
return self._jconf.contains(key)
else:
return key in self._conf
return key in cast(Dict[str, str], self._conf)

def toDebugString(self):
def toDebugString(self) -> str:
"""
Returns a printable version of the configuration, as a list of
key=value pairs, one per line.
"""
if self._jconf is not None:
return self._jconf.toDebugString()
else:
return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items())
return '\n'.join('%s=%s' % (k, v) for k, v in cast(Dict[str, str], self._conf).items())


def _test():
def _test() -> None:
import doctest
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
Expand Down
44 changes: 0 additions & 44 deletions python/pyspark/conf.pyi

This file was deleted.

4 changes: 2 additions & 2 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def _create_shell_session() -> "SparkSession":
try:
# Try to access HiveConf, it will raise exception if Hive is not added
conf = SparkConf()
if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive':
if cast(str, conf.get('spark.sql.catalogImplementation', 'hive')).lower() == 'hive':
(SparkContext._jvm # type: ignore[attr-defined]
.org.apache.hadoop.hive.conf.HiveConf())
return SparkSession.builder\
Expand All @@ -619,7 +619,7 @@ def _create_shell_session() -> "SparkSession":
else:
return SparkSession.builder.getOrCreate()
except (py4j.protocol.Py4JError, TypeError):
if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive':
if cast(str, conf.get('spark.sql.catalogImplementation', '')).lower() == 'hive':
warnings.warn("Fall back to non-hive support because failing to access HiveConf, "
"please make sure you build spark with hive")

Expand Down