· 6 years ago · Feb 04, 2020, 08:42 PM
1# -*- coding: utf-8 -*-
2#
3# Licensed to the Apache Software Foundation (ASF) under one
4# or more contributor license agreements. See the NOTICE file
5# distributed with this work for additional information
6# regarding copyright ownership. The ASF licenses this file
7# to you under the Apache License, Version 2.0 (the
8# "License"); you may not use this file except in compliance
9# with the License. You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing,
14# software distributed under the License is distributed on an
15# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16# KIND, either express or implied. See the License for the
17# specific language governing permissions and limitations
18# under the License.
19"""
20Implements Docker operator
21"""
22import ast
23import json
24from tempfile import TemporaryDirectory
25from typing import Dict, Iterable, List, Optional, Union
26
27from airflow.exceptions import AirflowException
28from airflow.hooks.docker_hook import DockerHook
29from airflow.models import BaseOperator
30from airflow.utils.decorators import apply_defaults
31
32from docker import APIClient, tls
33
34
35# pylint: disable=too-many-instance-attributes
36class DockerOperator(BaseOperator):
37 """
38 Execute a command inside a docker container.
39 A temporary directory is created on the host and
40 mounted into a container to allow storing files
41 that together exceed the default disk size of 10GB in a container.
42 The path to the mounted directory can be accessed
43 via the environment variable ``AIRFLOW_TMP_DIR``.
44 If a login to a private registry is required prior to pulling the image, a
45 Docker connection needs to be configured in Airflow and the connection ID
46 be provided with the parameter ``docker_conn_id``.
47 :param image: Docker image from which to create the container.
48 If image tag is omitted, "latest" will be used.
49 :type image: str
50 :param api_version: Remote API version. Set to ``auto`` to automatically
51 detect the server's version.
52 :type api_version: str
53 :param command: Command to be run in the container. (templated)
54 :type command: str or list
55 :param container_name: Name of the container. Optional (templated)
56 :type container_name: str or None
57 :param cpus: Number of CPUs to assign to the container.
58 This value gets multiplied with 1024. See
59 https://docs.docker.com/engine/reference/run/#cpu-share-constraint
60 :type cpus: float
61 :param docker_url: URL of the host running the docker daemon.
62 Default is unix://var/run/docker.sock
63 :type docker_url: str
64 :param environment: Environment variables to set in the container. (templated)
65 :type environment: dict
66 :param force_pull: Pull the docker image on every run. Default is False.
67 :type force_pull: bool
68 :param mem_limit: Maximum amount of memory the container can use.
69 Either a float value, which represents the limit in bytes,
70 or a string like ``128m`` or ``1g``.
71 :type mem_limit: float or str
72 :param host_tmp_dir: Specify the location of the temporary directory on the host which will
73 be mapped to tmp_dir. If not provided defaults to using the standard system temp directory.
74 :type host_tmp_dir: str
75 :param network_mode: Network mode for the container.
76 :type network_mode: str
77 :param tls_ca_cert: Path to a PEM-encoded certificate authority
78 to secure the docker connection.
79 :type tls_ca_cert: str
80 :param tls_client_cert: Path to the PEM-encoded certificate
81 used to authenticate docker client.
82 :type tls_client_cert: str
83 :param tls_client_key: Path to the PEM-encoded key used to authenticate docker client.
84 :type tls_client_key: str
85 :param tls_hostname: Hostname to match against
86 the docker server certificate or False to disable the check.
87 :type tls_hostname: str or bool
88 :param tls_ssl_version: Version of SSL to use when communicating with docker daemon.
89 :type tls_ssl_version: str
90 :param tmp_dir: Mount point inside the container to
91 a temporary directory created on the host by the operator.
92 The path is also made available via the environment variable
93 ``AIRFLOW_TMP_DIR`` inside the container.
94 :type tmp_dir: str
95 :param user: Default user inside the docker container.
96 :type user: int or str
97 :param volumes: List of volumes to mount into the container, e.g.
98 ``['/host/path:/container/path', '/host/path2:/container/path2:ro']``.
99 :type volumes: list
100 :param working_dir: Working directory to
101 set on the container (equivalent to the -w switch the docker client)
102 :type working_dir: str
103 :param xcom_all: Push all the stdout or just the last line.
104 The default is False (last line).
105 :type xcom_all: bool
106 :param docker_conn_id: ID of the Airflow connection to use
107 :type docker_conn_id: str
108 :param dns: Docker custom DNS servers
109 :type dns: list[str]
110 :param dns_search: Docker custom DNS search domain
111 :type dns_search: list[str]
112 :param auto_remove: Auto-removal of the container on daemon side when the
113 container's process exits.
114 The default is False.
115 :type auto_remove: bool
116 :param shm_size: Size of ``/dev/shm`` in bytes. The size must be
117 greater than 0. If omitted uses system default.
118 :type shm_size: int
119 :param tty: Allocate pseudo-TTY to the container
120 This needs to be set see logs of the Docker container.
121 :type tty: bool
122 """
123 template_fields = ('command', 'environment', 'container_name')
124 template_ext = ('.sh', '.bash',)
125
126 # pylint: disable=too-many-arguments,too-many-locals
127 @apply_defaults
128 def __init__(
129 self,
130 image: str,
131 api_version: Optional[str] = None,
132 command: Optional[Union[str, List[str]]] = None,
133 container_name: Optional[str] = None,
134 cpus: float = 1.0,
135 docker_url: str = 'unix://var/run/docker.sock',
136 environment: Optional[Dict] = None,
137 force_pull: bool = False,
138 mem_limit: Optional[Union[float, str]] = None,
139 host_tmp_dir: Optional[str] = None,
140 network_mode: Optional[str] = None,
141 tls_ca_cert: Optional[str] = None,
142 tls_client_cert: Optional[str] = None,
143 tls_client_key: Optional[str] = None,
144 tls_hostname: Optional[Union[str, bool]] = None,
145 tls_ssl_version: Optional[str] = None,
146 tmp_dir: str = '/tmp/airflow',
147 user: Optional[Union[str, int]] = None,
148 volumes: Optional[Iterable[str]] = None,
149 working_dir: Optional[str] = None,
150 xcom_all: bool = False,
151 docker_conn_id: Optional[str] = None,
152 dns: Optional[List[str]] = None,
153 dns_search: Optional[List[str]] = None,
154 auto_remove: bool = False,
155 shm_size: Optional[int] = None,
156 tty: Optional[bool] = False,
157 *args,
158 **kwargs) -> None:
159
160 super().__init__(*args, **kwargs)
161 self.api_version = api_version
162 self.auto_remove = auto_remove
163 self.command = command
164 self.container_name = container_name
165 self.cpus = cpus
166 self.dns = dns
167 self.dns_search = dns_search
168 self.docker_url = docker_url
169 self.environment = environment or {}
170 self.force_pull = force_pull
171 self.image = image
172 self.mem_limit = mem_limit
173 self.host_tmp_dir = host_tmp_dir
174 self.network_mode = network_mode
175 self.tls_ca_cert = tls_ca_cert
176 self.tls_client_cert = tls_client_cert
177 self.tls_client_key = tls_client_key
178 self.tls_hostname = tls_hostname
179 self.tls_ssl_version = tls_ssl_version
180 self.tmp_dir = tmp_dir
181 self.user = user
182 self.volumes = volumes or []
183 self.working_dir = working_dir
184 self.xcom_all = xcom_all
185 self.docker_conn_id = docker_conn_id
186 self.shm_size = shm_size
187 self.tty = tty
188 if kwargs.get('xcom_push') is not None:
189 raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead")
190
191 self.cli = None
192 self.container = None
193
194 def get_hook(self) -> DockerHook:
195 """
196 Retrieves hook for the operator.
197 :return: The Docker Hook
198 """
199 return DockerHook(
200 docker_conn_id=self.docker_conn_id,
201 base_url=self.docker_url,
202 version=self.api_version,
203 tls=self.__get_tls_config()
204 )
205
206 def _run_image(self):
207 """
208 Run a Docker container with the provided image
209 """
210 self.log.info('Starting docker container from image %s', self.image)
211
212 with TemporaryDirectory(prefix='airflowtmp', dir=self.host_tmp_dir) as host_tmp_dir:
213 self.volumes.append('{0}:{1}'.format(host_tmp_dir, self.tmp_dir))
214
215 self.container = self.cli.create_container(
216 command=self.get_command(),
217 name=self.container_name,
218 environment=self.environment,
219 host_config=self.cli.create_host_config(
220 auto_remove=self.auto_remove,
221 binds=self.volumes,
222 network_mode=self.network_mode,
223 shm_size=self.shm_size,
224 dns=self.dns,
225 dns_search=self.dns_search,
226 cpu_shares=int(round(self.cpus * 1024)),
227 mem_limit=self.mem_limit),
228 image=self.image,
229 user=self.user,
230 working_dir=self.working_dir,
231 tty=self.tty,
232 )
233
234 lines = self.cli.attach(container=self.container['Id'],
235 stdout=True,
236 stderr=True,
237 stream=True)
238
239 self.cli.start(self.container['Id'])
240
241 line = ''
242 for line in lines:
243 line = line.strip()
244 if hasattr(line, 'decode'):
245 line = line.decode('utf-8')
246 self.log.info(line)
247
248 result = self.cli.wait(self.container['Id'])
249 if result['StatusCode'] != 0:
250 raise AirflowException('docker container failed: ' + repr(result))
251
252 # duplicated conditional logic because of expensive operation
253 if self.do_xcom_push:
254 return self.cli.logs(container=self.container['Id']) \
255 if self.xcom_all else line.encode('utf-8')
256 else:
257 return None
258
259 def execute(self, context):
260
261 tls_config = self.__get_tls_config()
262
263 if self.docker_conn_id:
264 self.cli = self.get_hook().get_conn()
265 else:
266 self.cli = APIClient(
267 base_url=self.docker_url,
268 version=self.api_version,
269 tls=tls_config
270 )
271
272 # Pull the docker image if `force_pull` is set or image does not exist locally
273 if self.force_pull or not self.cli.images(name=self.image):
274 self.log.info('Pulling docker image %s', self.image)
275 for line in self.cli.pull(self.image, stream=True):
276 output = json.loads(line.decode('utf-8').strip())
277 if 'status' in output:
278 self.log.info("%s", output['status'])
279
280 self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir
281
282 return self._run_image()
283
284 def get_command(self):
285 """
286 Retrieve command(s). if command string starts with [, it returns the command list)
287 :return: the command (or commands)
288 :rtype: str | List[str]
289 """
290 if isinstance(self.command, str) and self.command.strip().find('[') == 0:
291 commands = ast.literal_eval(self.command)
292 else:
293 commands = self.command
294 return commands
295
296 def on_kill(self):
297 if self.cli is not None:
298 self.log.info('Stopping docker container')
299 self.cli.stop(self.container['Id'])
300
301 def __get_tls_config(self):
302 tls_config = None
303 if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key:
304 # Ignore type error on SSL version here - it is deprecated and type annotation is wrong
305 # it should be string
306 # noinspection PyTypeChecker
307 tls_config = tls.TLSConfig(
308 ca_cert=self.tls_ca_cert,
309 client_cert=(self.tls_client_cert, self.tls_client_key),
310 verify=True,
311 ssl_version=self.tls_ssl_version, # type: ignore
312 assert_hostname=self.tls_hostname
313 )
314 self.docker_url = self.docker_url.replace('tcp://', 'https://')
315 return tls_config