dbt-selly/dbt-env/lib/python3.8/site-packages/dbt/context/base.py

543 lines
16 KiB
Python

import json
import os
from typing import (
Any, Dict, NoReturn, Optional, Mapping
)
from dbt import flags
from dbt import tracking
from dbt.clients.jinja import undefined_error, get_rendered
from dbt.clients.yaml_helper import ( # noqa: F401
yaml, safe_load, SafeLoader, Loader, Dumper
)
from dbt.contracts.graph.compiled import CompiledResource
from dbt.exceptions import raise_compiler_error, MacroReturn
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.version import __version__ as dbt_version
# These modules are added to the context. Consider alternative
# approaches which will extend well to potentially many modules
import pytz
import datetime
import re
def get_pytz_module_context() -> Dict[str, Any]:
context_exports = pytz.__all__ # type: ignore
return {
name: getattr(pytz, name) for name in context_exports
}
def get_datetime_module_context() -> Dict[str, Any]:
context_exports = [
'date',
'datetime',
'time',
'timedelta',
'tzinfo'
]
return {
name: getattr(datetime, name) for name in context_exports
}
def get_re_module_context() -> Dict[str, Any]:
context_exports = re.__all__
return {
name: getattr(re, name) for name in context_exports
}
def get_context_modules() -> Dict[str, Dict[str, Any]]:
return {
'pytz': get_pytz_module_context(),
'datetime': get_datetime_module_context(),
're': get_re_module_context(),
}
class ContextMember:
def __init__(self, value, name=None):
self.name = name
self.inner = value
def key(self, default):
if self.name is None:
return default
return self.name
def contextmember(value):
if isinstance(value, str):
return lambda v: ContextMember(v, name=value)
return ContextMember(value)
def contextproperty(value):
if isinstance(value, str):
return lambda v: ContextMember(property(v), name=value)
return ContextMember(property(value))
class ContextMeta(type):
def __new__(mcls, name, bases, dct):
context_members = {}
context_attrs = {}
new_dct = {}
for base in bases:
context_members.update(getattr(base, '_context_members_', {}))
context_attrs.update(getattr(base, '_context_attrs_', {}))
for key, value in dct.items():
if isinstance(value, ContextMember):
context_key = value.key(key)
context_members[context_key] = value.inner
context_attrs[context_key] = key
value = value.inner
new_dct[key] = value
new_dct['_context_members_'] = context_members
new_dct['_context_attrs_'] = context_attrs
return type.__new__(mcls, name, bases, new_dct)
class Var:
UndefinedVarError = "Required var '{}' not found in config:\nVars "\
"supplied to {} = {}"
_VAR_NOTSET = object()
def __init__(
self,
context: Mapping[str, Any],
cli_vars: Mapping[str, Any],
node: Optional[CompiledResource] = None
) -> None:
self._context: Mapping[str, Any] = context
self._cli_vars: Mapping[str, Any] = cli_vars
self._node: Optional[CompiledResource] = node
self._merged: Mapping[str, Any] = self._generate_merged()
def _generate_merged(self) -> Mapping[str, Any]:
return self._cli_vars
@property
def node_name(self):
if self._node is not None:
return self._node.name
else:
return '<Configuration>'
def get_missing_var(self, var_name):
dct = {k: self._merged[k] for k in self._merged}
pretty_vars = json.dumps(dct, sort_keys=True, indent=4)
msg = self.UndefinedVarError.format(
var_name, self.node_name, pretty_vars
)
raise_compiler_error(msg, self._node)
def has_var(self, var_name: str):
return var_name in self._merged
def get_rendered_var(self, var_name):
raw = self._merged[var_name]
# if bool/int/float/etc are passed in, don't compile anything
if not isinstance(raw, str):
return raw
return get_rendered(raw, self._context)
def __call__(self, var_name, default=_VAR_NOTSET):
if self.has_var(var_name):
return self.get_rendered_var(var_name)
elif default is not self._VAR_NOTSET:
return default
else:
return self.get_missing_var(var_name)
class BaseContext(metaclass=ContextMeta):
def __init__(self, cli_vars):
self._ctx = {}
self.cli_vars = cli_vars
def generate_builtins(self):
builtins: Dict[str, Any] = {}
for key, value in self._context_members_.items():
if hasattr(value, '__get__'):
# handle properties, bound methods, etc
value = value.__get__(self)
builtins[key] = value
return builtins
# no dbtClassMixin so this is not an actual override
def to_dict(self):
self._ctx['context'] = self._ctx
builtins = self.generate_builtins()
self._ctx['builtins'] = builtins
self._ctx.update(builtins)
return self._ctx
@contextproperty
def dbt_version(self) -> str:
"""The `dbt_version` variable returns the installed version of dbt that
is currently running. It can be used for debugging or auditing
purposes.
> macros/get_version.sql
{% macro get_version() %}
{% set msg = "The installed version of dbt is: " ~ dbt_version %}
{% do log(msg, info=true) %}
{% endmacro %}
Example output:
$ dbt run-operation get_version
The installed version of dbt is 0.16.0
"""
return dbt_version
@contextproperty
def var(self) -> Var:
"""Variables can be passed from your `dbt_project.yml` file into models
during compilation. These variables are useful for configuring packages
for deployment in multiple environments, or defining values that should
be used across multiple models within a package.
To add a variable to a model, use the `var()` function:
> my_model.sql:
select * from events where event_type = '{{ var("event_type") }}'
If you try to run this model without supplying an `event_type`
variable, you'll receive a compilation error that looks like this:
Encountered an error:
! Compilation error while compiling model package_name.my_model:
! Required var 'event_type' not found in config:
Vars supplied to package_name.my_model = {
}
To supply a variable to a given model, add one or more `vars`
dictionaries to the `models` config in your `dbt_project.yml` file.
These `vars` are in-scope for all models at or below where they are
defined, so place them where they make the most sense. Below are three
different placements of the `vars` dict, all of which will make the
`my_model` model compile.
> dbt_project.yml:
# 1) scoped at the model level
models:
package_name:
my_model:
materialized: view
vars:
event_type: activation
# 2) scoped at the package level
models:
package_name:
vars:
event_type: activation
my_model:
materialized: view
# 3) scoped globally
models:
vars:
event_type: activation
package_name:
my_model:
materialized: view
## Variable default values
The `var()` function takes an optional second argument, `default`. If
this argument is provided, then it will be the default value for the
variable if one is not explicitly defined.
> my_model.sql:
-- Use 'activation' as the event_type if the variable is not
-- defined.
select *
from events
where event_type = '{{ var("event_type", "activation") }}'
"""
return Var(self._ctx, self.cli_vars)
@contextmember
@staticmethod
def env_var(var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the environment variable named 'var'.
If there is no such environment variable set, return the default.
If the default is None, raise an exception for an undefined variable.
"""
if var in os.environ:
return os.environ[var]
elif default is not None:
return default
else:
msg = f"Env var required but not provided: '{var}'"
undefined_error(msg)
if os.environ.get('DBT_MACRO_DEBUGGING'):
@contextmember
@staticmethod
def debug():
"""Enter a debugger at this line in the compiled jinja code."""
import sys
import ipdb # type: ignore
frame = sys._getframe(3)
ipdb.set_trace(frame)
return ''
@contextmember('return')
@staticmethod
def _return(data: Any) -> NoReturn:
"""The `return` function can be used in macros to return data to the
caller. The type of the data (`dict`, `list`, `int`, etc) will be
preserved through the return call.
:param data: The data to return to the caller
> macros/example.sql:
{% macro get_data() %}
{{ return([1,2,3]) }}
{% endmacro %}
> models/my_model.sql:
select
-- getdata() returns a list!
{% for i in getdata() %}
{{ i }}
{% if not loop.last %},{% endif %}
{% endfor %}
"""
raise MacroReturn(data)
@contextmember
@staticmethod
def fromjson(string: str, default: Any = None) -> Any:
"""The `fromjson` context method can be used to deserialize a json
string into a Python object primitive, eg. a `dict` or `list`.
:param value: The json string to deserialize
:param default: A default value to return if the `string` argument
cannot be deserialized (optional)
Usage:
{% set my_json_str = '{"abc": 123}' %}
{% set my_dict = fromjson(my_json_str) %}
{% do log(my_dict['abc']) %}
"""
try:
return json.loads(string)
except ValueError:
return default
@contextmember
@staticmethod
def tojson(
value: Any, default: Any = None, sort_keys: bool = False
) -> Any:
"""The `tojson` context method can be used to serialize a Python
object primitive, eg. a `dict` or `list` to a json string.
:param value: The value serialize to json
:param default: A default value to return if the `value` argument
cannot be serialized
:param sort_keys: If True, sort the keys.
Usage:
{% set my_dict = {"abc": 123} %}
{% set my_json_string = tojson(my_dict) %}
{% do log(my_json_string) %}
"""
try:
return json.dumps(value, sort_keys=sort_keys)
except ValueError:
return default
@contextmember
@staticmethod
def fromyaml(value: str, default: Any = None) -> Any:
"""The fromyaml context method can be used to deserialize a yaml string
into a Python object primitive, eg. a `dict` or `list`.
:param value: The yaml string to deserialize
:param default: A default value to return if the `string` argument
cannot be deserialized (optional)
Usage:
{% set my_yml_str -%}
dogs:
- good
- bad
{%- endset %}
{% set my_dict = fromyaml(my_yml_str) %}
{% do log(my_dict['dogs'], info=true) %}
-- ["good", "bad"]
{% do my_dict['dogs'].pop() }
{% do log(my_dict['dogs'], info=true) %}
-- ["good"]
"""
try:
return safe_load(value)
except (AttributeError, ValueError, yaml.YAMLError):
return default
# safe_dump defaults to sort_keys=True, but we act like json.dumps (the
# opposite)
@contextmember
@staticmethod
def toyaml(
value: Any, default: Optional[str] = None, sort_keys: bool = False
) -> Optional[str]:
"""The `tojson` context method can be used to serialize a Python
object primitive, eg. a `dict` or `list` to a yaml string.
:param value: The value serialize to yaml
:param default: A default value to return if the `value` argument
cannot be serialized
:param sort_keys: If True, sort the keys.
Usage:
{% set my_dict = {"abc": 123} %}
{% set my_yaml_string = toyaml(my_dict) %}
{% do log(my_yaml_string) %}
"""
try:
return yaml.safe_dump(data=value, sort_keys=sort_keys)
except (ValueError, yaml.YAMLError):
return default
@contextmember
@staticmethod
def log(msg: str, info: bool = False) -> str:
"""Logs a line to either the log file or stdout.
:param msg: The message to log
:param info: If `False`, write to the log file. If `True`, write to
both the log file and stdout.
> macros/my_log_macro.sql
{% macro some_macro(arg1, arg2) %}
{{ log("Running some_macro: " ~ arg1 ~ ", " ~ arg2) }}
{% endmacro %}"
"""
if info:
logger.info(msg)
else:
logger.debug(msg)
return ''
@contextproperty
def run_started_at(self) -> Optional[datetime.datetime]:
"""`run_started_at` outputs the timestamp that this run started, e.g.
`2017-04-21 01:23:45.678`. The `run_started_at` variable is a Python
`datetime` object. As of 0.9.1, the timezone of this variable defaults
to UTC.
> run_started_at_example.sql
select
'{{ run_started_at.strftime("%Y-%m-%d") }}' as date_day
from ...
To modify the timezone of this variable, use the the `pytz` module:
> run_started_at_utc.sql
{% set est = modules.pytz.timezone("America/New_York") %}
select
'{{ run_started_at.astimezone(est) }}' as run_started_est
from ...
"""
if tracking.active_user is not None:
return tracking.active_user.run_started_at
else:
return None
@contextproperty
def invocation_id(self) -> Optional[str]:
"""invocation_id outputs a UUID generated for this dbt run (useful for
auditing)
"""
if tracking.active_user is not None:
return tracking.active_user.invocation_id
else:
return None
@contextproperty
def modules(self) -> Dict[str, Any]:
"""The `modules` variable in the Jinja context contains useful Python
modules for operating on data.
# datetime
This variable is a pointer to the Python datetime module.
Usage:
{% set dt = modules.datetime.datetime.now() %}
# pytz
This variable is a pointer to the Python pytz module.
Usage:
{% set dt = modules.datetime.datetime(2002, 10, 27, 6, 0, 0) %}
{% set dt_local = modules.pytz.timezone('US/Eastern').localize(dt) %}
{{ dt_local }}
""" # noqa
return get_context_modules()
@contextproperty
def flags(self) -> Any:
"""The `flags` variable contains true/false values for flags provided
on the command line.
> flags.sql:
{% if flags.FULL_REFRESH %}
drop table ...
{% else %}
-- no-op
{% endif %}
The list of valid flags are:
- `flags.STRICT_MODE`: True if `--strict` (or `-S`) was provided on the
command line
- `flags.FULL_REFRESH`: True if `--full-refresh` was provided on the
command line
- `flags.NON_DESTRUCTIVE`: True if `--non-destructive` was provided on
the command line
"""
return flags
def generate_base_context(cli_vars: Dict[str, Any]) -> Dict[str, Any]:
ctx = BaseContext(cli_vars)
# This is not a Mashumaro to_dict call
return ctx.to_dict()