"""
Handles client connection/disconnection.
"""
import asyncio
import os
import socket
import ssl
import sys
import warnings
from typing import Any, Optional, Type, Union
from .compat import create_connection, create_unix_connection, get_running_loop
from .default import Default, _default
from .errors import (
SMTPConnectError,
SMTPConnectTimeoutError,
SMTPResponseException,
SMTPServerDisconnected,
SMTPTimeoutError,
)
from .protocol import SMTPProtocol
from .response import SMTPResponse
from .status import SMTPStatus
__all__ = ("SMTPConnection",)
SMTP_PORT = 25
SMTP_TLS_PORT = 465
SMTP_STARTTLS_PORT = 587
DEFAULT_TIMEOUT = 60
# Mypy special cases sys.version checks
if sys.version_info >= (3, 7):
SocketPathType = Union[str, bytes, os.PathLike]
else:
SocketPathType = Union[str, bytes]
class SMTPConnection:
"""
Handles connection/disconnection from the SMTP server provided.
Keyword arguments can be provided either on :meth:`__init__` or when
calling the :meth:`connect` method. Note that in both cases these options
are saved for later use; subsequent calls to :meth:`connect` will use the
same options, unless new ones are provided.
"""
def __init__(
self,
hostname: Optional[str] = "localhost",
port: Optional[int] = None,
username: Optional[str] = None,
password: Optional[str] = None,
source_address: Optional[str] = None,
timeout: Optional[float] = DEFAULT_TIMEOUT,
loop: Optional[asyncio.AbstractEventLoop] = None,
use_tls: bool = False,
start_tls: bool = False,
validate_certs: bool = True,
client_cert: Optional[str] = None,
client_key: Optional[str] = None,
tls_context: Optional[ssl.SSLContext] = None,
cert_bundle: Optional[str] = None,
socket_path: Optional[SocketPathType] = None,
sock: Optional[socket.socket] = None,
) -> None:
"""
:keyword hostname: Server name (or IP) to connect to. Defaults to "localhost".
:keyword port: Server port. Defaults ``465`` if ``use_tls`` is ``True``,
``587`` if ``start_tls`` is ``True``, or ``25`` otherwise.
:keyword username: Username to login as after connect.
:keyword password: Password for login after connect.
:keyword source_address: The hostname of the client. Defaults to the
result of :func:`socket.getfqdn`. Note that this call blocks.
:keyword timeout: Default timeout value for the connection, in seconds.
Defaults to 60.
:keyword loop: event loop to run on. If no loop is passed, the running loop
will be used. This option is deprecated, and will be removed in future.
:keyword use_tls: If True, make the _initial_ connection to the server
over TLS/SSL. Note that if the server supports STARTTLS only, this
should be False.
:keyword start_tls: If True, make the initial connection to the server
over plaintext, and then upgrade the connection to TLS/SSL. Not
compatible with use_tls.
:keyword validate_certs: Determines if server certificates are
validated. Defaults to True.
:keyword client_cert: Path to client side certificate, for TLS
verification.
:keyword client_key: Path to client side key, for TLS verification.
:keyword tls_context: An existing :py:class:`ssl.SSLContext`, for TLS
verification. Mutually exclusive with ``client_cert``/
``client_key``.
:keyword cert_bundle: Path to certificate bundle, for TLS verification.
:keyword socket_path: Path to a Unix domain socket. Not compatible with
hostname or port. Accepts str or bytes, or a pathlike object in 3.7+.
:keyword sock: An existing, connected socket object. If given, none of
hostname, port, or socket_path should be provided.
:raises ValueError: mutually exclusive options provided
"""
self.protocol = None # type: Optional[SMTPProtocol]
self.transport = None # type: Optional[asyncio.BaseTransport]
# Kwarg defaults are provided here, and saved for connect.
self.hostname = hostname
self.port = port
self._login_username = username
self._login_password = password
self.timeout = timeout
self.use_tls = use_tls
self._start_tls_on_connect = start_tls
self._source_address = source_address
self.validate_certs = validate_certs
self.client_cert = client_cert
self.client_key = client_key
self.tls_context = tls_context
self.cert_bundle = cert_bundle
self.socket_path = socket_path
self.sock = sock
if loop:
warnings.warn(
"Passing an event loop via the loop keyword argument is deprecated. "
"It will be removed in version 2.0.",
DeprecationWarning,
stacklevel=4,
)
self.loop = loop
self._connect_lock = None # type: Optional[asyncio.Lock]
self._validate_config()
async def __aenter__(self) -> "SMTPConnection":
if not self.is_connected:
await self.connect()
return self
async def __aexit__(
self, exc_type: Type[Exception], exc: Exception, traceback: Any
) -> None:
if isinstance(exc, (ConnectionError, TimeoutError)):
self.close()
return
try:
await self.quit()
except (SMTPServerDisconnected, SMTPResponseException, SMTPTimeoutError):
pass
@property
def is_connected(self) -> bool:
"""
Check if our transport is still connected.
"""
return bool(self.protocol is not None and self.protocol.is_connected)
@property
def source_address(self) -> str:
"""
Get the system hostname to be sent to the SMTP server.
Simply caches the result of :func:`socket.getfqdn`.
"""
if self._source_address is None:
self._source_address = socket.getfqdn()
return self._source_address
def _update_settings_from_kwargs(
self,
hostname: Optional[Union[str, Default]] = _default,
port: Optional[Union[int, Default]] = _default,
username: Optional[Union[str, Default]] = _default,
password: Optional[Union[str, Default]] = _default,
source_address: Optional[Union[str, Default]] = _default,
timeout: Optional[Union[float, Default]] = _default,
loop: Optional[Union[asyncio.AbstractEventLoop, Default]] = _default,
use_tls: Optional[bool] = None,
start_tls: Optional[bool] = None,
validate_certs: Optional[bool] = None,
client_cert: Optional[Union[str, Default]] = _default,
client_key: Optional[Union[str, Default]] = _default,
tls_context: Optional[Union[ssl.SSLContext, Default]] = _default,
cert_bundle: Optional[Union[str, Default]] = _default,
socket_path: Optional[Union[SocketPathType, Default]] = _default,
sock: Optional[Union[socket.socket, Default]] = _default,
) -> None:
"""Update our configuration from the kwargs provided.
This method can be called multiple times.
"""
if hostname is not _default:
self.hostname = hostname
if loop is not _default:
if loop is not None:
warnings.warn(
"Passing an event loop via the loop keyword argument is deprecated."
" It will be removed in version 2.0.",
DeprecationWarning,
stacklevel=3,
)
self.loop = loop
if use_tls is not None:
self.use_tls = use_tls
if start_tls is not None:
self._start_tls_on_connect = start_tls
if validate_certs is not None:
self.validate_certs = validate_certs
if port is not _default:
self.port = port
if username is not _default:
self._login_username = username
if password is not _default:
self._login_password = password
if timeout is not _default:
self.timeout = timeout
if source_address is not _default:
self._source_address = source_address
if client_cert is not _default:
self.client_cert = client_cert
if client_key is not _default:
self.client_key = client_key
if tls_context is not _default:
self.tls_context = tls_context
if cert_bundle is not _default:
self.cert_bundle = cert_bundle
if socket_path is not _default:
self.socket_path = socket_path
if sock is not _default:
self.sock = sock
def _validate_config(self) -> None:
if self._start_tls_on_connect and self.use_tls:
raise ValueError("The start_tls and use_tls options are not compatible.")
if self.tls_context is not None and self.client_cert is not None:
raise ValueError(
"Either a TLS context or a certificate/key must be provided"
)
if self.sock is not None and any([self.hostname, self.port, self.socket_path]):
raise ValueError(
"The socket option is not compatible with hostname, port or socket_path"
)
if self.socket_path is not None and any([self.hostname, self.port]):
raise ValueError(
"The socket_path option is not compatible with hostname/port"
)
async def connect(self, **kwargs) -> SMTPResponse:
"""
Initialize a connection to the server. Options provided to
:meth:`.connect` take precedence over those used to initialize the
class.
:keyword hostname: Server name (or IP) to connect to. Defaults to "localhost".
:keyword port: Server port. Defaults ``465`` if ``use_tls`` is ``True``,
``587`` if ``start_tls`` is ``True``, or ``25`` otherwise.
:keyword source_address: The hostname of the client. Defaults to the
result of :func:`socket.getfqdn`. Note that this call blocks.
:keyword timeout: Default timeout value for the connection, in seconds.
Defaults to 60.
:keyword loop: event loop to run on. If no loop is passed, the running loop
will be used. This option is deprecated, and will be removed in future.
:keyword use_tls: If True, make the initial connection to the server
over TLS/SSL. Note that if the server supports STARTTLS only, this
should be False.
:keyword start_tls: If True, make the initial connection to the server
over plaintext, and then upgrade the connection to TLS/SSL. Not
compatible with use_tls.
:keyword validate_certs: Determines if server certificates are
validated. Defaults to True.
:keyword client_cert: Path to client side certificate, for TLS.
:keyword client_key: Path to client side key, for TLS.
:keyword tls_context: An existing :py:class:`ssl.SSLContext`, for TLS.
Mutually exclusive with ``client_cert``/``client_key``.
:keyword cert_bundle: Path to certificate bundle, for TLS verification.
:keyword socket_path: Path to a Unix domain socket. Not compatible with
hostname or port. Accepts str or bytes, or a pathlike object in 3.7+.
:keyword sock: An existing, connected socket object. If given, none of
hostname, port, or socket_path should be provided.
:raises ValueError: mutually exclusive options provided
"""
self._update_settings_from_kwargs(**kwargs)
self._validate_config()
if self.loop is None:
self.loop = get_running_loop()
if self._connect_lock is None:
self._connect_lock = asyncio.Lock()
await self._connect_lock.acquire()
# Set default port last in case use_tls or start_tls is provided,
# and only if we're not using a socket.
if self.port is None and self.sock is None and self.socket_path is None:
if self.use_tls:
self.port = SMTP_TLS_PORT
elif self._start_tls_on_connect:
self.port = SMTP_STARTTLS_PORT
else:
self.port = SMTP_PORT
try:
response = await self._create_connection()
except Exception as exc:
self.close() # Reset our state to disconnected
raise exc
if self._start_tls_on_connect:
await self.starttls()
if self._login_username is not None:
password = self._login_password if self._login_password is not None else ""
await self.login(self._login_username, password)
return response
async def _create_connection(self) -> SMTPResponse:
if self.loop is None:
raise RuntimeError("No event loop set")
protocol = SMTPProtocol(
loop=self.loop, connection_lost_callback=self._connection_lost
)
tls_context = None # type: Optional[ssl.SSLContext]
ssl_handshake_timeout = None # type: Optional[float]
if self.use_tls:
tls_context = self._get_tls_context()
ssl_handshake_timeout = self.timeout
if self.sock:
connect_coro = create_connection(
self.loop,
lambda: protocol,
sock=self.sock,
ssl=tls_context,
ssl_handshake_timeout=ssl_handshake_timeout,
)
elif self.socket_path:
connect_coro = create_unix_connection(
self.loop,
lambda: protocol,
path=self.socket_path,
ssl=tls_context,
ssl_handshake_timeout=ssl_handshake_timeout,
)
else:
connect_coro = create_connection(
self.loop,
lambda: protocol,
host=self.hostname,
port=self.port,
ssl=tls_context,
ssl_handshake_timeout=ssl_handshake_timeout,
)
try:
transport, _ = await asyncio.wait_for(connect_coro, timeout=self.timeout)
except OSError as exc:
raise SMTPConnectError(
"Error connecting to {host} on port {port}: {err}".format(
host=self.hostname, port=self.port, err=exc
)
) from exc
except asyncio.TimeoutError as exc:
raise SMTPConnectTimeoutError(
"Timed out connecting to {host} on port {port}".format(
host=self.hostname, port=self.port
)
) from exc
self.protocol = protocol
self.transport = transport
try:
response = await protocol.read_response(timeout=self.timeout)
except SMTPServerDisconnected as exc:
raise SMTPConnectError(
"Error connecting to {host} on port {port}: {err}".format(
host=self.hostname, port=self.port, err=exc
)
) from exc
except SMTPTimeoutError as exc:
raise SMTPConnectTimeoutError(
"Timed out waiting for server ready message"
) from exc
if response.code != SMTPStatus.ready:
raise SMTPConnectError(str(response))
return response
def _connection_lost(self, waiter: asyncio.Future) -> None:
if waiter.cancelled() or waiter.exception() is not None:
self.close()
async def execute_command(
self, *args: bytes, timeout: Optional[Union[float, Default]] = _default
) -> SMTPResponse:
"""
Check that we're connected, if we got a timeout value, and then
pass the command to the protocol.
:raises SMTPServerDisconnected: connection lost
"""
if self.protocol is None:
raise SMTPServerDisconnected("Server not connected")
if timeout is _default:
timeout = self.timeout
response = await self.protocol.execute_command(*args, timeout=timeout)
# If the server is unavailable, be nice and close the connection
if response.code == SMTPStatus.domain_unavailable:
self.close()
return response
async def quit(
self, timeout: Optional[Union[float, Default]] = _default
) -> SMTPResponse:
raise NotImplementedError
async def login(
self,
username: str,
password: str,
timeout: Optional[Union[float, Default]] = _default,
) -> SMTPResponse:
raise NotImplementedError
async def starttls(
self,
server_hostname: Optional[str] = None,
validate_certs: Optional[bool] = None,
client_cert: Optional[Union[str, Default]] = _default,
client_key: Optional[Union[str, Default]] = _default,
cert_bundle: Optional[Union[str, Default]] = _default,
tls_context: Optional[Union[ssl.SSLContext, Default]] = _default,
timeout: Optional[Union[float, Default]] = _default,
) -> SMTPResponse:
raise NotImplementedError
def _get_tls_context(self) -> ssl.SSLContext:
"""
Build an SSLContext object from the options we've been given.
"""
if self.tls_context is not None:
context = self.tls_context
else:
# SERVER_AUTH is what we want for a client side socket
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.check_hostname = bool(self.validate_certs)
if self.validate_certs:
context.verify_mode = ssl.CERT_REQUIRED
else:
context.verify_mode = ssl.CERT_NONE
if self.cert_bundle is not None:
context.load_verify_locations(cafile=self.cert_bundle)
if self.client_cert is not None:
context.load_cert_chain(self.client_cert, keyfile=self.client_key)
return context
def close(self) -> None:
"""
Closes the connection.
"""
if self.transport is not None and not self.transport.is_closing():
self.transport.close()
if self._connect_lock is not None and self._connect_lock.locked():
self._connect_lock.release()
self.protocol = None
self.transport = None
def get_transport_info(self, key: str) -> Any:
"""
Get extra info from the transport.
Supported keys:
- ``peername``
- ``socket``
- ``sockname``
- ``compression``
- ``cipher``
- ``peercert``
- ``sslcontext``
- ``sslobject``
:raises SMTPServerDisconnected: connection lost
"""
if self.transport is None:
raise SMTPServerDisconnected("Server not connected")
return self.transport.get_extra_info(key)