--- /dev/null
+import sys
+import types
+from copy import deepcopy
+import collections
+import inspect
+
+__all__ = ['dataclass',
+ 'field',
+ 'FrozenInstanceError',
+ 'InitVar',
+
+ # Helper functions.
+ 'fields',
+ 'asdict',
+ 'astuple',
+ 'make_dataclass',
+ 'replace',
+ ]
+
+# Raised when an attempt is made to modify a frozen class.
+class FrozenInstanceError(AttributeError): pass
+
+# A sentinel object for default values to signal that a
+# default-factory will be used.
+# This is given a nice repr() which will appear in the function
+# signature of dataclasses' constructors.
+class _HAS_DEFAULT_FACTORY_CLASS:
+ def __repr__(self):
+ return '<factory>'
+_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
+
+# A sentinel object to detect if a parameter is supplied or not.
+class _MISSING_FACTORY:
+ def __repr__(self):
+ return '<missing>'
+_MISSING = _MISSING_FACTORY()
+
+# Since most per-field metadata will be unused, create an empty
+# read-only proxy that can be shared among all fields.
+_EMPTY_METADATA = types.MappingProxyType({})
+
+# Markers for the various kinds of fields and pseudo-fields.
+_FIELD = object() # An actual field.
+_FIELD_CLASSVAR = object() # Not a field, but a ClassVar.
+_FIELD_INITVAR = object() # Not a field, but an InitVar.
+
+# The name of an attribute on the class where we store the Field
+# objects. Also used to check if a class is a Data Class.
+_MARKER = '__dataclass_fields__'
+
+# The name of the function, that if it exists, is called at the end of
+# __init__.
+_POST_INIT_NAME = '__post_init__'
+
+
+class _InitVarMeta(type):
+ def __getitem__(self, params):
+ return self
+
+class InitVar(metaclass=_InitVarMeta):
+ pass
+
+
+# Instances of Field are only ever created from within this module,
+# and only from the field() function, although Field instances are
+# exposed externally as (conceptually) read-only objects.
+# name and type are filled in after the fact, not in __init__. They're
+# not known at the time this class is instantiated, but it's
+# convenient if they're available later.
+# When cls._MARKER is filled in with a list of Field objects, the name
+# and type fields will have been populated.
+class Field:
+ __slots__ = ('name',
+ 'type',
+ 'default',
+ 'default_factory',
+ 'repr',
+ 'hash',
+ 'init',
+ 'compare',
+ 'metadata',
+ '_field_type', # Private: not to be used by user code.
+ )
+
+ def __init__(self, default, default_factory, init, repr, hash, compare,
+ metadata):
+ self.name = None
+ self.type = None
+ self.default = default
+ self.default_factory = default_factory
+ self.init = init
+ self.repr = repr
+ self.hash = hash
+ self.compare = compare
+ self.metadata = (_EMPTY_METADATA
+ if metadata is None or len(metadata) == 0 else
+ types.MappingProxyType(metadata))
+ self._field_type = None
+
+ def __repr__(self):
+ return ('Field('
+ f'name={self.name!r},'
+ f'type={self.type},'
+ f'default={self.default},'
+ f'default_factory={self.default_factory},'
+ f'init={self.init},'
+ f'repr={self.repr},'
+ f'hash={self.hash},'
+ f'compare={self.compare},'
+ f'metadata={self.metadata}'
+ ')')
+
+
+# This function is used instead of exposing Field creation directly,
+# so that a type checker can be told (via overloads) that this is a
+# function whose type depends on its parameters.
+def field(*, default=_MISSING, default_factory=_MISSING, init=True, repr=True,
+ hash=None, compare=True, metadata=None):
+ """Return an object to identify dataclass fields.
+
+ default is the default value of the field. default_factory is a
+ 0-argument function called to initialize a field's value. If init
+ is True, the field will be a parameter to the class's __init__()
+ function. If repr is True, the field will be included in the
+ object's repr(). If hash is True, the field will be included in
+ the object's hash(). If compare is True, the field will be used in
+ comparison functions. metadata, if specified, must be a mapping
+ which is stored but not otherwise examined by dataclass.
+
+ It is an error to specify both default and default_factory.
+ """
+
+ if default is not _MISSING and default_factory is not _MISSING:
+ raise ValueError('cannot specify both default and default_factory')
+ return Field(default, default_factory, init, repr, hash, compare,
+ metadata)
+
+
+def _tuple_str(obj_name, fields):
+ # Return a string representing each field of obj_name as a tuple
+ # member. So, if fields is ['x', 'y'] and obj_name is "self",
+ # return "(self.x,self.y)".
+
+ # Special case for the 0-tuple.
+ if len(fields) == 0:
+ return '()'
+ # Note the trailing comma, needed if this turns out to be a 1-tuple.
+ return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
+
+
+def _create_fn(name, args, body, globals=None, locals=None,
+ return_type=_MISSING):
+ # Note that we mutate locals when exec() is called. Caller beware!
+ if locals is None:
+ locals = {}
+ return_annotation = ''
+ if return_type is not _MISSING:
+ locals['_return_type'] = return_type
+ return_annotation = '->_return_type'
+ args = ','.join(args)
+ body = '\n'.join(f' {b}' for b in body)
+
+ txt = f'def {name}({args}){return_annotation}:\n{body}'
+
+ exec(txt, globals, locals)
+ return locals[name]
+
+
+def _field_assign(frozen, name, value, self_name):
+ # If we're a frozen class, then assign to our fields in __init__
+ # via object.__setattr__. Otherwise, just use a simple
+ # assignment.
+ # self_name is what "self" is called in this function: don't
+ # hard-code "self", since that might be a field name.
+ if frozen:
+ return f'object.__setattr__({self_name},{name!r},{value})'
+ return f'{self_name}.{name}={value}'
+
+
+def _field_init(f, frozen, globals, self_name):
+ # Return the text of the line in the body of __init__ that will
+ # initialize this field.
+
+ default_name = f'_dflt_{f.name}'
+ if f.default_factory is not _MISSING:
+ if f.init:
+ # This field has a default factory. If a parameter is
+ # given, use it. If not, call the factory.
+ globals[default_name] = f.default_factory
+ value = (f'{default_name}() '
+ f'if {f.name} is _HAS_DEFAULT_FACTORY '
+ f'else {f.name}')
+ else:
+ # This is a field that's not in the __init__ params, but
+ # has a default factory function. It needs to be
+ # initialized here by calling the factory function,
+ # because there's no other way to initialize it.
+
+ # For a field initialized with a default=defaultvalue, the
+ # class dict just has the default value
+ # (cls.fieldname=defaultvalue). But that won't work for a
+ # default factory, the factory must be called in __init__
+ # and we must assign that to self.fieldname. We can't
+ # fall back to the class dict's value, both because it's
+ # not set, and because it might be different per-class
+ # (which, after all, is why we have a factory function!).
+
+ globals[default_name] = f.default_factory
+ value = f'{default_name}()'
+ else:
+ # No default factory.
+ if f.init:
+ if f.default is _MISSING:
+ # There's no default, just do an assignment.
+ value = f.name
+ elif f.default is not _MISSING:
+ globals[default_name] = f.default
+ value = f.name
+ else:
+ # This field does not need initialization. Signify that to
+ # the caller by returning None.
+ return None
+
+ # Only test this now, so that we can create variables for the
+ # default. However, return None to signify that we're not going
+ # to actually do the assignment statement for InitVars.
+ if f._field_type == _FIELD_INITVAR:
+ return None
+
+ # Now, actually generate the field assignment.
+ return _field_assign(frozen, f.name, value, self_name)
+
+
+def _init_param(f):
+ # Return the __init__ parameter string for this field.
+ # For example, the equivalent of 'x:int=3' (except instead of 'int',
+ # reference a variable set to int, and instead of '3', reference a
+ # variable set to 3).
+ if f.default is _MISSING and f.default_factory is _MISSING:
+ # There's no default, and no default_factory, just
+ # output the variable name and type.
+ default = ''
+ elif f.default is not _MISSING:
+ # There's a default, this will be the name that's used to look it up.
+ default = f'=_dflt_{f.name}'
+ elif f.default_factory is not _MISSING:
+ # There's a factory function. Set a marker.
+ default = '=_HAS_DEFAULT_FACTORY'
+ return f'{f.name}:_type_{f.name}{default}'
+
+
+def _init_fn(fields, frozen, has_post_init, self_name):
+ # fields contains both real fields and InitVar pseudo-fields.
+
+ # Make sure we don't have fields without defaults following fields
+ # with defaults. This actually would be caught when exec-ing the
+ # function source code, but catching it here gives a better error
+ # message, and future-proofs us in case we build up the function
+ # using ast.
+ seen_default = False
+ for f in fields:
+ # Only consider fields in the __init__ call.
+ if f.init:
+ if not (f.default is _MISSING and f.default_factory is _MISSING):
+ seen_default = True
+ elif seen_default:
+ raise TypeError(f'non-default argument {f.name!r} '
+ 'follows default argument')
+
+ globals = {'_MISSING': _MISSING,
+ '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
+
+ body_lines = []
+ for f in fields:
+ # Do not initialize the pseudo-fields, only the real ones.
+ line = _field_init(f, frozen, globals, self_name)
+ if line is not None:
+ # line is None means that this field doesn't require
+ # initialization. Just skip it.
+ body_lines.append(line)
+
+ # Does this class have a post-init function?
+ if has_post_init:
+ params_str = ','.join(f.name for f in fields
+ if f._field_type is _FIELD_INITVAR)
+ body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
+
+ # If no body lines, use 'pass'.
+ if len(body_lines) == 0:
+ body_lines = ['pass']
+
+ locals = {f'_type_{f.name}': f.type for f in fields}
+ return _create_fn('__init__',
+ [self_name] +[_init_param(f) for f in fields if f.init],
+ body_lines,
+ locals=locals,
+ globals=globals,
+ return_type=None)
+
+
+def _repr_fn(fields):
+ return _create_fn('__repr__',
+ ['self'],
+ ['return self.__class__.__qualname__ + f"(' +
+ ', '.join([f"{f.name}={{self.{f.name}!r}}"
+ for f in fields]) +
+ ')"'])
+
+
+def _frozen_setattr(self, name, value):
+ raise FrozenInstanceError(f'cannot assign to field {name!r}')
+
+
+def _frozen_delattr(self, name):
+ raise FrozenInstanceError(f'cannot delete field {name!r}')
+
+
+def _cmp_fn(name, op, self_tuple, other_tuple):
+ # Create a comparison function. If the fields in the object are
+ # named 'x' and 'y', then self_tuple is the string
+ # '(self.x,self.y)' and other_tuple is the string
+ # '(other.x,other.y)'.
+
+ return _create_fn(name,
+ ['self', 'other'],
+ [ 'if other.__class__ is self.__class__:',
+ f' return {self_tuple}{op}{other_tuple}',
+ 'return NotImplemented'])
+
+
+def _set_eq_fns(cls, fields):
+ # Create and set the equality comparison methods on cls.
+ # Pre-compute self_tuple and other_tuple, then re-use them for
+ # each function.
+ self_tuple = _tuple_str('self', fields)
+ other_tuple = _tuple_str('other', fields)
+ for name, op in [('__eq__', '=='),
+ ('__ne__', '!='),
+ ]:
+ _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
+
+
+def _set_order_fns(cls, fields):
+ # Create and set the ordering methods on cls.
+ # Pre-compute self_tuple and other_tuple, then re-use them for
+ # each function.
+ self_tuple = _tuple_str('self', fields)
+ other_tuple = _tuple_str('other', fields)
+ for name, op in [('__lt__', '<'),
+ ('__le__', '<='),
+ ('__gt__', '>'),
+ ('__ge__', '>='),
+ ]:
+ _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
+
+
+def _hash_fn(fields):
+ self_tuple = _tuple_str('self', fields)
+ return _create_fn('__hash__',
+ ['self'],
+ [f'return hash({self_tuple})'])
+
+
+def _get_field(cls, a_name, a_type):
+ # Return a Field object, for this field name and type. ClassVars
+ # and InitVars are also returned, but marked as such (see
+ # f._field_type).
+
+ # If the default value isn't derived from field, then it's
+ # only a normal default value. Convert it to a Field().
+ default = getattr(cls, a_name, _MISSING)
+ if isinstance(default, Field):
+ f = default
+ else:
+ f = field(default=default)
+
+ # Assume it's a normal field until proven otherwise.
+ f._field_type = _FIELD
+
+ # Only at this point do we know the name and the type. Set them.
+ f.name = a_name
+ f.type = a_type
+
+ # If typing has not been imported, then it's impossible for
+ # any annotation to be a ClassVar. So, only look for ClassVar
+ # if typing has been imported.
+ typing = sys.modules.get('typing')
+ if typing is not None:
+ # This test uses a typing internal class, but it's the best
+ # way to test if this is a ClassVar.
+ if type(a_type) is typing._ClassVar:
+ # This field is a ClassVar, so it's not a field.
+ f._field_type = _FIELD_CLASSVAR
+
+ if f._field_type is _FIELD:
+ # Check if this is an InitVar.
+ if a_type is InitVar:
+ # InitVars are not fields, either.
+ f._field_type = _FIELD_INITVAR
+
+ # Validations for fields. This is delayed until now, instead of
+ # in the Field() constructor, since only here do we know the field
+ # name, which allows better error reporting.
+
+ # Special restrictions for ClassVar and InitVar.
+ if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR):
+ if f.default_factory is not _MISSING:
+ raise TypeError(f'field {f.name} cannot have a '
+ 'default factory')
+ # Should I check for other field settings? default_factory
+ # seems the most serious to check for. Maybe add others. For
+ # example, how about init=False (or really,
+ # init=<not-the-default-init-value>)? It makes no sense for
+ # ClassVar and InitVar to specify init=<anything>.
+
+ # For real fields, disallow mutable defaults for known types.
+ if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
+ raise ValueError(f'mutable default {type(f.default)} for field '
+ f'{f.name} is not allowed: use default_factory')
+
+ return f
+
+
+def _find_fields(cls):
+ # Return a list of Field objects, in order, for this class (and no
+ # base classes). Fields are found from __annotations__ (which is
+ # guaranteed to be ordered). Default values are from class
+ # attributes, if a field has a default. If the default value is
+ # a Field(), then it contains additional info beyond (and
+ # possibly including) the actual default value. Pseudo-fields
+ # ClassVars and InitVars are included, despite the fact that
+ # they're not real fields. That's deal with later.
+
+ annotations = getattr(cls, '__annotations__', {})
+
+ return [_get_field(cls, a_name, a_type)
+ for a_name, a_type in annotations.items()]
+
+
+def _set_attribute(cls, name, value):
+ # Raise TypeError if an attribute by this name already exists.
+ if name in cls.__dict__:
+ raise TypeError(f'Cannot overwrite attribute {name} '
+ f'in {cls.__name__}')
+ setattr(cls, name, value)
+
+
+def _process_class(cls, repr, eq, order, hash, init, frozen):
+ # Use an OrderedDict because:
+ # - Order matters!
+ # - Derived class fields overwrite base class fields, but the
+ # order is defined by the base class, which is found first.
+ fields = collections.OrderedDict()
+
+ # Find our base classes in reverse MRO order, and exclude
+ # ourselves. In reversed order so that more derived classes
+ # override earlier field definitions in base classes.
+ for b in cls.__mro__[-1:0:-1]:
+ # Only process classes that have been processed by our
+ # decorator. That is, they have a _MARKER attribute.
+ base_fields = getattr(b, _MARKER, None)
+ if base_fields:
+ for f in base_fields.values():
+ fields[f.name] = f
+
+ # Now find fields in our class. While doing so, validate some
+ # things, and set the default values (as class attributes)
+ # where we can.
+ for f in _find_fields(cls):
+ fields[f.name] = f
+
+ # If the class attribute (which is the default value for
+ # this field) exists and is of type 'Field', replace it
+ # with the real default. This is so that normal class
+ # introspection sees a real default value, not a Field.
+ if isinstance(getattr(cls, f.name, None), Field):
+ if f.default is _MISSING:
+ # If there's no default, delete the class attribute.
+ # This happens if we specify field(repr=False), for
+ # example (that is, we specified a field object, but
+ # no default value). Also if we're using a default
+ # factory. The class attribute should not be set at
+ # all in the post-processed class.
+ delattr(cls, f.name)
+ else:
+ setattr(cls, f.name, f.default)
+
+ # Remember all of the fields on our class (including bases). This
+ # marks this class as being a dataclass.
+ setattr(cls, _MARKER, fields)
+
+ # We also need to check if a parent class is frozen: frozen has to
+ # be inherited down.
+ is_frozen = frozen or cls.__setattr__ is _frozen_setattr
+
+ # If we're generating ordering methods, we must be generating
+ # the eq methods.
+ if order and not eq:
+ raise ValueError('eq must be true if order is true')
+
+ if init:
+ # Does this class have a post-init function?
+ has_post_init = hasattr(cls, _POST_INIT_NAME)
+
+ # Include InitVars and regular fields (so, not ClassVars).
+ _set_attribute(cls, '__init__',
+ _init_fn(list(filter(lambda f: f._field_type
+ in (_FIELD, _FIELD_INITVAR),
+ fields.values())),
+ is_frozen,
+ has_post_init,
+ # The name to use for the "self" param
+ # in __init__. Use "self" if possible.
+ '__dataclass_self__' if 'self' in fields
+ else 'self',
+ ))
+
+ # Get the fields as a list, and include only real fields. This is
+ # used in all of the following methods.
+ field_list = list(filter(lambda f: f._field_type is _FIELD,
+ fields.values()))
+
+ if repr:
+ _set_attribute(cls, '__repr__',
+ _repr_fn(list(filter(lambda f: f.repr, field_list))))
+
+ if is_frozen:
+ _set_attribute(cls, '__setattr__', _frozen_setattr)
+ _set_attribute(cls, '__delattr__', _frozen_delattr)
+
+ generate_hash = False
+ if hash is None:
+ if eq and frozen:
+ # Generate a hash function.
+ generate_hash = True
+ elif eq and not frozen:
+ # Not hashable.
+ _set_attribute(cls, '__hash__', None)
+ elif not eq:
+ # Otherwise, use the base class definition of hash(). That is,
+ # don't set anything on this class.
+ pass
+ else:
+ assert "can't get here"
+ else:
+ generate_hash = hash
+ if generate_hash:
+ _set_attribute(cls, '__hash__',
+ _hash_fn(list(filter(lambda f: f.compare
+ if f.hash is None
+ else f.hash,
+ field_list))))
+
+ if eq:
+ # Create and __eq__ and __ne__ methods.
+ _set_eq_fns(cls, list(filter(lambda f: f.compare, field_list)))
+
+ if order:
+ # Create and __lt__, __le__, __gt__, and __ge__ methods.
+ # Create and set the comparison functions.
+ _set_order_fns(cls, list(filter(lambda f: f.compare, field_list)))
+
+ if not getattr(cls, '__doc__'):
+ # Create a class doc-string.
+ cls.__doc__ = (cls.__name__ +
+ str(inspect.signature(cls)).replace(' -> None', ''))
+
+ return cls
+
+
+# _cls should never be specified by keyword, so start it with an
+# underscore. The presense of _cls is used to detect if this
+# decorator is being called with parameters or not.
+def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
+ hash=None, frozen=False):
+ """Returns the same class as was passed in, with dunder methods
+ added based on the fields defined in the class.
+
+ Examines PEP 526 __annotations__ to determine fields.
+
+ If init is true, an __init__() method is added to the class. If
+ repr is true, a __repr__() method is added. If order is true, rich
+ comparison dunder methods are added. If hash is true, a __hash__()
+ method function is added. If frozen is true, fields may not be
+ assigned to after instance creation.
+ """
+
+ def wrap(cls):
+ return _process_class(cls, repr, eq, order, hash, init, frozen)
+
+ # See if we're being called as @dataclass or @dataclass().
+ if _cls is None:
+ # We're called with parens.
+ return wrap
+
+ # We're called as @dataclass without parens.
+ return wrap(_cls)
+
+
+def fields(class_or_instance):
+ """Return a tuple describing the fields of this dataclass.
+
+ Accepts a dataclass or an instance of one. Tuple elements are of
+ type Field.
+ """
+
+ # Might it be worth caching this, per class?
+ try:
+ fields = getattr(class_or_instance, _MARKER)
+ except AttributeError:
+ raise TypeError('must be called with a dataclass type or instance')
+
+ # Exclude pseudo-fields.
+ return tuple(f for f in fields.values() if f._field_type is _FIELD)
+
+
+def _isdataclass(obj):
+ """Returns True if obj is an instance of a dataclass."""
+ return not isinstance(obj, type) and hasattr(obj, _MARKER)
+
+
+def asdict(obj, *, dict_factory=dict):
+ """Return the fields of a dataclass instance as a new dictionary mapping
+ field names to field values.
+
+ Example usage:
+
+ @dataclass
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ assert asdict(c) == {'x': 1, 'y': 2}
+
+ If given, 'dict_factory' will be used instead of built-in dict.
+ The function applies recursively to field values that are
+ dataclass instances. This will also look into built-in containers:
+ tuples, lists, and dicts.
+ """
+ if not _isdataclass(obj):
+ raise TypeError("asdict() should be called on dataclass instances")
+ return _asdict_inner(obj, dict_factory)
+
+def _asdict_inner(obj, dict_factory):
+ if _isdataclass(obj):
+ result = []
+ for f in fields(obj):
+ value = _asdict_inner(getattr(obj, f.name), dict_factory)
+ result.append((f.name, value))
+ return dict_factory(result)
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
+ elif isinstance(obj, dict):
+ return type(obj)((_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
+ for k, v in obj.items())
+ else:
+ return deepcopy(obj)
+
+
+def astuple(obj, *, tuple_factory=tuple):
+ """Return the fields of a dataclass instance as a new tuple of field values.
+
+ Example usage::
+
+ @dataclass
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ assert asdtuple(c) == (1, 2)
+
+ If given, 'tuple_factory' will be used instead of built-in tuple.
+ The function applies recursively to field values that are
+ dataclass instances. This will also look into built-in containers:
+ tuples, lists, and dicts.
+ """
+
+ if not _isdataclass(obj):
+ raise TypeError("astuple() should be called on dataclass instances")
+ return _astuple_inner(obj, tuple_factory)
+
+def _astuple_inner(obj, tuple_factory):
+ if _isdataclass(obj):
+ result = []
+ for f in fields(obj):
+ value = _astuple_inner(getattr(obj, f.name), tuple_factory)
+ result.append(value)
+ return tuple_factory(result)
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
+ elif isinstance(obj, dict):
+ return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
+ for k, v in obj.items())
+ else:
+ return deepcopy(obj)
+
+
+def make_dataclass(cls_name, fields, *, bases=(), namespace=None):
+ """Return a new dynamically created dataclass.
+
+ The dataclass name will be 'cls_name'. 'fields' is an interable
+ of either (name, type) or (name, type, Field) objects. Field
+ objects are created by calling 'field(name, type [, Field])'.
+
+ C = make_class('C', [('a', int', ('b', int, Field(init=False))], bases=Base)
+
+ is equivalent to:
+
+ @dataclass
+ class C(Base):
+ a: int
+ b: int = field(init=False)
+
+ For the bases and namespace paremeters, see the builtin type() function.
+ """
+
+ if namespace is None:
+ namespace = {}
+ else:
+ # Copy namespace since we're going to mutate it.
+ namespace = namespace.copy()
+
+ anns = collections.OrderedDict((name, tp) for name, tp, *_ in fields)
+ namespace['__annotations__'] = anns
+ for item in fields:
+ if len(item) == 3:
+ name, tp, spec = item
+ namespace[name] = spec
+ cls = type(cls_name, bases, namespace)
+ return dataclass(cls)
+
+
+def replace(obj, **changes):
+ """Return a new object replacing specified fields with new values.
+
+ This is especially useful for frozen classes. Example usage:
+
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ c1 = replace(c, x=3)
+ assert c1.x == 3 and c1.y == 2
+ """
+
+ # We're going to mutate 'changes', but that's okay because it's a new
+ # dict, even if called with 'replace(obj, **my_changes)'.
+
+ if not _isdataclass(obj):
+ raise TypeError("replace() should be called on dataclass instances")
+
+ # It's an error to have init=False fields in 'changes'.
+ # If a field is not in 'changes', read its value from the provided obj.
+
+ for f in getattr(obj, _MARKER).values():
+ if not f.init:
+ # Error if this field is specified in changes.
+ if f.name in changes:
+ raise ValueError(f'field {f.name} is declared with '
+ 'init=False, it cannot be specified with '
+ 'replace()')
+ continue
+
+ if f.name not in changes:
+ changes[f.name] = getattr(obj, f.name)
+
+ # Create the new object, which calls __init__() and __post_init__
+ # (if defined), using all of the init fields we've added and/or
+ # left in 'changes'.
+ # If there are values supplied in changes that aren't fields, this
+ # will correctly raise a TypeError.
+ return obj.__class__(**changes)
--- /dev/null
+from dataclasses import (
+ dataclass, field, FrozenInstanceError, fields, asdict, astuple,
+ make_dataclass, replace, InitVar, Field
+)
+
+import pickle
+import inspect
+import unittest
+from unittest.mock import Mock
+from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
+from collections import deque, OrderedDict, namedtuple
+
+# Just any custom exception we can catch.
+class CustomError(Exception): pass
+
+class TestCase(unittest.TestCase):
+ def test_no_fields(self):
+ @dataclass
+ class C:
+ pass
+
+ o = C()
+ self.assertEqual(len(fields(C)), 0)
+
+ def test_one_field_no_default(self):
+ @dataclass
+ class C:
+ x: int
+
+ o = C(42)
+ self.assertEqual(o.x, 42)
+
+ def test_named_init_params(self):
+ @dataclass
+ class C:
+ x: int
+
+ o = C(x=32)
+ self.assertEqual(o.x, 32)
+
+ def test_two_fields_one_default(self):
+ @dataclass
+ class C:
+ x: int
+ y: int = 0
+
+ o = C(3)
+ self.assertEqual((o.x, o.y), (3, 0))
+
+ # Non-defaults following defaults.
+ with self.assertRaisesRegex(TypeError,
+ "non-default argument 'y' follows "
+ "default argument"):
+ @dataclass
+ class C:
+ x: int = 0
+ y: int
+
+ # A derived class adds a non-default field after a default one.
+ with self.assertRaisesRegex(TypeError,
+ "non-default argument 'y' follows "
+ "default argument"):
+ @dataclass
+ class B:
+ x: int = 0
+
+ @dataclass
+ class C(B):
+ y: int
+
+ # Override a base class field and add a default to
+ # a field which didn't use to have a default.
+ with self.assertRaisesRegex(TypeError,
+ "non-default argument 'y' follows "
+ "default argument"):
+ @dataclass
+ class B:
+ x: int
+ y: int
+
+ @dataclass
+ class C(B):
+ x: int = 0
+
+ def test_overwriting_init(self):
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __init__ '
+ 'in C'):
+ @dataclass
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+
+ @dataclass(init=False)
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(5).x, 10)
+
+ def test_overwriting_repr(self):
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __repr__ '
+ 'in C'):
+ @dataclass
+ class C:
+ x: int
+ def __repr__(self):
+ pass
+
+ @dataclass(repr=False)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+ def test_overwriting_cmp(self):
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __eq__ '
+ 'in C'):
+ # This will generate the comparison functions, make sure we can't
+ # overwrite them.
+ @dataclass(hash=False, frozen=False)
+ class C:
+ x: int
+ def __eq__(self):
+ pass
+
+ @dataclass(order=False, eq=False)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return True
+ self.assertEqual(C(0), 'x')
+
+ def test_overwriting_hash(self):
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __hash__ '
+ 'in C'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __hash__(self):
+ pass
+
+ @dataclass(frozen=True,hash=False)
+ class C:
+ x: int
+ def __hash__(self):
+ return 600
+ self.assertEqual(hash(C(0)), 600)
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __hash__ '
+ 'in C'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __hash__(self):
+ pass
+
+ @dataclass(frozen=True, hash=False)
+ class C:
+ x: int
+ def __hash__(self):
+ return 600
+ self.assertEqual(hash(C(0)), 600)
+
+ def test_overwriting_frozen(self):
+ # frozen uses __setattr__ and __delattr__
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __setattr__ '
+ 'in C'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __setattr__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __delattr__ '
+ 'in C'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __delattr__(self):
+ pass
+
+ @dataclass(frozen=False)
+ class C:
+ x: int
+ def __setattr__(self, name, value):
+ self.__dict__['x'] = value * 2
+ self.assertEqual(C(10).x, 20)
+
+ def test_overwrite_fields_in_derived_class(self):
+ # Note that x from C1 replaces x in Base, but the order remains
+ # the same as defined in Base.
+ @dataclass
+ class Base:
+ x: Any = 15.0
+ y: int = 0
+
+ @dataclass
+ class C1(Base):
+ z: int = 10
+ x: int = 15
+
+ o = Base()
+ self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
+
+ o = C1()
+ self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
+
+ o = C1(x=5)
+ self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
+
+ def test_field_named_self(self):
+ @dataclass
+ class C:
+ self: str
+ c=C('foo')
+ self.assertEqual(c.self, 'foo')
+
+ # Make sure the first parameter is not named 'self'.
+ sig = inspect.signature(C.__init__)
+ first = next(iter(sig.parameters))
+ self.assertNotEqual('self', first)
+
+ # But we do use 'self' if no field named self.
+ @dataclass
+ class C:
+ selfx: str
+
+ # Make sure the first parameter is named 'self'.
+ sig = inspect.signature(C.__init__)
+ first = next(iter(sig.parameters))
+ self.assertEqual('self', first)
+
+ def test_repr(self):
+ @dataclass
+ class B:
+ x: int
+
+ @dataclass
+ class C(B):
+ y: int = 10
+
+ o = C(4)
+ self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
+
+ @dataclass
+ class D(C):
+ x: int = 20
+ self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
+
+ @dataclass
+ class C:
+ @dataclass
+ class D:
+ i: int
+ @dataclass
+ class E:
+ pass
+ self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
+ self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
+
+ def test_0_field_compare(self):
+ # Ensure that order=False is the default.
+ @dataclass
+ class C0:
+ pass
+
+ @dataclass(order=False)
+ class C1:
+ pass
+
+ for cls in [C0, C1]:
+ with self.subTest(cls=cls):
+ self.assertEqual(cls(), cls())
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a > b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaisesRegex(TypeError,
+ f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
+ fn(cls(), cls())
+
+ @dataclass(order=True)
+ class C:
+ pass
+ self.assertLessEqual(C(), C())
+ self.assertGreaterEqual(C(), C())
+
+ def test_1_field_compare(self):
+ # Ensure that order=False is the default.
+ @dataclass
+ class C0:
+ x: int
+
+ @dataclass(order=False)
+ class C1:
+ x: int
+
+ for cls in [C0, C1]:
+ with self.subTest(cls=cls):
+ self.assertEqual(cls(1), cls(1))
+ self.assertNotEqual(cls(0), cls(1))
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a > b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaisesRegex(TypeError,
+ f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
+ fn(cls(0), cls(0))
+
+ @dataclass(order=True)
+ class C:
+ x: int
+ self.assertLess(C(0), C(1))
+ self.assertLessEqual(C(0), C(1))
+ self.assertLessEqual(C(1), C(1))
+ self.assertGreater(C(1), C(0))
+ self.assertGreaterEqual(C(1), C(0))
+ self.assertGreaterEqual(C(1), C(1))
+
+ def test_simple_compare(self):
+ # Ensure that order=False is the default.
+ @dataclass
+ class C0:
+ x: int
+ y: int
+
+ @dataclass(order=False)
+ class C1:
+ x: int
+ y: int
+
+ for cls in [C0, C1]:
+ with self.subTest(cls=cls):
+ self.assertEqual(cls(0, 0), cls(0, 0))
+ self.assertEqual(cls(1, 2), cls(1, 2))
+ self.assertNotEqual(cls(1, 0), cls(0, 0))
+ self.assertNotEqual(cls(1, 0), cls(1, 1))
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a > b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaisesRegex(TypeError,
+ f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
+ fn(cls(0, 0), cls(0, 0))
+
+ @dataclass(order=True)
+ class C:
+ x: int
+ y: int
+
+ for idx, fn in enumerate([lambda a, b: a == b,
+ lambda a, b: a <= b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ self.assertTrue(fn(C(0, 0), C(0, 0)))
+
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a != b]):
+ with self.subTest(idx=idx):
+ self.assertTrue(fn(C(0, 0), C(0, 1)))
+ self.assertTrue(fn(C(0, 1), C(1, 0)))
+ self.assertTrue(fn(C(1, 0), C(1, 1)))
+
+ for idx, fn in enumerate([lambda a, b: a > b,
+ lambda a, b: a >= b,
+ lambda a, b: a != b]):
+ with self.subTest(idx=idx):
+ self.assertTrue(fn(C(0, 1), C(0, 0)))
+ self.assertTrue(fn(C(1, 0), C(0, 1)))
+ self.assertTrue(fn(C(1, 1), C(1, 0)))
+
+ def test_compare_subclasses(self):
+ # Comparisons fail for subclasses, even if no fields
+ # are added.
+ @dataclass
+ class B:
+ i: int
+
+ @dataclass
+ class C(B):
+ pass
+
+ for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
+ (lambda a, b: a != b, True)]):
+ with self.subTest(idx=idx):
+ self.assertEqual(fn(B(0), C(0)), expected)
+
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a > b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaisesRegex(TypeError,
+ "not supported between instances of 'B' and 'C'"):
+ fn(B(0), C(0))
+
+ def test_0_field_hash(self):
+ @dataclass(hash=True)
+ class C:
+ pass
+ self.assertEqual(hash(C()), hash(()))
+
+ def test_1_field_hash(self):
+ @dataclass(hash=True)
+ class C:
+ x: int
+ self.assertEqual(hash(C(4)), hash((4,)))
+ self.assertEqual(hash(C(42)), hash((42,)))
+
+ def test_hash(self):
+ @dataclass(hash=True)
+ class C:
+ x: int
+ y: str
+ self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
+
+ def test_no_hash(self):
+ @dataclass(hash=None)
+ class C:
+ x: int
+ with self.assertRaisesRegex(TypeError,
+ "unhashable type: 'C'"):
+ hash(C(1))
+
+ def test_hash_rules(self):
+ # There are 24 cases of:
+ # hash=True/False/None
+ # eq=True/False
+ # order=True/False
+ # frozen=True/False
+ for (hash, eq, order, frozen, result ) in [
+ (False, False, False, False, 'absent'),
+ (False, False, False, True, 'absent'),
+ (False, False, True, False, 'exception'),
+ (False, False, True, True, 'exception'),
+ (False, True, False, False, 'absent'),
+ (False, True, False, True, 'absent'),
+ (False, True, True, False, 'absent'),
+ (False, True, True, True, 'absent'),
+ (True, False, False, False, 'fn'),
+ (True, False, False, True, 'fn'),
+ (True, False, True, False, 'exception'),
+ (True, False, True, True, 'exception'),
+ (True, True, False, False, 'fn'),
+ (True, True, False, True, 'fn'),
+ (True, True, True, False, 'fn'),
+ (True, True, True, True, 'fn'),
+ (None, False, False, False, 'absent'),
+ (None, False, False, True, 'absent'),
+ (None, False, True, False, 'exception'),
+ (None, False, True, True, 'exception'),
+ (None, True, False, False, 'none'),
+ (None, True, False, True, 'fn'),
+ (None, True, True, False, 'none'),
+ (None, True, True, True, 'fn'),
+ ]:
+ with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
+ if result == 'exception':
+ with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
+ @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
+ class C:
+ pass
+ else:
+ @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
+ class C:
+ pass
+
+ # See if the result matches what's expected.
+ if result == 'fn':
+ # __hash__ contains the function we generated.
+ self.assertIn('__hash__', C.__dict__)
+ self.assertIsNotNone(C.__dict__['__hash__'])
+ elif result == 'absent':
+ # __hash__ is not present in our class.
+ self.assertNotIn('__hash__', C.__dict__)
+ elif result == 'none':
+ # __hash__ is set to None.
+ self.assertIn('__hash__', C.__dict__)
+ self.assertIsNone(C.__dict__['__hash__'])
+ else:
+ assert False, f'unknown result {result!r}'
+
+ def test_eq_order(self):
+ for (eq, order, result ) in [
+ (False, False, 'neither'),
+ (False, True, 'exception'),
+ (True, False, 'eq_only'),
+ (True, True, 'both'),
+ ]:
+ with self.subTest(eq=eq, order=order):
+ if result == 'exception':
+ with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
+ @dataclass(eq=eq, order=order)
+ class C:
+ pass
+ else:
+ @dataclass(eq=eq, order=order)
+ class C:
+ pass
+
+ if result == 'neither':
+ self.assertNotIn('__eq__', C.__dict__)
+ self.assertNotIn('__ne__', C.__dict__)
+ self.assertNotIn('__lt__', C.__dict__)
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ elif result == 'both':
+ self.assertIn('__eq__', C.__dict__)
+ self.assertIn('__ne__', C.__dict__)
+ self.assertIn('__lt__', C.__dict__)
+ self.assertIn('__le__', C.__dict__)
+ self.assertIn('__gt__', C.__dict__)
+ self.assertIn('__ge__', C.__dict__)
+ elif result == 'eq_only':
+ self.assertIn('__eq__', C.__dict__)
+ self.assertIn('__ne__', C.__dict__)
+ self.assertNotIn('__lt__', C.__dict__)
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ else:
+ assert False, f'unknown result {result!r}'
+
+ def test_field_no_default(self):
+ @dataclass
+ class C:
+ x: int = field()
+
+ self.assertEqual(C(5).x, 5)
+
+ with self.assertRaisesRegex(TypeError,
+ r"__init__\(\) missing 1 required "
+ "positional argument: 'x'"):
+ C()
+
+ def test_field_default(self):
+ default = object()
+ @dataclass
+ class C:
+ x: object = field(default=default)
+
+ self.assertIs(C.x, default)
+ c = C(10)
+ self.assertEqual(c.x, 10)
+
+ # If we delete the instance attribute, we should then see the
+ # class attribute.
+ del c.x
+ self.assertIs(c.x, default)
+
+ self.assertIs(C().x, default)
+
+ def test_not_in_repr(self):
+ @dataclass
+ class C:
+ x: int = field(repr=False)
+ with self.assertRaises(TypeError):
+ C()
+ c = C(10)
+ self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
+
+ @dataclass
+ class C:
+ x: int = field(repr=False)
+ y: int
+ c = C(10, 20)
+ self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
+
+ def test_not_in_compare(self):
+ @dataclass
+ class C:
+ x: int = 0
+ y: int = field(compare=False, default=4)
+
+ self.assertEqual(C(), C(0, 20))
+ self.assertEqual(C(1, 10), C(1, 20))
+ self.assertNotEqual(C(3), C(4, 10))
+ self.assertNotEqual(C(3, 10), C(4, 10))
+
+ def test_hash_field_rules(self):
+ # Test all 6 cases of:
+ # hash=True/False/None
+ # compare=True/False
+ for (hash_val, compare, result ) in [
+ (True, False, 'field' ),
+ (True, True, 'field' ),
+ (False, False, 'absent'),
+ (False, True, 'absent'),
+ (None, False, 'absent'),
+ (None, True, 'field' ),
+ ]:
+ with self.subTest(hash_val=hash_val, compare=compare):
+ @dataclass(hash=True)
+ class C:
+ x: int = field(compare=compare, hash=hash_val, default=5)
+
+ if result == 'field':
+ # __hash__ contains the field.
+ self.assertEqual(C(5).__hash__(), hash((5,)))
+ elif result == 'absent':
+ # The field is not present in the hash.
+ self.assertEqual(C(5).__hash__(), hash(()))
+ else:
+ assert False, f'unknown result {result!r}'
+
+ def test_init_false_no_default(self):
+ # If init=False and no default value, then the field won't be
+ # present in the instance.
+ @dataclass
+ class C:
+ x: int = field(init=False)
+
+ self.assertNotIn('x', C().__dict__)
+
+ @dataclass
+ class C:
+ x: int
+ y: int = 0
+ z: int = field(init=False)
+ t: int = 10
+
+ self.assertNotIn('z', C(0).__dict__)
+ self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
+
+ def test_class_marker(self):
+ @dataclass
+ class C:
+ x: int
+ y: str = field(init=False, default=None)
+ z: str = field(repr=False)
+
+ the_fields = fields(C)
+ # the_fields is a tuple of 3 items, each value
+ # is in __annotations__.
+ self.assertIsInstance(the_fields, tuple)
+ for f in the_fields:
+ self.assertIs(type(f), Field)
+ self.assertIn(f.name, C.__annotations__)
+
+ self.assertEqual(len(the_fields), 3)
+
+ self.assertEqual(the_fields[0].name, 'x')
+ self.assertEqual(the_fields[0].type, int)
+ self.assertFalse(hasattr(C, 'x'))
+ self.assertTrue (the_fields[0].init)
+ self.assertTrue (the_fields[0].repr)
+ self.assertEqual(the_fields[1].name, 'y')
+ self.assertEqual(the_fields[1].type, str)
+ self.assertIsNone(getattr(C, 'y'))
+ self.assertFalse(the_fields[1].init)
+ self.assertTrue (the_fields[1].repr)
+ self.assertEqual(the_fields[2].name, 'z')
+ self.assertEqual(the_fields[2].type, str)
+ self.assertFalse(hasattr(C, 'z'))
+ self.assertTrue (the_fields[2].init)
+ self.assertFalse(the_fields[2].repr)
+
+ def test_field_order(self):
+ @dataclass
+ class B:
+ a: str = 'B:a'
+ b: str = 'B:b'
+ c: str = 'B:c'
+
+ @dataclass
+ class C(B):
+ b: str = 'C:b'
+
+ self.assertEqual([(f.name, f.default) for f in fields(C)],
+ [('a', 'B:a'),
+ ('b', 'C:b'),
+ ('c', 'B:c')])
+
+ @dataclass
+ class D(B):
+ c: str = 'D:c'
+
+ self.assertEqual([(f.name, f.default) for f in fields(D)],
+ [('a', 'B:a'),
+ ('b', 'B:b'),
+ ('c', 'D:c')])
+
+ @dataclass
+ class E(D):
+ a: str = 'E:a'
+ d: str = 'E:d'
+
+ self.assertEqual([(f.name, f.default) for f in fields(E)],
+ [('a', 'E:a'),
+ ('b', 'B:b'),
+ ('c', 'D:c'),
+ ('d', 'E:d')])
+
+ def test_class_attrs(self):
+ # We only have a class attribute if a default value is
+ # specified, either directly or via a field with a default.
+ default = object()
+ @dataclass
+ class C:
+ x: int
+ y: int = field(repr=False)
+ z: object = default
+ t: int = field(default=100)
+
+ self.assertFalse(hasattr(C, 'x'))
+ self.assertFalse(hasattr(C, 'y'))
+ self.assertIs (C.z, default)
+ self.assertEqual(C.t, 100)
+
+ def test_disallowed_mutable_defaults(self):
+ # For the known types, don't allow mutable default values.
+ for typ, empty, non_empty in [(list, [], [1]),
+ (dict, {}, {0:1}),
+ (set, set(), set([1])),
+ ]:
+ with self.subTest(typ=typ):
+ # Can't use a zero-length value.
+ with self.assertRaisesRegex(ValueError,
+ f'mutable default {typ} for field '
+ 'x is not allowed'):
+ @dataclass
+ class Point:
+ x: typ = empty
+
+
+ # Nor a non-zero-length value
+ with self.assertRaisesRegex(ValueError,
+ f'mutable default {typ} for field '
+ 'y is not allowed'):
+ @dataclass
+ class Point:
+ y: typ = non_empty
+
+ # Check subtypes also fail.
+ class Subclass(typ): pass
+
+ with self.assertRaisesRegex(ValueError,
+ f"mutable default .*Subclass'>"
+ ' for field z is not allowed'
+ ):
+ @dataclass
+ class Point:
+ z: typ = Subclass()
+
+ # Because this is a ClassVar, it can be mutable.
+ @dataclass
+ class C:
+ z: ClassVar[typ] = typ()
+
+ # Because this is a ClassVar, it can be mutable.
+ @dataclass
+ class C:
+ x: ClassVar[typ] = Subclass()
+
+
+ def test_deliberately_mutable_defaults(self):
+ # If a mutable default isn't in the known list of
+ # (list, dict, set), then it's okay.
+ class Mutable:
+ def __init__(self):
+ self.l = []
+
+ @dataclass
+ class C:
+ x: Mutable
+
+ # These 2 instances will share this value of x.
+ lst = Mutable()
+ o1 = C(lst)
+ o2 = C(lst)
+ self.assertEqual(o1, o2)
+ o1.x.l.extend([1, 2])
+ self.assertEqual(o1, o2)
+ self.assertEqual(o1.x.l, [1, 2])
+ self.assertIs(o1.x, o2.x)
+
+ def test_no_options(self):
+ # call with dataclass()
+ @dataclass()
+ class C:
+ x: int
+
+ self.assertEqual(C(42).x, 42)
+
+ def test_not_tuple(self):
+ # Make sure we can't be compared to a tuple.
+ @dataclass
+ class Point:
+ x: int
+ y: int
+ self.assertNotEqual(Point(1, 2), (1, 2))
+
+ # And that we can't compare to another unrelated dataclass
+ @dataclass
+ class C:
+ x: int
+ y: int
+ self.assertNotEqual(Point(1, 3), C(1, 3))
+
+ def test_base_has_init(self):
+ class B:
+ def __init__(self):
+ pass
+
+ # Make sure that declaring this class doesn't raise an error.
+ # The issue is that we can't override __init__ in our class,
+ # but it should be okay to add __init__ to us if our base has
+ # an __init__.
+ @dataclass
+ class C(B):
+ x: int = 0
+
+ def test_frozen(self):
+ @dataclass(frozen=True)
+ class C:
+ i: int
+
+ c = C(10)
+ self.assertEqual(c.i, 10)
+ with self.assertRaises(FrozenInstanceError):
+ c.i = 5
+ self.assertEqual(c.i, 10)
+
+ # Check that a derived class is still frozen, even if not
+ # marked so.
+ @dataclass
+ class D(C):
+ pass
+
+ d = D(20)
+ self.assertEqual(d.i, 20)
+ with self.assertRaises(FrozenInstanceError):
+ d.i = 5
+ self.assertEqual(d.i, 20)
+
+ def test_not_tuple(self):
+ # Test that some of the problems with namedtuple don't happen
+ # here.
+ @dataclass
+ class Point3D:
+ x: int
+ y: int
+ z: int
+
+ @dataclass
+ class Date:
+ year: int
+ month: int
+ day: int
+
+ self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
+ self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
+
+ # Make sure we can't unpack
+ with self.assertRaisesRegex(TypeError, 'is not iterable'):
+ x, y, z = Point3D(4, 5, 6)
+
+ # Maka sure another class with the same field names isn't
+ # equal.
+ @dataclass
+ class Point3Dv1:
+ x: int = 0
+ y: int = 0
+ z: int = 0
+ self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
+
+ def test_function_annotations(self):
+ # Some dummy class and instance to use as a default.
+ class F:
+ pass
+ f = F()
+
+ def validate_class(cls):
+ # First, check __annotations__, even though they're not
+ # function annotations.
+ self.assertEqual(cls.__annotations__['i'], int)
+ self.assertEqual(cls.__annotations__['j'], str)
+ self.assertEqual(cls.__annotations__['k'], F)
+ self.assertEqual(cls.__annotations__['l'], float)
+ self.assertEqual(cls.__annotations__['z'], complex)
+
+ # Verify __init__.
+
+ signature = inspect.signature(cls.__init__)
+ # Check the return type, should be None
+ self.assertIs(signature.return_annotation, None)
+
+ # Check each parameter.
+ params = iter(signature.parameters.values())
+ param = next(params)
+ # This is testing an internal name, and probably shouldn't be tested.
+ self.assertEqual(param.name, 'self')
+ param = next(params)
+ self.assertEqual(param.name, 'i')
+ self.assertIs (param.annotation, int)
+ self.assertEqual(param.default, inspect.Parameter.empty)
+ self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ param = next(params)
+ self.assertEqual(param.name, 'j')
+ self.assertIs (param.annotation, str)
+ self.assertEqual(param.default, inspect.Parameter.empty)
+ self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ param = next(params)
+ self.assertEqual(param.name, 'k')
+ self.assertIs (param.annotation, F)
+ # Don't test for the default, since it's set to _MISSING
+ self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ param = next(params)
+ self.assertEqual(param.name, 'l')
+ self.assertIs (param.annotation, float)
+ # Don't test for the default, since it's set to _MISSING
+ self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ self.assertRaises(StopIteration, next, params)
+
+
+ @dataclass
+ class C:
+ i: int
+ j: str
+ k: F = f
+ l: float=field(default=None)
+ z: complex=field(default=3+4j, init=False)
+
+ validate_class(C)
+
+ # Now repeat with __hash__.
+ @dataclass(frozen=True, hash=True)
+ class C:
+ i: int
+ j: str
+ k: F = f
+ l: float=field(default=None)
+ z: complex=field(default=3+4j, init=False)
+
+ validate_class(C)
+
+ def test_dont_include_other_annotations(self):
+ @dataclass
+ class C:
+ i: int
+ def foo(self) -> int:
+ return 4
+ @property
+ def bar(self) -> int:
+ return 5
+ self.assertEqual(list(C.__annotations__), ['i'])
+ self.assertEqual(C(10).foo(), 4)
+ self.assertEqual(C(10).bar, 5)
+
+ def test_post_init(self):
+ # Just make sure it gets called
+ @dataclass
+ class C:
+ def __post_init__(self):
+ raise CustomError()
+ with self.assertRaises(CustomError):
+ C()
+
+ @dataclass
+ class C:
+ i: int = 10
+ def __post_init__(self):
+ if self.i == 10:
+ raise CustomError()
+ with self.assertRaises(CustomError):
+ C()
+ # post-init gets called, but doesn't raise. This is just
+ # checking that self is used correctly.
+ C(5)
+
+ # If there's not an __init__, then post-init won't get called.
+ @dataclass(init=False)
+ class C:
+ def __post_init__(self):
+ raise CustomError()
+ # Creating the class won't raise
+ C()
+
+ @dataclass
+ class C:
+ x: int = 0
+ def __post_init__(self):
+ self.x *= 2
+ self.assertEqual(C().x, 0)
+ self.assertEqual(C(2).x, 4)
+
+ # Make sure that if we'r frozen, post-init can't set
+ # attributes.
+ @dataclass(frozen=True)
+ class C:
+ x: int = 0
+ def __post_init__(self):
+ self.x *= 2
+ with self.assertRaises(FrozenInstanceError):
+ C()
+
+ def test_post_init_super(self):
+ # Make sure super() post-init isn't called by default.
+ class B:
+ def __post_init__(self):
+ raise CustomError()
+
+ @dataclass
+ class C(B):
+ def __post_init__(self):
+ self.x = 5
+
+ self.assertEqual(C().x, 5)
+
+ # Now call super(), and it will raise
+ @dataclass
+ class C(B):
+ def __post_init__(self):
+ super().__post_init__()
+
+ with self.assertRaises(CustomError):
+ C()
+
+ # Make sure post-init is called, even if not defined in our
+ # class.
+ @dataclass
+ class C(B):
+ pass
+
+ with self.assertRaises(CustomError):
+ C()
+
+ def test_post_init_staticmethod(self):
+ flag = False
+ @dataclass
+ class C:
+ x: int
+ y: int
+ @staticmethod
+ def __post_init__():
+ nonlocal flag
+ flag = True
+
+ self.assertFalse(flag)
+ c = C(3, 4)
+ self.assertEqual((c.x, c.y), (3, 4))
+ self.assertTrue(flag)
+
+ def test_post_init_classmethod(self):
+ @dataclass
+ class C:
+ flag = False
+ x: int
+ y: int
+ @classmethod
+ def __post_init__(cls):
+ cls.flag = True
+
+ self.assertFalse(C.flag)
+ c = C(3, 4)
+ self.assertEqual((c.x, c.y), (3, 4))
+ self.assertTrue(C.flag)
+
+ def test_class_var(self):
+ # Make sure ClassVars are ignored in __init__, __repr__, etc.
+ @dataclass
+ class C:
+ x: int
+ y: int = 10
+ z: ClassVar[int] = 1000
+ w: ClassVar[int] = 2000
+ t: ClassVar[int] = 3000
+
+ c = C(5)
+ self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
+ self.assertEqual(len(fields(C)), 2) # We have 2 fields
+ self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
+ self.assertEqual(c.z, 1000)
+ self.assertEqual(c.w, 2000)
+ self.assertEqual(c.t, 3000)
+ C.z += 1
+ self.assertEqual(c.z, 1001)
+ c = C(20)
+ self.assertEqual((c.x, c.y), (20, 10))
+ self.assertEqual(c.z, 1001)
+ self.assertEqual(c.w, 2000)
+ self.assertEqual(c.t, 3000)
+
+ def test_class_var_no_default(self):
+ # If a ClassVar has no default value, it should not be set on the class.
+ @dataclass
+ class C:
+ x: ClassVar[int]
+
+ self.assertNotIn('x', C.__dict__)
+
+ def test_class_var_default_factory(self):
+ # It makes no sense for a ClassVar to have a default factory. When
+ # would it be called? Call it yourself, since it's class-wide.
+ with self.assertRaisesRegex(TypeError,
+ 'cannot have a default factory'):
+ @dataclass
+ class C:
+ x: ClassVar[int] = field(default_factory=int)
+
+ self.assertNotIn('x', C.__dict__)
+
+ def test_class_var_with_default(self):
+ # If a ClassVar has a default value, it should be set on the class.
+ @dataclass
+ class C:
+ x: ClassVar[int] = 10
+ self.assertEqual(C.x, 10)
+
+ @dataclass
+ class C:
+ x: ClassVar[int] = field(default=10)
+ self.assertEqual(C.x, 10)
+
+ def test_class_var_frozen(self):
+ # Make sure ClassVars work even if we're frozen.
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int = 10
+ z: ClassVar[int] = 1000
+ w: ClassVar[int] = 2000
+ t: ClassVar[int] = 3000
+
+ c = C(5)
+ self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
+ self.assertEqual(len(fields(C)), 2) # We have 2 fields
+ self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
+ self.assertEqual(c.z, 1000)
+ self.assertEqual(c.w, 2000)
+ self.assertEqual(c.t, 3000)
+ # We can still modify the ClassVar, it's only instances that are
+ # frozen.
+ C.z += 1
+ self.assertEqual(c.z, 1001)
+ c = C(20)
+ self.assertEqual((c.x, c.y), (20, 10))
+ self.assertEqual(c.z, 1001)
+ self.assertEqual(c.w, 2000)
+ self.assertEqual(c.t, 3000)
+
+ def test_init_var_no_default(self):
+ # If an InitVar has no default value, it should not be set on the class.
+ @dataclass
+ class C:
+ x: InitVar[int]
+
+ self.assertNotIn('x', C.__dict__)
+
+ def test_init_var_default_factory(self):
+ # It makes no sense for an InitVar to have a default factory. When
+ # would it be called? Call it yourself, since it's class-wide.
+ with self.assertRaisesRegex(TypeError,
+ 'cannot have a default factory'):
+ @dataclass
+ class C:
+ x: InitVar[int] = field(default_factory=int)
+
+ self.assertNotIn('x', C.__dict__)
+
+ def test_init_var_with_default(self):
+ # If an InitVar has a default value, it should be set on the class.
+ @dataclass
+ class C:
+ x: InitVar[int] = 10
+ self.assertEqual(C.x, 10)
+
+ @dataclass
+ class C:
+ x: InitVar[int] = field(default=10)
+ self.assertEqual(C.x, 10)
+
+ def test_init_var(self):
+ @dataclass
+ class C:
+ x: int = None
+ init_param: InitVar[int] = None
+
+ def __post_init__(self, init_param):
+ if self.x is None:
+ self.x = init_param*2
+
+ c = C(init_param=10)
+ self.assertEqual(c.x, 20)
+
+ def test_init_var_inheritance(self):
+ # Note that this deliberately tests that a dataclass need not
+ # have a __post_init__ function if it has an InitVar field.
+ # It could just be used in a derived class, as shown here.
+ @dataclass
+ class Base:
+ x: int
+ init_base: InitVar[int]
+
+ # We can instantiate by passing the InitVar, even though
+ # it's not used.
+ b = Base(0, 10)
+ self.assertEqual(vars(b), {'x': 0})
+
+ @dataclass
+ class C(Base):
+ y: int
+ init_derived: InitVar[int]
+
+ def __post_init__(self, init_base, init_derived):
+ self.x = self.x + init_base
+ self.y = self.y + init_derived
+
+ c = C(10, 11, 50, 51)
+ self.assertEqual(vars(c), {'x': 21, 'y': 101})
+
+ def test_default_factory(self):
+ # Test a factory that returns a new list.
+ @dataclass
+ class C:
+ x: int
+ y: list = field(default_factory=list)
+
+ c0 = C(3)
+ c1 = C(3)
+ self.assertEqual(c0.x, 3)
+ self.assertEqual(c0.y, [])
+ self.assertEqual(c0, c1)
+ self.assertIsNot(c0.y, c1.y)
+ self.assertEqual(astuple(C(5, [1])), (5, [1]))
+
+ # Test a factory that returns a shared list.
+ l = []
+ @dataclass
+ class C:
+ x: int
+ y: list = field(default_factory=lambda: l)
+
+ c0 = C(3)
+ c1 = C(3)
+ self.assertEqual(c0.x, 3)
+ self.assertEqual(c0.y, [])
+ self.assertEqual(c0, c1)
+ self.assertIs(c0.y, c1.y)
+ self.assertEqual(astuple(C(5, [1])), (5, [1]))
+
+ # Test various other field flags.
+ # repr
+ @dataclass
+ class C:
+ x: list = field(default_factory=list, repr=False)
+ self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
+ self.assertEqual(C().x, [])
+
+ # hash
+ @dataclass(hash=True)
+ class C:
+ x: list = field(default_factory=list, hash=False)
+ self.assertEqual(astuple(C()), ([],))
+ self.assertEqual(hash(C()), hash(()))
+
+ # init (see also test_default_factory_with_no_init)
+ @dataclass
+ class C:
+ x: list = field(default_factory=list, init=False)
+ self.assertEqual(astuple(C()), ([],))
+
+ # compare
+ @dataclass
+ class C:
+ x: list = field(default_factory=list, compare=False)
+ self.assertEqual(C(), C([1]))
+
+ def test_default_factory_with_no_init(self):
+ # We need a factory with a side effect.
+ factory = Mock()
+
+ @dataclass
+ class C:
+ x: list = field(default_factory=factory, init=False)
+
+ # Make sure the default factory is called for each new instance.
+ C().x
+ self.assertEqual(factory.call_count, 1)
+ C().x
+ self.assertEqual(factory.call_count, 2)
+
+ def test_default_factory_not_called_if_value_given(self):
+ # We need a factory that we can test if it's been called.
+ factory = Mock()
+
+ @dataclass
+ class C:
+ x: int = field(default_factory=factory)
+
+ # Make sure that if a field has a default factory function,
+ # it's not called if a value is specified.
+ C().x
+ self.assertEqual(factory.call_count, 1)
+ self.assertEqual(C(10).x, 10)
+ self.assertEqual(factory.call_count, 1)
+ C().x
+ self.assertEqual(factory.call_count, 2)
+
+ def x_test_classvar_default_factory(self):
+ # XXX: it's an error for a ClassVar to have a factory function
+ @dataclass
+ class C:
+ x: ClassVar[int] = field(default_factory=int)
+
+ self.assertIs(C().x, int)
+
+ def test_isdataclass(self):
+ # There is no isdataclass() helper any more, but the PEP
+ # describes how to write it, so make sure that works. Note
+ # that this version returns True for both classes and
+ # instances.
+ def isdataclass(obj):
+ try:
+ fields(obj)
+ return True
+ except TypeError:
+ return False
+
+ self.assertFalse(isdataclass(0))
+ self.assertFalse(isdataclass(int))
+
+ @dataclass
+ class C:
+ x: int
+
+ self.assertTrue(isdataclass(C))
+ self.assertTrue(isdataclass(C(0)))
+
+ def test_helper_fields_with_class_instance(self):
+ # Check that we can call fields() on either a class or instance,
+ # and get back the same thing.
+ @dataclass
+ class C:
+ x: int
+ y: float
+
+ self.assertEqual(fields(C), fields(C(0, 0.0)))
+
+ def test_helper_fields_exception(self):
+ # Check that TypeError is raised if not passed a dataclass or
+ # instance.
+ with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
+ fields(0)
+
+ class C: pass
+ with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
+ fields(C)
+ with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
+ fields(C())
+
+ def test_helper_asdict(self):
+ # Basic tests for asdict(), it should return a new dictionary
+ @dataclass
+ class C:
+ x: int
+ y: int
+ c = C(1, 2)
+
+ self.assertEqual(asdict(c), {'x': 1, 'y': 2})
+ self.assertEqual(asdict(c), asdict(c))
+ self.assertIsNot(asdict(c), asdict(c))
+ c.x = 42
+ self.assertEqual(asdict(c), {'x': 42, 'y': 2})
+ self.assertIs(type(asdict(c)), dict)
+
+ def test_helper_asdict_raises_on_classes(self):
+ # asdict() should raise on a class object
+ @dataclass
+ class C:
+ x: int
+ y: int
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ asdict(C)
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ asdict(int)
+
+ def test_helper_asdict_copy_values(self):
+ @dataclass
+ class C:
+ x: int
+ y: List[int] = field(default_factory=list)
+ initial = []
+ c = C(1, initial)
+ d = asdict(c)
+ self.assertEqual(d['y'], initial)
+ self.assertIsNot(d['y'], initial)
+ c = C(1)
+ d = asdict(c)
+ d['y'].append(1)
+ self.assertEqual(c.y, [])
+
+ def test_helper_asdict_nested(self):
+ @dataclass
+ class UserId:
+ token: int
+ group: int
+ @dataclass
+ class User:
+ name: str
+ id: UserId
+ u = User('Joe', UserId(123, 1))
+ d = asdict(u)
+ self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
+ self.assertIsNot(asdict(u), asdict(u))
+ u.id.group = 2
+ self.assertEqual(asdict(u), {'name': 'Joe',
+ 'id': {'token': 123, 'group': 2}})
+
+ def test_helper_asdict_builtin_containers(self):
+ @dataclass
+ class User:
+ name: str
+ id: int
+ @dataclass
+ class GroupList:
+ id: int
+ users: List[User]
+ @dataclass
+ class GroupTuple:
+ id: int
+ users: Tuple[User, ...]
+ @dataclass
+ class GroupDict:
+ id: int
+ users: Dict[str, User]
+ a = User('Alice', 1)
+ b = User('Bob', 2)
+ gl = GroupList(0, [a, b])
+ gt = GroupTuple(0, (a, b))
+ gd = GroupDict(0, {'first': a, 'second': b})
+ self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
+ {'name': 'Bob', 'id': 2}]})
+ self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
+ {'name': 'Bob', 'id': 2})})
+ self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
+ 'second': {'name': 'Bob', 'id': 2}}})
+
+ def test_helper_asdict_builtin_containers(self):
+ @dataclass
+ class Child:
+ d: object
+
+ @dataclass
+ class Parent:
+ child: Child
+
+ self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
+ self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
+
+ def test_helper_asdict_factory(self):
+ @dataclass
+ class C:
+ x: int
+ y: int
+ c = C(1, 2)
+ d = asdict(c, dict_factory=OrderedDict)
+ self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
+ self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
+ c.x = 42
+ d = asdict(c, dict_factory=OrderedDict)
+ self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
+ self.assertIs(type(d), OrderedDict)
+
+ def test_helper_astuple(self):
+ # Basic tests for astuple(), it should return a new tuple
+ @dataclass
+ class C:
+ x: int
+ y: int = 0
+ c = C(1)
+
+ self.assertEqual(astuple(c), (1, 0))
+ self.assertEqual(astuple(c), astuple(c))
+ self.assertIsNot(astuple(c), astuple(c))
+ c.y = 42
+ self.assertEqual(astuple(c), (1, 42))
+ self.assertIs(type(astuple(c)), tuple)
+
+ def test_helper_astuple_raises_on_classes(self):
+ # astuple() should raise on a class object
+ @dataclass
+ class C:
+ x: int
+ y: int
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ astuple(C)
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ astuple(int)
+
+ def test_helper_astuple_copy_values(self):
+ @dataclass
+ class C:
+ x: int
+ y: List[int] = field(default_factory=list)
+ initial = []
+ c = C(1, initial)
+ t = astuple(c)
+ self.assertEqual(t[1], initial)
+ self.assertIsNot(t[1], initial)
+ c = C(1)
+ t = astuple(c)
+ t[1].append(1)
+ self.assertEqual(c.y, [])
+
+ def test_helper_astuple_nested(self):
+ @dataclass
+ class UserId:
+ token: int
+ group: int
+ @dataclass
+ class User:
+ name: str
+ id: UserId
+ u = User('Joe', UserId(123, 1))
+ t = astuple(u)
+ self.assertEqual(t, ('Joe', (123, 1)))
+ self.assertIsNot(astuple(u), astuple(u))
+ u.id.group = 2
+ self.assertEqual(astuple(u), ('Joe', (123, 2)))
+
+ def test_helper_astuple_builtin_containers(self):
+ @dataclass
+ class User:
+ name: str
+ id: int
+ @dataclass
+ class GroupList:
+ id: int
+ users: List[User]
+ @dataclass
+ class GroupTuple:
+ id: int
+ users: Tuple[User, ...]
+ @dataclass
+ class GroupDict:
+ id: int
+ users: Dict[str, User]
+ a = User('Alice', 1)
+ b = User('Bob', 2)
+ gl = GroupList(0, [a, b])
+ gt = GroupTuple(0, (a, b))
+ gd = GroupDict(0, {'first': a, 'second': b})
+ self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
+ self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
+ self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
+
+ def test_helper_astuple_builtin_containers(self):
+ @dataclass
+ class Child:
+ d: object
+
+ @dataclass
+ class Parent:
+ child: Child
+
+ self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
+ self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
+
+ def test_helper_astuple_factory(self):
+ @dataclass
+ class C:
+ x: int
+ y: int
+ NT = namedtuple('NT', 'x y')
+ def nt(lst):
+ return NT(*lst)
+ c = C(1, 2)
+ t = astuple(c, tuple_factory=nt)
+ self.assertEqual(t, NT(1, 2))
+ self.assertIsNot(t, astuple(c, tuple_factory=nt))
+ c.x = 42
+ t = astuple(c, tuple_factory=nt)
+ self.assertEqual(t, NT(42, 2))
+ self.assertIs(type(t), NT)
+
+ def test_dynamic_class_creation(self):
+ cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
+ }
+
+ # Create the class.
+ cls = type('C', (), cls_dict)
+
+ # Make it a dataclass.
+ cls1 = dataclass(cls)
+
+ self.assertEqual(cls1, cls)
+ self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
+
+ def test_dynamic_class_creation_using_field(self):
+ cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
+ 'y': field(default=5),
+ }
+
+ # Create the class.
+ cls = type('C', (), cls_dict)
+
+ # Make it a dataclass.
+ cls1 = dataclass(cls)
+
+ self.assertEqual(cls1, cls)
+ self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
+
+ def test_init_in_order(self):
+ @dataclass
+ class C:
+ a: int
+ b: int = field()
+ c: list = field(default_factory=list, init=False)
+ d: list = field(default_factory=list)
+ e: int = field(default=4, init=False)
+ f: int = 4
+
+ calls = []
+ def setattr(self, name, value):
+ calls.append((name, value))
+
+ C.__setattr__ = setattr
+ c = C(0, 1)
+ self.assertEqual(('a', 0), calls[0])
+ self.assertEqual(('b', 1), calls[1])
+ self.assertEqual(('c', []), calls[2])
+ self.assertEqual(('d', []), calls[3])
+ self.assertNotIn(('e', 4), calls)
+ self.assertEqual(('f', 4), calls[4])
+
+ def test_items_in_dicts(self):
+ @dataclass
+ class C:
+ a: int
+ b: list = field(default_factory=list, init=False)
+ c: list = field(default_factory=list)
+ d: int = field(default=4, init=False)
+ e: int = 0
+
+ c = C(0)
+ # Class dict
+ self.assertNotIn('a', C.__dict__)
+ self.assertNotIn('b', C.__dict__)
+ self.assertNotIn('c', C.__dict__)
+ self.assertIn('d', C.__dict__)
+ self.assertEqual(C.d, 4)
+ self.assertIn('e', C.__dict__)
+ self.assertEqual(C.e, 0)
+ # Instance dict
+ self.assertIn('a', c.__dict__)
+ self.assertEqual(c.a, 0)
+ self.assertIn('b', c.__dict__)
+ self.assertEqual(c.b, [])
+ self.assertIn('c', c.__dict__)
+ self.assertEqual(c.c, [])
+ self.assertNotIn('d', c.__dict__)
+ self.assertIn('e', c.__dict__)
+ self.assertEqual(c.e, 0)
+
+ def test_alternate_classmethod_constructor(self):
+ # Since __post_init__ can't take params, use a classmethod
+ # alternate constructor. This is mostly an example to show how
+ # to use this technique.
+ @dataclass
+ class C:
+ x: int
+ @classmethod
+ def from_file(cls, filename):
+ # In a real example, create a new instance
+ # and populate 'x' from contents of a file.
+ value_in_file = 20
+ return cls(value_in_file)
+
+ self.assertEqual(C.from_file('filename').x, 20)
+
+ def test_field_metadata_default(self):
+ # Make sure the default metadata is read-only and of
+ # zero length.
+ @dataclass
+ class C:
+ i: int
+
+ self.assertFalse(fields(C)[0].metadata)
+ self.assertEqual(len(fields(C)[0].metadata), 0)
+ with self.assertRaisesRegex(TypeError,
+ 'does not support item assignment'):
+ fields(C)[0].metadata['test'] = 3
+
+ def test_field_metadata_mapping(self):
+ # Make sure only a mapping can be passed as metadata
+ # zero length.
+ with self.assertRaises(TypeError):
+ @dataclass
+ class C:
+ i: int = field(metadata=0)
+
+ # Make sure an empty dict works
+ @dataclass
+ class C:
+ i: int = field(metadata={})
+ self.assertFalse(fields(C)[0].metadata)
+ self.assertEqual(len(fields(C)[0].metadata), 0)
+ with self.assertRaisesRegex(TypeError,
+ 'does not support item assignment'):
+ fields(C)[0].metadata['test'] = 3
+
+ # Make sure a non-empty dict works.
+ @dataclass
+ class C:
+ i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
+ self.assertEqual(len(fields(C)[0].metadata), 3)
+ self.assertEqual(fields(C)[0].metadata['test'], 10)
+ self.assertEqual(fields(C)[0].metadata['bar'], '42')
+ self.assertEqual(fields(C)[0].metadata[3], 'three')
+ with self.assertRaises(KeyError):
+ # Non-existent key.
+ fields(C)[0].metadata['baz']
+ with self.assertRaisesRegex(TypeError,
+ 'does not support item assignment'):
+ fields(C)[0].metadata['test'] = 3
+
+ def test_field_metadata_custom_mapping(self):
+ # Try a custom mapping.
+ class SimpleNameSpace:
+ def __init__(self, **kw):
+ self.__dict__.update(kw)
+
+ def __getitem__(self, item):
+ if item == 'xyzzy':
+ return 'plugh'
+ return getattr(self, item)
+
+ def __len__(self):
+ return self.__dict__.__len__()
+
+ @dataclass
+ class C:
+ i: int = field(metadata=SimpleNameSpace(a=10))
+
+ self.assertEqual(len(fields(C)[0].metadata), 1)
+ self.assertEqual(fields(C)[0].metadata['a'], 10)
+ with self.assertRaises(AttributeError):
+ fields(C)[0].metadata['b']
+ # Make sure we're still talking to our custom mapping.
+ self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
+
+ def test_generic_dataclasses(self):
+ T = TypeVar('T')
+
+ @dataclass
+ class LabeledBox(Generic[T]):
+ content: T
+ label: str = '<unknown>'
+
+ box = LabeledBox(42)
+ self.assertEqual(box.content, 42)
+ self.assertEqual(box.label, '<unknown>')
+
+ # subscripting the resulting class should work, etc.
+ Alias = List[LabeledBox[int]]
+
+ def test_generic_extending(self):
+ S = TypeVar('S')
+ T = TypeVar('T')
+
+ @dataclass
+ class Base(Generic[T, S]):
+ x: T
+ y: S
+
+ @dataclass
+ class DataDerived(Base[int, T]):
+ new_field: str
+ Alias = DataDerived[str]
+ c = Alias(0, 'test1', 'test2')
+ self.assertEqual(astuple(c), (0, 'test1', 'test2'))
+
+ class NonDataDerived(Base[int, T]):
+ def new_method(self):
+ return self.y
+ Alias = NonDataDerived[float]
+ c = Alias(10, 1.0)
+ self.assertEqual(c.new_method(), 1.0)
+
+ def test_helper_replace(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ c1 = replace(c, x=3)
+ self.assertEqual(c1.x, 3)
+ self.assertEqual(c1.y, 2)
+
+ def test_helper_replace_frozen(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+ z: int = field(init=False, default=10)
+ t: int = field(init=False, default=100)
+
+ c = C(1, 2)
+ c1 = replace(c, x=3)
+ self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
+ self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
+
+
+ with self.assertRaisesRegex(ValueError, 'init=False'):
+ replace(c, x=3, z=20, t=50)
+ with self.assertRaisesRegex(ValueError, 'init=False'):
+ replace(c, z=20)
+ replace(c, x=3, z=20, t=50)
+
+ # Make sure the result is still frozen.
+ with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
+ c1.x = 3
+
+ # Make sure we can't replace an attribute that doesn't exist,
+ # if we're also replacing one that does exist. Test this
+ # here, because setting attributes on frozen instances is
+ # handled slightly differently from non-frozen ones.
+ with self.assertRaisesRegex(TypeError, "__init__\(\) got an unexpected "
+ "keyword argument 'a'"):
+ c1 = replace(c, x=20, a=5)
+
+ def test_helper_replace_invalid_field_name(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ with self.assertRaisesRegex(TypeError, "__init__\(\) got an unexpected "
+ "keyword argument 'z'"):
+ c1 = replace(c, z=3)
+
+ def test_helper_replace_invalid_object(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ replace(C, x=3)
+
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ replace(0, x=3)
+
+ def test_helper_replace_no_init(self):
+ @dataclass
+ class C:
+ x: int
+ y: int = field(init=False, default=10)
+
+ c = C(1)
+ c.y = 20
+
+ # Make sure y gets the default value.
+ c1 = replace(c, x=5)
+ self.assertEqual((c1.x, c1.y), (5, 10))
+
+ # Trying to replace y is an error.
+ with self.assertRaisesRegex(ValueError, 'init=False'):
+ replace(c, x=2, y=30)
+ with self.assertRaisesRegex(ValueError, 'init=False'):
+ replace(c, y=30)
+
+ def test_dataclassses_pickleable(self):
+ global P, Q, R
+ @dataclass
+ class P:
+ x: int
+ y: int = 0
+ @dataclass
+ class Q:
+ x: int
+ y: int = field(default=0, init=False)
+ @dataclass
+ class R:
+ x: int
+ y: List[int] = field(default_factory=list)
+ q = Q(1)
+ q.y = 2
+ samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
+ for sample in samples:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(sample=sample, proto=proto):
+ new_sample = pickle.loads(pickle.dumps(sample, proto))
+ self.assertEqual(sample.x, new_sample.x)
+ self.assertEqual(sample.y, new_sample.y)
+ self.assertIsNot(sample, new_sample)
+ new_sample.x = 42
+ another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
+ self.assertEqual(new_sample.x, another_new_sample.x)
+ self.assertEqual(sample.y, another_new_sample.y)
+
+ def test_helper_make_dataclass(self):
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', int, field(default=5))],
+ namespace={'add_one': lambda self: self.x + 1})
+ c = C(10)
+ self.assertEqual((c.x, c.y), (10, 5))
+ self.assertEqual(c.add_one(), 11)
+
+
+ def test_helper_make_dataclass_no_mutate_namespace(self):
+ # Make sure a provided namespace isn't mutated.
+ ns = {}
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', int, field(default=5))],
+ namespace=ns)
+ self.assertEqual(ns, {})
+
+ def test_helper_make_dataclass_base(self):
+ class Base1:
+ pass
+ class Base2:
+ pass
+ C = make_dataclass('C',
+ [('x', int)],
+ bases=(Base1, Base2))
+ c = C(2)
+ self.assertIsInstance(c, C)
+ self.assertIsInstance(c, Base1)
+ self.assertIsInstance(c, Base2)
+
+ def test_helper_make_dataclass_base_dataclass(self):
+ @dataclass
+ class Base1:
+ x: int
+ class Base2:
+ pass
+ C = make_dataclass('C',
+ [('y', int)],
+ bases=(Base1, Base2))
+ with self.assertRaisesRegex(TypeError, 'required positional'):
+ c = C(2)
+ c = C(1, 2)
+ self.assertIsInstance(c, C)
+ self.assertIsInstance(c, Base1)
+ self.assertIsInstance(c, Base2)
+
+ self.assertEqual((c.x, c.y), (1, 2))
+
+ def test_helper_make_dataclass_init_var(self):
+ def post_init(self, y):
+ self.x *= y
+
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', InitVar[int]),
+ ],
+ namespace={'__post_init__': post_init},
+ )
+ c = C(2, 3)
+ self.assertEqual(vars(c), {'x': 6})
+ self.assertEqual(len(fields(c)), 1)
+
+ def test_helper_make_dataclass_class_var(self):
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', ClassVar[int], 10),
+ ('z', ClassVar[int], field(default=20)),
+ ])
+ c = C(1)
+ self.assertEqual(vars(c), {'x': 1})
+ self.assertEqual(len(fields(c)), 1)
+ self.assertEqual(C.y, 10)
+ self.assertEqual(C.z, 20)
+
+
+class TestDocString(unittest.TestCase):
+ def assertDocStrEqual(self, a, b):
+ # Because 3.6 and 3.7 differ in how inspect.signature work
+ # (see bpo #32108), for the time being just compare them with
+ # whitespace stripped.
+ self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
+
+ def test_existing_docstring_not_overridden(self):
+ @dataclass
+ class C:
+ """Lorem ipsum"""
+ x: int
+
+ self.assertEqual(C.__doc__, "Lorem ipsum")
+
+ def test_docstring_no_fields(self):
+ @dataclass
+ class C:
+ pass
+
+ self.assertDocStrEqual(C.__doc__, "C()")
+
+ def test_docstring_one_field(self):
+ @dataclass
+ class C:
+ x: int
+
+ self.assertDocStrEqual(C.__doc__, "C(x:int)")
+
+ def test_docstring_two_fields(self):
+ @dataclass
+ class C:
+ x: int
+ y: int
+
+ self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
+
+ def test_docstring_three_fields(self):
+ @dataclass
+ class C:
+ x: int
+ y: int
+ z: str
+
+ self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
+
+ def test_docstring_one_field_with_default(self):
+ @dataclass
+ class C:
+ x: int = 3
+
+ self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
+
+ def test_docstring_one_field_with_default_none(self):
+ @dataclass
+ class C:
+ x: Union[int, type(None)] = None
+
+ self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
+
+ def test_docstring_list_field(self):
+ @dataclass
+ class C:
+ x: List[int]
+
+ self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
+
+ def test_docstring_list_field_with_default_factory(self):
+ @dataclass
+ class C:
+ x: List[int] = field(default_factory=list)
+
+ self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
+
+ def test_docstring_deque_field(self):
+ @dataclass
+ class C:
+ x: deque
+
+ self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
+
+ def test_docstring_deque_field_with_default_factory(self):
+ @dataclass
+ class C:
+ x: deque = field(default_factory=deque)
+
+ self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
+
+
+if __name__ == '__main__':
+ unittest.main()