/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.idp.saml.attribute.impl;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;

import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.ext.saml2mdattr.EntityAttributes;
import org.opensaml.saml.metadata.resolver.filter.FilterException;
import org.opensaml.saml.metadata.resolver.filter.MetadataNodeProcessor;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.metadata.AttributeConsumingService;
import org.opensaml.saml.saml2.metadata.EntitiesDescriptor;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.Extensions;
import org.opensaml.saml.saml2.metadata.RequestedAttribute;
import org.slf4j.Logger;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;

import net.shibboleth.idp.attribute.AttributeDecodingException;
import net.shibboleth.idp.attribute.AttributesMapContainer;
import net.shibboleth.idp.attribute.IdPAttribute;
import net.shibboleth.idp.attribute.IdPRequestedAttribute;
import net.shibboleth.idp.attribute.transcoding.AttributeTranscoder;
import net.shibboleth.idp.attribute.transcoding.AttributeTranscoderRegistry;
import net.shibboleth.idp.attribute.transcoding.TranscoderSupport;
import net.shibboleth.idp.attribute.transcoding.TranscodingRule;
import net.shibboleth.idp.saml.attribute.transcoding.SAML2AttributeTranscoder;
import net.shibboleth.idp.saml.attribute.transcoding.impl.SAML2StringAttributeTranscoder;
import net.shibboleth.shared.collection.CollectionSupport;
import net.shibboleth.shared.component.ComponentInitializationException;
import net.shibboleth.shared.logic.Constraint;
import net.shibboleth.shared.logic.ConstraintViolationException;
import net.shibboleth.shared.primitive.LoggerFactory;
import net.shibboleth.shared.service.ReloadableService;
import net.shibboleth.shared.service.ServiceException;
import net.shibboleth.shared.service.ServiceableComponent;

/**
 * An implementation of {@link MetadataNodeProcessor} which extracts {@link IdPRequestedAttribute}s from any
 * {@link AttributeConsumingService} we find and {@link IdPAttribute}s from any {@link EntityDescriptor} that we find.
 */
@NotThreadSafe
public class AttributeMappingNodeProcessor implements MetadataNodeProcessor {

    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(AttributeMappingNodeProcessor.class);

    /** Service used to get the registry of decoding rules. */
    @Nonnull private final ReloadableService<AttributeTranscoderRegistry> transcoderRegistry;
    
    /** Fallback for URI-named entity tags. */
    @Nonnull private final AttributeTranscoder<Attribute> defaultTranscoder;

    /**
     * Constructor.
     * 
     * @param registry the service for the decoding rules
     */
    public AttributeMappingNodeProcessor(@Nonnull final ReloadableService<AttributeTranscoderRegistry> registry) {
        transcoderRegistry = Constraint.isNotNull(registry, "AttributeTranscoderRegistry cannot be null");
        
        defaultTranscoder = new SAML2StringAttributeTranscoder();
        try {
            defaultTranscoder.initialize();
        } catch (final ComponentInitializationException e) {
            throw new ConstraintViolationException("Error initializing default transcoder");
        }
    }
    
    /** {@inheritDoc} */
    @Override
    public void process(@Nonnull final XMLObject metadataNode) throws FilterException {

        if (metadataNode instanceof AttributeConsumingService || metadataNode instanceof EntityDescriptor) {
            try (final ServiceableComponent<AttributeTranscoderRegistry> component =
                    transcoderRegistry.getServiceableComponent()) {
                if (metadataNode instanceof AttributeConsumingService) {
                    handleAttributeConsumingService(component.getComponent(), (AttributeConsumingService) metadataNode);
                } else if (metadataNode instanceof EntityDescriptor) {
                    handleEntityAttributes(component.getComponent(), ((EntityDescriptor) metadataNode).getExtensions());
                    XMLObject parent = metadataNode.getParent();
                    while (parent instanceof EntitiesDescriptor) {
                        handleEntityAttributes(component.getComponent(), ((EntitiesDescriptor) parent).getExtensions());
                        parent = parent.getParent();
                    }
                }
            } catch (final ServiceException e) {
                log.warn("Invalid AttributeTranscoderRegistry configuration", e);
            }
        }
    }

    /**
     * Look inside the {@link AttributeConsumingService} for any {@link RequestedAttribute}s and map them.
     * 
     * @param registry the registry service
     * @param acs the {@link AttributeConsumingService} to look at
     */
    private void handleAttributeConsumingService(@Nonnull final AttributeTranscoderRegistry registry,
            @Nonnull final AttributeConsumingService acs) {
        
        final List<RequestedAttribute> requestedAttributes = acs.getRequestedAttributes();
        if (null == requestedAttributes || requestedAttributes.isEmpty()) {
            return;
        }
        
        final Multimap<String,IdPAttribute> results = HashMultimap.create();
        assert results != null;
        
        for (final RequestedAttribute req : requestedAttributes) {
            try {
                assert req != null;
                decodeAttribute(registry.getTranscodingRules(req), req, results);
            } catch (final AttributeDecodingException e) {
                log.warn("Error decoding RequestedAttribute '{}'", req.getName(), e);
            }
        }
        
        if (!results.isEmpty()) {
            acs.getObjectMetadata().put(new AttributesMapContainer(results));
        }
    }

    /**
     * Look inside the {@link Extensions} for {@link EntityAttributes} and map them.
     * 
     * @param registry the registry service
     * @param extensions the extensions block
     */
//CheckStyle: CyclomaticComplexity OFF
    private void handleEntityAttributes(@Nonnull final AttributeTranscoderRegistry registry,
            @Nullable final Extensions extensions) {
        if (null == extensions) {
            return;
        }
        
        final List<XMLObject> entityAttributesList =
                extensions.getUnknownXMLObjects(EntityAttributes.DEFAULT_ELEMENT_NAME);
        if (null == entityAttributesList || entityAttributesList.isEmpty()) {
            return;
        }

        final XMLObject parent = extensions.getParent();
        if (parent == null) {
            log.warn("Extensions object had no parent to store results");
            return;
        }
        

        final Multimap<String,IdPAttribute> results = HashMultimap.create();
        assert results != null;
        
        for (final XMLObject xmlObj : entityAttributesList) {
            if (xmlObj instanceof EntityAttributes) {
                final EntityAttributes ea = (EntityAttributes) xmlObj;
                for (final Attribute attr : ea.getAttributes()) {
                    try {
                        assert attr != null;
                        @Nonnull Collection<TranscodingRule> rulesets = registry.getTranscodingRules(attr);
                        if (rulesets.isEmpty() && Attribute.URI_REFERENCE.equals(attr.getNameFormat())) {
                            log.trace("Applying default decoding rule for URI-named attribute {}", attr.getName());
                            final Map<String,Object> rulemap = new HashMap<>();
                            rulemap.put(AttributeTranscoderRegistry.PROP_ID, attr.getName());
                            rulemap.put(AttributeTranscoderRegistry.PROP_TRANSCODER, defaultTranscoder);
                            rulemap.put(SAML2AttributeTranscoder.PROP_NAME, attr.getName());
                            final TranscodingRule defaultRule = new TranscodingRule(rulemap);
                            rulesets = CollectionSupport.singletonList(defaultRule);
                        }

                        assert rulesets != null;
                        decodeAttribute(rulesets, attr, results);
                    } catch (final AttributeDecodingException e) {
                        log.warn("Error decoding RequestedAttribute '{}'", attr.getName(), e);
                    }
                }
            }
        }
        
        if (!results.isEmpty()) {
            parent.getObjectMetadata().put(new AttributesMapContainer(results));
        }
    }
  //CheckStyle: CyclomaticComplexity ON

    /**
     * Access the registry of transcoding rules to decode the input object.
     * 
     * @param <T> input type
     * @param rules transcoding rules
     * @param input input object
     * @param results collection to add results to
     * 
     * @throws AttributeDecodingException if an error occurs or no results were obtained
     */
    private <T> void decodeAttribute(@Nonnull final Collection<TranscodingRule> rules, @Nonnull final T input,
            @Nonnull final Multimap<String,IdPAttribute> results) throws AttributeDecodingException {
        
        for (final TranscodingRule rule : rules) {
            assert rule != null;
            final AttributeTranscoder<T> transcoder = TranscoderSupport.getTranscoder(rule);
            final IdPAttribute decodedAttribute = transcoder.decode(null, input, rule);
            if (decodedAttribute != null) {
                results.put(decodedAttribute.getId(), decodedAttribute);
            }
        }
    }

}