Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions graphene_django/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,23 @@
from urllib.parse import urlencode


def url_string(**url_params):
string = '/graphql'

def url_string(string='/graphql', **url_params):
if url_params:
string += '?' + urlencode(url_params)

return string


def batch_url_string(**url_params):
return url_string('/graphql/batch', **url_params)


def response_json(response):
return json.loads(response.content.decode())


j = lambda **kwargs: json.dumps(kwargs)
jl = lambda **kwargs: json.dumps([kwargs])


def test_graphiql_is_enabled(client):
Expand Down Expand Up @@ -169,6 +172,17 @@ def test_allows_post_with_json_encoding(client):
}


def test_batch_allows_post_with_json_encoding(client):
response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json')

assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'payload': { 'data': {'test': "Hello World"} },
'status': 200,
}]


def test_allows_sending_a_mutation_via_post(client):
response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json')

Expand Down Expand Up @@ -199,6 +213,22 @@ def test_supports_post_json_query_with_string_variables(client):
}



def test_batch_supports_post_json_query_with_string_variables(client):
response = client.post(batch_url_string(), jl(
id=1,
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
), 'application/json')

assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'payload': { 'data': {'test': "Hello Dolly"} },
'status': 200,
}]


def test_supports_post_json_query_with_json_variables(client):
response = client.post(url_string(), j(
query='query helloWho($who: String){ test(who: $who) }',
Expand All @@ -211,6 +241,21 @@ def test_supports_post_json_query_with_json_variables(client):
}


def test_batch_supports_post_json_query_with_json_variables(client):
response = client.post(batch_url_string(), jl(
id=1,
query='query helloWho($who: String){ test(who: $who) }',
variables={'who': "Dolly"}
), 'application/json')

assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'payload': { 'data': {'test': "Hello Dolly"} },
'status': 200,
}]


def test_supports_post_url_encoded_query_with_string_variables(client):
response = client.post(url_string(), urlencode(dict(
query='query helloWho($who: String){ test(who: $who) }',
Expand Down Expand Up @@ -285,6 +330,33 @@ def test_allows_post_with_operation_name(client):
}


def test_batch_allows_post_with_operation_name(client):
response = client.post(batch_url_string(), jl(
id=1,
query='''
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
), 'application/json')

assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'payload': {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
},
'status': 200,
}]


def test_allows_post_with_get_operation_name(client):
response = client.post(url_string(
operationName='helloWorld'
Expand Down
1 change: 1 addition & 0 deletions graphene_django/tests/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from ..views import GraphQLView

urlpatterns = [
url(r'^graphql/batch', GraphQLView.as_view(batch=True)),
url(r'^graphql', GraphQLView.as_view(graphiql=True)),
]
82 changes: 54 additions & 28 deletions graphene_django/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ class GraphQLView(View):
middleware = None
root_value = None
pretty = False
batch = False

def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False):
def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False,
batch=False):
if not schema:
schema = graphene_settings.SCHEMA

Expand All @@ -77,8 +79,10 @@ def __init__(self, schema=None, executor=None, middleware=None, root_value=None,
self.root_value = root_value
self.pretty = pretty
self.graphiql = graphiql
self.batch = batch

assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
assert not all((graphiql, batch)), 'Use either graphiql or batch processing'

# noinspection PyUnusedLocal
def get_root_value(self, request):
Expand All @@ -99,34 +103,15 @@ def dispatch(self, request, *args, **kwargs):
data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)

query, variables, operation_name = self.get_graphql_params(request, data)

execution_result = self.execute_graphql_request(
request,
data,
query,
variables,
operation_name,
show_graphiql
)

if execution_result:
response = {}

if execution_result.errors:
response['errors'] = [self.format_error(e) for e in execution_result.errors]

if execution_result.invalid:
status_code = 400
else:
status_code = 200
response['data'] = execution_result.data

result = self.json_encode(request, response, pretty=show_graphiql)
if self.batch:
responses = [self.get_response(request, entry) for entry in data]
result = '[{}]'.format(','.join([response[0] for response in responses]))
status_code = max(responses, key=lambda response: response[1])[1]
else:
result = None
result, status_code = self.get_response(request, data, show_graphiql)

if show_graphiql:
query, variables, operation_name, id = self.get_graphql_params(request, data)
return self.render_graphiql(
request,
graphiql_version=self.graphiql_version,
Expand All @@ -150,6 +135,43 @@ def dispatch(self, request, *args, **kwargs):
})
return response

def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(request, data)

execution_result = self.execute_graphql_request(
request,
data,
query,
variables,
operation_name,
show_graphiql
)

status_code = 200
if execution_result:
response = {}

if execution_result.errors:
response['errors'] = [self.format_error(e) for e in execution_result.errors]

if execution_result.invalid:
status_code = 400
else:
response['data'] = execution_result.data

if self.batch:
response = {
'id': id,
'payload': response,
'status': status_code,
}

result = self.json_encode(request, response, pretty=show_graphiql)
else:
result = None

return result, status_code

def render_graphiql(self, request, **data):
return render(request, self.graphiql_template, data)

Expand All @@ -170,7 +192,10 @@ def parse_body(self, request):
elif content_type == 'application/json':
try:
request_json = json.loads(request.body.decode('utf-8'))
assert isinstance(request_json, dict)
if self.batch:
assert isinstance(request_json, list)
else:
assert isinstance(request_json, dict)
return request_json
except:
raise HttpError(HttpResponseBadRequest('POST body sent invalid JSON.'))
Expand Down Expand Up @@ -242,6 +267,7 @@ def request_wants_html(cls, request):
def get_graphql_params(request, data):
query = request.GET.get('query') or data.get('query')
variables = request.GET.get('variables') or data.get('variables')
id = request.GET.get('id') or data.get('id')

if variables and isinstance(variables, six.text_type):
try:
Expand All @@ -251,7 +277,7 @@ def get_graphql_params(request, data):

operation_name = request.GET.get('operationName') or data.get('operationName')

return query, variables, operation_name
return query, variables, operation_name, id

@staticmethod
def format_error(error):
Expand Down