-
Notifications
You must be signed in to change notification settings - Fork 1
/
iko.py
292 lines (246 loc) · 8.16 KB
/
iko.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import asyncio
import inspect
import typing
OPTIONAL = object()
EXCLUDE = 'exclude'
INCLUDE = 'include'
_TField = typing.Union['Field', typing.Type['Field']]
_TFields = typing.Dict[str, _TField]
_TSchema = typing.Union['Schema', typing.Type['Schema']]
class Field:
OPTIONAL_LIST: typing.Iterable = (OPTIONAL,)
def __init__(
self,
default=OPTIONAL,
dump_to=None,
load_from=None,
outer_name=None,
):
if outer_name is not None:
assert dump_to is None and load_from is None
self._default = default
self.dump_to = dump_to or outer_name
self.load_from = load_from or outer_name
@property
def default(self):
if callable(self._default):
return self._default()
return self._default
async def dump(self, data, attr, context):
value = data.get(attr, self.default)
if value in self.OPTIONAL_LIST:
return OPTIONAL
return await self.post_dump(value, context)
async def post_dump(self, value, context):
return value
async def load(self, data, attr, context):
value = data.get(attr, self.default)
if value in self.OPTIONAL_LIST:
return OPTIONAL
return await self.post_load(value, context)
async def post_load(self, value, context):
return value
class Nested(Field):
schema: 'Schema'
def __init__(
self,
schema,
default=OPTIONAL,
dump_to=None,
load_from=None,
outer_name=None,
):
if inspect.isclass(schema):
self.schema = schema()
else:
self.schema = schema
super().__init__(default, dump_to, load_from, outer_name)
async def post_dump(self, value, context):
return await self.schema.dump(value, context=context)
async def post_load(self, value, context):
return await self.schema.load(value, context=context)
class List(Field):
"""
If a field is passed, then only post methods are used.
"""
def __init__(
self,
schema: typing.Optional[_TSchema] = None,
field: typing.Optional[_TField] = None,
default=OPTIONAL,
dump_to=None,
load_from=None,
outer_name=None,
):
assert not (schema and field)
self.obj_dump = None
self.obj_load = None
if schema:
if inspect.isclass(schema):
schema = schema() # type: ignore
self.obj_dump = schema.dump
self.obj_load = schema.load
if field:
if inspect.isclass(field):
field = field() # type: ignore
self.obj_dump = field.post_dump # type: ignore
self.obj_load = field.post_load # type: ignore
super().__init__(default, dump_to, load_from, outer_name)
async def post_dump(self, value, context):
return [
(
await self.obj_dump(item, context=context)
if self.obj_dump
else item
)
for item in value
]
async def post_load(self, value, context):
return [
(
await self.obj_load(item, context=context)
if self.obj_load
else item
)
for item in value
]
class SchemaMeta(type):
def __new__(mcs, name, bases, attrs):
fields = {}
for base in bases:
if issubclass(base, Schema):
fields.update(base.__fields__)
for key, field in attrs.items():
if isinstance(field, Field):
fields[key] = field
elif (
inspect.isclass(field)
and issubclass(field, Field)
and field.__name__ != key
):
fields[key] = field()
attrs['__fields__'] = fields
klass = super().__new__(mcs, name, bases, attrs)
meta = getattr(klass, 'Meta')
klass.__opts__ = klass.OPTIONS_CLASS(meta)
return klass
class SchemaOpts:
def __init__(self, meta):
self.unknown = getattr(meta, 'unknown', EXCLUDE)
self.exclude = getattr(meta, 'exclude', tuple())
class Schema(metaclass=SchemaMeta):
OPTIONS_CLASS = SchemaOpts
__fields__: typing.Dict[str, Field]
__opts__: SchemaOpts
class Meta:
pass
@classmethod
async def dump(cls, data, *, only=None, exclude=None, context=None):
if context is None:
context = {}
data = await cls.pre_dump(data, context)
only = only or cls.__fields__
exclude = exclude or []
exclude.extend(cls.__opts__.exclude)
partial = set(cls.__fields__) & set(only) - set(exclude)
fields = [
(attr, field)
for attr, field in cls.__fields__.items()
if attr in partial
]
fields_coros = [
field.dump(data, attr, context) for attr, field in fields
]
values = await cls.gather(fields_coros, context)
attrs = [field.dump_to or attr for attr, field in fields]
result = {
attr: value
for attr, value in zip(attrs, values)
if value != OPTIONAL
}
if cls.__opts__.unknown == INCLUDE:
known_fields = set(cls.__fields__) | set(exclude)
result.update(
{
key: value
for key, value in data.items()
if key not in known_fields
},
)
return await cls.post_dump(result, context)
@classmethod
async def gather(cls, fields: typing.List[typing.Coroutine], context):
return await asyncio.gather(*fields)
@classmethod
async def pre_dump(cls, value, context):
return value
@classmethod
async def post_dump(cls, value, context):
return value
@classmethod
def dump_many(cls, items, *, context=None):
if context is None:
context = {}
return asyncio.gather(
*[cls.dump(item, context=context) for item in items],
)
@classmethod
async def load(cls, data, *, only=None, exclude=None, context=None):
if context is None:
context = {}
data = await cls.pre_load(data, context)
only = only or cls.__fields__
exclude = exclude or []
exclude.extend(cls.__opts__.exclude)
partial = set(cls.__fields__) & set(only) - set(exclude)
fields = [
(attr, field)
for attr, field in cls.__fields__.items()
if attr in partial
]
fields_coros = [
field.load(
data, field.load_from if field.load_from else attr, context,
)
for attr, field in fields
]
values = await cls.gather(fields_coros, context)
result = {
attr: value
for attr, value in zip([field[0] for field in fields], values)
if value != OPTIONAL
}
if cls.__opts__.unknown == INCLUDE:
known_fields = {
field.load_from if field.load_from else attr
for attr, field in cls.__fields__.items()
} | set(exclude)
result.update(
{
key: value
for key, value in data.items()
if key not in known_fields
},
)
return await cls.post_load(result, context)
@classmethod
async def pre_load(cls, value, context):
return value
@classmethod
async def post_load(cls, value, context):
return value
@classmethod
def load_many(cls, items, *, context=None):
if context is None:
context = {}
return asyncio.gather(
*[cls.load(item, context=context) for item in items],
)
def schema_from_dict(*args: _TFields, **fields: _TField):
all_fields: _TFields = {}
for arg in args:
all_fields.update(arg)
all_fields.update(fields)
return typing.cast(
typing.Type[Schema], type('DictSchema', (Schema,), all_fields),
)