Skip to content

Commit a746893

Browse files
allow overriding produces/consumes with @swagger_auto_schema decorator (#916)
Co-authored-by: Joel Lefkowitz <[email protected]>
1 parent e747ad6 commit a746893

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

src/drf_yasg/inspectors/view.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,13 @@ def get_consumes(self):
473473
474474
:rtype: list[str]
475475
"""
476-
return get_consumes(self.get_parser_classes())
476+
return self.overrides.get("consumes") or get_consumes(self.get_parser_classes())
477477

478478
def get_produces(self):
479479
"""Return the MIME types this endpoint can produce.
480480
481481
:rtype: list[str]
482482
"""
483-
return get_produces(self.get_renderer_classes())
483+
return self.overrides.get("produces") or get_produces(
484+
self.get_renderer_classes()
485+
)

src/drf_yasg/utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def swagger_auto_schema(
6565
filter_inspectors=None,
6666
paginator_inspectors=None,
6767
tags=None,
68+
produces=None,
69+
consumes=None,
6870
**extra_overrides,
6971
):
7072
"""Decorate a view method to customize the :class:`.Operation` object generated from
@@ -164,6 +166,8 @@ def swagger_auto_schema(
164166
:attr:`.ViewInspector.paginator_inspectors` on the
165167
:class:`.inspectors.SwaggerAutoSchema`
166168
:param list[str] tags: tags override
169+
:param list[str] produces: produces override
170+
:param list[str] consumes: consumes override
167171
:param extra_overrides: extra values that will be saved into the ``overrides`` dict;
168172
these values will be available in the handling
169173
:class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides``
@@ -174,20 +178,22 @@ def decorator(view_method):
174178
"HTTP method names not allowed here"
175179
)
176180
data = {
177-
"request_body": request_body,
178-
"query_serializer": query_serializer,
181+
"consumes": consumes,
182+
"deprecated": deprecated,
183+
"field_inspectors": list(field_inspectors) if field_inspectors else None,
184+
"filter_inspectors": list(filter_inspectors) if filter_inspectors else None,
179185
"manual_parameters": manual_parameters,
186+
"operation_description": operation_description,
180187
"operation_id": operation_id,
181188
"operation_summary": operation_summary,
182-
"deprecated": deprecated,
183-
"operation_description": operation_description,
184-
"security": security,
185-
"responses": responses,
186-
"filter_inspectors": list(filter_inspectors) if filter_inspectors else None,
187189
"paginator_inspectors": list(paginator_inspectors)
188190
if paginator_inspectors
189191
else None,
190-
"field_inspectors": list(field_inspectors) if field_inspectors else None,
192+
"produces": produces,
193+
"query_serializer": query_serializer,
194+
"request_body": request_body,
195+
"responses": responses,
196+
"security": security,
191197
"tags": list(tags) if tags else None,
192198
}
193199
data = filter_none(data)

tests/test_schema_generator.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,41 @@ def _basename_or_base_name(basename):
122122
return {"base_name": basename}
123123

124124

125+
@pytest.mark.parametrize(
126+
"key,value,override_schema_key",
127+
(
128+
("consumes", ["application/vnd.x-override"], None),
129+
("deprecated", True, None),
130+
("operation_description", "description override", "description"),
131+
("operation_id", "id override", "operationId"),
132+
("operation_summary", "summary override", "summary"),
133+
("produces", ["application/vnd.x-override"], None),
134+
("tags", ["tag override"], None),
135+
),
136+
)
137+
def test_overrides(key, value, override_schema_key):
138+
@swagger_auto_schema(method="get", **{key: value})
139+
@api_view()
140+
def test_override(request, pk=None):
141+
return Response({"message": "Hello, world!"})
142+
143+
generator = OpenAPISchemaGenerator(
144+
info=openapi.Info(title="Test generator", default_version="v1"),
145+
version="v2",
146+
url="",
147+
patterns=[path("test/", test_override)],
148+
)
149+
150+
assert generator.get_schema(None, True)["paths"]["/test/"]["get"] == {
151+
"description": "",
152+
"operationId": "test_list",
153+
"parameters": [],
154+
"responses": openapi.Responses({200: openapi.Response("")}),
155+
"tags": ["test"],
156+
(key if override_schema_key is None else override_schema_key): value,
157+
}
158+
159+
125160
def test_replaced_serializer():
126161
class DetailSerializer(serializers.Serializer):
127162
detail = serializers.CharField()

0 commit comments

Comments
 (0)