Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
// A cache for the "field map" and list of fragment names found in any given
// selection set. Selection sets may be asked for this information multiple
// times, so this improves the performance of this validator.
val cachedFieldsAndFragmentNames = new MutableMap[Vector[ast.Selection], (MutableMap[String, ListBuffer[AstAndDef]], LinkedHashSet[String])]()
val cachedFieldsAndFragmentNames = new MutableMap[(Set[String], Vector[ast.Selection]), (MutableMap[String, ListBuffer[AstAndDef]], LinkedHashSet[String])]()

/**
* Algorithm:
Expand Down Expand Up @@ -85,32 +85,37 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
*/
override val onEnter: ValidationVisit = {
case selCont: ast.SelectionContainer if selCont.selections.nonEmpty ⇒
val conflicts = findConflictsWithinSelectionSet(ctx.typeInfo.parentType, selCont)
val conflicts = findConflictsWithinSelectionSet(ctx.typeInfo.parentType, selCont, Set.empty)

if (conflicts.nonEmpty)
Left(conflicts.toVector.map(c ⇒ FieldsConflictViolation(c.reason.fieldName, c.reason.reason, ctx.sourceMapper, (c.fields1 ++ c.fields2) flatMap (_.position))))
else
AstVisitorCommand.RightContinue
}

def findConflictsWithinSelectionSet(parentType: Option[Type], selCont: ast.SelectionContainer): ListBuffer[Conflict] = {
def findConflictsWithinSelectionSet(parentType: Option[Type], selCont: ast.SelectionContainer, visitedFragments: Set[String]): ListBuffer[Conflict] = {
val conflicts = ListBuffer[Conflict]()

val (fieldMap, fragmentNames) = getFieldsAndFragmentNames(parentType.asInstanceOf[Option[CompositeType[_]]], selCont)
val (fieldMap, fragmentNames) = getFieldsAndFragmentNames(parentType.asInstanceOf[Option[CompositeType[_]]], selCont, visitedFragments)

// (A) Find find all conflicts "within" the fields of this selection set.
// Note: this is the *only place* `collectConflictsWithin` is called.
collectConflictsWithin(conflicts, fieldMap)
collectConflictsWithin(conflicts, fieldMap, visitedFragments)

val fragmentNamesList = fragmentNames.toVector

// (B) Then collect conflicts between these fields and those represented by
// each spread fragment name found.
fragmentNames.zipWithIndex foreach { case (fragmentName, idx) ⇒
collectConflictsBetweenFieldsAndFragment(conflicts, fieldMap, fragmentName, false)
collectConflictsBetweenFieldsAndFragment(conflicts, fieldMap, fragmentName, false, visitedFragments + fragmentName)

for (i ← (idx + 1) until fragmentNamesList.size)
collectConflictsBetweenFragments(conflicts, fragmentName, fragmentNamesList(i), false)
collectConflictsBetweenFragments(
conflicts,
fragmentName,
fragmentNamesList(i),
visitedFragments + fragmentName,
visitedFragments + fragmentNamesList(i), false)
}

conflicts
Expand All @@ -125,6 +130,8 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
conflicts: ListBuffer[Conflict],
fieldMap1: MutableMap[String, ListBuffer[AstAndDef]],
fieldMap2: MutableMap[String, ListBuffer[AstAndDef]],
visitedFragments1: Set[String],
visitedFragments2: Set[String],
parentFieldsAreMutuallyExclusive: Boolean): Unit = {
// A field map is a keyed collection, where each key represents a response
// name and the value at that key is a list of all fields which provide that
Expand All @@ -139,15 +146,15 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
for {
f1 ← fields1
f2 ← fields2
} findConflict(outputName, f1, f2, parentFieldsAreMutuallyExclusive) foreach (conflicts += _)
} findConflict(outputName, f1, f2, visitedFragments1, visitedFragments2, parentFieldsAreMutuallyExclusive) foreach (conflicts += _)

case None ⇒ // It's ok, do nothing
}
}
}

// Collect all Conflicts "within" one collection of fields.
def collectConflictsWithin(conflicts: ListBuffer[Conflict], fieldMap: MutableMap[String, ListBuffer[AstAndDef]]): Unit = {
def collectConflictsWithin(conflicts: ListBuffer[Conflict], fieldMap: MutableMap[String, ListBuffer[AstAndDef]], visitedFragments: Set[String]): Unit = {
// A field map is a keyed collection, where each key represents a response
// name and the value at that key is a list of all fields which provide that
// response name. For every response name, if there are multiple fields, they
Expand All @@ -159,45 +166,49 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
for {
i ← 0 until fields.size
j ← (i + 1) until fields.size
} findConflict(outputName, fields(i), fields(j), false) foreach (conflicts += _)
} findConflict(outputName, fields(i), fields(j), visitedFragments, visitedFragments, false) foreach (conflicts += _)
}
}

def getFieldsAndFragmentNames(
parentType: Option[CompositeType[_]],
selCont: ast.SelectionContainer): (MutableMap[String, ListBuffer[AstAndDef]], LinkedHashSet[String]) = {
cachedFieldsAndFragmentNames.get(selCont.selections) match {
selCont: ast.SelectionContainer,
visitedFragments: Set[String]): (MutableMap[String, ListBuffer[AstAndDef]], LinkedHashSet[String]) = {
val cacheKey = visitedFragments → selCont.selections

cachedFieldsAndFragmentNames.get(cacheKey) match {
case Some(cached) ⇒ cached
case None ⇒
val astAndDefs = MutableMap[String, ListBuffer[AstAndDef]]()
val fragmentNames = mutable.LinkedHashSet[String]()

collectFieldsAndFragmentNames(parentType, selCont, astAndDefs, fragmentNames)
collectFieldsAndFragmentNames(parentType, selCont, astAndDefs, fragmentNames, visitedFragments)

val cached = astAndDefs → fragmentNames

cachedFieldsAndFragmentNames(selCont.selections) = cached
cachedFieldsAndFragmentNames(cacheKey) = cached
cached
}
}

// Given a reference to a fragment, return the represented collection of fields
// as well as a list of nested fragment names referenced via fragment spreads.
def getReferencedFieldsAndFragmentNames(fragment: ast.FragmentDefinition): (MutableMap[String, ListBuffer[AstAndDef]], LinkedHashSet[String]) = {
cachedFieldsAndFragmentNames.get(fragment.selections) match {
def getReferencedFieldsAndFragmentNames(fragment: ast.FragmentDefinition, visitedFragments: Set[String]): (MutableMap[String, ListBuffer[AstAndDef]], LinkedHashSet[String]) = {
cachedFieldsAndFragmentNames.get(visitedFragments → fragment.selections) match {
case Some(cached) ⇒ cached
case None ⇒
val fragmentType = ctx.schema.getOutputType(fragment.typeCondition, true).asInstanceOf[Option[CompositeType[_]]]

getFieldsAndFragmentNames(fragmentType, fragment)
getFieldsAndFragmentNames(fragmentType, fragment, visitedFragments)
}
}

def collectFieldsAndFragmentNames(
parentType: Option[OutputType[_]],
selCont: ast.SelectionContainer,
astAndDefs: MutableMap[String, ListBuffer[AstAndDef]],
fragmentNames: MutableSet[String]): Unit = {
fragmentNames: MutableSet[String],
visitedFragments: Set[String]): Unit = {
selCont.selections foreach {
case field: ast.Field ⇒
val fieldDef: Option[Field[_, _]] = parentType flatMap {
Expand All @@ -215,20 +226,26 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {

astAndDef += AstAndDef(field, parentType, fieldDef)

case fragment: ast.FragmentSpread if visitedFragments contains fragment.name ⇒
// This means a fragment spread in itself. We're going to infinite loop
// if we try and collect all fields. Pretend we did not index that fragment

case fragment: ast.FragmentSpread ⇒
fragmentNames += fragment.name

case fragment: ast.InlineFragment ⇒
val inlineFragmentType = fragment.typeCondition flatMap (ctx.schema.getOutputType(_, true)) orElse parentType

collectFieldsAndFragmentNames(inlineFragmentType, fragment, astAndDefs, fragmentNames)
collectFieldsAndFragmentNames(inlineFragmentType, fragment, astAndDefs, fragmentNames, visitedFragments)
}
}

def findConflict(
outputName: String,
fieldInfo1: AstAndDef,
fieldInfo2: AstAndDef,
visitedFragments1: Set[String],
visitedFragments2: Set[String],
parentFieldsAreMutuallyExclusive: Boolean): Option[Conflict] = {
val AstAndDef(ast1, parentType1, def1) = fieldInfo1
val AstAndDef(ast2, parentType2, def2) = fieldInfo2
Expand Down Expand Up @@ -266,7 +283,14 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
case None ⇒
val type1 = def1 map (d ⇒ d.fieldType.namedType)
val type2 = def2 map (d ⇒ d.fieldType.namedType)
val conflicts = findConflictsBetweenSubSelectionSets(areMutuallyExclusive, type1.asInstanceOf[Option[CompositeType[_]]], ast1, type2.asInstanceOf[Option[CompositeType[_]]], ast2)
val conflicts = findConflictsBetweenSubSelectionSets(
areMutuallyExclusive,
type1.asInstanceOf[Option[CompositeType[_]]],
ast1,
type2.asInstanceOf[Option[CompositeType[_]]],
ast2,
visitedFragments1,
visitedFragments2)

subfieldConflicts(conflicts, outputName, ast1, ast2)
}
Expand All @@ -281,32 +305,34 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
parentType1: Option[CompositeType[_]],
selCont1: ast.SelectionContainer,
parentType2: Option[CompositeType[_]],
selCont2: ast.SelectionContainer): ListBuffer[Conflict] = {
selCont2: ast.SelectionContainer,
visitedFragments1: Set[String],
visitedFragments2: Set[String]): ListBuffer[Conflict] = {
val conflicts = ListBuffer[Conflict]()

val (fieldMap1, fragmentNames1) = getFieldsAndFragmentNames(parentType1, selCont1)
val (fieldMap2, fragmentNames2) = getFieldsAndFragmentNames(parentType2, selCont2)
val (fieldMap1, fragmentNames1) = getFieldsAndFragmentNames(parentType1, selCont1, visitedFragments1)
val (fieldMap2, fragmentNames2) = getFieldsAndFragmentNames(parentType2, selCont2, visitedFragments2)

// (H) First, collect all conflicts between these two collections of field.
collectConflictsBetween(conflicts, fieldMap1, fieldMap2, areMutuallyExclusive)
collectConflictsBetween(conflicts, fieldMap1, fieldMap2, visitedFragments1, visitedFragments2, areMutuallyExclusive)

// (I) Then collect conflicts between the first collection of fields and
// those referenced by each fragment name associated with the second.
fragmentNames2 foreach (fragmentName ⇒
collectConflictsBetweenFieldsAndFragment(conflicts, fieldMap1, fragmentName, areMutuallyExclusive))
collectConflictsBetweenFieldsAndFragment(conflicts, fieldMap1, fragmentName, areMutuallyExclusive, visitedFragments2 + fragmentName))

// (I) Then collect conflicts between the second collection of fields and
// those referenced by each fragment name associated with the first.
fragmentNames1 foreach (fragmentName ⇒
collectConflictsBetweenFieldsAndFragment(conflicts, fieldMap2, fragmentName, areMutuallyExclusive))
collectConflictsBetweenFieldsAndFragment(conflicts, fieldMap2, fragmentName, areMutuallyExclusive, visitedFragments1 + fragmentName))

// (J) Also collect conflicts between any fragment names by the first and
// fragment names by the second. This compares each item in the first set of
// names to each item in the second set of names.
for {
fragmentName1 ← fragmentNames1
fragmentName2 ← fragmentNames2
} collectConflictsBetweenFragments(conflicts, fragmentName1, fragmentName2, areMutuallyExclusive)
} collectConflictsBetweenFragments(conflicts, fragmentName1, fragmentName2, visitedFragments1 + fragmentName1, visitedFragments2 + fragmentName2, areMutuallyExclusive)

conflicts
}
Expand All @@ -317,6 +343,8 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
conflicts: ListBuffer[Conflict],
fragmentName1: String,
fragmentName2: String,
visitedFragments1: Set[String],
visitedFragments2: Set[String],
areMutuallyExclusive: Boolean): Unit = {
(ctx.doc.fragments.get(fragmentName1), ctx.doc.fragments.get(fragmentName2)) match {
case (None, _) | (_, None) ⇒ // do nothing
Expand All @@ -329,42 +357,43 @@ class OverlappingFieldsCanBeMerged extends ValidationRule {
case (Some(f1), Some(f2)) ⇒
comparedFragments.add(f1.name, f2.name, areMutuallyExclusive)

val (fieldMap1, fragmentNames1) = getReferencedFieldsAndFragmentNames(f1)
val (fieldMap2, fragmentNames2) = getReferencedFieldsAndFragmentNames(f2)
val (fieldMap1, fragmentNames1) = getReferencedFieldsAndFragmentNames(f1, visitedFragments1)
val (fieldMap2, fragmentNames2) = getReferencedFieldsAndFragmentNames(f2, visitedFragments2)

// (F) First, collect all conflicts between these two collections of fields
// (not including any nested fragments).
collectConflictsBetween(conflicts, fieldMap1, fieldMap2, areMutuallyExclusive)
collectConflictsBetween(conflicts, fieldMap1, fieldMap2, visitedFragments1, visitedFragments2, areMutuallyExclusive)

// (G) Then collect conflicts between the first fragment and any nested
// fragments spread in the second fragment.
fragmentNames2 foreach (fragmentName ⇒
collectConflictsBetweenFragments(conflicts, fragmentName1, fragmentName, areMutuallyExclusive))
collectConflictsBetweenFragments(conflicts, fragmentName1, fragmentName, visitedFragments1, visitedFragments2 + fragmentName, areMutuallyExclusive))

// (G) Then collect conflicts between the first fragment and any nested
// fragments spread in the second fragment.
fragmentNames1 foreach (fragmentName ⇒
collectConflictsBetweenFragments(conflicts, fragmentName, fragmentName2, areMutuallyExclusive))
collectConflictsBetweenFragments(conflicts, fragmentName, fragmentName2, visitedFragments1 + fragmentName, visitedFragments2, areMutuallyExclusive))
}
}

def collectConflictsBetweenFieldsAndFragment(
conflicts: ListBuffer[Conflict],
fieldMap: MutableMap[String, ListBuffer[AstAndDef]],
fragmentName: String,
areMutuallyExclusive: Boolean): Unit = {
areMutuallyExclusive: Boolean,
visitedFragments: Set[String]): Unit = {
ctx.doc.fragments.get(fragmentName) match {
case Some(fragment) ⇒
val (fieldMap2, fragmentNames2) = getReferencedFieldsAndFragmentNames(fragment)
val (fieldMap2, fragmentNames2) = getReferencedFieldsAndFragmentNames(fragment, visitedFragments)

// (D) First collect any conflicts between the provided collection of fields
// and the collection of fields represented by the given fragment.
collectConflictsBetween(conflicts, fieldMap, fieldMap2, areMutuallyExclusive)
collectConflictsBetween(conflicts, fieldMap, fieldMap2, visitedFragments, visitedFragments, areMutuallyExclusive)

// (E) Then collect any conflicts between the provided collection of fields
// and any fragment names found in the given fragment.
fragmentNames2 foreach (fragmentName ⇒
collectConflictsBetweenFieldsAndFragment(conflicts, fieldMap, fragmentName, areMutuallyExclusive))
collectConflictsBetweenFieldsAndFragment(conflicts, fieldMap, fragmentName, areMutuallyExclusive, visitedFragments + fragmentName))

case None ⇒ // do nothing
}
Expand Down
59 changes: 58 additions & 1 deletion src/test/scala/sangria/execution/ExecutorSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import sangria.marshalling.InputUnmarshaller
import sangria.parser.QueryParser
import sangria.schema._
import sangria.macros._
import sangria.util.FutureResultSupport
import sangria.util.{DebugUtil, FutureResultSupport, Pos, SimpleGraphQlSupport}
import sangria.validation.QueryValidator
import InputUnmarshaller.mapVars
import sangria.execution.deferred.{Deferred, DeferredResolver, Fetcher, HasId}
import sangria.util.SimpleGraphQlSupport.checkContainsErrors

import scala.collection.immutable.Map
import scala.concurrent.{ExecutionContext, Future}
Expand Down Expand Up @@ -805,6 +806,62 @@ class ExecutorSpec extends WordSpec with Matchers with FutureResultSupport {
result.ctx.acc should be ("OneTwoThree")
}

"validate recursive fragments" in checkContainsErrors(schema = Schema(DataType), data = (), userContext = Ctx(),
query =
"""
{
...c
}

fragment one on DataType {
pic(size: 1)
}

fragment two on DataType {
pic(size: 3)
}

fragment c on DataType {
...c
...one
...two
}
""",
expectedData = null,
expectedErrorStrings = List(
"Cannot spread fragment 'c' within itself." → List(Pos(15, 13)),
"Field 'pic' conflict because they have differing arguments." → List(Pos(11, 13), Pos(7, 13))))

"validate mutually recursive fragments" in checkContainsErrors(schema = Schema(DataType), data = (), userContext = Ctx(),
query =
"""
{
...c
}

fragment one on DataType {
pic(size: 1)
}

fragment two on DataType {
pic(size: 3)
}

fragment c on DataType {
...d
}

fragment d on DataType {
...c
...one
...two
}
""",
expectedData = null,
expectedErrorStrings = List(
"Cannot spread fragment 'c' within itself via 'd'." → List(Pos(15, 13), Pos(19, 13)),
"Field 'pic' conflict because they have differing arguments." → List(Pos(11, 13), Pos(7, 13))))

"support `Action.sequence` in queries and mutations" in {
val error = new IllegalStateException("foo")

Expand Down
Loading